In this PR we add syspolicy/rsop package that facilitates policy source registration and provides access to the resultant policy merged from all registered sources for a given scope. Updates #12687 Signed-off-by: Nick Khyl <nickk@tailscale.com>main
parent
2aa9125ac4
commit
ff5f233c3a
@ -0,0 +1,107 @@ |
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package rsop |
||||
|
||||
import ( |
||||
"reflect" |
||||
"slices" |
||||
"sync" |
||||
"time" |
||||
|
||||
"tailscale.com/util/set" |
||||
"tailscale.com/util/syspolicy/internal/loggerx" |
||||
"tailscale.com/util/syspolicy/setting" |
||||
) |
||||
|
||||
// Change represents a change from the Old to the New value of type T.
|
||||
type Change[T any] struct { |
||||
New, Old T |
||||
} |
||||
|
||||
// PolicyChangeCallback is a function called whenever a policy changes.
|
||||
type PolicyChangeCallback func(*PolicyChange) |
||||
|
||||
// PolicyChange describes a policy change.
|
||||
type PolicyChange struct { |
||||
snapshots Change[*setting.Snapshot] |
||||
} |
||||
|
||||
// New returns the [setting.Snapshot] after the change.
|
||||
func (c PolicyChange) New() *setting.Snapshot { |
||||
return c.snapshots.New |
||||
} |
||||
|
||||
// Old returns the [setting.Snapshot] before the change.
|
||||
func (c PolicyChange) Old() *setting.Snapshot { |
||||
return c.snapshots.Old |
||||
} |
||||
|
||||
// HasChanged reports whether a policy setting with the specified [setting.Key], has changed.
|
||||
func (c PolicyChange) HasChanged(key setting.Key) bool { |
||||
new, newErr := c.snapshots.New.GetErr(key) |
||||
old, oldErr := c.snapshots.Old.GetErr(key) |
||||
if newErr != nil && oldErr != nil { |
||||
return false |
||||
} |
||||
if newErr != nil || oldErr != nil { |
||||
return true |
||||
} |
||||
switch newVal := new.(type) { |
||||
case bool, uint64, string, setting.Visibility, setting.PreferenceOption, time.Duration: |
||||
return newVal != old |
||||
case []string: |
||||
oldVal, ok := old.([]string) |
||||
return !ok || !slices.Equal(newVal, oldVal) |
||||
default: |
||||
loggerx.Errorf("[unexpected] %q has an unsupported value type: %T", key, newVal) |
||||
return !reflect.DeepEqual(new, old) |
||||
} |
||||
} |
||||
|
||||
// policyChangeCallbacks are the callbacks to invoke when the effective policy changes.
|
||||
// It is safe for concurrent use.
|
||||
type policyChangeCallbacks struct { |
||||
mu sync.Mutex |
||||
cbs set.HandleSet[PolicyChangeCallback] |
||||
} |
||||
|
||||
// Register adds the specified callback to be invoked whenever the policy changes.
|
||||
func (c *policyChangeCallbacks) Register(callback PolicyChangeCallback) (unregister func()) { |
||||
c.mu.Lock() |
||||
handle := c.cbs.Add(callback) |
||||
c.mu.Unlock() |
||||
return func() { |
||||
c.mu.Lock() |
||||
delete(c.cbs, handle) |
||||
c.mu.Unlock() |
||||
} |
||||
} |
||||
|
||||
// Invoke calls the registered callback functions with the specified policy change info.
|
||||
func (c *policyChangeCallbacks) Invoke(snapshots Change[*setting.Snapshot]) { |
||||
var wg sync.WaitGroup |
||||
defer wg.Wait() |
||||
|
||||
c.mu.Lock() |
||||
defer c.mu.Unlock() |
||||
|
||||
wg.Add(len(c.cbs)) |
||||
change := &PolicyChange{snapshots: snapshots} |
||||
for _, cb := range c.cbs { |
||||
go func() { |
||||
defer wg.Done() |
||||
cb(change) |
||||
}() |
||||
} |
||||
} |
||||
|
||||
// Close awaits the completion of active callbacks and prevents any further invocations.
|
||||
func (c *policyChangeCallbacks) Close() { |
||||
c.mu.Lock() |
||||
defer c.mu.Unlock() |
||||
if c.cbs != nil { |
||||
clear(c.cbs) |
||||
c.cbs = nil |
||||
} |
||||
} |
||||
@ -0,0 +1,449 @@ |
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package rsop |
||||
|
||||
import ( |
||||
"errors" |
||||
"fmt" |
||||
"slices" |
||||
"sync" |
||||
"sync/atomic" |
||||
"time" |
||||
|
||||
"tailscale.com/util/syspolicy/internal/loggerx" |
||||
"tailscale.com/util/syspolicy/setting" |
||||
|
||||
"tailscale.com/util/syspolicy/source" |
||||
) |
||||
|
||||
// ErrPolicyClosed is returned by [Policy.Reload], [Policy.addSource],
|
||||
// [Policy.removeSource] and [Policy.replaceSource] if the policy has been closed.
|
||||
var ErrPolicyClosed = errors.New("effective policy closed") |
||||
|
||||
// The minimum and maximum wait times after detecting a policy change
|
||||
// before reloading the policy. This only affects policy reloads triggered
|
||||
// by a change in the underlying [source.Store] and does not impact
|
||||
// synchronous, caller-initiated reloads, such as when [Policy.Reload] is called.
|
||||
//
|
||||
// Policy changes occurring within [policyReloadMinDelay] of each other
|
||||
// will be batched together, resulting in a single policy reload
|
||||
// no later than [policyReloadMaxDelay] after the first detected change.
|
||||
// In other words, the effective policy will be reloaded no more often than once
|
||||
// every 5 seconds, but at most 15 seconds after an underlying [source.Store]
|
||||
// has issued a policy change callback.
|
||||
//
|
||||
// See [Policy.watchReload].
|
||||
var ( |
||||
policyReloadMinDelay = 5 * time.Second |
||||
policyReloadMaxDelay = 15 * time.Second |
||||
) |
||||
|
||||
// Policy provides access to the current effective [setting.Snapshot] for a given
|
||||
// scope and allows to reload it from the underlying [source.Store] list. It also allows to
|
||||
// subscribe and receive a callback whenever the effective [setting.Snapshot] is changed.
|
||||
//
|
||||
// It is safe for concurrent use.
|
||||
type Policy struct { |
||||
scope setting.PolicyScope |
||||
|
||||
reloadCh chan reloadRequest // 1-buffered; written to when a policy reload is required
|
||||
closeCh chan struct{} // closed to signal that the Policy is being closed
|
||||
doneCh chan struct{} // closed by [Policy.closeInternal]
|
||||
|
||||
// effective is the most recent version of the [setting.Snapshot]
|
||||
// containing policy settings merged from all applicable sources.
|
||||
effective atomic.Pointer[setting.Snapshot] |
||||
|
||||
changeCallbacks policyChangeCallbacks |
||||
|
||||
mu sync.Mutex |
||||
watcherStarted bool // whether [Policy.watchReload] was started
|
||||
sources source.ReadableSources |
||||
closing bool // whether [Policy.Close] was called (even if we're still closing)
|
||||
} |
||||
|
||||
// newPolicy returns a new [Policy] for the specified [setting.PolicyScope]
|
||||
// that tracks changes and merges policy settings read from the specified sources.
|
||||
func newPolicy(scope setting.PolicyScope, sources ...*source.Source) (_ *Policy, err error) { |
||||
readableSources := make(source.ReadableSources, 0, len(sources)) |
||||
defer func() { |
||||
if err != nil { |
||||
readableSources.Close() |
||||
} |
||||
}() |
||||
for _, s := range sources { |
||||
reader, err := s.Reader() |
||||
if err != nil { |
||||
return nil, fmt.Errorf("failed to get a store reader: %w", err) |
||||
} |
||||
session, err := reader.OpenSession() |
||||
if err != nil { |
||||
return nil, fmt.Errorf("failed to open a reading session: %w", err) |
||||
} |
||||
readableSources = append(readableSources, source.ReadableSource{Source: s, ReadingSession: session}) |
||||
} |
||||
|
||||
// Sort policy sources by their precedence from lower to higher.
|
||||
// For example, {UserPolicy},{ProfilePolicy},{DevicePolicy}.
|
||||
readableSources.StableSort() |
||||
|
||||
p := &Policy{ |
||||
scope: scope, |
||||
sources: readableSources, |
||||
reloadCh: make(chan reloadRequest, 1), |
||||
closeCh: make(chan struct{}), |
||||
doneCh: make(chan struct{}), |
||||
} |
||||
if _, err := p.reloadNow(false); err != nil { |
||||
p.Close() |
||||
return nil, err |
||||
} |
||||
p.startWatchReloadIfNeeded() |
||||
return p, nil |
||||
} |
||||
|
||||
// IsValid reports whether p is in a valid state and has not been closed.
|
||||
//
|
||||
// Since p's state can be changed by other goroutines at any time, this should
|
||||
// only be used as an optimization.
|
||||
func (p *Policy) IsValid() bool { |
||||
select { |
||||
case <-p.closeCh: |
||||
return false |
||||
default: |
||||
return true |
||||
} |
||||
} |
||||
|
||||
// Scope returns the [setting.PolicyScope] that this policy applies to.
|
||||
func (p *Policy) Scope() setting.PolicyScope { |
||||
return p.scope |
||||
} |
||||
|
||||
// Get returns the effective [setting.Snapshot].
|
||||
func (p *Policy) Get() *setting.Snapshot { |
||||
return p.effective.Load() |
||||
} |
||||
|
||||
// RegisterChangeCallback adds a function to be called whenever the effective
|
||||
// policy changes. The returned function can be used to unregister the callback.
|
||||
func (p *Policy) RegisterChangeCallback(callback PolicyChangeCallback) (unregister func()) { |
||||
return p.changeCallbacks.Register(callback) |
||||
} |
||||
|
||||
// Reload synchronously re-reads policy settings from the underlying list of policy sources,
|
||||
// constructing a new merged [setting.Snapshot] even if the policy remains unchanged.
|
||||
// In most scenarios, there's no need to re-read the policy manually.
|
||||
// Instead, it is recommended to register a policy change callback, or to use
|
||||
// the most recent [setting.Snapshot] returned by the [Policy.Get] method.
|
||||
//
|
||||
// It must not be called with p.mu held.
|
||||
func (p *Policy) Reload() (*setting.Snapshot, error) { |
||||
return p.reload(true) |
||||
} |
||||
|
||||
// reload is like Reload, but allows to specify whether to re-read policy settings
|
||||
// from unchanged policy sources.
|
||||
//
|
||||
// It must not be called with p.mu held.
|
||||
func (p *Policy) reload(force bool) (*setting.Snapshot, error) { |
||||
if !p.startWatchReloadIfNeeded() { |
||||
return p.Get(), nil |
||||
} |
||||
|
||||
respCh := make(chan reloadResponse, 1) |
||||
select { |
||||
case p.reloadCh <- reloadRequest{force: force, respCh: respCh}: |
||||
// continue
|
||||
case <-p.closeCh: |
||||
return nil, ErrPolicyClosed |
||||
} |
||||
select { |
||||
case resp := <-respCh: |
||||
return resp.policy, resp.err |
||||
case <-p.closeCh: |
||||
return nil, ErrPolicyClosed |
||||
} |
||||
} |
||||
|
||||
// reloadAsync requests an asynchronous background policy reload.
|
||||
// The policy will be reloaded no later than in [policyReloadMaxDelay].
|
||||
//
|
||||
// It must not be called with p.mu held.
|
||||
func (p *Policy) reloadAsync() { |
||||
if !p.startWatchReloadIfNeeded() { |
||||
return |
||||
} |
||||
select { |
||||
case p.reloadCh <- reloadRequest{}: |
||||
// Sent.
|
||||
default: |
||||
// A reload request is already en route.
|
||||
} |
||||
} |
||||
|
||||
// reloadNow loads and merges policies from all sources, updating the effective policy.
|
||||
// If the force parameter is true, it forcibly reloads policies
|
||||
// from the underlying policy store, even if no policy changes were detected.
|
||||
//
|
||||
// Except for the initial policy reload during the [Policy] creation,
|
||||
// this method should only be called from the [Policy.watchReload] goroutine.
|
||||
func (p *Policy) reloadNow(force bool) (*setting.Snapshot, error) { |
||||
new, err := p.readAndMerge(force) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
old := p.effective.Swap(new) |
||||
// A nil old value indicates the initial policy load rather than a policy change.
|
||||
// Additionally, we should not invoke the policy change callbacks unless the
|
||||
// policy items have actually changed.
|
||||
if old != nil && !old.EqualItems(new) { |
||||
snapshots := Change[*setting.Snapshot]{New: new, Old: old} |
||||
p.changeCallbacks.Invoke(snapshots) |
||||
} |
||||
return new, nil |
||||
} |
||||
|
||||
// Done returns a channel that is closed when the [Policy] is closed.
|
||||
func (p *Policy) Done() <-chan struct{} { |
||||
return p.doneCh |
||||
} |
||||
|
||||
// readAndMerge reads and merges policy settings from all applicable sources,
|
||||
// returning a [setting.Snapshot] with the merged result.
|
||||
// If the force parameter is true, it re-reads policy settings from each source
|
||||
// even if no policy change was observed, and returns an error if the read
|
||||
// operation fails.
|
||||
func (p *Policy) readAndMerge(force bool) (*setting.Snapshot, error) { |
||||
p.mu.Lock() |
||||
defer p.mu.Unlock() |
||||
// Start with an empty policy in the target scope.
|
||||
effective := setting.NewSnapshot(nil, setting.SummaryWith(p.scope)) |
||||
// Then merge policy settings from all sources.
|
||||
// Policy sources with the highest precedence (e.g., the device policy) are merged last,
|
||||
// overriding any conflicting policy settings with lower precedence.
|
||||
for _, s := range p.sources { |
||||
var policy *setting.Snapshot |
||||
if force { |
||||
var err error |
||||
if policy, err = s.ReadSettings(); err != nil { |
||||
return nil, err |
||||
} |
||||
} else { |
||||
policy = s.GetSettings() |
||||
} |
||||
effective = setting.MergeSnapshots(effective, policy) |
||||
} |
||||
return effective, nil |
||||
} |
||||
|
||||
// addSource adds the specified source to the list of sources used by p,
|
||||
// and triggers a synchronous policy refresh. It returns an error
|
||||
// if the source is not a valid source for this effective policy,
|
||||
// or if the effective policy is being closed,
|
||||
// or if policy refresh fails with an error.
|
||||
func (p *Policy) addSource(source *source.Source) error { |
||||
return p.applySourcesChange(source, nil) |
||||
} |
||||
|
||||
// removeSource removes the specified source from the list of sources used by p,
|
||||
// and triggers a synchronous policy refresh. It returns an error if the
|
||||
// effective policy is being closed, or if policy refresh fails with an error.
|
||||
func (p *Policy) removeSource(source *source.Source) error { |
||||
return p.applySourcesChange(nil, source) |
||||
} |
||||
|
||||
// replaceSource replaces the old source with the new source atomically,
|
||||
// and triggers a synchronous policy refresh. It returns an error
|
||||
// if the source is not a valid source for this effective policy,
|
||||
// or if the effective policy is being closed,
|
||||
// or if policy refresh fails with an error.
|
||||
func (p *Policy) replaceSource(old, new *source.Source) error { |
||||
return p.applySourcesChange(new, old) |
||||
} |
||||
|
||||
func (p *Policy) applySourcesChange(toAdd, toRemove *source.Source) error { |
||||
if toAdd == toRemove { |
||||
return nil |
||||
} |
||||
if toAdd != nil && !toAdd.Scope().Contains(p.scope) { |
||||
return errors.New("scope mismatch") |
||||
} |
||||
|
||||
changed, err := func() (changed bool, err error) { |
||||
p.mu.Lock() |
||||
defer p.mu.Unlock() |
||||
if toAdd != nil && !p.sources.Contains(toAdd) { |
||||
reader, err := toAdd.Reader() |
||||
if err != nil { |
||||
return false, fmt.Errorf("failed to get a store reader: %w", err) |
||||
} |
||||
session, err := reader.OpenSession() |
||||
if err != nil { |
||||
return false, fmt.Errorf("failed to open a reading session: %w", err) |
||||
} |
||||
|
||||
addAt := p.sources.InsertionIndexOf(toAdd) |
||||
toAdd := source.ReadableSource{ |
||||
Source: toAdd, |
||||
ReadingSession: session, |
||||
} |
||||
p.sources = slices.Insert(p.sources, addAt, toAdd) |
||||
go p.watchPolicyChanges(toAdd) |
||||
changed = true |
||||
} |
||||
if toRemove != nil { |
||||
if deleteAt := p.sources.IndexOf(toRemove); deleteAt != -1 { |
||||
p.sources.DeleteAt(deleteAt) |
||||
changed = true |
||||
} |
||||
} |
||||
return changed, nil |
||||
}() |
||||
if changed { |
||||
_, err = p.reload(false) |
||||
} |
||||
return err // may be nil or non-nil
|
||||
} |
||||
|
||||
func (p *Policy) watchPolicyChanges(s source.ReadableSource) { |
||||
for { |
||||
select { |
||||
case _, ok := <-s.ReadingSession.PolicyChanged(): |
||||
if !ok { |
||||
p.mu.Lock() |
||||
abruptlyClosed := slices.Contains(p.sources, s) |
||||
p.mu.Unlock() |
||||
if abruptlyClosed { |
||||
// The underlying [source.Source] was closed abruptly without
|
||||
// being properly removed or replaced by another policy source.
|
||||
// We can't keep this [Policy] up to date, so we should close it.
|
||||
p.Close() |
||||
} |
||||
return |
||||
} |
||||
// The PolicyChanged channel was signaled.
|
||||
// Request an asynchronous policy reload.
|
||||
p.reloadAsync() |
||||
case <-p.closeCh: |
||||
// The [Policy] is being closed.
|
||||
return |
||||
} |
||||
} |
||||
} |
||||
|
||||
// startWatchReloadIfNeeded starts [Policy.watchReload] in a new goroutine
|
||||
// if the list of policy sources is not empty, it hasn't been started yet,
|
||||
// and the [Policy] is not being closed.
|
||||
// It reports whether [Policy.watchReload] has ever been started.
|
||||
//
|
||||
// It must not be called with p.mu held.
|
||||
func (p *Policy) startWatchReloadIfNeeded() bool { |
||||
p.mu.Lock() |
||||
defer p.mu.Unlock() |
||||
if len(p.sources) != 0 && !p.watcherStarted && !p.closing { |
||||
go p.watchReload() |
||||
for i := range p.sources { |
||||
go p.watchPolicyChanges(p.sources[i]) |
||||
} |
||||
p.watcherStarted = true |
||||
} |
||||
return p.watcherStarted |
||||
} |
||||
|
||||
// reloadRequest describes a policy reload request.
|
||||
type reloadRequest struct { |
||||
// force policy reload regardless of whether a policy change was detected.
|
||||
force bool |
||||
// respCh is an optional channel. If non-nil, it makes the reload request
|
||||
// synchronous and receives the result.
|
||||
respCh chan<- reloadResponse |
||||
} |
||||
|
||||
// reloadResponse is a result of a synchronous policy reload.
|
||||
type reloadResponse struct { |
||||
policy *setting.Snapshot |
||||
err error |
||||
} |
||||
|
||||
// watchReload processes incoming synchronous and asynchronous policy reload requests.
|
||||
//
|
||||
// Synchronous requests (with a non-nil respCh) are served immediately.
|
||||
//
|
||||
// Asynchronous requests are debounced and throttled: they are executed at least
|
||||
// [policyReloadMinDelay] after the last request, but no later than [policyReloadMaxDelay]
|
||||
// after the first request in a batch.
|
||||
func (p *Policy) watchReload() { |
||||
defer p.closeInternal() |
||||
|
||||
force := false // whether a forced refresh was requested
|
||||
var delayCh, timeoutCh <-chan time.Time |
||||
reload := func(respCh chan<- reloadResponse) { |
||||
delayCh, timeoutCh = nil, nil |
||||
policy, err := p.reloadNow(force) |
||||
if err != nil { |
||||
loggerx.Errorf("%v policy reload failed: %v\n", p.scope, err) |
||||
} |
||||
if respCh != nil { |
||||
respCh <- reloadResponse{policy: policy, err: err} |
||||
} |
||||
force = false |
||||
} |
||||
|
||||
loop: |
||||
for { |
||||
select { |
||||
case req := <-p.reloadCh: |
||||
if req.force { |
||||
force = true |
||||
} |
||||
if req.respCh != nil { |
||||
reload(req.respCh) |
||||
continue |
||||
} |
||||
if delayCh == nil { |
||||
timeoutCh = time.After(policyReloadMinDelay) |
||||
} |
||||
delayCh = time.After(policyReloadMaxDelay) |
||||
case <-delayCh: |
||||
reload(nil) |
||||
case <-timeoutCh: |
||||
reload(nil) |
||||
case <-p.closeCh: |
||||
break loop |
||||
} |
||||
} |
||||
} |
||||
|
||||
func (p *Policy) closeInternal() { |
||||
p.mu.Lock() |
||||
defer p.mu.Unlock() |
||||
p.sources.Close() |
||||
p.changeCallbacks.Close() |
||||
close(p.doneCh) |
||||
deletePolicy(p) |
||||
} |
||||
|
||||
// Close initiates the closing of the policy.
|
||||
// The [Policy.Done] channel is closed to signal that the operation has been completed.
|
||||
func (p *Policy) Close() { |
||||
p.mu.Lock() |
||||
alreadyClosing := p.closing |
||||
watcherStarted := p.watcherStarted |
||||
p.closing = true |
||||
p.mu.Unlock() |
||||
|
||||
if alreadyClosing { |
||||
return |
||||
} |
||||
|
||||
close(p.closeCh) |
||||
if !watcherStarted { |
||||
// Normally, closing p.closeCh signals [Policy.watchReload] to exit,
|
||||
// and [Policy.closeInternal] performs the actual closing when
|
||||
// [Policy.watchReload] returns. However, if the watcher was never
|
||||
// started, we need to call [Policy.closeInternal] manually.
|
||||
go p.closeInternal() |
||||
} |
||||
} |
||||
@ -0,0 +1,986 @@ |
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package rsop |
||||
|
||||
import ( |
||||
"errors" |
||||
"slices" |
||||
"sort" |
||||
"strconv" |
||||
"sync" |
||||
"testing" |
||||
"time" |
||||
|
||||
"github.com/google/go-cmp/cmp" |
||||
"github.com/google/go-cmp/cmp/cmpopts" |
||||
"tailscale.com/tstest" |
||||
"tailscale.com/util/syspolicy/setting" |
||||
|
||||
"tailscale.com/util/syspolicy/source" |
||||
) |
||||
|
||||
func TestGetEffectivePolicyNoSource(t *testing.T) { |
||||
tests := []struct { |
||||
name string |
||||
scope setting.PolicyScope |
||||
}{ |
||||
{ |
||||
name: "DevicePolicy", |
||||
scope: setting.DeviceScope, |
||||
}, |
||||
{ |
||||
name: "CurrentProfilePolicy", |
||||
scope: setting.CurrentProfileScope, |
||||
}, |
||||
{ |
||||
name: "CurrentUserPolicy", |
||||
scope: setting.CurrentUserScope, |
||||
}, |
||||
} |
||||
|
||||
for _, tt := range tests { |
||||
t.Run(tt.name, func(t *testing.T) { |
||||
var policy *Policy |
||||
t.Cleanup(func() { |
||||
if policy != nil { |
||||
policy.Close() |
||||
<-policy.Done() |
||||
} |
||||
}) |
||||
|
||||
// Make sure we don't create any goroutines.
|
||||
// We intentionally call ResourceCheck after t.Cleanup, so that when the test exits,
|
||||
// the resource check runs before the test cleanup closes the policy.
|
||||
// This helps to report any unexpectedly created goroutines.
|
||||
// The goal is to ensure that using the syspolicy package, and particularly
|
||||
// the rsop sub-package, is not wasteful and does not create unnecessary goroutines
|
||||
// on platforms without registered policy sources.
|
||||
tstest.ResourceCheck(t) |
||||
|
||||
policy, err := PolicyFor(tt.scope) |
||||
if err != nil { |
||||
t.Fatalf("Failed to get effective policy for %v: %v", tt.scope, err) |
||||
} |
||||
|
||||
if got := policy.Get(); got.Len() != 0 { |
||||
t.Errorf("Snapshot: got %v; want empty", got) |
||||
} |
||||
|
||||
if got, err := policy.Reload(); err != nil { |
||||
t.Errorf("Reload failed: %v", err) |
||||
} else if got.Len() != 0 { |
||||
t.Errorf("Snapshot: got %v; want empty", got) |
||||
} |
||||
}) |
||||
} |
||||
} |
||||
|
||||
func TestRegisterSourceAndGetEffectivePolicy(t *testing.T) { |
||||
type sourceConfig struct { |
||||
name string |
||||
scope setting.PolicyScope |
||||
settingKey setting.Key |
||||
settingValue string |
||||
wantEffective bool |
||||
} |
||||
tests := []struct { |
||||
name string |
||||
scope setting.PolicyScope |
||||
initialSources []sourceConfig |
||||
additionalSources []sourceConfig |
||||
wantSnapshot *setting.Snapshot |
||||
}{ |
||||
{ |
||||
name: "DevicePolicy/NoSources", |
||||
scope: setting.DeviceScope, |
||||
wantSnapshot: setting.NewSnapshot(nil, setting.DeviceScope), |
||||
}, |
||||
{ |
||||
name: "UserScope/NoSources", |
||||
scope: setting.CurrentUserScope, |
||||
wantSnapshot: setting.NewSnapshot(nil, setting.CurrentUserScope), |
||||
}, |
||||
{ |
||||
name: "DevicePolicy/OneInitialSource", |
||||
scope: setting.DeviceScope, |
||||
initialSources: []sourceConfig{ |
||||
{ |
||||
name: "TestSourceA", |
||||
scope: setting.DeviceScope, |
||||
settingKey: "TestKeyA", |
||||
settingValue: "TestValueA", |
||||
wantEffective: true, |
||||
}, |
||||
}, |
||||
wantSnapshot: setting.NewSnapshot(map[setting.Key]setting.RawItem{ |
||||
"TestKeyA": setting.RawItemWith("TestValueA", nil, setting.NewNamedOrigin("TestSourceA", setting.DeviceScope)), |
||||
}, setting.NewNamedOrigin("TestSourceA", setting.DeviceScope)), |
||||
}, |
||||
{ |
||||
name: "DevicePolicy/OneAdditionalSource", |
||||
scope: setting.DeviceScope, |
||||
additionalSources: []sourceConfig{ |
||||
{ |
||||
name: "TestSourceA", |
||||
scope: setting.DeviceScope, |
||||
settingKey: "TestKeyA", |
||||
settingValue: "TestValueA", |
||||
wantEffective: true, |
||||
}, |
||||
}, |
||||
wantSnapshot: setting.NewSnapshot(map[setting.Key]setting.RawItem{ |
||||
"TestKeyA": setting.RawItemWith("TestValueA", nil, setting.NewNamedOrigin("TestSourceA", setting.DeviceScope)), |
||||
}, setting.NewNamedOrigin("TestSourceA", setting.DeviceScope)), |
||||
}, |
||||
{ |
||||
name: "DevicePolicy/ManyInitialSources/NoConflicts", |
||||
scope: setting.DeviceScope, |
||||
initialSources: []sourceConfig{ |
||||
{ |
||||
name: "TestSourceA", |
||||
scope: setting.DeviceScope, |
||||
settingKey: "TestKeyA", |
||||
settingValue: "TestValueA", |
||||
wantEffective: true, |
||||
}, |
||||
{ |
||||
name: "TestSourceB", |
||||
scope: setting.DeviceScope, |
||||
settingKey: "TestKeyB", |
||||
settingValue: "TestValueB", |
||||
wantEffective: true, |
||||
}, |
||||
{ |
||||
name: "TestSourceC", |
||||
scope: setting.DeviceScope, |
||||
settingKey: "TestKeyC", |
||||
settingValue: "TestValueC", |
||||
wantEffective: true, |
||||
}, |
||||
}, |
||||
wantSnapshot: setting.NewSnapshot(map[setting.Key]setting.RawItem{ |
||||
"TestKeyA": setting.RawItemWith("TestValueA", nil, setting.NewNamedOrigin("TestSourceA", setting.DeviceScope)), |
||||
"TestKeyB": setting.RawItemWith("TestValueB", nil, setting.NewNamedOrigin("TestSourceB", setting.DeviceScope)), |
||||
"TestKeyC": setting.RawItemWith("TestValueC", nil, setting.NewNamedOrigin("TestSourceC", setting.DeviceScope)), |
||||
}, setting.DeviceScope), |
||||
}, |
||||
{ |
||||
name: "DevicePolicy/ManyInitialSources/Conflicts", |
||||
scope: setting.DeviceScope, |
||||
initialSources: []sourceConfig{ |
||||
{ |
||||
name: "TestSourceA", |
||||
scope: setting.DeviceScope, |
||||
settingKey: "TestKeyA", |
||||
settingValue: "TestValueA", |
||||
wantEffective: true, |
||||
}, |
||||
{ |
||||
name: "TestSourceB", |
||||
scope: setting.DeviceScope, |
||||
settingKey: "TestKeyB", |
||||
settingValue: "TestValueB", |
||||
wantEffective: true, |
||||
}, |
||||
{ |
||||
name: "TestSourceC", |
||||
scope: setting.DeviceScope, |
||||
settingKey: "TestKeyA", |
||||
settingValue: "TestValueC", |
||||
wantEffective: true, |
||||
}, |
||||
}, |
||||
wantSnapshot: setting.NewSnapshot(map[setting.Key]setting.RawItem{ |
||||
"TestKeyA": setting.RawItemWith("TestValueC", nil, setting.NewNamedOrigin("TestSourceC", setting.DeviceScope)), |
||||
"TestKeyB": setting.RawItemWith("TestValueB", nil, setting.NewNamedOrigin("TestSourceB", setting.DeviceScope)), |
||||
}, setting.DeviceScope), |
||||
}, |
||||
{ |
||||
name: "DevicePolicy/MixedSources/Conflicts", |
||||
scope: setting.DeviceScope, |
||||
initialSources: []sourceConfig{ |
||||
{ |
||||
name: "TestSourceA", |
||||
scope: setting.DeviceScope, |
||||
settingKey: "TestKeyA", |
||||
settingValue: "TestValueA", |
||||
wantEffective: true, |
||||
}, |
||||
{ |
||||
name: "TestSourceB", |
||||
scope: setting.DeviceScope, |
||||
settingKey: "TestKeyB", |
||||
settingValue: "TestValueB", |
||||
wantEffective: true, |
||||
}, |
||||
{ |
||||
name: "TestSourceC", |
||||
scope: setting.DeviceScope, |
||||
settingKey: "TestKeyA", |
||||
settingValue: "TestValueC", |
||||
wantEffective: true, |
||||
}, |
||||
}, |
||||
additionalSources: []sourceConfig{ |
||||
{ |
||||
name: "TestSourceD", |
||||
scope: setting.DeviceScope, |
||||
settingKey: "TestKeyA", |
||||
settingValue: "TestValueD", |
||||
wantEffective: true, |
||||
}, |
||||
{ |
||||
name: "TestSourceE", |
||||
scope: setting.DeviceScope, |
||||
settingKey: "TestKeyC", |
||||
settingValue: "TestValueE", |
||||
wantEffective: true, |
||||
}, |
||||
{ |
||||
name: "TestSourceF", |
||||
scope: setting.DeviceScope, |
||||
settingKey: "TestKeyA", |
||||
settingValue: "TestValueF", |
||||
wantEffective: true, |
||||
}, |
||||
}, |
||||
wantSnapshot: setting.NewSnapshot(map[setting.Key]setting.RawItem{ |
||||
"TestKeyA": setting.RawItemWith("TestValueF", nil, setting.NewNamedOrigin("TestSourceF", setting.DeviceScope)), |
||||
"TestKeyB": setting.RawItemWith("TestValueB", nil, setting.NewNamedOrigin("TestSourceB", setting.DeviceScope)), |
||||
"TestKeyC": setting.RawItemWith("TestValueE", nil, setting.NewNamedOrigin("TestSourceE", setting.DeviceScope)), |
||||
}, setting.DeviceScope), |
||||
}, |
||||
{ |
||||
name: "UserScope/Init-DeviceSource", |
||||
scope: setting.CurrentUserScope, |
||||
initialSources: []sourceConfig{ |
||||
{ |
||||
name: "TestSourceDevice", |
||||
scope: setting.DeviceScope, |
||||
settingKey: "TestKeyA", |
||||
settingValue: "DeviceValue", |
||||
wantEffective: true, |
||||
}, |
||||
}, |
||||
wantSnapshot: setting.NewSnapshot(map[setting.Key]setting.RawItem{ |
||||
"TestKeyA": setting.RawItemWith("DeviceValue", nil, setting.NewNamedOrigin("TestSourceDevice", setting.DeviceScope)), |
||||
}, setting.CurrentUserScope, setting.NewNamedOrigin("TestSourceDevice", setting.DeviceScope)), |
||||
}, |
||||
{ |
||||
name: "UserScope/Init-DeviceSource/Add-UserSource", |
||||
scope: setting.CurrentUserScope, |
||||
initialSources: []sourceConfig{ |
||||
{ |
||||
name: "TestSourceDevice", |
||||
scope: setting.DeviceScope, |
||||
settingKey: "TestKeyA", |
||||
settingValue: "DeviceValue", |
||||
wantEffective: true, |
||||
}, |
||||
}, |
||||
additionalSources: []sourceConfig{ |
||||
{ |
||||
name: "TestSourceUser", |
||||
scope: setting.CurrentUserScope, |
||||
settingKey: "TestKeyB", |
||||
settingValue: "UserValue", |
||||
wantEffective: true, |
||||
}, |
||||
}, |
||||
wantSnapshot: setting.NewSnapshot(map[setting.Key]setting.RawItem{ |
||||
"TestKeyA": setting.RawItemWith("DeviceValue", nil, setting.NewNamedOrigin("TestSourceDevice", setting.DeviceScope)), |
||||
"TestKeyB": setting.RawItemWith("UserValue", nil, setting.NewNamedOrigin("TestSourceUser", setting.CurrentUserScope)), |
||||
}, setting.CurrentUserScope), |
||||
}, |
||||
{ |
||||
name: "UserScope/Init-DeviceSource/Add-UserSource-and-ProfileSource", |
||||
scope: setting.CurrentUserScope, |
||||
initialSources: []sourceConfig{ |
||||
{ |
||||
name: "TestSourceDevice", |
||||
scope: setting.DeviceScope, |
||||
settingKey: "TestKeyA", |
||||
settingValue: "DeviceValue", |
||||
wantEffective: true, |
||||
}, |
||||
}, |
||||
additionalSources: []sourceConfig{ |
||||
{ |
||||
name: "TestSourceProfile", |
||||
scope: setting.CurrentProfileScope, |
||||
settingKey: "TestKeyB", |
||||
settingValue: "ProfileValue", |
||||
wantEffective: true, |
||||
}, |
||||
{ |
||||
name: "TestSourceUser", |
||||
scope: setting.CurrentUserScope, |
||||
settingKey: "TestKeyB", |
||||
settingValue: "UserValue", |
||||
wantEffective: true, |
||||
}, |
||||
}, |
||||
wantSnapshot: setting.NewSnapshot(map[setting.Key]setting.RawItem{ |
||||
"TestKeyA": setting.RawItemWith("DeviceValue", nil, setting.NewNamedOrigin("TestSourceDevice", setting.DeviceScope)), |
||||
"TestKeyB": setting.RawItemWith("ProfileValue", nil, setting.NewNamedOrigin("TestSourceProfile", setting.CurrentProfileScope)), |
||||
}, setting.CurrentUserScope), |
||||
}, |
||||
{ |
||||
name: "DevicePolicy/User-Source-does-not-apply", |
||||
scope: setting.DeviceScope, |
||||
initialSources: []sourceConfig{ |
||||
{ |
||||
name: "TestSourceDevice", |
||||
scope: setting.DeviceScope, |
||||
settingKey: "TestKeyA", |
||||
settingValue: "DeviceValue", |
||||
wantEffective: true, |
||||
}, |
||||
}, |
||||
additionalSources: []sourceConfig{ |
||||
{ |
||||
name: "TestSourceUser", |
||||
scope: setting.CurrentUserScope, |
||||
settingKey: "TestKeyA", |
||||
settingValue: "UserValue", |
||||
wantEffective: false, // Registering a user source should have no impact on the device policy.
|
||||
}, |
||||
}, |
||||
wantSnapshot: setting.NewSnapshot(map[setting.Key]setting.RawItem{ |
||||
"TestKeyA": setting.RawItemWith("DeviceValue", nil, setting.NewNamedOrigin("TestSourceDevice", setting.DeviceScope)), |
||||
}, setting.NewNamedOrigin("TestSourceDevice", setting.DeviceScope)), |
||||
}, |
||||
} |
||||
for _, tt := range tests { |
||||
t.Run(tt.name, func(t *testing.T) { |
||||
// Register all settings that we use in this test.
|
||||
var definitions []*setting.Definition |
||||
for _, source := range slices.Concat(tt.initialSources, tt.additionalSources) { |
||||
definitions = append(definitions, setting.NewDefinition(source.settingKey, tt.scope.Kind(), setting.StringValue)) |
||||
} |
||||
if err := setting.SetDefinitionsForTest(t, definitions...); err != nil { |
||||
t.Fatalf("SetDefinitionsForTest failed: %v", err) |
||||
} |
||||
|
||||
// Add the initial policy sources.
|
||||
var wantSources []*source.Source |
||||
for _, s := range tt.initialSources { |
||||
store := source.NewTestStoreOf(t, source.TestSettingOf(s.settingKey, s.settingValue)) |
||||
source := source.NewSource(s.name, s.scope, store) |
||||
if err := registerSource(source); err != nil { |
||||
t.Fatalf("Failed to register policy source: %v", source) |
||||
} |
||||
if s.wantEffective { |
||||
wantSources = append(wantSources, source) |
||||
} |
||||
t.Cleanup(func() { unregisterSource(source) }) |
||||
} |
||||
|
||||
// Retrieve the effective policy.
|
||||
policy, err := policyForTest(t, tt.scope) |
||||
if err != nil { |
||||
t.Fatalf("Failed to get effective policy for %v: %v", tt.scope, err) |
||||
} |
||||
|
||||
checkPolicySources(t, policy, wantSources) |
||||
|
||||
// Add additional setting sources.
|
||||
for _, s := range tt.additionalSources { |
||||
store := source.NewTestStoreOf(t, source.TestSettingOf(s.settingKey, s.settingValue)) |
||||
source := source.NewSource(s.name, s.scope, store) |
||||
if err := registerSource(source); err != nil { |
||||
t.Fatalf("Failed to register additional policy source: %v", source) |
||||
} |
||||
if s.wantEffective { |
||||
wantSources = append(wantSources, source) |
||||
} |
||||
t.Cleanup(func() { unregisterSource(source) }) |
||||
} |
||||
|
||||
checkPolicySources(t, policy, wantSources) |
||||
|
||||
// Verify the final effective settings snapshots.
|
||||
if got := policy.Get(); !got.Equal(tt.wantSnapshot) { |
||||
t.Errorf("Snapshot: got %v; want %v", got, tt.wantSnapshot) |
||||
} |
||||
}) |
||||
} |
||||
} |
||||
|
||||
func TestPolicyFor(t *testing.T) { |
||||
tests := []struct { |
||||
name string |
||||
scopeA, scopeB setting.PolicyScope |
||||
closePolicy bool // indicates whether to close policyA before retrieving policyB
|
||||
wantSame bool // specifies whether policyA and policyB should reference the same [Policy] instance
|
||||
}{ |
||||
{ |
||||
name: "Device/Device", |
||||
scopeA: setting.DeviceScope, |
||||
scopeB: setting.DeviceScope, |
||||
wantSame: true, |
||||
}, |
||||
{ |
||||
name: "Device/CurrentProfile", |
||||
scopeA: setting.DeviceScope, |
||||
scopeB: setting.CurrentProfileScope, |
||||
wantSame: false, |
||||
}, |
||||
{ |
||||
name: "Device/CurrentUser", |
||||
scopeA: setting.DeviceScope, |
||||
scopeB: setting.CurrentUserScope, |
||||
wantSame: false, |
||||
}, |
||||
{ |
||||
name: "CurrentProfile/CurrentProfile", |
||||
scopeA: setting.CurrentProfileScope, |
||||
scopeB: setting.CurrentProfileScope, |
||||
wantSame: true, |
||||
}, |
||||
{ |
||||
name: "CurrentProfile/CurrentUser", |
||||
scopeA: setting.CurrentProfileScope, |
||||
scopeB: setting.CurrentUserScope, |
||||
wantSame: false, |
||||
}, |
||||
{ |
||||
name: "CurrentUser/CurrentUser", |
||||
scopeA: setting.CurrentUserScope, |
||||
scopeB: setting.CurrentUserScope, |
||||
wantSame: true, |
||||
}, |
||||
{ |
||||
name: "UserA/UserA", |
||||
scopeA: setting.UserScopeOf("UserA"), |
||||
scopeB: setting.UserScopeOf("UserA"), |
||||
wantSame: true, |
||||
}, |
||||
{ |
||||
name: "UserA/UserB", |
||||
scopeA: setting.UserScopeOf("UserA"), |
||||
scopeB: setting.UserScopeOf("UserB"), |
||||
wantSame: false, |
||||
}, |
||||
{ |
||||
name: "New-after-close", |
||||
scopeA: setting.DeviceScope, |
||||
scopeB: setting.DeviceScope, |
||||
closePolicy: true, |
||||
wantSame: false, |
||||
}, |
||||
} |
||||
for _, tt := range tests { |
||||
t.Run(tt.name, func(t *testing.T) { |
||||
policyA, err := policyForTest(t, tt.scopeA) |
||||
if err != nil { |
||||
t.Fatalf("Failed to get effective policy for %v: %v", tt.scopeA, err) |
||||
} |
||||
|
||||
if tt.closePolicy { |
||||
policyA.Close() |
||||
} |
||||
|
||||
policyB, err := policyForTest(t, tt.scopeB) |
||||
if err != nil { |
||||
t.Fatalf("Failed to get effective policy for %v: %v", tt.scopeB, err) |
||||
} |
||||
|
||||
if gotSame := policyA == policyB; gotSame != tt.wantSame { |
||||
t.Fatalf("Got same: %v; want same %v", gotSame, tt.wantSame) |
||||
} |
||||
}) |
||||
} |
||||
} |
||||
|
||||
func TestPolicyChangeHasChanged(t *testing.T) { |
||||
tests := []struct { |
||||
name string |
||||
old, new map[setting.Key]setting.RawItem |
||||
wantChanged []setting.Key |
||||
wantUnchanged []setting.Key |
||||
}{ |
||||
{ |
||||
name: "String-Settings", |
||||
old: map[setting.Key]setting.RawItem{ |
||||
"ChangedSetting": setting.RawItemOf("Old"), |
||||
"UnchangedSetting": setting.RawItemOf("Value"), |
||||
}, |
||||
new: map[setting.Key]setting.RawItem{ |
||||
"ChangedSetting": setting.RawItemOf("New"), |
||||
"UnchangedSetting": setting.RawItemOf("Value"), |
||||
}, |
||||
wantChanged: []setting.Key{"ChangedSetting"}, |
||||
wantUnchanged: []setting.Key{"UnchangedSetting"}, |
||||
}, |
||||
{ |
||||
name: "UInt64-Settings", |
||||
old: map[setting.Key]setting.RawItem{ |
||||
"ChangedSetting": setting.RawItemOf(uint64(0)), |
||||
"UnchangedSetting": setting.RawItemOf(uint64(42)), |
||||
}, |
||||
new: map[setting.Key]setting.RawItem{ |
||||
"ChangedSetting": setting.RawItemOf(uint64(1)), |
||||
"UnchangedSetting": setting.RawItemOf(uint64(42)), |
||||
}, |
||||
wantChanged: []setting.Key{"ChangedSetting"}, |
||||
wantUnchanged: []setting.Key{"UnchangedSetting"}, |
||||
}, |
||||
{ |
||||
name: "StringSlice-Settings", |
||||
old: map[setting.Key]setting.RawItem{ |
||||
"ChangedSetting": setting.RawItemOf([]string{"Chicago"}), |
||||
"UnchangedSetting": setting.RawItemOf([]string{"String1", "String2"}), |
||||
}, |
||||
new: map[setting.Key]setting.RawItem{ |
||||
"ChangedSetting": setting.RawItemOf([]string{"New York"}), |
||||
"UnchangedSetting": setting.RawItemOf([]string{"String1", "String2"}), |
||||
}, |
||||
wantChanged: []setting.Key{"ChangedSetting"}, |
||||
wantUnchanged: []setting.Key{"UnchangedSetting"}, |
||||
}, |
||||
{ |
||||
name: "Int8-Settings", // We don't have actual int8 settings, but this should still work.
|
||||
old: map[setting.Key]setting.RawItem{ |
||||
"ChangedSetting": setting.RawItemOf(int8(0)), |
||||
"UnchangedSetting": setting.RawItemOf(int8(42)), |
||||
}, |
||||
new: map[setting.Key]setting.RawItem{ |
||||
"ChangedSetting": setting.RawItemOf(int8(1)), |
||||
"UnchangedSetting": setting.RawItemOf(int8(42)), |
||||
}, |
||||
wantChanged: []setting.Key{"ChangedSetting"}, |
||||
wantUnchanged: []setting.Key{"UnchangedSetting"}, |
||||
}, |
||||
} |
||||
for _, tt := range tests { |
||||
t.Run(tt.name, func(t *testing.T) { |
||||
old := setting.NewSnapshot(tt.old) |
||||
new := setting.NewSnapshot(tt.new) |
||||
change := PolicyChange{Change[*setting.Snapshot]{old, new}} |
||||
for _, wantChanged := range tt.wantChanged { |
||||
if !change.HasChanged(wantChanged) { |
||||
t.Errorf("%q changed: got false; want true", wantChanged) |
||||
} |
||||
} |
||||
for _, wantUnchanged := range tt.wantUnchanged { |
||||
if change.HasChanged(wantUnchanged) { |
||||
t.Errorf("%q unchanged: got true; want false", wantUnchanged) |
||||
} |
||||
} |
||||
}) |
||||
} |
||||
} |
||||
|
||||
func TestChangePolicySetting(t *testing.T) { |
||||
setForTest(t, &policyReloadMinDelay, 100*time.Millisecond) |
||||
setForTest(t, &policyReloadMaxDelay, 500*time.Millisecond) |
||||
|
||||
// Register policy settings used in this test.
|
||||
settingA := setting.NewDefinition("TestSettingA", setting.DeviceSetting, setting.StringValue) |
||||
settingB := setting.NewDefinition("TestSettingB", setting.DeviceSetting, setting.StringValue) |
||||
if err := setting.SetDefinitionsForTest(t, settingA, settingB); err != nil { |
||||
t.Fatalf("SetDefinitionsForTest failed: %v", err) |
||||
} |
||||
|
||||
// Register a test policy store and create a effective policy that reads the policy settings from it.
|
||||
store := source.NewTestStoreOf[string](t) |
||||
if _, err := RegisterStoreForTest(t, "TestSource", setting.DeviceScope, store); err != nil { |
||||
t.Fatalf("Failed to register policy store: %v", err) |
||||
} |
||||
policy, err := policyForTest(t, setting.DeviceScope) |
||||
if err != nil { |
||||
t.Fatalf("Failed to get effective policy: %v", err) |
||||
} |
||||
|
||||
// The policy setting is not configured yet.
|
||||
if _, ok := policy.Get().GetSetting(settingA.Key()); ok { |
||||
t.Fatalf("Policy setting %q unexpectedly exists", settingA.Key()) |
||||
} |
||||
|
||||
// Subscribe to the policy change callback...
|
||||
policyChanged := make(chan *PolicyChange) |
||||
unregister := policy.RegisterChangeCallback(func(pc *PolicyChange) { policyChanged <- pc }) |
||||
t.Cleanup(unregister) |
||||
|
||||
// ...make the change, and measure the time between initiating the change
|
||||
// and receiving the callback.
|
||||
start := time.Now() |
||||
const wantValueA = "TestValueA" |
||||
store.SetStrings(source.TestSettingOf(settingA.Key(), wantValueA)) |
||||
change := <-policyChanged |
||||
gotDelay := time.Since(start) |
||||
|
||||
// Ensure there is at least a [policyReloadMinDelay] delay between
|
||||
// a change and the policy reload along with the callback invocation.
|
||||
// This prevents reloading policy settings too frequently
|
||||
// when multiple settings change within a short period of time.
|
||||
if gotDelay < policyReloadMinDelay { |
||||
t.Errorf("Delay: got %v; want >= %v", gotDelay, policyReloadMinDelay) |
||||
} |
||||
|
||||
// Verify that the [PolicyChange] passed to the policy change callback
|
||||
// contains the correct information regarding the policy setting changes.
|
||||
if !change.HasChanged(settingA.Key()) { |
||||
t.Errorf("Policy setting %q has not changed", settingA.Key()) |
||||
} |
||||
if change.HasChanged(settingB.Key()) { |
||||
t.Errorf("Policy setting %q was unexpectedly changed", settingB.Key()) |
||||
} |
||||
if _, ok := change.Old().GetSetting(settingA.Key()); ok { |
||||
t.Fatalf("Policy setting %q unexpectedly exists", settingA.Key()) |
||||
} |
||||
if gotValue := change.New().Get(settingA.Key()); gotValue != wantValueA { |
||||
t.Errorf("Policy setting %q: got %q; want %q", settingA.Key(), gotValue, wantValueA) |
||||
} |
||||
|
||||
// And also verify that the current (most recent) [setting.Snapshot]
|
||||
// includes the change we just made.
|
||||
if gotValue := policy.Get().Get(settingA.Key()); gotValue != wantValueA { |
||||
t.Errorf("Policy setting %q: got %q; want %q", settingA.Key(), gotValue, wantValueA) |
||||
} |
||||
|
||||
// Now, let's change another policy setting value N times.
|
||||
const N = 10 |
||||
wantValueB := strconv.Itoa(N) |
||||
start = time.Now() |
||||
for i := range N { |
||||
store.SetStrings(source.TestSettingOf(settingB.Key(), strconv.Itoa(i+1))) |
||||
} |
||||
|
||||
// The callback should be invoked only once, even though the policy setting
|
||||
// has changed N times.
|
||||
change = <-policyChanged |
||||
gotDelay = time.Since(start) |
||||
gotCallbacks := 1 |
||||
drain: |
||||
for { |
||||
select { |
||||
case <-policyChanged: |
||||
gotCallbacks++ |
||||
case <-time.After(policyReloadMaxDelay): |
||||
break drain |
||||
} |
||||
} |
||||
if wantCallbacks := 1; gotCallbacks > wantCallbacks { |
||||
t.Errorf("Callbacks: got %d; want %d", gotCallbacks, wantCallbacks) |
||||
} |
||||
|
||||
// Additionally, the policy change callback should be received no sooner
|
||||
// than [policyReloadMinDelay] and no later than [policyReloadMaxDelay].
|
||||
if gotDelay < policyReloadMinDelay || gotDelay > policyReloadMaxDelay { |
||||
t.Errorf("Delay: got %v; want >= %v && <= %v", gotDelay, policyReloadMinDelay, policyReloadMaxDelay) |
||||
} |
||||
|
||||
// Verify that the [PolicyChange] received via the callback
|
||||
// contains the final policy setting value.
|
||||
if !change.HasChanged(settingB.Key()) { |
||||
t.Errorf("Policy setting %q has not changed", settingB.Key()) |
||||
} |
||||
if change.HasChanged(settingA.Key()) { |
||||
t.Errorf("Policy setting %q was unexpectedly changed", settingA.Key()) |
||||
} |
||||
if _, ok := change.Old().GetSetting(settingB.Key()); ok { |
||||
t.Fatalf("Policy setting %q unexpectedly exists", settingB.Key()) |
||||
} |
||||
if gotValue := change.New().Get(settingB.Key()); gotValue != wantValueB { |
||||
t.Errorf("Policy setting %q: got %q; want %q", settingB.Key(), gotValue, wantValueB) |
||||
} |
||||
|
||||
// Lastly, if a policy store issues a change notification, but the effective policy
|
||||
// remains unchanged, the [Policy] should ignore it without invoking the change callbacks.
|
||||
store.NotifyPolicyChanged() |
||||
select { |
||||
case <-policyChanged: |
||||
t.Fatal("Unexpected policy changed notification") |
||||
case <-time.After(policyReloadMaxDelay): |
||||
} |
||||
} |
||||
|
||||
func TestClosePolicySource(t *testing.T) { |
||||
testSetting := setting.NewDefinition("TestSetting", setting.DeviceSetting, setting.StringValue) |
||||
if err := setting.SetDefinitionsForTest(t, testSetting); err != nil { |
||||
t.Fatalf("SetDefinitionsForTest failed: %v", err) |
||||
} |
||||
|
||||
wantSettingValue := "TestValue" |
||||
store := source.NewTestStoreOf(t, source.TestSettingOf(testSetting.Key(), wantSettingValue)) |
||||
if _, err := RegisterStoreForTest(t, "TestSource", setting.DeviceScope, store); err != nil { |
||||
t.Fatalf("Failed to register policy store: %v", err) |
||||
} |
||||
policy, err := policyForTest(t, setting.DeviceScope) |
||||
if err != nil { |
||||
t.Fatalf("Failed to get effective policy: %v", err) |
||||
} |
||||
|
||||
initialSnapshot, err := policy.Reload() |
||||
if err != nil { |
||||
t.Fatalf("Failed to reload policy: %v", err) |
||||
} |
||||
if gotSettingValue, err := initialSnapshot.GetErr(testSetting.Key()); err != nil { |
||||
t.Fatalf("Failed to get %q setting value: %v", testSetting.Key(), err) |
||||
} else if gotSettingValue != wantSettingValue { |
||||
t.Fatalf("Setting %q: got %q; want %q", testSetting.Key(), gotSettingValue, wantSettingValue) |
||||
} |
||||
|
||||
store.Close() |
||||
|
||||
// Closing a policy source abruptly without removing it first should invalidate and close the policy.
|
||||
<-policy.Done() |
||||
if policy.IsValid() { |
||||
t.Fatal("The policy was not properly closed") |
||||
} |
||||
|
||||
// The resulting policy snapshot should remain valid and unchanged.
|
||||
finalSnapshot := policy.Get() |
||||
if !finalSnapshot.Equal(initialSnapshot) { |
||||
t.Fatal("Policy snapshot has changed") |
||||
} |
||||
if gotSettingValue, err := finalSnapshot.GetErr(testSetting.Key()); err != nil { |
||||
t.Fatalf("Failed to get final %q setting value: %v", testSetting.Key(), err) |
||||
} else if gotSettingValue != wantSettingValue { |
||||
t.Fatalf("Setting %q: got %q; want %q", testSetting.Key(), gotSettingValue, wantSettingValue) |
||||
} |
||||
|
||||
// However, any further requests to reload the policy should fail.
|
||||
if _, err := policy.Reload(); err == nil || !errors.Is(err, ErrPolicyClosed) { |
||||
t.Fatalf("Reload: gotErr: %v; wantErr: %v", err, ErrPolicyClosed) |
||||
} |
||||
} |
||||
|
||||
func TestRemovePolicySource(t *testing.T) { |
||||
// Register policy settings used in this test.
|
||||
settingA := setting.NewDefinition("TestSettingA", setting.DeviceSetting, setting.StringValue) |
||||
settingB := setting.NewDefinition("TestSettingB", setting.DeviceSetting, setting.StringValue) |
||||
if err := setting.SetDefinitionsForTest(t, settingA, settingB); err != nil { |
||||
t.Fatalf("SetDefinitionsForTest failed: %v", err) |
||||
} |
||||
|
||||
// Register two policy stores.
|
||||
storeA := source.NewTestStoreOf(t, source.TestSettingOf(settingA.Key(), "A")) |
||||
storeRegA, err := RegisterStoreForTest(t, "TestSourceA", setting.DeviceScope, storeA) |
||||
if err != nil { |
||||
t.Fatalf("Failed to register policy store A: %v", err) |
||||
} |
||||
storeB := source.NewTestStoreOf(t, source.TestSettingOf(settingB.Key(), "B")) |
||||
storeRegB, err := RegisterStoreForTest(t, "TestSourceB", setting.DeviceScope, storeB) |
||||
if err != nil { |
||||
t.Fatalf("Failed to register policy store A: %v", err) |
||||
} |
||||
|
||||
// Create a effective [Policy] that reads policy settings from the two stores.
|
||||
policy, err := policyForTest(t, setting.DeviceScope) |
||||
if err != nil { |
||||
t.Fatalf("Failed to get effective policy: %v", err) |
||||
} |
||||
|
||||
// Verify that the [Policy] uses both stores and includes policy settings from each.
|
||||
if gotSources, wantSources := len(policy.sources), 2; gotSources != wantSources { |
||||
t.Fatalf("Policy Sources: got %v; want %v", gotSources, wantSources) |
||||
} |
||||
if got, want := policy.Get().Get(settingA.Key()), "A"; got != want { |
||||
t.Fatalf("Setting %q: got %q; want %q", settingA.Key(), got, want) |
||||
} |
||||
if got, want := policy.Get().Get(settingB.Key()), "B"; got != want { |
||||
t.Fatalf("Setting %q: got %q; want %q", settingB.Key(), got, want) |
||||
} |
||||
|
||||
// Unregister Store A and verify that the effective policy remains valid.
|
||||
// It should no longer use the removed store or include any policy settings from it.
|
||||
if err := storeRegA.Unregister(); err != nil { |
||||
t.Fatalf("Failed to unregister Store A: %v", err) |
||||
} |
||||
if !policy.IsValid() { |
||||
t.Fatalf("Policy was unexpectedly closed") |
||||
} |
||||
if gotSources, wantSources := len(policy.sources), 1; gotSources != wantSources { |
||||
t.Fatalf("Policy Sources: got %v; want %v", gotSources, wantSources) |
||||
} |
||||
if got, want := policy.Get().Get(settingA.Key()), any(nil); got != want { |
||||
t.Fatalf("Setting %q: got %q; want %q", settingA.Key(), got, want) |
||||
} |
||||
if got, want := policy.Get().Get(settingB.Key()), "B"; got != want { |
||||
t.Fatalf("Setting %q: got %q; want %q", settingB.Key(), got, want) |
||||
} |
||||
|
||||
// Unregister Store B and verify that the effective policy is still valid.
|
||||
// However, it should be empty since there are no associated policy sources.
|
||||
if err := storeRegB.Unregister(); err != nil { |
||||
t.Fatalf("Failed to unregister Store B: %v", err) |
||||
} |
||||
if !policy.IsValid() { |
||||
t.Fatalf("Policy was unexpectedly closed") |
||||
} |
||||
if gotSources, wantSources := len(policy.sources), 0; gotSources != wantSources { |
||||
t.Fatalf("Policy Sources: got %v; want %v", gotSources, wantSources) |
||||
} |
||||
if got := policy.Get(); got.Len() != 0 { |
||||
t.Fatalf("Settings: got %v; want {Empty}", got) |
||||
} |
||||
} |
||||
|
||||
func TestReplacePolicySource(t *testing.T) { |
||||
setForTest(t, &policyReloadMinDelay, 100*time.Millisecond) |
||||
setForTest(t, &policyReloadMaxDelay, 500*time.Millisecond) |
||||
|
||||
// Register policy settings used in this test.
|
||||
testSetting := setting.NewDefinition("TestSettingA", setting.DeviceSetting, setting.StringValue) |
||||
if err := setting.SetDefinitionsForTest(t, testSetting); err != nil { |
||||
t.Fatalf("SetDefinitionsForTest failed: %v", err) |
||||
} |
||||
|
||||
// Create two policy stores.
|
||||
initialStore := source.NewTestStoreOf(t, source.TestSettingOf(testSetting.Key(), "InitialValue")) |
||||
newStore := source.NewTestStoreOf(t, source.TestSettingOf(testSetting.Key(), "NewValue")) |
||||
unchangedStore := source.NewTestStoreOf(t, source.TestSettingOf(testSetting.Key(), "NewValue")) |
||||
|
||||
// Register the initial store and create a effective [Policy] that reads policy settings from it.
|
||||
reg, err := RegisterStoreForTest(t, "TestStore", setting.DeviceScope, initialStore) |
||||
if err != nil { |
||||
t.Fatalf("Failed to register the initial store: %v", err) |
||||
} |
||||
policy, err := policyForTest(t, setting.DeviceScope) |
||||
if err != nil { |
||||
t.Fatalf("Failed to get effective policy: %v", err) |
||||
} |
||||
|
||||
// Verify that the test setting has its initial value.
|
||||
if got, want := policy.Get().Get(testSetting.Key()), "InitialValue"; got != want { |
||||
t.Fatalf("Setting %q: got %q; want %q", testSetting.Key(), got, want) |
||||
} |
||||
|
||||
// Subscribe to the policy change callback.
|
||||
policyChanged := make(chan *PolicyChange, 1) |
||||
unregister := policy.RegisterChangeCallback(func(pc *PolicyChange) { policyChanged <- pc }) |
||||
t.Cleanup(unregister) |
||||
|
||||
// Now, let's replace the initial store with the new store.
|
||||
reg, err = reg.ReplaceStore(newStore) |
||||
if err != nil { |
||||
t.Fatalf("Failed to replace the policy store: %v", err) |
||||
} |
||||
t.Cleanup(func() { reg.Unregister() }) |
||||
|
||||
// We should receive a policy change notification as the setting value has changed.
|
||||
<-policyChanged |
||||
|
||||
// Verify that the test setting has the new value.
|
||||
if got, want := policy.Get().Get(testSetting.Key()), "NewValue"; got != want { |
||||
t.Fatalf("Setting %q: got %q; want %q", testSetting.Key(), got, want) |
||||
} |
||||
|
||||
// Replacing a policy store with an identical one containing the same
|
||||
// values for the same settings should not be considered a policy change.
|
||||
reg, err = reg.ReplaceStore(unchangedStore) |
||||
if err != nil { |
||||
t.Fatalf("Failed to replace the policy store: %v", err) |
||||
} |
||||
t.Cleanup(func() { reg.Unregister() }) |
||||
|
||||
select { |
||||
case <-policyChanged: |
||||
t.Fatal("Unexpected policy changed notification") |
||||
default: |
||||
<-time.After(policyReloadMaxDelay) |
||||
} |
||||
} |
||||
|
||||
func TestAddClosedPolicySource(t *testing.T) { |
||||
store := source.NewTestStoreOf[string](t) |
||||
if _, err := RegisterStoreForTest(t, "TestSource", setting.DeviceScope, store); err != nil { |
||||
t.Fatalf("Failed to register policy store: %v", err) |
||||
} |
||||
store.Close() |
||||
|
||||
_, err := policyForTest(t, setting.DeviceScope) |
||||
if err == nil || !errors.Is(err, source.ErrStoreClosed) { |
||||
t.Fatalf("got: %v; want: %v", err, source.ErrStoreClosed) |
||||
} |
||||
} |
||||
|
||||
func TestClosePolicyMoreThanOnce(t *testing.T) { |
||||
tests := []struct { |
||||
name string |
||||
numSources int |
||||
}{ |
||||
{ |
||||
name: "NoSources", |
||||
numSources: 0, |
||||
}, |
||||
{ |
||||
name: "OneSource", |
||||
numSources: 1, |
||||
}, |
||||
{ |
||||
name: "ManySources", |
||||
numSources: 10, |
||||
}, |
||||
} |
||||
for _, tt := range tests { |
||||
t.Run(tt.name, func(t *testing.T) { |
||||
for i := range tt.numSources { |
||||
store := source.NewTestStoreOf[string](t) |
||||
if _, err := RegisterStoreForTest(t, "TestSource #"+strconv.Itoa(i), setting.DeviceScope, store); err != nil { |
||||
t.Fatalf("Failed to register policy store: %v", err) |
||||
} |
||||
} |
||||
|
||||
policy, err := policyForTest(t, setting.DeviceScope) |
||||
if err != nil { |
||||
t.Fatalf("failed to get effective policy: %v", err) |
||||
} |
||||
|
||||
const N = 10000 |
||||
var wg sync.WaitGroup |
||||
for range N { |
||||
wg.Add(1) |
||||
go func() { |
||||
wg.Done() |
||||
policy.Close() |
||||
<-policy.Done() |
||||
}() |
||||
} |
||||
wg.Wait() |
||||
}) |
||||
} |
||||
} |
||||
|
||||
func checkPolicySources(tb testing.TB, gotPolicy *Policy, wantSources []*source.Source) { |
||||
tb.Helper() |
||||
sort.SliceStable(wantSources, func(i, j int) bool { |
||||
return wantSources[i].Compare(wantSources[j]) < 0 |
||||
}) |
||||
gotSources := make([]*source.Source, len(gotPolicy.sources)) |
||||
for i := range gotPolicy.sources { |
||||
gotSources[i] = gotPolicy.sources[i].Source |
||||
} |
||||
type sourceSummary struct{ Name, Scope string } |
||||
toSourceSummary := cmp.Transformer("source", func(s *source.Source) sourceSummary { return sourceSummary{s.Name(), s.Scope().String()} }) |
||||
if diff := cmp.Diff(wantSources, gotSources, toSourceSummary, cmpopts.EquateEmpty()); diff != "" { |
||||
tb.Errorf("Policy Sources mismatch: %v", diff) |
||||
} |
||||
} |
||||
|
||||
// policyForTest is like [PolicyFor], but it deletes the policy
|
||||
// when tb and all its subtests complete.
|
||||
func policyForTest(tb testing.TB, target setting.PolicyScope) (*Policy, error) { |
||||
tb.Helper() |
||||
|
||||
policy, err := PolicyFor(target) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
tb.Cleanup(func() { |
||||
policy.Close() |
||||
<-policy.Done() |
||||
deletePolicy(policy) |
||||
}) |
||||
return policy, nil |
||||
} |
||||
|
||||
func setForTest[T any](tb testing.TB, target *T, newValue T) { |
||||
oldValue := *target |
||||
tb.Cleanup(func() { *target = oldValue }) |
||||
*target = newValue |
||||
} |
||||
@ -0,0 +1,174 @@ |
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
// Package rsop facilitates [source.Store] registration via [RegisterStore]
|
||||
// and provides access to the effective policy merged from all registered sources
|
||||
// via [PolicyFor].
|
||||
package rsop |
||||
|
||||
import ( |
||||
"errors" |
||||
"fmt" |
||||
"slices" |
||||
"sync" |
||||
|
||||
"tailscale.com/syncs" |
||||
"tailscale.com/util/slicesx" |
||||
"tailscale.com/util/syspolicy/internal" |
||||
"tailscale.com/util/syspolicy/setting" |
||||
"tailscale.com/util/syspolicy/source" |
||||
) |
||||
|
||||
var ( |
||||
policyMu sync.Mutex // protects [policySources] and [effectivePolicies]
|
||||
policySources []*source.Source // all registered policy sources
|
||||
effectivePolicies []*Policy // all active (non-closed) effective policies returned by [PolicyFor]
|
||||
|
||||
// effectivePolicyLRU is an LRU cache of [Policy] by [setting.Scope].
|
||||
// Although there could be multiple [setting.PolicyScope] instances with the same [setting.Scope],
|
||||
// such as two user scopes for different users, there is only one [setting.DeviceScope], only one
|
||||
// [setting.CurrentProfileScope], and in most cases, only one active user scope.
|
||||
// Therefore, cache misses that require falling back to [effectivePolicies] are extremely rare.
|
||||
// It's a fixed-size array of atomic values and can be accessed without [policyMu] held.
|
||||
effectivePolicyLRU [setting.NumScopes]syncs.AtomicValue[*Policy] |
||||
) |
||||
|
||||
// PolicyFor returns the [Policy] for the specified scope,
|
||||
// creating it from the registered [source.Store]s if it doesn't already exist.
|
||||
func PolicyFor(scope setting.PolicyScope) (*Policy, error) { |
||||
if err := internal.Init.Do(); err != nil { |
||||
return nil, err |
||||
} |
||||
policy := effectivePolicyLRU[scope.Kind()].Load() |
||||
if policy != nil && policy.Scope() == scope && policy.IsValid() { |
||||
return policy, nil |
||||
} |
||||
return policyForSlow(scope) |
||||
} |
||||
|
||||
func policyForSlow(scope setting.PolicyScope) (policy *Policy, err error) { |
||||
defer func() { |
||||
// Always update the LRU cache on exit if we found (or created)
|
||||
// a policy for the specified scope.
|
||||
if policy != nil { |
||||
effectivePolicyLRU[scope.Kind()].Store(policy) |
||||
} |
||||
}() |
||||
|
||||
policyMu.Lock() |
||||
defer policyMu.Unlock() |
||||
if policy, ok := findPolicyByScopeLocked(scope); ok { |
||||
return policy, nil |
||||
} |
||||
|
||||
// If there is no existing effective policy for the specified scope,
|
||||
// we need to create one using the policy sources registered for that scope.
|
||||
sources := slicesx.Filter(nil, policySources, func(source *source.Source) bool { |
||||
return source.Scope().Contains(scope) |
||||
}) |
||||
policy, err = newPolicy(scope, sources...) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
effectivePolicies = append(effectivePolicies, policy) |
||||
return policy, nil |
||||
} |
||||
|
||||
// findPolicyByScopeLocked returns a policy with the specified scope and true if
|
||||
// one exists in the [effectivePolicies] list, otherwise it returns nil, false.
|
||||
// [policyMu] must be held.
|
||||
func findPolicyByScopeLocked(target setting.PolicyScope) (policy *Policy, ok bool) { |
||||
for _, policy := range effectivePolicies { |
||||
if policy.Scope() == target && policy.IsValid() { |
||||
return policy, true |
||||
} |
||||
} |
||||
return nil, false |
||||
} |
||||
|
||||
// deletePolicy deletes the specified effective policy from [effectivePolicies]
|
||||
// and [effectivePolicyLRU].
|
||||
func deletePolicy(policy *Policy) { |
||||
policyMu.Lock() |
||||
defer policyMu.Unlock() |
||||
if i := slices.Index(effectivePolicies, policy); i != -1 { |
||||
effectivePolicies = slices.Delete(effectivePolicies, i, i+1) |
||||
} |
||||
effectivePolicyLRU[policy.Scope().Kind()].CompareAndSwap(policy, nil) |
||||
} |
||||
|
||||
// registerSource registers the specified [source.Source] to be used by the package.
|
||||
// It updates existing [Policy]s returned by [PolicyFor] to use this source if
|
||||
// they are within the source's [setting.PolicyScope].
|
||||
func registerSource(source *source.Source) error { |
||||
policyMu.Lock() |
||||
defer policyMu.Unlock() |
||||
if slices.Contains(policySources, source) { |
||||
// already registered
|
||||
return nil |
||||
} |
||||
policySources = append(policySources, source) |
||||
return forEachEffectivePolicyLocked(func(policy *Policy) error { |
||||
if !source.Scope().Contains(policy.Scope()) { |
||||
// Policy settings in the specified source do not apply
|
||||
// to the scope of this effective policy.
|
||||
// For example, a user policy source is being registered
|
||||
// while the effective policy is for the device (or another user).
|
||||
return nil |
||||
} |
||||
return policy.addSource(source) |
||||
}) |
||||
} |
||||
|
||||
// replaceSource is like [unregisterSource](old) followed by [registerSource](new),
|
||||
// but performed atomically: the effective policy will contain settings
|
||||
// either from the old source or the new source, never both and never neither.
|
||||
func replaceSource(old, new *source.Source) error { |
||||
policyMu.Lock() |
||||
defer policyMu.Unlock() |
||||
oldIndex := slices.Index(policySources, old) |
||||
if oldIndex == -1 { |
||||
return fmt.Errorf("the source is not registered: %v", old) |
||||
} |
||||
policySources[oldIndex] = new |
||||
return forEachEffectivePolicyLocked(func(policy *Policy) error { |
||||
if !old.Scope().Contains(policy.Scope()) || !new.Scope().Contains(policy.Scope()) { |
||||
return nil |
||||
} |
||||
return policy.replaceSource(old, new) |
||||
}) |
||||
} |
||||
|
||||
// unregisterSource unregisters the specified [source.Source],
|
||||
// so that it won't be used by any new or existing [Policy].
|
||||
func unregisterSource(source *source.Source) error { |
||||
policyMu.Lock() |
||||
defer policyMu.Unlock() |
||||
index := slices.Index(policySources, source) |
||||
if index == -1 { |
||||
return nil |
||||
} |
||||
policySources = slices.Delete(policySources, index, index+1) |
||||
return forEachEffectivePolicyLocked(func(policy *Policy) error { |
||||
if !source.Scope().Contains(policy.Scope()) { |
||||
return nil |
||||
} |
||||
return policy.removeSource(source) |
||||
}) |
||||
} |
||||
|
||||
// forEachEffectivePolicyLocked calls fn for every non-closed [Policy] in [effectivePolicies].
|
||||
// It accumulates the returned errors and returns an error that wraps all errors returned by fn.
|
||||
// The [policyMu] mutex must be held while this function is executed.
|
||||
func forEachEffectivePolicyLocked(fn func(p *Policy) error) error { |
||||
var errs []error |
||||
for _, policy := range effectivePolicies { |
||||
if policy.IsValid() { |
||||
err := fn(policy) |
||||
if err != nil && !errors.Is(err, ErrPolicyClosed) { |
||||
errs = append(errs, err) |
||||
} |
||||
} |
||||
} |
||||
return errors.Join(errs...) |
||||
} |
||||
@ -0,0 +1,94 @@ |
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package rsop |
||||
|
||||
import ( |
||||
"errors" |
||||
"sync" |
||||
"sync/atomic" |
||||
|
||||
"tailscale.com/util/syspolicy/internal" |
||||
"tailscale.com/util/syspolicy/setting" |
||||
"tailscale.com/util/syspolicy/source" |
||||
) |
||||
|
||||
// ErrAlreadyConsumed is the error returned when [StoreRegistration.ReplaceStore]
|
||||
// or [StoreRegistration.Unregister] is called more than once.
|
||||
var ErrAlreadyConsumed = errors.New("the store registration is no longer valid") |
||||
|
||||
// StoreRegistration is a [source.Store] registered for use in the specified scope.
|
||||
// It can be used to unregister the store, or replace it with another one.
|
||||
type StoreRegistration struct { |
||||
source *source.Source |
||||
m sync.Mutex // protects the [StoreRegistration.consumeSlow] path
|
||||
consumed atomic.Bool // can be read without holding m, but must be written with m held
|
||||
} |
||||
|
||||
// RegisterStore registers a new policy [source.Store] with the specified name and [setting.PolicyScope].
|
||||
func RegisterStore(name string, scope setting.PolicyScope, store source.Store) (*StoreRegistration, error) { |
||||
return newStoreRegistration(name, scope, store) |
||||
} |
||||
|
||||
// RegisterStoreForTest is like [RegisterStore], but unregisters the store when
|
||||
// tb and all its subtests complete.
|
||||
func RegisterStoreForTest(tb internal.TB, name string, scope setting.PolicyScope, store source.Store) (*StoreRegistration, error) { |
||||
reg, err := RegisterStore(name, scope, store) |
||||
if err == nil { |
||||
tb.Cleanup(func() { |
||||
if err := reg.Unregister(); err != nil && !errors.Is(err, ErrAlreadyConsumed) { |
||||
tb.Fatalf("Unregister failed: %v", err) |
||||
} |
||||
}) |
||||
} |
||||
return reg, err // may be nil or non-nil
|
||||
} |
||||
|
||||
func newStoreRegistration(name string, scope setting.PolicyScope, store source.Store) (*StoreRegistration, error) { |
||||
source := source.NewSource(name, scope, store) |
||||
if err := registerSource(source); err != nil { |
||||
return nil, err |
||||
} |
||||
return &StoreRegistration{source: source}, nil |
||||
} |
||||
|
||||
// ReplaceStore replaces the registered store with the new one,
|
||||
// returning a new [StoreRegistration] or an error.
|
||||
func (r *StoreRegistration) ReplaceStore(new source.Store) (*StoreRegistration, error) { |
||||
var res *StoreRegistration |
||||
err := r.consume(func() error { |
||||
newSource := source.NewSource(r.source.Name(), r.source.Scope(), new) |
||||
if err := replaceSource(r.source, newSource); err != nil { |
||||
return err |
||||
} |
||||
res = &StoreRegistration{source: newSource} |
||||
return nil |
||||
}) |
||||
return res, err |
||||
} |
||||
|
||||
// Unregister reverts the registration.
|
||||
func (r *StoreRegistration) Unregister() error { |
||||
return r.consume(func() error { return unregisterSource(r.source) }) |
||||
} |
||||
|
||||
// consume invokes fn, consuming r if no error is returned.
|
||||
// It returns [ErrAlreadyConsumed] on subsequent calls after the first successful call.
|
||||
func (r *StoreRegistration) consume(fn func() error) (err error) { |
||||
if r.consumed.Load() { |
||||
return ErrAlreadyConsumed |
||||
} |
||||
return r.consumeSlow(fn) |
||||
} |
||||
|
||||
func (r *StoreRegistration) consumeSlow(fn func() error) (err error) { |
||||
r.m.Lock() |
||||
defer r.m.Unlock() |
||||
if r.consumed.Load() { |
||||
return ErrAlreadyConsumed |
||||
} |
||||
if err = fn(); err == nil { |
||||
r.consumed.Store(true) |
||||
} |
||||
return err // may be nil or non-nil
|
||||
} |
||||
Loading…
Reference in new issue