|
|
|
|
@ -9,41 +9,30 @@ import ( |
|
|
|
|
"context" |
|
|
|
|
"encoding/binary" |
|
|
|
|
"errors" |
|
|
|
|
"fmt" |
|
|
|
|
"hash/crc32" |
|
|
|
|
"io" |
|
|
|
|
"math/rand" |
|
|
|
|
"net" |
|
|
|
|
"sync" |
|
|
|
|
"syscall" |
|
|
|
|
"time" |
|
|
|
|
|
|
|
|
|
dns "golang.org/x/net/dns/dnsmessage" |
|
|
|
|
"inet.af/netaddr" |
|
|
|
|
"tailscale.com/logtail/backoff" |
|
|
|
|
"tailscale.com/types/logger" |
|
|
|
|
"tailscale.com/util/dnsname" |
|
|
|
|
"tailscale.com/wgengine/monitor" |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
// headerBytes is the number of bytes in a DNS message header.
|
|
|
|
|
const headerBytes = 12 |
|
|
|
|
|
|
|
|
|
// connCount is the number of UDP connections to use for forwarding.
|
|
|
|
|
const connCount = 32 |
|
|
|
|
|
|
|
|
|
const ( |
|
|
|
|
// cleanupInterval is the interval between purged of timed-out entries from txMap.
|
|
|
|
|
cleanupInterval = 30 * time.Second |
|
|
|
|
// responseTimeout is the maximal amount of time to wait for a DNS response.
|
|
|
|
|
responseTimeout = 5 * time.Second |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
var errNoUpstreams = errors.New("upstream nameservers not set") |
|
|
|
|
|
|
|
|
|
type forwardingRecord struct { |
|
|
|
|
src netaddr.IPPort |
|
|
|
|
createdAt time.Time |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// txid identifies a DNS transaction.
|
|
|
|
|
//
|
|
|
|
|
// As the standard DNS Request ID is only 16 bits, we extend it:
|
|
|
|
|
@ -100,178 +89,164 @@ func getTxID(packet []byte) txid { |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
type route struct { |
|
|
|
|
suffix dnsname.FQDN |
|
|
|
|
resolvers []netaddr.IPPort |
|
|
|
|
Suffix dnsname.FQDN |
|
|
|
|
Resolvers []netaddr.IPPort |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// forwarder forwards DNS packets to a number of upstream nameservers.
|
|
|
|
|
type forwarder struct { |
|
|
|
|
logf logger.Logf |
|
|
|
|
logf logger.Logf |
|
|
|
|
linkMon *monitor.Mon |
|
|
|
|
linkSel ForwardLinkSelector |
|
|
|
|
|
|
|
|
|
ctx context.Context // good until Close
|
|
|
|
|
ctxCancel context.CancelFunc // closes ctx
|
|
|
|
|
|
|
|
|
|
// responses is a channel by which responses are returned.
|
|
|
|
|
responses chan packet |
|
|
|
|
// closed signals all goroutines to stop.
|
|
|
|
|
closed chan struct{} |
|
|
|
|
// wg signals when all goroutines have stopped.
|
|
|
|
|
wg sync.WaitGroup |
|
|
|
|
|
|
|
|
|
// conns are the UDP connections used for forwarding.
|
|
|
|
|
// A random one is selected for each request, regardless of the target upstream.
|
|
|
|
|
conns []*fwdConn |
|
|
|
|
|
|
|
|
|
mu sync.Mutex |
|
|
|
|
// routes are per-suffix resolvers to use.
|
|
|
|
|
routes []route // most specific routes first
|
|
|
|
|
txMap map[txid]forwardingRecord // txids to in-flight requests
|
|
|
|
|
|
|
|
|
|
mu sync.Mutex // guards following
|
|
|
|
|
|
|
|
|
|
// routes are per-suffix resolvers to use, with
|
|
|
|
|
// the most specific routes first.
|
|
|
|
|
routes []route |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func init() { |
|
|
|
|
rand.Seed(time.Now().UnixNano()) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func newForwarder(logf logger.Logf, responses chan packet) *forwarder { |
|
|
|
|
ret := &forwarder{ |
|
|
|
|
func newForwarder(logf logger.Logf, responses chan packet, linkMon *monitor.Mon, linkSel ForwardLinkSelector) *forwarder { |
|
|
|
|
f := &forwarder{ |
|
|
|
|
logf: logger.WithPrefix(logf, "forward: "), |
|
|
|
|
linkMon: linkMon, |
|
|
|
|
linkSel: linkSel, |
|
|
|
|
responses: responses, |
|
|
|
|
closed: make(chan struct{}), |
|
|
|
|
conns: make([]*fwdConn, connCount), |
|
|
|
|
txMap: make(map[txid]forwardingRecord), |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
ret.wg.Add(connCount + 1) |
|
|
|
|
for idx := range ret.conns { |
|
|
|
|
ret.conns[idx] = newFwdConn(ret.logf, idx) |
|
|
|
|
go ret.recv(ret.conns[idx]) |
|
|
|
|
} |
|
|
|
|
go ret.cleanMap() |
|
|
|
|
|
|
|
|
|
return ret |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func (f *forwarder) Close() { |
|
|
|
|
select { |
|
|
|
|
case <-f.closed: |
|
|
|
|
return |
|
|
|
|
default: |
|
|
|
|
// continue
|
|
|
|
|
} |
|
|
|
|
close(f.closed) |
|
|
|
|
|
|
|
|
|
for _, conn := range f.conns { |
|
|
|
|
conn.close() |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
f.wg.Wait() |
|
|
|
|
f.ctx, f.ctxCancel = context.WithCancel(context.Background()) |
|
|
|
|
return f |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func (f *forwarder) rebindFromNetworkChange() { |
|
|
|
|
for _, c := range f.conns { |
|
|
|
|
c.mu.Lock() |
|
|
|
|
c.reconnectLocked() |
|
|
|
|
c.mu.Unlock() |
|
|
|
|
} |
|
|
|
|
func (f *forwarder) Close() error { |
|
|
|
|
f.ctxCancel() |
|
|
|
|
return nil |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func (f *forwarder) setRoutes(routes []route) { |
|
|
|
|
f.mu.Lock() |
|
|
|
|
defer f.mu.Unlock() |
|
|
|
|
f.routes = routes |
|
|
|
|
f.mu.Unlock() |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// send sends packet to dst. It is best effort.
|
|
|
|
|
func (f *forwarder) send(packet []byte, dst netaddr.IPPort) { |
|
|
|
|
connIdx := rand.Intn(connCount) |
|
|
|
|
conn := f.conns[connIdx] |
|
|
|
|
conn.send(packet, dst) |
|
|
|
|
} |
|
|
|
|
var stdNetPacketListener packetListener = new(net.ListenConfig) |
|
|
|
|
|
|
|
|
|
func (f *forwarder) recv(conn *fwdConn) { |
|
|
|
|
defer f.wg.Done() |
|
|
|
|
type packetListener interface { |
|
|
|
|
ListenPacket(ctx context.Context, network, address string) (net.PacketConn, error) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
for { |
|
|
|
|
select { |
|
|
|
|
case <-f.closed: |
|
|
|
|
return |
|
|
|
|
default: |
|
|
|
|
} |
|
|
|
|
// The 1 extra byte is to detect packet truncation.
|
|
|
|
|
out := make([]byte, maxResponseBytes+1) |
|
|
|
|
n := conn.read(out) |
|
|
|
|
var truncated bool |
|
|
|
|
if n > maxResponseBytes { |
|
|
|
|
n = maxResponseBytes |
|
|
|
|
truncated = true |
|
|
|
|
} |
|
|
|
|
if n == 0 { |
|
|
|
|
continue |
|
|
|
|
} |
|
|
|
|
if n < headerBytes { |
|
|
|
|
f.logf("recv: packet too small (%d bytes)", n) |
|
|
|
|
} |
|
|
|
|
func (f *forwarder) packetListener(ip netaddr.IP) (packetListener, error) { |
|
|
|
|
if f.linkSel == nil || initListenConfig == nil { |
|
|
|
|
return stdNetPacketListener, nil |
|
|
|
|
} |
|
|
|
|
linkName := f.linkSel.PickLink(ip) |
|
|
|
|
if linkName == "" { |
|
|
|
|
return stdNetPacketListener, nil |
|
|
|
|
} |
|
|
|
|
lc := new(net.ListenConfig) |
|
|
|
|
if err := initListenConfig(lc, f.linkMon, linkName); err != nil { |
|
|
|
|
return nil, err |
|
|
|
|
} |
|
|
|
|
return lc, nil |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
out = out[:n] |
|
|
|
|
txid := getTxID(out) |
|
|
|
|
// send sends packet to dst. It is best effort.
|
|
|
|
|
//
|
|
|
|
|
// send expects the reply to have the same txid as txidOut.
|
|
|
|
|
//
|
|
|
|
|
// The provided closeOnCtxDone lets send register values to Close if
|
|
|
|
|
// the caller's ctx expires. This avoids send from allocating its own
|
|
|
|
|
// waiting goroutine to interrupt the ReadFrom, as memory is tight on
|
|
|
|
|
// iOS and we want the number of pending DNS lookups to be bursty
|
|
|
|
|
// without too much associated goroutine/memory cost.
|
|
|
|
|
func (f *forwarder) send(ctx context.Context, txidOut txid, closeOnCtxDone *closePool, packet []byte, dst netaddr.IPPort) ([]byte, error) { |
|
|
|
|
// TODO(bradfitz): if dst.IP is 8.8.8.8 or 8.8.4.4 or 1.1.1.1, etc, or
|
|
|
|
|
// something dynamically probed earlier to support DoH or DoT,
|
|
|
|
|
// do that here instead.
|
|
|
|
|
|
|
|
|
|
ln, err := f.packetListener(dst.IP()) |
|
|
|
|
if err != nil { |
|
|
|
|
return nil, err |
|
|
|
|
} |
|
|
|
|
conn, err := ln.ListenPacket(ctx, "udp", ":0") |
|
|
|
|
if err != nil { |
|
|
|
|
f.logf("ListenPacket failed: %v", err) |
|
|
|
|
return nil, err |
|
|
|
|
} |
|
|
|
|
defer conn.Close() |
|
|
|
|
|
|
|
|
|
if truncated { |
|
|
|
|
const dnsFlagTruncated = 0x200 |
|
|
|
|
flags := binary.BigEndian.Uint16(out[2:4]) |
|
|
|
|
flags |= dnsFlagTruncated |
|
|
|
|
binary.BigEndian.PutUint16(out[2:4], flags) |
|
|
|
|
closeOnCtxDone.Add(conn) |
|
|
|
|
defer closeOnCtxDone.Remove(conn) |
|
|
|
|
|
|
|
|
|
// TODO(#2067): Remove any incomplete records? RFC 1035 section 6.2
|
|
|
|
|
// states that truncation should head drop so that the authority
|
|
|
|
|
// section can be preserved if possible. However, the UDP read with
|
|
|
|
|
// a too-small buffer has already dropped the end, so that's the
|
|
|
|
|
// best we can do.
|
|
|
|
|
if _, err := conn.WriteTo(packet, dst.UDPAddr()); err != nil { |
|
|
|
|
if err := ctx.Err(); err != nil { |
|
|
|
|
return nil, err |
|
|
|
|
} |
|
|
|
|
return nil, err |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
f.mu.Lock() |
|
|
|
|
|
|
|
|
|
record, found := f.txMap[txid] |
|
|
|
|
// At most one nameserver will return a response:
|
|
|
|
|
// the first one to do so will delete txid from the map.
|
|
|
|
|
if !found { |
|
|
|
|
f.mu.Unlock() |
|
|
|
|
continue |
|
|
|
|
// The 1 extra byte is to detect packet truncation.
|
|
|
|
|
out := make([]byte, maxResponseBytes+1) |
|
|
|
|
n, _, err := conn.ReadFrom(out) |
|
|
|
|
if err != nil { |
|
|
|
|
if err := ctx.Err(); err != nil { |
|
|
|
|
return nil, err |
|
|
|
|
} |
|
|
|
|
delete(f.txMap, txid) |
|
|
|
|
|
|
|
|
|
f.mu.Unlock() |
|
|
|
|
|
|
|
|
|
pkt := packet{out, record.src} |
|
|
|
|
select { |
|
|
|
|
case <-f.closed: |
|
|
|
|
return |
|
|
|
|
case f.responses <- pkt: |
|
|
|
|
// continue
|
|
|
|
|
if packetWasTruncated(err) { |
|
|
|
|
err = nil |
|
|
|
|
} else { |
|
|
|
|
return nil, err |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// cleanMap periodically deletes timed-out forwarding records from f.txMap to bound growth.
|
|
|
|
|
func (f *forwarder) cleanMap() { |
|
|
|
|
defer f.wg.Done() |
|
|
|
|
truncated := n > maxResponseBytes |
|
|
|
|
if truncated { |
|
|
|
|
n = maxResponseBytes |
|
|
|
|
} |
|
|
|
|
if n < headerBytes { |
|
|
|
|
f.logf("recv: packet too small (%d bytes)", n) |
|
|
|
|
} |
|
|
|
|
out = out[:n] |
|
|
|
|
txid := getTxID(out) |
|
|
|
|
if txid != txidOut { |
|
|
|
|
return nil, errors.New("txid doesn't match") |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
t := time.NewTicker(cleanupInterval) |
|
|
|
|
defer t.Stop() |
|
|
|
|
if truncated { |
|
|
|
|
const dnsFlagTruncated = 0x200 |
|
|
|
|
flags := binary.BigEndian.Uint16(out[2:4]) |
|
|
|
|
flags |= dnsFlagTruncated |
|
|
|
|
binary.BigEndian.PutUint16(out[2:4], flags) |
|
|
|
|
|
|
|
|
|
// TODO(#2067): Remove any incomplete records? RFC 1035 section 6.2
|
|
|
|
|
// states that truncation should head drop so that the authority
|
|
|
|
|
// section can be preserved if possible. However, the UDP read with
|
|
|
|
|
// a too-small buffer has already dropped the end, so that's the
|
|
|
|
|
// best we can do.
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
var now time.Time |
|
|
|
|
for { |
|
|
|
|
select { |
|
|
|
|
case <-f.closed: |
|
|
|
|
return |
|
|
|
|
case now = <-t.C: |
|
|
|
|
// continue
|
|
|
|
|
} |
|
|
|
|
return out, nil |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
f.mu.Lock() |
|
|
|
|
for k, v := range f.txMap { |
|
|
|
|
if now.Sub(v.createdAt) > responseTimeout { |
|
|
|
|
delete(f.txMap, k) |
|
|
|
|
} |
|
|
|
|
// resolvers returns the resolvers to use for domain.
|
|
|
|
|
func (f *forwarder) resolvers(domain dnsname.FQDN) []netaddr.IPPort { |
|
|
|
|
f.mu.Lock() |
|
|
|
|
routes := f.routes |
|
|
|
|
f.mu.Unlock() |
|
|
|
|
for _, route := range routes { |
|
|
|
|
if route.Suffix == "." || route.Suffix.Contains(domain) { |
|
|
|
|
return route.Resolvers |
|
|
|
|
} |
|
|
|
|
f.mu.Unlock() |
|
|
|
|
} |
|
|
|
|
return nil |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// forward forwards the query to all upstream nameservers and returns the first response.
|
|
|
|
|
@ -283,225 +258,60 @@ func (f *forwarder) forward(query packet) error { |
|
|
|
|
|
|
|
|
|
txid := getTxID(query.bs) |
|
|
|
|
|
|
|
|
|
f.mu.Lock() |
|
|
|
|
routes := f.routes |
|
|
|
|
f.mu.Unlock() |
|
|
|
|
|
|
|
|
|
var resolvers []netaddr.IPPort |
|
|
|
|
for _, route := range routes { |
|
|
|
|
if route.suffix != "." && !route.suffix.Contains(domain) { |
|
|
|
|
continue |
|
|
|
|
} |
|
|
|
|
resolvers = route.resolvers |
|
|
|
|
break |
|
|
|
|
} |
|
|
|
|
resolvers := f.resolvers(domain) |
|
|
|
|
if len(resolvers) == 0 { |
|
|
|
|
return errNoUpstreams |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
f.mu.Lock() |
|
|
|
|
f.txMap[txid] = forwardingRecord{ |
|
|
|
|
src: query.addr, |
|
|
|
|
createdAt: time.Now(), |
|
|
|
|
} |
|
|
|
|
f.mu.Unlock() |
|
|
|
|
|
|
|
|
|
// TODO(#2066): EDNS size clamping
|
|
|
|
|
|
|
|
|
|
for _, resolver := range resolvers { |
|
|
|
|
f.send(query.bs, resolver) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
return nil |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// A fwdConn manages a single connection used to forward DNS requests.
|
|
|
|
|
// Net link changes can cause a *net.UDPConn to become permanently unusable, particularly on macOS.
|
|
|
|
|
// fwdConn detects such situations and transparently creates new connections.
|
|
|
|
|
type fwdConn struct { |
|
|
|
|
// logf allows a fwdConn to log.
|
|
|
|
|
logf logger.Logf |
|
|
|
|
|
|
|
|
|
// change allows calls to read to block until a the network connection has been replaced.
|
|
|
|
|
change *sync.Cond |
|
|
|
|
|
|
|
|
|
// mu protects fields that follow it; it is also change's Locker.
|
|
|
|
|
mu sync.Mutex |
|
|
|
|
// closed tracks whether fwdConn has been permanently closed.
|
|
|
|
|
closed bool |
|
|
|
|
// conn is the current active connection.
|
|
|
|
|
conn net.PacketConn |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func newFwdConn(logf logger.Logf, idx int) *fwdConn { |
|
|
|
|
c := new(fwdConn) |
|
|
|
|
c.logf = logger.WithPrefix(logf, fmt.Sprintf("fwdConn %d: ", idx)) |
|
|
|
|
c.change = sync.NewCond(&c.mu) |
|
|
|
|
// c.conn is created lazily in send
|
|
|
|
|
return c |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// send sends packet to dst using c's connection.
|
|
|
|
|
// It is best effort. It is UDP, after all. Failures are logged.
|
|
|
|
|
func (c *fwdConn) send(packet []byte, dst netaddr.IPPort) { |
|
|
|
|
var b *backoff.Backoff // lazily initialized, since it is not needed in the common case
|
|
|
|
|
backOff := func(err error) { |
|
|
|
|
if b == nil { |
|
|
|
|
b = backoff.NewBackoff("dns-fwdConn-send", c.logf, 30*time.Second) |
|
|
|
|
} |
|
|
|
|
b.BackOff(context.Background(), err) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
for { |
|
|
|
|
// Gather the current connection.
|
|
|
|
|
// We can't hold the lock while we call WriteTo.
|
|
|
|
|
c.mu.Lock() |
|
|
|
|
conn := c.conn |
|
|
|
|
closed := c.closed |
|
|
|
|
if closed { |
|
|
|
|
c.mu.Unlock() |
|
|
|
|
return |
|
|
|
|
} |
|
|
|
|
if conn == nil { |
|
|
|
|
c.reconnectLocked() |
|
|
|
|
c.mu.Unlock() |
|
|
|
|
continue |
|
|
|
|
} |
|
|
|
|
c.mu.Unlock() |
|
|
|
|
|
|
|
|
|
_, err := conn.WriteTo(packet, dst.UDPAddr()) |
|
|
|
|
if err == nil { |
|
|
|
|
// Success
|
|
|
|
|
return |
|
|
|
|
} |
|
|
|
|
if errors.Is(err, net.ErrClosed) { |
|
|
|
|
// We intentionally closed this connection.
|
|
|
|
|
// It has been replaced by a new connection. Try again.
|
|
|
|
|
continue |
|
|
|
|
} |
|
|
|
|
// Something else went wrong.
|
|
|
|
|
// We have three choices here: try again, give up, or create a new connection.
|
|
|
|
|
var opErr *net.OpError |
|
|
|
|
if !errors.As(err, &opErr) { |
|
|
|
|
// Weird. All errors from the net package should be *net.OpError. Bail.
|
|
|
|
|
c.logf("send: non-*net.OpErr %v (%T)", err, err) |
|
|
|
|
return |
|
|
|
|
} |
|
|
|
|
if opErr.Temporary() || opErr.Timeout() { |
|
|
|
|
// I doubt that either of these can happen (this is UDP),
|
|
|
|
|
// but go ahead and try again.
|
|
|
|
|
backOff(err) |
|
|
|
|
continue |
|
|
|
|
} |
|
|
|
|
if errors.Is(err, syscall.EHOSTUNREACH) { |
|
|
|
|
// "No route to host." The network stack is fine, but
|
|
|
|
|
// can't talk to this destination. Not much we can do
|
|
|
|
|
// about that, don't spam logs.
|
|
|
|
|
return |
|
|
|
|
} |
|
|
|
|
if networkIsDown(err) { |
|
|
|
|
// Fail.
|
|
|
|
|
c.logf("send: network is down") |
|
|
|
|
return |
|
|
|
|
} |
|
|
|
|
if networkIsUnreachable(err) { |
|
|
|
|
// This can be caused by a link change.
|
|
|
|
|
// Replace the existing connection with a new one.
|
|
|
|
|
c.mu.Lock() |
|
|
|
|
// It's possible that multiple senders discovered simultaneously
|
|
|
|
|
// that the network is unreachable. Avoid reconnecting multiple times:
|
|
|
|
|
// Only reconnect if the current connection is the one that we
|
|
|
|
|
// discovered to be problematic.
|
|
|
|
|
if c.conn == conn { |
|
|
|
|
backOff(err) |
|
|
|
|
c.reconnectLocked() |
|
|
|
|
closeOnCtxDone := new(closePool) |
|
|
|
|
defer closeOnCtxDone.Close() |
|
|
|
|
|
|
|
|
|
ctx, cancel := context.WithTimeout(f.ctx, responseTimeout) |
|
|
|
|
defer cancel() |
|
|
|
|
|
|
|
|
|
resc := make(chan []byte, 1) |
|
|
|
|
var ( |
|
|
|
|
mu sync.Mutex |
|
|
|
|
firstErr error |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
for _, ipp := range resolvers { |
|
|
|
|
go func(ipp netaddr.IPPort) { |
|
|
|
|
resb, err := f.send(ctx, txid, closeOnCtxDone, query.bs, ipp) |
|
|
|
|
if err != nil { |
|
|
|
|
mu.Lock() |
|
|
|
|
defer mu.Unlock() |
|
|
|
|
if firstErr == nil { |
|
|
|
|
firstErr = err |
|
|
|
|
} |
|
|
|
|
return |
|
|
|
|
} |
|
|
|
|
c.mu.Unlock() |
|
|
|
|
// Try again with our new network connection.
|
|
|
|
|
continue |
|
|
|
|
} |
|
|
|
|
// Unrecognized error. Fail.
|
|
|
|
|
c.logf("send: unrecognized error: %v", err) |
|
|
|
|
return |
|
|
|
|
select { |
|
|
|
|
case resc <- resb: |
|
|
|
|
default: |
|
|
|
|
} |
|
|
|
|
}(ipp) |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// read waits for a response from c's connection.
|
|
|
|
|
// It returns the number of bytes read, which may be 0
|
|
|
|
|
// in case of an error or a closed connection.
|
|
|
|
|
func (c *fwdConn) read(out []byte) int { |
|
|
|
|
for { |
|
|
|
|
// Gather the current connection.
|
|
|
|
|
// We can't hold the lock while we call ReadFrom.
|
|
|
|
|
c.mu.Lock() |
|
|
|
|
conn := c.conn |
|
|
|
|
closed := c.closed |
|
|
|
|
if closed { |
|
|
|
|
c.mu.Unlock() |
|
|
|
|
return 0 |
|
|
|
|
} |
|
|
|
|
if conn == nil { |
|
|
|
|
// There is no current connection.
|
|
|
|
|
// Wait for the connection to change, then try again.
|
|
|
|
|
c.change.Wait() |
|
|
|
|
c.mu.Unlock() |
|
|
|
|
continue |
|
|
|
|
} |
|
|
|
|
c.mu.Unlock() |
|
|
|
|
|
|
|
|
|
n, _, err := conn.ReadFrom(out) |
|
|
|
|
if err == nil || packetWasTruncated(err) { |
|
|
|
|
// Success.
|
|
|
|
|
return n |
|
|
|
|
select { |
|
|
|
|
case v := <-resc: |
|
|
|
|
select { |
|
|
|
|
case <-ctx.Done(): |
|
|
|
|
return ctx.Err() |
|
|
|
|
case f.responses <- packet{v, query.addr}: |
|
|
|
|
return nil |
|
|
|
|
} |
|
|
|
|
if errors.Is(err, net.ErrClosed) { |
|
|
|
|
// We intentionally closed this connection.
|
|
|
|
|
// It has been replaced by a new connection. Try again.
|
|
|
|
|
continue |
|
|
|
|
case <-ctx.Done(): |
|
|
|
|
mu.Lock() |
|
|
|
|
defer mu.Unlock() |
|
|
|
|
if firstErr != nil { |
|
|
|
|
return firstErr |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
c.logf("read: unrecognized error: %v", err) |
|
|
|
|
return 0 |
|
|
|
|
return ctx.Err() |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// reconnectLocked replaces the current connection with a new one.
|
|
|
|
|
// c.mu must be locked.
|
|
|
|
|
func (c *fwdConn) reconnectLocked() { |
|
|
|
|
c.closeConnLocked() |
|
|
|
|
// Make a new connection.
|
|
|
|
|
conn, err := net.ListenPacket("udp", "") |
|
|
|
|
if err != nil { |
|
|
|
|
c.logf("ListenPacket failed: %v", err) |
|
|
|
|
} else { |
|
|
|
|
c.conn = conn |
|
|
|
|
} |
|
|
|
|
// Broadcast that a new connection is available.
|
|
|
|
|
c.change.Broadcast() |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// closeCurrentConn closes the current connection.
|
|
|
|
|
// c.mu must be locked.
|
|
|
|
|
func (c *fwdConn) closeConnLocked() { |
|
|
|
|
if c.conn == nil { |
|
|
|
|
return |
|
|
|
|
} |
|
|
|
|
c.conn.Close() // unblocks all readers/writers, they'll pick up the next connection.
|
|
|
|
|
c.conn = nil |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// close permanently closes c.
|
|
|
|
|
func (c *fwdConn) close() { |
|
|
|
|
c.mu.Lock() |
|
|
|
|
defer c.mu.Unlock() |
|
|
|
|
if c.closed { |
|
|
|
|
return |
|
|
|
|
} |
|
|
|
|
c.closed = true |
|
|
|
|
c.closeConnLocked() |
|
|
|
|
// Unblock any remaining readers.
|
|
|
|
|
c.change.Broadcast() |
|
|
|
|
} |
|
|
|
|
var initListenConfig func(_ *net.ListenConfig, _ *monitor.Mon, tunName string) error |
|
|
|
|
|
|
|
|
|
// nameFromQuery extracts the normalized query name from bs.
|
|
|
|
|
func nameFromQuery(bs []byte) (dnsname.FQDN, error) { |
|
|
|
|
@ -523,3 +333,48 @@ func nameFromQuery(bs []byte) (dnsname.FQDN, error) { |
|
|
|
|
n := q.Name.Data[:q.Name.Length] |
|
|
|
|
return dnsname.ToFQDN(rawNameToLower(n)) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// closePool is a dynamic set of io.Closers to close as a group.
|
|
|
|
|
// It's intended to be Closed at most once.
|
|
|
|
|
//
|
|
|
|
|
// The zero value is ready for use.
|
|
|
|
|
type closePool struct { |
|
|
|
|
mu sync.Mutex |
|
|
|
|
m map[io.Closer]bool |
|
|
|
|
closed bool |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func (p *closePool) Add(c io.Closer) { |
|
|
|
|
p.mu.Lock() |
|
|
|
|
defer p.mu.Unlock() |
|
|
|
|
if p.closed { |
|
|
|
|
c.Close() |
|
|
|
|
return |
|
|
|
|
} |
|
|
|
|
if p.m == nil { |
|
|
|
|
p.m = map[io.Closer]bool{} |
|
|
|
|
} |
|
|
|
|
p.m[c] = true |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func (p *closePool) Remove(c io.Closer) { |
|
|
|
|
p.mu.Lock() |
|
|
|
|
defer p.mu.Unlock() |
|
|
|
|
if p.closed { |
|
|
|
|
return |
|
|
|
|
} |
|
|
|
|
delete(p.m, c) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func (p *closePool) Close() error { |
|
|
|
|
p.mu.Lock() |
|
|
|
|
defer p.mu.Unlock() |
|
|
|
|
if p.closed { |
|
|
|
|
return nil |
|
|
|
|
} |
|
|
|
|
p.closed = true |
|
|
|
|
for c := range p.m { |
|
|
|
|
c.Close() |
|
|
|
|
} |
|
|
|
|
return nil |
|
|
|
|
} |
|
|
|
|
|