From ce7789071f86d867c3928ea67183eb9139c7227e Mon Sep 17 00:00:00 2001 From: Michael Ben-Ami Date: Fri, 6 Mar 2026 15:45:00 -0500 Subject: [PATCH] feature/conn25: add NATing support with flow caching Introduce a datapathHandler that implements hooks that will receive packets from the tstun.Wrapper. This commit does not wire those up just yet. Perform DNAT from Magic IP to Transit IP on outbound flows on clients, and reverse SNAT in the reverse direction. Perform DNAT from Transit IP to final destination IP on outbound flows on connectors, and reverse SNAT in the reverse direction. Introduce FlowTable to cache validated flows by 5-tuple for fast lookups after the first packet. Flow expiration is not covered, and is intended as future work before the feature is officially released. Fixes tailscale/corp#34249 Fixes tailscale/corp#35995 Co-authored-by: Fran Bull Signed-off-by: Michael Ben-Ami --- cmd/tailscaled/depaware.txt | 2 +- feature/conn25/datapath.go | 242 +++++++++++++++++++++ feature/conn25/datapath_test.go | 361 +++++++++++++++++++++++++++++++ feature/conn25/flowtable.go | 149 +++++++++++++ feature/conn25/flowtable_test.go | 125 +++++++++++ 5 files changed, 878 insertions(+), 1 deletion(-) create mode 100644 feature/conn25/datapath.go create mode 100644 feature/conn25/datapath_test.go create mode 100644 feature/conn25/flowtable.go create mode 100644 feature/conn25/flowtable_test.go diff --git a/cmd/tailscaled/depaware.txt b/cmd/tailscaled/depaware.txt index c34bd490a..a7ecc865c 100644 --- a/cmd/tailscaled/depaware.txt +++ b/cmd/tailscaled/depaware.txt @@ -362,7 +362,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de tailscale.com/net/netutil from tailscale.com/client/local+ tailscale.com/net/netx from tailscale.com/control/controlclient+ tailscale.com/net/packet from tailscale.com/feature/capture+ - tailscale.com/net/packet/checksum from tailscale.com/net/tstun + tailscale.com/net/packet/checksum from tailscale.com/net/tstun+ tailscale.com/net/ping from tailscale.com/net/netcheck+ tailscale.com/net/portmapper from tailscale.com/feature/portmapper+ tailscale.com/net/portmapper/portmappertype from tailscale.com/feature/portmapper+ diff --git a/feature/conn25/datapath.go b/feature/conn25/datapath.go new file mode 100644 index 000000000..cc45edf63 --- /dev/null +++ b/feature/conn25/datapath.go @@ -0,0 +1,242 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package conn25 + +import ( + "errors" + "net/netip" + + "tailscale.com/envknob" + "tailscale.com/net/flowtrack" + "tailscale.com/net/packet" + "tailscale.com/net/packet/checksum" + "tailscale.com/types/ipproto" + "tailscale.com/types/logger" + "tailscale.com/wgengine/filter" +) + +var ( + ErrUnmappedMagicIP = errors.New("unmapped magic IP") + ErrUnmappedSrcAndTransitIP = errors.New("unmapped src and transit IP") +) + +// IPMapper provides methods for mapping special app connector IPs to each other +// in aid of performing DNAT and SNAT on app connector packets. +type IPMapper interface { + // ClientTransitIPForMagicIP returns a Transit IP for the given magicIP on a client. + // If the magicIP is within a configured Magic IP range for an app on the client, + // but not mapped to an active Transit IP, implementations should return [ErrUnmappedMagicIP]. + // If magicIP is not within a configured Magic IP range, i.e. it is not actually a Magic IP, + // implementations should return a nil error, and a zero-value [netip.Addr] to indicate + // this potentially valid, non-app-connector traffic. + ClientTransitIPForMagicIP(magicIP netip.Addr) (netip.Addr, error) + + // ConnectorRealIPForTransitIPConnection returns a real destination IP for the given + // srcIP and transitIP on a connector. If the transitIP is within a configured Transit IP + // range for an app on the connector, but not mapped to the client at srcIP, implementations + // should return [ErrUnmappedSrcAndTransitIP]. If the transitIP is not within a configured + // Transit IP range, i.e. it is not actually a Transit IP, implementations should return + // a nil error, a zero-value [netip.Addr] to indicate this is potentially valid, non-app-connector + // traffic. + ConnectorRealIPForTransitIPConnection(srcIP netip.Addr, transitIP netip.Addr) (netip.Addr, error) +} + +// datapathHandler handles packets from the datapath, +// performing appropriate NAT operations to support Connectors 2025. +// It maintains [FlowTable] caches for fast lookups of established flows. +// +// When hooked into the main datapath filter chain in [tstun], the datapathHandler +// will see every packet on the node, regardless of whether it is relevant to +// app connector operations. In the common case of non-connector traffic, it +// passes the packet through unmodified. +// +// It classifies each packet based on the presence of special Magic IPs or +// Transit IPs, and determines whether the packet is flowing through a "client" +// (the node with the application that starts the connection), or a "connector" +// (the node that connects to the internet-hosted destination). On the client, +// outbound connections are DNATed from Magic IP to Transit IP, and return +// traffic is SNATed from Transit IP to Magic IP. On the connector, outbound +// connections are DNATed from Transit IP to real IP, and return traffic is +// SNATed from real IP to Transit IP. +// +// There are two exposed methods, one for handling packets from the tun device, +// and one for handling packets from WireGuard, but through the use of flow tables, +// we can handle four cases: client outbound, client return, connector outbound, +// connector return. The first packet goes through IPMapper, which is where Connectors +// 2025 authoritative state is stored. For valid packets relevant to connectors, +// a bidirectional flow entry is installed, so that subsequent packets (and all return traffic) +// hit that cache. Only outbound (towards internet) packets create new flows; return (from internet) +// packets either match a cached entry or pass through. +// +// We check the cache before IPMapper both for performance, and so that existing flows stay alive +// even if address mappings change mid-flow. +type datapathHandler struct { + ipMapper IPMapper + + // Flow caches. One for the client, and one for the connector. + clientFlowTable *FlowTable + connectorFlowTable *FlowTable + + logf logger.Logf + debugLogging bool +} + +func newDatapathHandler(ipMapper IPMapper, logf logger.Logf) *datapathHandler { + return &datapathHandler{ + ipMapper: ipMapper, + + // TODO(mzb): Figure out sensible default max size for flow tables. + // Don't do any LRU eviction until we figure out deletion and expiration. + clientFlowTable: NewFlowTable(0), + connectorFlowTable: NewFlowTable(0), + logf: logf, + debugLogging: envknob.Bool("TS_CONN25_DATAPATH_DEBUG"), + } +} + +// HandlePacketFromWireGuard inspects packets coming from WireGuard, and performs +// appropriate DNAT or SNAT actions for Connectors 2025. Returning [filter.Accept] signals +// that the packet should pass through subsequent stages of the datapath pipeline. +// Returning [filter.Drop] signals the packet should be dropped. This method handles all +// packets coming from WireGuard, on both connectors, and clients of connectors. +func (dh *datapathHandler) HandlePacketFromWireGuard(p *packet.Parsed) filter.Response { + // TODO(tailscale/corp#38764): Support other protocols, like ICMP for error messages. + if p.IPProto != ipproto.TCP && p.IPProto != ipproto.UDP { + return filter.Accept + } + + // Check if this is an existing (return) flow on a client. + // If found, perform the action for the existing client flow and return. + existing, ok := dh.clientFlowTable.LookupFromWireGuard(flowtrack.MakeTuple(p.IPProto, p.Src, p.Dst)) + if ok { + existing.Action(p) + return filter.Accept + } + + // Check if this is an existing connector outbound flow. + // If found, perform the action for the existing connector outbound flow and return. + existing, ok = dh.connectorFlowTable.LookupFromWireGuard(flowtrack.MakeTuple(p.IPProto, p.Src, p.Dst)) + if ok { + existing.Action(p) + return filter.Accept + } + + // The flow was not found in either flow table. Since the packet came in + // from WireGuard, it can only be a new flow on the connector, + // other (non-app-connector) traffic, or broken app-connector traffic + // that needs to be re-established by a new outbound packet. + transitIP := p.Dst.Addr() + realIP, err := dh.ipMapper.ConnectorRealIPForTransitIPConnection(p.Src.Addr(), transitIP) + if err != nil { + if errors.Is(err, ErrUnmappedSrcAndTransitIP) { + // TODO(tailscale/corp#34256): This path should deliver an ICMP error to the client. + return filter.Drop + } + dh.debugLogf("error mapping src and transit IP, passing packet unmodified: %v", err) + return filter.Accept + } + + // If this is normal non-app-connector traffic, forward it along unmodified. + if !realIP.IsValid() { + return filter.Accept + } + + // This is a new outbound flow on a connector. Install a DNAT TransitIP-to-RealIP action + // for the outgoing direction, and an SNAT RealIP-to-TransitIP action for the + // return direction. + outgoing := FlowData{ + Tuple: flowtrack.MakeTuple(p.IPProto, p.Src, p.Dst), + Action: dh.dnatAction(realIP), + } + incoming := FlowData{ + Tuple: flowtrack.MakeTuple(p.IPProto, netip.AddrPortFrom(realIP, p.Dst.Port()), p.Src), + Action: dh.snatAction(transitIP), + } + if err := dh.connectorFlowTable.NewFlowFromWireGuard(outgoing, incoming); err != nil { + dh.debugLogf("error installing flow, passing packet unmodified: %v", err) + return filter.Accept + } + outgoing.Action(p) + return filter.Accept +} + +// HandlePacketFromTunDevice inspects packets coming from the tun device, and performs +// appropriate DNAT or SNAT actions for Connectors 2025. Returning [filter.Accept] signals +// that the packet should pass through subsequent stages of the datapath pipeline. +// Returning [filter.Drop] signals the packet should be dropped. This method handles all +// packets coming from the tun device, on both connectors, and clients of connectors. +func (dh *datapathHandler) HandlePacketFromTunDevice(p *packet.Parsed) filter.Response { + // TODO(tailscale/corp#38764): Support other protocols, like ICMP for error messages. + if p.IPProto != ipproto.TCP && p.IPProto != ipproto.UDP { + return filter.Accept + } + + // Check if this is an existing client outbound flow. + // If found, perform the action for the existing client flow and return. + existing, ok := dh.clientFlowTable.LookupFromTunDevice(flowtrack.MakeTuple(p.IPProto, p.Src, p.Dst)) + if ok { + existing.Action(p) + return filter.Accept + } + + // Check if this is an existing connector return flow. + // If found, perform the action for the existing connector return flow and return. + existing, ok = dh.connectorFlowTable.LookupFromTunDevice(flowtrack.MakeTuple(p.IPProto, p.Src, p.Dst)) + if ok { + existing.Action(p) + return filter.Accept + } + + // The flow was not found in either flow table. Since the packet came in on the + // tun device, it can only be a new client flow, other (non-app-connector) traffic, + // or broken return app-connector traffic on a connector, which needs to be re-established + // with a new outbound packet. + magicIP := p.Dst.Addr() + transitIP, err := dh.ipMapper.ClientTransitIPForMagicIP(magicIP) + if err != nil { + if errors.Is(err, ErrUnmappedMagicIP) { + // TODO(tailscale/corp#34257): This path should deliver an ICMP error to the client. + return filter.Drop + } + dh.debugLogf("error mapping magic IP, passing packet unmodified: %v", err) + return filter.Accept + } + + // If this is normal non-app-connector traffic, forward it along unmodified. + if !transitIP.IsValid() { + return filter.Accept + } + + // This is a new outbound client flow. Install a DNAT MagicIP-to-TransitIP action + // for the outgoing direction, and an SNAT TransitIP-to-MagicIP action for the + // return direction. + outgoing := FlowData{ + Tuple: flowtrack.MakeTuple(p.IPProto, p.Src, p.Dst), + Action: dh.dnatAction(transitIP), + } + incoming := FlowData{ + Tuple: flowtrack.MakeTuple(p.IPProto, netip.AddrPortFrom(transitIP, p.Dst.Port()), p.Src), + Action: dh.snatAction(magicIP), + } + if err := dh.clientFlowTable.NewFlowFromTunDevice(outgoing, incoming); err != nil { + dh.debugLogf("error installing flow from tun device, passing packet unmodified: %v", err) + return filter.Accept + } + outgoing.Action(p) + return filter.Accept +} + +func (dh *datapathHandler) dnatAction(to netip.Addr) PacketAction { + return PacketAction(func(p *packet.Parsed) { checksum.UpdateDstAddr(p, to) }) +} + +func (dh *datapathHandler) snatAction(to netip.Addr) PacketAction { + return PacketAction(func(p *packet.Parsed) { checksum.UpdateSrcAddr(p, to) }) +} + +func (dh *datapathHandler) debugLogf(msg string, args ...any) { + if dh.debugLogging { + dh.logf(msg, args...) + } +} diff --git a/feature/conn25/datapath_test.go b/feature/conn25/datapath_test.go new file mode 100644 index 000000000..a4a3363b7 --- /dev/null +++ b/feature/conn25/datapath_test.go @@ -0,0 +1,361 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package conn25 + +import ( + "errors" + "net/netip" + "testing" + + "tailscale.com/net/packet" + "tailscale.com/types/ipproto" + "tailscale.com/wgengine/filter" +) + +type testConn25 struct { + clientTransitIPForMagicIPFn func(netip.Addr) (netip.Addr, error) + connectorRealIPForTransitIPConnectionFn func(netip.Addr, netip.Addr) (netip.Addr, error) +} + +func (tc *testConn25) ClientTransitIPForMagicIP(magicIP netip.Addr) (netip.Addr, error) { + return tc.clientTransitIPForMagicIPFn(magicIP) +} + +func (tc *testConn25) ConnectorRealIPForTransitIPConnection(srcIP netip.Addr, transitIP netip.Addr) (netip.Addr, error) { + return tc.connectorRealIPForTransitIPConnectionFn(srcIP, transitIP) +} + +func TestHandlePacketFromTunDevice(t *testing.T) { + clientSrcIP := netip.MustParseAddr("100.70.0.1") + magicIP := netip.MustParseAddr("10.64.0.1") + unusedMagicIP := netip.MustParseAddr("10.64.0.2") + transitIP := netip.MustParseAddr("169.254.0.1") + realIP := netip.MustParseAddr("240.64.0.1") + + clientPort := uint16(1234) + serverPort := uint16(80) + + tests := []struct { + description string + p *packet.Parsed + throwMappingErr bool + expectedSrc netip.AddrPort + expectedDst netip.AddrPort + expectedFilterResponse filter.Response + }{ + { + description: "accept-and-nat-new-client-flow-mapped-magic-ip", + p: &packet.Parsed{ + Src: netip.AddrPortFrom(clientSrcIP, clientPort), + Dst: netip.AddrPortFrom(magicIP, serverPort), + }, + expectedSrc: netip.AddrPortFrom(clientSrcIP, clientPort), + expectedDst: netip.AddrPortFrom(transitIP, serverPort), + expectedFilterResponse: filter.Accept, + }, + { + description: "drop-unmapped-magic-ip", + p: &packet.Parsed{ + Src: netip.AddrPortFrom(clientSrcIP, clientPort), + Dst: netip.AddrPortFrom(unusedMagicIP, serverPort), + }, + expectedSrc: netip.AddrPortFrom(clientSrcIP, clientPort), + expectedDst: netip.AddrPortFrom(unusedMagicIP, serverPort), + expectedFilterResponse: filter.Drop, + }, + { + description: "accept-dont-nat-other-mapping-error", + p: &packet.Parsed{ + Src: netip.AddrPortFrom(clientSrcIP, clientPort), + Dst: netip.AddrPortFrom(magicIP, serverPort), + }, + throwMappingErr: true, + expectedSrc: netip.AddrPortFrom(clientSrcIP, clientPort), + expectedDst: netip.AddrPortFrom(magicIP, serverPort), + expectedFilterResponse: filter.Accept, + }, + { + description: "accept-dont-nat-uninteresting-client-side", + p: &packet.Parsed{ + Src: netip.AddrPortFrom(clientSrcIP, clientPort), + Dst: netip.AddrPortFrom(realIP, serverPort), + }, + expectedSrc: netip.AddrPortFrom(clientSrcIP, clientPort), + expectedDst: netip.AddrPortFrom(realIP, serverPort), + expectedFilterResponse: filter.Accept, + }, + { + description: "accept-dont-nat-uninteresting-connector-side", + p: &packet.Parsed{ + Src: netip.AddrPortFrom(realIP, serverPort), + Dst: netip.AddrPortFrom(clientSrcIP, clientPort), + }, + expectedSrc: netip.AddrPortFrom(realIP, serverPort), + expectedDst: netip.AddrPortFrom(clientSrcIP, clientPort), + expectedFilterResponse: filter.Accept, + }, + } + + for _, tt := range tests { + t.Run(tt.description, func(t *testing.T) { + mock := &testConn25{} + mock.clientTransitIPForMagicIPFn = func(mip netip.Addr) (netip.Addr, error) { + if tt.throwMappingErr { + return netip.Addr{}, errors.New("synthetic mapping error") + } + if mip == magicIP { + return transitIP, nil + } + if mip == unusedMagicIP { + return netip.Addr{}, ErrUnmappedMagicIP + } + return netip.Addr{}, nil + } + dph := newDatapathHandler(mock, nil) + + tt.p.IPProto = ipproto.UDP + tt.p.IPVersion = 4 + tt.p.StuffForTesting(40) + + if want, got := tt.expectedFilterResponse, dph.HandlePacketFromTunDevice(tt.p); want != got { + t.Errorf("unexpected filter response: want %v, got %v", want, got) + } + if want, got := tt.expectedSrc, tt.p.Src; want != got { + t.Errorf("unexpected packet src: want %v, got %v", want, got) + } + if want, got := tt.expectedDst, tt.p.Dst; want != got { + t.Errorf("unexpected packet dst: want %v, got %v", want, got) + } + }) + } +} + +func TestHandlePacketFromWireGuard(t *testing.T) { + clientSrcIP := netip.MustParseAddr("100.70.0.1") + unknownSrcIP := netip.MustParseAddr("100.99.99.99") + transitIP := netip.MustParseAddr("169.254.0.1") + realIP := netip.MustParseAddr("240.64.0.1") + + clientPort := uint16(1234) + serverPort := uint16(80) + + tests := []struct { + description string + p *packet.Parsed + throwMappingErr bool + expectedSrc netip.AddrPort + expectedDst netip.AddrPort + expectedFilterResponse filter.Response + }{ + { + description: "accept-and-nat-new-connector-flow-mapped-src-and-transit-ip", + p: &packet.Parsed{ + Src: netip.AddrPortFrom(clientSrcIP, clientPort), + Dst: netip.AddrPortFrom(transitIP, serverPort), + }, + expectedSrc: netip.AddrPortFrom(clientSrcIP, clientPort), + expectedDst: netip.AddrPortFrom(realIP, serverPort), + expectedFilterResponse: filter.Accept, + }, + { + description: "drop-unmapped-src-and-transit-ip", + p: &packet.Parsed{ + Src: netip.AddrPortFrom(unknownSrcIP, clientPort), + Dst: netip.AddrPortFrom(transitIP, serverPort), + }, + expectedSrc: netip.AddrPortFrom(unknownSrcIP, clientPort), + expectedDst: netip.AddrPortFrom(transitIP, serverPort), + expectedFilterResponse: filter.Drop, + }, + { + description: "accept-dont-nat-other-mapping-error", + p: &packet.Parsed{ + Src: netip.AddrPortFrom(clientSrcIP, clientPort), + Dst: netip.AddrPortFrom(transitIP, serverPort), + }, + throwMappingErr: true, + expectedSrc: netip.AddrPortFrom(clientSrcIP, clientPort), + expectedDst: netip.AddrPortFrom(transitIP, serverPort), + expectedFilterResponse: filter.Accept, + }, + { + description: "accept-dont-nat-uninteresting-connector-side", + p: &packet.Parsed{ + Src: netip.AddrPortFrom(clientSrcIP, clientPort), + Dst: netip.AddrPortFrom(realIP, serverPort), + }, + expectedSrc: netip.AddrPortFrom(clientSrcIP, clientPort), + expectedDst: netip.AddrPortFrom(realIP, serverPort), + expectedFilterResponse: filter.Accept, + }, + { + description: "accept-dont-nat-uninteresting-client-side", + p: &packet.Parsed{ + Src: netip.AddrPortFrom(realIP, serverPort), + Dst: netip.AddrPortFrom(clientSrcIP, clientPort), + }, + expectedSrc: netip.AddrPortFrom(realIP, serverPort), + expectedDst: netip.AddrPortFrom(clientSrcIP, clientPort), + expectedFilterResponse: filter.Accept, + }, + } + + for _, tt := range tests { + t.Run(tt.description, func(t *testing.T) { + mock := &testConn25{} + mock.connectorRealIPForTransitIPConnectionFn = func(src, tip netip.Addr) (netip.Addr, error) { + if tt.throwMappingErr { + return netip.Addr{}, errors.New("synthetic mapping error") + } + if tip == transitIP { + if src == clientSrcIP { + return realIP, nil + } else { + return netip.Addr{}, ErrUnmappedSrcAndTransitIP + } + } + return netip.Addr{}, nil + } + dph := newDatapathHandler(mock, nil) + + tt.p.IPProto = ipproto.UDP + tt.p.IPVersion = 4 + tt.p.StuffForTesting(40) + + if want, got := tt.expectedFilterResponse, dph.HandlePacketFromWireGuard(tt.p); want != got { + t.Errorf("unexpected filter response: want %v, got %v", want, got) + } + if want, got := tt.expectedSrc, tt.p.Src; want != got { + t.Errorf("unexpected packet src: want %v, got %v", want, got) + } + if want, got := tt.expectedDst, tt.p.Dst; want != got { + t.Errorf("unexpected packet dst: want %v, got %v", want, got) + } + }) + } +} + +func TestClientFlowCache(t *testing.T) { + getTransitIPCalled := false + + clientSrcIP := netip.MustParseAddr("100.70.0.1") + magicIP := netip.MustParseAddr("10.64.0.1") + transitIP := netip.MustParseAddr("169.254.0.1") + + clientPort := uint16(1234) + serverPort := uint16(80) + + mock := &testConn25{} + mock.clientTransitIPForMagicIPFn = func(mip netip.Addr) (netip.Addr, error) { + if getTransitIPCalled { + t.Errorf("ClientGetTransitIPForMagicIP unexpectedly called more than once") + } + getTransitIPCalled = true + return transitIP, nil + } + dph := newDatapathHandler(mock, nil) + + outgoing := packet.Parsed{ + IPProto: ipproto.UDP, + IPVersion: 4, + Src: netip.AddrPortFrom(clientSrcIP, clientPort), + Dst: netip.AddrPortFrom(magicIP, serverPort), + } + outgoing.StuffForTesting(40) + + o1 := outgoing + if dph.HandlePacketFromTunDevice(&o1) != filter.Accept { + t.Errorf("first call to HandlePacketFromTunDevice was not accepted") + } + if want, got := netip.AddrPortFrom(transitIP, serverPort), o1.Dst; want != got { + t.Errorf("unexpected packet dst after first call: want %v, got %v", want, got) + } + // The second call should use the cache. + o2 := outgoing + if dph.HandlePacketFromTunDevice(&o2) != filter.Accept { + t.Errorf("second call to HandlePacketFromTunDevice was not accepted") + } + if want, got := netip.AddrPortFrom(transitIP, serverPort), o2.Dst; want != got { + t.Errorf("unexpected packet dst after second call: want %v, got %v", want, got) + } + + // Return traffic should have the Transit IP as the source, + // and be SNATed to the Magic IP. + incoming := &packet.Parsed{ + IPProto: ipproto.UDP, + IPVersion: 4, + Src: netip.AddrPortFrom(transitIP, serverPort), + Dst: netip.AddrPortFrom(clientSrcIP, clientPort), + } + incoming.StuffForTesting(40) + + if dph.HandlePacketFromWireGuard(incoming) != filter.Accept { + t.Errorf("call to HandlePacketFromWireGuard was not accepted") + } + if want, got := netip.AddrPortFrom(magicIP, serverPort), incoming.Src; want != got { + t.Errorf("unexpected packet src after second call: want %v, got %v", want, got) + } +} + +func TestConnectorFlowCache(t *testing.T) { + getRealIPCalled := false + + clientSrcIP := netip.MustParseAddr("100.70.0.1") + transitIP := netip.MustParseAddr("169.254.0.1") + realIP := netip.MustParseAddr("240.64.0.1") + + clientPort := uint16(1234) + serverPort := uint16(80) + + mock := &testConn25{} + mock.connectorRealIPForTransitIPConnectionFn = func(src, tip netip.Addr) (netip.Addr, error) { + if getRealIPCalled { + t.Errorf("ConnectorRealIPForTransitIPConnection unexpectedly called more than once") + } + getRealIPCalled = true + return realIP, nil + } + dph := newDatapathHandler(mock, nil) + + outgoing := packet.Parsed{ + IPProto: ipproto.UDP, + IPVersion: 4, + Src: netip.AddrPortFrom(clientSrcIP, clientPort), + Dst: netip.AddrPortFrom(transitIP, serverPort), + } + outgoing.StuffForTesting(40) + + o1 := outgoing + if dph.HandlePacketFromWireGuard(&o1) != filter.Accept { + t.Errorf("first call to HandlePacketFromWireGuard was not accepted") + } + if want, got := netip.AddrPortFrom(realIP, serverPort), o1.Dst; want != got { + t.Errorf("unexpected packet dst after first call: want %v, got %v", want, got) + } + // The second call should use the cache. + o2 := outgoing + if dph.HandlePacketFromWireGuard(&o2) != filter.Accept { + t.Errorf("second call to HandlePacketFromWireGuard was not accepted") + } + if want, got := netip.AddrPortFrom(realIP, serverPort), o2.Dst; want != got { + t.Errorf("unexpected packet dst after second call: want %v, got %v", want, got) + } + + // Return traffic should have the Real IP as the source, + // and be SNATed to the Transit IP. + incoming := &packet.Parsed{ + IPProto: ipproto.UDP, + IPVersion: 4, + Src: netip.AddrPortFrom(realIP, serverPort), + Dst: netip.AddrPortFrom(clientSrcIP, clientPort), + } + incoming.StuffForTesting(40) + + if dph.HandlePacketFromTunDevice(incoming) != filter.Accept { + t.Errorf("call to HandlePacketFromTunDevice was not accepted") + } + if want, got := netip.AddrPortFrom(transitIP, serverPort), incoming.Src; want != got { + t.Errorf("unexpected packet src after second call: want %v, got %v", want, got) + } +} diff --git a/feature/conn25/flowtable.go b/feature/conn25/flowtable.go new file mode 100644 index 000000000..27486ded9 --- /dev/null +++ b/feature/conn25/flowtable.go @@ -0,0 +1,149 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package conn25 + +import ( + "errors" + "sync" + + "tailscale.com/net/flowtrack" + "tailscale.com/net/packet" +) + +// PacketAction may modify the packet. +type PacketAction func(*packet.Parsed) + +// FlowData is an entry stored in the [FlowTable]. +type FlowData struct { + Tuple flowtrack.Tuple + Action PacketAction +} + +// Origin is used to track the direction of a flow. +type Origin uint8 + +const ( + // FromTun indicates the flow is from the tun device. + FromTun Origin = iota + + // FromWireGuard indicates the flow is from the WireGuard tunnel. + FromWireGuard +) + +type cachedFlow struct { + flow FlowData + paired flowtrack.Tuple // tuple for the other direction +} + +// FlowTable stores and retrieves [FlowData] that can be looked up +// by 5-tuple. New entries specify the tuple to use for both directions +// of traffic flow. The underlying cache is LRU, and the maximum number +// of entries is specified in calls to [NewFlowTable]. FlowTable has +// its own mutex and is safe for concurrent use. +type FlowTable struct { + mu sync.Mutex + fromTunCache *flowtrack.Cache[cachedFlow] // guarded by mu + fromWGCache *flowtrack.Cache[cachedFlow] // guarded by mu +} + +// NewFlowTable returns a [FlowTable] maxEntries maximum entries. +// A maxEntries of 0 indicates no maximum. See also [FlowTable]. +func NewFlowTable(maxEntries int) *FlowTable { + return &FlowTable{ + fromTunCache: &flowtrack.Cache[cachedFlow]{ + MaxEntries: maxEntries, + }, + fromWGCache: &flowtrack.Cache[cachedFlow]{ + MaxEntries: maxEntries, + }, + } +} + +// LookupFromTunDevice looks up a [FlowData] entry that is valid to run for packets +// observed as coming from the tun device. The tuple must match the direction it was +// stored with. +func (t *FlowTable) LookupFromTunDevice(k flowtrack.Tuple) (FlowData, bool) { + return t.lookup(k, FromTun) +} + +// LookupFromWireGuard looks up a [FlowData] entry that is valid to run for packets +// observed as coming from the WireGuard tunnel. The tuple must match the direction it was +// stored with. +func (t *FlowTable) LookupFromWireGuard(k flowtrack.Tuple) (FlowData, bool) { + return t.lookup(k, FromWireGuard) +} + +func (t *FlowTable) lookup(k flowtrack.Tuple, want Origin) (FlowData, bool) { + var cache *flowtrack.Cache[cachedFlow] + switch want { + case FromTun: + cache = t.fromTunCache + case FromWireGuard: + cache = t.fromWGCache + default: + return FlowData{}, false + } + + t.mu.Lock() + defer t.mu.Unlock() + + v, ok := cache.Get(k) + if !ok { + return FlowData{}, false + } + return v.flow, true +} + +// NewFlowFromTunDevice installs (or overwrites) both the forward and return entries. +// The forward tuple is tagged as FromTun, and the return tuple is tagged as FromWireGuard. +// If overwriting, it removes the old paired tuple for the forward key to avoid stale reverse mappings. +func (t *FlowTable) NewFlowFromTunDevice(fwd, rev FlowData) error { + return t.newFlow(FromTun, fwd, rev) +} + +// NewFlowFromWireGuard installs (or overwrites) both the forward and return entries. +// The forward tuple is tagged as FromWireGuard, and the return tuple is tagged as FromTun. +// If overwriting, it removes the old paired tuple for the forward key to avoid stale reverse mappings. +func (t *FlowTable) NewFlowFromWireGuard(fwd, rev FlowData) error { + return t.newFlow(FromWireGuard, fwd, rev) +} + +func (t *FlowTable) newFlow(fwdOrigin Origin, fwd, rev FlowData) error { + if fwd.Action == nil || rev.Action == nil { + return errors.New("nil action received for flow") + } + + var fwdCache, revCache *flowtrack.Cache[cachedFlow] + switch fwdOrigin { + case FromTun: + fwdCache, revCache = t.fromTunCache, t.fromWGCache + case FromWireGuard: + fwdCache, revCache = t.fromWGCache, t.fromTunCache + default: + return errors.New("newFlow called with unknown direction") + } + + t.mu.Lock() + defer t.mu.Unlock() + + // If overwriting an existing entry, remove its previously-paired mapping so + // we don't leave stale tuples around. + if old, ok := fwdCache.Get(fwd.Tuple); ok { + revCache.Remove(old.paired) + } + if old, ok := revCache.Get(rev.Tuple); ok { + fwdCache.Remove(old.paired) + } + + fwdCache.Add(fwd.Tuple, cachedFlow{ + flow: fwd, + paired: rev.Tuple, + }) + revCache.Add(rev.Tuple, cachedFlow{ + flow: rev, + paired: fwd.Tuple, + }) + + return nil +} diff --git a/feature/conn25/flowtable_test.go b/feature/conn25/flowtable_test.go new file mode 100644 index 000000000..8c3cd63a2 --- /dev/null +++ b/feature/conn25/flowtable_test.go @@ -0,0 +1,125 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package conn25 + +import ( + "net/netip" + "testing" + + "tailscale.com/net/flowtrack" + "tailscale.com/net/packet" + "tailscale.com/types/ipproto" +) + +func TestFlowTable(t *testing.T) { + ft := NewFlowTable(0) + + fwdTuple := flowtrack.MakeTuple( + ipproto.UDP, + netip.MustParseAddrPort("1.2.3.4:1000"), + netip.MustParseAddrPort("4.3.2.1:80"), + ) + // Reverse tuple is defined by caller. Doesn't have to be mirror image of fwd. + // To account for intentional modifications, like NAT. + revTuple := flowtrack.MakeTuple( + ipproto.UDP, + netip.MustParseAddrPort("4.3.2.2:80"), + netip.MustParseAddrPort("1.2.3.4:1000"), + ) + + fwdAction, revAction := 0, 0 + fwdData := FlowData{ + Tuple: fwdTuple, + Action: func(_ *packet.Parsed) { fwdAction++ }, + } + revData := FlowData{ + Tuple: revTuple, + Action: func(_ *packet.Parsed) { revAction++ }, + } + + // For this test setup, from the tun device will be "forward", + // and from WG will be "reverse". + if err := ft.NewFlowFromTunDevice(fwdData, revData); err != nil { + t.Fatalf("got non-nil error for new flow from tun device") + } + + // Test basic lookups. + lookupFwd, ok := ft.LookupFromTunDevice(fwdTuple) + if !ok { + t.Fatalf("got not found on first lookup from tun device") + } + lookupFwd.Action(nil) + if fwdAction != 1 { + t.Errorf("action for fwd tuple key was not executed") + } + + lookupRev, ok := ft.LookupFromWireGuard(revTuple) + if !ok { + t.Fatalf("got not found on first lookup from WireGuard") + } + lookupRev.Action(nil) + if revAction != 1 { + t.Errorf("action for rev tuple key was not executed") + } + + // Test not found error. + notFoundTuple := flowtrack.MakeTuple( + ipproto.UDP, + netip.MustParseAddrPort("1.2.3.4:1000"), + netip.MustParseAddrPort("4.0.4.4:80"), + ) + if _, ok := ft.LookupFromTunDevice(notFoundTuple); ok { + t.Errorf("expected not found for foreign tuple") + } + + // Wrong direction is also not found. + if _, ok := ft.LookupFromWireGuard(fwdTuple); ok { + t.Errorf("expected not found for wrong direction tuple") + } + + // Overwriting forward tuple removes its reverse pair as well. + newRevData := FlowData{ + Tuple: flowtrack.MakeTuple( + ipproto.UDP, + netip.MustParseAddrPort("9.9.9.9:99"), + netip.MustParseAddrPort("8.8.8.8:88"), + ), + Action: func(_ *packet.Parsed) {}, + } + if err := ft.NewFlowFromTunDevice( + fwdData, + newRevData, + ); err != nil { + t.Fatalf("got non-nil error for new flow from tun device") + } + if _, ok := ft.LookupFromWireGuard(revTuple); ok { + t.Errorf("expected not found for removed reverse tuple") + } + + // Overwriting reverse tuple removes its forward pair as well. + if err := ft.NewFlowFromTunDevice( + FlowData{ + Tuple: flowtrack.MakeTuple( + ipproto.UDP, + netip.MustParseAddrPort("8.8.8.8:88"), + netip.MustParseAddrPort("9.9.9.9:99"), + ), + Action: func(_ *packet.Parsed) {}, + }, + newRevData, // This is the same "reverse" data installed in previous test. + ); err != nil { + t.Fatalf("got non-nil error for new flow from tun device") + } + if _, ok := ft.LookupFromTunDevice(fwdTuple); ok { + t.Errorf("expected not found for removed forward tuple") + } + + // Nil action returns an error. + if err := ft.NewFlowFromTunDevice( + FlowData{}, + FlowData{}, + ); err == nil { + t.Errorf("expected non-nil error for nil data") + } +}