From 0b4c0f208049358391b66756e4decb4b0f72177f Mon Sep 17 00:00:00 2001 From: Brendan Creane Date: Mon, 23 Mar 2026 10:40:05 -0700 Subject: [PATCH] net/dns/resolver: treat DNS REFUSED responses as soft errors in forwarder race (#19053) When racing multiple upstream DNS resolvers, a REFUSED (RCode 5) response from a broken or misconfigured resolver could win the race and be returned to the client before healthier resolvers had a chance to respond with a valid answer. This caused complete DNS failure in cases where, e.g., a broken upstream resolver returned REFUSED quickly while a working resolver (such as 1.1.1.1) was still responding. Previously, only SERVFAIL (RCode 2) was treated as a soft error. REFUSED responses were returned as successful bytes and could win the race immediately. This change also treats REFUSED as a soft error in the UDP and TCP forwarding paths, so the race continues until a better answer arrives. If all resolvers refuse, the first REFUSED response is returned to the client. Additionally, SERVFAIL responses from upstream resolvers are now returned verbatim to the client rather than replaced with a locally synthesized packet. Synthesized SERVFAIL responses were authoritative and guaranteed to include a question section echoing the original query; upstream responses carry no such guarantees but may include extended error information (e.g. RFC 8914 extended DNS errors) that would otherwise be lost. Fixes #19024 Signed-off-by: Brendan Creane --- net/dns/resolver/forwarder.go | 110 ++++++++++++++++++++++------- net/dns/resolver/forwarder_test.go | 102 ++++++++++++++++++++++++-- net/dns/resolver/tsdns.go | 32 +++++---- net/dns/resolver/tsdns_test.go | 8 +-- 4 files changed, 199 insertions(+), 53 deletions(-) diff --git a/net/dns/resolver/forwarder.go b/net/dns/resolver/forwarder.go index ca1599589..ed7ff78f7 100644 --- a/net/dns/resolver/forwarder.go +++ b/net/dns/resolver/forwarder.go @@ -739,6 +739,27 @@ type truncatedResponseError struct { func (tr truncatedResponseError) Error() string { return "response truncated" } +// rcodeResponseError is returned when an upstream DNS server responds with an +// rcode that is treated as a soft error (currently REFUSED and SERVFAIL). The +// response bytes are preserved so they can be returned to the client rather +// than synthesizing a new response. +type rcodeResponseError struct { + rcode dns.RCode + res []byte +} + +func (r rcodeResponseError) Error() string { return r.Unwrap().Error() } +func (r rcodeResponseError) Unwrap() error { + switch r.rcode { + case dns.RCodeRefused: + return errRefused + case dns.RCodeServerFailure: + return errServerFailure + } + return nil +} + +var errRefused = errors.New("response code indicates refusal") var errServerFailure = errors.New("response code indicates server issue") var errTxIDMismatch = errors.New("txid doesn't match") @@ -812,10 +833,16 @@ func (f *forwarder) sendUDP(ctx context.Context, fq *forwardQuery, rr resolverAn rcode := getRCode(out) // don't forward transient errors back to the client when the server fails - if rcode == dns.RCodeServerFailure { - f.logf("recv: response code indicating server failure: %d", rcode) + switch rcode { + case dns.RCodeServerFailure: + f.logf("sendUDP: response code indicating server failure: %d", rcode) metricDNSFwdUDPErrorServer.Add(1) - return nil, errServerFailure + return nil, rcodeResponseError{dns.RCodeServerFailure, out} + case dns.RCodeRefused: + // treat REFUSED as a soft error so other resolvers in the race can respond + f.logf("sendUDP: response code indicating refusal: %d", rcode) + metricDNSFwdUDPErrorRefused.Add(1) + return nil, rcodeResponseError{dns.RCodeRefused, out} } // Set the truncated bit if buffer was truncated during read and the flag isn't already set @@ -951,10 +978,16 @@ func (f *forwarder) sendTCP(ctx context.Context, fq *forwardQuery, rr resolverAn rcode := getRCode(out) // don't forward transient errors back to the client when the server fails - if rcode == dns.RCodeServerFailure { + switch rcode { + case dns.RCodeServerFailure: f.logf("sendTCP: response code indicating server failure: %d", rcode) metricDNSFwdTCPErrorServer.Add(1) - return nil, errServerFailure + return nil, rcodeResponseError{dns.RCodeServerFailure, out} + case dns.RCodeRefused: + // treat REFUSED as a soft error so other resolvers in the race can respond + f.logf("sendTCP: response code indicating refusal: %d", rcode) + metricDNSFwdTCPErrorRefused.Add(1) + return nil, rcodeResponseError{dns.RCodeRefused, out} } // TODO(andrew): do we need to do this? @@ -1128,6 +1161,7 @@ func (f *forwarder) forwardWithDestChan(ctx context.Context, query packet, respo var firstErr error var numErr int + var sawNonRefused bool for { select { case v := <-resc: @@ -1147,32 +1181,56 @@ func (f *forwarder) forwardWithDestChan(ctx context.Context, query packet, respo if firstErr == nil { firstErr = err } + if !errors.Is(err, errRefused) { + sawNonRefused = true + } numErr++ if numErr == len(resolvers) { - if errors.Is(firstErr, errServerFailure) { - res, err := servfailResponse(query) - if err != nil { - f.logf("building servfail response: %v", err) + var res packet + if sawNonRefused { + // At least one server failed with SERVFAIL or a transport error + // (e.g. network failure, TxID mismatch, unsupported resolver type). + // All such errors map to SERVFAIL at the client level. + // Prefer returning the upstream SERVFAIL bytes from firstErr if + // available; otherwise synthesize a SERVFAIL response. Note the + // rcode guard: firstErr may be a REFUSED rcodeResponseError if it + // arrived before the SERVFAIL that set sawNonRefused. + if rcodeErr, ok := errors.AsType[rcodeResponseError](firstErr); ok && rcodeErr.rcode == dns.RCodeServerFailure { + res = packet{rcodeErr.res, query.family, query.addr} + } else { + r, err := servfailResponse(query) + if err != nil { + f.logf("building servfail response: %v", err) + return firstErr + } + res = r + } + } else { + // !sawNonRefused means every error was an rcodeResponseError with rcode REFUSED, + // so firstErr is guaranteed to wrap one. + rcodeErr, ok := errors.AsType[rcodeResponseError](firstErr) + if !ok { + f.logf("unexpected: all errors were REFUSED but firstErr is not rcodeResponseError: %v", firstErr) return firstErr } - - select { - case <-ctx.Done(): - metricDNSFwdErrorContext.Add(1) - metricDNSFwdErrorContextGotError.Add(1) - var resolverAddrs []string - for _, rr := range resolvers { - resolverAddrs = append(resolverAddrs, rr.name.Addr) - } - if f.acceptDNS { - f.health.SetUnhealthy(dnsForwarderFailing, health.Args{health.ArgDNSServers: strings.Join(resolverAddrs, ",")}) - } - case responseChan <- res: - if f.verboseFwd { - f.logf("forwarder response(%d, %v, %d) = %d, %v", fq.txid, typ, len(domain), len(res.bs), firstErr) - } - return nil + res = packet{rcodeErr.res, query.family, query.addr} + } + select { + case <-ctx.Done(): + metricDNSFwdErrorContext.Add(1) + metricDNSFwdErrorContextGotError.Add(1) + var resolverAddrs []string + for _, rr := range resolvers { + resolverAddrs = append(resolverAddrs, rr.name.Addr) + } + if f.acceptDNS { + f.health.SetUnhealthy(dnsForwarderFailing, health.Args{health.ArgDNSServers: strings.Join(resolverAddrs, ",")}) + } + case responseChan <- res: + if f.verboseFwd { + f.logf("forwarder response(%d, %v, %d) = %d, %v", fq.txid, typ, len(domain), len(res.bs), firstErr) } + return nil } return firstErr } diff --git a/net/dns/resolver/forwarder_test.go b/net/dns/resolver/forwarder_test.go index 3ddb47433..670e8fe2b 100644 --- a/net/dns/resolver/forwarder_test.go +++ b/net/dns/resolver/forwarder_test.go @@ -1162,8 +1162,19 @@ func TestForwarderWithManyResolvers(t *testing.T) { }, }, { - name: "Refused", - responses: [][]byte{ // All upstream servers return different failures. + name: "AllRefused", + responses: [][]byte{ // All upstream servers return REFUSED. + makeTestResponse(t, domain, dns.RCodeRefused), + makeTestResponse(t, domain, dns.RCodeRefused), + makeTestResponse(t, domain, dns.RCodeRefused), + }, + wantResponses: [][]byte{ // When all refuse, return REFUSED to the client. + makeTestResponse(t, domain, dns.RCodeRefused), + }, + }, + { + name: "Refused+Success", + responses: [][]byte{ // Some upstream servers refuse, but one succeeds. makeTestResponse(t, domain, dns.RCodeRefused), makeTestResponse(t, domain, dns.RCodeRefused), makeTestResponse(t, domain, dns.RCodeRefused), @@ -1171,21 +1182,30 @@ func TestForwarderWithManyResolvers(t *testing.T) { makeTestResponse(t, domain, dns.RCodeRefused), makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.1")), }, - wantResponses: [][]byte{ // Refused is not considered to be an error and can be forwarded. - makeTestResponse(t, domain, dns.RCodeRefused), + wantResponses: [][]byte{ // Refused is treated as a soft error; the Success response should win. makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.1")), }, }, + { + name: "Refused+ServFail", + responses: [][]byte{ // Some servers refuse, at least one fails. + makeTestResponse(t, domain, dns.RCodeRefused), + makeTestResponse(t, domain, dns.RCodeServerFailure), + makeTestResponse(t, domain, dns.RCodeRefused), + }, + wantResponses: [][]byte{ // Any non-REFUSED failure triggers SERVFAIL regardless of arrival order. + makeTestResponse(t, domain, dns.RCodeServerFailure), + }, + }, { name: "MixFail", - responses: [][]byte{ // All upstream servers return different failures. + responses: [][]byte{ // Upstream servers return different failures. makeTestResponse(t, domain, dns.RCodeServerFailure), makeTestResponse(t, domain, dns.RCodeNameError), makeTestResponse(t, domain, dns.RCodeRefused), }, - wantResponses: [][]byte{ // Both NXDomain and Refused can be forwarded. + wantResponses: [][]byte{ // SERVFAIL and REFUSED are soft errors; NXDOMAIN wins. makeTestResponse(t, domain, dns.RCodeNameError), - makeTestResponse(t, domain, dns.RCodeRefused), }, }, } @@ -1297,3 +1317,71 @@ func TestForwarderVerboseLogs(t *testing.T) { t.Errorf("expected forwarding log, got:\n%s", logStr) } } + +// TestForwarderHealthOnContextExpiry verifies that when all resolvers fail and +// the context expires before the response can be sent, the health tracker is +// set unhealthy if and only if acceptDNS is true. +func TestForwarderHealthOnContextExpiry(t *testing.T) { + const domain = "health-test.example.com." + + tests := []struct { + name string + acceptDNS bool + wantUnhealthy bool + }{ + {"acceptDNS=true", true, true}, + {"acceptDNS=false", false, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + request := makeTestRequest(t, domain, dns.TypeA, 0) + logf := tstest.WhileTestRunningLogger(t) + bus := eventbustest.NewBus(t) + netMon, err := netmon.New(bus, logf) + if err != nil { + t.Fatal(err) + } + + var dialer tsdial.Dialer + dialer.SetNetMon(netMon) + dialer.SetBus(bus) + + ht := health.NewTracker(bus) + fwd := newForwarder(logf, netMon, nil, &dialer, ht, nil) + fwd.acceptDNS = tt.acceptDNS + + port1 := runDNSServer(t, nil, makeTestResponse(t, domain, dns.RCodeServerFailure), func(bool, []byte) {}) + port2 := runDNSServer(t, nil, makeTestResponse(t, domain, dns.RCodeServerFailure), func(bool, []byte) {}) + + resolvers := []resolverAndDelay{ + {name: &dnstype.Resolver{Addr: fmt.Sprintf("127.0.0.1:%d", port1)}}, + {name: &dnstype.Resolver{Addr: fmt.Sprintf("127.0.0.1:%d", port2)}}, + } + + rpkt := packet{ + bs: request, + family: "udp", + addr: netip.MustParseAddrPort("127.0.0.1:12345"), + } + + // Use an unbuffered responseChan so the send blocks, forcing the + // ctx.Done path and the SetUnhealthy call. + responseChan := make(chan packet) + + ctx, cancel := context.WithCancel(context.Background()) + // Cancel after DNS servers have had time to respond and their errors + // collected, leaving forwardWithDestChan blocked on responseChan. + go func() { + time.Sleep(50 * time.Millisecond) + cancel() + }() + + fwd.forwardWithDestChan(ctx, rpkt, responseChan, resolvers...) + + if got := ht.IsUnhealthy(dnsForwarderFailing); got != tt.wantUnhealthy { + t.Errorf("IsUnhealthy = %v, want %v", got, tt.wantUnhealthy) + } + }) + } +} diff --git a/net/dns/resolver/tsdns.go b/net/dns/resolver/tsdns.go index 53f130a8a..01f0c8a63 100644 --- a/net/dns/resolver/tsdns.go +++ b/net/dns/resolver/tsdns.go @@ -1402,21 +1402,23 @@ var ( metricDNSFwdErrorType = clientmetric.NewCounter("dns_query_fwd_error_type") metricDNSFwdTruncated = clientmetric.NewCounter("dns_query_fwd_truncated") - metricDNSFwdUDP = clientmetric.NewCounter("dns_query_fwd_udp") // on entry - metricDNSFwdUDPWrote = clientmetric.NewCounter("dns_query_fwd_udp_wrote") // sent UDP packet - metricDNSFwdUDPErrorWrite = clientmetric.NewCounter("dns_query_fwd_udp_error_write") - metricDNSFwdUDPErrorServer = clientmetric.NewCounter("dns_query_fwd_udp_error_server") - metricDNSFwdUDPErrorTxID = clientmetric.NewCounter("dns_query_fwd_udp_error_txid") - metricDNSFwdUDPErrorRead = clientmetric.NewCounter("dns_query_fwd_udp_error_read") - metricDNSFwdUDPSuccess = clientmetric.NewCounter("dns_query_fwd_udp_success") - - metricDNSFwdTCP = clientmetric.NewCounter("dns_query_fwd_tcp") // on entry - metricDNSFwdTCPWrote = clientmetric.NewCounter("dns_query_fwd_tcp_wrote") // sent TCP packet - metricDNSFwdTCPErrorWrite = clientmetric.NewCounter("dns_query_fwd_tcp_error_write") - metricDNSFwdTCPErrorServer = clientmetric.NewCounter("dns_query_fwd_tcp_error_server") - metricDNSFwdTCPErrorTxID = clientmetric.NewCounter("dns_query_fwd_tcp_error_txid") - metricDNSFwdTCPErrorRead = clientmetric.NewCounter("dns_query_fwd_tcp_error_read") - metricDNSFwdTCPSuccess = clientmetric.NewCounter("dns_query_fwd_tcp_success") + metricDNSFwdUDP = clientmetric.NewCounter("dns_query_fwd_udp") // on entry + metricDNSFwdUDPWrote = clientmetric.NewCounter("dns_query_fwd_udp_wrote") // sent UDP packet + metricDNSFwdUDPErrorWrite = clientmetric.NewCounter("dns_query_fwd_udp_error_write") + metricDNSFwdUDPErrorServer = clientmetric.NewCounter("dns_query_fwd_udp_error_server") + metricDNSFwdUDPErrorRefused = clientmetric.NewCounter("dns_query_fwd_udp_error_refused") + metricDNSFwdUDPErrorTxID = clientmetric.NewCounter("dns_query_fwd_udp_error_txid") + metricDNSFwdUDPErrorRead = clientmetric.NewCounter("dns_query_fwd_udp_error_read") + metricDNSFwdUDPSuccess = clientmetric.NewCounter("dns_query_fwd_udp_success") + + metricDNSFwdTCP = clientmetric.NewCounter("dns_query_fwd_tcp") // on entry + metricDNSFwdTCPWrote = clientmetric.NewCounter("dns_query_fwd_tcp_wrote") // sent TCP packet + metricDNSFwdTCPErrorWrite = clientmetric.NewCounter("dns_query_fwd_tcp_error_write") + metricDNSFwdTCPErrorServer = clientmetric.NewCounter("dns_query_fwd_tcp_error_server") + metricDNSFwdTCPErrorRefused = clientmetric.NewCounter("dns_query_fwd_tcp_error_refused") + metricDNSFwdTCPErrorTxID = clientmetric.NewCounter("dns_query_fwd_tcp_error_txid") + metricDNSFwdTCPErrorRead = clientmetric.NewCounter("dns_query_fwd_tcp_error_read") + metricDNSFwdTCPSuccess = clientmetric.NewCounter("dns_query_fwd_tcp_success") metricDNSFwdDoH = clientmetric.NewCounter("dns_query_fwd_doh") metricDNSFwdDoHErrorStatus = clientmetric.NewCounter("dns_query_fwd_doh_error_status") diff --git a/net/dns/resolver/tsdns_test.go b/net/dns/resolver/tsdns_test.go index 8ee22dd13..381ceedb4 100644 --- a/net/dns/resolver/tsdns_test.go +++ b/net/dns/resolver/tsdns_test.go @@ -1557,15 +1557,13 @@ func TestServfail(t *testing.T) { t.Fatalf("err = %v, want nil", err) } + // The upstream server's SERVFAIL bytes are returned directly. wantPkt := []byte{ 0x00, 0x00, // transaction id: 0 - 0x84, 0x02, // flags: response, authoritative, error: servfail - 0x00, 0x01, // one question + 0x00, 0x02, // flags: error: servfail + 0x00, 0x00, // no questions (upstream sent a minimal response) 0x00, 0x00, // no answers 0x00, 0x00, 0x00, 0x00, // no authority or additional RRs - // Question: - 0x04, 0x74, 0x65, 0x73, 0x74, 0x04, 0x73, 0x69, 0x74, 0x65, 0x00, // name - 0x00, 0x01, 0x00, 0x01, // type A, class IN } if !bytes.Equal(pkt, wantPkt) {