Co-authored-by: Maisem Ali <maisem@tailscale.com> Co-authored-by: Brad Fitzpatrick <bradfitz@tailscale.com> Signed-off-by: David Anderson <danderson@tailscale.com>main
parent
5bca44d572
commit
9e6b4d7ad8
@ -0,0 +1,88 @@ |
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
// Package lazy provides types for lazily initialized values.
|
||||
package lazy |
||||
|
||||
import "sync" |
||||
|
||||
// SyncValue is a lazily computed value.
|
||||
//
|
||||
// Use either Get or GetErr, depending on whether your fill function returns an
|
||||
// error.
|
||||
//
|
||||
// Recursive use of a SyncValue from its own fill function will deadlock.
|
||||
//
|
||||
// SyncValue is safe for concurrent use.
|
||||
type SyncValue[T any] struct { |
||||
once sync.Once |
||||
v T |
||||
err error |
||||
} |
||||
|
||||
// Set attempts to set z's value to val, and reports whether it succeeded.
|
||||
// Set only succeeds if none of Get/GetErr/Set have been called before.
|
||||
func (z *SyncValue[T]) Set(val T) bool { |
||||
var wasSet bool |
||||
z.once.Do(func() { |
||||
z.v = val |
||||
wasSet = true |
||||
}) |
||||
return wasSet |
||||
} |
||||
|
||||
// MustSet sets z's value to val, or panics if z already has a value.
|
||||
func (z *SyncValue[T]) MustSet(val T) { |
||||
if !z.Set(val) { |
||||
panic("Set after already filled") |
||||
} |
||||
} |
||||
|
||||
// Get returns z's value, calling fill to compute it if necessary.
|
||||
// f is called at most once.
|
||||
func (z *SyncValue[T]) Get(fill func() T) T { |
||||
z.once.Do(func() { z.v = fill() }) |
||||
return z.v |
||||
} |
||||
|
||||
// GetErr returns z's value, calling fill to compute it if necessary.
|
||||
// f is called at most once, and z remembers both of fill's outputs.
|
||||
func (z *SyncValue[T]) GetErr(fill func() (T, error)) (T, error) { |
||||
z.once.Do(func() { z.v, z.err = fill() }) |
||||
return z.v, z.err |
||||
} |
||||
|
||||
// SyncFunc wraps a function to make it lazy.
|
||||
//
|
||||
// The returned function calls fill the first time it's called, and returns
|
||||
// fill's result on every subsequent call.
|
||||
//
|
||||
// The returned function is safe for concurrent use.
|
||||
func SyncFunc[T any](fill func() T) func() T { |
||||
var ( |
||||
once sync.Once |
||||
v T |
||||
) |
||||
return func() T { |
||||
once.Do(func() { v = fill() }) |
||||
return v |
||||
} |
||||
} |
||||
|
||||
// SyncFuncErr wraps a function to make it lazy.
|
||||
//
|
||||
// The returned function calls fill the first time it's called, and returns
|
||||
// fill's results on every subsequent call.
|
||||
//
|
||||
// The returned function is safe for concurrent use.
|
||||
func SyncFuncErr[T any](fill func() (T, error)) func() (T, error) { |
||||
var ( |
||||
once sync.Once |
||||
v T |
||||
err error |
||||
) |
||||
return func() (T, error) { |
||||
once.Do(func() { v, err = fill() }) |
||||
return v, err |
||||
} |
||||
} |
||||
@ -0,0 +1,150 @@ |
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package lazy |
||||
|
||||
import ( |
||||
"errors" |
||||
"sync" |
||||
"testing" |
||||
) |
||||
|
||||
func TestSyncValue(t *testing.T) { |
||||
var lt SyncValue[int] |
||||
n := int(testing.AllocsPerRun(1000, func() { |
||||
got := lt.Get(fortyTwo) |
||||
if got != 42 { |
||||
t.Fatalf("got %v; want 42", got) |
||||
} |
||||
})) |
||||
if n != 0 { |
||||
t.Errorf("allocs = %v; want 0", n) |
||||
} |
||||
} |
||||
|
||||
func TestSyncValueErr(t *testing.T) { |
||||
var lt SyncValue[int] |
||||
n := int(testing.AllocsPerRun(1000, func() { |
||||
got, err := lt.GetErr(func() (int, error) { |
||||
return 42, nil |
||||
}) |
||||
if got != 42 || err != nil { |
||||
t.Fatalf("got %v, %v; want 42, nil", got, err) |
||||
} |
||||
})) |
||||
if n != 0 { |
||||
t.Errorf("allocs = %v; want 0", n) |
||||
} |
||||
|
||||
var lterr SyncValue[int] |
||||
wantErr := errors.New("test error") |
||||
n = int(testing.AllocsPerRun(1000, func() { |
||||
got, err := lterr.GetErr(func() (int, error) { |
||||
return 0, wantErr |
||||
}) |
||||
if got != 0 || err != wantErr { |
||||
t.Fatalf("got %v, %v; want 0, %v", got, err, wantErr) |
||||
} |
||||
})) |
||||
if n != 0 { |
||||
t.Errorf("allocs = %v; want 0", n) |
||||
} |
||||
} |
||||
|
||||
func TestSyncValueSet(t *testing.T) { |
||||
var lt SyncValue[int] |
||||
if !lt.Set(42) { |
||||
t.Fatalf("Set failed") |
||||
} |
||||
if lt.Set(43) { |
||||
t.Fatalf("Set succeeded after first Set") |
||||
} |
||||
n := int(testing.AllocsPerRun(1000, func() { |
||||
got := lt.Get(fortyTwo) |
||||
if got != 42 { |
||||
t.Fatalf("got %v; want 42", got) |
||||
} |
||||
})) |
||||
if n != 0 { |
||||
t.Errorf("allocs = %v; want 0", n) |
||||
} |
||||
} |
||||
|
||||
func TestSyncValueMustSet(t *testing.T) { |
||||
var lt SyncValue[int] |
||||
lt.MustSet(42) |
||||
defer func() { |
||||
if e := recover(); e == nil { |
||||
t.Errorf("unexpected success; want panic") |
||||
} |
||||
}() |
||||
lt.MustSet(43) |
||||
} |
||||
|
||||
func TestSyncValueConcurrent(t *testing.T) { |
||||
var ( |
||||
lt SyncValue[int] |
||||
wg sync.WaitGroup |
||||
start = make(chan struct{}) |
||||
routines = 10000 |
||||
) |
||||
wg.Add(routines) |
||||
for i := 0; i < routines; i++ { |
||||
go func() { |
||||
defer wg.Done() |
||||
// Every goroutine waits for the go signal, so that more of them
|
||||
// have a chance to race on the initial Get than with sequential
|
||||
// goroutine starts.
|
||||
<-start |
||||
got := lt.Get(fortyTwo) |
||||
if got != 42 { |
||||
t.Errorf("got %v; want 42", got) |
||||
} |
||||
}() |
||||
} |
||||
close(start) |
||||
wg.Wait() |
||||
} |
||||
|
||||
func TestSyncFunc(t *testing.T) { |
||||
f := SyncFunc(fortyTwo) |
||||
|
||||
n := int(testing.AllocsPerRun(1000, func() { |
||||
got := f() |
||||
if got != 42 { |
||||
t.Fatalf("got %v; want 42", got) |
||||
} |
||||
})) |
||||
if n != 0 { |
||||
t.Errorf("allocs = %v; want 0", n) |
||||
} |
||||
} |
||||
|
||||
func TestSyncFuncErr(t *testing.T) { |
||||
f := SyncFuncErr(func() (int, error) { |
||||
return 42, nil |
||||
}) |
||||
n := int(testing.AllocsPerRun(1000, func() { |
||||
got, err := f() |
||||
if got != 42 || err != nil { |
||||
t.Fatalf("got %v, %v; want 42, nil", got, err) |
||||
} |
||||
})) |
||||
if n != 0 { |
||||
t.Errorf("allocs = %v; want 0", n) |
||||
} |
||||
|
||||
wantErr := errors.New("test error") |
||||
f = SyncFuncErr(func() (int, error) { |
||||
return 0, wantErr |
||||
}) |
||||
n = int(testing.AllocsPerRun(1000, func() { |
||||
got, err := f() |
||||
if got != 0 || err != wantErr { |
||||
t.Fatalf("got %v, %v; want 0, %v", got, err, wantErr) |
||||
} |
||||
})) |
||||
if n != 0 { |
||||
t.Errorf("allocs = %v; want 0", n) |
||||
} |
||||
} |
||||
@ -0,0 +1,99 @@ |
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package lazy |
||||
|
||||
// GValue is a lazily computed value.
|
||||
//
|
||||
// Use either Get or GetErr, depending on whether your fill function returns an
|
||||
// error.
|
||||
//
|
||||
// Recursive use of a GValue from its own fill function will panic.
|
||||
//
|
||||
// GValue is not safe for concurrent use. (Mnemonic: G is for one Goroutine,
|
||||
// which isn't strictly true if you provide your own synchronization between
|
||||
// goroutines, but in practice most of our callers have been using it within
|
||||
// a single goroutine.)
|
||||
type GValue[T any] struct { |
||||
done bool |
||||
calling bool |
||||
V T |
||||
err error |
||||
} |
||||
|
||||
// Set attempts to set z's value to val, and reports whether it succeeded.
|
||||
// Set only succeeds if none of Get/GetErr/Set have been called before.
|
||||
func (z *GValue[T]) Set(v T) bool { |
||||
if z.done { |
||||
return false |
||||
} |
||||
if z.calling { |
||||
panic("Set while Get fill is running") |
||||
} |
||||
z.V = v |
||||
z.done = true |
||||
return true |
||||
} |
||||
|
||||
// MustSet sets z's value to val, or panics if z already has a value.
|
||||
func (z *GValue[T]) MustSet(val T) { |
||||
if !z.Set(val) { |
||||
panic("Set after already filled") |
||||
} |
||||
} |
||||
|
||||
// Get returns z's value, calling fill to compute it if necessary.
|
||||
// f is called at most once.
|
||||
func (z *GValue[T]) Get(fill func() T) T { |
||||
if !z.done { |
||||
if z.calling { |
||||
panic("recursive lazy fill") |
||||
} |
||||
z.calling = true |
||||
z.V = fill() |
||||
z.done = true |
||||
z.calling = false |
||||
} |
||||
return z.V |
||||
} |
||||
|
||||
// GetErr returns z's value, calling fill to compute it if necessary.
|
||||
// f is called at most once, and z remembers both of fill's outputs.
|
||||
func (z *GValue[T]) GetErr(fill func() (T, error)) (T, error) { |
||||
if !z.done { |
||||
if z.calling { |
||||
panic("recursive lazy fill") |
||||
} |
||||
z.calling = true |
||||
z.V, z.err = fill() |
||||
z.done = true |
||||
z.calling = false |
||||
} |
||||
return z.V, z.err |
||||
} |
||||
|
||||
// GFunc wraps a function to make it lazy.
|
||||
//
|
||||
// The returned function calls fill the first time it's called, and returns
|
||||
// fill's result on every subsequent call.
|
||||
//
|
||||
// The returned function is not safe for concurrent use.
|
||||
func GFunc[T any](fill func() T) func() T { |
||||
var v GValue[T] |
||||
return func() T { |
||||
return v.Get(fill) |
||||
} |
||||
} |
||||
|
||||
// SyncFuncErr wraps a function to make it lazy.
|
||||
//
|
||||
// The returned function calls fill the first time it's called, and returns
|
||||
// fill's results on every subsequent call.
|
||||
//
|
||||
// The returned function is not safe for concurrent use.
|
||||
func GFuncErr[T any](fill func() (T, error)) func() (T, error) { |
||||
var v GValue[T] |
||||
return func() (T, error) { |
||||
return v.GetErr(fill) |
||||
} |
||||
} |
||||
@ -0,0 +1,140 @@ |
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package lazy |
||||
|
||||
import ( |
||||
"errors" |
||||
"testing" |
||||
) |
||||
|
||||
func fortyTwo() int { return 42 } |
||||
|
||||
func TestGValue(t *testing.T) { |
||||
var lt GValue[int] |
||||
n := int(testing.AllocsPerRun(1000, func() { |
||||
got := lt.Get(fortyTwo) |
||||
if got != 42 { |
||||
t.Fatalf("got %v; want 42", got) |
||||
} |
||||
})) |
||||
if n != 0 { |
||||
t.Errorf("allocs = %v; want 0", n) |
||||
} |
||||
} |
||||
|
||||
func TestGValueErr(t *testing.T) { |
||||
var lt GValue[int] |
||||
n := int(testing.AllocsPerRun(1000, func() { |
||||
got, err := lt.GetErr(func() (int, error) { |
||||
return 42, nil |
||||
}) |
||||
if got != 42 || err != nil { |
||||
t.Fatalf("got %v, %v; want 42, nil", got, err) |
||||
} |
||||
})) |
||||
if n != 0 { |
||||
t.Errorf("allocs = %v; want 0", n) |
||||
} |
||||
|
||||
var lterr GValue[int] |
||||
wantErr := errors.New("test error") |
||||
n = int(testing.AllocsPerRun(1000, func() { |
||||
got, err := lterr.GetErr(func() (int, error) { |
||||
return 0, wantErr |
||||
}) |
||||
if got != 0 || err != wantErr { |
||||
t.Fatalf("got %v, %v; want 0, %v", got, err, wantErr) |
||||
} |
||||
})) |
||||
if n != 0 { |
||||
t.Errorf("allocs = %v; want 0", n) |
||||
} |
||||
} |
||||
|
||||
func TestGValueSet(t *testing.T) { |
||||
var lt GValue[int] |
||||
if !lt.Set(42) { |
||||
t.Fatalf("Set failed") |
||||
} |
||||
if lt.Set(43) { |
||||
t.Fatalf("Set succeeded after first Set") |
||||
} |
||||
n := int(testing.AllocsPerRun(1000, func() { |
||||
got := lt.Get(fortyTwo) |
||||
if got != 42 { |
||||
t.Fatalf("got %v; want 42", got) |
||||
} |
||||
})) |
||||
if n != 0 { |
||||
t.Errorf("allocs = %v; want 0", n) |
||||
} |
||||
} |
||||
|
||||
func TestGValueMustSet(t *testing.T) { |
||||
var lt GValue[int] |
||||
lt.MustSet(42) |
||||
defer func() { |
||||
if e := recover(); e == nil { |
||||
t.Errorf("unexpected success; want panic") |
||||
} |
||||
}() |
||||
lt.MustSet(43) |
||||
} |
||||
|
||||
func TestGValueRecursivePanic(t *testing.T) { |
||||
defer func() { |
||||
if e := recover(); e != nil { |
||||
t.Logf("got panic, as expected") |
||||
} else { |
||||
t.Errorf("unexpected success; want panic") |
||||
} |
||||
}() |
||||
v := GValue[int]{} |
||||
v.Get(func() int { |
||||
return v.Get(func() int { return 42 }) |
||||
}) |
||||
} |
||||
|
||||
func TestGFunc(t *testing.T) { |
||||
f := GFunc(fortyTwo) |
||||
|
||||
n := int(testing.AllocsPerRun(1000, func() { |
||||
got := f() |
||||
if got != 42 { |
||||
t.Fatalf("got %v; want 42", got) |
||||
} |
||||
})) |
||||
if n != 0 { |
||||
t.Errorf("allocs = %v; want 0", n) |
||||
} |
||||
} |
||||
|
||||
func TestGFuncErr(t *testing.T) { |
||||
f := GFuncErr(func() (int, error) { |
||||
return 42, nil |
||||
}) |
||||
n := int(testing.AllocsPerRun(1000, func() { |
||||
got, err := f() |
||||
if got != 42 || err != nil { |
||||
t.Fatalf("got %v, %v; want 42, nil", got, err) |
||||
} |
||||
})) |
||||
if n != 0 { |
||||
t.Errorf("allocs = %v; want 0", n) |
||||
} |
||||
|
||||
wantErr := errors.New("test error") |
||||
f = GFuncErr(func() (int, error) { |
||||
return 0, wantErr |
||||
}) |
||||
n = int(testing.AllocsPerRun(1000, func() { |
||||
got, err := f() |
||||
if got != 0 || err != wantErr { |
||||
t.Fatalf("got %v, %v; want 0, %v", got, err, wantErr) |
||||
} |
||||
})) |
||||
if n != 0 { |
||||
t.Errorf("allocs = %v; want 0", n) |
||||
} |
||||
} |
||||
Loading…
Reference in new issue