diff --git a/tsconsensus/authorization.go b/tsconsensus/authorization.go index 6261a8f1d..017c9e807 100644 --- a/tsconsensus/authorization.go +++ b/tsconsensus/authorization.go @@ -17,6 +17,10 @@ import ( "tailscale.com/util/set" ) +// defaultStatusCacheTimeout is the duration after which cached status will be +// disregarded. See tailscaleStatusGetter.cacheTimeout. +const defaultStatusCacheTimeout = time.Second + type statusGetter interface { getStatus(context.Context) (*ipnstate.Status, error) } @@ -24,6 +28,10 @@ type statusGetter interface { type tailscaleStatusGetter struct { ts *tsnet.Server + // cacheTimeout is used to determine when the cached status should be + // disregarded and a new status fetched. Zero means ignore the cache. + cacheTimeout time.Duration + mu sync.Mutex // protects the following lastStatus *ipnstate.Status lastStatusTime time.Time @@ -40,7 +48,7 @@ func (sg *tailscaleStatusGetter) fetchStatus(ctx context.Context) (*ipnstate.Sta func (sg *tailscaleStatusGetter) getStatus(ctx context.Context) (*ipnstate.Status, error) { sg.mu.Lock() defer sg.mu.Unlock() - if sg.lastStatus != nil && time.Since(sg.lastStatusTime) < 1*time.Second { + if sg.lastStatus != nil && time.Since(sg.lastStatusTime) < sg.cacheTimeout { return sg.lastStatus, nil } status, err := sg.fetchStatus(ctx) @@ -61,14 +69,23 @@ type authorization struct { } func newAuthorization(ts *tsnet.Server, tag string) *authorization { + return newAuthorizationWithCacheTimeout(ts, tag, defaultStatusCacheTimeout) +} + +func newAuthorizationWithCacheTimeout(ts *tsnet.Server, tag string, cacheTimeout time.Duration) *authorization { return &authorization{ sg: &tailscaleStatusGetter{ - ts: ts, + ts: ts, + cacheTimeout: cacheTimeout, }, tag: tag, } } +func newAuthorizationForTest(ts *tsnet.Server, tag string) *authorization { + return newAuthorizationWithCacheTimeout(ts, tag, 0) +} + func (a *authorization) Refresh(ctx context.Context) error { tStatus, err := a.sg.getStatus(ctx) if err != nil { diff --git a/tsconsensus/tsconsensus_test.go b/tsconsensus/tsconsensus_test.go index 2199a0c6b..8897db119 100644 --- a/tsconsensus/tsconsensus_test.go +++ b/tsconsensus/tsconsensus_test.go @@ -642,7 +642,7 @@ func TestOnlyTaggedPeersCanBeDialed(t *testing.T) { // make a StreamLayer for ps[0] ts := ps[0].ts - auth := newAuthorization(ts, clusterTag) + auth := newAuthorizationForTest(ts, clusterTag) port := 19841 lns := make([]net.Listener, 3) @@ -692,10 +692,12 @@ func TestOnlyTaggedPeersCanBeDialed(t *testing.T) { conn.Close() _, err = sl.Dial(a2, 2*time.Second) + if err == nil { + t.Fatal("expected dial error to untagged node, got none") + } if err.Error() != "dial: peer is not allowed" { t.Fatalf("expected dial: peer is not allowed, got: %v", err) } - } func TestOnlyTaggedPeersCanJoin(t *testing.T) {