tsnet: add tests to TestListenService for user-supplied TUN devices

This resolves a gap in test coverage, ensuring Server.ListenService
functions as expected in combination with user-supplied TUN devices

Fixes tailscale/corp#36603

Co-authored-by: Harry Harpham <harry@tailscale.com>
Signed-off-by: Harry Harpham <harry@tailscale.com>
This commit is contained in:
James Tucker
2026-01-29 14:25:32 -08:00
committed by Harry Harpham
parent 5edfa6f9a8
commit 569caefeb5
+102 -91
View File
@@ -1141,83 +1141,91 @@ func TestListenService(t *testing.T) {
// This ends up also testing the Service forwarding logic in // This ends up also testing the Service forwarding logic in
// LocalBackend, but that's useful too. // LocalBackend, but that's useful too.
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
ctx := t.Context() // We run each test with and without a TUN device ([Server.Tun]).
// Note that this TUN device is distinct from TUN mode for Services.
doTest := func(t *testing.T, withTUNDevice bool) {
ctx := t.Context()
controlURL, control := startControl(t) lt := setupTwoClientTest(t, withTUNDevice)
serviceHost, _, _ := startServer(t, ctx, controlURL, "service-host") serviceHost := lt.s2
serviceClient, _, _ := startServer(t, ctx, controlURL, "service-client") serviceClient := lt.s1
control := lt.control
const serviceName = tailcfg.ServiceName("svc:foo") const serviceName = tailcfg.ServiceName("svc:foo")
const serviceVIP = "100.11.22.33" const serviceVIP = "100.11.22.33"
// == Set up necessary state in our mock == // == Set up necessary state in our mock ==
// The Service host must have the 'service-host' capability, which // The Service host must have the 'service-host' capability, which
// is a mapping from the Service name to the Service VIP. // is a mapping from the Service name to the Service VIP.
var serviceHostCaps map[tailcfg.ServiceName]views.Slice[netip.Addr] var serviceHostCaps map[tailcfg.ServiceName]views.Slice[netip.Addr]
mak.Set(&serviceHostCaps, serviceName, views.SliceOf([]netip.Addr{netip.MustParseAddr(serviceVIP)})) mak.Set(&serviceHostCaps, serviceName, views.SliceOf([]netip.Addr{netip.MustParseAddr(serviceVIP)}))
j := must.Get(json.Marshal(serviceHostCaps)) j := must.Get(json.Marshal(serviceHostCaps))
cm := serviceHost.lb.NetMap().SelfNode.CapMap().AsMap() cm := serviceHost.lb.NetMap().SelfNode.CapMap().AsMap()
mak.Set(&cm, tailcfg.NodeAttrServiceHost, []tailcfg.RawMessage{tailcfg.RawMessage(j)}) mak.Set(&cm, tailcfg.NodeAttrServiceHost, []tailcfg.RawMessage{tailcfg.RawMessage(j)})
control.SetNodeCapMap(serviceHost.lb.NodeKey(), cm) control.SetNodeCapMap(serviceHost.lb.NodeKey(), cm)
// The Service host must be allowed to advertise the Service VIP. // The Service host must be allowed to advertise the Service VIP.
control.SetSubnetRoutes(serviceHost.lb.NodeKey(), []netip.Prefix{ control.SetSubnetRoutes(serviceHost.lb.NodeKey(), []netip.Prefix{
netip.MustParsePrefix(serviceVIP + `/32`), netip.MustParsePrefix(serviceVIP + `/32`),
})
// The Service host must be a tagged node (any tag will do).
serviceHostNode := control.Node(serviceHost.lb.NodeKey())
serviceHostNode.Tags = append(serviceHostNode.Tags, "some-tag")
control.UpdateNode(serviceHostNode)
// The service client must accept routes advertised by other nodes
// (RouteAll is equivalent to --accept-routes).
must.Get(serviceClient.localClient.EditPrefs(ctx, &ipn.MaskedPrefs{
RouteAllSet: true,
Prefs: ipn.Prefs{
RouteAll: true,
},
}))
// Set up DNS for our Service.
control.AddDNSRecords(tailcfg.DNSRecord{
Name: serviceName.WithoutPrefix() + "." + control.MagicDNSDomain,
Value: serviceVIP,
})
if tt.extraSetup != nil {
tt.extraSetup(t, control)
}
// Force netmap updates to avoid race conditions. The nodes need to
// see our control updates before we can start the test.
must.Do(control.ForceNetmapUpdate(ctx, serviceHost.lb.NodeKey()))
must.Do(control.ForceNetmapUpdate(ctx, serviceClient.lb.NodeKey()))
netmapUpToDate := func(s *Server) bool {
nm := s.lb.NetMap()
return slices.ContainsFunc(nm.DNS.ExtraRecords, func(r tailcfg.DNSRecord) bool {
return r.Value == serviceVIP
}) })
}
for !netmapUpToDate(serviceClient) { // The Service host must be a tagged node (any tag will do).
time.Sleep(10 * time.Millisecond) serviceHostNode := control.Node(serviceHost.lb.NodeKey())
} serviceHostNode.Tags = append(serviceHostNode.Tags, "some-tag")
for !netmapUpToDate(serviceHost) { control.UpdateNode(serviceHostNode)
time.Sleep(10 * time.Millisecond)
// The service client must accept routes advertised by other nodes
// (RouteAll is equivalent to --accept-routes).
must.Get(serviceClient.localClient.EditPrefs(ctx, &ipn.MaskedPrefs{
RouteAllSet: true,
Prefs: ipn.Prefs{
RouteAll: true,
},
}))
// Set up DNS for our Service.
control.AddDNSRecords(tailcfg.DNSRecord{
Name: serviceName.WithoutPrefix() + "." + control.MagicDNSDomain,
Value: serviceVIP,
})
if tt.extraSetup != nil {
tt.extraSetup(t, control)
}
// Force netmap updates to avoid race conditions. The nodes need to
// see our control updates before we can start the test.
must.Do(control.ForceNetmapUpdate(ctx, serviceHost.lb.NodeKey()))
must.Do(control.ForceNetmapUpdate(ctx, serviceClient.lb.NodeKey()))
netmapUpToDate := func(s *Server) bool {
nm := s.lb.NetMap()
return slices.ContainsFunc(nm.DNS.ExtraRecords, func(r tailcfg.DNSRecord) bool {
return r.Value == serviceVIP
})
}
for !netmapUpToDate(serviceClient) {
time.Sleep(10 * time.Millisecond)
}
for !netmapUpToDate(serviceHost) {
time.Sleep(10 * time.Millisecond)
}
// == Done setting up mock state ==
// Start the Service listeners.
listeners := make([]*ServiceListener, 0, len(tt.modes))
for _, input := range tt.modes {
ln := must.Get(serviceHost.ListenService(serviceName.String(), input))
defer ln.Close()
listeners = append(listeners, ln)
}
tt.run(t, listeners, serviceClient)
} }
// == Done setting up mock state == t.Run("TUN", func(t *testing.T) { doTest(t, true) })
t.Run("netstack", func(t *testing.T) { doTest(t, false) })
// Start the Service listeners.
listeners := make([]*ServiceListener, 0, len(tt.modes))
for _, input := range tt.modes {
ln := must.Get(serviceHost.ListenService(serviceName.String(), input))
defer ln.Close()
listeners = append(listeners, ln)
}
tt.run(t, listeners, serviceClient)
}) })
} }
} }
@@ -1928,20 +1936,21 @@ func (t *chanTUN) BatchSize() int { return 1 }
// listenTest provides common setup for listener and TUN tests. // listenTest provides common setup for listener and TUN tests.
type listenTest struct { type listenTest struct {
control *testcontrol.Server
s1, s2 *Server s1, s2 *Server
s1ip4, s1ip6 netip.Addr s1ip4, s1ip6 netip.Addr
s2ip4, s2ip6 netip.Addr s2ip4, s2ip6 netip.Addr
tun *chanTUN // nil for netstack mode tun *chanTUN // nil for netstack mode
} }
// setupListenTest creates two tsnet servers for testing. // setupTwoClientTest creates two tsnet servers for testing.
// If useTUN is true, s2 uses a chanTUN; otherwise it uses netstack only. // If useTUN is true, s2 uses a chanTUN; otherwise it uses netstack only.
func setupListenTest(t *testing.T, useTUN bool) *listenTest { func setupTwoClientTest(t *testing.T, useTUN bool) *listenTest {
t.Helper() t.Helper()
tstest.Shard(t) tstest.Shard(t)
tstest.ResourceCheck(t) tstest.ResourceCheck(t)
ctx := t.Context() ctx := t.Context()
controlURL, _ := startControl(t) controlURL, control := startControl(t)
s1, _, _ := startServer(t, ctx, controlURL, "s1") s1, _, _ := startServer(t, ctx, controlURL, "s1")
tmp := filepath.Join(t.TempDir(), "s2") tmp := filepath.Join(t.TempDir(), "s2")
@@ -1969,6 +1978,7 @@ func setupListenTest(t *testing.T, useTUN bool) *listenTest {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
s2.lb.ConfigureCertsForTest(testCertRoot.getCert)
s1ip4, s1ip6 := s1.TailscaleIPs() s1ip4, s1ip6 := s1.TailscaleIPs()
s2ip4 := s2status.TailscaleIPs[0] s2ip4 := s2status.TailscaleIPs[0]
@@ -1981,13 +1991,14 @@ func setupListenTest(t *testing.T, useTUN bool) *listenTest {
must.Get(lc1.Ping(ctx, s2ip4, tailcfg.PingTSMP)) must.Get(lc1.Ping(ctx, s2ip4, tailcfg.PingTSMP))
return &listenTest{ return &listenTest{
s1: s1, control: control,
s2: s2, s1: s1,
s1ip4: s1ip4, s2: s2,
s1ip6: s1ip6, s1ip4: s1ip4,
s2ip4: s2ip4, s1ip6: s1ip6,
s2ip6: s2ip6, s2ip4: s2ip4,
tun: tun, s2ip6: s2ip6,
tun: tun,
} }
} }
@@ -2016,7 +2027,7 @@ func echoUDP(pkt []byte) []byte {
} }
func TestTUN(t *testing.T) { func TestTUN(t *testing.T) {
tt := setupListenTest(t, true) tt := setupTwoClientTest(t, true)
go func() { go func() {
for pkt := range tt.tun.Inbound { for pkt := range tt.tun.Inbound {
@@ -2059,7 +2070,7 @@ func TestTUN(t *testing.T) {
// responses. This verifies that handleLocalPackets intercepts outbound traffic // responses. This verifies that handleLocalPackets intercepts outbound traffic
// to the service IP. // to the service IP.
func TestTUNDNS(t *testing.T) { func TestTUNDNS(t *testing.T) {
tt := setupListenTest(t, true) tt := setupTwoClientTest(t, true)
test := func(t *testing.T, srcIP netip.Addr, serviceIP netip.Addr) { test := func(t *testing.T, srcIP netip.Addr, serviceIP netip.Addr) {
tt.tun.Outbound <- buildDNSQuery("s2", srcIP) tt.tun.Outbound <- buildDNSQuery("s2", srcIP)
@@ -2149,13 +2160,13 @@ func TestListenPacket(t *testing.T) {
} }
t.Run("Netstack", func(t *testing.T) { t.Run("Netstack", func(t *testing.T) {
lt := setupListenTest(t, false) lt := setupTwoClientTest(t, false)
t.Run("IPv4", func(t *testing.T) { testListenPacket(t, lt, lt.s2ip4) }) t.Run("IPv4", func(t *testing.T) { testListenPacket(t, lt, lt.s2ip4) })
t.Run("IPv6", func(t *testing.T) { testListenPacket(t, lt, lt.s2ip6) }) t.Run("IPv6", func(t *testing.T) { testListenPacket(t, lt, lt.s2ip6) })
}) })
t.Run("TUN", func(t *testing.T) { t.Run("TUN", func(t *testing.T) {
lt := setupListenTest(t, true) lt := setupTwoClientTest(t, true)
t.Run("IPv4", func(t *testing.T) { testListenPacket(t, lt, lt.s2ip4) }) t.Run("IPv4", func(t *testing.T) { testListenPacket(t, lt, lt.s2ip4) })
t.Run("IPv6", func(t *testing.T) { testListenPacket(t, lt, lt.s2ip6) }) t.Run("IPv6", func(t *testing.T) { testListenPacket(t, lt, lt.s2ip6) })
}) })
@@ -2221,13 +2232,13 @@ func TestListenTCP(t *testing.T) {
} }
t.Run("Netstack", func(t *testing.T) { t.Run("Netstack", func(t *testing.T) {
lt := setupListenTest(t, false) lt := setupTwoClientTest(t, false)
t.Run("IPv4", func(t *testing.T) { testListenTCP(t, lt, lt.s2ip4) }) t.Run("IPv4", func(t *testing.T) { testListenTCP(t, lt, lt.s2ip4) })
t.Run("IPv6", func(t *testing.T) { testListenTCP(t, lt, lt.s2ip6) }) t.Run("IPv6", func(t *testing.T) { testListenTCP(t, lt, lt.s2ip6) })
}) })
t.Run("TUN", func(t *testing.T) { t.Run("TUN", func(t *testing.T) {
lt := setupListenTest(t, true) lt := setupTwoClientTest(t, true)
t.Run("IPv4", func(t *testing.T) { testListenTCP(t, lt, lt.s2ip4) }) t.Run("IPv4", func(t *testing.T) { testListenTCP(t, lt, lt.s2ip4) })
t.Run("IPv6", func(t *testing.T) { testListenTCP(t, lt, lt.s2ip6) }) t.Run("IPv6", func(t *testing.T) { testListenTCP(t, lt, lt.s2ip6) })
}) })
@@ -2299,13 +2310,13 @@ func TestListenTCPDualStack(t *testing.T) {
} }
t.Run("Netstack", func(t *testing.T) { t.Run("Netstack", func(t *testing.T) {
lt := setupListenTest(t, false) lt := setupTwoClientTest(t, false)
t.Run("DialIPv4", func(t *testing.T) { testListenTCPDualStack(t, lt, lt.s2ip4) }) t.Run("DialIPv4", func(t *testing.T) { testListenTCPDualStack(t, lt, lt.s2ip4) })
t.Run("DialIPv6", func(t *testing.T) { testListenTCPDualStack(t, lt, lt.s2ip6) }) t.Run("DialIPv6", func(t *testing.T) { testListenTCPDualStack(t, lt, lt.s2ip6) })
}) })
t.Run("TUN", func(t *testing.T) { t.Run("TUN", func(t *testing.T) {
lt := setupListenTest(t, true) lt := setupTwoClientTest(t, true)
t.Run("DialIPv4", func(t *testing.T) { testListenTCPDualStack(t, lt, lt.s2ip4) }) t.Run("DialIPv4", func(t *testing.T) { testListenTCPDualStack(t, lt, lt.s2ip4) })
t.Run("DialIPv6", func(t *testing.T) { testListenTCPDualStack(t, lt, lt.s2ip6) }) t.Run("DialIPv6", func(t *testing.T) { testListenTCPDualStack(t, lt, lt.s2ip6) })
}) })
@@ -2372,13 +2383,13 @@ func TestDialTCP(t *testing.T) {
} }
t.Run("Netstack", func(t *testing.T) { t.Run("Netstack", func(t *testing.T) {
lt := setupListenTest(t, false) lt := setupTwoClientTest(t, false)
t.Run("IPv4", func(t *testing.T) { testDialTCP(t, lt, lt.s1ip4) }) t.Run("IPv4", func(t *testing.T) { testDialTCP(t, lt, lt.s1ip4) })
t.Run("IPv6", func(t *testing.T) { testDialTCP(t, lt, lt.s1ip6) }) t.Run("IPv6", func(t *testing.T) { testDialTCP(t, lt, lt.s1ip6) })
}) })
t.Run("TUN", func(t *testing.T) { t.Run("TUN", func(t *testing.T) {
lt := setupListenTest(t, true) lt := setupTwoClientTest(t, true)
var escapedTCPPackets atomic.Int32 var escapedTCPPackets atomic.Int32
var wg sync.WaitGroup var wg sync.WaitGroup
@@ -2460,13 +2471,13 @@ func TestDialUDP(t *testing.T) {
} }
t.Run("Netstack", func(t *testing.T) { t.Run("Netstack", func(t *testing.T) {
lt := setupListenTest(t, false) lt := setupTwoClientTest(t, false)
t.Run("IPv4", func(t *testing.T) { testDialUDP(t, lt, lt.s1ip4) }) t.Run("IPv4", func(t *testing.T) { testDialUDP(t, lt, lt.s1ip4) })
t.Run("IPv6", func(t *testing.T) { testDialUDP(t, lt, lt.s1ip6) }) t.Run("IPv6", func(t *testing.T) { testDialUDP(t, lt, lt.s1ip6) })
}) })
t.Run("TUN", func(t *testing.T) { t.Run("TUN", func(t *testing.T) {
lt := setupListenTest(t, true) lt := setupTwoClientTest(t, true)
var escapedUDPPackets atomic.Int32 var escapedUDPPackets atomic.Int32
var wg sync.WaitGroup var wg sync.WaitGroup