wgengine/magicsock,control/controlclient: do not overwrite discokey with old key (#18606)

When a client starts up without being able to connect to control, it
sends its discoKey to other nodes it wants to communicate with over
TSMP. This disco key will be a newer key than the one control knows
about.

If the client that can connect to control gets a full netmap, ensure
that the disco key for the node not connected to control is not
overwritten with the stale key control knows about.

This is implemented through keeping track of mapSession and use that for
the discokey injection if it is available. This ensures that we are not
constantly resetting the wireguard connection when getting the wrong
keys from control.

This is implemented as:
 - If the key is received via TSMP:
   - Set lastSeen for the peer to now()
   - Set online for the peer to false
 - When processing new keys, only accept keys where either:
   - Peer is online
   - lastSeen is newer than existing last seen

If mapSession is not available, as in we are not yet connected to
control, punt down the disco key injection to magicsock.

Ideally, we will want to have mapSession be long lived at some point in
the near future so we only need to inject keys in one location and then
also use that for testing and loading the cache, but that is a yak for
another PR.

Updates #12639

Signed-off-by: Claus Lensbøl <claus@tailscale.com>
This commit is contained in:
Claus Lensbøl
2026-03-20 08:56:27 -04:00
committed by GitHub
parent ca9aa20255
commit 85bb5f84a5
15 changed files with 346 additions and 46 deletions
+14 -8
View File
@@ -91,7 +91,7 @@ func (c *Auto) updateRoutine() {
bo.BackOff(ctx, err)
continue
}
bo.BackOff(ctx, nil)
bo.Reset()
c.direct.logf("[v1] successful lite map update in %v", d)
lastUpdateGenInformed = gen
@@ -382,7 +382,7 @@ func (c *Auto) authRoutine() {
// backoff to avoid a busy loop.
bo.BackOff(ctx, errors.New("login URL not changing"))
} else {
bo.BackOff(ctx, nil)
bo.Reset()
}
continue
}
@@ -397,7 +397,7 @@ func (c *Auto) authRoutine() {
c.sendStatus("authRoutine-success", nil, "", nil)
c.restartMap()
bo.BackOff(ctx, nil)
bo.Reset()
}
}
@@ -446,13 +446,14 @@ func (mrs mapRoutineState) UpdateFullNetmap(nm *netmap.NetworkMap) {
c.expiry = nm.SelfKeyExpiry()
stillAuthed := c.loggedIn
c.logf("[v1] mapRoutine: netmap received: loggedIn=%v inMapPoll=true", stillAuthed)
// Reset the backoff timer if we got a netmap.
mrs.bo.Reset()
c.mu.Unlock()
if stillAuthed {
c.sendStatus("mapRoutine-got-netmap", nil, "", nm)
}
// Reset the backoff timer if we got a netmap.
mrs.bo.Reset()
}
func (mrs mapRoutineState) UpdateNetmapDelta(muts []netmap.NodeMutation) bool {
@@ -526,13 +527,18 @@ func (c *Auto) mapRoutine() {
c.mu.Lock()
c.inMapPoll = false
paused := c.paused
c.mu.Unlock()
if paused {
mrs.bo.BackOff(ctx, nil)
c.logf("mapRoutine: paused")
mrs.bo.Reset()
} else {
mrs.bo.BackOff(ctx, err)
}
c.mu.Unlock()
// Now safe to call functions that might acquire the mutex
if paused {
c.logf("mapRoutine: paused")
} else {
report(err, "PollNetMap")
}
}
+60 -8
View File
@@ -46,6 +46,7 @@ import (
"tailscale.com/tailcfg"
"tailscale.com/tka"
"tailscale.com/tstime"
"tailscale.com/types/events"
"tailscale.com/types/key"
"tailscale.com/types/logger"
"tailscale.com/types/netmap"
@@ -106,8 +107,9 @@ type Direct struct {
netinfo *tailcfg.NetInfo
endpoints []tailcfg.Endpoint
tkaHead string
lastPingURL string // last PingRequest.URL received, for dup suppression
connectionHandleForTest string // sent in MapRequest.ConnectionHandleForTest
lastPingURL string // last PingRequest.URL received, for dup suppression
connectionHandleForTest string // sent in MapRequest.ConnectionHandleForTest
streamingMapSession *mapSession // the one streaming mapSession instance
controlClientID int64 // Random ID used to differentiate clients for consumers of messages.
}
@@ -348,6 +350,38 @@ func NewDirect(opts Options) (*Direct, error) {
c.clientVersionPub = eventbus.Publish[tailcfg.ClientVersion](c.busClient)
c.autoUpdatePub = eventbus.Publish[AutoUpdate](c.busClient)
c.controlTimePub = eventbus.Publish[ControlTime](c.busClient)
discoKeyPub := eventbus.Publish[events.PeerDiscoKeyUpdate](c.busClient)
eventbus.SubscribeFunc(c.busClient, func(update events.DiscoKeyAdvertisement) {
c.mu.Lock()
defer c.mu.Unlock()
c.logf("controlclient direct: got TSMP disco key advertisement from %v via eventbus", update.Src)
if c.streamingMapSession != nil {
nm := c.streamingMapSession.netmap()
peer, ok := nm.PeerByTailscaleIP(update.Src)
if !ok {
return
}
c.logf("controlclient direct: updating discoKey for %v via mapSession", update.Src)
// If we update without error, return. If the err indicates that the
// mapSession has gone away, we want to fall back to pushing the key
// further down the chain.
if err := c.streamingMapSession.updateDiscoForNode(
peer.ID(), update.Key, time.Now(), false); err == nil ||
!errors.Is(err, ErrChangeQueueClosed) {
return
}
}
// We need to push the update further down the chain. Either because we do
// not have a mapSession (we are not connected to control) or because the
// mapSession queue has closed.
c.logf("controlclient direct: updating discoKey for %v via magicsock", update.Src)
discoKeyPub.Publish(events.PeerDiscoKeyUpdate{
Src: update.Src,
Key: update.Key,
})
})
return c, nil
}
@@ -821,21 +855,28 @@ func (c *Direct) PollNetMap(ctx context.Context, nu NetmapUpdater) error {
return c.sendMapRequest(ctx, true, nu)
}
// rememberLastNetmapUpdater is a container that remembers the last netmap
// update it observed. It is used by tests and [NetmapFromMapResponseForDebug].
// It will report only the first netmap seen.
type rememberLastNetmapUpdater struct {
last *netmap.NetworkMap
done chan any
}
func (nu *rememberLastNetmapUpdater) UpdateFullNetmap(nm *netmap.NetworkMap) {
nu.last = nm
nu.done <- nil
}
// FetchNetMapForTest fetches the netmap once.
func (c *Direct) FetchNetMapForTest(ctx context.Context) (*netmap.NetworkMap, error) {
var nu rememberLastNetmapUpdater
nu.done = make(chan any)
err := c.sendMapRequest(ctx, false, &nu)
if err == nil && nu.last == nil {
return nil, errors.New("[unexpected] sendMapRequest success without callback")
}
<-nu.done
return nu.last, err
}
@@ -1080,8 +1121,18 @@ func (c *Direct) sendMapRequest(ctx context.Context, isStreaming bool, nu Netmap
return nil
}
if c.streamingMapSession != nil {
panic("mapSession is already set")
}
sess := newMapSession(persist.PrivateNodeKey(), nu, c.controlKnobs)
defer sess.Close()
c.streamingMapSession = sess
defer func() {
sess.Close()
c.mu.Lock()
c.streamingMapSession = nil
c.mu.Unlock()
}()
sess.cancel = cancel
sess.logf = c.logf
sess.vlogf = vlogf
@@ -1235,7 +1286,7 @@ func NetmapFromMapResponseForDebug(ctx context.Context, pr persist.PersistView,
return nil, errors.New("PersistView invalid")
}
nu := &rememberLastNetmapUpdater{}
nu := &rememberLastNetmapUpdater{done: make(chan any)}
sess := newMapSession(pr.PrivateNodeKey(), nu, nil)
defer sess.Close()
@@ -1243,6 +1294,7 @@ func NetmapFromMapResponseForDebug(ctx context.Context, pr persist.PersistView,
return nil, fmt.Errorf("HandleNonKeepAliveMapResponse: %w", err)
}
<-nu.done
return sess.netmap(), nil
}
@@ -1303,10 +1355,10 @@ var jsonEscapedZero = []byte(`\u0000`)
const justKeepAliveStr = `{"KeepAlive":true}`
// decodeMsg is responsible for uncompressing msg and unmarshaling into v.
func (sess *mapSession) decodeMsg(compressedMsg []byte, v *tailcfg.MapResponse) error {
func (ms *mapSession) decodeMsg(compressedMsg []byte, v *tailcfg.MapResponse) error {
// Fast path for common case of keep-alive message.
// See tailscale/tailscale#17343.
if sess.keepAliveZ != nil && bytes.Equal(compressedMsg, sess.keepAliveZ) {
if ms.keepAliveZ != nil && bytes.Equal(compressedMsg, ms.keepAliveZ) {
v.KeepAlive = true
return nil
}
@@ -1315,7 +1367,7 @@ func (sess *mapSession) decodeMsg(compressedMsg []byte, v *tailcfg.MapResponse)
if err != nil {
return err
}
sess.ztdDecodesForTest++
ms.ztdDecodesForTest++
if DevKnob.DumpNetMaps() {
var buf bytes.Buffer
@@ -1330,7 +1382,7 @@ func (sess *mapSession) decodeMsg(compressedMsg []byte, v *tailcfg.MapResponse)
return fmt.Errorf("response: %v", err)
}
if v.KeepAlive && string(b) == justKeepAliveStr {
sess.keepAliveZ = compressedMsg
ms.keepAliveZ = compressedMsg
}
return nil
}
+127 -2
View File
@@ -9,6 +9,7 @@ import (
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"io"
"maps"
"net"
@@ -96,6 +97,10 @@ type mapSession struct {
lastPopBrowserURL string
lastTKAInfo *tailcfg.TKAInfo
lastNetmapSummary string // from NetworkMap.VeryConcise
cqmu sync.Mutex
changeQueue chan (*tailcfg.MapResponse)
changeQueueClosed bool
processQueue sync.WaitGroup
}
// newMapSession returns a mostly unconfigured new mapSession.
@@ -118,11 +123,48 @@ func newMapSession(privateNodeKey key.NodePrivate, nu NetmapUpdater, controlKnob
cancel: func() {},
onDebug: func(context.Context, *tailcfg.Debug) error { return nil },
onSelfNodeChanged: func(*netmap.NetworkMap) {},
changeQueue: make(chan *tailcfg.MapResponse),
changeQueueClosed: false,
}
ms.sessionAliveCtx, ms.sessionAliveCtxClose = context.WithCancel(context.Background())
ms.processQueue.Add(1)
go ms.run()
return ms
}
// run starts the mapSession processing a queue of tailcfg.MapResponse one by
// one until close() is called on the mapSession.
// When the mapSession is closed, the remaining queue is locked and processed
// before the mapSession is done processing.
func (ms *mapSession) run() {
defer ms.processQueue.Done()
for {
select {
case change := <-ms.changeQueue:
ms.handleNonKeepAliveMapResponse(ms.sessionAliveCtx, change)
case <-ms.sessionAliveCtx.Done():
// Drain any remaining items in the queue before exiting.
// Lock the queue during this time to avoid updates through other channels
// to be overwritten. This is especially relevant for calls to
// updateDiscoForNode.
ms.cqmu.Lock()
ms.changeQueueClosed = true
ms.cqmu.Unlock()
for {
select {
case change := <-ms.changeQueue:
ms.handleNonKeepAliveMapResponse(ms.sessionAliveCtx, change)
default:
// Queue is empty, close it and exit
close(ms.changeQueue)
return
}
}
}
}
}
// occasionallyPrintSummary logs summary at most once very 5 minutes. The
// summary is the Netmap.VeryConcise result from the last received map response.
func (ms *mapSession) occasionallyPrintSummary(summary string) {
@@ -143,9 +185,48 @@ func (ms *mapSession) clock() tstime.Clock {
func (ms *mapSession) Close() {
ms.sessionAliveCtxClose()
ms.processQueue.Wait()
}
// HandleNonKeepAliveMapResponse handles a non-KeepAlive MapResponse (full or
var ErrChangeQueueClosed = errors.New("change queue closed")
func (ms *mapSession) updateDiscoForNode(id tailcfg.NodeID, key key.DiscoPublic, lastSeen time.Time, online bool) error {
ms.cqmu.Lock()
if ms.changeQueueClosed {
ms.cqmu.Unlock()
ms.processQueue.Wait()
return ErrChangeQueueClosed
}
resp := &tailcfg.MapResponse{
PeersChangedPatch: []*tailcfg.PeerChange{{
NodeID: id,
LastSeen: &lastSeen,
Online: &online,
DiscoKey: &key,
}},
}
ms.changeQueue <- resp
ms.cqmu.Unlock()
return nil
}
func (ms *mapSession) HandleNonKeepAliveMapResponse(ctx context.Context, resp *tailcfg.MapResponse) error {
ms.cqmu.Lock()
if ms.changeQueueClosed {
ms.cqmu.Unlock()
ms.processQueue.Wait()
return ErrChangeQueueClosed
}
ms.changeQueue <- resp
ms.cqmu.Unlock()
return nil
}
// handleNonKeepAliveMapResponse handles a non-KeepAlive MapResponse (full or
// incremental).
//
// All fields that are valid on a KeepAlive MapResponse have already been
@@ -153,7 +234,7 @@ func (ms *mapSession) Close() {
//
// TODO(bradfitz): make this handle all fields later. For now (2023-08-20) this
// is [re]factoring progress enough.
func (ms *mapSession) HandleNonKeepAliveMapResponse(ctx context.Context, resp *tailcfg.MapResponse) error {
func (ms *mapSession) handleNonKeepAliveMapResponse(ctx context.Context, resp *tailcfg.MapResponse) error {
if debug := resp.Debug; debug != nil {
if err := ms.onDebug(ctx, debug); err != nil {
return err
@@ -199,6 +280,8 @@ func (ms *mapSession) HandleNonKeepAliveMapResponse(ctx context.Context, resp *t
ms.patchifyPeersChanged(resp)
ms.removeUnwantedDiscoUpdates(resp)
ms.updateStateFromResponse(resp)
if ms.tryHandleIncrementally(resp) {
@@ -281,6 +364,48 @@ type updateStats struct {
changed int
}
// removeUnwantedDiscoUpdates goes over the patchified updates and reject items
// where the node is offline and has last been seen before the recorded last seen.
func (ms *mapSession) removeUnwantedDiscoUpdates(resp *tailcfg.MapResponse) {
existingMap := ms.netmap()
acceptedDiscoUpdates := resp.PeersChangedPatch[:0]
for _, change := range resp.PeersChangedPatch {
// Accept if:
// - DiscoKey is nil and did not change.
// - Fields we rely on for rejection is missing.
if change.DiscoKey == nil || change.Online == nil || change.LastSeen == nil {
acceptedDiscoUpdates = append(acceptedDiscoUpdates, change)
continue
}
// Accept if:
// - Node is online.
if *change.Online {
acceptedDiscoUpdates = append(acceptedDiscoUpdates, change)
continue
}
peerIdx := existingMap.PeerIndexByNodeID(change.NodeID)
// Accept if:
// - Cannot find the peer, don't have enough data
if peerIdx < 0 {
acceptedDiscoUpdates = append(acceptedDiscoUpdates, change)
continue
}
existingNode := existingMap.Peers[peerIdx]
// Accept if:
// - lastSeen moved forward in time.
if existingLastSeen, ok := existingNode.LastSeen().GetOk(); ok &&
change.LastSeen.After(existingLastSeen) {
acceptedDiscoUpdates = append(acceptedDiscoUpdates, change)
}
}
resp.PeersChangedPatch = acceptedDiscoUpdates
}
// updateStateFromResponse updates ms from res. It takes ownership of res.
func (ms *mapSession) updateStateFromResponse(resp *tailcfg.MapResponse) {
ms.updatePeersStateFromResponse(resp)
+86
View File
@@ -623,6 +623,90 @@ func TestNetmapForResponse(t *testing.T) {
})
}
func TestUpdateDiscoForNode(t *testing.T) {
tests := []struct {
name string
initialOnline bool
initialLastSeen time.Time
updateOnline bool
updateLastSeen time.Time
wantUpdate bool
}{
{
name: "newer_key_not_online",
initialOnline: true,
initialLastSeen: time.Unix(1, 0),
updateOnline: false,
updateLastSeen: time.Now(),
wantUpdate: true,
},
{
name: "newer_key_online",
initialOnline: true,
initialLastSeen: time.Unix(1, 0),
updateOnline: true,
updateLastSeen: time.Now(),
wantUpdate: true,
},
{
name: "older_key_not_online",
initialOnline: false,
initialLastSeen: time.Now(),
updateOnline: false,
updateLastSeen: time.Unix(1, 0),
wantUpdate: false,
},
{
name: "older_key_online",
initialOnline: false,
initialLastSeen: time.Now(),
updateOnline: true,
updateLastSeen: time.Unix(1, 0),
wantUpdate: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
nu := &rememberLastNetmapUpdater{
done: make(chan any),
}
ms := newTestMapSession(t, nu)
oldKey := key.NewDisco()
// Insert existing node
node := tailcfg.Node{
ID: 1,
DiscoKey: oldKey.Public(),
Online: &tt.initialOnline,
LastSeen: &tt.initialLastSeen,
}
if nm := ms.netmapForResponse(&tailcfg.MapResponse{
Peers: []*tailcfg.Node{&node},
}); len(nm.Peers) != 1 {
t.Fatalf("node not inserted")
}
newKey := key.NewDisco()
ms.updateDiscoForNode(node.ID, newKey.Public(), tt.updateLastSeen, tt.updateOnline)
<-nu.done
nm := ms.netmap()
peerIdx := nm.PeerIndexByNodeID(node.ID)
if peerIdx == -1 {
t.Fatal("node not found")
}
updated := nm.Peers[peerIdx].DiscoKey().Compare(newKey.Public()) == 0
if updated != tt.wantUpdate {
t.Fatalf("Disco key update: %t, wanted update: %t", updated, tt.wantUpdate)
}
})
}
}
func first[T any](s []T) T {
if len(s) == 0 {
var zero T
@@ -1098,6 +1182,8 @@ func BenchmarkMapSessionDelta(b *testing.B) {
ctx := context.Background()
nu := &countingNetmapUpdater{}
ms := newTestMapSession(b, nu)
// Disable log output for benchmarks to avoid races
ms.logf = func(string, ...any) {}
res := &tailcfg.MapResponse{
Node: &tailcfg.Node{
ID: 1,