diff --git a/control/controlclient/direct.go b/control/controlclient/direct.go index 3b4e6ba9b..529db1874 100644 --- a/control/controlclient/direct.go +++ b/control/controlclient/direct.go @@ -362,26 +362,24 @@ func NewDirect(opts Options) (*Direct, error) { discoKeyPub := eventbus.Publish[events.PeerDiscoKeyUpdate](c.busClient) eventbus.SubscribeFunc(c.busClient, func(update events.DiscoKeyAdvertisement) { c.logf("controlclient direct: got TSMP disco key advertisement from %v via eventbus", update.Src) - var nm *netmap.NetworkMap + var peerID tailcfg.NodeID + var peerKey key.NodePublic + var ok bool c.mu.Lock() sess := c.streamingMapSession + c.mu.Unlock() if sess != nil { - nm = c.streamingMapSession.netmap() + peerID, peerKey, ok = sess.PeerIDAndKeyByTailscaleIP(update.Src) } - c.mu.Unlock() - if sess != nil { - peer, ok := nm.PeerByTailscaleIP(update.Src) - if !ok { - return - } + if sess != nil && ok { c.logf("controlclient direct: updating discoKey for %v via mapSession", update.Src) // If we update without error, return. If the err indicates that the // mapSession has gone away, we want to fall back to pushing the key // further down the chain. if err := sess.updateDiscoForNode( - peer.ID(), peer.Key(), update.Key, time.Now(), false); err == nil || + peerID, peerKey, update.Key, time.Now(), false); err == nil || !errors.Is(err, ErrChangeQueueClosed) { return } diff --git a/control/controlclient/map.go b/control/controlclient/map.go index 18e79ebc6..4c58ae8af 100644 --- a/control/controlclient/map.go +++ b/control/controlclient/map.go @@ -13,6 +13,7 @@ import ( "io" "maps" "net" + "net/netip" "reflect" "runtime" "runtime/debug" @@ -86,7 +87,6 @@ type mapSession struct { lastPrintMap time.Time lastNode tailcfg.NodeView lastCapSet set.Set[tailcfg.NodeCapability] - peers map[tailcfg.NodeID]tailcfg.NodeView lastDNSConfig *tailcfg.DNSConfig lastDERPMap *tailcfg.DERPMap lastUserProfile map[tailcfg.UserID]tailcfg.UserProfileView @@ -106,6 +106,10 @@ type mapSession struct { changeQueue chan responseWithSource changeQueueClosed bool processQueue sync.WaitGroup + + // mu protects the peers map. + peersMu sync.RWMutex + peers map[tailcfg.NodeID]tailcfg.NodeView } // newMapSession returns a mostly unconfigured new mapSession. @@ -675,6 +679,9 @@ var ( // updatePeersStateFromResponseres updates ms.peers from resp. // It takes ownership of resp. func (ms *mapSession) updatePeersStateFromResponse(resp *tailcfg.MapResponse) (stats updateStats) { + ms.peersMu.Lock() + defer ms.peersMu.Unlock() + if ms.peers == nil { ms.peers = make(map[tailcfg.NodeID]tailcfg.NodeView) } @@ -854,6 +861,9 @@ func getNodeFields() []string { // It returns ok=false if a patch can't be made, (V, ok) on a delta, or (nil, // true) if all the fields were identical (a zero change). func (ms *mapSession) patchifyPeer(n *tailcfg.Node) (_ *tailcfg.PeerChange, ok bool) { + ms.peersMu.RLock() + defer ms.peersMu.RUnlock() + was, ok := ms.peers[n.ID] if !ok { return nil, false @@ -1056,7 +1066,28 @@ func peerChangeDiff(was tailcfg.NodeView, n *tailcfg.Node) (_ *tailcfg.PeerChang return ret, true } +// PeerIDAndKeyByTailscaleIP returns the node ID and node Key from the peers +// map without touching the netmap itself. The implementation mirrors the +// implementation of [netmap.PeerByTailscaleIP]. +func (ms *mapSession) PeerIDAndKeyByTailscaleIP(ip netip.Addr) (tailcfg.NodeID, key.NodePublic, bool) { + ms.peersMu.RLock() + defer ms.peersMu.RUnlock() + for _, n := range ms.peers { + ad := n.Addresses() + for i := range ad.Len() { + a := ad.At(i) + if a.Addr() == ip { + return n.ID(), n.Key(), true + } + } + } + return 0, key.NodePublic{}, false +} + func (ms *mapSession) sortedPeers() []tailcfg.NodeView { + ms.peersMu.RLock() + defer ms.peersMu.RUnlock() + ret := slicesx.MapValues(ms.peers) slices.SortFunc(ret, func(a, b tailcfg.NodeView) int { return cmp.Compare(a.ID(), b.ID()) diff --git a/control/controlclient/map_test.go b/control/controlclient/map_test.go index fff5c7131..5eee931f3 100644 --- a/control/controlclient/map_test.go +++ b/control/controlclient/map_test.go @@ -1781,3 +1781,61 @@ func TestPathDiscokeyerImplementations(t *testing.T) { t.Error("wgengine.watchdogEngine must implement patchDiscoKeyer") } } + +func TestPeerIDAndKeyByTailscaleIP(t *testing.T) { + peerKey1 := key.NewNode().Public() + peerKey2 := key.NewNode().Public() + + peer1 := &tailcfg.Node{ + ID: 1, + Key: peerKey1, + Addresses: []netip.Prefix{netip.MustParsePrefix("100.64.0.1/32")}, + } + peer2 := &tailcfg.Node{ + ID: 2, + Key: peerKey2, + Addresses: []netip.Prefix{ + netip.MustParsePrefix("100.64.0.2/32"), + netip.MustParsePrefix("fd7a:115c::2/128"), + }, + } + + ms := newTestMapSession(t, nil) + ms.updateStateFromResponse(&tailcfg.MapResponse{ + Node: new(tailcfg.Node), + Peers: []*tailcfg.Node{peer1, peer2}, + }) + + t.Run("known_ip_peer1", func(t *testing.T) { + gotID, gotKey, ok := ms.PeerIDAndKeyByTailscaleIP(netip.MustParseAddr("100.64.0.1")) + if !ok { + t.Fatal("PeerIDAndKeyByTailscaleIP returned ok=false, want true") + } + if gotID != peer1.ID { + t.Errorf("NodeID = %v, want %v", gotID, peer1.ID) + } + if gotKey != peerKey1 { + t.Errorf("NodePublic = %v, want %v", gotKey, peerKey1) + } + }) + + t.Run("known_ip_peer2_v6", func(t *testing.T) { + gotID, gotKey, ok := ms.PeerIDAndKeyByTailscaleIP(netip.MustParseAddr("fd7a:115c::2")) + if !ok { + t.Fatal("PeerIDAndKeyByTailscaleIP returned ok=false, want true") + } + if gotID != peer2.ID { + t.Errorf("NodeID = %v, want %v", gotID, peer2.ID) + } + if gotKey != peerKey2 { + t.Errorf("NodePublic = %v, want %v", gotKey, peerKey2) + } + }) + + t.Run("unknown_ip", func(t *testing.T) { + gotID, gotKey, ok := ms.PeerIDAndKeyByTailscaleIP(netip.MustParseAddr("100.64.0.99")) + if ok { + t.Errorf("PeerIDAndKeyByTailscaleIP returned ok=true for unknown IP, got id=%v key=%v", gotID, gotKey) + } + }) +}