net/tstun: add TSMPDiscoAdvertisement to TSMPPing (#17995)

Adds a new types of TSMP messages for advertising disco keys keys
to/from a peer, and implements the advertising triggered by a TSMP ping.

Needed as part of the effort to cache the netmap and still let clients
connect without control being reachable.

Updates #12639

Signed-off-by: Claus Lensbøl <claus@tailscale.com>
Co-authored-by: James Tucker <james@tailscale.com>
This commit is contained in:
Claus Lensbøl
2025-11-25 21:35:38 +01:00
committed by GitHub
parent b38dd1ae06
commit c54d243690
7 changed files with 280 additions and 25 deletions
+55
View File
@@ -15,7 +15,9 @@ import (
"fmt"
"net/netip"
"go4.org/mem"
"tailscale.com/types/ipproto"
"tailscale.com/types/key"
)
const minTSMPSize = 7 // the rejected body is 7 bytes
@@ -72,6 +74,9 @@ const (
// TSMPTypePong is the type byte for a TailscalePongResponse.
TSMPTypePong TSMPType = 'o'
// TSPMTypeDiscoAdvertisement is the type byte for sending disco keys
TSMPTypeDiscoAdvertisement TSMPType = 'a'
)
type TailscaleRejectReason byte
@@ -259,3 +264,53 @@ func (h TSMPPongReply) Marshal(buf []byte) error {
binary.BigEndian.PutUint16(buf[9:11], h.PeerAPIPort)
return nil
}
// TSMPDiscoKeyAdvertisement is a TSMP message that's used for distributing Disco Keys.
//
// On the wire, after the IP header, it's currently 33 bytes:
// - 'a' (TSMPTypeDiscoAdvertisement)
// - 32 disco key bytes
type TSMPDiscoKeyAdvertisement struct {
Src, Dst netip.Addr
Key key.DiscoPublic
}
func (ka *TSMPDiscoKeyAdvertisement) Marshal() ([]byte, error) {
var iph Header
if ka.Src.Is4() {
iph = IP4Header{
IPProto: ipproto.TSMP,
Src: ka.Src,
Dst: ka.Dst,
}
} else {
iph = IP6Header{
IPProto: ipproto.TSMP,
Src: ka.Src,
Dst: ka.Dst,
}
}
payload := make([]byte, 0, 33)
payload = append(payload, byte(TSMPTypeDiscoAdvertisement))
payload = ka.Key.AppendTo(payload)
if len(payload) != 33 {
// Mostly to safeguard against ourselves changing this in the future.
return []byte{}, fmt.Errorf("expected payload length 33, got %d", len(payload))
}
return Generate(iph, payload), nil
}
func (pp *Parsed) AsTSMPDiscoAdvertisement() (tka TSMPDiscoKeyAdvertisement, ok bool) {
if pp.IPProto != ipproto.TSMP {
return
}
p := pp.Payload()
if len(p) < 33 || p[0] != byte(TSMPTypeDiscoAdvertisement) {
return
}
tka.Src = pp.Src.Addr()
tka.Key = key.DiscoPublicFromRaw32(mem.B(p[1:33]))
return tka, true
}
+65
View File
@@ -4,8 +4,14 @@
package packet
import (
"bytes"
"encoding/hex"
"net/netip"
"slices"
"testing"
"go4.org/mem"
"tailscale.com/types/key"
)
func TestTailscaleRejectedHeader(t *testing.T) {
@@ -71,3 +77,62 @@ func TestTailscaleRejectedHeader(t *testing.T) {
}
}
}
func TestTSMPDiscoKeyAdvertisementMarshal(t *testing.T) {
var (
// IPv4: Ver(4)Len(5), TOS, Len(53), ID, Flags, TTL(64), Proto(99), Cksum
headerV4, _ = hex.DecodeString("45000035000000004063705d")
// IPv6: Ver(6)TCFlow, Len(33), NextHdr(99), HopLim(64)
headerV6, _ = hex.DecodeString("6000000000216340")
packetType = []byte{'a'}
testKey = bytes.Repeat([]byte{'a'}, 32)
// IPs
srcV4 = netip.MustParseAddr("1.2.3.4")
dstV4 = netip.MustParseAddr("4.3.2.1")
srcV6 = netip.MustParseAddr("2001:db8::1")
dstV6 = netip.MustParseAddr("2001:db8::2")
)
join := func(parts ...[]byte) []byte {
return bytes.Join(parts, nil)
}
tests := []struct {
name string
tka TSMPDiscoKeyAdvertisement
want []byte
}{
{
name: "v4Header",
tka: TSMPDiscoKeyAdvertisement{
Src: srcV4,
Dst: dstV4,
Key: key.DiscoPublicFromRaw32(mem.B(testKey)),
},
want: join(headerV4, srcV4.AsSlice(), dstV4.AsSlice(), packetType, testKey),
},
{
name: "v6Header",
tka: TSMPDiscoKeyAdvertisement{
Src: srcV6,
Dst: dstV6,
Key: key.DiscoPublicFromRaw32(mem.B(testKey)),
},
want: join(headerV6, srcV6.AsSlice(), dstV6.AsSlice(), packetType, testKey),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := tt.tka.Marshal()
if err != nil {
t.Errorf("error mashalling TSMPDiscoAdvertisement: %s", err)
}
if !slices.Equal(got, tt.want) {
t.Errorf("error mashalling TSMPDiscoAdvertisement, expected: \n%x, \ngot:\n%x", tt.want, got)
}
})
}
}
+24 -5
View File
@@ -34,6 +34,7 @@ import (
"tailscale.com/types/logger"
"tailscale.com/types/netlogfunc"
"tailscale.com/util/clientmetric"
"tailscale.com/util/eventbus"
"tailscale.com/util/usermetric"
"tailscale.com/wgengine/filter"
"tailscale.com/wgengine/netstack/gro"
@@ -209,6 +210,9 @@ type Wrapper struct {
captureHook syncs.AtomicValue[packet.CaptureCallback]
metrics *metrics
eventClient *eventbus.Client
discoKeyAdvertisementPub *eventbus.Publisher[DiscoKeyAdvertisement]
}
type metrics struct {
@@ -254,15 +258,15 @@ func (w *Wrapper) Start() {
close(w.startCh)
}
func WrapTAP(logf logger.Logf, tdev tun.Device, m *usermetric.Registry) *Wrapper {
return wrap(logf, tdev, true, m)
func WrapTAP(logf logger.Logf, tdev tun.Device, m *usermetric.Registry, bus *eventbus.Bus) *Wrapper {
return wrap(logf, tdev, true, m, bus)
}
func Wrap(logf logger.Logf, tdev tun.Device, m *usermetric.Registry) *Wrapper {
return wrap(logf, tdev, false, m)
func Wrap(logf logger.Logf, tdev tun.Device, m *usermetric.Registry, bus *eventbus.Bus) *Wrapper {
return wrap(logf, tdev, false, m, bus)
}
func wrap(logf logger.Logf, tdev tun.Device, isTAP bool, m *usermetric.Registry) *Wrapper {
func wrap(logf logger.Logf, tdev tun.Device, isTAP bool, m *usermetric.Registry, bus *eventbus.Bus) *Wrapper {
logf = logger.WithPrefix(logf, "tstun: ")
w := &Wrapper{
logf: logf,
@@ -283,6 +287,9 @@ func wrap(logf logger.Logf, tdev tun.Device, isTAP bool, m *usermetric.Registry)
metrics: registerMetrics(m),
}
w.eventClient = bus.Client("net.tstun")
w.discoKeyAdvertisementPub = eventbus.Publish[DiscoKeyAdvertisement](w.eventClient)
w.vectorBuffer = make([][]byte, tdev.BatchSize())
for i := range w.vectorBuffer {
w.vectorBuffer[i] = make([]byte, maxBufferSize)
@@ -357,6 +364,7 @@ func (t *Wrapper) Close() error {
close(t.vectorOutbound)
t.outboundMu.Unlock()
err = t.tdev.Close()
t.eventClient.Close()
})
return err
}
@@ -1118,6 +1126,11 @@ func (t *Wrapper) injectedRead(res tunInjectedRead, outBuffs [][]byte, sizes []i
return n, err
}
type DiscoKeyAdvertisement struct {
Src netip.Addr
Key key.DiscoPublic
}
func (t *Wrapper) filterPacketInboundFromWireGuard(p *packet.Parsed, captHook packet.CaptureCallback, pc *peerConfigTable, gro *gro.GRO) (filter.Response, *gro.GRO) {
if captHook != nil {
captHook(packet.FromPeer, t.now(), p.Buffer(), p.CaptureMeta)
@@ -1128,6 +1141,12 @@ func (t *Wrapper) filterPacketInboundFromWireGuard(p *packet.Parsed, captHook pa
t.noteActivity()
t.injectOutboundPong(p, pingReq)
return filter.DropSilently, gro
} else if discoKeyAdvert, ok := p.AsTSMPDiscoAdvertisement(); ok {
t.discoKeyAdvertisementPub.Publish(DiscoKeyAdvertisement{
Src: discoKeyAdvert.Src,
Key: discoKeyAdvert.Key,
})
return filter.DropSilently, gro
} else if data, ok := p.AsTSMPPong(); ok {
if f := t.OnTSMPPongReceived; f != nil {
f(data)
+48 -14
View File
@@ -36,6 +36,8 @@ import (
"tailscale.com/types/netlogtype"
"tailscale.com/types/ptr"
"tailscale.com/types/views"
"tailscale.com/util/eventbus"
"tailscale.com/util/eventbus/eventbustest"
"tailscale.com/util/must"
"tailscale.com/util/usermetric"
"tailscale.com/wgengine/filter"
@@ -170,10 +172,10 @@ func setfilter(logf logger.Logf, tun *Wrapper) {
tun.SetFilter(filter.New(matches, nil, ipSet, ipSet, nil, logf))
}
func newChannelTUN(logf logger.Logf, secure bool) (*tuntest.ChannelTUN, *Wrapper) {
func newChannelTUN(logf logger.Logf, bus *eventbus.Bus, secure bool) (*tuntest.ChannelTUN, *Wrapper) {
chtun := tuntest.NewChannelTUN()
reg := new(usermetric.Registry)
tun := Wrap(logf, chtun.TUN(), reg)
tun := Wrap(logf, chtun.TUN(), reg, bus)
if secure {
setfilter(logf, tun)
} else {
@@ -183,10 +185,10 @@ func newChannelTUN(logf logger.Logf, secure bool) (*tuntest.ChannelTUN, *Wrapper
return chtun, tun
}
func newFakeTUN(logf logger.Logf, secure bool) (*fakeTUN, *Wrapper) {
func newFakeTUN(logf logger.Logf, bus *eventbus.Bus, secure bool) (*fakeTUN, *Wrapper) {
ftun := NewFake()
reg := new(usermetric.Registry)
tun := Wrap(logf, ftun, reg)
tun := Wrap(logf, ftun, reg, bus)
if secure {
setfilter(logf, tun)
} else {
@@ -196,7 +198,8 @@ func newFakeTUN(logf logger.Logf, secure bool) (*fakeTUN, *Wrapper) {
}
func TestReadAndInject(t *testing.T) {
chtun, tun := newChannelTUN(t.Logf, false)
bus := eventbustest.NewBus(t)
chtun, tun := newChannelTUN(t.Logf, bus, false)
defer tun.Close()
const size = 2 // all payloads have this size
@@ -221,7 +224,7 @@ func TestReadAndInject(t *testing.T) {
}
var buf [MaxPacketSize]byte
var seen = make(map[string]bool)
seen := make(map[string]bool)
sizes := make([]int, 1)
// We expect the same packets back, in no particular order.
for i := range len(written) + len(injected) {
@@ -257,7 +260,8 @@ func TestReadAndInject(t *testing.T) {
}
func TestWriteAndInject(t *testing.T) {
chtun, tun := newChannelTUN(t.Logf, false)
bus := eventbustest.NewBus(t)
chtun, tun := newChannelTUN(t.Logf, bus, false)
defer tun.Close()
written := []string{"w0", "w1"}
@@ -316,8 +320,8 @@ func mustHexDecode(s string) []byte {
}
func TestFilter(t *testing.T) {
chtun, tun := newChannelTUN(t.Logf, true)
bus := eventbustest.NewBus(t)
chtun, tun := newChannelTUN(t.Logf, bus, true)
defer tun.Close()
// Reset the metrics before test. These are global
@@ -462,7 +466,8 @@ func assertMetricPackets(t *testing.T, metricName string, want, got int64) {
}
func TestAllocs(t *testing.T) {
ftun, tun := newFakeTUN(t.Logf, false)
bus := eventbustest.NewBus(t)
ftun, tun := newFakeTUN(t.Logf, bus, false)
defer tun.Close()
buf := [][]byte{{0x00}}
@@ -473,14 +478,14 @@ func TestAllocs(t *testing.T) {
return
}
})
if err != nil {
t.Error(err)
}
}
func TestClose(t *testing.T) {
ftun, tun := newFakeTUN(t.Logf, false)
bus := eventbustest.NewBus(t)
ftun, tun := newFakeTUN(t.Logf, bus, false)
data := [][]byte{udp4("1.2.3.4", "5.6.7.8", 98, 98)}
_, err := ftun.Write(data, 0)
@@ -497,7 +502,8 @@ func TestClose(t *testing.T) {
func BenchmarkWrite(b *testing.B) {
b.ReportAllocs()
ftun, tun := newFakeTUN(b.Logf, true)
bus := eventbustest.NewBus(b)
ftun, tun := newFakeTUN(b.Logf, bus, true)
defer tun.Close()
packet := [][]byte{udp4("5.6.7.8", "1.2.3.4", 89, 89)}
@@ -887,7 +893,8 @@ func TestCaptureHook(t *testing.T) {
now := time.Unix(1682085856, 0)
_, w := newFakeTUN(t.Logf, true)
bus := eventbustest.NewBus(t)
_, w := newFakeTUN(t.Logf, bus, true)
w.timeNow = func() time.Time {
return now
}
@@ -957,3 +964,30 @@ func TestCaptureHook(t *testing.T) {
captured, want)
}
}
func TestTSMPDisco(t *testing.T) {
t.Run("IPv6DiscoAdvert", func(t *testing.T) {
src := netip.MustParseAddr("2001:db8::1")
dst := netip.MustParseAddr("2001:db8::2")
discoKey := key.NewDisco()
buf, _ := (&packet.TSMPDiscoKeyAdvertisement{
Src: src,
Dst: dst,
Key: discoKey.Public(),
}).Marshal()
var p packet.Parsed
p.Decode(buf)
tda, ok := p.AsTSMPDiscoAdvertisement()
if !ok {
t.Error("Unable to parse message as TSMPDiscoAdversitement")
}
if tda.Src != src {
t.Errorf("Src address did not match, expected %v, got %v", src, tda.Src)
}
if !reflect.DeepEqual(tda.Key, discoKey.Public()) {
t.Errorf("Key did not match, expected %q, got %q", discoKey.Public(), tda.Key)
}
})
}