diff --git a/net/dns/resolver/forwarder.go b/net/dns/resolver/forwarder.go index ca1599589..ed7ff78f7 100644 --- a/net/dns/resolver/forwarder.go +++ b/net/dns/resolver/forwarder.go @@ -739,6 +739,27 @@ type truncatedResponseError struct { func (tr truncatedResponseError) Error() string { return "response truncated" } +// rcodeResponseError is returned when an upstream DNS server responds with an +// rcode that is treated as a soft error (currently REFUSED and SERVFAIL). The +// response bytes are preserved so they can be returned to the client rather +// than synthesizing a new response. +type rcodeResponseError struct { + rcode dns.RCode + res []byte +} + +func (r rcodeResponseError) Error() string { return r.Unwrap().Error() } +func (r rcodeResponseError) Unwrap() error { + switch r.rcode { + case dns.RCodeRefused: + return errRefused + case dns.RCodeServerFailure: + return errServerFailure + } + return nil +} + +var errRefused = errors.New("response code indicates refusal") var errServerFailure = errors.New("response code indicates server issue") var errTxIDMismatch = errors.New("txid doesn't match") @@ -812,10 +833,16 @@ func (f *forwarder) sendUDP(ctx context.Context, fq *forwardQuery, rr resolverAn rcode := getRCode(out) // don't forward transient errors back to the client when the server fails - if rcode == dns.RCodeServerFailure { - f.logf("recv: response code indicating server failure: %d", rcode) + switch rcode { + case dns.RCodeServerFailure: + f.logf("sendUDP: response code indicating server failure: %d", rcode) metricDNSFwdUDPErrorServer.Add(1) - return nil, errServerFailure + return nil, rcodeResponseError{dns.RCodeServerFailure, out} + case dns.RCodeRefused: + // treat REFUSED as a soft error so other resolvers in the race can respond + f.logf("sendUDP: response code indicating refusal: %d", rcode) + metricDNSFwdUDPErrorRefused.Add(1) + return nil, rcodeResponseError{dns.RCodeRefused, out} } // Set the truncated bit if buffer was truncated during read and the flag isn't already set @@ -951,10 +978,16 @@ func (f *forwarder) sendTCP(ctx context.Context, fq *forwardQuery, rr resolverAn rcode := getRCode(out) // don't forward transient errors back to the client when the server fails - if rcode == dns.RCodeServerFailure { + switch rcode { + case dns.RCodeServerFailure: f.logf("sendTCP: response code indicating server failure: %d", rcode) metricDNSFwdTCPErrorServer.Add(1) - return nil, errServerFailure + return nil, rcodeResponseError{dns.RCodeServerFailure, out} + case dns.RCodeRefused: + // treat REFUSED as a soft error so other resolvers in the race can respond + f.logf("sendTCP: response code indicating refusal: %d", rcode) + metricDNSFwdTCPErrorRefused.Add(1) + return nil, rcodeResponseError{dns.RCodeRefused, out} } // TODO(andrew): do we need to do this? @@ -1128,6 +1161,7 @@ func (f *forwarder) forwardWithDestChan(ctx context.Context, query packet, respo var firstErr error var numErr int + var sawNonRefused bool for { select { case v := <-resc: @@ -1147,32 +1181,56 @@ func (f *forwarder) forwardWithDestChan(ctx context.Context, query packet, respo if firstErr == nil { firstErr = err } + if !errors.Is(err, errRefused) { + sawNonRefused = true + } numErr++ if numErr == len(resolvers) { - if errors.Is(firstErr, errServerFailure) { - res, err := servfailResponse(query) - if err != nil { - f.logf("building servfail response: %v", err) + var res packet + if sawNonRefused { + // At least one server failed with SERVFAIL or a transport error + // (e.g. network failure, TxID mismatch, unsupported resolver type). + // All such errors map to SERVFAIL at the client level. + // Prefer returning the upstream SERVFAIL bytes from firstErr if + // available; otherwise synthesize a SERVFAIL response. Note the + // rcode guard: firstErr may be a REFUSED rcodeResponseError if it + // arrived before the SERVFAIL that set sawNonRefused. + if rcodeErr, ok := errors.AsType[rcodeResponseError](firstErr); ok && rcodeErr.rcode == dns.RCodeServerFailure { + res = packet{rcodeErr.res, query.family, query.addr} + } else { + r, err := servfailResponse(query) + if err != nil { + f.logf("building servfail response: %v", err) + return firstErr + } + res = r + } + } else { + // !sawNonRefused means every error was an rcodeResponseError with rcode REFUSED, + // so firstErr is guaranteed to wrap one. + rcodeErr, ok := errors.AsType[rcodeResponseError](firstErr) + if !ok { + f.logf("unexpected: all errors were REFUSED but firstErr is not rcodeResponseError: %v", firstErr) return firstErr } - - select { - case <-ctx.Done(): - metricDNSFwdErrorContext.Add(1) - metricDNSFwdErrorContextGotError.Add(1) - var resolverAddrs []string - for _, rr := range resolvers { - resolverAddrs = append(resolverAddrs, rr.name.Addr) - } - if f.acceptDNS { - f.health.SetUnhealthy(dnsForwarderFailing, health.Args{health.ArgDNSServers: strings.Join(resolverAddrs, ",")}) - } - case responseChan <- res: - if f.verboseFwd { - f.logf("forwarder response(%d, %v, %d) = %d, %v", fq.txid, typ, len(domain), len(res.bs), firstErr) - } - return nil + res = packet{rcodeErr.res, query.family, query.addr} + } + select { + case <-ctx.Done(): + metricDNSFwdErrorContext.Add(1) + metricDNSFwdErrorContextGotError.Add(1) + var resolverAddrs []string + for _, rr := range resolvers { + resolverAddrs = append(resolverAddrs, rr.name.Addr) + } + if f.acceptDNS { + f.health.SetUnhealthy(dnsForwarderFailing, health.Args{health.ArgDNSServers: strings.Join(resolverAddrs, ",")}) + } + case responseChan <- res: + if f.verboseFwd { + f.logf("forwarder response(%d, %v, %d) = %d, %v", fq.txid, typ, len(domain), len(res.bs), firstErr) } + return nil } return firstErr } diff --git a/net/dns/resolver/forwarder_test.go b/net/dns/resolver/forwarder_test.go index 3ddb47433..670e8fe2b 100644 --- a/net/dns/resolver/forwarder_test.go +++ b/net/dns/resolver/forwarder_test.go @@ -1162,8 +1162,19 @@ func TestForwarderWithManyResolvers(t *testing.T) { }, }, { - name: "Refused", - responses: [][]byte{ // All upstream servers return different failures. + name: "AllRefused", + responses: [][]byte{ // All upstream servers return REFUSED. + makeTestResponse(t, domain, dns.RCodeRefused), + makeTestResponse(t, domain, dns.RCodeRefused), + makeTestResponse(t, domain, dns.RCodeRefused), + }, + wantResponses: [][]byte{ // When all refuse, return REFUSED to the client. + makeTestResponse(t, domain, dns.RCodeRefused), + }, + }, + { + name: "Refused+Success", + responses: [][]byte{ // Some upstream servers refuse, but one succeeds. makeTestResponse(t, domain, dns.RCodeRefused), makeTestResponse(t, domain, dns.RCodeRefused), makeTestResponse(t, domain, dns.RCodeRefused), @@ -1171,21 +1182,30 @@ func TestForwarderWithManyResolvers(t *testing.T) { makeTestResponse(t, domain, dns.RCodeRefused), makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.1")), }, - wantResponses: [][]byte{ // Refused is not considered to be an error and can be forwarded. - makeTestResponse(t, domain, dns.RCodeRefused), + wantResponses: [][]byte{ // Refused is treated as a soft error; the Success response should win. makeTestResponse(t, domain, dns.RCodeSuccess, netip.MustParseAddr("127.0.0.1")), }, }, + { + name: "Refused+ServFail", + responses: [][]byte{ // Some servers refuse, at least one fails. + makeTestResponse(t, domain, dns.RCodeRefused), + makeTestResponse(t, domain, dns.RCodeServerFailure), + makeTestResponse(t, domain, dns.RCodeRefused), + }, + wantResponses: [][]byte{ // Any non-REFUSED failure triggers SERVFAIL regardless of arrival order. + makeTestResponse(t, domain, dns.RCodeServerFailure), + }, + }, { name: "MixFail", - responses: [][]byte{ // All upstream servers return different failures. + responses: [][]byte{ // Upstream servers return different failures. makeTestResponse(t, domain, dns.RCodeServerFailure), makeTestResponse(t, domain, dns.RCodeNameError), makeTestResponse(t, domain, dns.RCodeRefused), }, - wantResponses: [][]byte{ // Both NXDomain and Refused can be forwarded. + wantResponses: [][]byte{ // SERVFAIL and REFUSED are soft errors; NXDOMAIN wins. makeTestResponse(t, domain, dns.RCodeNameError), - makeTestResponse(t, domain, dns.RCodeRefused), }, }, } @@ -1297,3 +1317,71 @@ func TestForwarderVerboseLogs(t *testing.T) { t.Errorf("expected forwarding log, got:\n%s", logStr) } } + +// TestForwarderHealthOnContextExpiry verifies that when all resolvers fail and +// the context expires before the response can be sent, the health tracker is +// set unhealthy if and only if acceptDNS is true. +func TestForwarderHealthOnContextExpiry(t *testing.T) { + const domain = "health-test.example.com." + + tests := []struct { + name string + acceptDNS bool + wantUnhealthy bool + }{ + {"acceptDNS=true", true, true}, + {"acceptDNS=false", false, false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + request := makeTestRequest(t, domain, dns.TypeA, 0) + logf := tstest.WhileTestRunningLogger(t) + bus := eventbustest.NewBus(t) + netMon, err := netmon.New(bus, logf) + if err != nil { + t.Fatal(err) + } + + var dialer tsdial.Dialer + dialer.SetNetMon(netMon) + dialer.SetBus(bus) + + ht := health.NewTracker(bus) + fwd := newForwarder(logf, netMon, nil, &dialer, ht, nil) + fwd.acceptDNS = tt.acceptDNS + + port1 := runDNSServer(t, nil, makeTestResponse(t, domain, dns.RCodeServerFailure), func(bool, []byte) {}) + port2 := runDNSServer(t, nil, makeTestResponse(t, domain, dns.RCodeServerFailure), func(bool, []byte) {}) + + resolvers := []resolverAndDelay{ + {name: &dnstype.Resolver{Addr: fmt.Sprintf("127.0.0.1:%d", port1)}}, + {name: &dnstype.Resolver{Addr: fmt.Sprintf("127.0.0.1:%d", port2)}}, + } + + rpkt := packet{ + bs: request, + family: "udp", + addr: netip.MustParseAddrPort("127.0.0.1:12345"), + } + + // Use an unbuffered responseChan so the send blocks, forcing the + // ctx.Done path and the SetUnhealthy call. + responseChan := make(chan packet) + + ctx, cancel := context.WithCancel(context.Background()) + // Cancel after DNS servers have had time to respond and their errors + // collected, leaving forwardWithDestChan blocked on responseChan. + go func() { + time.Sleep(50 * time.Millisecond) + cancel() + }() + + fwd.forwardWithDestChan(ctx, rpkt, responseChan, resolvers...) + + if got := ht.IsUnhealthy(dnsForwarderFailing); got != tt.wantUnhealthy { + t.Errorf("IsUnhealthy = %v, want %v", got, tt.wantUnhealthy) + } + }) + } +} diff --git a/net/dns/resolver/tsdns.go b/net/dns/resolver/tsdns.go index 53f130a8a..01f0c8a63 100644 --- a/net/dns/resolver/tsdns.go +++ b/net/dns/resolver/tsdns.go @@ -1402,21 +1402,23 @@ var ( metricDNSFwdErrorType = clientmetric.NewCounter("dns_query_fwd_error_type") metricDNSFwdTruncated = clientmetric.NewCounter("dns_query_fwd_truncated") - metricDNSFwdUDP = clientmetric.NewCounter("dns_query_fwd_udp") // on entry - metricDNSFwdUDPWrote = clientmetric.NewCounter("dns_query_fwd_udp_wrote") // sent UDP packet - metricDNSFwdUDPErrorWrite = clientmetric.NewCounter("dns_query_fwd_udp_error_write") - metricDNSFwdUDPErrorServer = clientmetric.NewCounter("dns_query_fwd_udp_error_server") - metricDNSFwdUDPErrorTxID = clientmetric.NewCounter("dns_query_fwd_udp_error_txid") - metricDNSFwdUDPErrorRead = clientmetric.NewCounter("dns_query_fwd_udp_error_read") - metricDNSFwdUDPSuccess = clientmetric.NewCounter("dns_query_fwd_udp_success") - - metricDNSFwdTCP = clientmetric.NewCounter("dns_query_fwd_tcp") // on entry - metricDNSFwdTCPWrote = clientmetric.NewCounter("dns_query_fwd_tcp_wrote") // sent TCP packet - metricDNSFwdTCPErrorWrite = clientmetric.NewCounter("dns_query_fwd_tcp_error_write") - metricDNSFwdTCPErrorServer = clientmetric.NewCounter("dns_query_fwd_tcp_error_server") - metricDNSFwdTCPErrorTxID = clientmetric.NewCounter("dns_query_fwd_tcp_error_txid") - metricDNSFwdTCPErrorRead = clientmetric.NewCounter("dns_query_fwd_tcp_error_read") - metricDNSFwdTCPSuccess = clientmetric.NewCounter("dns_query_fwd_tcp_success") + metricDNSFwdUDP = clientmetric.NewCounter("dns_query_fwd_udp") // on entry + metricDNSFwdUDPWrote = clientmetric.NewCounter("dns_query_fwd_udp_wrote") // sent UDP packet + metricDNSFwdUDPErrorWrite = clientmetric.NewCounter("dns_query_fwd_udp_error_write") + metricDNSFwdUDPErrorServer = clientmetric.NewCounter("dns_query_fwd_udp_error_server") + metricDNSFwdUDPErrorRefused = clientmetric.NewCounter("dns_query_fwd_udp_error_refused") + metricDNSFwdUDPErrorTxID = clientmetric.NewCounter("dns_query_fwd_udp_error_txid") + metricDNSFwdUDPErrorRead = clientmetric.NewCounter("dns_query_fwd_udp_error_read") + metricDNSFwdUDPSuccess = clientmetric.NewCounter("dns_query_fwd_udp_success") + + metricDNSFwdTCP = clientmetric.NewCounter("dns_query_fwd_tcp") // on entry + metricDNSFwdTCPWrote = clientmetric.NewCounter("dns_query_fwd_tcp_wrote") // sent TCP packet + metricDNSFwdTCPErrorWrite = clientmetric.NewCounter("dns_query_fwd_tcp_error_write") + metricDNSFwdTCPErrorServer = clientmetric.NewCounter("dns_query_fwd_tcp_error_server") + metricDNSFwdTCPErrorRefused = clientmetric.NewCounter("dns_query_fwd_tcp_error_refused") + metricDNSFwdTCPErrorTxID = clientmetric.NewCounter("dns_query_fwd_tcp_error_txid") + metricDNSFwdTCPErrorRead = clientmetric.NewCounter("dns_query_fwd_tcp_error_read") + metricDNSFwdTCPSuccess = clientmetric.NewCounter("dns_query_fwd_tcp_success") metricDNSFwdDoH = clientmetric.NewCounter("dns_query_fwd_doh") metricDNSFwdDoHErrorStatus = clientmetric.NewCounter("dns_query_fwd_doh_error_status") diff --git a/net/dns/resolver/tsdns_test.go b/net/dns/resolver/tsdns_test.go index 8ee22dd13..381ceedb4 100644 --- a/net/dns/resolver/tsdns_test.go +++ b/net/dns/resolver/tsdns_test.go @@ -1557,15 +1557,13 @@ func TestServfail(t *testing.T) { t.Fatalf("err = %v, want nil", err) } + // The upstream server's SERVFAIL bytes are returned directly. wantPkt := []byte{ 0x00, 0x00, // transaction id: 0 - 0x84, 0x02, // flags: response, authoritative, error: servfail - 0x00, 0x01, // one question + 0x00, 0x02, // flags: error: servfail + 0x00, 0x00, // no questions (upstream sent a minimal response) 0x00, 0x00, // no answers 0x00, 0x00, 0x00, 0x00, // no authority or additional RRs - // Question: - 0x04, 0x74, 0x65, 0x73, 0x74, 0x04, 0x73, 0x69, 0x74, 0x65, 0x00, // name - 0x00, 0x01, 0x00, 0x01, // type A, class IN } if !bytes.Equal(pkt, wantPkt) {