ipn/ipnlocal: convert more tests to use policytest, de-global-ify
Now that we have policytest and the policyclient.Client interface, we can de-global-ify many of the tests, letting them run concurrently with each other, and just removing global variable complexity. This does ~half of the LocalBackend ones. Updates #16998 Change-Id: Iece754e1ef4e49744ccd967fa83629d0dca6f66a Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:
committed by
Brad Fitzpatrick
parent
21f21bd2a2
commit
d06d9007a6
+48
-48
@@ -2881,20 +2881,16 @@ func TestSetExitNodeIDPolicy(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
syspolicy.RegisterWellKnownSettingsForTest(t)
|
|
||||||
|
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
t.Run(test.name, func(t *testing.T) {
|
t.Run(test.name, func(t *testing.T) {
|
||||||
b := newTestBackend(t)
|
var polc policytest.Config
|
||||||
|
|
||||||
policyStore := source.NewTestStore(t)
|
|
||||||
if test.exitNodeIDKey {
|
if test.exitNodeIDKey {
|
||||||
policyStore.SetStrings(source.TestSettingOf(pkey.ExitNodeID, test.exitNodeID))
|
polc.Set(pkey.ExitNodeID, test.exitNodeID)
|
||||||
}
|
}
|
||||||
if test.exitNodeIPKey {
|
if test.exitNodeIPKey {
|
||||||
policyStore.SetStrings(source.TestSettingOf(pkey.ExitNodeIP, test.exitNodeIP))
|
polc.Set(pkey.ExitNodeIP, test.exitNodeIP)
|
||||||
}
|
}
|
||||||
syspolicy.MustRegisterStoreForTest(t, "TestStore", setting.DeviceScope, policyStore)
|
b := newTestBackend(t, polc)
|
||||||
|
|
||||||
if test.nm == nil {
|
if test.nm == nil {
|
||||||
test.nm = new(netmap.NetworkMap)
|
test.nm = new(netmap.NetworkMap)
|
||||||
@@ -3026,15 +3022,13 @@ func TestUpdateNetmapDeltaAutoExitNode(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
syspolicy.RegisterWellKnownSettingsForTest(t)
|
|
||||||
policyStore := source.NewTestStoreOf(t, source.TestSettingOf(
|
|
||||||
pkey.ExitNodeID, "auto:any",
|
|
||||||
))
|
|
||||||
syspolicy.MustRegisterStoreForTest(t, "TestStore", setting.DeviceScope, policyStore)
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
b := newTestLocalBackend(t)
|
sys := tsd.NewSystem()
|
||||||
|
sys.PolicyClient.Set(policytest.Config{
|
||||||
|
pkey.ExitNodeID: "auto:any",
|
||||||
|
})
|
||||||
|
b := newTestLocalBackendWithSys(t, sys)
|
||||||
b.currentNode().SetNetMap(tt.netmap)
|
b.currentNode().SetNetMap(tt.netmap)
|
||||||
b.lastSuggestedExitNode = tt.lastSuggestedExitNode
|
b.lastSuggestedExitNode = tt.lastSuggestedExitNode
|
||||||
b.sys.MagicSock.Get().SetLastNetcheckReportForTest(b.ctx, tt.report)
|
b.sys.MagicSock.Get().SetLastNetcheckReportForTest(b.ctx, tt.report)
|
||||||
@@ -3094,7 +3088,13 @@ func TestUpdateNetmapDeltaAutoExitNode(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestAutoExitNodeSetNetInfoCallback(t *testing.T) {
|
func TestAutoExitNodeSetNetInfoCallback(t *testing.T) {
|
||||||
b := newTestLocalBackend(t)
|
polc := policytest.Config{
|
||||||
|
pkey.ExitNodeID: "auto:any",
|
||||||
|
}
|
||||||
|
sys := tsd.NewSystem()
|
||||||
|
sys.PolicyClient.Set(polc)
|
||||||
|
|
||||||
|
b := newTestLocalBackendWithSys(t, sys)
|
||||||
hi := hostinfo.New()
|
hi := hostinfo.New()
|
||||||
ni := tailcfg.NetInfo{LinkType: "wired"}
|
ni := tailcfg.NetInfo{LinkType: "wired"}
|
||||||
hi.NetInfo = &ni
|
hi.NetInfo = &ni
|
||||||
@@ -3106,16 +3106,12 @@ func TestAutoExitNodeSetNetInfoCallback(t *testing.T) {
|
|||||||
GetMachinePrivateKey: func() (key.MachinePrivate, error) {
|
GetMachinePrivateKey: func() (key.MachinePrivate, error) {
|
||||||
return k, nil
|
return k, nil
|
||||||
},
|
},
|
||||||
Dialer: tsdial.NewDialer(netmon.NewStatic()),
|
Dialer: tsdial.NewDialer(netmon.NewStatic()),
|
||||||
Logf: b.logf,
|
Logf: b.logf,
|
||||||
|
PolicyClient: polc,
|
||||||
}
|
}
|
||||||
cc = newClient(t, opts)
|
cc = newClient(t, opts)
|
||||||
b.cc = cc
|
b.cc = cc
|
||||||
syspolicy.RegisterWellKnownSettingsForTest(t)
|
|
||||||
policyStore := source.NewTestStoreOf(t, source.TestSettingOf(
|
|
||||||
pkey.ExitNodeID, "auto:any",
|
|
||||||
))
|
|
||||||
syspolicy.MustRegisterStoreForTest(t, "TestStore", setting.DeviceScope, policyStore)
|
|
||||||
peer1 := makePeer(1, withCap(26), withDERP(3), withSuggest(), withExitRoutes())
|
peer1 := makePeer(1, withCap(26), withDERP(3), withSuggest(), withExitRoutes())
|
||||||
peer2 := makePeer(2, withCap(26), withDERP(2), withSuggest(), withExitRoutes())
|
peer2 := makePeer(2, withCap(26), withDERP(2), withSuggest(), withExitRoutes())
|
||||||
selfNode := tailcfg.Node{
|
selfNode := tailcfg.Node{
|
||||||
@@ -3219,12 +3215,14 @@ func TestSetControlClientStatusAutoExitNode(t *testing.T) {
|
|||||||
},
|
},
|
||||||
DERPMap: derpMap,
|
DERPMap: derpMap,
|
||||||
}
|
}
|
||||||
b := newTestLocalBackend(t)
|
|
||||||
syspolicy.RegisterWellKnownSettingsForTest(t)
|
polc := policytest.Config{
|
||||||
policyStore := source.NewTestStoreOf(t, source.TestSettingOf(
|
pkey.ExitNodeID: "auto:any",
|
||||||
pkey.ExitNodeID, "auto:any",
|
}
|
||||||
))
|
sys := tsd.NewSystem()
|
||||||
syspolicy.MustRegisterStoreForTest(t, "TestStore", setting.DeviceScope, policyStore)
|
sys.PolicyClient.Set(polc)
|
||||||
|
|
||||||
|
b := newTestLocalBackendWithSys(t, sys)
|
||||||
b.currentNode().SetNetMap(nm)
|
b.currentNode().SetNetMap(nm)
|
||||||
// Peer 2 should be the initial exit node, as it's better than peer 1
|
// Peer 2 should be the initial exit node, as it's better than peer 1
|
||||||
// in terms of latency and DERP region.
|
// in terms of latency and DERP region.
|
||||||
@@ -3461,21 +3459,20 @@ func TestApplySysPolicy(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
syspolicy.RegisterWellKnownSettingsForTest(t)
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
settings := make([]source.TestSetting[string], 0, len(tt.stringPolicies))
|
var polc policytest.Config
|
||||||
for p, v := range tt.stringPolicies {
|
for k, v := range tt.stringPolicies {
|
||||||
settings = append(settings, source.TestSettingOf(p, v))
|
polc.Set(k, v)
|
||||||
}
|
}
|
||||||
policyStore := source.NewTestStoreOf(t, settings...)
|
|
||||||
syspolicy.MustRegisterStoreForTest(t, "TestStore", setting.DeviceScope, policyStore)
|
|
||||||
|
|
||||||
t.Run("unit", func(t *testing.T) {
|
t.Run("unit", func(t *testing.T) {
|
||||||
prefs := tt.prefs.Clone()
|
prefs := tt.prefs.Clone()
|
||||||
|
|
||||||
lb := newTestLocalBackend(t)
|
sys := tsd.NewSystem()
|
||||||
|
sys.PolicyClient.Set(polc)
|
||||||
|
|
||||||
|
lb := newTestLocalBackendWithSys(t, sys)
|
||||||
gotAnyChange := lb.applySysPolicyLocked(prefs)
|
gotAnyChange := lb.applySysPolicyLocked(prefs)
|
||||||
|
|
||||||
if gotAnyChange && prefs.Equals(&tt.prefs) {
|
if gotAnyChange && prefs.Equals(&tt.prefs) {
|
||||||
@@ -3508,7 +3505,7 @@ func TestApplySysPolicy(t *testing.T) {
|
|||||||
pm := must.Get(newProfileManager(new(mem.Store), t.Logf, new(health.Tracker)))
|
pm := must.Get(newProfileManager(new(mem.Store), t.Logf, new(health.Tracker)))
|
||||||
pm.prefs = usePrefs.View()
|
pm.prefs = usePrefs.View()
|
||||||
|
|
||||||
b := newTestBackend(t)
|
b := newTestBackend(t, polc)
|
||||||
b.mu.Lock()
|
b.mu.Lock()
|
||||||
b.pm = pm
|
b.pm = pm
|
||||||
b.mu.Unlock()
|
b.mu.Unlock()
|
||||||
@@ -3607,24 +3604,26 @@ func TestPreferencePolicyInfo(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
syspolicy.RegisterWellKnownSettingsForTest(t)
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
for _, pp := range preferencePolicies {
|
for _, pp := range preferencePolicies {
|
||||||
t.Run(string(pp.key), func(t *testing.T) {
|
t.Run(string(pp.key), func(t *testing.T) {
|
||||||
s := source.TestSetting[string]{
|
t.Parallel()
|
||||||
Key: pp.key,
|
|
||||||
Error: tt.policyError,
|
var polc policytest.Config
|
||||||
Value: tt.policyValue,
|
if tt.policyError != nil {
|
||||||
|
polc.Set(pp.key, tt.policyError)
|
||||||
|
} else {
|
||||||
|
polc.Set(pp.key, tt.policyValue)
|
||||||
}
|
}
|
||||||
policyStore := source.NewTestStoreOf(t, s)
|
|
||||||
syspolicy.MustRegisterStoreForTest(t, "TestStore", setting.DeviceScope, policyStore)
|
|
||||||
|
|
||||||
prefs := defaultPrefs.AsStruct()
|
prefs := defaultPrefs.AsStruct()
|
||||||
pp.set(prefs, tt.initialValue)
|
pp.set(prefs, tt.initialValue)
|
||||||
|
|
||||||
lb := newTestLocalBackend(t)
|
sys := tsd.NewSystem()
|
||||||
|
sys.PolicyClient.Set(polc)
|
||||||
|
|
||||||
|
lb := newTestLocalBackendWithSys(t, sys)
|
||||||
gotAnyChange := lb.applySysPolicyLocked(prefs)
|
gotAnyChange := lb.applySysPolicyLocked(prefs)
|
||||||
|
|
||||||
if gotAnyChange != tt.wantChange {
|
if gotAnyChange != tt.wantChange {
|
||||||
@@ -6534,7 +6533,8 @@ func TestUpdatePrefsOnSysPolicyChange(t *testing.T) {
|
|||||||
store := source.NewTestStoreOf[string](t)
|
store := source.NewTestStoreOf[string](t)
|
||||||
syspolicy.MustRegisterStoreForTest(t, "TestSource", setting.DeviceScope, store)
|
syspolicy.MustRegisterStoreForTest(t, "TestSource", setting.DeviceScope, store)
|
||||||
|
|
||||||
lb := newLocalBackendWithTestControl(t, enableLogging, func(tb testing.TB, opts controlclient.Options) controlclient.Client {
|
sys := tsd.NewSystem()
|
||||||
|
lb := newLocalBackendWithSysAndTestControl(t, enableLogging, sys, func(tb testing.TB, opts controlclient.Options) controlclient.Client {
|
||||||
return newClient(tb, opts)
|
return newClient(tb, opts)
|
||||||
})
|
})
|
||||||
if tt.initialPrefs != nil {
|
if tt.initialPrefs != nil {
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ import (
|
|||||||
"tailscale.com/types/netmap"
|
"tailscale.com/types/netmap"
|
||||||
"tailscale.com/util/mak"
|
"tailscale.com/util/mak"
|
||||||
"tailscale.com/util/must"
|
"tailscale.com/util/must"
|
||||||
|
"tailscale.com/util/syspolicy/policyclient"
|
||||||
"tailscale.com/wgengine"
|
"tailscale.com/wgengine"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -870,7 +871,7 @@ func mustCreateURL(t *testing.T, u string) url.URL {
|
|||||||
return *uParsed
|
return *uParsed
|
||||||
}
|
}
|
||||||
|
|
||||||
func newTestBackend(t *testing.T) *LocalBackend {
|
func newTestBackend(t *testing.T, opts ...any) *LocalBackend {
|
||||||
var logf logger.Logf = logger.Discard
|
var logf logger.Logf = logger.Discard
|
||||||
const debug = true
|
const debug = true
|
||||||
if debug {
|
if debug {
|
||||||
@@ -878,6 +879,16 @@ func newTestBackend(t *testing.T) *LocalBackend {
|
|||||||
}
|
}
|
||||||
|
|
||||||
sys := tsd.NewSystem()
|
sys := tsd.NewSystem()
|
||||||
|
|
||||||
|
for _, o := range opts {
|
||||||
|
switch v := o.(type) {
|
||||||
|
case policyclient.Client:
|
||||||
|
sys.PolicyClient.Set(v)
|
||||||
|
default:
|
||||||
|
panic(fmt.Sprintf("unsupported option type %T", v))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
e, err := wgengine.NewUserspaceEngine(logf, wgengine.Config{
|
e, err := wgengine.NewUserspaceEngine(logf, wgengine.Config{
|
||||||
SetSubsystem: sys.Set,
|
SetSubsystem: sys.Set,
|
||||||
HealthTracker: sys.HealthTracker(),
|
HealthTracker: sys.HealthTracker(),
|
||||||
|
|||||||
@@ -19,7 +19,12 @@ import (
|
|||||||
// It is used for testing purposes to simulate policy client behavior.
|
// It is used for testing purposes to simulate policy client behavior.
|
||||||
//
|
//
|
||||||
// It panics if a value is Set with one type and then accessed with a different
|
// It panics if a value is Set with one type and then accessed with a different
|
||||||
// expected type.
|
// expected type and/or value. Some accessors such as GetPreferenceOption and
|
||||||
|
// GetVisibility support either a ptype.PreferenceOption/ptype.Visibility in the
|
||||||
|
// map, or the string representation as supported by their UnmarshalText
|
||||||
|
// methods.
|
||||||
|
//
|
||||||
|
// The map value may be an error to return that error value from the accessor.
|
||||||
type Config map[pkey.Key]any
|
type Config map[pkey.Key]any
|
||||||
|
|
||||||
var _ policyclient.Client = Config{}
|
var _ policyclient.Client = Config{}
|
||||||
@@ -33,70 +38,108 @@ func (c *Config) Set(key pkey.Key, value any) {
|
|||||||
|
|
||||||
func (c Config) GetStringArray(key pkey.Key, defaultVal []string) ([]string, error) {
|
func (c Config) GetStringArray(key pkey.Key, defaultVal []string) ([]string, error) {
|
||||||
if val, ok := c[key]; ok {
|
if val, ok := c[key]; ok {
|
||||||
if arr, ok := val.([]string); ok {
|
switch val := val.(type) {
|
||||||
return arr, nil
|
case []string:
|
||||||
|
return val, nil
|
||||||
|
case error:
|
||||||
|
return nil, val
|
||||||
|
default:
|
||||||
|
panic(fmt.Sprintf("key %s is not a []string; got %T", key, val))
|
||||||
}
|
}
|
||||||
panic(fmt.Sprintf("key %s is not a []string", key))
|
|
||||||
}
|
}
|
||||||
return defaultVal, nil
|
return defaultVal, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c Config) GetString(key pkey.Key, defaultVal string) (string, error) {
|
func (c Config) GetString(key pkey.Key, defaultVal string) (string, error) {
|
||||||
if val, ok := c[key]; ok {
|
if val, ok := c[key]; ok {
|
||||||
if str, ok := val.(string); ok {
|
switch val := val.(type) {
|
||||||
return str, nil
|
case string:
|
||||||
|
return val, nil
|
||||||
|
case error:
|
||||||
|
return "", val
|
||||||
|
default:
|
||||||
|
panic(fmt.Sprintf("key %s is not a string; got %T", key, val))
|
||||||
}
|
}
|
||||||
panic(fmt.Sprintf("key %s is not a string", key))
|
|
||||||
}
|
}
|
||||||
return defaultVal, nil
|
return defaultVal, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c Config) GetBoolean(key pkey.Key, defaultVal bool) (bool, error) {
|
func (c Config) GetBoolean(key pkey.Key, defaultVal bool) (bool, error) {
|
||||||
if val, ok := c[key]; ok {
|
if val, ok := c[key]; ok {
|
||||||
if b, ok := val.(bool); ok {
|
switch val := val.(type) {
|
||||||
return b, nil
|
case bool:
|
||||||
|
return val, nil
|
||||||
|
case error:
|
||||||
|
return false, val
|
||||||
|
default:
|
||||||
|
panic(fmt.Sprintf("key %s is not a bool; got %T", key, val))
|
||||||
}
|
}
|
||||||
panic(fmt.Sprintf("key %s is not a bool", key))
|
|
||||||
}
|
}
|
||||||
return defaultVal, nil
|
return defaultVal, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c Config) GetUint64(key pkey.Key, defaultVal uint64) (uint64, error) {
|
func (c Config) GetUint64(key pkey.Key, defaultVal uint64) (uint64, error) {
|
||||||
if val, ok := c[key]; ok {
|
if val, ok := c[key]; ok {
|
||||||
if u, ok := val.(uint64); ok {
|
switch val := val.(type) {
|
||||||
return u, nil
|
case uint64:
|
||||||
|
return val, nil
|
||||||
|
case error:
|
||||||
|
return 0, val
|
||||||
|
default:
|
||||||
|
panic(fmt.Sprintf("key %s is not a uint64; got %T", key, val))
|
||||||
}
|
}
|
||||||
panic(fmt.Sprintf("key %s is not a uint64", key))
|
|
||||||
}
|
}
|
||||||
return defaultVal, nil
|
return defaultVal, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c Config) GetDuration(key pkey.Key, defaultVal time.Duration) (time.Duration, error) {
|
func (c Config) GetDuration(key pkey.Key, defaultVal time.Duration) (time.Duration, error) {
|
||||||
if val, ok := c[key]; ok {
|
if val, ok := c[key]; ok {
|
||||||
if d, ok := val.(time.Duration); ok {
|
switch val := val.(type) {
|
||||||
return d, nil
|
case time.Duration:
|
||||||
|
return val, nil
|
||||||
|
case error:
|
||||||
|
return 0, val
|
||||||
|
default:
|
||||||
|
panic(fmt.Sprintf("key %s is not a time.Duration; got %T", key, val))
|
||||||
}
|
}
|
||||||
panic(fmt.Sprintf("key %s is not a time.Duration", key))
|
|
||||||
}
|
}
|
||||||
return defaultVal, nil
|
return defaultVal, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c Config) GetPreferenceOption(key pkey.Key, defaultVal ptype.PreferenceOption) (ptype.PreferenceOption, error) {
|
func (c Config) GetPreferenceOption(key pkey.Key, defaultVal ptype.PreferenceOption) (ptype.PreferenceOption, error) {
|
||||||
if val, ok := c[key]; ok {
|
if val, ok := c[key]; ok {
|
||||||
if p, ok := val.(ptype.PreferenceOption); ok {
|
switch val := val.(type) {
|
||||||
return p, nil
|
case ptype.PreferenceOption:
|
||||||
|
return val, nil
|
||||||
|
case error:
|
||||||
|
var zero ptype.PreferenceOption
|
||||||
|
return zero, val
|
||||||
|
case string:
|
||||||
|
var p ptype.PreferenceOption
|
||||||
|
err := p.UnmarshalText(([]byte)(val))
|
||||||
|
return p, err
|
||||||
|
default:
|
||||||
|
panic(fmt.Sprintf("key %s is not a ptype.PreferenceOption", key))
|
||||||
}
|
}
|
||||||
panic(fmt.Sprintf("key %s is not a ptype.PreferenceOption", key))
|
|
||||||
}
|
}
|
||||||
return defaultVal, nil
|
return defaultVal, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c Config) GetVisibility(key pkey.Key) (ptype.Visibility, error) {
|
func (c Config) GetVisibility(key pkey.Key) (ptype.Visibility, error) {
|
||||||
if val, ok := c[key]; ok {
|
if val, ok := c[key]; ok {
|
||||||
if p, ok := val.(ptype.Visibility); ok {
|
switch val := val.(type) {
|
||||||
return p, nil
|
case ptype.Visibility:
|
||||||
|
return val, nil
|
||||||
|
case error:
|
||||||
|
var zero ptype.Visibility
|
||||||
|
return zero, val
|
||||||
|
case string:
|
||||||
|
var p ptype.Visibility
|
||||||
|
err := p.UnmarshalText(([]byte)(val))
|
||||||
|
return p, err
|
||||||
|
default:
|
||||||
|
panic(fmt.Sprintf("key %s is not a ptype.Visibility", key))
|
||||||
}
|
}
|
||||||
panic(fmt.Sprintf("key %s is not a ptype.Visibility", key))
|
|
||||||
}
|
}
|
||||||
return ptype.Visibility(ptype.ShowChoiceByPolicy), nil
|
return ptype.Visibility(ptype.ShowChoiceByPolicy), nil
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user