tsdns: implement reverse DNS lookups, canonicalize names everywhere. (#640)
Signed-off-by: Dmytro Shynkevych <dmytro@tailscale.com>
This commit is contained in:
committed by
GitHub
parent
41c4560592
commit
78c2e1ff83
+191
-28
@@ -9,6 +9,7 @@ package tsdns
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -84,7 +85,7 @@ type Resolver struct {
|
||||
dialer netns.Dialer
|
||||
|
||||
// mu guards the following fields from being updated while used.
|
||||
mu sync.RWMutex
|
||||
mu sync.Mutex
|
||||
// dnsMap is the map most recently received from the control server.
|
||||
dnsMap *Map
|
||||
// nameservers is the list of nameserver addresses that should be used
|
||||
@@ -94,15 +95,15 @@ type Resolver struct {
|
||||
}
|
||||
|
||||
// NewResolver constructs a resolver associated with the given root domain.
|
||||
// The root domain must be in canonical form (with a trailing period).
|
||||
func NewResolver(logf logger.Logf, rootDomain string) *Resolver {
|
||||
r := &Resolver{
|
||||
logf: logger.WithPrefix(logf, "tsdns: "),
|
||||
queue: make(chan Packet, queueSize),
|
||||
responses: make(chan Packet),
|
||||
errors: make(chan error),
|
||||
closed: make(chan struct{}),
|
||||
// Conform to the name format dnsmessage uses (trailing period, bytes).
|
||||
rootDomain: []byte(rootDomain + "."),
|
||||
logf: logger.WithPrefix(logf, "tsdns: "),
|
||||
queue: make(chan Packet, queueSize),
|
||||
responses: make(chan Packet),
|
||||
errors: make(chan error),
|
||||
closed: make(chan struct{}),
|
||||
rootDomain: []byte(rootDomain),
|
||||
dialer: netns.NewDialer(),
|
||||
}
|
||||
|
||||
@@ -173,22 +174,40 @@ func (r *Resolver) NextResponse() (Packet, error) {
|
||||
}
|
||||
|
||||
// Resolve maps a given domain name to the IP address of the host that owns it.
|
||||
// The domain name must not have a trailing period.
|
||||
// The domain name must be in canonical form (with a trailing period).
|
||||
func (r *Resolver) Resolve(domain string) (netaddr.IP, dns.RCode, error) {
|
||||
r.mu.RLock()
|
||||
if r.dnsMap == nil {
|
||||
r.mu.RUnlock()
|
||||
r.mu.Lock()
|
||||
dnsMap := r.dnsMap
|
||||
r.mu.Unlock()
|
||||
|
||||
if dnsMap == nil {
|
||||
return netaddr.IP{}, dns.RCodeServerFailure, errMapNotSet
|
||||
}
|
||||
addr, found := r.dnsMap.nameToIP[domain]
|
||||
r.mu.RUnlock()
|
||||
|
||||
addr, found := dnsMap.nameToIP[domain]
|
||||
if !found {
|
||||
return netaddr.IP{}, dns.RCodeNameError, nil
|
||||
}
|
||||
return addr, dns.RCodeSuccess, nil
|
||||
}
|
||||
|
||||
// ResolveReverse returns the unique domain name that maps to the given address.
|
||||
// The returned domain name is in canonical form (with a trailing period).
|
||||
func (r *Resolver) ResolveReverse(ip netaddr.IP) (string, dns.RCode, error) {
|
||||
r.mu.Lock()
|
||||
dnsMap := r.dnsMap
|
||||
r.mu.Unlock()
|
||||
|
||||
if dnsMap == nil {
|
||||
return "", dns.RCodeServerFailure, errMapNotSet
|
||||
}
|
||||
name, found := dnsMap.ipToName[ip]
|
||||
if !found {
|
||||
return "", dns.RCodeNameError, nil
|
||||
}
|
||||
return name, dns.RCodeSuccess, nil
|
||||
}
|
||||
|
||||
func (r *Resolver) poll() {
|
||||
defer r.pollGroup.Done()
|
||||
|
||||
@@ -253,9 +272,9 @@ func (r *Resolver) queryServer(ctx context.Context, server string, query []byte)
|
||||
|
||||
// delegate forwards the query to all upstream nameservers and returns the first response.
|
||||
func (r *Resolver) delegate(query []byte) ([]byte, error) {
|
||||
r.mu.RLock()
|
||||
r.mu.Lock()
|
||||
nameservers := r.nameservers
|
||||
r.mu.RUnlock()
|
||||
r.mu.Unlock()
|
||||
|
||||
if len(nameservers) == 0 {
|
||||
return nil, errNoNameservers
|
||||
@@ -301,8 +320,10 @@ func (r *Resolver) delegate(query []byte) ([]byte, error) {
|
||||
type response struct {
|
||||
Header dns.Header
|
||||
Question dns.Question
|
||||
Name string
|
||||
IP netaddr.IP
|
||||
// Name is the response to a PTR query.
|
||||
Name string
|
||||
// IP is the response to an A, AAAA, or ANY query.
|
||||
IP netaddr.IP
|
||||
}
|
||||
|
||||
// parseQuery parses the query in given packet into a response struct.
|
||||
@@ -359,6 +380,25 @@ func marshalAAAARecord(name dns.Name, ip netaddr.IP, builder *dns.Builder) error
|
||||
return builder.AAAAResource(answerHeader, answer)
|
||||
}
|
||||
|
||||
// marshalPTRRecord serializes a PTR record into an active builder.
|
||||
// The caller may continue using the builder following the call.
|
||||
func marshalPTRRecord(queryName dns.Name, name string, builder *dns.Builder) error {
|
||||
var answer dns.PTRResource
|
||||
var err error
|
||||
|
||||
answerHeader := dns.ResourceHeader{
|
||||
Name: queryName,
|
||||
Type: dns.TypePTR,
|
||||
Class: dns.ClassINET,
|
||||
TTL: uint32(defaultTTL / time.Second),
|
||||
}
|
||||
answer.PTR, err = dns.NewName(name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return builder.PTRResource(answerHeader, answer)
|
||||
}
|
||||
|
||||
// marshalResponse serializes the DNS response into a new buffer.
|
||||
func marshalResponse(resp *response) ([]byte, error) {
|
||||
resp.Header.Response = true
|
||||
@@ -389,10 +429,15 @@ func marshalResponse(resp *response) ([]byte, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp.IP.Is4() {
|
||||
err = marshalARecord(resp.Question.Name, resp.IP, &builder)
|
||||
} else {
|
||||
err = marshalAAAARecord(resp.Question.Name, resp.IP, &builder)
|
||||
switch resp.Question.Type {
|
||||
case dns.TypeA, dns.TypeAAAA, dns.TypeALL:
|
||||
if resp.IP.Is4() {
|
||||
err = marshalARecord(resp.Question.Name, resp.IP, &builder)
|
||||
} else {
|
||||
err = marshalAAAARecord(resp.Question.Name, resp.IP, &builder)
|
||||
}
|
||||
case dns.TypePTR:
|
||||
err = marshalPTRRecord(resp.Question.Name, resp.Name, &builder)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -401,6 +446,120 @@ func marshalResponse(resp *response) ([]byte, error) {
|
||||
return builder.Finish()
|
||||
}
|
||||
|
||||
var (
|
||||
rdnsv4Suffix = []byte(".in-addr.arpa.")
|
||||
rdnsv6Suffix = []byte(".ip6.arpa.")
|
||||
)
|
||||
|
||||
// ptrNameToIPv4 transforms a PTR name representing an IPv4 address to said address.
|
||||
// Such names are IPv4 labels in reverse order followed by .in-addr.arpa.
|
||||
// For example,
|
||||
// 4.3.2.1.in-addr.arpa
|
||||
// is transformed to
|
||||
// 1.2.3.4
|
||||
func rdnsNameToIPv4(name []byte) (ip netaddr.IP, ok bool) {
|
||||
name = bytes.TrimSuffix(name, rdnsv4Suffix)
|
||||
ip, err := netaddr.ParseIP(string(name))
|
||||
if err != nil {
|
||||
return netaddr.IP{}, false
|
||||
}
|
||||
if !ip.Is4() {
|
||||
return netaddr.IP{}, false
|
||||
}
|
||||
b := ip.As4()
|
||||
return netaddr.IPv4(b[3], b[2], b[1], b[0]), true
|
||||
}
|
||||
|
||||
// ptrNameToIPv6 transforms a PTR name representing an IPv6 address to said address.
|
||||
// Such names are dot-separated nibbles in reverse order followed by .ip6.arpa.
|
||||
// For example,
|
||||
// b.a.9.8.7.6.5.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.8.b.d.0.1.0.0.2.ip6.arpa.
|
||||
// is transformed to
|
||||
// 2001:db8::567:89ab
|
||||
func rdnsNameToIPv6(name []byte) (ip netaddr.IP, ok bool) {
|
||||
var b [32]byte
|
||||
var ipb [16]byte
|
||||
|
||||
name = bytes.TrimSuffix(name, rdnsv6Suffix)
|
||||
// 32 nibbles and 31 dots between them.
|
||||
if len(name) != 63 {
|
||||
return netaddr.IP{}, false
|
||||
}
|
||||
|
||||
// Dots and hex digits alternate.
|
||||
prevDot := true
|
||||
// i ranges over name backward; j ranges over b forward.
|
||||
for i, j := len(name)-1, 0; i >= 0; i-- {
|
||||
thisDot := (name[i] == '.')
|
||||
if prevDot == thisDot {
|
||||
return netaddr.IP{}, false
|
||||
}
|
||||
prevDot = thisDot
|
||||
|
||||
if !thisDot {
|
||||
// This is safe assuming alternation.
|
||||
// We do not check that non-dots are hex digits: hex.Decode below will do that.
|
||||
b[j] = name[i]
|
||||
j++
|
||||
}
|
||||
}
|
||||
|
||||
_, err := hex.Decode(ipb[:], b[:])
|
||||
if err != nil {
|
||||
return netaddr.IP{}, false
|
||||
}
|
||||
|
||||
return netaddr.IPFrom16(ipb), true
|
||||
}
|
||||
|
||||
// respondReverse returns a DNS response to a PTR query.
|
||||
// It is assumed that resp.Question is populated by respond before this is called.
|
||||
func (r *Resolver) respondReverse(query []byte, resp *response) ([]byte, error) {
|
||||
name := resp.Question.Name.Data[:resp.Question.Name.Length]
|
||||
|
||||
shouldDelegate := false
|
||||
|
||||
var ip netaddr.IP
|
||||
var ok bool
|
||||
var err error
|
||||
switch {
|
||||
case bytes.HasSuffix(name, rdnsv4Suffix):
|
||||
ip, ok = rdnsNameToIPv4(name)
|
||||
case bytes.HasSuffix(name, rdnsv6Suffix):
|
||||
ip, ok = rdnsNameToIPv6(name)
|
||||
default:
|
||||
shouldDelegate = true
|
||||
}
|
||||
|
||||
// It is more likely that we failed in parsing the name than that it is actually malformed.
|
||||
// To avoid frustrating users, just log and delegate.
|
||||
if !ok {
|
||||
// Without this conversion, escape analysis rules that resp escapes.
|
||||
r.logf("parsing rdns: malformed name: %s", resp.Question.Name.String())
|
||||
shouldDelegate = true
|
||||
}
|
||||
|
||||
if !shouldDelegate {
|
||||
resp.Name, resp.Header.RCode, err = r.ResolveReverse(ip)
|
||||
if err != nil {
|
||||
r.logf("resolving rdns: %v", ip, err)
|
||||
}
|
||||
shouldDelegate = (resp.Header.RCode == dns.RCodeNameError)
|
||||
}
|
||||
|
||||
if shouldDelegate {
|
||||
out, err := r.delegate(query)
|
||||
if err != nil {
|
||||
r.logf("delegating rdns: %v", err)
|
||||
resp.Header.RCode = dns.RCodeServerFailure
|
||||
return marshalResponse(resp)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
return marshalResponse(resp)
|
||||
}
|
||||
|
||||
// respond returns a DNS response to query.
|
||||
func (r *Resolver) respond(query []byte) ([]byte, error) {
|
||||
resp := new(response)
|
||||
@@ -416,7 +575,14 @@ func (r *Resolver) respond(query []byte) ([]byte, error) {
|
||||
return marshalResponse(resp)
|
||||
}
|
||||
|
||||
// Delegate only when not a subdomain of rootDomain.
|
||||
// Always try to handle reverse lookups; delegate inside when not found.
|
||||
// This way, queries for exitent nodes do not leak,
|
||||
// but we behave gracefully if non-Tailscale nodes exist in CGNATRange.
|
||||
if resp.Question.Type == dns.TypePTR {
|
||||
return r.respondReverse(query, resp)
|
||||
}
|
||||
|
||||
// Delegate forward lookups when not a subdomain of rootDomain.
|
||||
// We do this on bytes because Name.String() allocates.
|
||||
rawName := resp.Question.Name.Data[:resp.Question.Name.Length]
|
||||
if !bytes.HasSuffix(rawName, r.rootDomain) {
|
||||
@@ -431,11 +597,8 @@ func (r *Resolver) respond(query []byte) ([]byte, error) {
|
||||
|
||||
switch resp.Question.Type {
|
||||
case dns.TypeA, dns.TypeAAAA, dns.TypeALL:
|
||||
domain := resp.Question.Name.String()
|
||||
// Strip off the trailing period.
|
||||
// This is safe: Name is guaranteed to have a trailing period by construction.
|
||||
domain = domain[:len(domain)-1]
|
||||
resp.IP, resp.Header.RCode, err = r.Resolve(domain)
|
||||
name := resp.Question.Name.String()
|
||||
resp.IP, resp.Header.RCode, err = r.Resolve(name)
|
||||
default:
|
||||
resp.Header.RCode = dns.RCodeNotImplemented
|
||||
err = errNotImplemented
|
||||
|
||||
Reference in New Issue
Block a user