diff --git a/control/controlclient/auto.go b/control/controlclient/auto.go index 195525228..b087e1444 100644 --- a/control/controlclient/auto.go +++ b/control/controlclient/auto.go @@ -478,6 +478,27 @@ func (mrs mapRoutineState) UpdateNetmapDelta(muts []netmap.NodeMutation) bool { return err == nil && ok } +var _ patchDiscoKeyer = mapRoutineState{} + +func (mrs mapRoutineState) PatchDiscoKey(pub key.NodePublic, disco key.DiscoPublic) { + c := mrs.c + c.mu.Lock() + goodState := c.loggedIn && c.inMapPoll + dun, ok := c.observer.(patchDiscoKeyer) + c.mu.Unlock() + + if !goodState || !ok { + return + } + + ctx, cancel := context.WithTimeout(c.mapCtx, 2*time.Second) + defer cancel() + + c.observerQueue.RunSync(ctx, func() { + dun.PatchDiscoKey(pub, disco) + }) +} + // mapRoutine is responsible for keeping a read-only streaming connection to the // control server, and keeping the netmap up to date. func (c *Auto) mapRoutine() { diff --git a/control/controlclient/direct.go b/control/controlclient/direct.go index dc3ebd300..1569d7517 100644 --- a/control/controlclient/direct.go +++ b/control/controlclient/direct.go @@ -228,6 +228,15 @@ type NetmapDeltaUpdater interface { UpdateNetmapDelta([]netmap.NodeMutation) (ok bool) } +// patchDiscoKeyer is an optional interface that can be implemented by an [Observer] to be +// notified about node disco keys received out-of-band from control, via +// existing connection state. +type patchDiscoKeyer interface { + // PatchDiscoKey reports to the receiver that the specified disco key + // for node was obtained out-of-band from control. + PatchDiscoKey(key.NodePublic, key.DiscoPublic) +} + var nextControlClientID atomic.Int64 // NewDirect returns a new Direct client. @@ -367,7 +376,7 @@ func NewDirect(opts Options) (*Direct, error) { // mapSession has gone away, we want to fall back to pushing the key // further down the chain. if err := c.streamingMapSession.updateDiscoForNode( - peer.ID(), update.Key, time.Now(), false); err == nil || + peer.ID(), peer.Key(), update.Key, time.Now(), false); err == nil || !errors.Is(err, ErrChangeQueueClosed) { return } @@ -377,10 +386,7 @@ func NewDirect(opts Options) (*Direct, error) { // not have a mapSession (we are not connected to control) or because the // mapSession queue has closed. c.logf("controlclient direct: updating discoKey for %v via magicsock", update.Src) - discoKeyPub.Publish(events.PeerDiscoKeyUpdate{ - Src: update.Src, - Key: update.Key, - }) + discoKeyPub.Publish(events.PeerDiscoKeyUpdate(update)) }) return c, nil @@ -859,8 +865,10 @@ func (c *Direct) PollNetMap(ctx context.Context, nu NetmapUpdater) error { // update it observed. It is used by tests and [NetmapFromMapResponseForDebug]. // It will report only the first netmap seen. type rememberLastNetmapUpdater struct { - last *netmap.NetworkMap - done chan any + last *netmap.NetworkMap + lastTSMPKey key.NodePublic + lastTSMPDisco key.DiscoPublic + done chan any } func (nu *rememberLastNetmapUpdater) UpdateFullNetmap(nm *netmap.NetworkMap) { @@ -871,6 +879,11 @@ func (nu *rememberLastNetmapUpdater) UpdateFullNetmap(nm *netmap.NetworkMap) { } } +func (nu *rememberLastNetmapUpdater) PatchDiscoKey(key key.NodePublic, disco key.DiscoPublic) { + nu.lastTSMPKey = key + nu.lastTSMPDisco = disco +} + // FetchNetMapForTest fetches the netmap once. func (c *Direct) FetchNetMapForTest(ctx context.Context) (*netmap.NetworkMap, error) { var nu rememberLastNetmapUpdater diff --git a/control/controlclient/map.go b/control/controlclient/map.go index 1a0ab0037..c08a54ac4 100644 --- a/control/controlclient/map.go +++ b/control/controlclient/map.go @@ -37,6 +37,11 @@ import ( "tailscale.com/wgengine/filter" ) +type responseWithSource struct { + response *tailcfg.MapResponse + viaTSMP bool +} + // mapSession holds the state over a long-polled "map" request to the // control plane. // @@ -98,7 +103,7 @@ type mapSession struct { lastTKAInfo *tailcfg.TKAInfo lastNetmapSummary string // from NetworkMap.VeryConcise cqmu sync.Mutex - changeQueue chan (*tailcfg.MapResponse) + changeQueue chan responseWithSource changeQueueClosed bool processQueue sync.WaitGroup } @@ -123,7 +128,7 @@ func newMapSession(privateNodeKey key.NodePrivate, nu NetmapUpdater, controlKnob cancel: func() {}, onDebug: func(context.Context, *tailcfg.Debug) error { return nil }, onSelfNodeChanged: func(*netmap.NetworkMap) {}, - changeQueue: make(chan *tailcfg.MapResponse), + changeQueue: make(chan responseWithSource), changeQueueClosed: false, } ms.sessionAliveCtx, ms.sessionAliveCtxClose = context.WithCancel(context.Background()) @@ -142,7 +147,7 @@ func (ms *mapSession) run() { for { select { case change := <-ms.changeQueue: - ms.handleNonKeepAliveMapResponse(ms.sessionAliveCtx, change) + ms.handleNonKeepAliveMapResponse(ms.sessionAliveCtx, change.response, change.viaTSMP) case <-ms.sessionAliveCtx.Done(): // Drain any remaining items in the queue before exiting. // Lock the queue during this time to avoid updates through other channels @@ -154,7 +159,7 @@ func (ms *mapSession) run() { for { select { case change := <-ms.changeQueue: - ms.handleNonKeepAliveMapResponse(ms.sessionAliveCtx, change) + ms.handleNonKeepAliveMapResponse(ms.sessionAliveCtx, change.response, change.viaTSMP) default: // Queue is empty, close it and exit close(ms.changeQueue) @@ -190,7 +195,7 @@ func (ms *mapSession) Close() { var ErrChangeQueueClosed = errors.New("change queue closed") -func (ms *mapSession) updateDiscoForNode(id tailcfg.NodeID, key key.DiscoPublic, lastSeen time.Time, online bool) error { +func (ms *mapSession) updateDiscoForNode(id tailcfg.NodeID, key key.NodePublic, discoKey key.DiscoPublic, lastSeen time.Time, online bool) error { ms.cqmu.Lock() if ms.changeQueueClosed { @@ -199,13 +204,17 @@ func (ms *mapSession) updateDiscoForNode(id tailcfg.NodeID, key key.DiscoPublic, return ErrChangeQueueClosed } - resp := &tailcfg.MapResponse{ - PeersChangedPatch: []*tailcfg.PeerChange{{ - NodeID: id, - LastSeen: &lastSeen, - Online: &online, - DiscoKey: &key, - }}, + resp := responseWithSource{ + response: &tailcfg.MapResponse{ + PeersChangedPatch: []*tailcfg.PeerChange{{ + NodeID: id, + Key: &key, + LastSeen: &lastSeen, + Online: &online, + DiscoKey: &discoKey, + }}, + }, + viaTSMP: true, } ms.changeQueue <- resp ms.cqmu.Unlock() @@ -221,7 +230,12 @@ func (ms *mapSession) HandleNonKeepAliveMapResponse(ctx context.Context, resp *t return ErrChangeQueueClosed } - ms.changeQueue <- resp + change := responseWithSource{ + response: resp, + viaTSMP: false, + } + + ms.changeQueue <- change ms.cqmu.Unlock() return nil } @@ -234,7 +248,7 @@ func (ms *mapSession) HandleNonKeepAliveMapResponse(ctx context.Context, resp *t // // TODO(bradfitz): make this handle all fields later. For now (2023-08-20) this // is [re]factoring progress enough. -func (ms *mapSession) handleNonKeepAliveMapResponse(ctx context.Context, resp *tailcfg.MapResponse) error { +func (ms *mapSession) handleNonKeepAliveMapResponse(ctx context.Context, resp *tailcfg.MapResponse, viaTSMP bool) error { if debug := resp.Debug; debug != nil { if err := ms.onDebug(ctx, debug); err != nil { return err @@ -284,6 +298,13 @@ func (ms *mapSession) handleNonKeepAliveMapResponse(ctx context.Context, resp *t ms.updateStateFromResponse(resp) + // If source was learned via TSMP, the updated disco key need to be marked in + // userspaceEngine as an update that should not reconfigure the wireguard + // connection. + if viaTSMP { + ms.tryMarkDiscoAsLearnedFromTSMP(resp) + } + if ms.tryHandleIncrementally(resp) { ms.occasionallyPrintSummary(ms.lastNetmapSummary) return nil @@ -312,6 +333,21 @@ func (ms *mapSession) handleNonKeepAliveMapResponse(ctx context.Context, resp *t return nil } +func (ms *mapSession) tryMarkDiscoAsLearnedFromTSMP(res *tailcfg.MapResponse) { + dun, ok := ms.netmapUpdater.(patchDiscoKeyer) + if !ok { + return + } + + // In reality we should never really have more than one change here over TSMP. + for _, change := range res.PeersChangedPatch { + if change == nil || change.DiscoKey == nil || change.Key == nil { + continue + } + dun.PatchDiscoKey(*change.Key, *change.DiscoKey) + } +} + // upgradeNode upgrades Node fields from the server into the modern forms // not using deprecated fields. func upgradeNode(n *tailcfg.Node) { diff --git a/control/controlclient/map_test.go b/control/controlclient/map_test.go index 154b9742e..7b99ae7b8 100644 --- a/control/controlclient/map_test.go +++ b/control/controlclient/map_test.go @@ -33,7 +33,9 @@ import ( "tailscale.com/util/eventbus/eventbustest" "tailscale.com/util/mak" "tailscale.com/util/must" + "tailscale.com/util/usermetric" "tailscale.com/util/zstdframe" + "tailscale.com/wgengine" ) func eps(s ...string) []netip.AddrPort { @@ -678,6 +680,7 @@ func TestUpdateDiscoForNode(t *testing.T) { // Insert existing node node := tailcfg.Node{ ID: 1, + Key: key.NewNode().Public(), DiscoKey: oldKey.Public(), Online: &tt.initialOnline, LastSeen: &tt.initialLastSeen, @@ -690,7 +693,7 @@ func TestUpdateDiscoForNode(t *testing.T) { } newKey := key.NewDisco() - ms.updateDiscoForNode(node.ID, newKey.Public(), tt.updateLastSeen, tt.updateOnline) + ms.updateDiscoForNode(node.ID, node.Key, newKey.Public(), tt.updateLastSeen, tt.updateOnline) <-nu.done nm := ms.netmap() @@ -707,6 +710,82 @@ func TestUpdateDiscoForNode(t *testing.T) { } } +func TestUpdateDiscoForNodeCallback(t *testing.T) { + t.Run("key_wired_through_to_updater", func(t *testing.T) { + nu := &rememberLastNetmapUpdater{ + done: make(chan any, 1), + } + ms := newTestMapSession(t, nu) + + oldKey := key.NewDisco() + + // Insert existing node + node := tailcfg.Node{ + ID: 1, + Key: key.NewNode().Public(), + DiscoKey: oldKey.Public(), + Online: new(false), + LastSeen: new(time.Unix(1, 0)), + } + + if nm := ms.netmapForResponse(&tailcfg.MapResponse{ + Peers: []*tailcfg.Node{&node}, + }); len(nm.Peers) != 1 { + t.Fatalf("node not inserted") + } + + newKey := key.NewDisco() + ms.updateDiscoForNode(node.ID, node.Key, newKey.Public(), time.Now(), false) + <-nu.done + + if nu.lastTSMPKey != node.Key || nu.lastTSMPDisco != newKey.Public() { + t.Fatalf("expected [%s]=%s, got [%s]=%s", node.Key, newKey.Public(), + nu.lastTSMPKey, nu.lastTSMPDisco) + } + }) + t.Run("key_not_wired_through_to_updater", func(t *testing.T) { + nu := &rememberLastNetmapUpdater{ + done: make(chan any, 1), + } + ms := newTestMapSession(t, nu) + + oldKey := key.NewDisco() + + // Insert existing node + node := tailcfg.Node{ + ID: 1, + Key: key.NewNode().Public(), + DiscoKey: oldKey.Public(), + Online: new(false), + LastSeen: new(time.Unix(1, 0)), + } + + if nm := ms.netmapForResponse(&tailcfg.MapResponse{ + Peers: []*tailcfg.Node{&node}, + }); len(nm.Peers) != 1 { + t.Fatalf("node not inserted") + } + + newKey := key.NewDisco().Public() + resp := &tailcfg.MapResponse{ + PeersChangedPatch: []*tailcfg.PeerChange{{ + NodeID: node.ID, + Key: &node.Key, + LastSeen: new(time.Now()), + Online: new(true), + DiscoKey: &newKey, + }}, + } + ms.HandleNonKeepAliveMapResponse(t.Context(), resp) + <-nu.done + + if !nu.lastTSMPKey.IsZero() || !nu.lastTSMPDisco.IsZero() { + t.Fatalf("expected zero keys, got [%s]=%s", + nu.lastTSMPKey, nu.lastTSMPDisco) + } + }) +} + func first[T any](s []T) T { if len(s) == 0 { var zero T @@ -1568,3 +1647,22 @@ func TestLearnZstdOfKeepAlive(t *testing.T) { t.Fatalf("got %d zstd decodes; want %d", got, want) } } + +func TestPathDiscokeyerImplementations(t *testing.T) { + bus := eventbustest.NewBus(t) + ht := health.NewTracker(bus) + reg := new(usermetric.Registry) + e, err := wgengine.NewFakeUserspaceEngine(t.Logf, 0, ht, reg, bus) + if err != nil { + t.Fatal(err) + } + t.Cleanup(e.Close) + if _, ok := e.(patchDiscoKeyer); !ok { + t.Error("wgengine.userspaceEngine must implement patchDiscoKeyer") + } + + wd := wgengine.NewWatchdog(e) + if _, ok := wd.(patchDiscoKeyer); !ok { + t.Error("wgengine.watchdogEngine must implement patchDiscoKeyer") + } +} diff --git a/ipn/ipnlocal/local.go b/ipn/ipnlocal/local.go index 845317c4a..5c25583e5 100644 --- a/ipn/ipnlocal/local.go +++ b/ipn/ipnlocal/local.go @@ -1857,6 +1857,18 @@ func (b *LocalBackend) setControlClientStatusLocked(c controlclient.Client, st c b.authReconfigLocked() } +func (b *LocalBackend) PatchDiscoKey(pub key.NodePublic, disco key.DiscoPublic) { + // PatchDiscoKey mirrors the implementation of [controlclient.patchDiscoKeyer]. + // It is implemented here to avoid the dependency edge to controlclient, but must be kept + // in sync with the original implementation. + type patchDiscoKeyer interface { + PatchDiscoKey(key.NodePublic, key.DiscoPublic) + } + if e, ok := b.e.(patchDiscoKeyer); ok { + e.PatchDiscoKey(pub, disco) + } +} + type preferencePolicyInfo struct { key pkey.Key get func(ipn.PrefsView) bool diff --git a/wgengine/userspace.go b/wgengine/userspace.go index 5670541af..5b81206d0 100644 --- a/wgengine/userspace.go +++ b/wgengine/userspace.go @@ -121,7 +121,8 @@ type userspaceEngine struct { birdClient BIRDClient // or nil controlKnobs *controlknobs.Knobs // or nil - testMaybeReconfigHook func() // for tests; if non-nil, fires if maybeReconfigWireguardLocked called + testMaybeReconfigHook func() // for tests; if non-nil, fires if maybeReconfigWireguardLocked called + testDiscoChangedHook func(map[key.NodePublic]bool) // for tests; if non-nil, fires after assembling discoChanged map // isLocalAddr reports the whether an IP is assigned to the local // tunnel interface. It's used to reflect local packets @@ -167,6 +168,10 @@ type userspaceEngine struct { // networkLogger logs statistics about network connections. networkLogger netlog.Logger + // tsmpLearnedDisco tracks per node key if a peer disco key was learned via TSMP. + // wgLock must be held when using this map. + tsmpLearnedDisco map[key.NodePublic]key.DiscoPublic + // Lock ordering: magicsock.Conn.mu, wgLock, then mu. } @@ -1028,6 +1033,12 @@ func (e *userspaceEngine) ResetAndStop() (*Status, error) { } } +func (e *userspaceEngine) PatchDiscoKey(pub key.NodePublic, disco key.DiscoPublic) { + e.wgLock.Lock() + defer e.wgLock.Unlock() + mak.Set(&e.tsmpLearnedDisco, pub, disco) +} + func (e *userspaceEngine) Reconfig(cfg *wgcfg.Config, routerCfg *router.Config, dnsCfg *dns.Config) error { if routerCfg == nil { panic("routerCfg must not be nil") @@ -1119,14 +1130,31 @@ func (e *userspaceEngine) Reconfig(cfg *wgcfg.Config, routerCfg *router.Config, if p.DiscoKey.IsZero() { continue } + + // If the key changed, mark the connection for reconfiguration. pub := p.PublicKey if old, ok := prevEP[pub]; ok && old != p.DiscoKey { + // If the disco key was learned via TSMP, we do not need to reset the + // wireguard config as the new key was received over an existing wireguard + // connection. + if discoTSMP, okTSMP := e.tsmpLearnedDisco[p.PublicKey]; okTSMP && + discoTSMP == p.DiscoKey { + delete(e.tsmpLearnedDisco, p.PublicKey) + e.logf("wgengine: Skipping reconfig (TSMP key): %s changed from %q to %q", pub.ShortString(), old, p.DiscoKey) + continue + } + discoChanged[pub] = true e.logf("wgengine: Reconfig: %s changed from %q to %q", pub.ShortString(), old, p.DiscoKey) } } } + // For tests, what disco connections needs to be changed. + if e.testDiscoChangedHook != nil { + e.testDiscoChangedHook(discoChanged) + } + e.lastCfgFull = *cfg.Clone() // Tell magicsock about the new (or initial) private key @@ -1144,6 +1172,13 @@ func (e *userspaceEngine) Reconfig(cfg *wgcfg.Config, routerCfg *router.Config, return err } + // Cleanup map of tsmp marks for peers that no longer exists in config. + for nodeKey := range e.tsmpLearnedDisco { + if !peerSet.Contains(nodeKey) { + delete(e.tsmpLearnedDisco, nodeKey) + } + } + // Shutdown the network logger because the IDs changed. // Let it be started back up by subsequent logic. if buildfeatures.HasNetLog && netLogIDsChanged && e.networkLogger.Running() { diff --git a/wgengine/userspace_test.go b/wgengine/userspace_test.go index 18d870af1..210c528b3 100644 --- a/wgengine/userspace_test.go +++ b/wgengine/userspace_test.go @@ -164,6 +164,79 @@ func TestUserspaceEngineReconfig(t *testing.T) { } } +func TestUserspaceEngineTSMPLearned(t *testing.T) { + bus := eventbustest.NewBus(t) + + ht := health.NewTracker(bus) + reg := new(usermetric.Registry) + e, err := NewFakeUserspaceEngine(t.Logf, 0, ht, reg, bus) + if err != nil { + t.Fatal(err) + } + t.Cleanup(e.Close) + ue := e.(*userspaceEngine) + + discoChangedChan := make(chan map[key.NodePublic]bool, 1) + ue.testDiscoChangedHook = func(m map[key.NodePublic]bool) { + discoChangedChan <- m + } + + routerCfg := &router.Config{} + + keyChanges := []struct { + tsmp bool + inMap bool + }{ + {tsmp: false, inMap: false}, + {tsmp: true, inMap: false}, + {tsmp: false, inMap: true}, + } + + nkHex := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" + for _, change := range keyChanges { + oldDisco := key.NewDisco() + nm := &netmap.NetworkMap{ + Peers: nodeViews([]*tailcfg.Node{ + { + ID: 1, + Key: nkFromHex(nkHex), + DiscoKey: oldDisco.Public(), + }, + }), + } + nk, err := key.ParseNodePublicUntyped(mem.S(nkHex)) + if err != nil { + t.Fatal(err) + } + e.SetNetworkMap(nm) + + newDisco := key.NewDisco() + cfg := &wgcfg.Config{ + Peers: []wgcfg.Peer{ + { + PublicKey: nk, + DiscoKey: newDisco.Public(), + }, + }, + } + + if change.tsmp { + ue.PatchDiscoKey(nk, newDisco.Public()) + } + err = e.Reconfig(cfg, routerCfg, &dns.Config{}) + if err != nil { + t.Fatal(err) + } + + changeMap := <-discoChangedChan + + if _, ok := changeMap[nk]; ok != change.inMap { + t.Fatalf("expect key %v in map %v to be %t, got %t", nk, changeMap, + change.inMap, ok) + } + } +} + func TestUserspaceEnginePortReconfig(t *testing.T) { flakytest.Mark(t, "https://github.com/tailscale/tailscale/issues/2855") const defaultPort = 49983 diff --git a/wgengine/watchdog.go b/wgengine/watchdog.go index f12b1c19e..4bb320b4b 100644 --- a/wgengine/watchdog.go +++ b/wgengine/watchdog.go @@ -242,3 +242,15 @@ func (e *watchdogEngine) InstallCaptureHook(cb packet.CaptureCallback) { func (e *watchdogEngine) PeerByKey(pubKey key.NodePublic) (_ wgint.Peer, ok bool) { return e.wrap.PeerByKey(pubKey) } + +func (e *watchdogEngine) PatchDiscoKey(pub key.NodePublic, disco key.DiscoPublic) { + // PatchDiscoKey mirrors the implementation of [controlclient.patchDiscoKeyer ]. + // It is implemented here to avoid the dependency edge to controlclient, but must be kept + // in sync with the original implementation. + type patchDiscoKeyer interface { + PatchDiscoKey(key.NodePublic, key.DiscoPublic) + } + if n, ok := e.wrap.(patchDiscoKeyer); ok { + n.PatchDiscoKey(pub, disco) + } +}