types/key,wgengine/magicsock,control/controlclient,ipn: add debug disco key rotation

Adds the ability to rotate discovery keys on running clients, needed for
testing upcoming disco key distribution changes.

Introduces key.DiscoKey, an atomic container for a disco private key,
public key, and the public key's ShortString, replacing the prior
separate atomic fields.

magicsock.Conn has a new RotateDiscoKey method, and access to this is
provided via localapi and a CLI debug command.

Note that this implementation is primarily for testing as it stands, and
regular use should likely introduce an additional mechanism that allows
the old key to be used for some time, to provide a seamless key rotation
rather than one that invalidates all sessions.

Updates tailscale/corp#34037

Signed-off-by: James Tucker <james@tailscale.com>
This commit is contained in:
James Tucker
2025-11-03 16:41:37 -08:00
committed by James Tucker
parent da508c504d
commit c09c95ef67
16 changed files with 375 additions and 37 deletions
+58
View File
@@ -0,0 +1,58 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package magicsock
import (
"sync/atomic"
"tailscale.com/types/key"
)
type discoKeyPair struct {
private key.DiscoPrivate
public key.DiscoPublic
short string // public.ShortString()
}
// discoAtomic is an atomic container for a disco private key, public key, and
// the public key's ShortString. The private and public keys are always kept
// synchronized.
//
// The zero value is not ready for use. Use [Set] to provide a usable value.
type discoAtomic struct {
pair atomic.Pointer[discoKeyPair]
}
// Pair returns the private and public keys together atomically.
// Code that needs both the private and public keys synchronized should
// use Pair instead of calling Private and Public separately.
func (dk *discoAtomic) Pair() (key.DiscoPrivate, key.DiscoPublic) {
p := dk.pair.Load()
return p.private, p.public
}
// Private returns the private key.
func (dk *discoAtomic) Private() key.DiscoPrivate {
return dk.pair.Load().private
}
// Public returns the public key.
func (dk *discoAtomic) Public() key.DiscoPublic {
return dk.pair.Load().public
}
// Short returns the short string of the public key (see [DiscoPublic.ShortString]).
func (dk *discoAtomic) Short() string {
return dk.pair.Load().short
}
// Set updates the private key (and the cached public key and short string).
func (dk *discoAtomic) Set(private key.DiscoPrivate) {
public := private.Public()
dk.pair.Store(&discoKeyPair{
private: private,
public: public,
short: public.ShortString(),
})
}
+70
View File
@@ -0,0 +1,70 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package magicsock
import (
"testing"
"tailscale.com/types/key"
)
func TestDiscoAtomic(t *testing.T) {
var dk discoAtomic
dk.Set(key.NewDisco())
private := dk.Private()
public := dk.Public()
short := dk.Short()
if private.IsZero() {
t.Fatal("DiscoKey private key should not be zero")
}
if public.IsZero() {
t.Fatal("DiscoKey public key should not be zero")
}
if short == "" {
t.Fatal("DiscoKey short string should not be empty")
}
if public != private.Public() {
t.Fatal("DiscoKey public key doesn't match private key")
}
if short != public.ShortString() {
t.Fatal("DiscoKey short string doesn't match public key")
}
gotPrivate, gotPublic := dk.Pair()
if !gotPrivate.Equal(private) {
t.Fatal("Pair() returned different private key")
}
if gotPublic != public {
t.Fatal("Pair() returned different public key")
}
}
func TestDiscoAtomicSet(t *testing.T) {
var dk discoAtomic
dk.Set(key.NewDisco())
oldPrivate := dk.Private()
oldPublic := dk.Public()
newPrivate := key.NewDisco()
dk.Set(newPrivate)
currentPrivate := dk.Private()
currentPublic := dk.Public()
if currentPrivate.Equal(oldPrivate) {
t.Fatal("DiscoKey private key should have changed after Set")
}
if currentPublic == oldPublic {
t.Fatal("DiscoKey public key should have changed after Set")
}
if !currentPrivate.Equal(newPrivate) {
t.Fatal("DiscoKey private key doesn't match the set key")
}
if currentPublic != newPrivate.Public() {
t.Fatal("DiscoKey public key doesn't match derived from set private key")
}
}
+2 -2
View File
@@ -697,7 +697,7 @@ func (de *endpoint) maybeProbeUDPLifetimeLocked() (afterInactivityFor time.Durat
// shuffling probing probability where the local node ends up with a large
// key value lexicographically relative to the other nodes it tends to
// communicate with. If de's disco key changes, the cycle will reset.
if de.c.discoPublic.Compare(epDisco.key) >= 0 {
if de.c.discoAtomic.Public().Compare(epDisco.key) >= 0 {
// lower disco pub key node probes higher
return afterInactivityFor, false
}
@@ -1739,7 +1739,7 @@ func (de *endpoint) handlePongConnLocked(m *disco.Pong, di *discoInfo, src epAdd
}
if sp.purpose != pingHeartbeat && sp.purpose != pingHeartbeatForUDPLifetime {
de.c.dlogf("[v1] magicsock: disco: %v<-%v (%v, %v) got pong tx=%x latency=%v pktlen=%v pong.src=%v%v", de.c.discoShort, de.discoShort(), de.publicKey.ShortString(), src, m.TxID[:6], latency.Round(time.Millisecond), pktLen, m.Src, logger.ArgWriter(func(bw *bufio.Writer) {
de.c.dlogf("[v1] magicsock: disco: %v<-%v (%v, %v) got pong tx=%x latency=%v pktlen=%v pong.src=%v%v", de.c.discoAtomic.Short(), de.discoShort(), de.publicKey.ShortString(), src, m.TxID[:6], latency.Round(time.Millisecond), pktLen, m.Src, logger.ArgWriter(func(bw *bufio.Writer) {
if sp.to != src {
fmt.Fprintf(bw, " ping.to=%v", sp.to)
}
+20 -5
View File
@@ -146,15 +146,22 @@ func TestProbeUDPLifetimeConfig_Valid(t *testing.T) {
}
func Test_endpoint_maybeProbeUDPLifetimeLocked(t *testing.T) {
var lowerPriv, higherPriv key.DiscoPrivate
var lower, higher key.DiscoPublic
a := key.NewDisco().Public()
b := key.NewDisco().Public()
privA := key.NewDisco()
privB := key.NewDisco()
a := privA.Public()
b := privB.Public()
if a.String() < b.String() {
lower = a
higher = b
lowerPriv = privA
higherPriv = privB
} else {
lower = b
higher = a
lowerPriv = privB
higherPriv = privA
}
addr := addrQuality{epAddr: epAddr{ap: netip.MustParseAddrPort("1.1.1.1:1")}}
newProbeUDPLifetime := func() *probeUDPLifetime {
@@ -281,10 +288,18 @@ func Test_endpoint_maybeProbeUDPLifetimeLocked(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &Conn{}
if tt.localDisco.IsZero() {
c.discoAtomic.Set(key.NewDisco())
} else if tt.localDisco.Compare(lower) == 0 {
c.discoAtomic.Set(lowerPriv)
} else if tt.localDisco.Compare(higher) == 0 {
c.discoAtomic.Set(higherPriv)
} else {
t.Fatalf("unexpected localDisco value")
}
de := &endpoint{
c: &Conn{
discoPublic: tt.localDisco,
},
c: c,
bestAddr: tt.bestAddr,
}
if tt.remoteDisco != nil {
+40 -23
View File
@@ -273,14 +273,8 @@ type Conn struct {
// channel operations and goroutine creation.
hasPeerRelayServers atomic.Bool
// discoPrivate is the private naclbox key used for active
// discovery traffic. It is always present, and immutable.
discoPrivate key.DiscoPrivate
// public of discoPrivate. It is always present and immutable.
discoPublic key.DiscoPublic
// ShortString of discoPublic (to save logging work later). It is always
// present and immutable.
discoShort string
// discoAtomic is the current disco private and public keypair for this conn.
discoAtomic discoAtomic
// ============================================================
// mu guards all following fields; see userspaceEngine lock
@@ -603,11 +597,9 @@ func newConn(logf logger.Logf) *Conn {
peerLastDerp: make(map[key.NodePublic]int),
peerMap: newPeerMap(),
discoInfo: make(map[key.DiscoPublic]*discoInfo),
discoPrivate: discoPrivate,
discoPublic: discoPrivate.Public(),
cloudInfo: newCloudInfo(logf),
}
c.discoShort = c.discoPublic.ShortString()
c.discoAtomic.Set(discoPrivate)
c.bind = &connBind{Conn: c, closed: true}
c.receiveBatchPool = sync.Pool{New: func() any {
msgs := make([]ipv6.Message, c.bind.BatchSize())
@@ -635,7 +627,7 @@ func (c *Conn) onUDPRelayAllocResp(allocResp UDPRelayAllocResp) {
// now versus taking a network round-trip through DERP.
selfNodeKey := c.publicKeyAtomic.Load()
if selfNodeKey.Compare(allocResp.ReqRxFromNodeKey) == 0 &&
allocResp.ReqRxFromDiscoKey.Compare(c.discoPublic) == 0 {
allocResp.ReqRxFromDiscoKey.Compare(c.discoAtomic.Public()) == 0 {
c.relayManager.handleRxDiscoMsg(c, allocResp.Message, selfNodeKey, allocResp.ReqRxFromDiscoKey, epAddr{})
metricLocalDiscoAllocUDPRelayEndpointResponse.Add(1)
}
@@ -765,7 +757,7 @@ func NewConn(opts Options) (*Conn, error) {
c.logf("[v1] couldn't create raw v6 disco listener, using regular listener instead: %v", err)
}
c.logf("magicsock: disco key = %v", c.discoShort)
c.logf("magicsock: disco key = %v", c.discoAtomic.Short())
return c, nil
}
@@ -1244,7 +1236,32 @@ func (c *Conn) GetEndpointChanges(peer tailcfg.NodeView) ([]EndpointChange, erro
// DiscoPublicKey returns the discovery public key.
func (c *Conn) DiscoPublicKey() key.DiscoPublic {
return c.discoPublic
return c.discoAtomic.Public()
}
// RotateDiscoKey generates a new discovery key pair and updates the connection
// to use it. This invalidates all existing disco sessions and will cause peers
// to re-establish discovery sessions with the new key.
//
// This is primarily for debugging and testing purposes, a future enhancement
// should provide a mechanism for seamless rotation by supporting short term use
// of the old key.
func (c *Conn) RotateDiscoKey() {
oldShort := c.discoAtomic.Short()
newPrivate := key.NewDisco()
c.mu.Lock()
c.discoAtomic.Set(newPrivate)
newShort := c.discoAtomic.Short()
c.discoInfo = make(map[key.DiscoPublic]*discoInfo)
connCtx := c.connCtx
c.mu.Unlock()
c.logf("magicsock: rotated disco key from %v to %v", oldShort, newShort)
if connCtx != nil {
c.ReSTUN("disco-key-rotation")
}
}
// determineEndpoints returns the machine's endpoint addresses. It does a STUN
@@ -1914,7 +1931,7 @@ func (c *Conn) sendDiscoAllocateUDPRelayEndpointRequest(dst epAddr, dstKey key.N
if isDERP && dstKey.Compare(selfNodeKey) == 0 {
c.allocRelayEndpointPub.Publish(UDPRelayAllocReq{
RxFromNodeKey: selfNodeKey,
RxFromDiscoKey: c.discoPublic,
RxFromDiscoKey: c.discoAtomic.Public(),
Message: allocReq,
})
metricLocalDiscoAllocUDPRelayEndpointRequest.Add(1)
@@ -1985,7 +2002,7 @@ func (c *Conn) sendDiscoMessage(dst epAddr, dstKey key.NodePublic, dstDisco key.
}
}
pkt = append(pkt, disco.Magic...)
pkt = c.discoPublic.AppendTo(pkt)
pkt = c.discoAtomic.Public().AppendTo(pkt)
if isDERP {
metricSendDiscoDERP.Add(1)
@@ -2003,7 +2020,7 @@ func (c *Conn) sendDiscoMessage(dst epAddr, dstKey key.NodePublic, dstDisco key.
if !dstKey.IsZero() {
node = dstKey.ShortString()
}
c.dlogf("[v1] magicsock: disco: %v->%v (%v, %v) sent %v len %v\n", c.discoShort, dstDisco.ShortString(), node, derpStr(dst.String()), disco.MessageSummary(m), len(pkt))
c.dlogf("[v1] magicsock: disco: %v->%v (%v, %v) sent %v len %v\n", c.discoAtomic.Short(), dstDisco.ShortString(), node, derpStr(dst.String()), disco.MessageSummary(m), len(pkt))
}
if isDERP {
metricSentDiscoDERP.Add(1)
@@ -2352,13 +2369,13 @@ func (c *Conn) handleDiscoMessage(msg []byte, src epAddr, shouldBeRelayHandshake
}
if isVia {
c.dlogf("[v1] magicsock: disco: %v<-%v via %v (%v, %v) got call-me-maybe-via, %d endpoints",
c.discoShort, epDisco.short, via.ServerDisco.ShortString(),
c.discoAtomic.Short(), epDisco.short, via.ServerDisco.ShortString(),
ep.publicKey.ShortString(), derpStr(src.String()),
len(via.AddrPorts))
c.relayManager.handleCallMeMaybeVia(ep, lastBest, lastBestIsTrusted, via)
} else {
c.dlogf("[v1] magicsock: disco: %v<-%v (%v, %v) got call-me-maybe, %d endpoints",
c.discoShort, epDisco.short,
c.discoAtomic.Short(), epDisco.short,
ep.publicKey.ShortString(), derpStr(src.String()),
len(cmm.MyNumber))
go ep.handleCallMeMaybe(cmm)
@@ -2404,7 +2421,7 @@ func (c *Conn) handleDiscoMessage(msg []byte, src epAddr, shouldBeRelayHandshake
if isResp {
c.dlogf("[v1] magicsock: disco: %v<-%v (%v, %v) got %s, %d endpoints",
c.discoShort, epDisco.short,
c.discoAtomic.Short(), epDisco.short,
ep.publicKey.ShortString(), derpStr(src.String()),
msgType,
len(resp.AddrPorts))
@@ -2418,7 +2435,7 @@ func (c *Conn) handleDiscoMessage(msg []byte, src epAddr, shouldBeRelayHandshake
return
} else {
c.dlogf("[v1] magicsock: disco: %v<-%v (%v, %v) got %s disco[0]=%v disco[1]=%v",
c.discoShort, epDisco.short,
c.discoAtomic.Short(), epDisco.short,
ep.publicKey.ShortString(), derpStr(src.String()),
msgType,
req.ClientDisco[0].ShortString(), req.ClientDisco[1].ShortString())
@@ -2583,7 +2600,7 @@ func (c *Conn) handlePingLocked(dm *disco.Ping, src epAddr, di *discoInfo, derpN
if numNodes > 1 {
pingNodeSrcStr = "[one-of-multi]"
}
c.dlogf("[v1] magicsock: disco: %v<-%v (%v, %v) got ping tx=%x padding=%v", c.discoShort, di.discoShort, pingNodeSrcStr, src, dm.TxID[:6], dm.Padding)
c.dlogf("[v1] magicsock: disco: %v<-%v (%v, %v) got ping tx=%x padding=%v", c.discoAtomic.Short(), di.discoShort, pingNodeSrcStr, src, dm.TxID[:6], dm.Padding)
}
ipDst := src
@@ -2656,7 +2673,7 @@ func (c *Conn) discoInfoForKnownPeerLocked(k key.DiscoPublic) *discoInfo {
di = &discoInfo{
discoKey: k,
discoShort: k.ShortString(),
sharedKey: c.discoPrivate.Shared(k),
sharedKey: c.discoAtomic.Private().Shared(k),
}
c.discoInfo[k] = di
}
+70
View File
@@ -4235,3 +4235,73 @@ func Test_lazyEndpoint_FromPeer(t *testing.T) {
})
}
}
func TestRotateDiscoKey(t *testing.T) {
c := newConn(t.Logf)
oldPrivate, oldPublic := c.discoAtomic.Pair()
oldShort := c.discoAtomic.Short()
if oldPublic != oldPrivate.Public() {
t.Fatalf("old public key doesn't match old private key")
}
if oldShort != oldPublic.ShortString() {
t.Fatalf("old short string doesn't match old public key")
}
testDiscoKey := key.NewDisco().Public()
c.mu.Lock()
c.discoInfo[testDiscoKey] = &discoInfo{
discoKey: testDiscoKey,
discoShort: testDiscoKey.ShortString(),
}
if len(c.discoInfo) != 1 {
t.Fatalf("expected 1 discoInfo entry, got %d", len(c.discoInfo))
}
c.mu.Unlock()
c.RotateDiscoKey()
newPrivate, newPublic := c.discoAtomic.Pair()
newShort := c.discoAtomic.Short()
if newPublic.Compare(oldPublic) == 0 {
t.Fatalf("disco key didn't change after rotation")
}
if newShort == oldShort {
t.Fatalf("short string didn't change after rotation")
}
if newPublic != newPrivate.Public() {
t.Fatalf("new public key doesn't match new private key")
}
if newShort != newPublic.ShortString() {
t.Fatalf("new short string doesn't match new public key")
}
c.mu.Lock()
if len(c.discoInfo) != 0 {
t.Fatalf("expected discoInfo to be cleared, got %d entries", len(c.discoInfo))
}
c.mu.Unlock()
}
func TestRotateDiscoKeyMultipleTimes(t *testing.T) {
c := newConn(t.Logf)
keys := make([]key.DiscoPublic, 0, 5)
keys = append(keys, c.discoAtomic.Public())
for i := 0; i < 4; i++ {
c.RotateDiscoKey()
newKey := c.discoAtomic.Public()
for j, oldKey := range keys {
if newKey.Compare(oldKey) == 0 {
t.Fatalf("rotation %d produced same key as rotation %d", i+1, j)
}
}
keys = append(keys, newKey)
}
}
+2 -2
View File
@@ -361,7 +361,7 @@ func (r *relayManager) ensureDiscoInfoFor(work *relayHandshakeWork) {
di.di = &discoInfo{
discoKey: work.se.ServerDisco,
discoShort: work.se.ServerDisco.ShortString(),
sharedKey: work.wlb.ep.c.discoPrivate.Shared(work.se.ServerDisco),
sharedKey: work.wlb.ep.c.discoAtomic.Private().Shared(work.se.ServerDisco),
}
}
}
@@ -1031,7 +1031,7 @@ func (r *relayManager) allocateAllServersRunLoop(wlb endpointWithLastBest) {
if remoteDisco == nil {
return
}
discoKeys := key.NewSortedPairOfDiscoPublic(wlb.ep.c.discoPublic, remoteDisco.key)
discoKeys := key.NewSortedPairOfDiscoPublic(wlb.ep.c.discoAtomic.Public(), remoteDisco.key)
for _, v := range r.serversByNodeKey {
byDiscoKeys, ok := r.allocWorkByDiscoKeysByServerNodeKey[v.nodeKey]
if !ok {
+6 -2
View File
@@ -22,11 +22,15 @@ func TestRelayManagerInitAndIdle(t *testing.T) {
<-rm.runLoopStoppedCh
rm = relayManager{}
rm.handleCallMeMaybeVia(&endpoint{c: &Conn{discoPrivate: key.NewDisco()}}, addrQuality{}, false, &disco.CallMeMaybeVia{UDPRelayEndpoint: disco.UDPRelayEndpoint{ServerDisco: key.NewDisco().Public()}})
c1 := &Conn{}
c1.discoAtomic.Set(key.NewDisco())
rm.handleCallMeMaybeVia(&endpoint{c: c1}, addrQuality{}, false, &disco.CallMeMaybeVia{UDPRelayEndpoint: disco.UDPRelayEndpoint{ServerDisco: key.NewDisco().Public()}})
<-rm.runLoopStoppedCh
rm = relayManager{}
rm.handleRxDiscoMsg(&Conn{discoPrivate: key.NewDisco()}, &disco.BindUDPRelayEndpointChallenge{}, key.NodePublic{}, key.DiscoPublic{}, epAddr{})
c2 := &Conn{}
c2.discoAtomic.Set(key.NewDisco())
rm.handleRxDiscoMsg(c2, &disco.BindUDPRelayEndpointChallenge{}, key.NodePublic{}, key.DiscoPublic{}, epAddr{})
<-rm.runLoopStoppedCh
rm = relayManager{}