diff --git a/ipn/ipnlocal/netmapcache/netmapcache.go b/ipn/ipnlocal/netmapcache/netmapcache.go index b12443b99..1b8347f0b 100644 --- a/ipn/ipnlocal/netmapcache/netmapcache.go +++ b/ipn/ipnlocal/netmapcache/netmapcache.go @@ -27,6 +27,7 @@ import ( "tailscale.com/types/netmap" "tailscale.com/util/mak" "tailscale.com/util/set" + "tailscale.com/wgengine/filter" ) var ( @@ -45,17 +46,17 @@ var ( type Cache struct { store Store - // wantKeys records the storage keys from the last write or load of a cached + // wantKeys records the cache keys from the last write or load of a cached // netmap. This is used to prune keys that are no longer referenced after an // update. - wantKeys set.Set[string] + wantKeys set.Set[cacheKey] // lastWrote records the last values written to each stored key. // // TODO(creachadair): This is meant to avoid disk writes, but I'm not // convinced we need it. Or maybe just track hashes of the content rather // than caching a complete copy. - lastWrote map[string]lastWrote + lastWrote map[cacheKey]lastWrote } // NewCache constructs a new empty [Cache] from the given [Store]. @@ -66,8 +67,8 @@ func NewCache(s Store) *Cache { } return &Cache{ store: s, - wantKeys: make(set.Set[string]), - lastWrote: make(map[string]lastWrote), + wantKeys: make(set.Set[cacheKey]), + lastWrote: make(map[cacheKey]lastWrote), } } @@ -76,7 +77,7 @@ type lastWrote struct { at time.Time } -func (c *Cache) writeJSON(ctx context.Context, key string, v any) error { +func (c *Cache) writeJSON(ctx context.Context, key cacheKey, v any) error { j, err := jsonv1.Marshal(v) if err != nil { return fmt.Errorf("JSON marshalling %q: %w", key, err) @@ -90,7 +91,7 @@ func (c *Cache) writeJSON(ctx context.Context, key string, v any) error { return nil } - if err := c.store.Store(ctx, key, j); err != nil { + if err := c.store.Store(ctx, string(key), j); err != nil { return err } @@ -110,11 +111,12 @@ func (c *Cache) removeUnwantedKeys(ctx context.Context) error { errs = append(errs, err) break } - if !c.wantKeys.Contains(key) { + ckey := cacheKey(key) + if !c.wantKeys.Contains(ckey) { if err := c.store.Remove(ctx, key); err != nil { errs = append(errs, fmt.Errorf("remove key %q: %w", key, err)) } - delete(c.lastWrote, key) // even if removal failed, we don't want it + delete(c.lastWrote, ckey) // even if removal failed, we don't want it } } return errors.Join(errs...) @@ -177,6 +179,20 @@ func (s FileStore) Remove(ctx context.Context, key string) error { return err } +// cacheKey is a type wrapper for strings used as cache keys. +type cacheKey string + +const ( + selfKey cacheKey = "self" + miscKey cacheKey = "msic" + dnsKey cacheKey = "dns" + derpMapKey cacheKey = "derpmap" + peerKeyPrefix cacheKey = "peer-" // + stable ID + userKeyPrefix cacheKey = "user-" // + profile ID + sshPolicyKey cacheKey = "ssh" + packetFilterKey cacheKey = "filter" +) + // Store records nm in the cache, replacing any previously-cached values. func (c *Cache) Store(ctx context.Context, nm *netmap.NetworkMap) error { if !buildfeatures.HasCacheNetMap || nm == nil || nm.Cached { @@ -187,7 +203,7 @@ func (c *Cache) Store(ctx context.Context, nm *netmap.NetworkMap) error { } clear(c.wantKeys) - if err := c.writeJSON(ctx, "misc", netmapMisc{ + if err := c.writeJSON(ctx, miscKey, netmapMisc{ MachineKey: &nm.MachineKey, CollectServices: &nm.CollectServices, DisplayMessages: &nm.DisplayMessages, @@ -198,33 +214,36 @@ func (c *Cache) Store(ctx context.Context, nm *netmap.NetworkMap) error { }); err != nil { return err } - if err := c.writeJSON(ctx, "dns", netmapDNS{DNS: &nm.DNS}); err != nil { + if err := c.writeJSON(ctx, dnsKey, netmapDNS{DNS: &nm.DNS}); err != nil { return err } - if err := c.writeJSON(ctx, "derpmap", netmapDERPMap{DERPMap: &nm.DERPMap}); err != nil { + if err := c.writeJSON(ctx, derpMapKey, netmapDERPMap{DERPMap: &nm.DERPMap}); err != nil { return err } - if err := c.writeJSON(ctx, "self", netmapNode{Node: &nm.SelfNode}); err != nil { + if err := c.writeJSON(ctx, selfKey, netmapNode{Node: &nm.SelfNode}); err != nil { return err // N.B. The NodeKey and AllCaps fields can be recovered from SelfNode on // load, and do not need to be stored separately. } for _, p := range nm.Peers { - key := fmt.Sprintf("peer-%s", p.StableID()) + key := peerKeyPrefix + cacheKey(p.StableID()) if err := c.writeJSON(ctx, key, netmapNode{Node: &p}); err != nil { return err } } for uid, u := range nm.UserProfiles { - key := fmt.Sprintf("user-%d", uid) - if err := c.writeJSON(ctx, key, netmapUserProfile{UserProfile: &u}); err != nil { + key := fmt.Sprintf("%s%d", userKeyPrefix, uid) + if err := c.writeJSON(ctx, cacheKey(key), netmapUserProfile{UserProfile: &u}); err != nil { return err } } + if err := c.writeJSON(ctx, packetFilterKey, netmapPacketFilter{Rules: &nm.PacketFilterRules}); err != nil { + return err + } if buildfeatures.HasSSH && nm.SSHPolicy != nil { - if err := c.writeJSON(ctx, "ssh", netmapSSH{SSHPolicy: &nm.SSHPolicy}); err != nil { + if err := c.writeJSON(ctx, sshPolicyKey, netmapSSH{SSHPolicy: &nm.SSHPolicy}); err != nil { return err } } @@ -244,12 +263,12 @@ func (c *Cache) Load(ctx context.Context) (*netmap.NetworkMap, error) { // At minimum, we require that the cache contain a "self" node, or the data // are not usable. - if self, err := c.store.Load(ctx, "self"); errors.Is(err, ErrKeyNotFound) { + if self, err := c.store.Load(ctx, string(selfKey)); errors.Is(err, ErrKeyNotFound) { return nil, ErrCacheNotAvailable } else if err := jsonv1.Unmarshal(self, &netmapNode{Node: &nm.SelfNode}); err != nil { return nil, err } - c.wantKeys.Add("self") + c.wantKeys.Add(selfKey) // If we successfully recovered a SelfNode, pull out its related fields. if s := nm.SelfNode; s.Valid() { @@ -266,7 +285,7 @@ func (c *Cache) Load(ctx context.Context) (*netmap.NetworkMap, error) { // Unmarshal the contents of each specified cache entry directly into the // fields of the output. See the comment in types.go for more detail. - if err := c.readJSON(ctx, "misc", &netmapMisc{ + if err := c.readJSON(ctx, miscKey, &netmapMisc{ MachineKey: &nm.MachineKey, CollectServices: &nm.CollectServices, DisplayMessages: &nm.DisplayMessages, @@ -278,43 +297,52 @@ func (c *Cache) Load(ctx context.Context) (*netmap.NetworkMap, error) { return nil, err } - if err := c.readJSON(ctx, "dns", &netmapDNS{DNS: &nm.DNS}); err != nil { + if err := c.readJSON(ctx, dnsKey, &netmapDNS{DNS: &nm.DNS}); err != nil { return nil, err } - if err := c.readJSON(ctx, "derpmap", &netmapDERPMap{DERPMap: &nm.DERPMap}); err != nil { + if err := c.readJSON(ctx, derpMapKey, &netmapDERPMap{DERPMap: &nm.DERPMap}); err != nil { return nil, err } - for key, err := range c.store.List(ctx, "peer-") { + for key, err := range c.store.List(ctx, string(peerKeyPrefix)) { if err != nil { return nil, err } var peer tailcfg.NodeView - if err := c.readJSON(ctx, key, &netmapNode{Node: &peer}); err != nil { + if err := c.readJSON(ctx, cacheKey(key), &netmapNode{Node: &peer}); err != nil { return nil, err } nm.Peers = append(nm.Peers, peer) } slices.SortFunc(nm.Peers, func(a, b tailcfg.NodeView) int { return cmp.Compare(a.ID(), b.ID()) }) - for key, err := range c.store.List(ctx, "user-") { + for key, err := range c.store.List(ctx, string(userKeyPrefix)) { if err != nil { return nil, err } var up tailcfg.UserProfileView - if err := c.readJSON(ctx, key, &netmapUserProfile{UserProfile: &up}); err != nil { + if err := c.readJSON(ctx, cacheKey(key), &netmapUserProfile{UserProfile: &up}); err != nil { return nil, err } mak.Set(&nm.UserProfiles, up.ID(), up) } - if err := c.readJSON(ctx, "ssh", &netmapSSH{SSHPolicy: &nm.SSHPolicy}); err != nil { + if err := c.readJSON(ctx, sshPolicyKey, &netmapSSH{SSHPolicy: &nm.SSHPolicy}); err != nil { return nil, err } + if err := c.readJSON(ctx, packetFilterKey, &netmapPacketFilter{Rules: &nm.PacketFilterRules}); err != nil { + return nil, err + } else if r := nm.PacketFilterRules; r.Len() != 0 { + // Reconstitute packet match expressions from the filter rules, + nm.PacketFilter, err = filter.MatchesFromFilterRules(r.AsSlice()) + if err != nil { + return nil, err + } + } return &nm, nil } -func (c *Cache) readJSON(ctx context.Context, key string, value any) error { - data, err := c.store.Load(ctx, key) +func (c *Cache) readJSON(ctx context.Context, key cacheKey, value any) error { + data, err := c.store.Load(ctx, string(key)) if errors.Is(err, ErrKeyNotFound) { return nil } else if err != nil { diff --git a/ipn/ipnlocal/netmapcache/netmapcache_test.go b/ipn/ipnlocal/netmapcache/netmapcache_test.go index b31db2d5e..b5a46d298 100644 --- a/ipn/ipnlocal/netmapcache/netmapcache_test.go +++ b/ipn/ipnlocal/netmapcache/netmapcache_test.go @@ -11,6 +11,7 @@ import ( "fmt" "iter" "maps" + "net/netip" "os" "reflect" "slices" @@ -23,10 +24,13 @@ import ( "tailscale.com/ipn/ipnlocal/netmapcache" "tailscale.com/tailcfg" "tailscale.com/tka" + "tailscale.com/types/ipproto" "tailscale.com/types/key" "tailscale.com/types/netmap" "tailscale.com/types/views" "tailscale.com/util/set" + "tailscale.com/wgengine/filter" + "tailscale.com/wgengine/filter/filtertype" ) // Input values for valid-looking placeholder values for keys, hashes, etc. @@ -68,6 +72,27 @@ func init() { panic(fmt.Sprintf("invalid test AUM hash %q: %v", testAUMHashString, err)) } + pfRules := []tailcfg.FilterRule{ + { + SrcIPs: []string{"192.168.0.0/16"}, + DstPorts: []tailcfg.NetPortRange{{ + IP: "*", + Ports: tailcfg.PortRange{First: 2000, Last: 9999}, + }}, + IPProto: []int{1, 6, 17}, // ICMPv4, TCP, UDP + CapGrant: []tailcfg.CapGrant{{ + Dsts: []netip.Prefix{netip.MustParsePrefix("192.168.4.0/24")}, + CapMap: tailcfg.PeerCapMap{ + "tailscale.com/testcap": []tailcfg.RawMessage{`"apple"`, `"pear"`}, + }, + }}, + }, + } + pfMatch, err := filter.MatchesFromFilterRules(pfRules) + if err != nil { + panic(fmt.Sprintf("invalid packet filter rules: %v", err)) + } + // The following network map must have a non-zero non-empty value for every // field that is to be stored in the cache. The test checks for this using // reflection, as a way to ensure that new fields added to the type are @@ -79,8 +104,9 @@ func init() { testMap = &netmap.NetworkMap{ Cached: false, // not cached, this is metadata for the cache machinery - PacketFilter: nil, // not cached - PacketFilterRules: views.Slice[tailcfg.FilterRule]{}, // not cached + // These two fields must contain compatible data. + PacketFilterRules: views.SliceOf(pfRules), + PacketFilter: pfMatch, // Fields stored under the "self" key. // Note that SelfNode must have a valid user in order to be considered @@ -235,7 +261,7 @@ func TestInvalidCache(t *testing.T) { // skippedMapFields are the names of fields that should not be considered by // network map caching, and thus skipped when comparing test results. var skippedMapFields = []string{ - "Cached", "PacketFilter", "PacketFilterRules", + "Cached", } // checkFieldCoverage logs an error in t if any of the fields of nm are zero @@ -366,6 +392,27 @@ func (t testStore) Remove(_ context.Context, key string) error { delete(t, key); func diffNetMaps(got, want *netmap.NetworkMap) string { return cmp.Diff(got, want, cmpopts.IgnoreFields(netmap.NetworkMap{}, skippedMapFields...), - cmpopts.EquateComparable(key.NodePublic{}, key.MachinePublic{}), + cmpopts.IgnoreFields(filtertype.Match{}, "SrcsContains"), // function pointer + cmpopts.EquateComparable(key.NodePublic{}, key.MachinePublic{}, netip.Prefix{}), + cmp.Comparer(eqViewsSlice(eqFilterRules)), + cmp.Comparer(eqViewsSlice(func(a, b ipproto.Proto) bool { return a == b })), ) } + +func eqViewsSlice[T any](eqVal func(x, y T) bool) func(a, b views.Slice[T]) bool { + return func(a, b views.Slice[T]) bool { + if a.Len() != b.Len() { + return false + } + for i := range a.Len() { + if !eqVal(a.At(i), b.At(i)) { + return false + } + } + return true + } +} + +func eqFilterRules(a, b tailcfg.FilterRule) bool { + return cmp.Equal(a, b, cmpopts.EquateComparable(netip.Prefix{})) +} diff --git a/ipn/ipnlocal/netmapcache/types.go b/ipn/ipnlocal/netmapcache/types.go index 2fb5a1575..c9f9efc1e 100644 --- a/ipn/ipnlocal/netmapcache/types.go +++ b/ipn/ipnlocal/netmapcache/types.go @@ -7,6 +7,7 @@ import ( "tailscale.com/tailcfg" "tailscale.com/tka" "tailscale.com/types/key" + "tailscale.com/types/views" ) // The fields in the following wrapper types are all pointers, even when their @@ -50,3 +51,9 @@ type netmapNode struct { type netmapUserProfile struct { UserProfile *tailcfg.UserProfileView } + +type netmapPacketFilter struct { + Rules *views.Slice[tailcfg.FilterRule] + + // Match expressions are derived from the rules. +}