net/tstun: add TSMPDiscoAdvertisement to TSMPPing (#17995)
Adds a new types of TSMP messages for advertising disco keys keys to/from a peer, and implements the advertising triggered by a TSMP ping. Needed as part of the effort to cache the netmap and still let clients connect without control being reachable. Updates #12639 Signed-off-by: Claus Lensbøl <claus@tailscale.com> Co-authored-by: James Tucker <james@tailscale.com>
This commit is contained in:
@@ -15,7 +15,9 @@ import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
|
||||
"go4.org/mem"
|
||||
"tailscale.com/types/ipproto"
|
||||
"tailscale.com/types/key"
|
||||
)
|
||||
|
||||
const minTSMPSize = 7 // the rejected body is 7 bytes
|
||||
@@ -72,6 +74,9 @@ const (
|
||||
|
||||
// TSMPTypePong is the type byte for a TailscalePongResponse.
|
||||
TSMPTypePong TSMPType = 'o'
|
||||
|
||||
// TSPMTypeDiscoAdvertisement is the type byte for sending disco keys
|
||||
TSMPTypeDiscoAdvertisement TSMPType = 'a'
|
||||
)
|
||||
|
||||
type TailscaleRejectReason byte
|
||||
@@ -259,3 +264,53 @@ func (h TSMPPongReply) Marshal(buf []byte) error {
|
||||
binary.BigEndian.PutUint16(buf[9:11], h.PeerAPIPort)
|
||||
return nil
|
||||
}
|
||||
|
||||
// TSMPDiscoKeyAdvertisement is a TSMP message that's used for distributing Disco Keys.
|
||||
//
|
||||
// On the wire, after the IP header, it's currently 33 bytes:
|
||||
// - 'a' (TSMPTypeDiscoAdvertisement)
|
||||
// - 32 disco key bytes
|
||||
type TSMPDiscoKeyAdvertisement struct {
|
||||
Src, Dst netip.Addr
|
||||
Key key.DiscoPublic
|
||||
}
|
||||
|
||||
func (ka *TSMPDiscoKeyAdvertisement) Marshal() ([]byte, error) {
|
||||
var iph Header
|
||||
if ka.Src.Is4() {
|
||||
iph = IP4Header{
|
||||
IPProto: ipproto.TSMP,
|
||||
Src: ka.Src,
|
||||
Dst: ka.Dst,
|
||||
}
|
||||
} else {
|
||||
iph = IP6Header{
|
||||
IPProto: ipproto.TSMP,
|
||||
Src: ka.Src,
|
||||
Dst: ka.Dst,
|
||||
}
|
||||
}
|
||||
payload := make([]byte, 0, 33)
|
||||
payload = append(payload, byte(TSMPTypeDiscoAdvertisement))
|
||||
payload = ka.Key.AppendTo(payload)
|
||||
if len(payload) != 33 {
|
||||
// Mostly to safeguard against ourselves changing this in the future.
|
||||
return []byte{}, fmt.Errorf("expected payload length 33, got %d", len(payload))
|
||||
}
|
||||
|
||||
return Generate(iph, payload), nil
|
||||
}
|
||||
|
||||
func (pp *Parsed) AsTSMPDiscoAdvertisement() (tka TSMPDiscoKeyAdvertisement, ok bool) {
|
||||
if pp.IPProto != ipproto.TSMP {
|
||||
return
|
||||
}
|
||||
p := pp.Payload()
|
||||
if len(p) < 33 || p[0] != byte(TSMPTypeDiscoAdvertisement) {
|
||||
return
|
||||
}
|
||||
tka.Src = pp.Src.Addr()
|
||||
tka.Key = key.DiscoPublicFromRaw32(mem.B(p[1:33]))
|
||||
|
||||
return tka, true
|
||||
}
|
||||
|
||||
@@ -4,8 +4,14 @@
|
||||
package packet
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/hex"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"go4.org/mem"
|
||||
"tailscale.com/types/key"
|
||||
)
|
||||
|
||||
func TestTailscaleRejectedHeader(t *testing.T) {
|
||||
@@ -71,3 +77,62 @@ func TestTailscaleRejectedHeader(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestTSMPDiscoKeyAdvertisementMarshal(t *testing.T) {
|
||||
var (
|
||||
// IPv4: Ver(4)Len(5), TOS, Len(53), ID, Flags, TTL(64), Proto(99), Cksum
|
||||
headerV4, _ = hex.DecodeString("45000035000000004063705d")
|
||||
// IPv6: Ver(6)TCFlow, Len(33), NextHdr(99), HopLim(64)
|
||||
headerV6, _ = hex.DecodeString("6000000000216340")
|
||||
|
||||
packetType = []byte{'a'}
|
||||
testKey = bytes.Repeat([]byte{'a'}, 32)
|
||||
|
||||
// IPs
|
||||
srcV4 = netip.MustParseAddr("1.2.3.4")
|
||||
dstV4 = netip.MustParseAddr("4.3.2.1")
|
||||
srcV6 = netip.MustParseAddr("2001:db8::1")
|
||||
dstV6 = netip.MustParseAddr("2001:db8::2")
|
||||
)
|
||||
|
||||
join := func(parts ...[]byte) []byte {
|
||||
return bytes.Join(parts, nil)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
tka TSMPDiscoKeyAdvertisement
|
||||
want []byte
|
||||
}{
|
||||
{
|
||||
name: "v4Header",
|
||||
tka: TSMPDiscoKeyAdvertisement{
|
||||
Src: srcV4,
|
||||
Dst: dstV4,
|
||||
Key: key.DiscoPublicFromRaw32(mem.B(testKey)),
|
||||
},
|
||||
want: join(headerV4, srcV4.AsSlice(), dstV4.AsSlice(), packetType, testKey),
|
||||
},
|
||||
{
|
||||
name: "v6Header",
|
||||
tka: TSMPDiscoKeyAdvertisement{
|
||||
Src: srcV6,
|
||||
Dst: dstV6,
|
||||
Key: key.DiscoPublicFromRaw32(mem.B(testKey)),
|
||||
},
|
||||
want: join(headerV6, srcV6.AsSlice(), dstV6.AsSlice(), packetType, testKey),
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := tt.tka.Marshal()
|
||||
if err != nil {
|
||||
t.Errorf("error mashalling TSMPDiscoAdvertisement: %s", err)
|
||||
}
|
||||
if !slices.Equal(got, tt.want) {
|
||||
t.Errorf("error mashalling TSMPDiscoAdvertisement, expected: \n%x, \ngot:\n%x", tt.want, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
+24
-5
@@ -34,6 +34,7 @@ import (
|
||||
"tailscale.com/types/logger"
|
||||
"tailscale.com/types/netlogfunc"
|
||||
"tailscale.com/util/clientmetric"
|
||||
"tailscale.com/util/eventbus"
|
||||
"tailscale.com/util/usermetric"
|
||||
"tailscale.com/wgengine/filter"
|
||||
"tailscale.com/wgengine/netstack/gro"
|
||||
@@ -209,6 +210,9 @@ type Wrapper struct {
|
||||
captureHook syncs.AtomicValue[packet.CaptureCallback]
|
||||
|
||||
metrics *metrics
|
||||
|
||||
eventClient *eventbus.Client
|
||||
discoKeyAdvertisementPub *eventbus.Publisher[DiscoKeyAdvertisement]
|
||||
}
|
||||
|
||||
type metrics struct {
|
||||
@@ -254,15 +258,15 @@ func (w *Wrapper) Start() {
|
||||
close(w.startCh)
|
||||
}
|
||||
|
||||
func WrapTAP(logf logger.Logf, tdev tun.Device, m *usermetric.Registry) *Wrapper {
|
||||
return wrap(logf, tdev, true, m)
|
||||
func WrapTAP(logf logger.Logf, tdev tun.Device, m *usermetric.Registry, bus *eventbus.Bus) *Wrapper {
|
||||
return wrap(logf, tdev, true, m, bus)
|
||||
}
|
||||
|
||||
func Wrap(logf logger.Logf, tdev tun.Device, m *usermetric.Registry) *Wrapper {
|
||||
return wrap(logf, tdev, false, m)
|
||||
func Wrap(logf logger.Logf, tdev tun.Device, m *usermetric.Registry, bus *eventbus.Bus) *Wrapper {
|
||||
return wrap(logf, tdev, false, m, bus)
|
||||
}
|
||||
|
||||
func wrap(logf logger.Logf, tdev tun.Device, isTAP bool, m *usermetric.Registry) *Wrapper {
|
||||
func wrap(logf logger.Logf, tdev tun.Device, isTAP bool, m *usermetric.Registry, bus *eventbus.Bus) *Wrapper {
|
||||
logf = logger.WithPrefix(logf, "tstun: ")
|
||||
w := &Wrapper{
|
||||
logf: logf,
|
||||
@@ -283,6 +287,9 @@ func wrap(logf logger.Logf, tdev tun.Device, isTAP bool, m *usermetric.Registry)
|
||||
metrics: registerMetrics(m),
|
||||
}
|
||||
|
||||
w.eventClient = bus.Client("net.tstun")
|
||||
w.discoKeyAdvertisementPub = eventbus.Publish[DiscoKeyAdvertisement](w.eventClient)
|
||||
|
||||
w.vectorBuffer = make([][]byte, tdev.BatchSize())
|
||||
for i := range w.vectorBuffer {
|
||||
w.vectorBuffer[i] = make([]byte, maxBufferSize)
|
||||
@@ -357,6 +364,7 @@ func (t *Wrapper) Close() error {
|
||||
close(t.vectorOutbound)
|
||||
t.outboundMu.Unlock()
|
||||
err = t.tdev.Close()
|
||||
t.eventClient.Close()
|
||||
})
|
||||
return err
|
||||
}
|
||||
@@ -1118,6 +1126,11 @@ func (t *Wrapper) injectedRead(res tunInjectedRead, outBuffs [][]byte, sizes []i
|
||||
return n, err
|
||||
}
|
||||
|
||||
type DiscoKeyAdvertisement struct {
|
||||
Src netip.Addr
|
||||
Key key.DiscoPublic
|
||||
}
|
||||
|
||||
func (t *Wrapper) filterPacketInboundFromWireGuard(p *packet.Parsed, captHook packet.CaptureCallback, pc *peerConfigTable, gro *gro.GRO) (filter.Response, *gro.GRO) {
|
||||
if captHook != nil {
|
||||
captHook(packet.FromPeer, t.now(), p.Buffer(), p.CaptureMeta)
|
||||
@@ -1128,6 +1141,12 @@ func (t *Wrapper) filterPacketInboundFromWireGuard(p *packet.Parsed, captHook pa
|
||||
t.noteActivity()
|
||||
t.injectOutboundPong(p, pingReq)
|
||||
return filter.DropSilently, gro
|
||||
} else if discoKeyAdvert, ok := p.AsTSMPDiscoAdvertisement(); ok {
|
||||
t.discoKeyAdvertisementPub.Publish(DiscoKeyAdvertisement{
|
||||
Src: discoKeyAdvert.Src,
|
||||
Key: discoKeyAdvert.Key,
|
||||
})
|
||||
return filter.DropSilently, gro
|
||||
} else if data, ok := p.AsTSMPPong(); ok {
|
||||
if f := t.OnTSMPPongReceived; f != nil {
|
||||
f(data)
|
||||
|
||||
+48
-14
@@ -36,6 +36,8 @@ import (
|
||||
"tailscale.com/types/netlogtype"
|
||||
"tailscale.com/types/ptr"
|
||||
"tailscale.com/types/views"
|
||||
"tailscale.com/util/eventbus"
|
||||
"tailscale.com/util/eventbus/eventbustest"
|
||||
"tailscale.com/util/must"
|
||||
"tailscale.com/util/usermetric"
|
||||
"tailscale.com/wgengine/filter"
|
||||
@@ -170,10 +172,10 @@ func setfilter(logf logger.Logf, tun *Wrapper) {
|
||||
tun.SetFilter(filter.New(matches, nil, ipSet, ipSet, nil, logf))
|
||||
}
|
||||
|
||||
func newChannelTUN(logf logger.Logf, secure bool) (*tuntest.ChannelTUN, *Wrapper) {
|
||||
func newChannelTUN(logf logger.Logf, bus *eventbus.Bus, secure bool) (*tuntest.ChannelTUN, *Wrapper) {
|
||||
chtun := tuntest.NewChannelTUN()
|
||||
reg := new(usermetric.Registry)
|
||||
tun := Wrap(logf, chtun.TUN(), reg)
|
||||
tun := Wrap(logf, chtun.TUN(), reg, bus)
|
||||
if secure {
|
||||
setfilter(logf, tun)
|
||||
} else {
|
||||
@@ -183,10 +185,10 @@ func newChannelTUN(logf logger.Logf, secure bool) (*tuntest.ChannelTUN, *Wrapper
|
||||
return chtun, tun
|
||||
}
|
||||
|
||||
func newFakeTUN(logf logger.Logf, secure bool) (*fakeTUN, *Wrapper) {
|
||||
func newFakeTUN(logf logger.Logf, bus *eventbus.Bus, secure bool) (*fakeTUN, *Wrapper) {
|
||||
ftun := NewFake()
|
||||
reg := new(usermetric.Registry)
|
||||
tun := Wrap(logf, ftun, reg)
|
||||
tun := Wrap(logf, ftun, reg, bus)
|
||||
if secure {
|
||||
setfilter(logf, tun)
|
||||
} else {
|
||||
@@ -196,7 +198,8 @@ func newFakeTUN(logf logger.Logf, secure bool) (*fakeTUN, *Wrapper) {
|
||||
}
|
||||
|
||||
func TestReadAndInject(t *testing.T) {
|
||||
chtun, tun := newChannelTUN(t.Logf, false)
|
||||
bus := eventbustest.NewBus(t)
|
||||
chtun, tun := newChannelTUN(t.Logf, bus, false)
|
||||
defer tun.Close()
|
||||
|
||||
const size = 2 // all payloads have this size
|
||||
@@ -221,7 +224,7 @@ func TestReadAndInject(t *testing.T) {
|
||||
}
|
||||
|
||||
var buf [MaxPacketSize]byte
|
||||
var seen = make(map[string]bool)
|
||||
seen := make(map[string]bool)
|
||||
sizes := make([]int, 1)
|
||||
// We expect the same packets back, in no particular order.
|
||||
for i := range len(written) + len(injected) {
|
||||
@@ -257,7 +260,8 @@ func TestReadAndInject(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestWriteAndInject(t *testing.T) {
|
||||
chtun, tun := newChannelTUN(t.Logf, false)
|
||||
bus := eventbustest.NewBus(t)
|
||||
chtun, tun := newChannelTUN(t.Logf, bus, false)
|
||||
defer tun.Close()
|
||||
|
||||
written := []string{"w0", "w1"}
|
||||
@@ -316,8 +320,8 @@ func mustHexDecode(s string) []byte {
|
||||
}
|
||||
|
||||
func TestFilter(t *testing.T) {
|
||||
|
||||
chtun, tun := newChannelTUN(t.Logf, true)
|
||||
bus := eventbustest.NewBus(t)
|
||||
chtun, tun := newChannelTUN(t.Logf, bus, true)
|
||||
defer tun.Close()
|
||||
|
||||
// Reset the metrics before test. These are global
|
||||
@@ -462,7 +466,8 @@ func assertMetricPackets(t *testing.T, metricName string, want, got int64) {
|
||||
}
|
||||
|
||||
func TestAllocs(t *testing.T) {
|
||||
ftun, tun := newFakeTUN(t.Logf, false)
|
||||
bus := eventbustest.NewBus(t)
|
||||
ftun, tun := newFakeTUN(t.Logf, bus, false)
|
||||
defer tun.Close()
|
||||
|
||||
buf := [][]byte{{0x00}}
|
||||
@@ -473,14 +478,14 @@ func TestAllocs(t *testing.T) {
|
||||
return
|
||||
}
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClose(t *testing.T) {
|
||||
ftun, tun := newFakeTUN(t.Logf, false)
|
||||
bus := eventbustest.NewBus(t)
|
||||
ftun, tun := newFakeTUN(t.Logf, bus, false)
|
||||
|
||||
data := [][]byte{udp4("1.2.3.4", "5.6.7.8", 98, 98)}
|
||||
_, err := ftun.Write(data, 0)
|
||||
@@ -497,7 +502,8 @@ func TestClose(t *testing.T) {
|
||||
|
||||
func BenchmarkWrite(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
ftun, tun := newFakeTUN(b.Logf, true)
|
||||
bus := eventbustest.NewBus(b)
|
||||
ftun, tun := newFakeTUN(b.Logf, bus, true)
|
||||
defer tun.Close()
|
||||
|
||||
packet := [][]byte{udp4("5.6.7.8", "1.2.3.4", 89, 89)}
|
||||
@@ -887,7 +893,8 @@ func TestCaptureHook(t *testing.T) {
|
||||
|
||||
now := time.Unix(1682085856, 0)
|
||||
|
||||
_, w := newFakeTUN(t.Logf, true)
|
||||
bus := eventbustest.NewBus(t)
|
||||
_, w := newFakeTUN(t.Logf, bus, true)
|
||||
w.timeNow = func() time.Time {
|
||||
return now
|
||||
}
|
||||
@@ -957,3 +964,30 @@ func TestCaptureHook(t *testing.T) {
|
||||
captured, want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTSMPDisco(t *testing.T) {
|
||||
t.Run("IPv6DiscoAdvert", func(t *testing.T) {
|
||||
src := netip.MustParseAddr("2001:db8::1")
|
||||
dst := netip.MustParseAddr("2001:db8::2")
|
||||
discoKey := key.NewDisco()
|
||||
buf, _ := (&packet.TSMPDiscoKeyAdvertisement{
|
||||
Src: src,
|
||||
Dst: dst,
|
||||
Key: discoKey.Public(),
|
||||
}).Marshal()
|
||||
|
||||
var p packet.Parsed
|
||||
p.Decode(buf)
|
||||
|
||||
tda, ok := p.AsTSMPDiscoAdvertisement()
|
||||
if !ok {
|
||||
t.Error("Unable to parse message as TSMPDiscoAdversitement")
|
||||
}
|
||||
if tda.Src != src {
|
||||
t.Errorf("Src address did not match, expected %v, got %v", src, tda.Src)
|
||||
}
|
||||
if !reflect.DeepEqual(tda.Key, discoKey.Public()) {
|
||||
t.Errorf("Key did not match, expected %q, got %q", discoKey.Public(), tda.Key)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user