diff --git a/net/tstun/wrap_test.go b/net/tstun/wrap_test.go index 8515cb8f0..1744fc302 100644 --- a/net/tstun/wrap_test.go +++ b/net/tstun/wrap_test.go @@ -41,6 +41,7 @@ import ( "tailscale.com/util/must" "tailscale.com/util/usermetric" "tailscale.com/wgengine/filter" + "tailscale.com/wgengine/netstack/gro" "tailscale.com/wgengine/wgcfg" ) @@ -991,3 +992,67 @@ func TestTSMPDisco(t *testing.T) { } }) } + +func TestInterceptOrdering(t *testing.T) { + bus := eventbustest.NewBus(t) + chtun, tun := newChannelTUN(t.Logf, bus, true) + defer tun.Close() + + var seq uint8 + orderedFilterFn := func(expected uint8) FilterFunc { + return func(_ *packet.Parsed, _ *Wrapper) filter.Response { + seq++ + if expected != seq { + t.Errorf("got sequence %d; want %d", seq, expected) + } + return filter.Accept + } + } + + ordereredGROFilterFn := func(expected uint8) GROFilterFunc { + return func(_ *packet.Parsed, _ *Wrapper, _ *gro.GRO) (filter.Response, *gro.GRO) { + seq++ + if expected != seq { + t.Errorf("got sequence %d; want %d", seq, expected) + } + return filter.Accept, nil + } + } + + // As the number of inbound intercepts change, + // this value should change. + numInboundIntercepts := uint8(3) + + tun.PreFilterPacketInboundFromWireGuard = orderedFilterFn(1) + tun.PostFilterPacketInboundFromWireGuardAppConnector = orderedFilterFn(2) + tun.PostFilterPacketInboundFromWireGuard = ordereredGROFilterFn(3) + + // Write the packet. + go func() { <-chtun.Inbound }() // Simulate tun device receiving. + packet := [][]byte{udp4("5.6.7.8", "1.2.3.4", 89, 89)} + tun.Write(packet, 0) + + if seq != numInboundIntercepts { + t.Errorf("got number of intercepts run in Write(): %d; want: %d", seq, numInboundIntercepts) + } + + // As the number of inbound intercepts change, + // this value should change. + numOutboundIntercepts := uint8(4) + + seq = 0 + tun.PreFilterPacketOutboundToWireGuardNetstackIntercept = ordereredGROFilterFn(1) + tun.PreFilterPacketOutboundToWireGuardEngineIntercept = orderedFilterFn(2) + tun.PreFilterPacketOutboundToWireGuardAppConnectorIntercept = orderedFilterFn(3) + tun.PostFilterPacketOutboundToWireGuard = orderedFilterFn(4) + + // Read the packet. + var buf [MaxPacketSize]byte + sizes := make([]int, 1) + chtun.Outbound <- udp4("1.2.3.4", "5.6.7.8", 98, 98) // Simulate tun device sending. + tun.Read([][]byte{buf[:]}, sizes, 0) + + if seq != numOutboundIntercepts { + t.Errorf("got number of intercepts run in Read(): %d; want: %d", seq, numOutboundIntercepts) + } +}