net/dns/resolver: treat DNS REFUSED responses as soft errors in forwarder race (#19053)

When racing multiple upstream DNS resolvers, a REFUSED (RCode 5) response
from a broken or misconfigured resolver could win the race and be returned
to the client before healthier resolvers had a chance to respond with a
valid answer. This caused complete DNS failure in cases where, e.g., a
broken upstream resolver returned REFUSED quickly while a working resolver
(such as 1.1.1.1) was still responding.

Previously, only SERVFAIL (RCode 2) was treated as a soft error. REFUSED
responses were returned as successful bytes and could win the race
immediately. This change also treats REFUSED as a soft error in the UDP
and TCP forwarding paths, so the race continues until a better answer
arrives. If all resolvers refuse, the first REFUSED response is returned
to the client.

Additionally, SERVFAIL responses from upstream resolvers are now returned
verbatim to the client rather than replaced with a locally synthesized
packet. Synthesized SERVFAIL responses were authoritative and guaranteed
to include a question section echoing the original query; upstream
responses carry no such guarantees but may include extended error
information (e.g. RFC 8914 extended DNS errors) that would otherwise
be lost.

Fixes #19024

Signed-off-by: Brendan Creane <bcreane@gmail.com>
main
Brendan Creane 4 weeks ago committed by GitHub
parent 04ef9d80b5
commit 0b4c0f2080
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 110
      net/dns/resolver/forwarder.go
  2. 102
      net/dns/resolver/forwarder_test.go
  3. 32
      net/dns/resolver/tsdns.go
  4. 8
      net/dns/resolver/tsdns_test.go

@ -739,6 +739,27 @@ type truncatedResponseError struct {
func (tr truncatedResponseError) Error() string { return "response truncated" }
// rcodeResponseError is returned when an upstream DNS server responds with an
// rcode that is treated as a soft error (currently REFUSED and SERVFAIL). The
// response bytes are preserved so they can be returned to the client rather
// than synthesizing a new response.
type rcodeResponseError struct {
rcode dns.RCode
res []byte
}
func (r rcodeResponseError) Error() string { return r.Unwrap().Error() }
func (r rcodeResponseError) Unwrap() error {
switch r.rcode {
case dns.RCodeRefused:
return errRefused
case dns.RCodeServerFailure:
return errServerFailure
}
return nil
}
var errRefused = errors.New("response code indicates refusal")
var errServerFailure = errors.New("response code indicates server issue")
var errTxIDMismatch = errors.New("txid doesn't match")
@ -812,10 +833,16 @@ func (f *forwarder) sendUDP(ctx context.Context, fq *forwardQuery, rr resolverAn
rcode := getRCode(out)
// don't forward transient errors back to the client when the server fails
if rcode == dns.RCodeServerFailure {
f.logf("recv: response code indicating server failure: %d", rcode)
switch rcode {
case dns.RCodeServerFailure:
f.logf("sendUDP: response code indicating server failure: %d", rcode)
metricDNSFwdUDPErrorServer.Add(1)
return nil, errServerFailure
return nil, rcodeResponseError{dns.RCodeServerFailure, out}
case dns.RCodeRefused:
// treat REFUSED as a soft error so other resolvers in the race can respond
f.logf("sendUDP: response code indicating refusal: %d", rcode)
metricDNSFwdUDPErrorRefused.Add(1)
return nil, rcodeResponseError{dns.RCodeRefused, out}
}
// Set the truncated bit if buffer was truncated during read and the flag isn't already set
@ -951,10 +978,16 @@ func (f *forwarder) sendTCP(ctx context.Context, fq *forwardQuery, rr resolverAn
rcode := getRCode(out)
// don't forward transient errors back to the client when the server fails
if rcode == dns.RCodeServerFailure {
switch rcode {
case dns.RCodeServerFailure:
f.logf("sendTCP: response code indicating server failure: %d", rcode)
metricDNSFwdTCPErrorServer.Add(1)
return nil, errServerFailure
return nil, rcodeResponseError{dns.RCodeServerFailure, out}
case dns.RCodeRefused:
// treat REFUSED as a soft error so other resolvers in the race can respond
f.logf("sendTCP: response code indicating refusal: %d", rcode)
metricDNSFwdTCPErrorRefused.Add(1)
return nil, rcodeResponseError{dns.RCodeRefused, out}
}
// TODO(andrew): do we need to do this?
@ -1128,6 +1161,7 @@ func (f *forwarder) forwardWithDestChan(ctx context.Context, query packet, respo
var firstErr error
var numErr int
var sawNonRefused bool
for {
select {
case v := <-resc:
@ -1147,32 +1181,56 @@ func (f *forwarder) forwardWithDestChan(ctx context.Context, query packet, respo
if firstErr == nil {
firstErr = err
}
if !errors.Is(err, errRefused) {
sawNonRefused = true
}
numErr++
if numErr == len(resolvers) {
if errors.Is(firstErr, errServerFailure) {
res, err := servfailResponse(query)
if err != nil {
f.logf("building servfail response: %v", err)
var res packet
if sawNonRefused {
// At least one server failed with SERVFAIL or a transport error
// (e.g. network failure, TxID mismatch, unsupported resolver type).
// All such errors map to SERVFAIL at the client level.
// Prefer returning the upstream SERVFAIL bytes from firstErr if
// available; otherwise synthesize a SERVFAIL response. Note the
// rcode guard: firstErr may be a REFUSED rcodeResponseError if it
// arrived before the SERVFAIL that set sawNonRefused.
if rcodeErr, ok := errors.AsType[rcodeResponseError](firstErr); ok && rcodeErr.rcode == dns.RCodeServerFailure {
res = packet{rcodeErr.res, query.family, query.addr}
} else {
r, err := servfailResponse(query)
if err != nil {
f.logf("building servfail response: %v", err)
return firstErr
}
res = r
}
} else {
// !sawNonRefused means every error was an rcodeResponseError with rcode REFUSED,
// so firstErr is guaranteed to wrap one.
rcodeErr, ok := errors.AsType[rcodeResponseError](firstErr)
if !ok {
f.logf("unexpected: all errors were REFUSED but firstErr is not rcodeResponseError: %v", firstErr)
return firstErr
}
select {
case <-ctx.Done():
metricDNSFwdErrorContext.Add(1)
metricDNSFwdErrorContextGotError.Add(1)
var resolverAddrs []string
for _, rr := range resolvers {
resolverAddrs = append(resolverAddrs, rr.name.Addr)
}
if f.acceptDNS {
f.health.SetUnhealthy(dnsForwarderFailing, health.Args{health.ArgDNSServers: strings.Join(resolverAddrs, ",")})
}
case responseChan <- res:
if f.verboseFwd {
f.logf("forwarder response(%d, %v, %d) = %d, %v", fq.txid, typ, len(domain), len(res.bs), firstErr)
}
return nil
res = packet{rcodeErr.res, query.family, query.addr}
}
select {
case <-ctx.Done():
metricDNSFwdErrorContext.Add(1)
metricDNSFwdErrorContextGotError.Add(1)
var resolverAddrs []string
for _, rr := range resolvers {
resolverAddrs = append(resolverAddrs, rr.name.Addr)
}
if f.acceptDNS {
f.health.SetUnhealthy(dnsForwarderFailing, health.Args{health.ArgDNSServers: strings.Join(resolverAddrs, ",")})
}
case responseChan <- res:
if f.verboseFwd {
f.logf("forwarder response(%d, %v, %d) = %d, %v", fq.txid, typ, len(domain), len(res.bs), firstErr)
}
return nil
}
return firstErr
}

@ -1162,8 +1162,19 @@ func TestForwarderWithManyResolvers(t *testing.T) {
},
},
{
name: "Refused",
responses: [][]byte{ // All upstream servers return different failures.
name: "AllRefused",
responses: [][]byte{ // All upstream servers return REFUSED.
makeTestResponse(t, domain, dns.RCodeRefused),
makeTestResponse(t, domain, dns.RCodeRefused),
makeTestResponse(t, domain, dns.RCodeRefused),
},
wantResponses: [][]byte{ // When all refuse, return REFUSED to the client.
makeTestResponse(t, domain, dns.RCodeRefused),
},
},
{
name: "Refused+Success",
responses: [][]byte{ // Some upstream servers refuse, but one succeeds.
makeTestResponse(t, domain, dns.RCodeRefused),
makeTestResponse(t, domain, dns.RCodeRefused),
makeTestResponse(t, domain, dns.RCodeRefused),
@ -1171,21 +1182,30 @@ func TestForwarderWithManyResolvers(t *testing.T) {
makeTestResponse(t, domain, dns.RCodeRefused),
makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.1")),
},
wantResponses: [][]byte{ // Refused is not considered to be an error and can be forwarded.
makeTestResponse(t, domain, dns.RCodeRefused),
wantResponses: [][]byte{ // Refused is treated as a soft error; the Success response should win.
makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.1")),
},
},
{
name: "Refused+ServFail",
responses: [][]byte{ // Some servers refuse, at least one fails.
makeTestResponse(t, domain, dns.RCodeRefused),
makeTestResponse(t, domain, dns.RCodeServerFailure),
makeTestResponse(t, domain, dns.RCodeRefused),
},
wantResponses: [][]byte{ // Any non-REFUSED failure triggers SERVFAIL regardless of arrival order.
makeTestResponse(t, domain, dns.RCodeServerFailure),
},
},
{
name: "MixFail",
responses: [][]byte{ // All upstream servers return different failures.
responses: [][]byte{ // Upstream servers return different failures.
makeTestResponse(t, domain, dns.RCodeServerFailure),
makeTestResponse(t, domain, dns.RCodeNameError),
makeTestResponse(t, domain, dns.RCodeRefused),
},
wantResponses: [][]byte{ // Both NXDomain and Refused can be forwarded.
wantResponses: [][]byte{ // SERVFAIL and REFUSED are soft errors; NXDOMAIN wins.
makeTestResponse(t, domain, dns.RCodeNameError),
makeTestResponse(t, domain, dns.RCodeRefused),
},
},
}
@ -1297,3 +1317,71 @@ func TestForwarderVerboseLogs(t *testing.T) {
t.Errorf("expected forwarding log, got:\n%s", logStr)
}
}
// TestForwarderHealthOnContextExpiry verifies that when all resolvers fail and
// the context expires before the response can be sent, the health tracker is
// set unhealthy if and only if acceptDNS is true.
func TestForwarderHealthOnContextExpiry(t *testing.T) {
const domain = "health-test.example.com."
tests := []struct {
name string
acceptDNS bool
wantUnhealthy bool
}{
{"acceptDNS=true", true, true},
{"acceptDNS=false", false, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
request := makeTestRequest(t, domain, dns.TypeA, 0)
logf := tstest.WhileTestRunningLogger(t)
bus := eventbustest.NewBus(t)
netMon, err := netmon.New(bus, logf)
if err != nil {
t.Fatal(err)
}
var dialer tsdial.Dialer
dialer.SetNetMon(netMon)
dialer.SetBus(bus)
ht := health.NewTracker(bus)
fwd := newForwarder(logf, netMon, nil, &dialer, ht, nil)
fwd.acceptDNS = tt.acceptDNS
port1 := runDNSServer(t, nil, makeTestResponse(t, domain, dns.RCodeServerFailure), func(bool, []byte) {})
port2 := runDNSServer(t, nil, makeTestResponse(t, domain, dns.RCodeServerFailure), func(bool, []byte) {})
resolvers := []resolverAndDelay{
{name: &dnstype.Resolver{Addr: fmt.Sprintf("127.0.0.1:%d", port1)}},
{name: &dnstype.Resolver{Addr: fmt.Sprintf("127.0.0.1:%d", port2)}},
}
rpkt := packet{
bs: request,
family: "udp",
addr: netip.MustParseAddrPort("127.0.0.1:12345"),
}
// Use an unbuffered responseChan so the send blocks, forcing the
// ctx.Done path and the SetUnhealthy call.
responseChan := make(chan packet)
ctx, cancel := context.WithCancel(context.Background())
// Cancel after DNS servers have had time to respond and their errors
// collected, leaving forwardWithDestChan blocked on responseChan.
go func() {
time.Sleep(50 * time.Millisecond)
cancel()
}()
fwd.forwardWithDestChan(ctx, rpkt, responseChan, resolvers...)
if got := ht.IsUnhealthy(dnsForwarderFailing); got != tt.wantUnhealthy {
t.Errorf("IsUnhealthy = %v, want %v", got, tt.wantUnhealthy)
}
})
}
}

@ -1402,21 +1402,23 @@ var (
metricDNSFwdErrorType = clientmetric.NewCounter("dns_query_fwd_error_type")
metricDNSFwdTruncated = clientmetric.NewCounter("dns_query_fwd_truncated")
metricDNSFwdUDP = clientmetric.NewCounter("dns_query_fwd_udp") // on entry
metricDNSFwdUDPWrote = clientmetric.NewCounter("dns_query_fwd_udp_wrote") // sent UDP packet
metricDNSFwdUDPErrorWrite = clientmetric.NewCounter("dns_query_fwd_udp_error_write")
metricDNSFwdUDPErrorServer = clientmetric.NewCounter("dns_query_fwd_udp_error_server")
metricDNSFwdUDPErrorTxID = clientmetric.NewCounter("dns_query_fwd_udp_error_txid")
metricDNSFwdUDPErrorRead = clientmetric.NewCounter("dns_query_fwd_udp_error_read")
metricDNSFwdUDPSuccess = clientmetric.NewCounter("dns_query_fwd_udp_success")
metricDNSFwdTCP = clientmetric.NewCounter("dns_query_fwd_tcp") // on entry
metricDNSFwdTCPWrote = clientmetric.NewCounter("dns_query_fwd_tcp_wrote") // sent TCP packet
metricDNSFwdTCPErrorWrite = clientmetric.NewCounter("dns_query_fwd_tcp_error_write")
metricDNSFwdTCPErrorServer = clientmetric.NewCounter("dns_query_fwd_tcp_error_server")
metricDNSFwdTCPErrorTxID = clientmetric.NewCounter("dns_query_fwd_tcp_error_txid")
metricDNSFwdTCPErrorRead = clientmetric.NewCounter("dns_query_fwd_tcp_error_read")
metricDNSFwdTCPSuccess = clientmetric.NewCounter("dns_query_fwd_tcp_success")
metricDNSFwdUDP = clientmetric.NewCounter("dns_query_fwd_udp") // on entry
metricDNSFwdUDPWrote = clientmetric.NewCounter("dns_query_fwd_udp_wrote") // sent UDP packet
metricDNSFwdUDPErrorWrite = clientmetric.NewCounter("dns_query_fwd_udp_error_write")
metricDNSFwdUDPErrorServer = clientmetric.NewCounter("dns_query_fwd_udp_error_server")
metricDNSFwdUDPErrorRefused = clientmetric.NewCounter("dns_query_fwd_udp_error_refused")
metricDNSFwdUDPErrorTxID = clientmetric.NewCounter("dns_query_fwd_udp_error_txid")
metricDNSFwdUDPErrorRead = clientmetric.NewCounter("dns_query_fwd_udp_error_read")
metricDNSFwdUDPSuccess = clientmetric.NewCounter("dns_query_fwd_udp_success")
metricDNSFwdTCP = clientmetric.NewCounter("dns_query_fwd_tcp") // on entry
metricDNSFwdTCPWrote = clientmetric.NewCounter("dns_query_fwd_tcp_wrote") // sent TCP packet
metricDNSFwdTCPErrorWrite = clientmetric.NewCounter("dns_query_fwd_tcp_error_write")
metricDNSFwdTCPErrorServer = clientmetric.NewCounter("dns_query_fwd_tcp_error_server")
metricDNSFwdTCPErrorRefused = clientmetric.NewCounter("dns_query_fwd_tcp_error_refused")
metricDNSFwdTCPErrorTxID = clientmetric.NewCounter("dns_query_fwd_tcp_error_txid")
metricDNSFwdTCPErrorRead = clientmetric.NewCounter("dns_query_fwd_tcp_error_read")
metricDNSFwdTCPSuccess = clientmetric.NewCounter("dns_query_fwd_tcp_success")
metricDNSFwdDoH = clientmetric.NewCounter("dns_query_fwd_doh")
metricDNSFwdDoHErrorStatus = clientmetric.NewCounter("dns_query_fwd_doh_error_status")

@ -1557,15 +1557,13 @@ func TestServfail(t *testing.T) {
t.Fatalf("err = %v, want nil", err)
}
// The upstream server's SERVFAIL bytes are returned directly.
wantPkt := []byte{
0x00, 0x00, // transaction id: 0
0x84, 0x02, // flags: response, authoritative, error: servfail
0x00, 0x01, // one question
0x00, 0x02, // flags: error: servfail
0x00, 0x00, // no questions (upstream sent a minimal response)
0x00, 0x00, // no answers
0x00, 0x00, 0x00, 0x00, // no authority or additional RRs
// Question:
0x04, 0x74, 0x65, 0x73, 0x74, 0x04, 0x73, 0x69, 0x74, 0x65, 0x00, // name
0x00, 0x01, 0x00, 0x01, // type A, class IN
}
if !bytes.Equal(pkt, wantPkt) {

Loading…
Cancel
Save