This is an integration test that covers all the code in Direct, Auto, and LocalBackend that processes NetMaps and creates a Filter. The test uses tsnet as a convenient proxy for setting up all the client pieces correctly, but is not actually a test specific to tsnet. Updates tailscale/corp#20514 Signed-off-by: James Sanderson <jsanderson@tailscale.com>main
parent
5be6ff9b62
commit
85a7abef0c
@ -0,0 +1,248 @@ |
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package tsnet |
||||
|
||||
import ( |
||||
"context" |
||||
"fmt" |
||||
"net/netip" |
||||
"testing" |
||||
"time" |
||||
|
||||
"tailscale.com/ipn" |
||||
"tailscale.com/tailcfg" |
||||
"tailscale.com/types/ipproto" |
||||
"tailscale.com/types/key" |
||||
"tailscale.com/types/netmap" |
||||
"tailscale.com/util/must" |
||||
"tailscale.com/wgengine/filter" |
||||
) |
||||
|
||||
// waitFor blocks until a NetMap is seen on the IPN bus that satisfies the given
|
||||
// function f. Note: has no timeout, should be called with a ctx that has an
|
||||
// appropriate timeout set.
|
||||
func waitFor(t testing.TB, ctx context.Context, s *Server, f func(*netmap.NetworkMap) bool) error { |
||||
t.Helper() |
||||
watcher, err := s.localClient.WatchIPNBus(ctx, ipn.NotifyInitialNetMap) |
||||
if err != nil { |
||||
t.Fatalf("error watching IPN bus: %s", err) |
||||
} |
||||
defer watcher.Close() |
||||
|
||||
for { |
||||
n, err := watcher.Next() |
||||
if err != nil { |
||||
return fmt.Errorf("getting next ipn.Notify from IPN bus: %w", err) |
||||
} |
||||
if n.NetMap != nil { |
||||
if f(n.NetMap) { |
||||
return nil |
||||
} |
||||
} |
||||
} |
||||
} |
||||
|
||||
// TestPacketFilterFromNetmap tests all of the client code for processing
|
||||
// netmaps and turning them into packet filters together. Only the control-plane
|
||||
// side is mocked out.
|
||||
func TestPacketFilterFromNetmap(t *testing.T) { |
||||
t.Parallel() |
||||
|
||||
var key key.NodePublic |
||||
must.Do(key.UnmarshalText([]byte("nodekey:5c8f86d5fc70d924e55f02446165a5dae8f822994ad26bcf4b08fd841f9bf261"))) |
||||
|
||||
type check struct { |
||||
src string |
||||
dst string |
||||
port uint16 |
||||
want filter.Response |
||||
} |
||||
|
||||
tests := []struct { |
||||
name string |
||||
mapResponse *tailcfg.MapResponse |
||||
waitTest func(*netmap.NetworkMap) bool |
||||
|
||||
incrementalMapResponse *tailcfg.MapResponse // optional
|
||||
incrementalWaitTest func(*netmap.NetworkMap) bool // optional
|
||||
|
||||
checks []check |
||||
}{ |
||||
{ |
||||
name: "IP_based_peers", |
||||
mapResponse: &tailcfg.MapResponse{ |
||||
Node: &tailcfg.Node{ |
||||
Addresses: []netip.Prefix{netip.MustParsePrefix("1.1.1.1/32")}, |
||||
}, |
||||
Peers: []*tailcfg.Node{{ |
||||
ID: 2, |
||||
Name: "foo", |
||||
Key: key, |
||||
Addresses: []netip.Prefix{netip.MustParsePrefix("2.2.2.2/32")}, |
||||
CapMap: nil, |
||||
}}, |
||||
PacketFilter: []tailcfg.FilterRule{{ |
||||
SrcIPs: []string{"2.2.2.2/32"}, |
||||
DstPorts: []tailcfg.NetPortRange{{ |
||||
IP: "1.1.1.1/32", |
||||
Ports: tailcfg.PortRange{ |
||||
First: 22, |
||||
Last: 22, |
||||
}, |
||||
}}, |
||||
IPProto: []int{int(ipproto.TCP)}, |
||||
}}, |
||||
}, |
||||
waitTest: func(nm *netmap.NetworkMap) bool { |
||||
return len(nm.Peers) > 0 |
||||
}, |
||||
checks: []check{ |
||||
{src: "2.2.2.2", dst: "1.1.1.1", port: 22, want: filter.Accept}, |
||||
{src: "2.2.2.2", dst: "1.1.1.1", port: 23, want: filter.Drop}, // different port
|
||||
{src: "3.3.3.3", dst: "1.1.1.1", port: 22, want: filter.Drop}, // different src
|
||||
{src: "2.2.2.2", dst: "1.1.1.2", port: 22, want: filter.Drop}, // different dst
|
||||
}, |
||||
}, |
||||
{ |
||||
name: "capmap_based_peers", |
||||
mapResponse: &tailcfg.MapResponse{ |
||||
Node: &tailcfg.Node{ |
||||
Addresses: []netip.Prefix{netip.MustParsePrefix("1.1.1.1/32")}, |
||||
}, |
||||
Peers: []*tailcfg.Node{{ |
||||
ID: 2, |
||||
Name: "foo", |
||||
Key: key, |
||||
Addresses: []netip.Prefix{netip.MustParsePrefix("2.2.2.2/32")}, |
||||
CapMap: tailcfg.NodeCapMap{"X": nil}, |
||||
}}, |
||||
PacketFilter: []tailcfg.FilterRule{{ |
||||
SrcIPs: []string{"cap:X"}, |
||||
DstPorts: []tailcfg.NetPortRange{{ |
||||
IP: "1.1.1.1/32", |
||||
Ports: tailcfg.PortRange{ |
||||
First: 22, |
||||
Last: 22, |
||||
}, |
||||
}}, |
||||
IPProto: []int{int(ipproto.TCP)}, |
||||
}}, |
||||
}, |
||||
waitTest: func(nm *netmap.NetworkMap) bool { |
||||
return len(nm.Peers) > 0 |
||||
}, |
||||
checks: []check{ |
||||
{src: "2.2.2.2", dst: "1.1.1.1", port: 22, want: filter.Accept}, |
||||
{src: "2.2.2.2", dst: "1.1.1.1", port: 23, want: filter.Drop}, // different port
|
||||
{src: "3.3.3.3", dst: "1.1.1.1", port: 22, want: filter.Drop}, // different src
|
||||
{src: "2.2.2.2", dst: "1.1.1.2", port: 22, want: filter.Drop}, // different dst
|
||||
}, |
||||
}, |
||||
{ |
||||
name: "capmap_based_peers_changed", |
||||
mapResponse: &tailcfg.MapResponse{ |
||||
Node: &tailcfg.Node{ |
||||
Addresses: []netip.Prefix{netip.MustParsePrefix("1.1.1.1/32")}, |
||||
CapMap: tailcfg.NodeCapMap{"X-sigil": nil}, |
||||
}, |
||||
PacketFilter: []tailcfg.FilterRule{{ |
||||
SrcIPs: []string{"cap:label-1"}, |
||||
DstPorts: []tailcfg.NetPortRange{{ |
||||
IP: "1.1.1.1/32", |
||||
Ports: tailcfg.PortRange{ |
||||
First: 22, |
||||
Last: 22, |
||||
}, |
||||
}}, |
||||
IPProto: []int{int(ipproto.TCP)}, |
||||
}}, |
||||
}, |
||||
waitTest: func(nm *netmap.NetworkMap) bool { |
||||
return nm.SelfNode.HasCap("X-sigil") |
||||
}, |
||||
incrementalMapResponse: &tailcfg.MapResponse{ |
||||
PeersChanged: []*tailcfg.Node{{ |
||||
ID: 2, |
||||
Name: "foo", |
||||
Key: key, |
||||
Addresses: []netip.Prefix{netip.MustParsePrefix("2.2.2.2/32")}, |
||||
CapMap: tailcfg.NodeCapMap{"label-1": nil}, |
||||
}}, |
||||
}, |
||||
incrementalWaitTest: func(nm *netmap.NetworkMap) bool { |
||||
return len(nm.Peers) > 0 |
||||
}, |
||||
checks: []check{ |
||||
{src: "2.2.2.2", dst: "1.1.1.1", port: 22, want: filter.Accept}, |
||||
{src: "2.2.2.2", dst: "1.1.1.1", port: 23, want: filter.Drop}, // different port
|
||||
{src: "3.3.3.3", dst: "1.1.1.1", port: 22, want: filter.Drop}, // different src
|
||||
{src: "2.2.2.2", dst: "1.1.1.2", port: 22, want: filter.Drop}, // different dst
|
||||
}, |
||||
}, |
||||
} |
||||
for _, test := range tests { |
||||
t.Run(test.name, func(t *testing.T) { |
||||
t.Parallel() |
||||
ctx, cancel := context.WithTimeout(t.Context(), 30*time.Second) |
||||
defer cancel() |
||||
|
||||
controlURL, c := startControl(t) |
||||
s, _, pubKey := startServer(t, ctx, controlURL, "node") |
||||
|
||||
if test.waitTest(s.lb.NetMap()) { |
||||
t.Fatal("waitTest already passes before sending initial netmap: this will be flaky") |
||||
} |
||||
|
||||
if !c.AddRawMapResponse(pubKey, test.mapResponse) { |
||||
t.Fatalf("could not send map response to %s", pubKey) |
||||
} |
||||
|
||||
if err := waitFor(t, ctx, s, test.waitTest); err != nil { |
||||
t.Fatalf("waitFor: %s", err) |
||||
} |
||||
|
||||
pf := s.lb.GetFilterForTest() |
||||
|
||||
for _, check := range test.checks { |
||||
got := pf.Check(netip.MustParseAddr(check.src), netip.MustParseAddr(check.dst), check.port, ipproto.TCP) |
||||
|
||||
want := check.want |
||||
if test.incrementalMapResponse != nil { |
||||
want = filter.Drop |
||||
} |
||||
if got != want { |
||||
t.Errorf("check %s -> %s:%d, got: %s, want: %s", check.src, check.dst, check.port, got, want) |
||||
} |
||||
} |
||||
|
||||
if test.incrementalMapResponse != nil { |
||||
if test.incrementalWaitTest == nil { |
||||
t.Fatal("incrementalWaitTest must be set if incrementalMapResponse is set") |
||||
} |
||||
|
||||
if test.incrementalWaitTest(s.lb.NetMap()) { |
||||
t.Fatal("incrementalWaitTest already passes before sending incremental netmap: this will be flaky") |
||||
} |
||||
|
||||
if !c.AddRawMapResponse(pubKey, test.incrementalMapResponse) { |
||||
t.Fatalf("could not send map response to %s", pubKey) |
||||
} |
||||
|
||||
if err := waitFor(t, ctx, s, test.incrementalWaitTest); err != nil { |
||||
t.Fatalf("waitFor: %s", err) |
||||
} |
||||
|
||||
pf := s.lb.GetFilterForTest() |
||||
|
||||
for _, check := range test.checks { |
||||
got := pf.Check(netip.MustParseAddr(check.src), netip.MustParseAddr(check.dst), check.port, ipproto.TCP) |
||||
if got != check.want { |
||||
t.Errorf("check %s -> %s:%d, got: %s, want: %s", check.src, check.dst, check.port, got, check.want) |
||||
} |
||||
} |
||||
} |
||||
|
||||
}) |
||||
} |
||||
} |
||||
Loading…
Reference in new issue