diff --git a/cmd/tailscaled/depaware.txt b/cmd/tailscaled/depaware.txt index c34bd490a..a7ecc865c 100644 --- a/cmd/tailscaled/depaware.txt +++ b/cmd/tailscaled/depaware.txt @@ -362,7 +362,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de tailscale.com/net/netutil from tailscale.com/client/local+ tailscale.com/net/netx from tailscale.com/control/controlclient+ tailscale.com/net/packet from tailscale.com/feature/capture+ - tailscale.com/net/packet/checksum from tailscale.com/net/tstun + tailscale.com/net/packet/checksum from tailscale.com/net/tstun+ tailscale.com/net/ping from tailscale.com/net/netcheck+ tailscale.com/net/portmapper from tailscale.com/feature/portmapper+ tailscale.com/net/portmapper/portmappertype from tailscale.com/feature/portmapper+ diff --git a/feature/conn25/datapath.go b/feature/conn25/datapath.go new file mode 100644 index 000000000..cc45edf63 --- /dev/null +++ b/feature/conn25/datapath.go @@ -0,0 +1,242 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package conn25 + +import ( + "errors" + "net/netip" + + "tailscale.com/envknob" + "tailscale.com/net/flowtrack" + "tailscale.com/net/packet" + "tailscale.com/net/packet/checksum" + "tailscale.com/types/ipproto" + "tailscale.com/types/logger" + "tailscale.com/wgengine/filter" +) + +var ( + ErrUnmappedMagicIP = errors.New("unmapped magic IP") + ErrUnmappedSrcAndTransitIP = errors.New("unmapped src and transit IP") +) + +// IPMapper provides methods for mapping special app connector IPs to each other +// in aid of performing DNAT and SNAT on app connector packets. +type IPMapper interface { + // ClientTransitIPForMagicIP returns a Transit IP for the given magicIP on a client. + // If the magicIP is within a configured Magic IP range for an app on the client, + // but not mapped to an active Transit IP, implementations should return [ErrUnmappedMagicIP]. + // If magicIP is not within a configured Magic IP range, i.e. it is not actually a Magic IP, + // implementations should return a nil error, and a zero-value [netip.Addr] to indicate + // this potentially valid, non-app-connector traffic. + ClientTransitIPForMagicIP(magicIP netip.Addr) (netip.Addr, error) + + // ConnectorRealIPForTransitIPConnection returns a real destination IP for the given + // srcIP and transitIP on a connector. If the transitIP is within a configured Transit IP + // range for an app on the connector, but not mapped to the client at srcIP, implementations + // should return [ErrUnmappedSrcAndTransitIP]. If the transitIP is not within a configured + // Transit IP range, i.e. it is not actually a Transit IP, implementations should return + // a nil error, a zero-value [netip.Addr] to indicate this is potentially valid, non-app-connector + // traffic. + ConnectorRealIPForTransitIPConnection(srcIP netip.Addr, transitIP netip.Addr) (netip.Addr, error) +} + +// datapathHandler handles packets from the datapath, +// performing appropriate NAT operations to support Connectors 2025. +// It maintains [FlowTable] caches for fast lookups of established flows. +// +// When hooked into the main datapath filter chain in [tstun], the datapathHandler +// will see every packet on the node, regardless of whether it is relevant to +// app connector operations. In the common case of non-connector traffic, it +// passes the packet through unmodified. +// +// It classifies each packet based on the presence of special Magic IPs or +// Transit IPs, and determines whether the packet is flowing through a "client" +// (the node with the application that starts the connection), or a "connector" +// (the node that connects to the internet-hosted destination). On the client, +// outbound connections are DNATed from Magic IP to Transit IP, and return +// traffic is SNATed from Transit IP to Magic IP. On the connector, outbound +// connections are DNATed from Transit IP to real IP, and return traffic is +// SNATed from real IP to Transit IP. +// +// There are two exposed methods, one for handling packets from the tun device, +// and one for handling packets from WireGuard, but through the use of flow tables, +// we can handle four cases: client outbound, client return, connector outbound, +// connector return. The first packet goes through IPMapper, which is where Connectors +// 2025 authoritative state is stored. For valid packets relevant to connectors, +// a bidirectional flow entry is installed, so that subsequent packets (and all return traffic) +// hit that cache. Only outbound (towards internet) packets create new flows; return (from internet) +// packets either match a cached entry or pass through. +// +// We check the cache before IPMapper both for performance, and so that existing flows stay alive +// even if address mappings change mid-flow. +type datapathHandler struct { + ipMapper IPMapper + + // Flow caches. One for the client, and one for the connector. + clientFlowTable *FlowTable + connectorFlowTable *FlowTable + + logf logger.Logf + debugLogging bool +} + +func newDatapathHandler(ipMapper IPMapper, logf logger.Logf) *datapathHandler { + return &datapathHandler{ + ipMapper: ipMapper, + + // TODO(mzb): Figure out sensible default max size for flow tables. + // Don't do any LRU eviction until we figure out deletion and expiration. + clientFlowTable: NewFlowTable(0), + connectorFlowTable: NewFlowTable(0), + logf: logf, + debugLogging: envknob.Bool("TS_CONN25_DATAPATH_DEBUG"), + } +} + +// HandlePacketFromWireGuard inspects packets coming from WireGuard, and performs +// appropriate DNAT or SNAT actions for Connectors 2025. Returning [filter.Accept] signals +// that the packet should pass through subsequent stages of the datapath pipeline. +// Returning [filter.Drop] signals the packet should be dropped. This method handles all +// packets coming from WireGuard, on both connectors, and clients of connectors. +func (dh *datapathHandler) HandlePacketFromWireGuard(p *packet.Parsed) filter.Response { + // TODO(tailscale/corp#38764): Support other protocols, like ICMP for error messages. + if p.IPProto != ipproto.TCP && p.IPProto != ipproto.UDP { + return filter.Accept + } + + // Check if this is an existing (return) flow on a client. + // If found, perform the action for the existing client flow and return. + existing, ok := dh.clientFlowTable.LookupFromWireGuard(flowtrack.MakeTuple(p.IPProto, p.Src, p.Dst)) + if ok { + existing.Action(p) + return filter.Accept + } + + // Check if this is an existing connector outbound flow. + // If found, perform the action for the existing connector outbound flow and return. + existing, ok = dh.connectorFlowTable.LookupFromWireGuard(flowtrack.MakeTuple(p.IPProto, p.Src, p.Dst)) + if ok { + existing.Action(p) + return filter.Accept + } + + // The flow was not found in either flow table. Since the packet came in + // from WireGuard, it can only be a new flow on the connector, + // other (non-app-connector) traffic, or broken app-connector traffic + // that needs to be re-established by a new outbound packet. + transitIP := p.Dst.Addr() + realIP, err := dh.ipMapper.ConnectorRealIPForTransitIPConnection(p.Src.Addr(), transitIP) + if err != nil { + if errors.Is(err, ErrUnmappedSrcAndTransitIP) { + // TODO(tailscale/corp#34256): This path should deliver an ICMP error to the client. + return filter.Drop + } + dh.debugLogf("error mapping src and transit IP, passing packet unmodified: %v", err) + return filter.Accept + } + + // If this is normal non-app-connector traffic, forward it along unmodified. + if !realIP.IsValid() { + return filter.Accept + } + + // This is a new outbound flow on a connector. Install a DNAT TransitIP-to-RealIP action + // for the outgoing direction, and an SNAT RealIP-to-TransitIP action for the + // return direction. + outgoing := FlowData{ + Tuple: flowtrack.MakeTuple(p.IPProto, p.Src, p.Dst), + Action: dh.dnatAction(realIP), + } + incoming := FlowData{ + Tuple: flowtrack.MakeTuple(p.IPProto, netip.AddrPortFrom(realIP, p.Dst.Port()), p.Src), + Action: dh.snatAction(transitIP), + } + if err := dh.connectorFlowTable.NewFlowFromWireGuard(outgoing, incoming); err != nil { + dh.debugLogf("error installing flow, passing packet unmodified: %v", err) + return filter.Accept + } + outgoing.Action(p) + return filter.Accept +} + +// HandlePacketFromTunDevice inspects packets coming from the tun device, and performs +// appropriate DNAT or SNAT actions for Connectors 2025. Returning [filter.Accept] signals +// that the packet should pass through subsequent stages of the datapath pipeline. +// Returning [filter.Drop] signals the packet should be dropped. This method handles all +// packets coming from the tun device, on both connectors, and clients of connectors. +func (dh *datapathHandler) HandlePacketFromTunDevice(p *packet.Parsed) filter.Response { + // TODO(tailscale/corp#38764): Support other protocols, like ICMP for error messages. + if p.IPProto != ipproto.TCP && p.IPProto != ipproto.UDP { + return filter.Accept + } + + // Check if this is an existing client outbound flow. + // If found, perform the action for the existing client flow and return. + existing, ok := dh.clientFlowTable.LookupFromTunDevice(flowtrack.MakeTuple(p.IPProto, p.Src, p.Dst)) + if ok { + existing.Action(p) + return filter.Accept + } + + // Check if this is an existing connector return flow. + // If found, perform the action for the existing connector return flow and return. + existing, ok = dh.connectorFlowTable.LookupFromTunDevice(flowtrack.MakeTuple(p.IPProto, p.Src, p.Dst)) + if ok { + existing.Action(p) + return filter.Accept + } + + // The flow was not found in either flow table. Since the packet came in on the + // tun device, it can only be a new client flow, other (non-app-connector) traffic, + // or broken return app-connector traffic on a connector, which needs to be re-established + // with a new outbound packet. + magicIP := p.Dst.Addr() + transitIP, err := dh.ipMapper.ClientTransitIPForMagicIP(magicIP) + if err != nil { + if errors.Is(err, ErrUnmappedMagicIP) { + // TODO(tailscale/corp#34257): This path should deliver an ICMP error to the client. + return filter.Drop + } + dh.debugLogf("error mapping magic IP, passing packet unmodified: %v", err) + return filter.Accept + } + + // If this is normal non-app-connector traffic, forward it along unmodified. + if !transitIP.IsValid() { + return filter.Accept + } + + // This is a new outbound client flow. Install a DNAT MagicIP-to-TransitIP action + // for the outgoing direction, and an SNAT TransitIP-to-MagicIP action for the + // return direction. + outgoing := FlowData{ + Tuple: flowtrack.MakeTuple(p.IPProto, p.Src, p.Dst), + Action: dh.dnatAction(transitIP), + } + incoming := FlowData{ + Tuple: flowtrack.MakeTuple(p.IPProto, netip.AddrPortFrom(transitIP, p.Dst.Port()), p.Src), + Action: dh.snatAction(magicIP), + } + if err := dh.clientFlowTable.NewFlowFromTunDevice(outgoing, incoming); err != nil { + dh.debugLogf("error installing flow from tun device, passing packet unmodified: %v", err) + return filter.Accept + } + outgoing.Action(p) + return filter.Accept +} + +func (dh *datapathHandler) dnatAction(to netip.Addr) PacketAction { + return PacketAction(func(p *packet.Parsed) { checksum.UpdateDstAddr(p, to) }) +} + +func (dh *datapathHandler) snatAction(to netip.Addr) PacketAction { + return PacketAction(func(p *packet.Parsed) { checksum.UpdateSrcAddr(p, to) }) +} + +func (dh *datapathHandler) debugLogf(msg string, args ...any) { + if dh.debugLogging { + dh.logf(msg, args...) + } +} diff --git a/feature/conn25/datapath_test.go b/feature/conn25/datapath_test.go new file mode 100644 index 000000000..a4a3363b7 --- /dev/null +++ b/feature/conn25/datapath_test.go @@ -0,0 +1,361 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package conn25 + +import ( + "errors" + "net/netip" + "testing" + + "tailscale.com/net/packet" + "tailscale.com/types/ipproto" + "tailscale.com/wgengine/filter" +) + +type testConn25 struct { + clientTransitIPForMagicIPFn func(netip.Addr) (netip.Addr, error) + connectorRealIPForTransitIPConnectionFn func(netip.Addr, netip.Addr) (netip.Addr, error) +} + +func (tc *testConn25) ClientTransitIPForMagicIP(magicIP netip.Addr) (netip.Addr, error) { + return tc.clientTransitIPForMagicIPFn(magicIP) +} + +func (tc *testConn25) ConnectorRealIPForTransitIPConnection(srcIP netip.Addr, transitIP netip.Addr) (netip.Addr, error) { + return tc.connectorRealIPForTransitIPConnectionFn(srcIP, transitIP) +} + +func TestHandlePacketFromTunDevice(t *testing.T) { + clientSrcIP := netip.MustParseAddr("100.70.0.1") + magicIP := netip.MustParseAddr("10.64.0.1") + unusedMagicIP := netip.MustParseAddr("10.64.0.2") + transitIP := netip.MustParseAddr("169.254.0.1") + realIP := netip.MustParseAddr("240.64.0.1") + + clientPort := uint16(1234) + serverPort := uint16(80) + + tests := []struct { + description string + p *packet.Parsed + throwMappingErr bool + expectedSrc netip.AddrPort + expectedDst netip.AddrPort + expectedFilterResponse filter.Response + }{ + { + description: "accept-and-nat-new-client-flow-mapped-magic-ip", + p: &packet.Parsed{ + Src: netip.AddrPortFrom(clientSrcIP, clientPort), + Dst: netip.AddrPortFrom(magicIP, serverPort), + }, + expectedSrc: netip.AddrPortFrom(clientSrcIP, clientPort), + expectedDst: netip.AddrPortFrom(transitIP, serverPort), + expectedFilterResponse: filter.Accept, + }, + { + description: "drop-unmapped-magic-ip", + p: &packet.Parsed{ + Src: netip.AddrPortFrom(clientSrcIP, clientPort), + Dst: netip.AddrPortFrom(unusedMagicIP, serverPort), + }, + expectedSrc: netip.AddrPortFrom(clientSrcIP, clientPort), + expectedDst: netip.AddrPortFrom(unusedMagicIP, serverPort), + expectedFilterResponse: filter.Drop, + }, + { + description: "accept-dont-nat-other-mapping-error", + p: &packet.Parsed{ + Src: netip.AddrPortFrom(clientSrcIP, clientPort), + Dst: netip.AddrPortFrom(magicIP, serverPort), + }, + throwMappingErr: true, + expectedSrc: netip.AddrPortFrom(clientSrcIP, clientPort), + expectedDst: netip.AddrPortFrom(magicIP, serverPort), + expectedFilterResponse: filter.Accept, + }, + { + description: "accept-dont-nat-uninteresting-client-side", + p: &packet.Parsed{ + Src: netip.AddrPortFrom(clientSrcIP, clientPort), + Dst: netip.AddrPortFrom(realIP, serverPort), + }, + expectedSrc: netip.AddrPortFrom(clientSrcIP, clientPort), + expectedDst: netip.AddrPortFrom(realIP, serverPort), + expectedFilterResponse: filter.Accept, + }, + { + description: "accept-dont-nat-uninteresting-connector-side", + p: &packet.Parsed{ + Src: netip.AddrPortFrom(realIP, serverPort), + Dst: netip.AddrPortFrom(clientSrcIP, clientPort), + }, + expectedSrc: netip.AddrPortFrom(realIP, serverPort), + expectedDst: netip.AddrPortFrom(clientSrcIP, clientPort), + expectedFilterResponse: filter.Accept, + }, + } + + for _, tt := range tests { + t.Run(tt.description, func(t *testing.T) { + mock := &testConn25{} + mock.clientTransitIPForMagicIPFn = func(mip netip.Addr) (netip.Addr, error) { + if tt.throwMappingErr { + return netip.Addr{}, errors.New("synthetic mapping error") + } + if mip == magicIP { + return transitIP, nil + } + if mip == unusedMagicIP { + return netip.Addr{}, ErrUnmappedMagicIP + } + return netip.Addr{}, nil + } + dph := newDatapathHandler(mock, nil) + + tt.p.IPProto = ipproto.UDP + tt.p.IPVersion = 4 + tt.p.StuffForTesting(40) + + if want, got := tt.expectedFilterResponse, dph.HandlePacketFromTunDevice(tt.p); want != got { + t.Errorf("unexpected filter response: want %v, got %v", want, got) + } + if want, got := tt.expectedSrc, tt.p.Src; want != got { + t.Errorf("unexpected packet src: want %v, got %v", want, got) + } + if want, got := tt.expectedDst, tt.p.Dst; want != got { + t.Errorf("unexpected packet dst: want %v, got %v", want, got) + } + }) + } +} + +func TestHandlePacketFromWireGuard(t *testing.T) { + clientSrcIP := netip.MustParseAddr("100.70.0.1") + unknownSrcIP := netip.MustParseAddr("100.99.99.99") + transitIP := netip.MustParseAddr("169.254.0.1") + realIP := netip.MustParseAddr("240.64.0.1") + + clientPort := uint16(1234) + serverPort := uint16(80) + + tests := []struct { + description string + p *packet.Parsed + throwMappingErr bool + expectedSrc netip.AddrPort + expectedDst netip.AddrPort + expectedFilterResponse filter.Response + }{ + { + description: "accept-and-nat-new-connector-flow-mapped-src-and-transit-ip", + p: &packet.Parsed{ + Src: netip.AddrPortFrom(clientSrcIP, clientPort), + Dst: netip.AddrPortFrom(transitIP, serverPort), + }, + expectedSrc: netip.AddrPortFrom(clientSrcIP, clientPort), + expectedDst: netip.AddrPortFrom(realIP, serverPort), + expectedFilterResponse: filter.Accept, + }, + { + description: "drop-unmapped-src-and-transit-ip", + p: &packet.Parsed{ + Src: netip.AddrPortFrom(unknownSrcIP, clientPort), + Dst: netip.AddrPortFrom(transitIP, serverPort), + }, + expectedSrc: netip.AddrPortFrom(unknownSrcIP, clientPort), + expectedDst: netip.AddrPortFrom(transitIP, serverPort), + expectedFilterResponse: filter.Drop, + }, + { + description: "accept-dont-nat-other-mapping-error", + p: &packet.Parsed{ + Src: netip.AddrPortFrom(clientSrcIP, clientPort), + Dst: netip.AddrPortFrom(transitIP, serverPort), + }, + throwMappingErr: true, + expectedSrc: netip.AddrPortFrom(clientSrcIP, clientPort), + expectedDst: netip.AddrPortFrom(transitIP, serverPort), + expectedFilterResponse: filter.Accept, + }, + { + description: "accept-dont-nat-uninteresting-connector-side", + p: &packet.Parsed{ + Src: netip.AddrPortFrom(clientSrcIP, clientPort), + Dst: netip.AddrPortFrom(realIP, serverPort), + }, + expectedSrc: netip.AddrPortFrom(clientSrcIP, clientPort), + expectedDst: netip.AddrPortFrom(realIP, serverPort), + expectedFilterResponse: filter.Accept, + }, + { + description: "accept-dont-nat-uninteresting-client-side", + p: &packet.Parsed{ + Src: netip.AddrPortFrom(realIP, serverPort), + Dst: netip.AddrPortFrom(clientSrcIP, clientPort), + }, + expectedSrc: netip.AddrPortFrom(realIP, serverPort), + expectedDst: netip.AddrPortFrom(clientSrcIP, clientPort), + expectedFilterResponse: filter.Accept, + }, + } + + for _, tt := range tests { + t.Run(tt.description, func(t *testing.T) { + mock := &testConn25{} + mock.connectorRealIPForTransitIPConnectionFn = func(src, tip netip.Addr) (netip.Addr, error) { + if tt.throwMappingErr { + return netip.Addr{}, errors.New("synthetic mapping error") + } + if tip == transitIP { + if src == clientSrcIP { + return realIP, nil + } else { + return netip.Addr{}, ErrUnmappedSrcAndTransitIP + } + } + return netip.Addr{}, nil + } + dph := newDatapathHandler(mock, nil) + + tt.p.IPProto = ipproto.UDP + tt.p.IPVersion = 4 + tt.p.StuffForTesting(40) + + if want, got := tt.expectedFilterResponse, dph.HandlePacketFromWireGuard(tt.p); want != got { + t.Errorf("unexpected filter response: want %v, got %v", want, got) + } + if want, got := tt.expectedSrc, tt.p.Src; want != got { + t.Errorf("unexpected packet src: want %v, got %v", want, got) + } + if want, got := tt.expectedDst, tt.p.Dst; want != got { + t.Errorf("unexpected packet dst: want %v, got %v", want, got) + } + }) + } +} + +func TestClientFlowCache(t *testing.T) { + getTransitIPCalled := false + + clientSrcIP := netip.MustParseAddr("100.70.0.1") + magicIP := netip.MustParseAddr("10.64.0.1") + transitIP := netip.MustParseAddr("169.254.0.1") + + clientPort := uint16(1234) + serverPort := uint16(80) + + mock := &testConn25{} + mock.clientTransitIPForMagicIPFn = func(mip netip.Addr) (netip.Addr, error) { + if getTransitIPCalled { + t.Errorf("ClientGetTransitIPForMagicIP unexpectedly called more than once") + } + getTransitIPCalled = true + return transitIP, nil + } + dph := newDatapathHandler(mock, nil) + + outgoing := packet.Parsed{ + IPProto: ipproto.UDP, + IPVersion: 4, + Src: netip.AddrPortFrom(clientSrcIP, clientPort), + Dst: netip.AddrPortFrom(magicIP, serverPort), + } + outgoing.StuffForTesting(40) + + o1 := outgoing + if dph.HandlePacketFromTunDevice(&o1) != filter.Accept { + t.Errorf("first call to HandlePacketFromTunDevice was not accepted") + } + if want, got := netip.AddrPortFrom(transitIP, serverPort), o1.Dst; want != got { + t.Errorf("unexpected packet dst after first call: want %v, got %v", want, got) + } + // The second call should use the cache. + o2 := outgoing + if dph.HandlePacketFromTunDevice(&o2) != filter.Accept { + t.Errorf("second call to HandlePacketFromTunDevice was not accepted") + } + if want, got := netip.AddrPortFrom(transitIP, serverPort), o2.Dst; want != got { + t.Errorf("unexpected packet dst after second call: want %v, got %v", want, got) + } + + // Return traffic should have the Transit IP as the source, + // and be SNATed to the Magic IP. + incoming := &packet.Parsed{ + IPProto: ipproto.UDP, + IPVersion: 4, + Src: netip.AddrPortFrom(transitIP, serverPort), + Dst: netip.AddrPortFrom(clientSrcIP, clientPort), + } + incoming.StuffForTesting(40) + + if dph.HandlePacketFromWireGuard(incoming) != filter.Accept { + t.Errorf("call to HandlePacketFromWireGuard was not accepted") + } + if want, got := netip.AddrPortFrom(magicIP, serverPort), incoming.Src; want != got { + t.Errorf("unexpected packet src after second call: want %v, got %v", want, got) + } +} + +func TestConnectorFlowCache(t *testing.T) { + getRealIPCalled := false + + clientSrcIP := netip.MustParseAddr("100.70.0.1") + transitIP := netip.MustParseAddr("169.254.0.1") + realIP := netip.MustParseAddr("240.64.0.1") + + clientPort := uint16(1234) + serverPort := uint16(80) + + mock := &testConn25{} + mock.connectorRealIPForTransitIPConnectionFn = func(src, tip netip.Addr) (netip.Addr, error) { + if getRealIPCalled { + t.Errorf("ConnectorRealIPForTransitIPConnection unexpectedly called more than once") + } + getRealIPCalled = true + return realIP, nil + } + dph := newDatapathHandler(mock, nil) + + outgoing := packet.Parsed{ + IPProto: ipproto.UDP, + IPVersion: 4, + Src: netip.AddrPortFrom(clientSrcIP, clientPort), + Dst: netip.AddrPortFrom(transitIP, serverPort), + } + outgoing.StuffForTesting(40) + + o1 := outgoing + if dph.HandlePacketFromWireGuard(&o1) != filter.Accept { + t.Errorf("first call to HandlePacketFromWireGuard was not accepted") + } + if want, got := netip.AddrPortFrom(realIP, serverPort), o1.Dst; want != got { + t.Errorf("unexpected packet dst after first call: want %v, got %v", want, got) + } + // The second call should use the cache. + o2 := outgoing + if dph.HandlePacketFromWireGuard(&o2) != filter.Accept { + t.Errorf("second call to HandlePacketFromWireGuard was not accepted") + } + if want, got := netip.AddrPortFrom(realIP, serverPort), o2.Dst; want != got { + t.Errorf("unexpected packet dst after second call: want %v, got %v", want, got) + } + + // Return traffic should have the Real IP as the source, + // and be SNATed to the Transit IP. + incoming := &packet.Parsed{ + IPProto: ipproto.UDP, + IPVersion: 4, + Src: netip.AddrPortFrom(realIP, serverPort), + Dst: netip.AddrPortFrom(clientSrcIP, clientPort), + } + incoming.StuffForTesting(40) + + if dph.HandlePacketFromTunDevice(incoming) != filter.Accept { + t.Errorf("call to HandlePacketFromTunDevice was not accepted") + } + if want, got := netip.AddrPortFrom(transitIP, serverPort), incoming.Src; want != got { + t.Errorf("unexpected packet src after second call: want %v, got %v", want, got) + } +} diff --git a/feature/conn25/flowtable.go b/feature/conn25/flowtable.go new file mode 100644 index 000000000..27486ded9 --- /dev/null +++ b/feature/conn25/flowtable.go @@ -0,0 +1,149 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package conn25 + +import ( + "errors" + "sync" + + "tailscale.com/net/flowtrack" + "tailscale.com/net/packet" +) + +// PacketAction may modify the packet. +type PacketAction func(*packet.Parsed) + +// FlowData is an entry stored in the [FlowTable]. +type FlowData struct { + Tuple flowtrack.Tuple + Action PacketAction +} + +// Origin is used to track the direction of a flow. +type Origin uint8 + +const ( + // FromTun indicates the flow is from the tun device. + FromTun Origin = iota + + // FromWireGuard indicates the flow is from the WireGuard tunnel. + FromWireGuard +) + +type cachedFlow struct { + flow FlowData + paired flowtrack.Tuple // tuple for the other direction +} + +// FlowTable stores and retrieves [FlowData] that can be looked up +// by 5-tuple. New entries specify the tuple to use for both directions +// of traffic flow. The underlying cache is LRU, and the maximum number +// of entries is specified in calls to [NewFlowTable]. FlowTable has +// its own mutex and is safe for concurrent use. +type FlowTable struct { + mu sync.Mutex + fromTunCache *flowtrack.Cache[cachedFlow] // guarded by mu + fromWGCache *flowtrack.Cache[cachedFlow] // guarded by mu +} + +// NewFlowTable returns a [FlowTable] maxEntries maximum entries. +// A maxEntries of 0 indicates no maximum. See also [FlowTable]. +func NewFlowTable(maxEntries int) *FlowTable { + return &FlowTable{ + fromTunCache: &flowtrack.Cache[cachedFlow]{ + MaxEntries: maxEntries, + }, + fromWGCache: &flowtrack.Cache[cachedFlow]{ + MaxEntries: maxEntries, + }, + } +} + +// LookupFromTunDevice looks up a [FlowData] entry that is valid to run for packets +// observed as coming from the tun device. The tuple must match the direction it was +// stored with. +func (t *FlowTable) LookupFromTunDevice(k flowtrack.Tuple) (FlowData, bool) { + return t.lookup(k, FromTun) +} + +// LookupFromWireGuard looks up a [FlowData] entry that is valid to run for packets +// observed as coming from the WireGuard tunnel. The tuple must match the direction it was +// stored with. +func (t *FlowTable) LookupFromWireGuard(k flowtrack.Tuple) (FlowData, bool) { + return t.lookup(k, FromWireGuard) +} + +func (t *FlowTable) lookup(k flowtrack.Tuple, want Origin) (FlowData, bool) { + var cache *flowtrack.Cache[cachedFlow] + switch want { + case FromTun: + cache = t.fromTunCache + case FromWireGuard: + cache = t.fromWGCache + default: + return FlowData{}, false + } + + t.mu.Lock() + defer t.mu.Unlock() + + v, ok := cache.Get(k) + if !ok { + return FlowData{}, false + } + return v.flow, true +} + +// NewFlowFromTunDevice installs (or overwrites) both the forward and return entries. +// The forward tuple is tagged as FromTun, and the return tuple is tagged as FromWireGuard. +// If overwriting, it removes the old paired tuple for the forward key to avoid stale reverse mappings. +func (t *FlowTable) NewFlowFromTunDevice(fwd, rev FlowData) error { + return t.newFlow(FromTun, fwd, rev) +} + +// NewFlowFromWireGuard installs (or overwrites) both the forward and return entries. +// The forward tuple is tagged as FromWireGuard, and the return tuple is tagged as FromTun. +// If overwriting, it removes the old paired tuple for the forward key to avoid stale reverse mappings. +func (t *FlowTable) NewFlowFromWireGuard(fwd, rev FlowData) error { + return t.newFlow(FromWireGuard, fwd, rev) +} + +func (t *FlowTable) newFlow(fwdOrigin Origin, fwd, rev FlowData) error { + if fwd.Action == nil || rev.Action == nil { + return errors.New("nil action received for flow") + } + + var fwdCache, revCache *flowtrack.Cache[cachedFlow] + switch fwdOrigin { + case FromTun: + fwdCache, revCache = t.fromTunCache, t.fromWGCache + case FromWireGuard: + fwdCache, revCache = t.fromWGCache, t.fromTunCache + default: + return errors.New("newFlow called with unknown direction") + } + + t.mu.Lock() + defer t.mu.Unlock() + + // If overwriting an existing entry, remove its previously-paired mapping so + // we don't leave stale tuples around. + if old, ok := fwdCache.Get(fwd.Tuple); ok { + revCache.Remove(old.paired) + } + if old, ok := revCache.Get(rev.Tuple); ok { + fwdCache.Remove(old.paired) + } + + fwdCache.Add(fwd.Tuple, cachedFlow{ + flow: fwd, + paired: rev.Tuple, + }) + revCache.Add(rev.Tuple, cachedFlow{ + flow: rev, + paired: fwd.Tuple, + }) + + return nil +} diff --git a/feature/conn25/flowtable_test.go b/feature/conn25/flowtable_test.go new file mode 100644 index 000000000..8c3cd63a2 --- /dev/null +++ b/feature/conn25/flowtable_test.go @@ -0,0 +1,125 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package conn25 + +import ( + "net/netip" + "testing" + + "tailscale.com/net/flowtrack" + "tailscale.com/net/packet" + "tailscale.com/types/ipproto" +) + +func TestFlowTable(t *testing.T) { + ft := NewFlowTable(0) + + fwdTuple := flowtrack.MakeTuple( + ipproto.UDP, + netip.MustParseAddrPort("1.2.3.4:1000"), + netip.MustParseAddrPort("4.3.2.1:80"), + ) + // Reverse tuple is defined by caller. Doesn't have to be mirror image of fwd. + // To account for intentional modifications, like NAT. + revTuple := flowtrack.MakeTuple( + ipproto.UDP, + netip.MustParseAddrPort("4.3.2.2:80"), + netip.MustParseAddrPort("1.2.3.4:1000"), + ) + + fwdAction, revAction := 0, 0 + fwdData := FlowData{ + Tuple: fwdTuple, + Action: func(_ *packet.Parsed) { fwdAction++ }, + } + revData := FlowData{ + Tuple: revTuple, + Action: func(_ *packet.Parsed) { revAction++ }, + } + + // For this test setup, from the tun device will be "forward", + // and from WG will be "reverse". + if err := ft.NewFlowFromTunDevice(fwdData, revData); err != nil { + t.Fatalf("got non-nil error for new flow from tun device") + } + + // Test basic lookups. + lookupFwd, ok := ft.LookupFromTunDevice(fwdTuple) + if !ok { + t.Fatalf("got not found on first lookup from tun device") + } + lookupFwd.Action(nil) + if fwdAction != 1 { + t.Errorf("action for fwd tuple key was not executed") + } + + lookupRev, ok := ft.LookupFromWireGuard(revTuple) + if !ok { + t.Fatalf("got not found on first lookup from WireGuard") + } + lookupRev.Action(nil) + if revAction != 1 { + t.Errorf("action for rev tuple key was not executed") + } + + // Test not found error. + notFoundTuple := flowtrack.MakeTuple( + ipproto.UDP, + netip.MustParseAddrPort("1.2.3.4:1000"), + netip.MustParseAddrPort("4.0.4.4:80"), + ) + if _, ok := ft.LookupFromTunDevice(notFoundTuple); ok { + t.Errorf("expected not found for foreign tuple") + } + + // Wrong direction is also not found. + if _, ok := ft.LookupFromWireGuard(fwdTuple); ok { + t.Errorf("expected not found for wrong direction tuple") + } + + // Overwriting forward tuple removes its reverse pair as well. + newRevData := FlowData{ + Tuple: flowtrack.MakeTuple( + ipproto.UDP, + netip.MustParseAddrPort("9.9.9.9:99"), + netip.MustParseAddrPort("8.8.8.8:88"), + ), + Action: func(_ *packet.Parsed) {}, + } + if err := ft.NewFlowFromTunDevice( + fwdData, + newRevData, + ); err != nil { + t.Fatalf("got non-nil error for new flow from tun device") + } + if _, ok := ft.LookupFromWireGuard(revTuple); ok { + t.Errorf("expected not found for removed reverse tuple") + } + + // Overwriting reverse tuple removes its forward pair as well. + if err := ft.NewFlowFromTunDevice( + FlowData{ + Tuple: flowtrack.MakeTuple( + ipproto.UDP, + netip.MustParseAddrPort("8.8.8.8:88"), + netip.MustParseAddrPort("9.9.9.9:99"), + ), + Action: func(_ *packet.Parsed) {}, + }, + newRevData, // This is the same "reverse" data installed in previous test. + ); err != nil { + t.Fatalf("got non-nil error for new flow from tun device") + } + if _, ok := ft.LookupFromTunDevice(fwdTuple); ok { + t.Errorf("expected not found for removed forward tuple") + } + + // Nil action returns an error. + if err := ft.NewFlowFromTunDevice( + FlowData{}, + FlowData{}, + ); err == nil { + t.Errorf("expected non-nil error for nil data") + } +}