tsnet: add support for a user-supplied tun.Device

tsnet users can now provide a tun.Device, including any custom
implementation that conforms to the interface.

netstack has a new option CheckLocalTransportEndpoints that when used
alongside a TUN enables netstack listens and dials to correctly capture
traffic associated with those sockets. tsnet with a TUN sets this
option, while all other builds leave this at false to preserve existing
performance.

Updates #18423

Signed-off-by: James Tucker <james@tailscale.com>
This commit is contained in:
James Tucker
2026-01-15 20:35:41 -08:00
committed by James Tucker
parent c062230cce
commit 63d563e734
3 changed files with 842 additions and 5 deletions
+85 -1
View File
@@ -165,6 +165,17 @@ type Impl struct {
// over the UDP flow.
GetUDPHandlerForFlow func(src, dst netip.AddrPort) (handler func(nettype.ConnPacketConn), intercept bool)
// CheckLocalTransportEndpoints, if true, causes netstack to check if gVisor
// has a registered endpoint for incoming packets to local IPs. This is used
// by tsnet to intercept packets for registered listeners and outbound
// connections when ProcessLocalIPs is false (i.e., when using a TUN).
// It can only be set before calling Start.
// TODO(raggi): refactor the way we handle both CheckLocalTransportEndpoints
// and the earlier netstack registrations for serve, funnel, peerAPI and so
// on. Currently this optimizes away cost for tailscaled in TUN mode, while
// enabling extension support when using tsnet in TUN mode. See #18423.
CheckLocalTransportEndpoints bool
// ProcessLocalIPs is whether netstack should handle incoming
// traffic directed at the Node.Addresses (local IPs).
// It can only be set before calling Start.
@@ -1109,6 +1120,45 @@ func (ns *Impl) shouldProcessInbound(p *packet.Parsed, t *tstun.Wrapper) bool {
if ns.ProcessSubnets && !isLocal {
return true
}
if isLocal && ns.CheckLocalTransportEndpoints {
// Handle packets to registered listeners and replies to outbound
// connections by checking if gVisor has a registered endpoint.
// This covers TCP listeners, UDP listeners, and outbound TCP replies.
if p.IPProto == ipproto.TCP || p.IPProto == ipproto.UDP {
var netProto tcpip.NetworkProtocolNumber
var id stack.TransportEndpointID
if p.Dst.Addr().Is4() {
netProto = ipv4.ProtocolNumber
id = stack.TransportEndpointID{
LocalAddress: tcpip.AddrFrom4(p.Dst.Addr().As4()),
LocalPort: p.Dst.Port(),
RemoteAddress: tcpip.AddrFrom4(p.Src.Addr().As4()),
RemotePort: p.Src.Port(),
}
} else {
netProto = ipv6.ProtocolNumber
id = stack.TransportEndpointID{
LocalAddress: tcpip.AddrFrom16(p.Dst.Addr().As16()),
LocalPort: p.Dst.Port(),
RemoteAddress: tcpip.AddrFrom16(p.Src.Addr().As16()),
RemotePort: p.Src.Port(),
}
}
var transProto tcpip.TransportProtocolNumber
if p.IPProto == ipproto.TCP {
transProto = tcp.ProtocolNumber
} else {
transProto = udp.ProtocolNumber
}
ep := ns.ipstack.FindTransportEndpoint(netProto, transProto, id, nicID)
if debugNetstack() {
ns.logf("[v2] FindTransportEndpoint: id=%+v found=%v", id, ep != nil)
}
if ep != nil {
return true
}
}
}
return false
}
@@ -1575,7 +1625,7 @@ func (ns *Impl) forwardTCP(getClient func(...tcpip.SettableSocketOption) *gonet.
func (ns *Impl) ListenPacket(network, address string) (net.PacketConn, error) {
ap, err := netip.ParseAddrPort(address)
if err != nil {
return nil, fmt.Errorf("netstack: ParseAddrPort(%q): %v", address, err)
return nil, fmt.Errorf("netstack: ParseAddrPort(%q): %w", address, err)
}
var networkProto tcpip.NetworkProtocolNumber
@@ -1612,6 +1662,40 @@ func (ns *Impl) ListenPacket(network, address string) (net.PacketConn, error) {
return gonet.NewUDPConn(&wq, ep), nil
}
// ListenTCP listens for TCP connections on the given address.
func (ns *Impl) ListenTCP(network, address string) (*gonet.TCPListener, error) {
ap, err := netip.ParseAddrPort(address)
if err != nil {
return nil, fmt.Errorf("netstack: ParseAddrPort(%q): %w", address, err)
}
var networkProto tcpip.NetworkProtocolNumber
switch network {
case "tcp4":
networkProto = ipv4.ProtocolNumber
if ap.Addr().IsValid() && !ap.Addr().Is4() {
return nil, fmt.Errorf("netstack: tcp4 requires an IPv4 address")
}
case "tcp6":
networkProto = ipv6.ProtocolNumber
if ap.Addr().IsValid() && !ap.Addr().Is6() {
return nil, fmt.Errorf("netstack: tcp6 requires an IPv6 address")
}
default:
return nil, fmt.Errorf("netstack: unsupported network %q", network)
}
localAddress := tcpip.FullAddress{
NIC: nicID,
Port: ap.Port(),
}
if ap.Addr().IsValid() && !ap.Addr().IsUnspecified() {
localAddress.Addr = tcpip.AddrFromSlice(ap.Addr().AsSlice())
}
return gonet.ListenTCP(ns.ipstack, localAddress, networkProto)
}
func (ns *Impl) acceptUDP(r *udp.ForwarderRequest) {
sess := r.ID()
if debugNetstack() {