Initial commit

This commit is contained in:
Nick Thomas
2015-01-10 02:01:24 +00:00
commit 259ffdc8ed
12 changed files with 2585 additions and 0 deletions

174
pipe/backend/backend.go Normal file
View File

@@ -0,0 +1,174 @@
// Copyright 2015 Bytemark Computer Consulting Ltd. All rights reserved
// Licensed under the GNU General Public License, version 2. See the LICENSE
// file for more details
// Handler for the PowerDNS pipebackend protocol, as documented here:
// https://doc.powerdns.com/md/authoritative/backend-pipe/
//
// Can speak all three (at time of writing) protocol versions.
//
// Usage:
//
// backend := backend.New(
package backend
import (
"bufio"
"errors"
"fmt"
"io"
"strconv"
"strings"
)
type Backend struct {
// The text the backend will serve to a successful hello message
Banner string
// The protocol version negotiated with the remote end
ProtocolVersion int
io *bufio.ReadWriter
}
// A callback of this type is executed whenever a query is received. If an error
// is returned, the responses are ignored and the error text is returned to the
// backend. Otherwise, the responses are serialised and sent back in order.
type Callback func(b *Backend, q *Query) ([]*Response, error)
// Build a new backend object. The banner is reported to the client upon
// successful negotiation; the io can be anything.
func New(r io.Reader, w io.Writer, banner string) *Backend {
io := bufio.NewReadWriter(
bufio.NewReader(r),
bufio.NewWriter(w),
)
return &Backend{Banner: banner, io: io}
}
// Does initial handshake with peer. Returns nil, or an error
// Note that the pipebackend protocol documentation states that if negotiation
// fails, the process should retry, not exit itself.
func (b *Backend) Negotiate() error {
hello, err := b.io.ReadString('\n')
if err != nil {
return err
}
// We're not interested in the trailing newlines
parts := strings.Split(strings.TrimRight(hello, "\r\n"), "\t")
if len(parts) != 2 || parts[0] != "HELO" {
return errors.New("Bad hello from client")
}
version, err := strconv.Atoi(parts[1])
if version < 1 || version > 3 || err != nil {
return errors.New("Unknown protocol version requested")
}
_, err = b.io.WriteString(fmt.Sprintf("OK\t%s\n", b.Banner))
if err == nil {
err = b.io.Flush()
}
if err == nil {
b.ProtocolVersion = version
}
return err
}
func (b *Backend) handleQ(data string, callback Callback) ([]*Response, error) {
query := Query{ProtocolVersion: b.ProtocolVersion}
err := query.fromData(data)
if err != nil {
return nil, err
}
return callback(b, &query)
}
// TODO
func (b *Backend) handleAXFR() ([]*Response, error) {
return nil, errors.New("AXFR requests not supported")
}
// Reads lines in a loop, processing them by executing the provided callback
// and writing appropriate output in response, sequentially, until we hit an
// error or our IO hits EOF
func (b *Backend) Run(callback Callback) error {
responses := make([]*Response, 0)
for {
line, err := b.io.ReadString('\n')
if err != nil {
if err == io.EOF {
return nil
}
return err
}
parts := strings.SplitN(strings.TrimRight(line, "\n"), "\t", 2)
if len(parts) == 2 {
}
switch parts[0] {
case "Q":
responses, err = b.handleQ(parts[1], callback)
case "PING":
responses, err = nil, nil // We just need to return END
case "AXFR":
responses, err = b.handleAXFR()
default:
responses, err = nil, errors.New("Bad command")
}
if err != nil {
// avoid protocol errors
clean := strings.Replace(err.Error(), "\n", " ", -1)
msg := fmt.Sprintf("LOG\tError handling line: %s\nFAIL\n", clean)
_, err := b.io.WriteString(msg)
if err != nil {
return fmt.Errorf("%s while writing FAIL response", err)
}
//
err = b.io.Flush()
if err != nil {
return fmt.Errorf("%s while flushing FAIL response", err)
}
continue
}
// DATA (if there are any records to return)
for _, response := range responses {
// Always output a line of the right protocol version
// TODO: panic if it's set to a wrong non-zero value?
response.ProtocolVersion = b.ProtocolVersion
data, err := response.String()
if err != nil {
data = "LOG\tError serialising response: " + err.Error() + "\n"
}
_, err = b.io.WriteString(data)
if err != nil {
return fmt.Errorf("%s while writing DATA response", err)
}
}
// END
_, err = b.io.WriteString("END\n")
if err == nil {
err = b.io.Flush()
}
if err != nil {
return fmt.Errorf("%s while writing END", err)
}
}
// We should never hit this at the moment.
// TODO: graceful signal handling - intercept and ensure current query
// completes before breaking the above loop and returning nil here
return nil
}

View File

@@ -0,0 +1,169 @@
package backend_test
import (
h "../test_helpers"
"bytes"
"errors"
"fmt"
. "github.com/BytemarkHosting/go-pdns/pipe/backend"
"strings"
"testing"
)
// Test serializing Query & Response instances - we use them in the tests
func TestQueryStringV1(t *testing.T) {
h.AssertEqualString(
t, "Q\texample.com\tIN\tANY\t-1\t127.0.0.2\n",
h.FakeQueryString(t, 1), "V1 query serialisation problem",
)
}
func TestQueryStringV2(t *testing.T) {
h.AssertEqualString(
t, "Q\texample.com\tIN\tANY\t-1\t127.0.0.2\t127.0.0.1\n",
h.FakeQueryString(t, 2), "V2 query serialisation problem",
)
}
func TestQueryStringV3(t *testing.T) {
h.AssertEqualString(
t, "Q\texample.com\tIN\tANY\t-1\t127.0.0.2\t127.0.0.1\t127.0.0.3\n",
h.FakeQueryString(t, 3), "V3 query serialisation problem",
)
}
// Ensure we serialize Response instances correctly - we use them in the tests
func TestResponseStringV1andV2(t *testing.T) {
exemplar := "DATA\texample.com\tIN\tANY\t3600\t-1\tfoo\n"
h.AssertEqualString(t, exemplar, h.FakeResponseString(t, 1), "V1 response serialisation problem")
h.AssertEqualString(t, exemplar, h.FakeResponseString(t, 2), "V2 response serialisation problem")
}
func TestResponseStringV3(t *testing.T) {
h.AssertEqualString(
t, "DATA\t24\tauth\texample.com\tIN\tANY\t3600\t-1\tfoo\n",
h.FakeResponseString(t, 3), "V3 response serialisation problem",
)
}
func BuildAndNegotiate(t *testing.T, protoVersion int) (*Backend, *bytes.Buffer, *bytes.Buffer) {
r := bytes.NewBufferString(fmt.Sprintf("HELO\t%d\n", protoVersion))
w := &bytes.Buffer{}
b := New(r, w, "Testing Backend")
h.RefuteError(t, b.Negotiate(), "Negotiation failed")
h.AssertEqualInt(t, protoVersion, b.ProtocolVersion, "Bad protocol version")
h.AssertEqualString(t, "OK\tTesting Backend\n", w.String(), "Bad response to HELO")
w.Reset()
return b, r, w
}
func AssertRun(t *testing.T, b *Backend, f Callback) {
err := b.Run(f)
h.RefuteError(t, err, "Running backend")
}
func AssertProtocolVersionNegotiation(t *testing.T, protoVersion int) {
b, r, w := BuildAndNegotiate(t, protoVersion)
r.WriteString(h.FakeQueryString(t, protoVersion))
// We also test that the negotiated version can be handled
AssertRun(t, b, h.EmptyDispatch)
h.AssertEqualString(t, "END\n", w.String(), "Unexpected response")
w.Reset()
exp := fmt.Sprintf(
"LOG\tError handling line: v%d query should have %d data parts\nFAIL\n",
protoVersion, protoVersion+4,
)
// Test a short Q
r.WriteString("Q\tfoo\n")
AssertRun(t, b, h.EmptyDispatch)
h.AssertEqualString(t, exp, w.String(), "Unexpected response")
w.Reset()
// test a long Q
long := strings.TrimRight(h.FakeQueryString(t, protoVersion), "\n") + "\tfoo\n"
r.WriteString(long)
AssertRun(t, b, h.EmptyDispatch)
h.AssertEqualString(t, exp, w.String(), "Unexpected response")
}
func TestNegotiatedVersion1(t *testing.T) {
AssertProtocolVersionNegotiation(t, 1)
}
func TestNegotiatedVersion2(t *testing.T) {
AssertProtocolVersionNegotiation(t, 2)
}
func TestNegotiatedVersion3(t *testing.T) {
AssertProtocolVersionNegotiation(t, 3)
}
func TestQueriesArePassedToDispatcher(t *testing.T) {
b, r, w := BuildAndNegotiate(t, 3)
var outQ *Query
runs := 0
expected := h.FakeQueryString(t, 3)
r.WriteString(expected)
AssertRun(t, b, func(b *Backend, q *Query) ([]*Response, error) {
runs = runs + 1
outQ = q
return nil, nil
})
h.AssertEqualInt(t, 1, runs, "Exactly one dispatch callback expected")
txt, _ := outQ.String()
h.AssertEqualString(t, expected, txt, "Wrong query dispatched")
h.AssertEqualString(t, "END\n", w.String(), "Unexpected response")
}
func TestResponsesFromDispatcherArePassedToAsker(t *testing.T) {
b, r, w := BuildAndNegotiate(t, 3)
r.WriteString(h.FakeQueryString(t, 3))
fr := h.FakeResponse(3)
AssertRun(t, b, func(b *Backend, q *Query) ([]*Response, error) {
return []*Response{fr, fr}, nil
})
exp := fmt.Sprintf("%s%sEND\n", h.FakeResponseString(t, 3), h.FakeResponseString(t, 3))
h.AssertEqualString(t, exp, w.String(), "Bad response")
}
func TestErrorFromDispatcherSuppressesResponses(t *testing.T) {
b, r, w := BuildAndNegotiate(t, 3)
r.WriteString(h.FakeQueryString(t, 3))
fr := h.FakeResponse(3)
AssertRun(t, b, func(b *Backend, q *Query) ([]*Response, error) {
return []*Response{fr, fr}, errors.New("Test\nerror")
})
h.AssertEqualString(t, "LOG\tError handling line: Test error\nFAIL\n", w.String(), "Bad response")
}
func TestHandlesPing(t *testing.T) {
b, r, w := BuildAndNegotiate(t, 3)
r.WriteString("PING\n")
AssertRun(t, b, h.EmptyDispatch)
h.AssertEqualString(t, "END\n", w.String(), "Bad response")
}
func TestAXFRIsTODO(t *testing.T) {
b, r, w := BuildAndNegotiate(t, 3)
r.WriteString("AXFR\n")
AssertRun(t, b, h.EmptyDispatch)
h.AssertEqualString(t, "LOG\tError handling line: AXFR requests not supported\nFAIL\n", w.String(), "Bad response")
}
func TestUnknownCommand(t *testing.T) {
b, r, w := BuildAndNegotiate(t, 3)
r.WriteString("GOGOGO\n")
AssertRun(t, b, h.EmptyDispatch)
h.AssertEqualString(t, "LOG\tError handling line: Bad command\nFAIL\n", w.String(), "Bad response")
}

85
pipe/backend/query.go Normal file
View File

@@ -0,0 +1,85 @@
package backend
import (
"errors"
"fmt"
"strings"
)
// Represents a query received by the backend. Certain fields may be blank,
// depending on negotiated protocol version - which is stored in ProtocolVersion
//
// QName, QClass, QType, Id and RemoteIpAddress are present in all versions
// LocalIpAddress was added in version 2
// EdnsSubnetAddress was added in version 3
type Query struct {
ProtocolVersion int
QName string
QClass string
QType string
Id string
RemoteIpAddress string
LocalIpAddress string
EdnsSubnetAddress string
}
func (q *Query) fromData(data string) (err error) {
parts := strings.Split(data, "\t")
// ugh
switch q.ProtocolVersion {
case 1:
if len(parts) != 5 {
return errors.New("v1 query should have 5 data parts")
}
case 2:
if len(parts) != 6 {
return errors.New("v2 query should have 6 data parts")
}
q.LocalIpAddress = parts[5]
case 3:
if len(parts) != 7 {
return errors.New("v3 query should have 7 data parts")
}
q.LocalIpAddress = parts[5]
q.EdnsSubnetAddress = parts[6]
default:
return errors.New("Unknown protocol version in query")
}
// common
q.QName = parts[0]
q.QClass = parts[1]
q.QType = parts[2]
q.Id = parts[3]
q.RemoteIpAddress = parts[4]
return nil
}
func (q *Query) String() (string, error) {
if q.ProtocolVersion < 1 || q.ProtocolVersion > 3 {
return "", errors.New("Unknown protocol version in query")
}
switch q.ProtocolVersion {
case 1:
return fmt.Sprintf(
"Q\t%s\t%s\t%s\t%s\t%s\n",
q.QName, q.QClass, q.QType, q.Id, q.RemoteIpAddress,
), nil
case 2:
return fmt.Sprintf(
"Q\t%s\t%s\t%s\t%s\t%s\t%s\n",
q.QName, q.QClass, q.QType, q.Id, q.RemoteIpAddress,
q.LocalIpAddress,
), nil
case 3:
return fmt.Sprintf(
"Q\t%s\t%s\t%s\t%s\t%s\t%s\t%s\n",
q.QName, q.QClass, q.QType, q.Id, q.RemoteIpAddress,
q.LocalIpAddress, q.EdnsSubnetAddress,
), nil
}
return "", errors.New("Unknown protocol version in query")
}

41
pipe/backend/response.go Normal file
View File

@@ -0,0 +1,41 @@
package backend
import (
"errors"
"fmt"
)
// A response to be sent back in answer to a query. Again, some fields may be
// blank, depending on protocol version.
//
// QName, QClass, QType, TTL, Id and Content are present in all versions
// No additions in version 2
// ScopeBits and Auth were added in version 3
type Response struct {
ProtocolVersion int
ScopeBits string
Auth string
QName string
QClass string
QType string
TTL string
Id string
Content string
}
// Gives the response in a serialized form suitable for squirting on the wire
func (r *Response) String() (string, error) {
switch r.ProtocolVersion {
case 1, 2:
return fmt.Sprintf(
"DATA\t%s\t%s\t%s\t%s\t%s\t%s\n",
r.QName, r.QClass, r.QType, r.TTL, r.Id, r.Content,
), nil
case 3:
return fmt.Sprintf(
"DATA\t%s\t%s\t%s\t%s\t%s\t%s\t%s\t%s\n",
r.ScopeBits, r.Auth, r.QName, r.QClass, r.QType, r.TTL, r.Id, r.Content,
), nil
}
return "", errors.New("Unknown protocol version in response")
}

59
pipe/dsl/context.go Normal file
View File

@@ -0,0 +1,59 @@
package dsl
import (
"github.com/BytemarkHosting/go-pdns/pipe/backend"
"strconv"
)
// Callbacks are run with a context instance, which allows them to accumulate
// answers while maintaining a short type signature. It will also make
// concurrent callbacks easier, when we handle that, but for now one context
// is maintained across all callbacks for a particular query
type Context struct {
// Replies that don't specify a TTL will be given this instead.
DefaultTTL int
// The query that triggered this callback run. Note that its QType
// member may be "ANY"
Query *backend.Query
// QType this callback is being run as. Matches the qtype field given
// with the callback at the time DSL.Register was called
QType string
// The callback is registered with a regexp; if that regexp contains
// any match groups, then the matched text is placed here.
Matches []string
// Set this if an error has been encountered; no more callbacks will be
// run, and the error text (only) will be reported to the backend.
Error error
// Answers to be sent to the backend are stored here. Context.Reply()
// calls, etc, generate answers and put them here, for instance.
// If multiple callbacks are being run, then later callbacks will be
// able to see the answers earlier ones generated (for now)
Answers []*backend.Response
}
// Add an answer, using default QName and TTL for the query
func (c *Context) Reply(content string) {
c.ReplyExtra(c.Query.QName, content, c.DefaultTTL)
}
// Add an answer, using the default QName but specifying a particular TTL
func (c *Context) ReplyTTL(content string, ttl int) {
c.ReplyExtra(c.Query.QName, content, ttl)
}
// Add an answer, specifying both QName and TTL.
func (c *Context) ReplyExtra(qname, content string, ttl int) {
c.Answers = append(c.Answers, &backend.Response{
QName: qname,
QClass: c.Query.QClass,
QType: c.QType, // q.Query.QType may == "ANY"
Id: c.Query.Id,
Content: content,
TTL: strconv.Itoa(ttl),
})
}

223
pipe/dsl/dsl.go Normal file
View File

@@ -0,0 +1,223 @@
// Copyright 2015 Bytemark Computer Consulting Ltd. All rights reserved
// Licensed under the GNU General Public License, version 2. See the LICENSE
// file for more details
// Simple DSL for pipebackend. Usage:
//
// // Create new handle. You have to specify a default TTL here.
// x := dsl.New()
// root := regexp.QuoteMeta("example.com")
//
// // most zones need SOA + NS records
// x.SOA(root, func(c *dsl.Context) {
// c.Reply("ns1.example.com hostmaster.example.com 1 3600 1800 86400 3600")
// })
//
// // Zones need NS records too. All replies will be returned
// x.NS(root, func(c *dsl.Context) {
// c.Reply("ns1.example.com")
// c.Reply("ns2.example.com")
// c.Reply("ns3.example.com")
// })
//
// // You don't have to use anonymous functions, of course
// func answer(c *dsl.Context) {
// switch c.Query.QType {
// case "A" : c.Reply("169.254.0.1")
// case "AAAA": c.Reply("fe80::1" )
// }
// }
// x.A(root, answer)
// x.AAAA(root, answer)
//
// // Setting c.Error at any point will suppress *all* replies from being
// // sent back. Instead, a FAIL response with the c.Error.Error() as the
// // content is returned to powerdns
// x.SSHFP(root, func(c *dsl.Context) {
// c.Reply("1 1 f1d2d2f924e986ac86fdf7b36c94bcdf32beec15")
// c.Error = errors.New("Don't use SSHFP on unsigned zones")
// c.Reply("1 2 e242ed3bffccdf271b7fbaf34ed72d089537b42f")
// })
//
// // You can do anything in a callback, but be aware that powerdns has a
// // time limit on responses and there is no request concurrency within a
// // single pipe connection. pdns achieves concurrency through multiple
// // backend connections instead
// c := make(chan string)
// x.MX(root, func(c *dsl.Context) {
// c.Reply(<-c)
// })
//
// // If your regexp includes capture groups, they are quoted back to you.
// // Here's a simple DNS echo server. Note the use of ReplyExtra to allow
// // a non-default TTL to be set.
// //
// // Don't forget: DNS is supposed to be case-insensitive. Be careful.
// c.TXT(`(.*)\.` + root, func(c *dsl.Context) {
// c.ReplyTTL(c.Query.QName, c.Matches[0], 0)
// })
//
// // Dispatch is up to you. It will probably look like this, but you
// // might want to add logging around the request or something more
// // complicated (different DSL instance depending on backend version?)
// func doit(b *backend.Backend, q *backend.Query) ([]*backend.Response, error) {
// if q.QClass == "IN" {
// return x.Lookup(q)
// }
// return nil, errors.New("Only IN QClass is supported")
// }
//
// pipe := backend.New( r, w, "Example backend" )
// err1 := pipe.Negotiate() // do check for errors
// err2 := pipe.Run(doit)
//
//
//
package dsl
import (
"github.com/BytemarkHosting/go-pdns/pipe/backend"
"regexp"
)
// Instances of this struct are used to hold onto registered callbacks, etc.
type DSL struct {
callbacks map[string][]callbackNode
qtypeSort []string
defaultTTL int
beforeCallback Callback
}
// Get a new builder with a default TTL of one hour
func New() *DSL {
return NewWithTTL(3600)
}
// Get a new builder, specifying a default TTL explicitly
func NewWithTTL(ttl int) *DSL {
return &DSL{
callbacks: make(map[string][]callbackNode),
qtypeSort: make([]string, 0),
defaultTTL: ttl,
}
}
// Callbacks are registered against the DSL instance and run against incoming
// queries if the regexp they are registered with matches the QName of the query
type Callback func(c *Context)
type callbackNode struct {
matcher *regexp.Regexp
fn Callback
}
// Register a callback to run before every request. Set c.Error to halt
// processing, or mutate the context however you like.
func (d *DSL) Before(f Callback) {
d.beforeCallback = f
}
// Register a callback to be run whenever a query with a QName matching the
// regular expression comes in. The regex is provided as a string (matcher)
// to keep ordinary invocations short; it's compiled immediately with
// regexp.MustCompile. Don't forget to anchor your regexes!
//
// If match groups are included in the regex, then any matched text is placed in
// the Context the callback receives.
//
// Callbacks are run with slightly obtuse ordering: all callbacks of a qtype
// are run in the order they were registered. We iterate the list of qtypes
// in the order that a callback with a matching qtype was *first* registered.
// If your pdns server has the "noshuffle" configuration directive, the order
// will be reflected in the responses returned by it; future concurrent DSL
// should maintain this ordering.
func (d *DSL) Register(qtype string, re *regexp.Regexp, f Callback) {
// Maintain our obtuse sense of order
alreadyIn := false
for _, prospect := range d.qtypeSort {
if prospect == qtype {
alreadyIn = true
break
}
}
if !alreadyIn {
d.qtypeSort = append(d.qtypeSort, qtype)
}
node := callbackNode{matcher: re, fn: f}
d.callbacks[qtype] = append(d.callbacks[qtype], node)
}
// Once we're concurrent, this method will create the context and return it
func (d *DSL) runNode(c *Context, node *callbackNode) {
matches := node.matcher.FindStringSubmatch(c.Query.QName)
if matches != nil && len(matches) > 0 {
// Probably unnecessary, but ensure that the previous value of
// Matches is preserved. This could also be = nil
oldmatches := c.Matches
defer func(c *Context) { c.Matches = oldmatches }(c)
// The first match is the whole thing, followed by the capture
// groups. We're only interested in the latter.
c.Matches = matches[1:]
if d.beforeCallback != nil {
d.beforeCallback(c)
}
if c.Error == nil {
node.fn(c)
}
}
}
// Run all registered callbacks against the query. If any callbacks report an
// error, we halt and return the error only (partially constructed responses are
// discarded).
//
// For now, callbacks are run sequentially, rather than in parallel. There could
// be a speedup to running each callback in its own goroutine. Currently, all
// callbacks share the same context instance; we'd have to change that if we
// ran them in parallel.
func (d *DSL) Lookup(q *backend.Query) ([]*backend.Response, error) {
c := Context{
DefaultTTL: d.defaultTTL,
Query: q,
Answers: make([]*backend.Response, 0),
Error: nil,
}
var runOn []string
if q.QType == "ANY" {
runOn = d.qtypeSort
} else {
runOn = []string{q.QType}
}
for _, qtype := range runOn {
c.QType = qtype
for _, node := range d.callbacks[qtype] {
d.runNode(&c, &node)
if c.Error != nil {
return nil, c.Error
}
}
}
return c.Answers, nil
}
// Reports the registered callbacks, in order. Handy for testing or status.
func (d *DSL) String() string {
out := ""
for _, qtype := range d.qtypeSort {
out = out + qtype + "\t:"
for _, node := range d.callbacks[qtype] {
out = out + "\t" + node.matcher.String() + "\n"
}
}
return out
}

1163
pipe/dsl/dsl_helpers.go Normal file

File diff suppressed because it is too large Load Diff

181
pipe/dsl/dsl_test.go Normal file
View File

@@ -0,0 +1,181 @@
package dsl_test
import (
"errors"
"fmt"
"github.com/BytemarkHosting/go-pdns/pipe/backend"
. "github.com/BytemarkHosting/go-pdns/pipe/dsl"
h "github.com/BytemarkHosting/go-pdns/pipe/test_helpers"
"strings"
"testing"
)
func ReplyHandler(x string) func(c *Context) {
return func(c *Context) { c.Reply(x) }
}
func NullHandler(c *Context) {}
var ErrorReplyError = errors.New("Foo")
func ErrorReplyHandler(c *Context) { c.Error = ErrorReplyError }
func SOAQuery() *backend.Query {
q := h.FakeQuery(3)
q.QName = "example.com"
q.QType = "SOA"
return q
}
func AssertTableEntry(t *testing.T, d *DSL, qtype, matcher, msg string) {
table := qtype + "\t:\t" + matcher + "\n"
h.AssertEqualString(t, table, d.String(), msg)
}
func AssertLookup(t *testing.T, d *DSL, q *backend.Query, n int, err error) []*backend.Response {
rsp, rspErr := d.Lookup(q)
if err == nil {
h.RefuteError(t, err, "Lookup shouldn't return error")
} else if rspErr == nil {
t.Logf("Expected error %s but no error was returned", err)
t.FailNow()
} else {
h.AssertEqualString(t, err.Error(), rspErr.Error(), "Expected error not returned")
}
rspStrs := []string{}
for _, r := range rsp {
r.ProtocolVersion = 1
str, err := r.String()
h.RefuteError(t, err, "sanity")
rspStrs = append(rspStrs, str)
}
h.AssertEqualInt(t, n, len(rsp), fmt.Sprintf("One response expected, got:\n%s", strings.Join(rspStrs, "")))
return rsp
}
// Don't test all the autogenerated helpers explicitly, just a common one.
func TestAutogeneratedExampleRegistersCorrectCallbackWhenRun(t *testing.T) {
d := New()
d.SOA(`example\.com`, NullHandler)
AssertTableEntry(t, d, "SOA", `^(?i)example\.com$`, "SOA callback not registered")
}
func TestDefaultTTLFromNewIsOneHour(t *testing.T) {
d := New()
d.SOA(`*`, ReplyHandler("Foo"))
rsp := AssertLookup(t, d, SOAQuery(), 1, nil)
h.AssertEqualString(t, "3600", rsp[0].TTL, "Default TTL not honoured")
}
func TestAlternativeTTLCanBeSpecifiedUsingNewWithTTL(t *testing.T) {
d := NewWithTTL(86400)
d.SOA(`*`, ReplyHandler("Foo"))
rsp := AssertLookup(t, d, SOAQuery(), 1, nil)
h.AssertEqualString(t, "86400", rsp[0].TTL, "Custom TTL not honoured")
}
func TestBeforeCallbackIsCalledIfSpecified(t *testing.T) {
d := New()
d.Before(ReplyHandler("Before"))
d.SOA(`*`, ReplyHandler("SOA 1"))
d.SOA(`*`, ReplyHandler("SOA 2"))
rsp := AssertLookup(t, d, SOAQuery(), 4, nil)
h.AssertEqualString(t, "Before", rsp[0].Content, "First Before not called")
h.AssertEqualString(t, "SOA 1", rsp[1].Content, "First SOA not called")
h.AssertEqualString(t, "Before", rsp[2].Content, "Second Before not called")
h.AssertEqualString(t, "SOA 2", rsp[3].Content, "Second SOA not called")
}
func TestLookupCallbackOrderIsDefined(t *testing.T) {
d := New()
d.SOA(`*`, ReplyHandler("SOA 1"))
d.MX(`*`, ReplyHandler("MX 1"))
d.SOA(`*`, ReplyHandler("SOA 2"))
d.SOA(`*`, ReplyHandler("SOA 3"))
d.AAAA(`*`, ReplyHandler("AAAA 1"))
d.A(`*`, ReplyHandler("AAAA 2"))
rsp := AssertLookup(t, d, h.FakeQuery(1), 6, nil)
msg := "Order is wrong"
h.AssertEqualString(t, "SOA 1", rsp[0].Content, msg)
h.AssertEqualString(t, "SOA 2", rsp[1].Content, msg)
h.AssertEqualString(t, "SOA 3", rsp[2].Content, msg)
h.AssertEqualString(t, "MX 1", rsp[3].Content, msg)
h.AssertEqualString(t, "AAAA 1", rsp[4].Content, msg)
h.AssertEqualString(t, "AAAA 2", rsp[5].Content, msg)
}
func TestOnlyMatchingCallbacksAreRun(t *testing.T) {
d := New()
d.SOA(`example\.com`, ReplyHandler("Good"))
d.SOA(`example\.org`, ReplyHandler("Bad"))
rsp := AssertLookup(t, d, SOAQuery(), 1, nil)
h.AssertEqualString(t, "Good", rsp[0].Content, "Wrong callback run")
}
func TestOnlyCallbacksOfTheRightQTypeAreRun(t *testing.T) {
d := New()
d.SOA(`example\.com`, ReplyHandler("Good"))
d.NS(`example\.com`, ReplyHandler("Bad"))
rsp := AssertLookup(t, d, SOAQuery(), 1, nil)
h.AssertEqualString(t, "Good", rsp[0].Content, "Wrong callback run")
}
func TestCaptureGroupsArePutIntoContextMatches(t *testing.T) {
d := New()
var captures []string
d.SOA(`([a-z]{3})\.([a-z]{3})\.([a-z]{3})\.example\.com`, func(c *Context) {
captures = c.Matches
c.Reply("OK")
})
q := SOAQuery()
q.QName = "BAZ.bar.foo.example.com"
rsp := AssertLookup(t, d, q, 1, nil)
h.AssertEqualInt(t, 3, len(captures), "Wrong number of match groups returned")
h.AssertEqualString(t, "OK", rsp[0].Content, "Wrong callback run?")
// case-insensitive, so we don't mangle these
h.AssertEqualString(t, "BAZ", captures[0], "Part 1 not captured")
h.AssertEqualString(t, "bar", captures[1], "Part 2 not captured")
h.AssertEqualString(t, "foo", captures[2], "Part 3 not captured")
}
func TestBeforeCanModifyCaptureGroups(t *testing.T) {
d := New()
d.Before(func(c *Context) { c.Matches = append(c.Matches, "modify-cg") })
d.SOA(`*`, func(c *Context) { c.Reply(c.Matches[0]) })
rsp := AssertLookup(t, d, SOAQuery(), 1, nil)
h.AssertEqualString(t, "modify-cg", rsp[0].Content, "Capture modification lost")
}
func TestChangesToMatchesInOrdinaryCallbacksDoNotPersist(t *testing.T) {
d := New()
var ok bool
d.SOA(`*`, func(c *Context) { c.Matches = append(c.Matches, "Bad") })
d.SOA(`*`, func(c *Context) { ok = (len(c.Matches) == 0) })
AssertLookup(t, d, SOAQuery(), 0, nil)
h.Assert(t, ok, "Change was persisted")
}
func TestCallbackReturningErrorIsReported(t *testing.T) {
d := New()
d.SOA(`*`, ErrorReplyHandler)
AssertLookup(t, d, SOAQuery(), 0, ErrorReplyError)
}
func TestCallbackReturningErrorBlanksAnswers(t *testing.T) {
d := New()
d.SOA(`*`, ReplyHandler("Foo"))
d.SOA(`*`, ErrorReplyHandler)
AssertLookup(t, d, SOAQuery(), 0, ErrorReplyError)
}
func TestCallbackReturningErrorStopsLaterCallbacksFromRunning(t *testing.T) {
d := New()
ok := true
d.SOA(`*`, ErrorReplyHandler)
d.SOA(`*`, func(c *Context) { ok = false })
AssertLookup(t, d, SOAQuery(), 0, ErrorReplyError)
h.Assert(t, ok, "Later callback was run")
}

View File

@@ -0,0 +1,41 @@
#!/bin/sh
#
records=`curl -sf http://www.iana.org/assignments/dns-parameters/dns-parameters-4.csv | cut -f1 -d',' | sort | uniq | awk '/^[A-Z][A-Z]*$/ { print $0 }'` || exit 1
echo "// AUTOGENERATED helper methods for IANA-registered RRTYPES. Do not edit."
echo "// See generate_dsl_helpers.sh for details"
echo "package dsl"
echo "
import (
\"fmt\"
\"regexp\"
)
"
# A list of RRTypes might come in handy, you never know
echo "var RRTypes = []string{"
for record in $records; do
echo " \"${record}\","
done
echo "}"
for record in $records; do
echo "
// Helper function to register a callback for ${record} queries.
// The matcher is given as a string, which is compiled to a regular expression
// (using regexp.MustCompile) with the following rules:
//
// * The regexp is anchored to the start of the match string(\"^\" at start)
// * The case-insensitivity option is added \"(?i)\"
// * The regexp is anchored to the end of the match string (\"$\" at end)
//
// If any of these options are unwelcome, you can use the DSL.Register and pass
// a regexp and the \"${record}\" string directly.
func (d *DSL) ${record}(matcher string, f Callback) {
re := regexp.MustCompile(fmt.Sprintf(\"^(?i)%s$\", matcher))
d.Register(\"${record}\", re, f)
}"
done

View File

@@ -0,0 +1,98 @@
package test_helpers
import(
"fmt"
"github.com/BytemarkHosting/go-pdns/pipe/backend"
"testing"
)
func EmptyDispatch(b *backend.Backend, q *backend.Query)([]*backend.Response,error) {
return nil, nil
}
func AssertEqualString(t *testing.T, a, b, msg string) {
if a != b {
t.Logf(fmt.Sprintf("%s:'\n%s\n'\nshould be the same as:'\n%s\n'", msg, a, b))
t.FailNow()
}
}
func AssertEqualInt(t *testing.T, a, b int, msg string) {
if a != b {
t.Logf(fmt.Sprintf("%s: %d should == %d", msg, a, b))
t.FailNow()
}
}
func RefuteError(t *testing.T, err error, msg string) {
if err != nil {
t.Logf("Error: %s was expected to be nil", msg, err)
t.FailNow()
}
}
func Assert(t *testing.T, condition bool, msg string) {
if !condition {
t.Log(msg)
t.FailNow()
}
}
func FakeQuery(protoVersion int) *backend.Query {
if protoVersion < 1 || protoVersion > 3 {
panic("Invalid protoVersion")
}
q := backend.Query{
ProtocolVersion: protoVersion,
QClass: "IN",
QType: "ANY",
QName: "example.com",
Id: "-1",
RemoteIpAddress: "127.0.0.2",
}
if protoVersion > 1 {
q.LocalIpAddress = "127.0.0.1"
}
if protoVersion > 2 {
q.EdnsSubnetAddress = "127.0.0.3"
}
return &q
}
func FakeResponse(protoVersion int) *backend.Response {
if protoVersion < 1 || protoVersion > 3 {
panic("Invalid protoVersion")
}
r := backend.Response{
ProtocolVersion: protoVersion,
QClass: "IN",
QType: "ANY",
QName: "example.com",
Id: "-1",
TTL: "3600",
Content: "foo",
}
if protoVersion > 2 {
r.ScopeBits = "24"
r.Auth = "auth"
}
return &r
}
func FakeQueryString(t *testing.T, protoVersion int) string {
str, err := FakeQuery(protoVersion).String()
RefuteError(t, err, "Failed to serialise test query")
return str
}
func FakeResponseString(t *testing.T, protoVersion int) string {
str, err := FakeResponse(protoVersion).String()
RefuteError(t, err, "Failed to serialise test response")
return str
}