|
|
|
|
@ -22,11 +22,6 @@ import ( |
|
|
|
|
// headerBytes is the number of bytes in a DNS message header.
|
|
|
|
|
const headerBytes = 12 |
|
|
|
|
|
|
|
|
|
// forwardQueueSize is the maximal number of requests that can be pending delegation.
|
|
|
|
|
// Note that this is distinct from the number of requests that are pending a response,
|
|
|
|
|
// which is not limited (except by txid collisions).
|
|
|
|
|
const forwardQueueSize = 64 |
|
|
|
|
|
|
|
|
|
// connCount is the number of UDP connections to use for forwarding.
|
|
|
|
|
const connCount = 32 |
|
|
|
|
|
|
|
|
|
@ -138,7 +133,6 @@ func newForwarder(logf logger.Logf, responses chan Packet) *forwarder { |
|
|
|
|
return &forwarder{ |
|
|
|
|
logf: logger.WithPrefix(logf, "forward: "), |
|
|
|
|
responses: responses, |
|
|
|
|
queue: make(chan forwardedPacket, forwardQueueSize), |
|
|
|
|
closed: make(chan struct{}), |
|
|
|
|
conns: make([]*net.UDPConn, connCount), |
|
|
|
|
txMap: make(map[txid]forwardingRecord), |
|
|
|
|
@ -155,11 +149,10 @@ func (f *forwarder) Start() error { |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
f.wg.Add(connCount + 2) |
|
|
|
|
f.wg.Add(connCount + 1) |
|
|
|
|
for idx, conn := range f.conns { |
|
|
|
|
go f.recv(uint16(idx), conn) |
|
|
|
|
} |
|
|
|
|
go f.send() |
|
|
|
|
go f.cleanMap() |
|
|
|
|
|
|
|
|
|
return nil |
|
|
|
|
@ -191,28 +184,13 @@ func (f *forwarder) setUpstreams(upstreams []net.Addr) { |
|
|
|
|
f.mu.Unlock() |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func (f *forwarder) send() { |
|
|
|
|
defer f.wg.Done() |
|
|
|
|
|
|
|
|
|
var packet forwardedPacket |
|
|
|
|
for { |
|
|
|
|
select { |
|
|
|
|
case <-f.closed: |
|
|
|
|
return |
|
|
|
|
case packet = <-f.queue: |
|
|
|
|
// continue
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
connIdx := rand.Intn(connCount) |
|
|
|
|
conn := f.conns[connIdx] |
|
|
|
|
_, err := conn.WriteTo(packet.payload, packet.dst) |
|
|
|
|
if err != nil { |
|
|
|
|
// Do not log errors due to expired deadline.
|
|
|
|
|
if !errors.Is(err, os.ErrDeadlineExceeded) { |
|
|
|
|
f.logf("send: %v", err) |
|
|
|
|
} |
|
|
|
|
return |
|
|
|
|
} |
|
|
|
|
func (f *forwarder) send(packet []byte, dst net.Addr) { |
|
|
|
|
connIdx := rand.Intn(connCount) |
|
|
|
|
conn := f.conns[connIdx] |
|
|
|
|
_, err := conn.WriteTo(packet, dst) |
|
|
|
|
// Do not log errors due to expired deadline.
|
|
|
|
|
if err != nil && !errors.Is(err, os.ErrDeadlineExceeded) { |
|
|
|
|
f.logf("send: %v", err) |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
@ -308,17 +286,8 @@ func (f *forwarder) forward(query Packet) error { |
|
|
|
|
|
|
|
|
|
f.mu.Unlock() |
|
|
|
|
|
|
|
|
|
packet := forwardedPacket{ |
|
|
|
|
payload: query.Payload, |
|
|
|
|
} |
|
|
|
|
for _, upstream := range upstreams { |
|
|
|
|
packet.dst = upstream |
|
|
|
|
select { |
|
|
|
|
case <-f.closed: |
|
|
|
|
return ErrClosed |
|
|
|
|
case f.queue <- packet: |
|
|
|
|
// continue
|
|
|
|
|
} |
|
|
|
|
f.send(query.Payload, upstream) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
return nil |
|
|
|
|
|