diff --git a/net/dnscache/dnscache.go b/net/dnscache/dnscache.go index 835158de7..cf1ef5817 100644 --- a/net/dnscache/dnscache.go +++ b/net/dnscache/dnscache.go @@ -279,8 +279,9 @@ type DialContextFunc func(ctx context.Context, network, address string) (net.Con // Dialer returns a wrapped DialContext func that uses the provided dnsCache. func Dialer(fwd DialContextFunc, dnsCache *Resolver) DialContextFunc { d := &dialer{ - fwd: fwd, - dnsCache: dnsCache, + fwd: fwd, + dnsCache: dnsCache, + pastConnect: map[netaddr.IP]time.Time{}, } return d.DialContext } @@ -289,6 +290,9 @@ func Dialer(fwd DialContextFunc, dnsCache *Resolver) DialContextFunc { type dialer struct { fwd DialContextFunc dnsCache *Resolver + + mu sync.Mutex + pastConnect map[netaddr.IP]time.Time } func (d *dialer) DialContext(ctx context.Context, network, address string) (retConn net.Conn, ret error) { @@ -306,8 +310,9 @@ func (d *dialer) DialContext(ctx context.Context, network, address string) (retC port: port, } defer func() { - // On any failure, assume our DNS is wrong and try our fallback, if any. - if ret == nil || d.dnsCache.LookupIPFallback == nil { + // On failure, consider that our DNS might be wrong and ask the DNS fallback mechanism for + // some other IPs to try. + if ret == nil || d.dnsCache.LookupIPFallback == nil || dc.dnsWasTrustworthy() { return } ips, err := d.dnsCache.LookupIPFallback(ctx, host) @@ -328,17 +333,23 @@ func (d *dialer) DialContext(ctx context.Context, network, address string) (retC } i4s := v4addrs(allIPs) if len(i4s) < 2 { - dst := net.JoinHostPort(ip.String(), port) if debug { - log.Printf("dnscache: dialing %s, %s for %s", network, dst, address) + log.Printf("dnscache: dialing %s, %s for %s", network, ip, address) } - c, err := d.fwd(ctx, network, dst) - if err == nil || ctx.Err() != nil || ip6 == nil { + ipNA, ok := netaddr.FromStdIP(ip) + if !ok { + return nil, fmt.Errorf("invalid IP %q", ip) + } + c, err := dc.dialOne(ctx, ipNA) + if err == nil || ctx.Err() != nil { return c, err } - // Fall back to trying IPv6. - dst = net.JoinHostPort(ip6.String(), port) - return d.fwd(ctx, network, dst) + // Fall back to trying IPv6, if any. + ip6NA, ok := netaddr.FromStdIP(ip6) + if !ok { + return nil, err + } + return dc.dialOne(ctx, ip6NA) } // Multiple IPv4 candidates, and 0+ IPv6. @@ -350,6 +361,77 @@ func (d *dialer) DialContext(ctx context.Context, network, address string) (retC type dialCall struct { d *dialer network, address, host, port string + + mu sync.Mutex // lock ordering: dialer.mu, then dialCall.mu + fails map[netaddr.IP]error // set of IPs that failed to dial thus far +} + +// dnsWasTrustworthy reports whether we think the IP address(es) we +// tried (and failed) to dial were probably the correct IPs. Currently +// the heuristic is whether they ever worked previously. +func (dc *dialCall) dnsWasTrustworthy() bool { + dc.d.mu.Lock() + defer dc.d.mu.Unlock() + dc.mu.Lock() + defer dc.mu.Unlock() + + if len(dc.fails) == 0 { + // No information. + return false + } + + // If any of the IPs we failed to dial worked previously in + // this dialer, assume the DNS is fine. + for ip := range dc.fails { + if _, ok := dc.d.pastConnect[ip]; ok { + return true + } + } + return false +} + +func (dc *dialCall) dialOne(ctx context.Context, ip netaddr.IP) (net.Conn, error) { + c, err := dc.d.fwd(ctx, dc.network, net.JoinHostPort(ip.String(), dc.port)) + dc.noteDialResult(ip, err) + return c, err +} + +// noteDialResult records that a dial to ip either succeeded or +// failed. +func (dc *dialCall) noteDialResult(ip netaddr.IP, err error) { + if err == nil { + d := dc.d + d.mu.Lock() + defer d.mu.Unlock() + d.pastConnect[ip] = time.Now() + return + } + dc.mu.Lock() + defer dc.mu.Unlock() + if dc.fails == nil { + dc.fails = map[netaddr.IP]error{} + } + dc.fails[ip] = err +} + +// uniqueIPs returns a possibly-mutated subslice of ips, filtering out +// dups and ones that have already failed previously. +func (dc *dialCall) uniqueIPs(ips []netaddr.IP) (ret []netaddr.IP) { + dc.mu.Lock() + defer dc.mu.Unlock() + seen := map[netaddr.IP]bool{} + ret = ips[:0] + for _, ip := range ips { + if seen[ip] { + continue + } + seen[ip] = true + if dc.fails[ip] != nil { + continue + } + ret = append(ret, ip) + } + return ret } // fallbackDelay is how long to wait between trying subsequent @@ -360,11 +442,6 @@ const fallbackDelay = 300 * time.Millisecond // raceDial tries to dial port on each ip in ips, starting a new race // dial every fallbackDelay apart, returning whichever completes first. func (dc *dialCall) raceDial(ctx context.Context, ips []netaddr.IP) (net.Conn, error) { - var ( - fwd = dc.d.fwd - network = dc.network - port = dc.port - ) ctx, cancel := context.WithCancel(ctx) defer cancel() @@ -375,6 +452,14 @@ func (dc *dialCall) raceDial(ctx context.Context, ips []netaddr.IP) (net.Conn, e resc := make(chan res) // must be unbuffered failBoost := make(chan struct{}) // best effort send on dial failure + // Remove IPs that we tried & failed to dial previously + // (such as when we're being called after a dnsfallback lookup and get + // the same results) + ips = dc.uniqueIPs(ips) + if len(ips) == 0 { + return nil, errors.New("no IPs") + } + go func() { for i, ip := range ips { if i != 0 { @@ -389,7 +474,7 @@ func (dc *dialCall) raceDial(ctx context.Context, ips []netaddr.IP) (net.Conn, e } } go func(ip netaddr.IP) { - c, err := fwd(ctx, network, net.JoinHostPort(ip.String(), port)) + c, err := dc.dialOne(ctx, ip) if err != nil { // Best effort wake-up a pending dial. // e.g. IPv4 dials failing quickly on an IPv6-only system. diff --git a/net/dnscache/dnscache_test.go b/net/dnscache/dnscache_test.go index 096049ccf..10cfd5398 100644 --- a/net/dnscache/dnscache_test.go +++ b/net/dnscache/dnscache_test.go @@ -6,10 +6,14 @@ package dnscache import ( "context" + "errors" "flag" "net" + "reflect" "testing" "time" + + "inet.af/netaddr" ) var dialTest = flag.String("dial-test", "", "if non-empty, addr:port to test dial") @@ -31,3 +35,78 @@ func TestDialer(t *testing.T) { t.Logf("dialed in %v", time.Since(t0)) c.Close() } + +func TestDialCall_DNSWasTrustworthy(t *testing.T) { + type step struct { + ip netaddr.IP // IP we pretended to dial + err error // the dial error or nil for success + } + mustIP := netaddr.MustParseIP + errFail := errors.New("some connect failure") + tests := []struct { + name string + steps []step + want bool + }{ + { + name: "no-info", + want: false, + }, + { + name: "previous-dial", + steps: []step{ + {mustIP("2003::1"), nil}, + {mustIP("2003::1"), errFail}, + }, + want: true, + }, + { + name: "no-previous-dial", + steps: []step{ + {mustIP("2003::1"), errFail}, + }, + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + d := &dialer{ + pastConnect: map[netaddr.IP]time.Time{}, + } + dc := &dialCall{ + d: d, + } + for _, st := range tt.steps { + dc.noteDialResult(st.ip, st.err) + } + got := dc.dnsWasTrustworthy() + if got != tt.want { + t.Errorf("got %v; want %v", got, tt.want) + } + }) + } +} + +func TestDialCall_uniqueIPs(t *testing.T) { + dc := &dialCall{} + mustIP := netaddr.MustParseIP + errFail := errors.New("some connect failure") + dc.noteDialResult(mustIP("2003::1"), errFail) + dc.noteDialResult(mustIP("2003::2"), errFail) + got := dc.uniqueIPs([]netaddr.IP{ + mustIP("2003::1"), + mustIP("2003::2"), + mustIP("2003::2"), + mustIP("2003::3"), + mustIP("2003::3"), + mustIP("2003::4"), + mustIP("2003::4"), + }) + want := []netaddr.IP{ + mustIP("2003::3"), + mustIP("2003::4"), + } + if !reflect.DeepEqual(got, want) { + t.Errorf("got %v; want %v", got, want) + } +}