util/syspolicy: add caching handler (#10288)
Fixes tailscale/corp#15850 Co-authored-by: Adrian Dewhurst <adrian@tailscale.com> Signed-off-by: Claire Wang <claire@tailscale.com>main
parent
719ee4415e
commit
b8a2aedccd
@ -0,0 +1,98 @@ |
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package syspolicy |
||||
|
||||
import ( |
||||
"errors" |
||||
"sync" |
||||
) |
||||
|
||||
// CachingHandler is a handler that reads policies from an underlying handler the first time each key is requested
|
||||
// and permanently caches the result unless there is an error. If there is an ErrNoSuchKey error, that result is cached,
|
||||
// otherwise the actual error is returned and the next read for that key will retry using the handler.
|
||||
type CachingHandler struct { |
||||
mu sync.Mutex |
||||
strings map[string]string |
||||
uint64s map[string]uint64 |
||||
bools map[string]bool |
||||
notFound map[string]bool |
||||
handler Handler |
||||
} |
||||
|
||||
// NewCachingHandler creates a CachingHandler given a handler.
|
||||
func NewCachingHandler(handler Handler) *CachingHandler { |
||||
return &CachingHandler{ |
||||
handler: handler, |
||||
strings: make(map[string]string), |
||||
uint64s: make(map[string]uint64), |
||||
bools: make(map[string]bool), |
||||
notFound: make(map[string]bool), |
||||
} |
||||
} |
||||
|
||||
// ReadString reads the policy settings value string given the key.
|
||||
// ReadString first reads from the handler's cache before resorting to using the handler.
|
||||
func (ch *CachingHandler) ReadString(key string) (string, error) { |
||||
ch.mu.Lock() |
||||
defer ch.mu.Unlock() |
||||
if val, ok := ch.strings[key]; ok { |
||||
return val, nil |
||||
} |
||||
if notFound := ch.notFound[key]; notFound { |
||||
return "", ErrNoSuchKey |
||||
} |
||||
val, err := ch.handler.ReadString(key) |
||||
if errors.Is(err, ErrNoSuchKey) { |
||||
ch.notFound[key] = true |
||||
return "", err |
||||
} else if err != nil { |
||||
return "", err |
||||
} |
||||
ch.strings[key] = val |
||||
return val, nil |
||||
} |
||||
|
||||
// ReadUInt64 reads the policy settings uint64 value given the key.
|
||||
// ReadUInt64 first reads from the handler's cache before resorting to using the handler.
|
||||
func (ch *CachingHandler) ReadUInt64(key string) (uint64, error) { |
||||
ch.mu.Lock() |
||||
defer ch.mu.Unlock() |
||||
if val, ok := ch.uint64s[key]; ok { |
||||
return val, nil |
||||
} |
||||
if notFound := ch.notFound[key]; notFound { |
||||
return 0, ErrNoSuchKey |
||||
} |
||||
val, err := ch.handler.ReadUInt64(key) |
||||
if errors.Is(err, ErrNoSuchKey) { |
||||
ch.notFound[key] = true |
||||
return 0, err |
||||
} else if err != nil { |
||||
return 0, err |
||||
} |
||||
ch.uint64s[key] = val |
||||
return val, nil |
||||
} |
||||
|
||||
// ReadBoolean reads the policy settings boolean value given the key.
|
||||
// ReadBoolean first reads from the handler's cache before resorting to using the handler.
|
||||
func (ch *CachingHandler) ReadBoolean(key string) (bool, error) { |
||||
ch.mu.Lock() |
||||
defer ch.mu.Unlock() |
||||
if val, ok := ch.bools[key]; ok { |
||||
return val, nil |
||||
} |
||||
if notFound := ch.notFound[key]; notFound { |
||||
return false, ErrNoSuchKey |
||||
} |
||||
val, err := ch.handler.ReadBoolean(key) |
||||
if errors.Is(err, ErrNoSuchKey) { |
||||
ch.notFound[key] = true |
||||
return false, err |
||||
} else if err != nil { |
||||
return false, err |
||||
} |
||||
ch.bools[key] = val |
||||
return val, nil |
||||
} |
||||
@ -0,0 +1,262 @@ |
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package syspolicy |
||||
|
||||
import ( |
||||
"testing" |
||||
) |
||||
|
||||
func TestHandlerReadString(t *testing.T) { |
||||
tests := []struct { |
||||
name string |
||||
key string |
||||
handlerKey Key |
||||
handlerValue string |
||||
handlerError error |
||||
preserveHandler bool |
||||
wantValue string |
||||
wantErr error |
||||
strings map[string]string |
||||
expectedCalls int |
||||
}{ |
||||
{ |
||||
name: "read existing cached values", |
||||
key: "test", |
||||
handlerKey: "do not read", |
||||
strings: map[string]string{"test": "foo"}, |
||||
wantValue: "foo", |
||||
expectedCalls: 0, |
||||
}, |
||||
{ |
||||
name: "read existing values not cached", |
||||
key: "test", |
||||
handlerKey: "test", |
||||
handlerValue: "foo", |
||||
wantValue: "foo", |
||||
expectedCalls: 1, |
||||
}, |
||||
{ |
||||
name: "error no such key", |
||||
key: "test", |
||||
handlerKey: "test", |
||||
handlerError: ErrNoSuchKey, |
||||
wantErr: ErrNoSuchKey, |
||||
expectedCalls: 1, |
||||
}, |
||||
{ |
||||
name: "other error", |
||||
key: "test", |
||||
handlerKey: "test", |
||||
handlerError: someOtherError, |
||||
wantErr: someOtherError, |
||||
preserveHandler: true, |
||||
expectedCalls: 2, |
||||
}, |
||||
} |
||||
|
||||
for _, tt := range tests { |
||||
t.Run(tt.name, func(t *testing.T) { |
||||
testHandler := &testHandler{ |
||||
t: t, |
||||
key: tt.handlerKey, |
||||
s: tt.handlerValue, |
||||
err: tt.handlerError, |
||||
} |
||||
cache := NewCachingHandler(testHandler) |
||||
if tt.strings != nil { |
||||
cache.strings = tt.strings |
||||
} |
||||
got, err := cache.ReadString(tt.key) |
||||
if err != tt.wantErr { |
||||
t.Errorf("err=%v want %v", err, tt.wantErr) |
||||
} |
||||
if got != tt.wantValue { |
||||
t.Errorf("got %v want %v", got, cache.strings[tt.key]) |
||||
} |
||||
if !tt.preserveHandler { |
||||
testHandler.key, testHandler.s, testHandler.err = "do not read", "", nil |
||||
} |
||||
got, err = cache.ReadString(tt.key) |
||||
if err != tt.wantErr { |
||||
t.Errorf("repeat err=%v want %v", err, tt.wantErr) |
||||
} |
||||
if got != tt.wantValue { |
||||
t.Errorf("repeat got %v want %v", got, cache.strings[tt.key]) |
||||
} |
||||
if testHandler.calls != tt.expectedCalls { |
||||
t.Errorf("calls=%v want %v", testHandler.calls, tt.expectedCalls) |
||||
} |
||||
}) |
||||
} |
||||
} |
||||
|
||||
func TestHandlerReadUint64(t *testing.T) { |
||||
tests := []struct { |
||||
name string |
||||
key string |
||||
handlerKey Key |
||||
handlerValue uint64 |
||||
handlerError error |
||||
preserveHandler bool |
||||
wantValue uint64 |
||||
wantErr error |
||||
uint64s map[string]uint64 |
||||
expectedCalls int |
||||
}{ |
||||
{ |
||||
name: "read existing cached values", |
||||
key: "test", |
||||
handlerKey: "do not read", |
||||
uint64s: map[string]uint64{"test": 1}, |
||||
wantValue: 1, |
||||
expectedCalls: 0, |
||||
}, |
||||
{ |
||||
name: "read existing values not cached", |
||||
key: "test", |
||||
handlerKey: "test", |
||||
handlerValue: 1, |
||||
wantValue: 1, |
||||
expectedCalls: 1, |
||||
}, |
||||
{ |
||||
name: "error no such key", |
||||
key: "test", |
||||
handlerKey: "test", |
||||
handlerError: ErrNoSuchKey, |
||||
wantErr: ErrNoSuchKey, |
||||
expectedCalls: 1, |
||||
}, |
||||
{ |
||||
name: "other error", |
||||
key: "test", |
||||
handlerKey: "test", |
||||
handlerError: someOtherError, |
||||
wantErr: someOtherError, |
||||
preserveHandler: true, |
||||
expectedCalls: 2, |
||||
}, |
||||
} |
||||
|
||||
for _, tt := range tests { |
||||
t.Run(tt.name, func(t *testing.T) { |
||||
testHandler := &testHandler{ |
||||
t: t, |
||||
key: tt.handlerKey, |
||||
u64: tt.handlerValue, |
||||
err: tt.handlerError, |
||||
} |
||||
cache := NewCachingHandler(testHandler) |
||||
if tt.uint64s != nil { |
||||
cache.uint64s = tt.uint64s |
||||
} |
||||
got, err := cache.ReadUInt64(tt.key) |
||||
if err != tt.wantErr { |
||||
t.Errorf("err=%v want %v", err, tt.wantErr) |
||||
} |
||||
if got != tt.wantValue { |
||||
t.Errorf("got %v want %v", got, cache.strings[tt.key]) |
||||
} |
||||
if !tt.preserveHandler { |
||||
testHandler.key, testHandler.s, testHandler.err = "do not read", "", nil |
||||
} |
||||
got, err = cache.ReadUInt64(tt.key) |
||||
if err != tt.wantErr { |
||||
t.Errorf("repeat err=%v want %v", err, tt.wantErr) |
||||
} |
||||
if got != tt.wantValue { |
||||
t.Errorf("repeat got %v want %v", got, cache.strings[tt.key]) |
||||
} |
||||
if testHandler.calls != tt.expectedCalls { |
||||
t.Errorf("calls=%v want %v", testHandler.calls, tt.expectedCalls) |
||||
} |
||||
}) |
||||
} |
||||
|
||||
} |
||||
|
||||
func TestHandlerReadBool(t *testing.T) { |
||||
tests := []struct { |
||||
name string |
||||
key string |
||||
handlerKey Key |
||||
handlerValue bool |
||||
handlerError error |
||||
preserveHandler bool |
||||
wantValue bool |
||||
wantErr error |
||||
bools map[string]bool |
||||
expectedCalls int |
||||
}{ |
||||
{ |
||||
name: "read existing cached values", |
||||
key: "test", |
||||
handlerKey: "do not read", |
||||
bools: map[string]bool{"test": true}, |
||||
wantValue: true, |
||||
expectedCalls: 0, |
||||
}, |
||||
{ |
||||
name: "read existing values not cached", |
||||
key: "test", |
||||
handlerKey: "test", |
||||
handlerValue: true, |
||||
wantValue: true, |
||||
expectedCalls: 1, |
||||
}, |
||||
{ |
||||
name: "error no such key", |
||||
key: "test", |
||||
handlerKey: "test", |
||||
handlerError: ErrNoSuchKey, |
||||
wantErr: ErrNoSuchKey, |
||||
expectedCalls: 1, |
||||
}, |
||||
{ |
||||
name: "other error", |
||||
key: "test", |
||||
handlerKey: "test", |
||||
handlerError: someOtherError, |
||||
wantErr: someOtherError, |
||||
preserveHandler: true, |
||||
expectedCalls: 2, |
||||
}, |
||||
} |
||||
|
||||
for _, tt := range tests { |
||||
t.Run(tt.name, func(t *testing.T) { |
||||
testHandler := &testHandler{ |
||||
t: t, |
||||
key: tt.handlerKey, |
||||
b: tt.handlerValue, |
||||
err: tt.handlerError, |
||||
} |
||||
cache := NewCachingHandler(testHandler) |
||||
if tt.bools != nil { |
||||
cache.bools = tt.bools |
||||
} |
||||
got, err := cache.ReadBoolean(tt.key) |
||||
if err != tt.wantErr { |
||||
t.Errorf("err=%v want %v", err, tt.wantErr) |
||||
} |
||||
if got != tt.wantValue { |
||||
t.Errorf("got %v want %v", got, cache.strings[tt.key]) |
||||
} |
||||
if !tt.preserveHandler { |
||||
testHandler.key, testHandler.s, testHandler.err = "do not read", "", nil |
||||
} |
||||
got, err = cache.ReadBoolean(tt.key) |
||||
if err != tt.wantErr { |
||||
t.Errorf("repeat err=%v want %v", err, tt.wantErr) |
||||
} |
||||
if got != tt.wantValue { |
||||
t.Errorf("repeat got %v want %v", got, cache.strings[tt.key]) |
||||
} |
||||
if testHandler.calls != tt.expectedCalls { |
||||
t.Errorf("calls=%v want %v", testHandler.calls, tt.expectedCalls) |
||||
} |
||||
}) |
||||
} |
||||
|
||||
} |
||||
Loading…
Reference in new issue