From 4f1406f05ae7a1ea4c79d12587bdb3156bb2e12e Mon Sep 17 00:00:00 2001 From: "M. J. Fromberger" Date: Fri, 13 Feb 2026 10:59:43 -0800 Subject: [PATCH] ipn/ipnlocal/netmapcache: include packet filters in the cache (#18715) Store packet filter rules in the cache. The match expressions are derived from the filter rules, so these do not need to be stored explicitly, but ensure they are properly reconstructed when the cache is read back. Update the tests to include these fields, and provide representative values. Updates #12639 Change-Id: I9bdb972a86d2c6387177d393ada1f54805a2448b Signed-off-by: M. J. Fromberger --- ipn/ipnlocal/netmapcache/netmapcache.go | 86 +++++++++++++------- ipn/ipnlocal/netmapcache/netmapcache_test.go | 55 ++++++++++++- ipn/ipnlocal/netmapcache/types.go | 7 ++ 3 files changed, 115 insertions(+), 33 deletions(-) diff --git a/ipn/ipnlocal/netmapcache/netmapcache.go b/ipn/ipnlocal/netmapcache/netmapcache.go index b12443b99..1b8347f0b 100644 --- a/ipn/ipnlocal/netmapcache/netmapcache.go +++ b/ipn/ipnlocal/netmapcache/netmapcache.go @@ -27,6 +27,7 @@ import ( "tailscale.com/types/netmap" "tailscale.com/util/mak" "tailscale.com/util/set" + "tailscale.com/wgengine/filter" ) var ( @@ -45,17 +46,17 @@ var ( type Cache struct { store Store - // wantKeys records the storage keys from the last write or load of a cached + // wantKeys records the cache keys from the last write or load of a cached // netmap. This is used to prune keys that are no longer referenced after an // update. - wantKeys set.Set[string] + wantKeys set.Set[cacheKey] // lastWrote records the last values written to each stored key. // // TODO(creachadair): This is meant to avoid disk writes, but I'm not // convinced we need it. Or maybe just track hashes of the content rather // than caching a complete copy. - lastWrote map[string]lastWrote + lastWrote map[cacheKey]lastWrote } // NewCache constructs a new empty [Cache] from the given [Store]. @@ -66,8 +67,8 @@ func NewCache(s Store) *Cache { } return &Cache{ store: s, - wantKeys: make(set.Set[string]), - lastWrote: make(map[string]lastWrote), + wantKeys: make(set.Set[cacheKey]), + lastWrote: make(map[cacheKey]lastWrote), } } @@ -76,7 +77,7 @@ type lastWrote struct { at time.Time } -func (c *Cache) writeJSON(ctx context.Context, key string, v any) error { +func (c *Cache) writeJSON(ctx context.Context, key cacheKey, v any) error { j, err := jsonv1.Marshal(v) if err != nil { return fmt.Errorf("JSON marshalling %q: %w", key, err) @@ -90,7 +91,7 @@ func (c *Cache) writeJSON(ctx context.Context, key string, v any) error { return nil } - if err := c.store.Store(ctx, key, j); err != nil { + if err := c.store.Store(ctx, string(key), j); err != nil { return err } @@ -110,11 +111,12 @@ func (c *Cache) removeUnwantedKeys(ctx context.Context) error { errs = append(errs, err) break } - if !c.wantKeys.Contains(key) { + ckey := cacheKey(key) + if !c.wantKeys.Contains(ckey) { if err := c.store.Remove(ctx, key); err != nil { errs = append(errs, fmt.Errorf("remove key %q: %w", key, err)) } - delete(c.lastWrote, key) // even if removal failed, we don't want it + delete(c.lastWrote, ckey) // even if removal failed, we don't want it } } return errors.Join(errs...) @@ -177,6 +179,20 @@ func (s FileStore) Remove(ctx context.Context, key string) error { return err } +// cacheKey is a type wrapper for strings used as cache keys. +type cacheKey string + +const ( + selfKey cacheKey = "self" + miscKey cacheKey = "msic" + dnsKey cacheKey = "dns" + derpMapKey cacheKey = "derpmap" + peerKeyPrefix cacheKey = "peer-" // + stable ID + userKeyPrefix cacheKey = "user-" // + profile ID + sshPolicyKey cacheKey = "ssh" + packetFilterKey cacheKey = "filter" +) + // Store records nm in the cache, replacing any previously-cached values. func (c *Cache) Store(ctx context.Context, nm *netmap.NetworkMap) error { if !buildfeatures.HasCacheNetMap || nm == nil || nm.Cached { @@ -187,7 +203,7 @@ func (c *Cache) Store(ctx context.Context, nm *netmap.NetworkMap) error { } clear(c.wantKeys) - if err := c.writeJSON(ctx, "misc", netmapMisc{ + if err := c.writeJSON(ctx, miscKey, netmapMisc{ MachineKey: &nm.MachineKey, CollectServices: &nm.CollectServices, DisplayMessages: &nm.DisplayMessages, @@ -198,33 +214,36 @@ func (c *Cache) Store(ctx context.Context, nm *netmap.NetworkMap) error { }); err != nil { return err } - if err := c.writeJSON(ctx, "dns", netmapDNS{DNS: &nm.DNS}); err != nil { + if err := c.writeJSON(ctx, dnsKey, netmapDNS{DNS: &nm.DNS}); err != nil { return err } - if err := c.writeJSON(ctx, "derpmap", netmapDERPMap{DERPMap: &nm.DERPMap}); err != nil { + if err := c.writeJSON(ctx, derpMapKey, netmapDERPMap{DERPMap: &nm.DERPMap}); err != nil { return err } - if err := c.writeJSON(ctx, "self", netmapNode{Node: &nm.SelfNode}); err != nil { + if err := c.writeJSON(ctx, selfKey, netmapNode{Node: &nm.SelfNode}); err != nil { return err // N.B. The NodeKey and AllCaps fields can be recovered from SelfNode on // load, and do not need to be stored separately. } for _, p := range nm.Peers { - key := fmt.Sprintf("peer-%s", p.StableID()) + key := peerKeyPrefix + cacheKey(p.StableID()) if err := c.writeJSON(ctx, key, netmapNode{Node: &p}); err != nil { return err } } for uid, u := range nm.UserProfiles { - key := fmt.Sprintf("user-%d", uid) - if err := c.writeJSON(ctx, key, netmapUserProfile{UserProfile: &u}); err != nil { + key := fmt.Sprintf("%s%d", userKeyPrefix, uid) + if err := c.writeJSON(ctx, cacheKey(key), netmapUserProfile{UserProfile: &u}); err != nil { return err } } + if err := c.writeJSON(ctx, packetFilterKey, netmapPacketFilter{Rules: &nm.PacketFilterRules}); err != nil { + return err + } if buildfeatures.HasSSH && nm.SSHPolicy != nil { - if err := c.writeJSON(ctx, "ssh", netmapSSH{SSHPolicy: &nm.SSHPolicy}); err != nil { + if err := c.writeJSON(ctx, sshPolicyKey, netmapSSH{SSHPolicy: &nm.SSHPolicy}); err != nil { return err } } @@ -244,12 +263,12 @@ func (c *Cache) Load(ctx context.Context) (*netmap.NetworkMap, error) { // At minimum, we require that the cache contain a "self" node, or the data // are not usable. - if self, err := c.store.Load(ctx, "self"); errors.Is(err, ErrKeyNotFound) { + if self, err := c.store.Load(ctx, string(selfKey)); errors.Is(err, ErrKeyNotFound) { return nil, ErrCacheNotAvailable } else if err := jsonv1.Unmarshal(self, &netmapNode{Node: &nm.SelfNode}); err != nil { return nil, err } - c.wantKeys.Add("self") + c.wantKeys.Add(selfKey) // If we successfully recovered a SelfNode, pull out its related fields. if s := nm.SelfNode; s.Valid() { @@ -266,7 +285,7 @@ func (c *Cache) Load(ctx context.Context) (*netmap.NetworkMap, error) { // Unmarshal the contents of each specified cache entry directly into the // fields of the output. See the comment in types.go for more detail. - if err := c.readJSON(ctx, "misc", &netmapMisc{ + if err := c.readJSON(ctx, miscKey, &netmapMisc{ MachineKey: &nm.MachineKey, CollectServices: &nm.CollectServices, DisplayMessages: &nm.DisplayMessages, @@ -278,43 +297,52 @@ func (c *Cache) Load(ctx context.Context) (*netmap.NetworkMap, error) { return nil, err } - if err := c.readJSON(ctx, "dns", &netmapDNS{DNS: &nm.DNS}); err != nil { + if err := c.readJSON(ctx, dnsKey, &netmapDNS{DNS: &nm.DNS}); err != nil { return nil, err } - if err := c.readJSON(ctx, "derpmap", &netmapDERPMap{DERPMap: &nm.DERPMap}); err != nil { + if err := c.readJSON(ctx, derpMapKey, &netmapDERPMap{DERPMap: &nm.DERPMap}); err != nil { return nil, err } - for key, err := range c.store.List(ctx, "peer-") { + for key, err := range c.store.List(ctx, string(peerKeyPrefix)) { if err != nil { return nil, err } var peer tailcfg.NodeView - if err := c.readJSON(ctx, key, &netmapNode{Node: &peer}); err != nil { + if err := c.readJSON(ctx, cacheKey(key), &netmapNode{Node: &peer}); err != nil { return nil, err } nm.Peers = append(nm.Peers, peer) } slices.SortFunc(nm.Peers, func(a, b tailcfg.NodeView) int { return cmp.Compare(a.ID(), b.ID()) }) - for key, err := range c.store.List(ctx, "user-") { + for key, err := range c.store.List(ctx, string(userKeyPrefix)) { if err != nil { return nil, err } var up tailcfg.UserProfileView - if err := c.readJSON(ctx, key, &netmapUserProfile{UserProfile: &up}); err != nil { + if err := c.readJSON(ctx, cacheKey(key), &netmapUserProfile{UserProfile: &up}); err != nil { return nil, err } mak.Set(&nm.UserProfiles, up.ID(), up) } - if err := c.readJSON(ctx, "ssh", &netmapSSH{SSHPolicy: &nm.SSHPolicy}); err != nil { + if err := c.readJSON(ctx, sshPolicyKey, &netmapSSH{SSHPolicy: &nm.SSHPolicy}); err != nil { return nil, err } + if err := c.readJSON(ctx, packetFilterKey, &netmapPacketFilter{Rules: &nm.PacketFilterRules}); err != nil { + return nil, err + } else if r := nm.PacketFilterRules; r.Len() != 0 { + // Reconstitute packet match expressions from the filter rules, + nm.PacketFilter, err = filter.MatchesFromFilterRules(r.AsSlice()) + if err != nil { + return nil, err + } + } return &nm, nil } -func (c *Cache) readJSON(ctx context.Context, key string, value any) error { - data, err := c.store.Load(ctx, key) +func (c *Cache) readJSON(ctx context.Context, key cacheKey, value any) error { + data, err := c.store.Load(ctx, string(key)) if errors.Is(err, ErrKeyNotFound) { return nil } else if err != nil { diff --git a/ipn/ipnlocal/netmapcache/netmapcache_test.go b/ipn/ipnlocal/netmapcache/netmapcache_test.go index b31db2d5e..b5a46d298 100644 --- a/ipn/ipnlocal/netmapcache/netmapcache_test.go +++ b/ipn/ipnlocal/netmapcache/netmapcache_test.go @@ -11,6 +11,7 @@ import ( "fmt" "iter" "maps" + "net/netip" "os" "reflect" "slices" @@ -23,10 +24,13 @@ import ( "tailscale.com/ipn/ipnlocal/netmapcache" "tailscale.com/tailcfg" "tailscale.com/tka" + "tailscale.com/types/ipproto" "tailscale.com/types/key" "tailscale.com/types/netmap" "tailscale.com/types/views" "tailscale.com/util/set" + "tailscale.com/wgengine/filter" + "tailscale.com/wgengine/filter/filtertype" ) // Input values for valid-looking placeholder values for keys, hashes, etc. @@ -68,6 +72,27 @@ func init() { panic(fmt.Sprintf("invalid test AUM hash %q: %v", testAUMHashString, err)) } + pfRules := []tailcfg.FilterRule{ + { + SrcIPs: []string{"192.168.0.0/16"}, + DstPorts: []tailcfg.NetPortRange{{ + IP: "*", + Ports: tailcfg.PortRange{First: 2000, Last: 9999}, + }}, + IPProto: []int{1, 6, 17}, // ICMPv4, TCP, UDP + CapGrant: []tailcfg.CapGrant{{ + Dsts: []netip.Prefix{netip.MustParsePrefix("192.168.4.0/24")}, + CapMap: tailcfg.PeerCapMap{ + "tailscale.com/testcap": []tailcfg.RawMessage{`"apple"`, `"pear"`}, + }, + }}, + }, + } + pfMatch, err := filter.MatchesFromFilterRules(pfRules) + if err != nil { + panic(fmt.Sprintf("invalid packet filter rules: %v", err)) + } + // The following network map must have a non-zero non-empty value for every // field that is to be stored in the cache. The test checks for this using // reflection, as a way to ensure that new fields added to the type are @@ -79,8 +104,9 @@ func init() { testMap = &netmap.NetworkMap{ Cached: false, // not cached, this is metadata for the cache machinery - PacketFilter: nil, // not cached - PacketFilterRules: views.Slice[tailcfg.FilterRule]{}, // not cached + // These two fields must contain compatible data. + PacketFilterRules: views.SliceOf(pfRules), + PacketFilter: pfMatch, // Fields stored under the "self" key. // Note that SelfNode must have a valid user in order to be considered @@ -235,7 +261,7 @@ func TestInvalidCache(t *testing.T) { // skippedMapFields are the names of fields that should not be considered by // network map caching, and thus skipped when comparing test results. var skippedMapFields = []string{ - "Cached", "PacketFilter", "PacketFilterRules", + "Cached", } // checkFieldCoverage logs an error in t if any of the fields of nm are zero @@ -366,6 +392,27 @@ func (t testStore) Remove(_ context.Context, key string) error { delete(t, key); func diffNetMaps(got, want *netmap.NetworkMap) string { return cmp.Diff(got, want, cmpopts.IgnoreFields(netmap.NetworkMap{}, skippedMapFields...), - cmpopts.EquateComparable(key.NodePublic{}, key.MachinePublic{}), + cmpopts.IgnoreFields(filtertype.Match{}, "SrcsContains"), // function pointer + cmpopts.EquateComparable(key.NodePublic{}, key.MachinePublic{}, netip.Prefix{}), + cmp.Comparer(eqViewsSlice(eqFilterRules)), + cmp.Comparer(eqViewsSlice(func(a, b ipproto.Proto) bool { return a == b })), ) } + +func eqViewsSlice[T any](eqVal func(x, y T) bool) func(a, b views.Slice[T]) bool { + return func(a, b views.Slice[T]) bool { + if a.Len() != b.Len() { + return false + } + for i := range a.Len() { + if !eqVal(a.At(i), b.At(i)) { + return false + } + } + return true + } +} + +func eqFilterRules(a, b tailcfg.FilterRule) bool { + return cmp.Equal(a, b, cmpopts.EquateComparable(netip.Prefix{})) +} diff --git a/ipn/ipnlocal/netmapcache/types.go b/ipn/ipnlocal/netmapcache/types.go index 2fb5a1575..c9f9efc1e 100644 --- a/ipn/ipnlocal/netmapcache/types.go +++ b/ipn/ipnlocal/netmapcache/types.go @@ -7,6 +7,7 @@ import ( "tailscale.com/tailcfg" "tailscale.com/tka" "tailscale.com/types/key" + "tailscale.com/types/views" ) // The fields in the following wrapper types are all pointers, even when their @@ -50,3 +51,9 @@ type netmapNode struct { type netmapUserProfile struct { UserProfile *tailcfg.UserProfileView } + +type netmapPacketFilter struct { + Rules *views.Slice[tailcfg.FilterRule] + + // Match expressions are derived from the rules. +}