tsdns: fix accidental rejection of all non-{A, AAAA} questions.

This is a bug introduced in a903d6c2ed.

Signed-off-by: Dmytro Shynkevych <dmytro@tailscale.com>
This commit is contained in:
Dmytro Shynkevych
2020-08-27 00:07:15 -04:00
parent 28f9cd06f5
commit bc34788e65
3 changed files with 110 additions and 64 deletions
+80 -45
View File
@@ -48,49 +48,64 @@ func dnspacket(domain string, tp dns.Type) []byte {
return payload
}
func extractipcode(response []byte) (netaddr.IP, dns.RCode, error) {
var ip netaddr.IP
type dnsResponse struct {
ip netaddr.IP
name string
rcode dns.RCode
}
func unpackResponse(payload []byte) (dnsResponse, error) {
var response dnsResponse
var parser dns.Parser
h, err := parser.Start(response)
h, err := parser.Start(payload)
if err != nil {
return ip, 0, err
return response, err
}
if !h.Response {
return ip, 0, errors.New("not a response")
return response, errors.New("not a response")
}
if h.RCode != dns.RCodeSuccess {
return ip, h.RCode, nil
response.rcode = h.RCode
if response.rcode != dns.RCodeSuccess {
return response, nil
}
err = parser.SkipAllQuestions()
if err != nil {
return ip, 0, err
return response, err
}
ah, err := parser.AnswerHeader()
if err != nil {
return ip, 0, err
return response, err
}
switch ah.Type {
case dns.TypeA:
res, err := parser.AResource()
if err != nil {
return ip, 0, err
return response, err
}
ip = netaddr.IPv4(res.A[0], res.A[1], res.A[2], res.A[3])
response.ip = netaddr.IPv4(res.A[0], res.A[1], res.A[2], res.A[3])
case dns.TypeAAAA:
res, err := parser.AAAAResource()
if err != nil {
return ip, 0, err
return response, err
}
ip = netaddr.IPv6Raw(res.AAAA)
response.ip = netaddr.IPv6Raw(res.AAAA)
case dns.TypeNS:
res, err := parser.NSResource()
if err != nil {
return response, err
}
response.name = res.NS.String()
default:
return ip, 0, errors.New("type not in {A, AAAA}")
return response, errors.New("type not in {A, AAAA, NS}")
}
return ip, h.RCode, nil
return response, nil
}
func syncRespond(r *Resolver, query []byte) ([]byte, error) {
@@ -188,20 +203,21 @@ func TestResolve(t *testing.T) {
defer r.Close()
tests := []struct {
name string
domain string
ip netaddr.IP
code dns.RCode
name string
qname string
qtype dns.Type
ip netaddr.IP
code dns.RCode
}{
{"ipv4", "test1.ipn.dev.", testipv4, dns.RCodeSuccess},
{"ipv6", "test2.ipn.dev.", testipv6, dns.RCodeSuccess},
{"nxdomain", "test3.ipn.dev.", netaddr.IP{}, dns.RCodeNameError},
{"foreign domain", "google.com.", netaddr.IP{}, dns.RCodeRefused},
{"ipv4", "test1.ipn.dev.", dns.TypeA, testipv4, dns.RCodeSuccess},
{"ipv6", "test2.ipn.dev.", dns.TypeAAAA, testipv6, dns.RCodeSuccess},
{"nxdomain", "test3.ipn.dev.", dns.TypeA, netaddr.IP{}, dns.RCodeNameError},
{"foreign domain", "google.com.", dns.TypeA, netaddr.IP{}, dns.RCodeRefused},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ip, code, err := r.Resolve(tt.domain)
ip, code, err := r.Resolve(tt.qname, tt.qtype)
if err != nil {
t.Errorf("err = %v; want nil", err)
}
@@ -256,7 +272,7 @@ func TestDelegate(t *testing.T) {
rc := tstest.NewResourceCheck()
defer rc.Assert(t)
dnsHandleFunc("test.site.", resolveToIP(testipv4, testipv6))
dnsHandleFunc("test.site.", resolveToIP(testipv4, testipv6, "dns.test.site."))
dnsHandleFunc("nxdomain.site.", resolveToNXDOMAIN)
v4server, v4errch := serveDNS("127.0.0.1:0")
@@ -296,40 +312,59 @@ func TestDelegate(t *testing.T) {
defer r.Close()
tests := []struct {
name string
query []byte
ip netaddr.IP
code dns.RCode
title string
query []byte
response dnsResponse
}{
{"ipv4", dnspacket("test.site.", dns.TypeA), testipv4, dns.RCodeSuccess},
{"ipv6", dnspacket("test.site.", dns.TypeAAAA), testipv6, dns.RCodeSuccess},
{"nxdomain", dnspacket("nxdomain.site.", dns.TypeA), netaddr.IP{}, dns.RCodeNameError},
{
"ipv4",
dnspacket("test.site.", dns.TypeA),
dnsResponse{ip: testipv4, rcode: dns.RCodeSuccess},
},
{
"ipv6",
dnspacket("test.site.", dns.TypeAAAA),
dnsResponse{ip: testipv6, rcode: dns.RCodeSuccess},
},
{
"ns",
dnspacket("test.site.", dns.TypeNS),
dnsResponse{name: "dns.test.site.", rcode: dns.RCodeSuccess},
},
{
"nxdomain",
dnspacket("nxdomain.site.", dns.TypeA),
dnsResponse{rcode: dns.RCodeNameError},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
resp, err := syncRespond(r, tt.query)
t.Run(tt.title, func(t *testing.T) {
payload, err := syncRespond(r, tt.query)
if err != nil {
t.Errorf("err = %v; want nil", err)
return
}
ip, code, err := extractipcode(resp)
response, err := unpackResponse(payload)
if err != nil {
t.Errorf("extract: err = %v; want nil (in %x)", err, resp)
t.Errorf("extract: err = %v; want nil (in %x)", err, payload)
return
}
if code != tt.code {
t.Errorf("code = %v; want %v", code, tt.code)
if response.rcode != tt.response.rcode {
t.Errorf("rcode = %v; want %v", response.rcode, tt.response.rcode)
}
if ip != tt.ip {
t.Errorf("ip = %v; want %v", ip, tt.ip)
if response.ip != tt.response.ip {
t.Errorf("ip = %v; want %v", response.ip, tt.response.ip)
}
if response.name != tt.response.name {
t.Errorf("name = %v; want %v", response.name, tt.response.name)
}
})
}
}
func TestDelegateCollision(t *testing.T) {
dnsHandleFunc("test.site.", resolveToIP(testipv4, testipv6))
dnsHandleFunc("test.site.", resolveToIP(testipv4, testipv6, "dns.test.site."))
server, errch := serveDNS("127.0.0.1:0")
defer func() {
@@ -425,13 +460,13 @@ func TestConcurrentSetMap(t *testing.T) {
}()
go func() {
defer wg.Done()
r.Resolve("test1.ipn.dev")
r.Resolve("test1.ipn.dev", dns.TypeA)
}()
wg.Wait()
}
func TestConcurrentSetUpstreams(t *testing.T) {
dnsHandleFunc("test.site.", resolveToIP(testipv4, testipv6))
dnsHandleFunc("test.site.", resolveToIP(testipv4, testipv6, "dns.test.site."))
server, errch := serveDNS("127.0.0.1:0")
defer func() {
@@ -570,7 +605,7 @@ func TestFull(t *testing.T) {
{"ipv6", dnspacket("test2.ipn.dev.", dns.TypeAAAA), ipv6Response},
{"upper", dnspacket("TEST1.IPN.DEV.", dns.TypeA), ipv4UppercaseResponse},
{"ptr", dnspacket("4.3.2.1.in-addr.arpa.", dns.TypePTR), ptrResponse},
{"error", dnspacket("test3.ipn.dev.", dns.TypeA), nxdomainResponse},
{"nxdomain", dnspacket("test3.ipn.dev.", dns.TypeA), nxdomainResponse},
}
for _, tt := range tests {
@@ -619,7 +654,7 @@ func TestAllocs(t *testing.T) {
}
func BenchmarkFull(b *testing.B) {
dnsHandleFunc("test.site.", resolveToIP(testipv4, testipv6))
dnsHandleFunc("test.site.", resolveToIP(testipv4, testipv6, "dns.test.site."))
server, errch := serveDNS("127.0.0.1:0")
defer func() {