tsdns: dual resolution mode, IPv6 support (#526)

This change adds to tsdns the ability to delegate lookups to upstream nameservers.
This is crucial for setting Magic DNS as the system resolver.

Signed-off-by: Dmytro Shynkevych <dmytro@tailscale.com>
This commit is contained in:
Dmytro Shynkevych
2020-07-07 15:25:32 -04:00
committed by GitHub
parent ce1b52bb71
commit 67ebba90e1
7 changed files with 555 additions and 257 deletions
+317 -149
View File
@@ -7,128 +7,319 @@
package tsdns
import (
"encoding/binary"
"bytes"
"context"
"errors"
"strings"
"sync"
"time"
dns "golang.org/x/net/dns/dnsmessage"
"inet.af/netaddr"
"tailscale.com/net/netns"
"tailscale.com/types/logger"
"tailscale.com/wgengine/packet"
)
// maxResponseSize is the maximum size of a response from a Resolver.
const maxResponseSize = 512
// queueSize is the maximal number of DNS requests that can be pending at a time.
// If EnqueueRequest is called when this many requests are already pending,
// the request will be dropped to avoid blocking the caller.
const queueSize = 8
// delegateTimeout is the maximal amount of time Resolver will wait
// for upstream nameservers to process a query.
const delegateTimeout = 5 * time.Second
// defaultTTL is the TTL of all responses from Resolver.
const defaultTTL = 600 * time.Second
// ErrClosed indicates that the resolver has been closed and readers should exit.
var ErrClosed = errors.New("closed")
var (
errAllFailed = errors.New("all upstream nameservers failed")
errFullQueue = errors.New("request queue full")
errMapNotSet = errors.New("domain map not set")
errNoSuchDomain = errors.New("domain does not exist")
errNotImplemented = errors.New("query type not implemented")
errNotOurName = errors.New("not an *.ipn.dev domain")
errNotOurQuery = errors.New("query not for this resolver")
errNotQuery = errors.New("not a DNS query")
errSmallBuffer = errors.New("response buffer too small")
)
var (
defaultIP = packet.IP(binary.BigEndian.Uint32([]byte{100, 100, 100, 100}))
defaultPort = uint16(53)
)
// Map is all the data Resolver needs to resolve DNS queries.
// Map is all the data Resolver needs to resolve DNS queries within the Tailscale network.
type Map struct {
// domainToIP is a mapping of Tailscale domains to their IP addresses.
// For example, monitoring.ipn.dev -> 100.64.0.1.
// For example, monitoring.tailscale.us -> 100.64.0.1.
domainToIP map[string]netaddr.IP
}
// NewMap returns a new Map with domain to address mapping given by domainToIP.
// It takes ownership of the provided map.
func NewMap(domainToIP map[string]netaddr.IP) *Map {
return &Map{
domainToIP: domainToIP,
}
return &Map{domainToIP: domainToIP}
}
// Resolver is a DNS resolver for domain names of the form *.ipn.dev.
// Packet represents a DNS payload together with the address of its origin.
type Packet struct {
// Payload is the application layer DNS payload.
// Resolver assumes ownership of the request payload when it is enqueued
// and cedes ownership of the response payload when it is returned from NextResponse.
Payload []byte
// Addr is the source address for a request and the destination address for a response.
Addr netaddr.IPPort
}
// Resolver is a DNS resolver for nodes on the Tailscale network,
// associating them with domain names of the form <mynode>.<mydomain>.<root>.
// If it is asked to resolve a domain that is not of that form,
// it delegates to upstream nameservers if any are set.
type Resolver struct {
logf logger.Logf
// ip is the IP on which the resolver is listening.
ip packet.IP
// port is the port on which the resolver is listening.
port uint16
// The asynchronous interface is due to the fact that resolution may potentially
// block for a long time (if the upstream nameserver is slow to reach).
// queue is a buffered channel holding DNS requests queued for resolution.
queue chan Packet
// responses is an unbuffered channel to which responses are sent.
responses chan Packet
// errors is an unbuffered channel to which errors are sent.
errors chan error
// closed notifies the poll goroutines to stop.
closed chan struct{}
// pollGroup signals when all poll goroutines have stopped.
pollGroup sync.WaitGroup
// rootDomain is <root> in <mynode>.<mydomain>.<root>.
rootDomain []byte
// dialer is the netns.Dialer used for delegation.
dialer netns.Dialer
// mu guards the following fields from being updated while used.
mu sync.Mutex
mu sync.RWMutex
// dnsMap is the map most recently received from the control server.
dnsMap *Map
// nameservers is the list of nameserver addresses that should be used
// if the received query is not for a Tailscale node.
// The addresses are strings of the form ip:port, as expected by Dial.
nameservers []string
}
// NewResolver constructs a resolver with default parameters.
func NewResolver(logf logger.Logf) *Resolver {
// NewResolver constructs a resolver associated with the given root domain.
func NewResolver(logf logger.Logf, rootDomain string) *Resolver {
r := &Resolver{
logf: logf,
ip: defaultIP,
port: defaultPort,
logf: logger.WithPrefix(logf, "tsdns: "),
queue: make(chan Packet, queueSize),
responses: make(chan Packet),
errors: make(chan error),
closed: make(chan struct{}),
// Conform to the name format dnsmessage uses (trailing period, bytes).
rootDomain: []byte(rootDomain + "."),
dialer: netns.NewDialer(),
}
return r
}
// AcceptsPacket determines if the given packet is
// directed to this resolver (by ip and port).
// We also require that UDP be used to simplify things for now.
func (r *Resolver) AcceptsPacket(in *packet.ParsedPacket) bool {
return in.DstIP == r.ip && in.DstPort == r.port && in.IPProto == packet.UDP
func (r *Resolver) Start() {
// TODO(dmytro): spawn more than one goroutine? They block on delegation.
r.pollGroup.Add(1)
go r.poll()
}
// SetMap sets the resolver's DNS map.
// Close shuts down the resolver and ensures poll goroutines have exited.
// The Resolver cannot be used again after Close is called.
func (r *Resolver) Close() {
select {
case <-r.closed:
return
default:
// continue
}
close(r.closed)
r.pollGroup.Wait()
}
// SetMap sets the resolver's DNS map, taking ownership of it.
func (r *Resolver) SetMap(m *Map) {
r.mu.Lock()
r.dnsMap = m
r.mu.Unlock()
}
// Resolve maps a given domain name to the IP address of the host that owns it.
func (r *Resolver) Resolve(domain string) (netaddr.IP, dns.RCode, error) {
// If not a subdomain of ipn.dev, then we must refuse this query.
// We do this before checking the map to distinguish beween nonexistent domains
// and misdirected queries.
if !strings.HasSuffix(domain, ".ipn.dev") {
return netaddr.IP{}, dns.RCodeRefused, errNotOurName
}
// SetUpstreamNameservers sets the addresses of the resolver's
// upstream nameservers, taking ownership of the argument.
// The addresses should be strings of the form ip:port,
// matching what Dial("udp", addr) expects as addr.
func (r *Resolver) SetNameservers(nameservers []string) {
r.mu.Lock()
r.nameservers = nameservers
r.mu.Unlock()
}
// EnqueueRequest places the given DNS request in the resolver's queue.
// It takes ownership of the payload and does not block.
// If the queue is full, the request will be dropped and an error will be returned.
func (r *Resolver) EnqueueRequest(request Packet) error {
select {
case r.queue <- request:
return nil
default:
return errFullQueue
}
}
// NextResponse returns a DNS response to a previously enqueued request.
// It blocks until a response is available and gives up ownership of the response payload.
func (r *Resolver) NextResponse() (Packet, error) {
select {
case resp := <-r.responses:
return resp, nil
case err := <-r.errors:
return Packet{}, err
case <-r.closed:
return Packet{}, ErrClosed
}
}
// Resolve maps a given domain name to the IP address of the host that owns it.
// The domain name must not have a trailing period.
func (r *Resolver) Resolve(domain string) (netaddr.IP, dns.RCode, error) {
r.mu.RLock()
if r.dnsMap == nil {
r.mu.Unlock()
r.mu.RUnlock()
return netaddr.IP{}, dns.RCodeServerFailure, errMapNotSet
}
addr, found := r.dnsMap.domainToIP[domain]
r.mu.Unlock()
r.mu.RUnlock()
if !found {
return netaddr.IP{}, dns.RCodeNameError, errNoSuchDomain
return netaddr.IP{}, dns.RCodeNameError, nil
}
return addr, dns.RCodeSuccess, nil
}
func (r *Resolver) poll() {
defer r.pollGroup.Done()
var (
packet Packet
err error
)
for {
select {
case packet = <-r.queue:
// continue
case <-r.closed:
return
}
packet.Payload, err = r.respond(packet.Payload)
if err != nil {
select {
case r.errors <- err:
// continue
case <-r.closed:
return
}
} else {
select {
case r.responses <- packet:
// continue
case <-r.closed:
return
}
}
}
}
// queryServer obtains a DNS response by querying the given server.
func (r *Resolver) queryServer(ctx context.Context, server string, query []byte) ([]byte, error) {
conn, err := r.dialer.DialContext(ctx, "udp", server)
if err != nil {
return nil, err
}
defer conn.Close()
// Interrupt the current operation when the context is cancelled.
go func() {
<-ctx.Done()
conn.SetDeadline(time.Unix(1, 0))
}()
_, err = conn.Write(query)
if err != nil {
return nil, err
}
out := make([]byte, maxResponseSize)
n, err := conn.Read(out)
if err != nil {
return nil, err
}
return out[:n], nil
}
// delegate forwards the query to all upstream nameservers and returns the first response.
func (r *Resolver) delegate(query []byte) ([]byte, error) {
r.mu.RLock()
nameservers := r.nameservers
r.mu.RUnlock()
if len(r.nameservers) == 0 {
return nil, errAllFailed
}
ctx, cancel := context.WithTimeout(context.Background(), delegateTimeout)
defer cancel()
// Common case, don't spawn goroutines.
if len(nameservers) == 1 {
return r.queryServer(ctx, nameservers[0], query)
}
datach := make(chan []byte)
for _, server := range nameservers {
go func(s string) {
resp, err := r.queryServer(ctx, s, query)
// Only print errors not due to cancelation after first response.
if err != nil && ctx.Err() != context.Canceled {
r.logf("querying %s: %v", s, err)
}
datach <- resp
}(server)
}
var response []byte
for range nameservers {
cur := <-datach
if cur != nil && response == nil {
// Received first successful response
response = cur
cancel()
}
}
if response == nil {
return nil, errAllFailed
}
return response, nil
}
type response struct {
Header dns.Header
ResourceHeader dns.ResourceHeader
Question dns.Question
// TODO(dmytro): support IPv6.
IP netaddr.IP
Header dns.Header
Question dns.Question
Name string
IP netaddr.IP
}
// parseQuery parses the query in given packet into a response struct.
func (r *Resolver) parseQuery(query *packet.ParsedPacket, resp *response) error {
func (r *Resolver) parseQuery(query []byte, resp *response) error {
var parser dns.Parser
var err error
resp.Header, err = parser.Start(query.Payload())
resp.Header, err = parser.Start(query)
if err != nil {
return err
}
@@ -145,146 +336,123 @@ func (r *Resolver) parseQuery(query *packet.ParsedPacket, resp *response) error
return nil
}
// makeResponse resolves the question stored in resp and sets the answer fields.
func (r *Resolver) makeResponse(resp *response) error {
var err error
name := resp.Question.Name.String()
if len(name) > 0 {
name = name[:len(name)-1]
}
if resp.Question.Type == dns.TypeA {
// Remove final dot from name: *.ipn.dev. -> *.ipn.dev
resp.IP, resp.Header.RCode, err = r.Resolve(name)
} else {
resp.Header.RCode = dns.RCodeNotImplemented
err = errNotImplemented
}
return err
}
// marshalAnswer serializes the answer record into an active builder.
// marshalARecord serializes an A record into an active builder.
// The caller may continue using the builder following the call.
func marshalAnswer(resp *response, builder *dns.Builder) error {
func marshalARecord(name dns.Name, ip netaddr.IP, builder *dns.Builder) error {
var answer dns.AResource
err := builder.StartAnswers()
if err != nil {
return err
}
answerHeader := dns.ResourceHeader{
Name: resp.Question.Name,
Name: name,
Type: dns.TypeA,
Class: dns.ClassINET,
TTL: uint32(defaultTTL / time.Second),
}
ip := resp.IP.As16()
copy(answer.A[:], ip[12:])
ipbytes := ip.As4()
copy(answer.A[:], ipbytes[:])
return builder.AResource(answerHeader, answer)
}
// marshalResponse serializes the DNS response into an active builder.
// marshalAAAARecord serializes an AAAA record into an active builder.
// The caller may continue using the builder following the call.
func marshalResponse(resp *response, builder *dns.Builder) error {
err := builder.StartQuestions()
if err != nil {
return err
}
func marshalAAAARecord(name dns.Name, ip netaddr.IP, builder *dns.Builder) error {
var answer dns.AAAAResource
err = builder.Question(resp.Question)
if err != nil {
return err
answerHeader := dns.ResourceHeader{
Name: name,
Type: dns.TypeAAAA,
Class: dns.ClassINET,
TTL: uint32(defaultTTL / time.Second),
}
if resp.Header.RCode == dns.RCodeSuccess {
err = marshalAnswer(resp, builder)
if err != nil {
return err
}
}
return nil
ipbytes := ip.As16()
copy(answer.AAAA[:], ipbytes[:])
return builder.AAAAResource(answerHeader, answer)
}
// marshalReponsePacket marshals a full DNS packet (including headers)
// representing resp, which is a response to query, into buf.
// It returns buf trimmed to the length of the response packet.
func marshalResponsePacket(query *packet.ParsedPacket, resp *response, buf []byte) ([]byte, error) {
udpHeader := query.UDPHeader()
udpHeader.ToResponse()
offset := udpHeader.Len()
// marshalResponse serializes the DNS response into a new buffer.
func marshalResponse(resp *response) ([]byte, error) {
resp.Header.Response = true
resp.Header.Authoritative = true
if resp.Header.RecursionDesired {
resp.Header.RecursionAvailable = true
}
// dns.Builder appends to the passed buffer (without reallocation when possible),
// so we pass in a zero-length slice starting at the point it should start writing.
builder := dns.NewBuilder(buf[offset:offset], resp.Header)
builder := dns.NewBuilder(nil, resp.Header)
err := marshalResponse(resp, &builder)
err := builder.StartQuestions()
if err != nil {
return nil, err
}
// rbuf is the response slice with the correct length starting at offset.
rbuf, err := builder.Finish()
err = builder.Question(resp.Question)
if err != nil {
return nil, err
}
end := offset + len(rbuf)
err = udpHeader.Marshal(buf[:end])
// Only successful responses contain answers.
if resp.Header.RCode != dns.RCodeSuccess {
return builder.Finish()
}
err = builder.StartAnswers()
if err != nil {
return nil, err
}
return buf[:end], nil
if resp.IP.Is4() {
err = marshalARecord(resp.Question.Name, resp.IP, &builder)
} else {
err = marshalAAAARecord(resp.Question.Name, resp.IP, &builder)
}
if err != nil {
return nil, err
}
return builder.Finish()
}
// Respond writes a response to query into buf and returns buf trimmed to the response length.
// It is assumed that r.AcceptsPacket(query) is true.
func (r *Resolver) Respond(query *packet.ParsedPacket, buf []byte) ([]byte, error) {
var resp response
var err error
// respond returns a DNS response to query.
func (r *Resolver) respond(query []byte) ([]byte, error) {
resp := new(response)
// 0. Verify that contract is upheld.
if !r.AcceptsPacket(query) {
return nil, errNotOurQuery
}
// A DNS response is at least as long as the query
if len(buf) < len(query.Buffer()) {
return nil, errSmallBuffer
}
// 1. Parse query packet.
err = r.parseQuery(query, &resp)
// ParseQuery is sufficiently fast to run on every DNS packet.
// This is considerably simpler than extracting the name by hand
// to shave off microseconds in case of delegation.
err := r.parseQuery(query, resp)
// We will not return this error: it is the sender's fault.
if err != nil {
r.logf("tsdns: error during query parsing: %v", err)
r.logf("parsing query: %v", err)
resp.Header.RCode = dns.RCodeFormatError
return marshalResponsePacket(query, &resp, buf)
return marshalResponse(resp)
}
// 2. Service the query.
err = r.makeResponse(&resp)
// Delegate only when not a subdomain of rootDomain.
// We do this on bytes because Name.String() allocates.
rawName := resp.Question.Name.Data[:resp.Question.Name.Length]
if !bytes.HasSuffix(rawName, r.rootDomain) {
out, err := r.delegate(query)
if err != nil {
r.logf("delegating: %v", err)
resp.Header.RCode = dns.RCodeServerFailure
return marshalResponse(resp)
}
return out, nil
}
switch resp.Question.Type {
case dns.TypeA, dns.TypeAAAA:
domain := resp.Question.Name.String()
// Strip off the trailing period.
// This is safe: Name is guaranteed to have a trailing period by construction.
domain = domain[:len(domain)-1]
resp.IP, resp.Header.RCode, err = r.Resolve(domain)
default:
resp.Header.RCode = dns.RCodeNotImplemented
err = errNotImplemented
}
// We will not return this error: it is the sender's fault.
if err != nil {
r.logf("tsdns: error during name resolution: %v", err)
return marshalResponsePacket(query, &resp, buf)
}
// For now, we require IPv4 in all cases.
// If we somehow came up with a non-IPv4 address, it's our fault.
if !resp.IP.Is4() {
resp.Header.RCode = dns.RCodeServerFailure
r.logf("tsdns: error during name resolution: IPv6 address: %v", resp.IP)
r.logf("resolving: %v", err)
}
// 3. Serialize the response.
return marshalResponsePacket(query, &resp, buf)
return marshalResponse(resp)
}