diff --git a/tsnet/tsnet.go b/tsnet/tsnet.go index 416c90750..776854e22 100644 --- a/tsnet/tsnet.go +++ b/tsnet/tsnet.go @@ -196,6 +196,7 @@ type Server struct { mu sync.Mutex listeners map[listenKey]*listener + nextEphemeralPort uint16 // next port to try in ephemeral range; 0 means use ephemeralPortFirst fallbackTCPHandlers set.HandleSet[FallbackTCPHandler] dialer *tsdial.Dialer closeOnce sync.Once @@ -1099,16 +1100,27 @@ func (s *Server) ListenPacket(network, addr string) (net.PacketConn, error) { network = "udp6" } } + if err := s.Start(); err != nil { + return nil, err + } - netLn, err := s.listen(network, addr, listenOnTailnet) + // Create the gVisor PacketConn first so it can handle port 0 allocation. + pc, err := s.netstack.ListenPacket(network, ap.String()) if err != nil { return nil, err } - ln := netLn.(*listener) - pc, err := s.netstack.ListenPacket(network, ap.String()) + // If port 0 was requested, use the port gVisor assigned. + if ap.Port() == 0 { + if p := portFromAddr(pc.LocalAddr()); p != 0 { + ap = netip.AddrPortFrom(ap.Addr(), p) + addr = ap.String() + } + } + + ln, err := s.registerListener(network, addr, ap, listenOnTailnet, nil) if err != nil { - ln.Close() + pc.Close() return nil, err } @@ -1621,6 +1633,11 @@ func resolveListenAddr(network, addr string) (netip.AddrPort, error) { if err != nil { return zero, fmt.Errorf("invalid Listen addr %q; host part must be empty or IP literal", host) } + // Normalize unspecified addresses (0.0.0.0, ::) to the zero value, + // equivalent to an empty host, so they match the node's own IPs. + if bindHostOrZero.IsUnspecified() { + return netip.AddrPortFrom(netip.Addr{}, uint16(port)), nil + } if strings.HasSuffix(network, "4") && !bindHostOrZero.Is4() { return zero, fmt.Errorf("invalid non-IPv4 addr %v for network %q", host, network) } @@ -1630,6 +1647,17 @@ func resolveListenAddr(network, addr string) (netip.AddrPort, error) { return netip.AddrPortFrom(bindHostOrZero, uint16(port)), nil } +// ephemeral port range for non-TUN listeners requesting port 0. This range is +// chosen to reduce the probability of collision with host listeners, avoiding +// both the typical ephemeral range, and privilege listener ranges. Collisions +// may still occur and could for example shadow host sockets in a netstack+TUN +// situation, the range here is a UX improvement, not a guarantee that +// application authors will never have to consider these cases. +const ( + ephemeralPortFirst = 10002 + ephemeralPortLast = 19999 +) + func (s *Server) listen(network, addr string, lnOn listenOn) (net.Listener, error) { switch network { case "", "tcp", "tcp4", "tcp6", "udp", "udp4", "udp6": @@ -1643,6 +1671,76 @@ func (s *Server) listen(network, addr string, lnOn listenOn) (net.Listener, erro if err := s.Start(); err != nil { return nil, err } + + isTCP := network == "" || network == "tcp" || network == "tcp4" || network == "tcp6" + + // When using a TUN with TCP, create a gVisor TCP listener. + // gVisor handles port 0 allocation natively. + var gonetLn net.Listener + if s.Tun != nil && isTCP { + gonetLn, err = s.listenTCP(network, host) + if err != nil { + return nil, err + } + // If port 0 was requested, update host to the port gVisor assigned + // so that the listenKey uses the real port. + if host.Port() == 0 { + if p := portFromAddr(gonetLn.Addr()); p != 0 { + host = netip.AddrPortFrom(host.Addr(), p) + addr = listenAddr(host) + } + } + } + + ln, err := s.registerListener(network, addr, host, lnOn, gonetLn) + if err != nil { + if gonetLn != nil { + gonetLn.Close() + } + return nil, err + } + return ln, nil +} + +// listenTCP creates a gVisor TCP listener for TUN mode. +func (s *Server) listenTCP(network string, host netip.AddrPort) (net.Listener, error) { + var nsNetwork string + nsAddr := host + switch { + case network == "tcp4" || network == "tcp6": + nsNetwork = network + case host.Addr().Is4(): + nsNetwork = "tcp4" + case host.Addr().Is6(): + nsNetwork = "tcp6" + default: + // Wildcard address: use tcp6 for dual-stack (accepts both v4 and v6). + nsNetwork = "tcp6" + nsAddr = netip.AddrPortFrom(netip.IPv6Unspecified(), host.Port()) + } + ln, err := s.netstack.ListenTCP(nsNetwork, nsAddr.String()) + if err != nil { + return nil, fmt.Errorf("tsnet: %w", err) + } + return ln, nil +} + +// registerListener allocates a port (if 0) and registers the listener in +// s.listeners under s.mu. +func (s *Server) registerListener(network, addr string, host netip.AddrPort, lnOn listenOn, gonetLn net.Listener) (*listener, error) { + s.mu.Lock() + defer s.mu.Unlock() + + // Allocate an ephemeral port for non-TUN listeners requesting port 0. + if host.Port() == 0 && gonetLn == nil { + p, ok := s.allocEphemeralLocked(network, host.Addr(), lnOn) + if !ok { + return nil, errors.New("tsnet: no available port in ephemeral range") + } + host = netip.AddrPortFrom(host.Addr(), p) + addr = listenAddr(host) + } + var keys []listenKey switch lnOn { case listenOnTailnet: @@ -1654,58 +1752,93 @@ func (s *Server) listen(network, addr string, lnOn listenOn) (net.Listener, erro keys = append(keys, listenKey{network, host.Addr(), host.Port(), true}) } - ln := &listener{ - s: s, - keys: keys, - addr: addr, - - closedc: make(chan struct{}), - conn: make(chan net.Conn), - } - - // When using a TUN with TCP, create a gVisor TCP listener. - if s.Tun != nil && (network == "" || network == "tcp" || network == "tcp4" || network == "tcp6") { - var nsNetwork string - nsAddr := host - switch { - case network == "tcp4" || network == "tcp6": - nsNetwork = network - case host.Addr().Is4(): - nsNetwork = "tcp4" - case host.Addr().Is6(): - nsNetwork = "tcp6" - default: - // Wildcard address: use tcp6 for dual-stack (accepts both v4 and v6). - nsNetwork = "tcp6" - nsAddr = netip.AddrPortFrom(netip.IPv6Unspecified(), host.Port()) - } - gonetLn, err := s.netstack.ListenTCP(nsNetwork, nsAddr.String()) - if err != nil { - return nil, fmt.Errorf("tsnet: %w", err) - } - ln.gonetLn = gonetLn - } - - s.mu.Lock() for _, key := range keys { if _, ok := s.listeners[key]; ok { - s.mu.Unlock() - if ln.gonetLn != nil { - ln.gonetLn.Close() - } return nil, fmt.Errorf("tsnet: listener already open for %s, %s", network, addr) } } + + ln := &listener{ + s: s, + keys: keys, + addr: addr, + closedc: make(chan struct{}), + conn: make(chan net.Conn), + gonetLn: gonetLn, + } if s.listeners == nil { s.listeners = make(map[listenKey]*listener) } for _, key := range keys { s.listeners[key] = ln } - s.mu.Unlock() return ln, nil } +// allocEphemeralLocked finds an unused port in [ephemeralPortFirst, +// ephemeralPortLast] that does not collide with any existing listener for the +// given network, host, and listenOn. s.mu must be held. +func (s *Server) allocEphemeralLocked(network string, host netip.Addr, lnOn listenOn) (uint16, bool) { + if s.nextEphemeralPort < ephemeralPortFirst || s.nextEphemeralPort > ephemeralPortLast { + s.nextEphemeralPort = ephemeralPortFirst + } + start := s.nextEphemeralPort + for { + p := s.nextEphemeralPort + s.nextEphemeralPort++ + if s.nextEphemeralPort > ephemeralPortLast { + s.nextEphemeralPort = ephemeralPortFirst + } + if !s.portInUseLocked(network, host, p, lnOn) { + return p, true + } + if s.nextEphemeralPort == start { + return 0, false + } + } +} + +// portInUseLocked reports whether any listenKey for the given network, host, +// port, and listenOn already exists in s.listeners. +func (s *Server) portInUseLocked(network string, host netip.Addr, port uint16, lnOn listenOn) bool { + switch lnOn { + case listenOnTailnet: + _, ok := s.listeners[listenKey{network, host, port, false}] + return ok + case listenOnFunnel: + _, ok := s.listeners[listenKey{network, host, port, true}] + return ok + case listenOnBoth: + _, ok1 := s.listeners[listenKey{network, host, port, false}] + _, ok2 := s.listeners[listenKey{network, host, port, true}] + return ok1 || ok2 + } + return false +} + +// listenAddr formats host as a listen address string. +// If host has no IP, it returns ":port". +func listenAddr(host netip.AddrPort) string { + if !host.Addr().IsValid() { + return ":" + strconv.Itoa(int(host.Port())) + } + return host.String() +} + +// portFromAddr extracts the port from a net.Addr, or returns 0. +func portFromAddr(a net.Addr) uint16 { + switch v := a.(type) { + case *net.TCPAddr: + return uint16(v.Port) + case *net.UDPAddr: + return uint16(v.Port) + } + if ap, err := netip.ParseAddrPort(a.String()); err == nil { + return ap.Port() + } + return 0 +} + // GetRootPath returns the root path of the tsnet server. // This is where the state file and other data is stored. func (s *Server) GetRootPath() string { diff --git a/tsnet/tsnet_test.go b/tsnet/tsnet_test.go index 9481defae..266a60f78 100644 --- a/tsnet/tsnet_test.go +++ b/tsnet/tsnet_test.go @@ -112,6 +112,86 @@ func TestListenerPort(t *testing.T) { } } +func TestResolveListenAddrUnspecified(t *testing.T) { + tests := []struct { + name string + network string + addr string + wantIP netip.Addr + }{ + {"empty_host", "tcp", ":80", netip.Addr{}}, + {"ipv4_unspecified", "tcp", "0.0.0.0:80", netip.Addr{}}, + {"ipv6_unspecified", "tcp", "[::]:80", netip.Addr{}}, + {"specific_ipv4", "tcp", "100.64.0.1:80", netip.MustParseAddr("100.64.0.1")}, + {"specific_ipv6", "tcp6", "[fd7a:115c:a1e0::1]:80", netip.MustParseAddr("fd7a:115c:a1e0::1")}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := resolveListenAddr(tt.network, tt.addr) + if err != nil { + t.Fatal(err) + } + if got.Addr() != tt.wantIP { + t.Errorf("Addr() = %v, want %v", got.Addr(), tt.wantIP) + } + }) + } +} + +func TestAllocEphemeral(t *testing.T) { + s := &Server{listeners: make(map[listenKey]*listener)} + + // Sequential allocations should return unique ports in range. + var ports []uint16 + for range 5 { + s.mu.Lock() + p, ok := s.allocEphemeralLocked("tcp", netip.Addr{}, listenOnTailnet) + s.mu.Unlock() + if !ok { + t.Fatal("allocEphemeralLocked failed unexpectedly") + } + if p < ephemeralPortFirst || p > ephemeralPortLast { + t.Errorf("port %d outside [%d, %d]", p, ephemeralPortFirst, ephemeralPortLast) + } + for _, prev := range ports { + if p == prev { + t.Errorf("duplicate port %d", p) + } + } + ports = append(ports, p) + // Occupy the port so the next call skips it. + s.listeners[listenKey{"tcp", netip.Addr{}, p, false}] = &listener{} + } + + // Verify skip over occupied port. + s.mu.Lock() + next := s.nextEphemeralPort + if next < ephemeralPortFirst || next > ephemeralPortLast { + next = ephemeralPortFirst + } + s.listeners[listenKey{"tcp", netip.Addr{}, next, false}] = &listener{} + p, ok := s.allocEphemeralLocked("tcp", netip.Addr{}, listenOnTailnet) + s.mu.Unlock() + if !ok { + t.Fatal("allocEphemeralLocked failed after skip") + } + if p == next { + t.Errorf("should have skipped occupied port %d", next) + } + + // Wrap-around. + s.mu.Lock() + s.nextEphemeralPort = ephemeralPortLast + p, ok = s.allocEphemeralLocked("tcp", netip.Addr{}, listenOnTailnet) + s.mu.Unlock() + if !ok { + t.Fatal("allocEphemeralLocked failed at wrap") + } + if p < ephemeralPortFirst || p > ephemeralPortLast { + t.Errorf("port %d outside range after wrap", p) + } +} + var verboseDERP = flag.Bool("verbose-derp", false, "if set, print DERP and STUN logs") var verboseNodes = flag.Bool("verbose-nodes", false, "if set, print tsnet.Server logs") @@ -2869,3 +2949,159 @@ func TestSelfDial(t *testing.T) { t.Errorf("server->client: got %q, want %q", gotReply, reply) } } + +// TestListenUnspecifiedAddr verifies that listening on 0.0.0.0 or [::] works +// the same as listening on an empty host (":port"), accepting connections +// destined to the node's Tailscale IPs. +func TestListenUnspecifiedAddr(t *testing.T) { + testUnspec := func(t *testing.T, lt *listenTest, addr, dialPort string) { + ln, err := lt.s2.Listen("tcp", addr) + if err != nil { + t.Fatal(err) + } + defer ln.Close() + + echoErr := make(chan error, 1) + go func() { + conn, err := ln.Accept() + if err != nil { + echoErr <- err + return + } + defer conn.Close() + buf := make([]byte, 1024) + n, err := conn.Read(buf) + if err != nil { + echoErr <- err + return + } + _, err = conn.Write(buf[:n]) + echoErr <- err + }() + + dialAddr := net.JoinHostPort(lt.s2ip4.String(), dialPort) + conn, err := lt.s1.Dial(t.Context(), "tcp", dialAddr) + if err != nil { + t.Fatalf("Dial(%q) failed: %v", dialAddr, err) + } + defer conn.Close() + want := "hello unspec" + if _, err := conn.Write([]byte(want)); err != nil { + t.Fatalf("Write failed: %v", err) + } + conn.SetReadDeadline(time.Now().Add(5 * time.Second)) + got := make([]byte, 1024) + n, err := conn.Read(got) + if err != nil { + t.Fatalf("Read failed: %v", err) + } + if string(got[:n]) != want { + t.Errorf("got %q, want %q", got[:n], want) + } + if err := <-echoErr; err != nil { + t.Fatalf("echo error: %v", err) + } + } + + t.Run("Netstack", func(t *testing.T) { + lt := setupTwoClientTest(t, false) + t.Run("0.0.0.0", func(t *testing.T) { testUnspec(t, lt, "0.0.0.0:8080", "8080") }) + t.Run("::", func(t *testing.T) { testUnspec(t, lt, "[::]:8081", "8081") }) + }) + t.Run("TUN", func(t *testing.T) { + lt := setupTwoClientTest(t, true) + t.Run("0.0.0.0", func(t *testing.T) { testUnspec(t, lt, "0.0.0.0:8080", "8080") }) + t.Run("::", func(t *testing.T) { testUnspec(t, lt, "[::]:8081", "8081") }) + }) +} + +// TestListenMultipleEphemeralPorts verifies that calling Listen with port 0 +// multiple times allocates distinct ports, each of which can receive +// connections independently. +func TestListenMultipleEphemeralPorts(t *testing.T) { + testMultipleEphemeral := func(t *testing.T, lt *listenTest) { + const n = 3 + listeners := make([]net.Listener, n) + ports := make([]string, n) + for i := range n { + ln, err := lt.s2.Listen("tcp", ":0") + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { ln.Close() }) + _, portStr, err := net.SplitHostPort(ln.Addr().String()) + if err != nil { + t.Fatalf("parsing Addr %q: %v", ln.Addr(), err) + } + if portStr == "0" { + t.Fatal("Addr() returned port 0; expected allocated port") + } + for j := range i { + if ports[j] == portStr { + t.Fatalf("listeners %d and %d both got port %s", j, i, portStr) + } + } + listeners[i] = ln + ports[i] = portStr + } + + // Verify each listener independently accepts connections. + for i := range n { + echoErr := make(chan error, 1) + go func() { + conn, err := listeners[i].Accept() + if err != nil { + echoErr <- err + return + } + defer conn.Close() + buf := make([]byte, 1024) + rn, err := conn.Read(buf) + if err != nil { + echoErr <- err + return + } + _, err = conn.Write(buf[:rn]) + echoErr <- err + }() + + dialAddr := net.JoinHostPort(lt.s2ip4.String(), ports[i]) + conn, err := lt.s1.Dial(t.Context(), "tcp", dialAddr) + if err != nil { + t.Fatalf("listener %d: Dial(%q) failed: %v", i, dialAddr, err) + } + want := fmt.Sprintf("hello port %d", i) + if _, err := conn.Write([]byte(want)); err != nil { + conn.Close() + t.Fatalf("listener %d: Write failed: %v", i, err) + } + conn.SetReadDeadline(time.Now().Add(5 * time.Second)) + got := make([]byte, 1024) + rn, err := conn.Read(got) + conn.Close() + if err != nil { + select { + case e := <-echoErr: + t.Fatalf("listener %d: echo error: %v; read error: %v", i, e, err) + default: + t.Fatalf("listener %d: Read failed: %v", i, err) + } + } + if string(got[:rn]) != want { + t.Errorf("listener %d: got %q, want %q", i, got[:rn], want) + } + if err := <-echoErr; err != nil { + t.Fatalf("listener %d: echo error: %v", i, err) + } + } + } + + t.Run("Netstack", func(t *testing.T) { + lt := setupTwoClientTest(t, false) + testMultipleEphemeral(t, lt) + }) + t.Run("TUN", func(t *testing.T) { + lt := setupTwoClientTest(t, true) + testMultipleEphemeral(t, lt) + }) +}