expvarx.SafeFunc wraps an expvar.Func with a time limit. On reaching the time limit, calls to Value return nil, and no new concurrent calls to the underlying expvar.Func will be started until the call completes. Updates tailscale/corp#16999 Signed-off-by: James Tucker <james@tailscale.com>main
parent
fd94d96e2b
commit
0f3b2e7b86
@ -0,0 +1,89 @@ |
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
// Package expvarx provides some extensions to the [expvar] package.
|
||||
package expvarx |
||||
|
||||
import ( |
||||
"encoding/json" |
||||
"expvar" |
||||
"sync" |
||||
"time" |
||||
|
||||
"tailscale.com/types/lazy" |
||||
) |
||||
|
||||
// SafeFunc is a wrapper around [expvar.Func] that guards against unbounded call
|
||||
// time and ensures that only a single call is in progress at any given time.
|
||||
type SafeFunc struct { |
||||
f expvar.Func |
||||
limit time.Duration |
||||
onSlow func(time.Duration, any) |
||||
|
||||
mu sync.Mutex |
||||
inflight *lazy.SyncValue[any] |
||||
} |
||||
|
||||
// NewSafeFunc returns a new SafeFunc that wraps f.
|
||||
// If f takes longer than limit to execute then Value calls return nil.
|
||||
// If onSlow is non-nil, it is called when f takes longer than limit to execute.
|
||||
// onSlow is called with the duration of the slow call and the final computed
|
||||
// value.
|
||||
func NewSafeFunc(f expvar.Func, limit time.Duration, onSlow func(time.Duration, any)) *SafeFunc { |
||||
return &SafeFunc{f: f, limit: limit, onSlow: onSlow} |
||||
} |
||||
|
||||
// Value acts similarly to [expvar.Func.Value], but if the underlying function
|
||||
// takes longer than the configured limit, all callers will receive nil until
|
||||
// the underlying operation completes. On completion of the underlying
|
||||
// operation, the onSlow callback is called if set.
|
||||
func (s *SafeFunc) Value() any { |
||||
s.mu.Lock() |
||||
|
||||
if s.inflight == nil { |
||||
s.inflight = new(lazy.SyncValue[any]) |
||||
} |
||||
var inflight = s.inflight |
||||
s.mu.Unlock() |
||||
|
||||
// inflight ensures that only a single work routine is spawned at any given
|
||||
// time, but if the routine takes too long inflight is populated with a nil
|
||||
// result. The long running computed value is lost forever.
|
||||
return inflight.Get(func() any { |
||||
start := time.Now() |
||||
result := make(chan any, 1) |
||||
|
||||
// work is spawned in routine so that the caller can timeout.
|
||||
go func() { |
||||
// Allow new work to be started after this work completes
|
||||
defer func() { |
||||
s.mu.Lock() |
||||
s.inflight = nil |
||||
s.mu.Unlock() |
||||
|
||||
}() |
||||
|
||||
v := s.f.Value() |
||||
result <- v |
||||
}() |
||||
|
||||
select { |
||||
case v := <-result: |
||||
return v |
||||
case <-time.After(s.limit): |
||||
if s.onSlow != nil { |
||||
go func() { |
||||
s.onSlow(time.Since(start), <-result) |
||||
}() |
||||
} |
||||
return nil |
||||
} |
||||
}) |
||||
} |
||||
|
||||
// String implements stringer in the same pattern as [expvar.Func], calling
|
||||
// Value and serializing the result as JSON, ignoring errors.
|
||||
func (s *SafeFunc) String() string { |
||||
v, _ := json.Marshal(s.Value()) |
||||
return string(v) |
||||
} |
||||
@ -0,0 +1,137 @@ |
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package expvarx |
||||
|
||||
import ( |
||||
"expvar" |
||||
"fmt" |
||||
"sync" |
||||
"sync/atomic" |
||||
"testing" |
||||
"time" |
||||
) |
||||
|
||||
func ExampleNewSafeFunc() { |
||||
// An artificial blocker to emulate a slow operation.
|
||||
blocker := make(chan struct{}) |
||||
|
||||
// limit is the amount of time a call can take before Value returns nil. No
|
||||
// new calls to the unsafe func will be started until the slow call
|
||||
// completes, at which point onSlow will be called.
|
||||
limit := time.Millisecond |
||||
|
||||
// onSlow is called with the final call duration and the final value in the
|
||||
// event a slow call.
|
||||
onSlow := func(d time.Duration, v any) { |
||||
_ = d // d contains the time the call took
|
||||
_ = v // v contains the final value computed by the slow call
|
||||
fmt.Println("slow call!") |
||||
} |
||||
|
||||
// An unsafe expvar.Func that blocks on the blocker channel.
|
||||
unsafeFunc := expvar.Func(func() any { |
||||
for range blocker { |
||||
} |
||||
return "hello world" |
||||
}) |
||||
|
||||
// f implements the same interface as expvar.Func, but returns nil values
|
||||
// when the unsafe func is too slow.
|
||||
f := NewSafeFunc(unsafeFunc, limit, onSlow) |
||||
|
||||
fmt.Println(f.Value()) |
||||
fmt.Println(f.Value()) |
||||
close(blocker) |
||||
time.Sleep(time.Millisecond) |
||||
fmt.Println(f.Value()) |
||||
// Output: <nil>
|
||||
// <nil>
|
||||
// slow call!
|
||||
// hello world
|
||||
} |
||||
|
||||
func TestSafeFuncHappyPath(t *testing.T) { |
||||
var count int |
||||
f := NewSafeFunc(expvar.Func(func() any { |
||||
count++ |
||||
return count |
||||
}), time.Millisecond, nil) |
||||
|
||||
if got, want := f.Value(), 1; got != want { |
||||
t.Errorf("got %v, want %v", got, want) |
||||
} |
||||
if got, want := f.Value(), 2; got != want { |
||||
t.Errorf("got %v, want %v", got, want) |
||||
} |
||||
} |
||||
|
||||
func TestSafeFuncSlow(t *testing.T) { |
||||
var count int |
||||
blocker := make(chan struct{}) |
||||
var wg sync.WaitGroup |
||||
wg.Add(1) |
||||
f := NewSafeFunc(expvar.Func(func() any { |
||||
defer wg.Done() |
||||
count++ |
||||
<-blocker |
||||
return count |
||||
}), time.Millisecond, nil) |
||||
|
||||
if got := f.Value(); got != nil { |
||||
t.Errorf("got %v; want nil", got) |
||||
} |
||||
if got := f.Value(); got != nil { |
||||
t.Errorf("got %v; want nil", got) |
||||
} |
||||
|
||||
close(blocker) |
||||
wg.Wait() |
||||
|
||||
if count != 1 { |
||||
t.Errorf("got count=%d; want 1", count) |
||||
} |
||||
} |
||||
|
||||
func TestSafeFuncSlowOnSlow(t *testing.T) { |
||||
var count int |
||||
blocker := make(chan struct{}) |
||||
var wg sync.WaitGroup |
||||
wg.Add(2) |
||||
var slowDuration atomic.Pointer[time.Duration] |
||||
var slowCallCount atomic.Int32 |
||||
var slowValue atomic.Value |
||||
f := NewSafeFunc(expvar.Func(func() any { |
||||
defer wg.Done() |
||||
count++ |
||||
<-blocker |
||||
return count |
||||
}), time.Millisecond, func(d time.Duration, v any) { |
||||
defer wg.Done() |
||||
slowDuration.Store(&d) |
||||
slowCallCount.Add(1) |
||||
slowValue.Store(v) |
||||
}) |
||||
|
||||
for i := 0; i < 10; i++ { |
||||
if got := f.Value(); got != nil { |
||||
t.Fatalf("got value=%v; want nil", got) |
||||
} |
||||
} |
||||
|
||||
close(blocker) |
||||
wg.Wait() |
||||
|
||||
if count != 1 { |
||||
t.Errorf("got count=%d; want 1", count) |
||||
} |
||||
if got, want := *slowDuration.Load(), 1*time.Millisecond; got < want { |
||||
t.Errorf("got slowDuration=%v; want at least %d", got, want) |
||||
} |
||||
if got, want := slowCallCount.Load(), int32(1); got != want { |
||||
t.Errorf("got slowCallCount=%d; want %d", got, want) |
||||
} |
||||
if got, want := slowValue.Load().(int), 1; got != want { |
||||
t.Errorf("got slowValue=%d, want %d", got, want) |
||||
} |
||||
} |
||||
Loading…
Reference in new issue