|
|
|
|
@ -5,11 +5,8 @@ |
|
|
|
|
package stunner |
|
|
|
|
|
|
|
|
|
import ( |
|
|
|
|
"bytes" |
|
|
|
|
"context" |
|
|
|
|
"crypto/rand" |
|
|
|
|
"fmt" |
|
|
|
|
"log" |
|
|
|
|
"net" |
|
|
|
|
"strconv" |
|
|
|
|
"sync" |
|
|
|
|
@ -29,79 +26,114 @@ import ( |
|
|
|
|
// for the connection. (An endpoint may be reported multiple times if
|
|
|
|
|
// multiple servers are provided.)
|
|
|
|
|
type Stunner struct { |
|
|
|
|
Send func([]byte, net.Addr) (int, error) // sends a packet
|
|
|
|
|
Endpoint func(endpoint string) // reports an endpoint
|
|
|
|
|
Servers []string // STUN servers to contact
|
|
|
|
|
// Send sends a packet.
|
|
|
|
|
// It will typically be a PacketConn.WriteTo method value.
|
|
|
|
|
Send func([]byte, net.Addr) (int, error) // sends a packet
|
|
|
|
|
|
|
|
|
|
// Endpoint is called whenever a STUN response is received.
|
|
|
|
|
// The server is the STUN server that replied, endpoint is the ip:port
|
|
|
|
|
// from the STUN response, and d is the duration that the STUN request
|
|
|
|
|
// took on the wire (not including DNS lookup time.
|
|
|
|
|
Endpoint func(server, endpoint string, d time.Duration) |
|
|
|
|
|
|
|
|
|
Servers []string // STUN servers to contact
|
|
|
|
|
|
|
|
|
|
// Resolver optionally specifies a resolver to use for DNS lookups.
|
|
|
|
|
// If nil, net.DefaultResolver is used.
|
|
|
|
|
Resolver *net.Resolver |
|
|
|
|
Logf func(format string, args ...interface{}) |
|
|
|
|
|
|
|
|
|
// Logf optionally specifies a log function. If nil, logging is disabled.
|
|
|
|
|
Logf func(format string, args ...interface{}) |
|
|
|
|
|
|
|
|
|
// OnlyIPv6 controls whether IPv6 is exclusively used.
|
|
|
|
|
// If false, only IPv4 is used. There is currently no mixed mode.
|
|
|
|
|
OnlyIPv6 bool |
|
|
|
|
|
|
|
|
|
// sessions tracks the state of each server.
|
|
|
|
|
// It's keyed by the STUN server (from the Servers field).
|
|
|
|
|
sessions map[string]*session |
|
|
|
|
|
|
|
|
|
mu sync.Mutex |
|
|
|
|
inFlight map[stun.TxID]request |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func (s *Stunner) addTX(tx stun.TxID, server string) { |
|
|
|
|
s.mu.Lock() |
|
|
|
|
defer s.mu.Unlock() |
|
|
|
|
if s.inFlight == nil { |
|
|
|
|
s.inFlight = make(map[stun.TxID]request) |
|
|
|
|
} |
|
|
|
|
s.inFlight[tx] = request{sent: time.Now(), server: server} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func (s *Stunner) removeTX(tx stun.TxID) (request, bool) { |
|
|
|
|
s.mu.Lock() |
|
|
|
|
defer s.mu.Unlock() |
|
|
|
|
r, ok := s.inFlight[tx] |
|
|
|
|
delete(s.inFlight, tx) |
|
|
|
|
return r, ok |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
type request struct { |
|
|
|
|
sent time.Time |
|
|
|
|
server string |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
type session struct { |
|
|
|
|
replied chan struct{} // closed when server responds
|
|
|
|
|
tIDs []stun.TxID // transaction IDs sent to a server
|
|
|
|
|
ctx context.Context // closed via call to done when reply received
|
|
|
|
|
cancel context.CancelFunc |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func (s *Stunner) logf(format string, args ...interface{}) { |
|
|
|
|
if s.Logf != nil { |
|
|
|
|
s.Logf(format, args...) |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// Receive delivers a STUN packet to the stunner.
|
|
|
|
|
func (s *Stunner) Receive(p []byte, fromAddr *net.UDPAddr) { |
|
|
|
|
if !stun.Is(p) { |
|
|
|
|
log.Println("stunner: received non-STUN packet") |
|
|
|
|
s.logf("stunner: received non-STUN packet") |
|
|
|
|
return |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
responseTID, addr, port, err := stun.ParseResponse(p) |
|
|
|
|
now := time.Now() |
|
|
|
|
tx, addr, port, err := stun.ParseResponse(p) |
|
|
|
|
if err != nil { |
|
|
|
|
log.Printf("stunner: received bad STUN response: %v", err) |
|
|
|
|
s.logf("stunner: received bad STUN response: %v", err) |
|
|
|
|
return |
|
|
|
|
} |
|
|
|
|
r, ok := s.removeTX(tx) |
|
|
|
|
if !ok { |
|
|
|
|
s.logf("stunner: got STUN packet for unknown TxID %x", tx) |
|
|
|
|
return |
|
|
|
|
} |
|
|
|
|
d := now.Sub(r.sent) |
|
|
|
|
|
|
|
|
|
session := s.sessions[r.server] |
|
|
|
|
if session != nil { |
|
|
|
|
host := net.JoinHostPort(net.IP(addr).String(), fmt.Sprint(port)) |
|
|
|
|
s.logf("STUN server %s reports public endpoint %s after %v", r.server, host, d) |
|
|
|
|
s.Endpoint(r.server, host, d) |
|
|
|
|
session.cancel() |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// Accept any of the tIDs from any of the active sessions.
|
|
|
|
|
for server, session := range s.sessions { |
|
|
|
|
for _, tID := range session.tIDs { |
|
|
|
|
if bytes.Equal(tID[:], responseTID[:]) { |
|
|
|
|
select { |
|
|
|
|
case <-session.replied: |
|
|
|
|
return // already got a reply from this server
|
|
|
|
|
default: |
|
|
|
|
} |
|
|
|
|
close(session.replied) |
|
|
|
|
|
|
|
|
|
// TODO(crawshaw): use different endpoints returned from
|
|
|
|
|
// different STUN servers to detect NAT types.
|
|
|
|
|
portStr := fmt.Sprintf("%d", port) |
|
|
|
|
host := net.JoinHostPort(net.IP(addr).String(), portStr) |
|
|
|
|
if s.Logf != nil { |
|
|
|
|
s.Logf("STUN server %s reports public endpoint %s", server, host) |
|
|
|
|
} |
|
|
|
|
s.Endpoint(host) |
|
|
|
|
return |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
func (s *Stunner) resolver() *net.Resolver { |
|
|
|
|
if s.Resolver != nil { |
|
|
|
|
return s.Resolver |
|
|
|
|
} |
|
|
|
|
log.Printf("stunner: received STUN packet for unknown transaction: %x", responseTID) |
|
|
|
|
return net.DefaultResolver |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// Run starts a Stunner and blocks until all servers either respond
|
|
|
|
|
// or are tried multiple times and timeout.
|
|
|
|
|
func (s *Stunner) Run(ctx context.Context) error { |
|
|
|
|
if s.Resolver == nil { |
|
|
|
|
s.Resolver = net.DefaultResolver |
|
|
|
|
} |
|
|
|
|
s.sessions = map[string]*session{} |
|
|
|
|
for _, server := range s.Servers { |
|
|
|
|
// Generate the transaction IDs for this session.
|
|
|
|
|
tIDs := make([]stun.TxID, len(retryDurations)) |
|
|
|
|
for i := range tIDs { |
|
|
|
|
if _, err := rand.Read(tIDs[i][:]); err != nil { |
|
|
|
|
return fmt.Errorf("stunner: rand failed: %v", err) |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
if s.sessions == nil { |
|
|
|
|
s.sessions = make(map[string]*session) |
|
|
|
|
} |
|
|
|
|
sctx, cancel := context.WithCancel(ctx) |
|
|
|
|
s.sessions[server] = &session{ |
|
|
|
|
replied: make(chan struct{}), |
|
|
|
|
tIDs: tIDs, |
|
|
|
|
ctx: sctx, |
|
|
|
|
cancel: cancel, |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
// after this point, the s.sessions map is read-only
|
|
|
|
|
@ -124,30 +156,26 @@ func (s *Stunner) runServer(ctx context.Context, server string) { |
|
|
|
|
|
|
|
|
|
for i, d := range retryDurations { |
|
|
|
|
ctx, cancel := context.WithTimeout(ctx, d) |
|
|
|
|
err := s.sendSTUN(ctx, session.tIDs[i], server) |
|
|
|
|
err := s.sendSTUN(ctx, server) |
|
|
|
|
if err != nil { |
|
|
|
|
if s.Logf != nil { |
|
|
|
|
s.Logf("stunner: %s: %v", server, err) |
|
|
|
|
} |
|
|
|
|
s.logf("stunner: %s: %v", server, err) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
select { |
|
|
|
|
case <-ctx.Done(): |
|
|
|
|
cancel() |
|
|
|
|
case <-session.replied: |
|
|
|
|
case <-session.ctx.Done(): |
|
|
|
|
cancel() |
|
|
|
|
if i > 0 && s.Logf != nil { |
|
|
|
|
s.Logf("stunner: slow STUN response from %s: %d retries", server, i) |
|
|
|
|
if i > 0 { |
|
|
|
|
s.logf("stunner: slow STUN response from %s: %d retries", server, i) |
|
|
|
|
} |
|
|
|
|
return |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
if s.Logf != nil { |
|
|
|
|
s.Logf("stunner: no STUN response from %s", server) |
|
|
|
|
} |
|
|
|
|
s.logf("stunner: no STUN response from %s", server) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func (s *Stunner) sendSTUN(ctx context.Context, tID stun.TxID, server string) error { |
|
|
|
|
func (s *Stunner) sendSTUN(ctx context.Context, server string) error { |
|
|
|
|
host, port, err := net.SplitHostPort(server) |
|
|
|
|
if err != nil { |
|
|
|
|
return err |
|
|
|
|
@ -161,23 +189,35 @@ func (s *Stunner) sendSTUN(ctx context.Context, tID stun.TxID, server string) er |
|
|
|
|
} |
|
|
|
|
addr := &net.UDPAddr{Port: addrPort} |
|
|
|
|
|
|
|
|
|
ipAddrs, err := s.Resolver.LookupIPAddr(ctx, host) |
|
|
|
|
ipAddrs, err := s.resolver().LookupIPAddr(ctx, host) |
|
|
|
|
if err != nil { |
|
|
|
|
return fmt.Errorf("lookup ip addr: %v", err) |
|
|
|
|
} |
|
|
|
|
for _, ipAddr := range ipAddrs { |
|
|
|
|
if ip4 := ipAddr.IP.To4(); ip4 != nil { |
|
|
|
|
ip4 := ipAddr.IP.To4() |
|
|
|
|
if ip4 != nil { |
|
|
|
|
if s.OnlyIPv6 { |
|
|
|
|
continue |
|
|
|
|
} |
|
|
|
|
addr.IP = ip4 |
|
|
|
|
addr.Zone = ipAddr.Zone |
|
|
|
|
break |
|
|
|
|
} else if s.OnlyIPv6 { |
|
|
|
|
addr.IP = ipAddr.IP |
|
|
|
|
addr.Zone = ipAddr.Zone |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
if addr.IP == nil { |
|
|
|
|
if s.OnlyIPv6 { |
|
|
|
|
return fmt.Errorf("cannot resolve any ipv6 addresses for %s, got: %v", server, ipAddrs) |
|
|
|
|
} |
|
|
|
|
return fmt.Errorf("cannot resolve any ipv4 addresses for %s, got: %v", server, ipAddrs) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
req := stun.Request(tID) |
|
|
|
|
if _, err := s.Send(req, addr); err != nil { |
|
|
|
|
txID := stun.NewTxID() |
|
|
|
|
req := stun.Request(txID) |
|
|
|
|
s.addTX(txID, server) |
|
|
|
|
_, err = s.Send(req, addr) |
|
|
|
|
if err != nil { |
|
|
|
|
return fmt.Errorf("send: %v", err) |
|
|
|
|
} |
|
|
|
|
return nil |
|
|
|
|
|