net/udprelay: replace map+sync.Mutex with sync.Map for VNI lookup

This commit also introduces a sync.Mutex for guarding mutatable fields
on serverEndpoint, now that it is no longer guarded by the sync.Mutex
in Server.

These changes reduce lock contention and by effect increase aggregate
throughput under high flow count load. A benchmark on Linux with AWS
c8gn instances showed a ~30% increase in aggregate throughput (37Gb/s
vs 28Gb/s) for 12 tailscaled flows.

Updates tailscale/corp#35264

Signed-off-by: Jordan Whited <jordan@tailscale.com>
This commit is contained in:
Jordan Whited
2025-12-15 12:14:34 -08:00
committed by Jordan Whited
parent 951d711054
commit a663639bea
2 changed files with 93 additions and 82 deletions
+81 -69
View File
@@ -77,8 +77,8 @@ type Server struct {
closeCh chan struct{} closeCh chan struct{}
netChecker *netcheck.Client netChecker *netcheck.Client
mu sync.Mutex // guards the following fields mu sync.Mutex // guards the following fields
macSecrets [][blake2s.Size]byte // [0] is most recent, max 2 elements macSecrets views.Slice[[blake2s.Size]byte] // [0] is most recent, max 2 elements
macSecretRotatedAt mono.Time macSecretRotatedAt mono.Time
derpMap *tailcfg.DERPMap derpMap *tailcfg.DERPMap
onlyStaticAddrPorts bool // no dynamic addr port discovery when set onlyStaticAddrPorts bool // no dynamic addr port discovery when set
@@ -87,8 +87,11 @@ type Server struct {
closed bool closed bool
lamportID uint64 lamportID uint64
nextVNI uint32 nextVNI uint32
byVNI map[uint32]*serverEndpoint // serverEndpointByVNI is consistent with serverEndpointByDisco while mu is
byDisco map[key.SortedPairOfDiscoPublic]*serverEndpoint // held, i.e. mu must be held around write ops. Read ops in performance
// sensitive paths, e.g. packet forwarding, do not need to acquire mu.
serverEndpointByVNI sync.Map // key is uint32 (Geneve VNI), value is [*serverEndpoint]
serverEndpointByDisco map[key.SortedPairOfDiscoPublic]*serverEndpoint
} }
const macSecretRotationInterval = time.Minute * 2 const macSecretRotationInterval = time.Minute * 2
@@ -100,23 +103,23 @@ const (
) )
// serverEndpoint contains Server-internal [endpoint.ServerEndpoint] state. // serverEndpoint contains Server-internal [endpoint.ServerEndpoint] state.
// serverEndpoint methods are not thread-safe.
type serverEndpoint struct { type serverEndpoint struct {
// discoPubKeys contains the key.DiscoPublic of the served clients. The // discoPubKeys contains the key.DiscoPublic of the served clients. The
// indexing of this array aligns with the following fields, e.g. // indexing of this array aligns with the following fields, e.g.
// discoSharedSecrets[0] is the shared secret to use when sealing // discoSharedSecrets[0] is the shared secret to use when sealing
// Disco protocol messages for transmission towards discoPubKeys[0]. // Disco protocol messages for transmission towards discoPubKeys[0].
discoPubKeys key.SortedPairOfDiscoPublic discoPubKeys key.SortedPairOfDiscoPublic
discoSharedSecrets [2]key.DiscoShared discoSharedSecrets [2]key.DiscoShared
lamportID uint64
vni uint32
allocatedAt mono.Time
mu sync.Mutex // guards the following fields
inProgressGeneration [2]uint32 // or zero if a handshake has never started, or has just completed inProgressGeneration [2]uint32 // or zero if a handshake has never started, or has just completed
boundAddrPorts [2]netip.AddrPort // or zero value if a handshake has never completed for that relay leg boundAddrPorts [2]netip.AddrPort // or zero value if a handshake has never completed for that relay leg
lastSeen [2]mono.Time lastSeen [2]mono.Time
packetsRx [2]uint64 // num packets received from/sent by each client after they are bound packetsRx [2]uint64 // num packets received from/sent by each client after they are bound
bytesRx [2]uint64 // num bytes received from/sent by each client after they are bound bytesRx [2]uint64 // num bytes received from/sent by each client after they are bound
lamportID uint64
vni uint32
allocatedAt mono.Time
} }
func blakeMACFromBindMsg(blakeKey [blake2s.Size]byte, src netip.AddrPort, msg disco.BindUDPRelayEndpointCommon) ([blake2s.Size]byte, error) { func blakeMACFromBindMsg(blakeKey [blake2s.Size]byte, src netip.AddrPort, msg disco.BindUDPRelayEndpointCommon) ([blake2s.Size]byte, error) {
@@ -141,7 +144,10 @@ func blakeMACFromBindMsg(blakeKey [blake2s.Size]byte, src netip.AddrPort, msg di
return out, nil return out, nil
} }
func (e *serverEndpoint) handleDiscoControlMsg(from netip.AddrPort, senderIndex int, discoMsg disco.Message, serverDisco key.DiscoPublic, macSecrets [][blake2s.Size]byte, now mono.Time) (write []byte, to netip.AddrPort) { func (e *serverEndpoint) handleDiscoControlMsg(from netip.AddrPort, senderIndex int, discoMsg disco.Message, serverDisco key.DiscoPublic, macSecrets views.Slice[[blake2s.Size]byte], now mono.Time) (write []byte, to netip.AddrPort) {
e.mu.Lock()
defer e.mu.Unlock()
if senderIndex != 0 && senderIndex != 1 { if senderIndex != 0 && senderIndex != 1 {
return nil, netip.AddrPort{} return nil, netip.AddrPort{}
} }
@@ -186,7 +192,7 @@ func (e *serverEndpoint) handleDiscoControlMsg(from netip.AddrPort, senderIndex
} }
reply = append(reply, disco.Magic...) reply = append(reply, disco.Magic...)
reply = serverDisco.AppendTo(reply) reply = serverDisco.AppendTo(reply)
mac, err := blakeMACFromBindMsg(macSecrets[0], from, m.BindUDPRelayEndpointCommon) mac, err := blakeMACFromBindMsg(macSecrets.At(0), from, m.BindUDPRelayEndpointCommon)
if err != nil { if err != nil {
return nil, netip.AddrPort{} return nil, netip.AddrPort{}
} }
@@ -206,7 +212,7 @@ func (e *serverEndpoint) handleDiscoControlMsg(from netip.AddrPort, senderIndex
// silently drop // silently drop
return nil, netip.AddrPort{} return nil, netip.AddrPort{}
} }
for _, macSecret := range macSecrets { for _, macSecret := range macSecrets.All() {
mac, err := blakeMACFromBindMsg(macSecret, from, discoMsg.BindUDPRelayEndpointCommon) mac, err := blakeMACFromBindMsg(macSecret, from, discoMsg.BindUDPRelayEndpointCommon)
if err != nil { if err != nil {
// silently drop // silently drop
@@ -230,7 +236,7 @@ func (e *serverEndpoint) handleDiscoControlMsg(from netip.AddrPort, senderIndex
} }
} }
func (e *serverEndpoint) handleSealedDiscoControlMsg(from netip.AddrPort, b []byte, serverDisco key.DiscoPublic, macSecrets [][blake2s.Size]byte, now mono.Time) (write []byte, to netip.AddrPort) { func (e *serverEndpoint) handleSealedDiscoControlMsg(from netip.AddrPort, b []byte, serverDisco key.DiscoPublic, macSecrets views.Slice[[blake2s.Size]byte], now mono.Time) (write []byte, to netip.AddrPort) {
senderRaw, isDiscoMsg := disco.Source(b) senderRaw, isDiscoMsg := disco.Source(b)
if !isDiscoMsg { if !isDiscoMsg {
// Not a Disco message // Not a Disco message
@@ -265,7 +271,9 @@ func (e *serverEndpoint) handleSealedDiscoControlMsg(from netip.AddrPort, b []by
} }
func (e *serverEndpoint) handleDataPacket(from netip.AddrPort, b []byte, now mono.Time) (write []byte, to netip.AddrPort) { func (e *serverEndpoint) handleDataPacket(from netip.AddrPort, b []byte, now mono.Time) (write []byte, to netip.AddrPort) {
if !e.isBound() { e.mu.Lock()
defer e.mu.Unlock()
if !e.isBoundLocked() {
// not a control packet, but serverEndpoint isn't bound // not a control packet, but serverEndpoint isn't bound
return nil, netip.AddrPort{} return nil, netip.AddrPort{}
} }
@@ -287,7 +295,9 @@ func (e *serverEndpoint) handleDataPacket(from netip.AddrPort, b []byte, now mon
} }
func (e *serverEndpoint) isExpired(now mono.Time, bindLifetime, steadyStateLifetime time.Duration) bool { func (e *serverEndpoint) isExpired(now mono.Time, bindLifetime, steadyStateLifetime time.Duration) bool {
if !e.isBound() { e.mu.Lock()
defer e.mu.Unlock()
if !e.isBoundLocked() {
if now.Sub(e.allocatedAt) > bindLifetime { if now.Sub(e.allocatedAt) > bindLifetime {
return true return true
} }
@@ -299,9 +309,9 @@ func (e *serverEndpoint) isExpired(now mono.Time, bindLifetime, steadyStateLifet
return false return false
} }
// isBound returns true if both clients have completed a 3-way handshake, // isBoundLocked returns true if both clients have completed a 3-way handshake,
// otherwise false. // otherwise false.
func (e *serverEndpoint) isBound() bool { func (e *serverEndpoint) isBoundLocked() bool {
return e.boundAddrPorts[0].IsValid() && return e.boundAddrPorts[0].IsValid() &&
e.boundAddrPorts[1].IsValid() e.boundAddrPorts[1].IsValid()
} }
@@ -313,15 +323,14 @@ func (e *serverEndpoint) isBound() bool {
// used. // used.
func NewServer(logf logger.Logf, port uint16, onlyStaticAddrPorts bool) (s *Server, err error) { func NewServer(logf logger.Logf, port uint16, onlyStaticAddrPorts bool) (s *Server, err error) {
s = &Server{ s = &Server{
logf: logf, logf: logf,
disco: key.NewDisco(), disco: key.NewDisco(),
bindLifetime: defaultBindLifetime, bindLifetime: defaultBindLifetime,
steadyStateLifetime: defaultSteadyStateLifetime, steadyStateLifetime: defaultSteadyStateLifetime,
closeCh: make(chan struct{}), closeCh: make(chan struct{}),
onlyStaticAddrPorts: onlyStaticAddrPorts, onlyStaticAddrPorts: onlyStaticAddrPorts,
byDisco: make(map[key.SortedPairOfDiscoPublic]*serverEndpoint), serverEndpointByDisco: make(map[key.SortedPairOfDiscoPublic]*serverEndpoint),
nextVNI: minVNI, nextVNI: minVNI,
byVNI: make(map[uint32]*serverEndpoint),
} }
s.discoPublic = s.disco.Public() s.discoPublic = s.disco.Public()
@@ -640,8 +649,8 @@ func (s *Server) Close() error {
// acquire s.mu. // acquire s.mu.
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
clear(s.byVNI) s.serverEndpointByVNI.Clear()
clear(s.byDisco) clear(s.serverEndpointByDisco)
s.closed = true s.closed = true
s.bus.Close() s.bus.Close()
}) })
@@ -659,10 +668,10 @@ func (s *Server) endpointGCLoop() {
// holding s.mu for the duration. Keep it simple (and slow) for now. // holding s.mu for the duration. Keep it simple (and slow) for now.
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
for k, v := range s.byDisco { for k, v := range s.serverEndpointByDisco {
if v.isExpired(now, s.bindLifetime, s.steadyStateLifetime) { if v.isExpired(now, s.bindLifetime, s.steadyStateLifetime) {
delete(s.byDisco, k) delete(s.serverEndpointByDisco, k)
delete(s.byVNI, v.vni) s.serverEndpointByVNI.Delete(v.vni)
} }
} }
} }
@@ -690,12 +699,7 @@ func (s *Server) handlePacket(from netip.AddrPort, b []byte) (write []byte, to n
if err != nil { if err != nil {
return nil, netip.AddrPort{} return nil, netip.AddrPort{}
} }
// TODO: consider performance implications of holding s.mu for the remainder e, ok := s.serverEndpointByVNI.Load(gh.VNI.Get())
// of this method, which does a bunch of disco/crypto work depending. Keep
// it simple (and slow) for now.
s.mu.Lock()
defer s.mu.Unlock()
e, ok := s.byVNI[gh.VNI.Get()]
if !ok { if !ok {
// unknown VNI // unknown VNI
return nil, netip.AddrPort{} return nil, netip.AddrPort{}
@@ -708,27 +712,36 @@ func (s *Server) handlePacket(from netip.AddrPort, b []byte) (write []byte, to n
return nil, netip.AddrPort{} return nil, netip.AddrPort{}
} }
msg := b[packet.GeneveFixedHeaderLength:] msg := b[packet.GeneveFixedHeaderLength:]
s.maybeRotateMACSecretLocked(now) secrets := s.getMACSecrets(now)
return e.handleSealedDiscoControlMsg(from, msg, s.discoPublic, s.macSecrets, now) return e.(*serverEndpoint).handleSealedDiscoControlMsg(from, msg, s.discoPublic, secrets, now)
} }
return e.handleDataPacket(from, b, now) return e.(*serverEndpoint).handleDataPacket(from, b, now)
}
func (s *Server) getMACSecrets(now mono.Time) views.Slice[[blake2s.Size]byte] {
s.mu.Lock()
defer s.mu.Unlock()
s.maybeRotateMACSecretLocked(now)
return s.macSecrets
} }
func (s *Server) maybeRotateMACSecretLocked(now mono.Time) { func (s *Server) maybeRotateMACSecretLocked(now mono.Time) {
if !s.macSecretRotatedAt.IsZero() && now.Sub(s.macSecretRotatedAt) < macSecretRotationInterval { if !s.macSecretRotatedAt.IsZero() && now.Sub(s.macSecretRotatedAt) < macSecretRotationInterval {
return return
} }
switch len(s.macSecrets) { secrets := s.macSecrets.AsSlice()
switch len(secrets) {
case 0: case 0:
s.macSecrets = make([][blake2s.Size]byte, 1, 2) secrets = make([][blake2s.Size]byte, 1, 2)
case 1: case 1:
s.macSecrets = append(s.macSecrets, [blake2s.Size]byte{}) secrets = append(secrets, [blake2s.Size]byte{})
fallthrough fallthrough
case 2: case 2:
s.macSecrets[1] = s.macSecrets[0] secrets[1] = secrets[0]
} }
rand.Read(s.macSecrets[0][:]) rand.Read(secrets[0][:])
s.macSecretRotatedAt = now s.macSecretRotatedAt = now
s.macSecrets = views.SliceOf(secrets)
return return
} }
@@ -838,7 +851,7 @@ func (s *Server) getNextVNILocked() (uint32, error) {
} else { } else {
s.nextVNI++ s.nextVNI++
} }
_, ok := s.byVNI[vni] _, ok := s.serverEndpointByVNI.Load(vni)
if !ok { if !ok {
return vni, nil return vni, nil
} }
@@ -877,7 +890,7 @@ func (s *Server) AllocateEndpoint(discoA, discoB key.DiscoPublic) (endpoint.Serv
} }
pair := key.NewSortedPairOfDiscoPublic(discoA, discoB) pair := key.NewSortedPairOfDiscoPublic(discoA, discoB)
e, ok := s.byDisco[pair] e, ok := s.serverEndpointByDisco[pair]
if ok { if ok {
// Return the existing allocation. Clients can resolve duplicate // Return the existing allocation. Clients can resolve duplicate
// [endpoint.ServerEndpoint]'s via [endpoint.ServerEndpoint.LamportID]. // [endpoint.ServerEndpoint]'s via [endpoint.ServerEndpoint.LamportID].
@@ -915,8 +928,8 @@ func (s *Server) AllocateEndpoint(discoA, discoB key.DiscoPublic) (endpoint.Serv
e.discoSharedSecrets[0] = s.disco.Shared(e.discoPubKeys.Get()[0]) e.discoSharedSecrets[0] = s.disco.Shared(e.discoPubKeys.Get()[0])
e.discoSharedSecrets[1] = s.disco.Shared(e.discoPubKeys.Get()[1]) e.discoSharedSecrets[1] = s.disco.Shared(e.discoPubKeys.Get()[1])
s.byDisco[pair] = e s.serverEndpointByDisco[pair] = e
s.byVNI[e.vni] = e s.serverEndpointByVNI.Store(e.vni, e)
s.logf("allocated endpoint vni=%d lamportID=%d disco[0]=%v disco[1]=%v", e.vni, e.lamportID, pair.Get()[0].ShortString(), pair.Get()[1].ShortString()) s.logf("allocated endpoint vni=%d lamportID=%d disco[0]=%v disco[1]=%v", e.vni, e.lamportID, pair.Get()[0].ShortString(), pair.Get()[1].ShortString())
return endpoint.ServerEndpoint{ return endpoint.ServerEndpoint{
@@ -930,19 +943,19 @@ func (s *Server) AllocateEndpoint(discoA, discoB key.DiscoPublic) (endpoint.Serv
}, nil }, nil
} }
// extractClientInfo constructs a [status.ClientInfo] for one of the two peer // extractClientInfo constructs a [status.ClientInfo] for both relay clients
// relay clients involved in this session. // involved in this session.
func extractClientInfo(idx int, ep *serverEndpoint) status.ClientInfo { func (e *serverEndpoint) extractClientInfo() [2]status.ClientInfo {
if idx != 0 && idx != 1 { e.mu.Lock()
panic(fmt.Sprintf("idx passed to extractClientInfo() must be 0 or 1; got %d", idx)) defer e.mu.Unlock()
} ret := [2]status.ClientInfo{}
for i := range e.boundAddrPorts {
return status.ClientInfo{ ret[i].Endpoint = e.boundAddrPorts[i]
Endpoint: ep.boundAddrPorts[idx], ret[i].ShortDisco = e.discoPubKeys.Get()[i].ShortString()
ShortDisco: ep.discoPubKeys.Get()[idx].ShortString(), ret[i].PacketsTx = e.packetsRx[i]
PacketsTx: ep.packetsRx[idx], ret[i].BytesTx = e.bytesRx[i]
BytesTx: ep.bytesRx[idx],
} }
return ret
} }
// GetSessions returns a slice of peer relay session statuses, with each // GetSessions returns a slice of peer relay session statuses, with each
@@ -955,14 +968,13 @@ func (s *Server) GetSessions() []status.ServerSession {
if s.closed { if s.closed {
return nil return nil
} }
var sessions = make([]status.ServerSession, 0, len(s.byDisco)) var sessions = make([]status.ServerSession, 0, len(s.serverEndpointByDisco))
for _, se := range s.byDisco { for _, se := range s.serverEndpointByDisco {
c1 := extractClientInfo(0, se) clientInfos := se.extractClientInfo()
c2 := extractClientInfo(1, se)
sessions = append(sessions, status.ServerSession{ sessions = append(sessions, status.ServerSession{
VNI: se.vni, VNI: se.vni,
Client1: c1, Client1: clientInfos[0],
Client2: c2, Client2: clientInfos[1],
}) })
} }
return sessions return sessions
+12 -13
View File
@@ -339,19 +339,18 @@ func TestServer_getNextVNILocked(t *testing.T) {
c := qt.New(t) c := qt.New(t)
s := &Server{ s := &Server{
nextVNI: minVNI, nextVNI: minVNI,
byVNI: make(map[uint32]*serverEndpoint),
} }
for i := uint64(0); i < uint64(totalPossibleVNI); i++ { for i := uint64(0); i < uint64(totalPossibleVNI); i++ {
vni, err := s.getNextVNILocked() vni, err := s.getNextVNILocked()
if err != nil { // using quicktest here triples test time if err != nil { // using quicktest here triples test time
t.Fatal(err) t.Fatal(err)
} }
s.byVNI[vni] = nil s.serverEndpointByVNI.Store(vni, nil)
} }
c.Assert(s.nextVNI, qt.Equals, minVNI) c.Assert(s.nextVNI, qt.Equals, minVNI)
_, err := s.getNextVNILocked() _, err := s.getNextVNILocked()
c.Assert(err, qt.IsNotNil) c.Assert(err, qt.IsNotNil)
delete(s.byVNI, minVNI) s.serverEndpointByVNI.Delete(minVNI)
_, err = s.getNextVNILocked() _, err = s.getNextVNILocked()
c.Assert(err, qt.IsNil) c.Assert(err, qt.IsNil)
} }
@@ -455,17 +454,17 @@ func TestServer_maybeRotateMACSecretLocked(t *testing.T) {
s := &Server{} s := &Server{}
start := mono.Now() start := mono.Now()
s.maybeRotateMACSecretLocked(start) s.maybeRotateMACSecretLocked(start)
qt.Assert(t, len(s.macSecrets), qt.Equals, 1) qt.Assert(t, s.macSecrets.Len(), qt.Equals, 1)
macSecret := s.macSecrets[0] macSecret := s.macSecrets.At(0)
s.maybeRotateMACSecretLocked(start.Add(macSecretRotationInterval - time.Nanosecond)) s.maybeRotateMACSecretLocked(start.Add(macSecretRotationInterval - time.Nanosecond))
qt.Assert(t, len(s.macSecrets), qt.Equals, 1) qt.Assert(t, s.macSecrets.Len(), qt.Equals, 1)
qt.Assert(t, s.macSecrets[0], qt.Equals, macSecret) qt.Assert(t, s.macSecrets.At(0), qt.Equals, macSecret)
s.maybeRotateMACSecretLocked(start.Add(macSecretRotationInterval)) s.maybeRotateMACSecretLocked(start.Add(macSecretRotationInterval))
qt.Assert(t, len(s.macSecrets), qt.Equals, 2) qt.Assert(t, s.macSecrets.Len(), qt.Equals, 2)
qt.Assert(t, s.macSecrets[1], qt.Equals, macSecret) qt.Assert(t, s.macSecrets.At(1), qt.Equals, macSecret)
qt.Assert(t, s.macSecrets[0], qt.Not(qt.Equals), s.macSecrets[1]) qt.Assert(t, s.macSecrets.At(0), qt.Not(qt.Equals), s.macSecrets.At(1))
s.maybeRotateMACSecretLocked(s.macSecretRotatedAt.Add(macSecretRotationInterval)) s.maybeRotateMACSecretLocked(s.macSecretRotatedAt.Add(macSecretRotationInterval))
qt.Assert(t, macSecret, qt.Not(qt.Equals), s.macSecrets[0]) qt.Assert(t, macSecret, qt.Not(qt.Equals), s.macSecrets.At(0))
qt.Assert(t, macSecret, qt.Not(qt.Equals), s.macSecrets[1]) qt.Assert(t, macSecret, qt.Not(qt.Equals), s.macSecrets.At(1))
qt.Assert(t, s.macSecrets[0], qt.Not(qt.Equals), s.macSecrets[1]) qt.Assert(t, s.macSecrets.At(0), qt.Not(qt.Equals), s.macSecrets.At(1))
} }