tsnet: add support for a user-supplied tun.Device
tsnet users can now provide a tun.Device, including any custom implementation that conforms to the interface. netstack has a new option CheckLocalTransportEndpoints that when used alongside a TUN enables netstack listens and dials to correctly capture traffic associated with those sockets. tsnet with a TUN sets this option, while all other builds leave this at false to preserve existing performance. Updates #18423 Signed-off-by: James Tucker <james@tailscale.com>
This commit is contained in:
committed by
James Tucker
parent
c062230cce
commit
63d563e734
+84
-4
@@ -26,6 +26,7 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/tailscale/wireguard-go/tun"
|
||||
"tailscale.com/client/local"
|
||||
"tailscale.com/control/controlclient"
|
||||
"tailscale.com/envknob"
|
||||
@@ -167,6 +168,11 @@ type Server struct {
|
||||
// that the control server will allow the node to adopt that tag.
|
||||
AdvertiseTags []string
|
||||
|
||||
// Tun, if non-nil, specifies a custom tun.Device to use for packet I/O.
|
||||
//
|
||||
// This field must be set before calling Start.
|
||||
Tun tun.Device
|
||||
|
||||
initOnce sync.Once
|
||||
initErr error
|
||||
lb *ipnlocal.LocalBackend
|
||||
@@ -659,6 +665,7 @@ func (s *Server) start() (reterr error) {
|
||||
s.dialer = &tsdial.Dialer{Logf: tsLogf} // mutated below (before used)
|
||||
s.dialer.SetBus(sys.Bus.Get())
|
||||
eng, err := wgengine.NewUserspaceEngine(tsLogf, wgengine.Config{
|
||||
Tun: s.Tun,
|
||||
EventBus: sys.Bus.Get(),
|
||||
ListenPort: s.Port,
|
||||
NetMon: s.netMon,
|
||||
@@ -682,8 +689,16 @@ func (s *Server) start() (reterr error) {
|
||||
}
|
||||
sys.Tun.Get().Start()
|
||||
sys.Set(ns)
|
||||
ns.ProcessLocalIPs = true
|
||||
ns.ProcessSubnets = true
|
||||
if s.Tun == nil {
|
||||
// Only process packets in netstack when using the default fake TUN.
|
||||
// When a TUN is provided, let packets flow through it instead.
|
||||
ns.ProcessLocalIPs = true
|
||||
ns.ProcessSubnets = true
|
||||
} else {
|
||||
// When using a TUN, check gVisor for registered endpoints to handle
|
||||
// packets for tsnet listeners and outbound connection replies.
|
||||
ns.CheckLocalTransportEndpoints = true
|
||||
}
|
||||
ns.GetTCPHandlerForFlow = s.getTCPHandlerForFlow
|
||||
ns.GetUDPHandlerForFlow = s.getUDPHandlerForFlow
|
||||
s.netstack = ns
|
||||
@@ -1072,10 +1087,34 @@ func (s *Server) ListenPacket(network, addr string) (net.PacketConn, error) {
|
||||
network = "udp6"
|
||||
}
|
||||
}
|
||||
if err := s.Start(); err != nil {
|
||||
|
||||
netLn, err := s.listen(network, addr, listenOnTailnet)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s.netstack.ListenPacket(network, ap.String())
|
||||
ln := netLn.(*listener)
|
||||
|
||||
pc, err := s.netstack.ListenPacket(network, ap.String())
|
||||
if err != nil {
|
||||
ln.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &udpPacketConn{
|
||||
PacketConn: pc,
|
||||
ln: ln,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// udpPacketConn wraps a net.PacketConn to unregister from s.listeners on Close.
|
||||
type udpPacketConn struct {
|
||||
net.PacketConn
|
||||
ln *listener
|
||||
}
|
||||
|
||||
func (c *udpPacketConn) Close() error {
|
||||
c.ln.Close()
|
||||
return c.PacketConn.Close()
|
||||
}
|
||||
|
||||
// ListenTLS announces only on the Tailscale network.
|
||||
@@ -1611,10 +1650,37 @@ func (s *Server) listen(network, addr string, lnOn listenOn) (net.Listener, erro
|
||||
closedc: make(chan struct{}),
|
||||
conn: make(chan net.Conn),
|
||||
}
|
||||
|
||||
// When using a TUN with TCP, create a gVisor TCP listener.
|
||||
if s.Tun != nil && (network == "" || network == "tcp" || network == "tcp4" || network == "tcp6") {
|
||||
var nsNetwork string
|
||||
nsAddr := host
|
||||
switch {
|
||||
case network == "tcp4" || network == "tcp6":
|
||||
nsNetwork = network
|
||||
case host.Addr().Is4():
|
||||
nsNetwork = "tcp4"
|
||||
case host.Addr().Is6():
|
||||
nsNetwork = "tcp6"
|
||||
default:
|
||||
// Wildcard address: use tcp6 for dual-stack (accepts both v4 and v6).
|
||||
nsNetwork = "tcp6"
|
||||
nsAddr = netip.AddrPortFrom(netip.IPv6Unspecified(), host.Port())
|
||||
}
|
||||
gonetLn, err := s.netstack.ListenTCP(nsNetwork, nsAddr.String())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("tsnet: %w", err)
|
||||
}
|
||||
ln.gonetLn = gonetLn
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
for _, key := range keys {
|
||||
if _, ok := s.listeners[key]; ok {
|
||||
s.mu.Unlock()
|
||||
if ln.gonetLn != nil {
|
||||
ln.gonetLn.Close()
|
||||
}
|
||||
return nil, fmt.Errorf("tsnet: listener already open for %s, %s", network, addr)
|
||||
}
|
||||
}
|
||||
@@ -1684,9 +1750,17 @@ type listener struct {
|
||||
conn chan net.Conn // unbuffered, never closed
|
||||
closedc chan struct{} // closed on [listener.Close]
|
||||
closed bool // guarded by s.mu
|
||||
|
||||
// gonetLn, if set, is the gonet.Listener that handles new connections.
|
||||
// gonetLn is set by [listen] when a TUN is in use and terminates the listener.
|
||||
// gonetLn is nil when TUN is nil.
|
||||
gonetLn net.Listener
|
||||
}
|
||||
|
||||
func (ln *listener) Accept() (net.Conn, error) {
|
||||
if ln.gonetLn != nil {
|
||||
return ln.gonetLn.Accept()
|
||||
}
|
||||
select {
|
||||
case c := <-ln.conn:
|
||||
return c, nil
|
||||
@@ -1696,6 +1770,9 @@ func (ln *listener) Accept() (net.Conn, error) {
|
||||
}
|
||||
|
||||
func (ln *listener) Addr() net.Addr {
|
||||
if ln.gonetLn != nil {
|
||||
return ln.gonetLn.Addr()
|
||||
}
|
||||
return addr{
|
||||
network: ln.keys[0].network,
|
||||
addr: ln.addr,
|
||||
@@ -1721,6 +1798,9 @@ func (ln *listener) closeLocked() error {
|
||||
}
|
||||
close(ln.closedc)
|
||||
ln.closed = true
|
||||
if ln.gonetLn != nil {
|
||||
ln.gonetLn.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -39,6 +39,7 @@ import (
|
||||
"github.com/google/go-cmp/cmp"
|
||||
dto "github.com/prometheus/client_model/go"
|
||||
"github.com/prometheus/common/expfmt"
|
||||
"github.com/tailscale/wireguard-go/tun"
|
||||
"golang.org/x/net/proxy"
|
||||
|
||||
"tailscale.com/client/local"
|
||||
@@ -48,11 +49,13 @@ import (
|
||||
"tailscale.com/ipn/ipnlocal"
|
||||
"tailscale.com/ipn/store/mem"
|
||||
"tailscale.com/net/netns"
|
||||
"tailscale.com/net/packet"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/tstest"
|
||||
"tailscale.com/tstest/deptest"
|
||||
"tailscale.com/tstest/integration"
|
||||
"tailscale.com/tstest/integration/testcontrol"
|
||||
"tailscale.com/types/ipproto"
|
||||
"tailscale.com/types/key"
|
||||
"tailscale.com/types/logger"
|
||||
"tailscale.com/types/views"
|
||||
@@ -1860,6 +1863,676 @@ func mustDirect(t *testing.T, logf logger.Logf, lc1, lc2 *local.Client) {
|
||||
t.Error("magicsock did not find a direct path from lc1 to lc2")
|
||||
}
|
||||
|
||||
// chanTUN is a tun.Device for testing that uses channels for packet I/O.
|
||||
// Inbound receives packets written to the TUN (from the perspective of the network stack).
|
||||
// Outbound is for injecting packets to be read from the TUN.
|
||||
type chanTUN struct {
|
||||
Inbound chan []byte // packets written to TUN
|
||||
Outbound chan []byte // packets to read from TUN
|
||||
closed chan struct{}
|
||||
events chan tun.Event
|
||||
}
|
||||
|
||||
func newChanTUN() *chanTUN {
|
||||
t := &chanTUN{
|
||||
Inbound: make(chan []byte, 10),
|
||||
Outbound: make(chan []byte, 10),
|
||||
closed: make(chan struct{}),
|
||||
events: make(chan tun.Event, 1),
|
||||
}
|
||||
t.events <- tun.EventUp
|
||||
return t
|
||||
}
|
||||
|
||||
func (t *chanTUN) File() *os.File { panic("not implemented") }
|
||||
|
||||
func (t *chanTUN) Close() error {
|
||||
select {
|
||||
case <-t.closed:
|
||||
default:
|
||||
close(t.closed)
|
||||
close(t.Inbound)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *chanTUN) Read(bufs [][]byte, sizes []int, offset int) (int, error) {
|
||||
select {
|
||||
case <-t.closed:
|
||||
return 0, io.EOF
|
||||
case pkt := <-t.Outbound:
|
||||
sizes[0] = copy(bufs[0][offset:], pkt)
|
||||
return 1, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (t *chanTUN) Write(bufs [][]byte, offset int) (int, error) {
|
||||
for _, buf := range bufs {
|
||||
pkt := buf[offset:]
|
||||
if len(pkt) == 0 {
|
||||
continue
|
||||
}
|
||||
select {
|
||||
case <-t.closed:
|
||||
return 0, errors.New("closed")
|
||||
case t.Inbound <- slices.Clone(pkt):
|
||||
}
|
||||
}
|
||||
return len(bufs), nil
|
||||
}
|
||||
|
||||
func (t *chanTUN) MTU() (int, error) { return 1280, nil }
|
||||
func (t *chanTUN) Name() (string, error) { return "chantun", nil }
|
||||
func (t *chanTUN) Events() <-chan tun.Event { return t.events }
|
||||
func (t *chanTUN) BatchSize() int { return 1 }
|
||||
|
||||
// listenTest provides common setup for listener and TUN tests.
|
||||
type listenTest struct {
|
||||
s1, s2 *Server
|
||||
s1ip4, s1ip6 netip.Addr
|
||||
s2ip4, s2ip6 netip.Addr
|
||||
tun *chanTUN // nil for netstack mode
|
||||
}
|
||||
|
||||
// setupListenTest creates two tsnet servers for testing.
|
||||
// If useTUN is true, s2 uses a chanTUN; otherwise it uses netstack only.
|
||||
func setupListenTest(t *testing.T, useTUN bool) *listenTest {
|
||||
t.Helper()
|
||||
tstest.Shard(t)
|
||||
tstest.ResourceCheck(t)
|
||||
ctx := t.Context()
|
||||
controlURL, _ := startControl(t)
|
||||
s1, _, _ := startServer(t, ctx, controlURL, "s1")
|
||||
|
||||
tmp := filepath.Join(t.TempDir(), "s2")
|
||||
must.Do(os.MkdirAll(tmp, 0755))
|
||||
s2 := &Server{
|
||||
Dir: tmp,
|
||||
ControlURL: controlURL,
|
||||
Hostname: "s2",
|
||||
Store: new(mem.Store),
|
||||
Ephemeral: true,
|
||||
}
|
||||
|
||||
var tun *chanTUN
|
||||
if useTUN {
|
||||
tun = newChanTUN()
|
||||
s2.Tun = tun
|
||||
}
|
||||
|
||||
if *verboseNodes {
|
||||
s2.Logf = t.Logf
|
||||
}
|
||||
t.Cleanup(func() { s2.Close() })
|
||||
|
||||
s2status, err := s2.Up(ctx)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
s1ip4, s1ip6 := s1.TailscaleIPs()
|
||||
s2ip4 := s2status.TailscaleIPs[0]
|
||||
var s2ip6 netip.Addr
|
||||
if len(s2status.TailscaleIPs) > 1 {
|
||||
s2ip6 = s2status.TailscaleIPs[1]
|
||||
}
|
||||
|
||||
lc1 := must.Get(s1.LocalClient())
|
||||
must.Get(lc1.Ping(ctx, s2ip4, tailcfg.PingTSMP))
|
||||
|
||||
return &listenTest{
|
||||
s1: s1,
|
||||
s2: s2,
|
||||
s1ip4: s1ip4,
|
||||
s1ip6: s1ip6,
|
||||
s2ip4: s2ip4,
|
||||
s2ip6: s2ip6,
|
||||
tun: tun,
|
||||
}
|
||||
}
|
||||
|
||||
// echoUDP returns an IP packet with src/dst and ports swapped, with checksums recomputed.
|
||||
func echoUDP(pkt []byte) []byte {
|
||||
var p packet.Parsed
|
||||
p.Decode(pkt)
|
||||
if p.IPProto != ipproto.UDP {
|
||||
return nil
|
||||
}
|
||||
switch p.IPVersion {
|
||||
case 4:
|
||||
h := p.UDP4Header()
|
||||
h.ToResponse()
|
||||
return packet.Generate(h, p.Payload())
|
||||
case 6:
|
||||
h := packet.UDP6Header{
|
||||
IP6Header: p.IP6Header(),
|
||||
SrcPort: p.Src.Port(),
|
||||
DstPort: p.Dst.Port(),
|
||||
}
|
||||
h.ToResponse()
|
||||
return packet.Generate(h, p.Payload())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestTUN(t *testing.T) {
|
||||
tt := setupListenTest(t, true)
|
||||
|
||||
go func() {
|
||||
for pkt := range tt.tun.Inbound {
|
||||
var p packet.Parsed
|
||||
p.Decode(pkt)
|
||||
if p.Dst.Port() == 9999 {
|
||||
tt.tun.Outbound <- echoUDP(pkt)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
test := func(t *testing.T, s2ip netip.Addr) {
|
||||
conn, err := tt.s1.Dial(t.Context(), "udp", netip.AddrPortFrom(s2ip, 9999).String())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
want := "hello from s1"
|
||||
if _, err := conn.Write([]byte(want)); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
conn.SetReadDeadline(time.Now().Add(5 * time.Second))
|
||||
got := make([]byte, 1024)
|
||||
n, err := conn.Read(got)
|
||||
if err != nil {
|
||||
t.Fatalf("reading echo response: %v", err)
|
||||
}
|
||||
if string(got[:n]) != want {
|
||||
t.Errorf("got %q, want %q", got[:n], want)
|
||||
}
|
||||
}
|
||||
|
||||
t.Run("IPv4", func(t *testing.T) { test(t, tt.s2ip4) })
|
||||
t.Run("IPv6", func(t *testing.T) { test(t, tt.s2ip6) })
|
||||
}
|
||||
|
||||
// TestTUNDNS tests that a TUN can send DNS queries to quad-100 and receive
|
||||
// responses. This verifies that handleLocalPackets intercepts outbound traffic
|
||||
// to the service IP.
|
||||
func TestTUNDNS(t *testing.T) {
|
||||
tt := setupListenTest(t, true)
|
||||
|
||||
test := func(t *testing.T, srcIP netip.Addr, serviceIP netip.Addr) {
|
||||
tt.tun.Outbound <- buildDNSQuery("s2", srcIP)
|
||||
|
||||
ipVersion := uint8(4)
|
||||
if srcIP.Is6() {
|
||||
ipVersion = 6
|
||||
}
|
||||
for {
|
||||
select {
|
||||
case pkt := <-tt.tun.Inbound:
|
||||
var p packet.Parsed
|
||||
p.Decode(pkt)
|
||||
if p.IPVersion != ipVersion || p.IPProto != ipproto.UDP {
|
||||
continue
|
||||
}
|
||||
if p.Src.Addr() == serviceIP && p.Src.Port() == 53 {
|
||||
if len(p.Payload()) < 12 {
|
||||
t.Fatalf("DNS response too short: %d bytes", len(p.Payload()))
|
||||
}
|
||||
return // success
|
||||
}
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("timeout waiting for DNS response")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
t.Run("IPv4", func(t *testing.T) {
|
||||
test(t, tt.s2ip4, netip.MustParseAddr("100.100.100.100"))
|
||||
})
|
||||
t.Run("IPv6", func(t *testing.T) {
|
||||
test(t, tt.s2ip6, netip.MustParseAddr("fd7a:115c:a1e0::53"))
|
||||
})
|
||||
}
|
||||
|
||||
// TestListenPacket tests UDP listeners (ListenPacket) in both netstack and TUN modes.
|
||||
func TestListenPacket(t *testing.T) {
|
||||
testListenPacket := func(t *testing.T, lt *listenTest, listenIP netip.Addr) {
|
||||
pc, err := lt.s2.ListenPacket("udp", netip.AddrPortFrom(listenIP, 0).String())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer pc.Close()
|
||||
|
||||
echoErr := make(chan error, 1)
|
||||
go func() {
|
||||
buf := make([]byte, 1500)
|
||||
n, addr, err := pc.ReadFrom(buf)
|
||||
if err != nil {
|
||||
echoErr <- err
|
||||
return
|
||||
}
|
||||
_, err = pc.WriteTo(buf[:n], addr)
|
||||
if err != nil {
|
||||
echoErr <- err
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
conn, err := lt.s1.Dial(t.Context(), "udp", pc.LocalAddr().String())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
want := "hello udp"
|
||||
if _, err := conn.Write([]byte(want)); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
conn.SetReadDeadline(time.Now().Add(5 * time.Second))
|
||||
got := make([]byte, 1024)
|
||||
n, err := conn.Read(got)
|
||||
if err != nil {
|
||||
select {
|
||||
case e := <-echoErr:
|
||||
t.Fatalf("echo error: %v; read error: %v", e, err)
|
||||
default:
|
||||
t.Fatalf("Read failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if string(got[:n]) != want {
|
||||
t.Errorf("got %q, want %q", got[:n], want)
|
||||
}
|
||||
}
|
||||
|
||||
t.Run("Netstack", func(t *testing.T) {
|
||||
lt := setupListenTest(t, false)
|
||||
t.Run("IPv4", func(t *testing.T) { testListenPacket(t, lt, lt.s2ip4) })
|
||||
t.Run("IPv6", func(t *testing.T) { testListenPacket(t, lt, lt.s2ip6) })
|
||||
})
|
||||
|
||||
t.Run("TUN", func(t *testing.T) {
|
||||
lt := setupListenTest(t, true)
|
||||
t.Run("IPv4", func(t *testing.T) { testListenPacket(t, lt, lt.s2ip4) })
|
||||
t.Run("IPv6", func(t *testing.T) { testListenPacket(t, lt, lt.s2ip6) })
|
||||
})
|
||||
}
|
||||
|
||||
// TestListenTCP tests TCP listeners with concrete addresses in both netstack
|
||||
// and TUN modes.
|
||||
func TestListenTCP(t *testing.T) {
|
||||
testListenTCP := func(t *testing.T, lt *listenTest, listenIP netip.Addr) {
|
||||
ln, err := lt.s2.Listen("tcp", netip.AddrPortFrom(listenIP, 0).String())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer ln.Close()
|
||||
|
||||
echoErr := make(chan error, 1)
|
||||
go func() {
|
||||
conn, err := ln.Accept()
|
||||
if err != nil {
|
||||
echoErr <- err
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
buf := make([]byte, 1024)
|
||||
n, err := conn.Read(buf)
|
||||
if err != nil {
|
||||
echoErr <- err
|
||||
return
|
||||
}
|
||||
_, err = conn.Write(buf[:n])
|
||||
if err != nil {
|
||||
echoErr <- err
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
conn, err := lt.s1.Dial(t.Context(), "tcp", ln.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatalf("Dial failed: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
want := "hello tcp"
|
||||
if _, err := conn.Write([]byte(want)); err != nil {
|
||||
t.Fatalf("Write failed: %v", err)
|
||||
}
|
||||
|
||||
conn.SetReadDeadline(time.Now().Add(5 * time.Second))
|
||||
got := make([]byte, 1024)
|
||||
n, err := conn.Read(got)
|
||||
if err != nil {
|
||||
select {
|
||||
case e := <-echoErr:
|
||||
t.Fatalf("echo error: %v; read error: %v", e, err)
|
||||
default:
|
||||
t.Fatalf("Read failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if string(got[:n]) != want {
|
||||
t.Errorf("got %q, want %q", got[:n], want)
|
||||
}
|
||||
}
|
||||
|
||||
t.Run("Netstack", func(t *testing.T) {
|
||||
lt := setupListenTest(t, false)
|
||||
t.Run("IPv4", func(t *testing.T) { testListenTCP(t, lt, lt.s2ip4) })
|
||||
t.Run("IPv6", func(t *testing.T) { testListenTCP(t, lt, lt.s2ip6) })
|
||||
})
|
||||
|
||||
t.Run("TUN", func(t *testing.T) {
|
||||
lt := setupListenTest(t, true)
|
||||
t.Run("IPv4", func(t *testing.T) { testListenTCP(t, lt, lt.s2ip4) })
|
||||
t.Run("IPv6", func(t *testing.T) { testListenTCP(t, lt, lt.s2ip6) })
|
||||
})
|
||||
}
|
||||
|
||||
// TestListenTCPDualStack tests TCP listeners with wildcard addresses (dual-stack)
|
||||
// in both netstack and TUN modes.
|
||||
func TestListenTCPDualStack(t *testing.T) {
|
||||
testListenTCPDualStack := func(t *testing.T, lt *listenTest, dialIP netip.Addr) {
|
||||
ln, err := lt.s2.Listen("tcp", ":0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer ln.Close()
|
||||
|
||||
_, portStr, err := net.SplitHostPort(ln.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatalf("parsing listener address %q: %v", ln.Addr().String(), err)
|
||||
}
|
||||
|
||||
echoErr := make(chan error, 1)
|
||||
go func() {
|
||||
conn, err := ln.Accept()
|
||||
if err != nil {
|
||||
echoErr <- err
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
buf := make([]byte, 1024)
|
||||
n, err := conn.Read(buf)
|
||||
if err != nil {
|
||||
echoErr <- err
|
||||
return
|
||||
}
|
||||
_, err = conn.Write(buf[:n])
|
||||
if err != nil {
|
||||
echoErr <- err
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
dialAddr := net.JoinHostPort(dialIP.String(), portStr)
|
||||
conn, err := lt.s1.Dial(t.Context(), "tcp", dialAddr)
|
||||
if err != nil {
|
||||
t.Fatalf("Dial(%q) failed: %v", dialAddr, err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
want := "hello tcp dualstack"
|
||||
if _, err := conn.Write([]byte(want)); err != nil {
|
||||
t.Fatalf("Write failed: %v", err)
|
||||
}
|
||||
|
||||
conn.SetReadDeadline(time.Now().Add(5 * time.Second))
|
||||
got := make([]byte, 1024)
|
||||
n, err := conn.Read(got)
|
||||
if err != nil {
|
||||
select {
|
||||
case e := <-echoErr:
|
||||
t.Fatalf("echo error: %v; read error: %v", e, err)
|
||||
default:
|
||||
t.Fatalf("Read failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if string(got[:n]) != want {
|
||||
t.Errorf("got %q, want %q", got[:n], want)
|
||||
}
|
||||
}
|
||||
|
||||
t.Run("Netstack", func(t *testing.T) {
|
||||
lt := setupListenTest(t, false)
|
||||
t.Run("DialIPv4", func(t *testing.T) { testListenTCPDualStack(t, lt, lt.s2ip4) })
|
||||
t.Run("DialIPv6", func(t *testing.T) { testListenTCPDualStack(t, lt, lt.s2ip6) })
|
||||
})
|
||||
|
||||
t.Run("TUN", func(t *testing.T) {
|
||||
lt := setupListenTest(t, true)
|
||||
t.Run("DialIPv4", func(t *testing.T) { testListenTCPDualStack(t, lt, lt.s2ip4) })
|
||||
t.Run("DialIPv6", func(t *testing.T) { testListenTCPDualStack(t, lt, lt.s2ip6) })
|
||||
})
|
||||
}
|
||||
|
||||
// TestDialTCP tests TCP dialing from s2 to s1 in both netstack and TUN modes.
|
||||
// In TUN mode, this verifies that outbound TCP connections and their replies
|
||||
// are handled by netstack without packets escaping to the TUN.
|
||||
func TestDialTCP(t *testing.T) {
|
||||
testDialTCP := func(t *testing.T, lt *listenTest, listenIP netip.Addr) {
|
||||
ln, err := lt.s1.Listen("tcp", netip.AddrPortFrom(listenIP, 0).String())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer ln.Close()
|
||||
|
||||
echoErr := make(chan error, 1)
|
||||
go func() {
|
||||
conn, err := ln.Accept()
|
||||
if err != nil {
|
||||
echoErr <- err
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
buf := make([]byte, 1024)
|
||||
n, err := conn.Read(buf)
|
||||
if err != nil {
|
||||
echoErr <- err
|
||||
return
|
||||
}
|
||||
_, err = conn.Write(buf[:n])
|
||||
if err != nil {
|
||||
echoErr <- err
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
conn, err := lt.s2.Dial(t.Context(), "tcp", ln.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatalf("Dial failed: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
want := "hello tcp dial"
|
||||
if _, err := conn.Write([]byte(want)); err != nil {
|
||||
t.Fatalf("Write failed: %v", err)
|
||||
}
|
||||
|
||||
conn.SetReadDeadline(time.Now().Add(5 * time.Second))
|
||||
got := make([]byte, 1024)
|
||||
n, err := conn.Read(got)
|
||||
if err != nil {
|
||||
select {
|
||||
case e := <-echoErr:
|
||||
t.Fatalf("echo error: %v; read error: %v", e, err)
|
||||
default:
|
||||
t.Fatalf("Read failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if string(got[:n]) != want {
|
||||
t.Errorf("got %q, want %q", got[:n], want)
|
||||
}
|
||||
}
|
||||
|
||||
t.Run("Netstack", func(t *testing.T) {
|
||||
lt := setupListenTest(t, false)
|
||||
t.Run("IPv4", func(t *testing.T) { testDialTCP(t, lt, lt.s1ip4) })
|
||||
t.Run("IPv6", func(t *testing.T) { testDialTCP(t, lt, lt.s1ip6) })
|
||||
})
|
||||
|
||||
t.Run("TUN", func(t *testing.T) {
|
||||
lt := setupListenTest(t, true)
|
||||
|
||||
var escapedTCPPackets atomic.Int32
|
||||
var wg sync.WaitGroup
|
||||
wg.Go(func() {
|
||||
for pkt := range lt.tun.Inbound {
|
||||
var p packet.Parsed
|
||||
p.Decode(pkt)
|
||||
if p.IPProto == ipproto.TCP {
|
||||
escapedTCPPackets.Add(1)
|
||||
t.Logf("TCP packet escaped to TUN: %v -> %v", p.Src, p.Dst)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("IPv4", func(t *testing.T) { testDialTCP(t, lt, lt.s1ip4) })
|
||||
t.Run("IPv6", func(t *testing.T) { testDialTCP(t, lt, lt.s1ip6) })
|
||||
|
||||
lt.tun.Close()
|
||||
wg.Wait()
|
||||
if escaped := escapedTCPPackets.Load(); escaped > 0 {
|
||||
t.Errorf("%d TCP packets escaped to TUN", escaped)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestDialUDP tests UDP dialing from s2 to s1 in both netstack and TUN modes.
|
||||
// In TUN mode, this verifies that outbound UDP connections register endpoints
|
||||
// with gVisor, allowing reply packets to be routed through netstack instead of
|
||||
// escaping to the TUN.
|
||||
func TestDialUDP(t *testing.T) {
|
||||
testDialUDP := func(t *testing.T, lt *listenTest, listenIP netip.Addr) {
|
||||
pc, err := lt.s1.ListenPacket("udp", netip.AddrPortFrom(listenIP, 0).String())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer pc.Close()
|
||||
|
||||
echoErr := make(chan error, 1)
|
||||
go func() {
|
||||
buf := make([]byte, 1500)
|
||||
n, addr, err := pc.ReadFrom(buf)
|
||||
if err != nil {
|
||||
echoErr <- err
|
||||
return
|
||||
}
|
||||
_, err = pc.WriteTo(buf[:n], addr)
|
||||
if err != nil {
|
||||
echoErr <- err
|
||||
return
|
||||
}
|
||||
}()
|
||||
|
||||
conn, err := lt.s2.Dial(t.Context(), "udp", pc.LocalAddr().String())
|
||||
if err != nil {
|
||||
t.Fatalf("Dial failed: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
want := "hello udp dial"
|
||||
if _, err := conn.Write([]byte(want)); err != nil {
|
||||
t.Fatalf("Write failed: %v", err)
|
||||
}
|
||||
|
||||
conn.SetReadDeadline(time.Now().Add(5 * time.Second))
|
||||
got := make([]byte, 1024)
|
||||
n, err := conn.Read(got)
|
||||
if err != nil {
|
||||
select {
|
||||
case e := <-echoErr:
|
||||
t.Fatalf("echo error: %v; read error: %v", e, err)
|
||||
default:
|
||||
t.Fatalf("Read failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if string(got[:n]) != want {
|
||||
t.Errorf("got %q, want %q", got[:n], want)
|
||||
}
|
||||
}
|
||||
|
||||
t.Run("Netstack", func(t *testing.T) {
|
||||
lt := setupListenTest(t, false)
|
||||
t.Run("IPv4", func(t *testing.T) { testDialUDP(t, lt, lt.s1ip4) })
|
||||
t.Run("IPv6", func(t *testing.T) { testDialUDP(t, lt, lt.s1ip6) })
|
||||
})
|
||||
|
||||
t.Run("TUN", func(t *testing.T) {
|
||||
lt := setupListenTest(t, true)
|
||||
|
||||
var escapedUDPPackets atomic.Int32
|
||||
var wg sync.WaitGroup
|
||||
wg.Go(func() {
|
||||
for pkt := range lt.tun.Inbound {
|
||||
var p packet.Parsed
|
||||
p.Decode(pkt)
|
||||
if p.IPProto == ipproto.UDP {
|
||||
escapedUDPPackets.Add(1)
|
||||
t.Logf("UDP packet escaped to TUN: %v -> %v", p.Src, p.Dst)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("IPv4", func(t *testing.T) { testDialUDP(t, lt, lt.s1ip4) })
|
||||
t.Run("IPv6", func(t *testing.T) { testDialUDP(t, lt, lt.s1ip6) })
|
||||
|
||||
lt.tun.Close()
|
||||
wg.Wait()
|
||||
if escaped := escapedUDPPackets.Load(); escaped > 0 {
|
||||
t.Errorf("%d UDP packets escaped to TUN", escaped)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// buildDNSQuery builds a UDP/IP packet containing a DNS query for name to the
|
||||
// Tailscale service IP (100.100.100.100 for IPv4, fd7a:115c:a1e0::53 for IPv6).
|
||||
func buildDNSQuery(name string, srcIP netip.Addr) []byte {
|
||||
qtype := byte(0x01) // Type A for IPv4
|
||||
if srcIP.Is6() {
|
||||
qtype = 0x1c // Type AAAA for IPv6
|
||||
}
|
||||
dns := []byte{
|
||||
0x12, 0x34, // ID
|
||||
0x01, 0x00, // Flags: standard query, recursion desired
|
||||
0x00, 0x01, // QDCOUNT: 1
|
||||
0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // ANCOUNT, NSCOUNT, ARCOUNT
|
||||
}
|
||||
for _, label := range strings.Split(name, ".") {
|
||||
dns = append(dns, byte(len(label)))
|
||||
dns = append(dns, label...)
|
||||
}
|
||||
dns = append(dns, 0x00, 0x00, qtype, 0x00, 0x01) // null, Type A/AAAA, Class IN
|
||||
|
||||
if srcIP.Is4() {
|
||||
h := packet.UDP4Header{
|
||||
IP4Header: packet.IP4Header{
|
||||
Src: srcIP,
|
||||
Dst: netip.MustParseAddr("100.100.100.100"),
|
||||
},
|
||||
SrcPort: 12345,
|
||||
DstPort: 53,
|
||||
}
|
||||
return packet.Generate(h, dns)
|
||||
}
|
||||
h := packet.UDP6Header{
|
||||
IP6Header: packet.IP6Header{
|
||||
Src: srcIP,
|
||||
Dst: netip.MustParseAddr("fd7a:115c:a1e0::53"),
|
||||
},
|
||||
SrcPort: 12345,
|
||||
DstPort: 53,
|
||||
}
|
||||
return packet.Generate(h, dns)
|
||||
}
|
||||
|
||||
func TestDeps(t *testing.T) {
|
||||
tstest.Shard(t)
|
||||
deptest.DepChecker{
|
||||
|
||||
Reference in New Issue
Block a user