Initial commit
This commit is contained in:
174
pipe/backend/backend.go
Normal file
174
pipe/backend/backend.go
Normal file
@@ -0,0 +1,174 @@
|
||||
// Copyright 2015 Bytemark Computer Consulting Ltd. All rights reserved
|
||||
// Licensed under the GNU General Public License, version 2. See the LICENSE
|
||||
// file for more details
|
||||
|
||||
// Handler for the PowerDNS pipebackend protocol, as documented here:
|
||||
// https://doc.powerdns.com/md/authoritative/backend-pipe/
|
||||
//
|
||||
// Can speak all three (at time of writing) protocol versions.
|
||||
//
|
||||
// Usage:
|
||||
//
|
||||
// backend := backend.New(
|
||||
|
||||
package backend
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Backend struct {
|
||||
// The text the backend will serve to a successful hello message
|
||||
Banner string
|
||||
|
||||
// The protocol version negotiated with the remote end
|
||||
ProtocolVersion int
|
||||
|
||||
io *bufio.ReadWriter
|
||||
}
|
||||
|
||||
// A callback of this type is executed whenever a query is received. If an error
|
||||
// is returned, the responses are ignored and the error text is returned to the
|
||||
// backend. Otherwise, the responses are serialised and sent back in order.
|
||||
type Callback func(b *Backend, q *Query) ([]*Response, error)
|
||||
|
||||
// Build a new backend object. The banner is reported to the client upon
|
||||
// successful negotiation; the io can be anything.
|
||||
func New(r io.Reader, w io.Writer, banner string) *Backend {
|
||||
io := bufio.NewReadWriter(
|
||||
bufio.NewReader(r),
|
||||
bufio.NewWriter(w),
|
||||
)
|
||||
return &Backend{Banner: banner, io: io}
|
||||
}
|
||||
|
||||
// Does initial handshake with peer. Returns nil, or an error
|
||||
// Note that the pipebackend protocol documentation states that if negotiation
|
||||
// fails, the process should retry, not exit itself.
|
||||
func (b *Backend) Negotiate() error {
|
||||
hello, err := b.io.ReadString('\n')
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// We're not interested in the trailing newlines
|
||||
parts := strings.Split(strings.TrimRight(hello, "\r\n"), "\t")
|
||||
|
||||
if len(parts) != 2 || parts[0] != "HELO" {
|
||||
return errors.New("Bad hello from client")
|
||||
}
|
||||
|
||||
version, err := strconv.Atoi(parts[1])
|
||||
if version < 1 || version > 3 || err != nil {
|
||||
return errors.New("Unknown protocol version requested")
|
||||
}
|
||||
|
||||
_, err = b.io.WriteString(fmt.Sprintf("OK\t%s\n", b.Banner))
|
||||
if err == nil {
|
||||
err = b.io.Flush()
|
||||
}
|
||||
|
||||
if err == nil {
|
||||
b.ProtocolVersion = version
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (b *Backend) handleQ(data string, callback Callback) ([]*Response, error) {
|
||||
query := Query{ProtocolVersion: b.ProtocolVersion}
|
||||
|
||||
err := query.fromData(data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return callback(b, &query)
|
||||
}
|
||||
|
||||
// TODO
|
||||
func (b *Backend) handleAXFR() ([]*Response, error) {
|
||||
return nil, errors.New("AXFR requests not supported")
|
||||
}
|
||||
|
||||
// Reads lines in a loop, processing them by executing the provided callback
|
||||
// and writing appropriate output in response, sequentially, until we hit an
|
||||
// error or our IO hits EOF
|
||||
func (b *Backend) Run(callback Callback) error {
|
||||
responses := make([]*Response, 0)
|
||||
|
||||
for {
|
||||
line, err := b.io.ReadString('\n')
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
parts := strings.SplitN(strings.TrimRight(line, "\n"), "\t", 2)
|
||||
if len(parts) == 2 {
|
||||
|
||||
}
|
||||
|
||||
switch parts[0] {
|
||||
case "Q":
|
||||
responses, err = b.handleQ(parts[1], callback)
|
||||
case "PING":
|
||||
responses, err = nil, nil // We just need to return END
|
||||
case "AXFR":
|
||||
responses, err = b.handleAXFR()
|
||||
default:
|
||||
responses, err = nil, errors.New("Bad command")
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
// avoid protocol errors
|
||||
clean := strings.Replace(err.Error(), "\n", " ", -1)
|
||||
msg := fmt.Sprintf("LOG\tError handling line: %s\nFAIL\n", clean)
|
||||
_, err := b.io.WriteString(msg)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s while writing FAIL response", err)
|
||||
}
|
||||
//
|
||||
err = b.io.Flush()
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s while flushing FAIL response", err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
// DATA (if there are any records to return)
|
||||
for _, response := range responses {
|
||||
// Always output a line of the right protocol version
|
||||
// TODO: panic if it's set to a wrong non-zero value?
|
||||
response.ProtocolVersion = b.ProtocolVersion
|
||||
data, err := response.String()
|
||||
if err != nil {
|
||||
data = "LOG\tError serialising response: " + err.Error() + "\n"
|
||||
}
|
||||
_, err = b.io.WriteString(data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s while writing DATA response", err)
|
||||
}
|
||||
}
|
||||
|
||||
// END
|
||||
_, err = b.io.WriteString("END\n")
|
||||
if err == nil {
|
||||
err = b.io.Flush()
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("%s while writing END", err)
|
||||
}
|
||||
}
|
||||
|
||||
// We should never hit this at the moment.
|
||||
// TODO: graceful signal handling - intercept and ensure current query
|
||||
// completes before breaking the above loop and returning nil here
|
||||
return nil
|
||||
}
|
169
pipe/backend/backend_test.go
Normal file
169
pipe/backend/backend_test.go
Normal file
@@ -0,0 +1,169 @@
|
||||
package backend_test
|
||||
|
||||
import (
|
||||
h "../test_helpers"
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
. "github.com/BytemarkHosting/go-pdns/pipe/backend"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// Test serializing Query & Response instances - we use them in the tests
|
||||
func TestQueryStringV1(t *testing.T) {
|
||||
h.AssertEqualString(
|
||||
t, "Q\texample.com\tIN\tANY\t-1\t127.0.0.2\n",
|
||||
h.FakeQueryString(t, 1), "V1 query serialisation problem",
|
||||
)
|
||||
}
|
||||
|
||||
func TestQueryStringV2(t *testing.T) {
|
||||
h.AssertEqualString(
|
||||
t, "Q\texample.com\tIN\tANY\t-1\t127.0.0.2\t127.0.0.1\n",
|
||||
h.FakeQueryString(t, 2), "V2 query serialisation problem",
|
||||
)
|
||||
}
|
||||
|
||||
func TestQueryStringV3(t *testing.T) {
|
||||
h.AssertEqualString(
|
||||
t, "Q\texample.com\tIN\tANY\t-1\t127.0.0.2\t127.0.0.1\t127.0.0.3\n",
|
||||
h.FakeQueryString(t, 3), "V3 query serialisation problem",
|
||||
)
|
||||
}
|
||||
|
||||
// Ensure we serialize Response instances correctly - we use them in the tests
|
||||
func TestResponseStringV1andV2(t *testing.T) {
|
||||
exemplar := "DATA\texample.com\tIN\tANY\t3600\t-1\tfoo\n"
|
||||
|
||||
h.AssertEqualString(t, exemplar, h.FakeResponseString(t, 1), "V1 response serialisation problem")
|
||||
h.AssertEqualString(t, exemplar, h.FakeResponseString(t, 2), "V2 response serialisation problem")
|
||||
}
|
||||
|
||||
func TestResponseStringV3(t *testing.T) {
|
||||
h.AssertEqualString(
|
||||
t, "DATA\t24\tauth\texample.com\tIN\tANY\t3600\t-1\tfoo\n",
|
||||
h.FakeResponseString(t, 3), "V3 response serialisation problem",
|
||||
)
|
||||
}
|
||||
|
||||
func BuildAndNegotiate(t *testing.T, protoVersion int) (*Backend, *bytes.Buffer, *bytes.Buffer) {
|
||||
r := bytes.NewBufferString(fmt.Sprintf("HELO\t%d\n", protoVersion))
|
||||
w := &bytes.Buffer{}
|
||||
b := New(r, w, "Testing Backend")
|
||||
|
||||
h.RefuteError(t, b.Negotiate(), "Negotiation failed")
|
||||
h.AssertEqualInt(t, protoVersion, b.ProtocolVersion, "Bad protocol version")
|
||||
h.AssertEqualString(t, "OK\tTesting Backend\n", w.String(), "Bad response to HELO")
|
||||
w.Reset()
|
||||
|
||||
return b, r, w
|
||||
}
|
||||
|
||||
func AssertRun(t *testing.T, b *Backend, f Callback) {
|
||||
err := b.Run(f)
|
||||
h.RefuteError(t, err, "Running backend")
|
||||
}
|
||||
|
||||
func AssertProtocolVersionNegotiation(t *testing.T, protoVersion int) {
|
||||
b, r, w := BuildAndNegotiate(t, protoVersion)
|
||||
r.WriteString(h.FakeQueryString(t, protoVersion))
|
||||
|
||||
// We also test that the negotiated version can be handled
|
||||
AssertRun(t, b, h.EmptyDispatch)
|
||||
h.AssertEqualString(t, "END\n", w.String(), "Unexpected response")
|
||||
w.Reset()
|
||||
|
||||
exp := fmt.Sprintf(
|
||||
"LOG\tError handling line: v%d query should have %d data parts\nFAIL\n",
|
||||
protoVersion, protoVersion+4,
|
||||
)
|
||||
|
||||
// Test a short Q
|
||||
r.WriteString("Q\tfoo\n")
|
||||
AssertRun(t, b, h.EmptyDispatch)
|
||||
h.AssertEqualString(t, exp, w.String(), "Unexpected response")
|
||||
w.Reset()
|
||||
|
||||
// test a long Q
|
||||
long := strings.TrimRight(h.FakeQueryString(t, protoVersion), "\n") + "\tfoo\n"
|
||||
r.WriteString(long)
|
||||
AssertRun(t, b, h.EmptyDispatch)
|
||||
h.AssertEqualString(t, exp, w.String(), "Unexpected response")
|
||||
}
|
||||
|
||||
func TestNegotiatedVersion1(t *testing.T) {
|
||||
AssertProtocolVersionNegotiation(t, 1)
|
||||
}
|
||||
func TestNegotiatedVersion2(t *testing.T) {
|
||||
AssertProtocolVersionNegotiation(t, 2)
|
||||
}
|
||||
|
||||
func TestNegotiatedVersion3(t *testing.T) {
|
||||
AssertProtocolVersionNegotiation(t, 3)
|
||||
}
|
||||
|
||||
func TestQueriesArePassedToDispatcher(t *testing.T) {
|
||||
b, r, w := BuildAndNegotiate(t, 3)
|
||||
var outQ *Query
|
||||
runs := 0
|
||||
|
||||
expected := h.FakeQueryString(t, 3)
|
||||
r.WriteString(expected)
|
||||
|
||||
AssertRun(t, b, func(b *Backend, q *Query) ([]*Response, error) {
|
||||
runs = runs + 1
|
||||
outQ = q
|
||||
return nil, nil
|
||||
})
|
||||
h.AssertEqualInt(t, 1, runs, "Exactly one dispatch callback expected")
|
||||
|
||||
txt, _ := outQ.String()
|
||||
h.AssertEqualString(t, expected, txt, "Wrong query dispatched")
|
||||
h.AssertEqualString(t, "END\n", w.String(), "Unexpected response")
|
||||
}
|
||||
|
||||
func TestResponsesFromDispatcherArePassedToAsker(t *testing.T) {
|
||||
b, r, w := BuildAndNegotiate(t, 3)
|
||||
r.WriteString(h.FakeQueryString(t, 3))
|
||||
|
||||
fr := h.FakeResponse(3)
|
||||
AssertRun(t, b, func(b *Backend, q *Query) ([]*Response, error) {
|
||||
return []*Response{fr, fr}, nil
|
||||
})
|
||||
|
||||
exp := fmt.Sprintf("%s%sEND\n", h.FakeResponseString(t, 3), h.FakeResponseString(t, 3))
|
||||
h.AssertEqualString(t, exp, w.String(), "Bad response")
|
||||
}
|
||||
|
||||
func TestErrorFromDispatcherSuppressesResponses(t *testing.T) {
|
||||
b, r, w := BuildAndNegotiate(t, 3)
|
||||
r.WriteString(h.FakeQueryString(t, 3))
|
||||
|
||||
fr := h.FakeResponse(3)
|
||||
AssertRun(t, b, func(b *Backend, q *Query) ([]*Response, error) {
|
||||
return []*Response{fr, fr}, errors.New("Test\nerror")
|
||||
})
|
||||
h.AssertEqualString(t, "LOG\tError handling line: Test error\nFAIL\n", w.String(), "Bad response")
|
||||
}
|
||||
|
||||
func TestHandlesPing(t *testing.T) {
|
||||
b, r, w := BuildAndNegotiate(t, 3)
|
||||
r.WriteString("PING\n")
|
||||
AssertRun(t, b, h.EmptyDispatch)
|
||||
h.AssertEqualString(t, "END\n", w.String(), "Bad response")
|
||||
}
|
||||
|
||||
func TestAXFRIsTODO(t *testing.T) {
|
||||
b, r, w := BuildAndNegotiate(t, 3)
|
||||
r.WriteString("AXFR\n")
|
||||
AssertRun(t, b, h.EmptyDispatch)
|
||||
h.AssertEqualString(t, "LOG\tError handling line: AXFR requests not supported\nFAIL\n", w.String(), "Bad response")
|
||||
}
|
||||
|
||||
func TestUnknownCommand(t *testing.T) {
|
||||
b, r, w := BuildAndNegotiate(t, 3)
|
||||
r.WriteString("GOGOGO\n")
|
||||
AssertRun(t, b, h.EmptyDispatch)
|
||||
h.AssertEqualString(t, "LOG\tError handling line: Bad command\nFAIL\n", w.String(), "Bad response")
|
||||
}
|
85
pipe/backend/query.go
Normal file
85
pipe/backend/query.go
Normal file
@@ -0,0 +1,85 @@
|
||||
package backend
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Represents a query received by the backend. Certain fields may be blank,
|
||||
// depending on negotiated protocol version - which is stored in ProtocolVersion
|
||||
//
|
||||
// QName, QClass, QType, Id and RemoteIpAddress are present in all versions
|
||||
// LocalIpAddress was added in version 2
|
||||
// EdnsSubnetAddress was added in version 3
|
||||
type Query struct {
|
||||
ProtocolVersion int
|
||||
QName string
|
||||
QClass string
|
||||
QType string
|
||||
Id string
|
||||
RemoteIpAddress string
|
||||
LocalIpAddress string
|
||||
EdnsSubnetAddress string
|
||||
}
|
||||
|
||||
func (q *Query) fromData(data string) (err error) {
|
||||
parts := strings.Split(data, "\t")
|
||||
|
||||
// ugh
|
||||
switch q.ProtocolVersion {
|
||||
case 1:
|
||||
if len(parts) != 5 {
|
||||
return errors.New("v1 query should have 5 data parts")
|
||||
}
|
||||
case 2:
|
||||
if len(parts) != 6 {
|
||||
return errors.New("v2 query should have 6 data parts")
|
||||
}
|
||||
q.LocalIpAddress = parts[5]
|
||||
case 3:
|
||||
if len(parts) != 7 {
|
||||
return errors.New("v3 query should have 7 data parts")
|
||||
}
|
||||
q.LocalIpAddress = parts[5]
|
||||
q.EdnsSubnetAddress = parts[6]
|
||||
default:
|
||||
return errors.New("Unknown protocol version in query")
|
||||
}
|
||||
|
||||
// common
|
||||
q.QName = parts[0]
|
||||
q.QClass = parts[1]
|
||||
q.QType = parts[2]
|
||||
q.Id = parts[3]
|
||||
q.RemoteIpAddress = parts[4]
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (q *Query) String() (string, error) {
|
||||
if q.ProtocolVersion < 1 || q.ProtocolVersion > 3 {
|
||||
return "", errors.New("Unknown protocol version in query")
|
||||
}
|
||||
switch q.ProtocolVersion {
|
||||
case 1:
|
||||
return fmt.Sprintf(
|
||||
"Q\t%s\t%s\t%s\t%s\t%s\n",
|
||||
q.QName, q.QClass, q.QType, q.Id, q.RemoteIpAddress,
|
||||
), nil
|
||||
case 2:
|
||||
return fmt.Sprintf(
|
||||
"Q\t%s\t%s\t%s\t%s\t%s\t%s\n",
|
||||
q.QName, q.QClass, q.QType, q.Id, q.RemoteIpAddress,
|
||||
q.LocalIpAddress,
|
||||
), nil
|
||||
case 3:
|
||||
return fmt.Sprintf(
|
||||
"Q\t%s\t%s\t%s\t%s\t%s\t%s\t%s\n",
|
||||
q.QName, q.QClass, q.QType, q.Id, q.RemoteIpAddress,
|
||||
q.LocalIpAddress, q.EdnsSubnetAddress,
|
||||
), nil
|
||||
}
|
||||
|
||||
return "", errors.New("Unknown protocol version in query")
|
||||
}
|
41
pipe/backend/response.go
Normal file
41
pipe/backend/response.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package backend
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// A response to be sent back in answer to a query. Again, some fields may be
|
||||
// blank, depending on protocol version.
|
||||
//
|
||||
// QName, QClass, QType, TTL, Id and Content are present in all versions
|
||||
// No additions in version 2
|
||||
// ScopeBits and Auth were added in version 3
|
||||
type Response struct {
|
||||
ProtocolVersion int
|
||||
ScopeBits string
|
||||
Auth string
|
||||
QName string
|
||||
QClass string
|
||||
QType string
|
||||
TTL string
|
||||
Id string
|
||||
Content string
|
||||
}
|
||||
|
||||
// Gives the response in a serialized form suitable for squirting on the wire
|
||||
func (r *Response) String() (string, error) {
|
||||
switch r.ProtocolVersion {
|
||||
case 1, 2:
|
||||
return fmt.Sprintf(
|
||||
"DATA\t%s\t%s\t%s\t%s\t%s\t%s\n",
|
||||
r.QName, r.QClass, r.QType, r.TTL, r.Id, r.Content,
|
||||
), nil
|
||||
case 3:
|
||||
return fmt.Sprintf(
|
||||
"DATA\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\n",
|
||||
r.ScopeBits, r.Auth, r.QName, r.QClass, r.QType, r.TTL, r.Id, r.Content,
|
||||
), nil
|
||||
}
|
||||
return "", errors.New("Unknown protocol version in response")
|
||||
}
|
Reference in New Issue
Block a user