util/ctxkey: add package for type-safe context keys (#10841)
The lack of type-safety in context.WithValue leads to the common pattern
of defining of package-scoped type to ensure global uniqueness:
type fooKey struct{}
func withFoo(ctx context, v Foo) context.Context {
return context.WithValue(ctx, fooKey{}, v)
}
func fooValue(ctx context) Foo {
v, _ := ctx.Value(fooKey{}).(Foo)
return v
}
where usage becomes:
ctx = withFoo(ctx, foo)
foo := fooValue(ctx)
With many different context keys, this can be quite tedious.
Using generics, we can simplify this as:
var fooKey = ctxkey.New("mypkg.fooKey", Foo{})
where usage becomes:
ctx = fooKey.WithValue(ctx, foo)
foo := fooKey.Value(ctx)
See https://go.dev/issue/49189
Updates #cleanup
Signed-off-by: Joe Tsai <joetsai@digital-static.net>
main
parent
c9fd166cc6
commit
241a541864
@ -0,0 +1,139 @@ |
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
// ctxkey provides type-safe key-value pairs for use with [context.Context].
|
||||
//
|
||||
// Example usage:
|
||||
//
|
||||
// // Create a context key.
|
||||
// var TimeoutKey = ctxkey.New("fsrv.Timeout", 5*time.Second)
|
||||
//
|
||||
// // Store a context value.
|
||||
// ctx = fsrv.TimeoutKey.WithValue(ctx, 10*time.Second)
|
||||
//
|
||||
// // Load a context value.
|
||||
// timeout := fsrv.TimeoutKey.Value(ctx)
|
||||
// ... // use timeout of type time.Duration
|
||||
//
|
||||
// This is inspired by https://go.dev/issue/49189.
|
||||
package ctxkey |
||||
|
||||
import ( |
||||
"context" |
||||
"fmt" |
||||
"reflect" |
||||
) |
||||
|
||||
// Key is a generic key type associated with a specific value type.
|
||||
//
|
||||
// A zero Key is valid where the Value type itself is used as the context key.
|
||||
// This pattern should only be used with locally declared Go types.
|
||||
// The Value type must not be an interface type.
|
||||
//
|
||||
// Example usage:
|
||||
//
|
||||
// type peerInfo struct { ... } // peerInfo is an unexported type
|
||||
// var peerInfoKey = ctxkey.Key[peerInfo]
|
||||
// ctx = peerInfoKey.WithValue(ctx, info) // store a context value
|
||||
// info = peerInfoKey.Value(ctx) // load a context value
|
||||
//
|
||||
// In general, any exported keys should be produced using [New].
|
||||
type Key[Value any] struct { |
||||
name *stringer[string] |
||||
defVal *Value |
||||
} |
||||
|
||||
// New constructs a new context key with an associated value type
|
||||
// where the default value for an unpopulated value is the provided value.
|
||||
//
|
||||
// The provided name is an arbitrary name only used for human debugging.
|
||||
// As a convention, it is recommended that the name be the dot-delimited
|
||||
// combination of the package name of the caller with the variable name.
|
||||
// Every key is unique, even if provided the same name.
|
||||
//
|
||||
// Example usage:
|
||||
//
|
||||
// package mapreduce
|
||||
// var NumWorkersKey = ctxkey.New("mapreduce.NumWorkers", runtime.NumCPU())
|
||||
func New[Value any](name string, defaultValue Value) Key[Value] { |
||||
if name == "" { |
||||
var v Value |
||||
name = reflect.TypeOf(v).String() // TODO(https://go.dev/issue/60088): Use reflect.TypeFor.
|
||||
} |
||||
var defVal *Value |
||||
switch v := reflect.ValueOf(&defaultValue).Elem(); { |
||||
case v.Kind() == reflect.Interface: |
||||
panic(fmt.Sprintf("value type %v must not be an interface", v.Type())) |
||||
case !v.IsZero(): |
||||
defVal = &defaultValue |
||||
} |
||||
// Allocate a *stringer to ensure that every invocation of New
|
||||
// creates a universally unique context key even for the same name.
|
||||
return Key[Value]{name: &stringer[string]{name}, defVal: defVal} |
||||
} |
||||
|
||||
// contextKey returns the context key to use.
|
||||
func (key Key[Value]) contextKey() any { |
||||
if key.name == nil { |
||||
// Use the reflect.Type of the Value (implies key not created by New).
|
||||
var v Value |
||||
t := reflect.TypeOf(v) |
||||
if t == nil { |
||||
panic(fmt.Sprintf("value type %v must not be an interface", reflect.TypeOf(&v).Elem())) |
||||
} |
||||
return t |
||||
} else { |
||||
// Use the name pointer directly (implies key created by New).
|
||||
return key.name |
||||
} |
||||
} |
||||
|
||||
// WithValue returns a copy of parent in which the value associated with key is val.
|
||||
//
|
||||
// It is a type-safe equivalent of [context.WithValue].
|
||||
func (key Key[Value]) WithValue(parent context.Context, val Value) context.Context { |
||||
return context.WithValue(parent, key.contextKey(), stringer[Value]{val}) |
||||
} |
||||
|
||||
// ValueOk returns the value in the context associated with this key
|
||||
// and also reports whether it was present.
|
||||
// If the value is not present, it returns the default value.
|
||||
func (key Key[Value]) ValueOk(ctx context.Context) (v Value, ok bool) { |
||||
vv, ok := ctx.Value(key.contextKey()).(stringer[Value]) |
||||
if !ok && key.defVal != nil { |
||||
vv.v = *key.defVal |
||||
} |
||||
return vv.v, ok |
||||
} |
||||
|
||||
// Value returns the value in the context associated with this key.
|
||||
// If the value is not present, it returns the default value.
|
||||
func (key Key[Value]) Value(ctx context.Context) (v Value) { |
||||
v, _ = key.ValueOk(ctx) |
||||
return v |
||||
} |
||||
|
||||
// Has reports whether the context has a value for this key.
|
||||
func (key Key[Value]) Has(ctx context.Context) (ok bool) { |
||||
_, ok = key.ValueOk(ctx) |
||||
return ok |
||||
} |
||||
|
||||
// String returns the name of the key.
|
||||
func (key Key[Value]) String() string { |
||||
if key.name == nil { |
||||
var v Value |
||||
return reflect.TypeOf(v).String() // TODO(https://go.dev/issue/60088): Use reflect.TypeFor.
|
||||
} |
||||
return key.name.String() |
||||
} |
||||
|
||||
// stringer implements [fmt.Stringer] on a generic T.
|
||||
//
|
||||
// This assists in debugging such that printing a context prints key and value.
|
||||
// Note that the [context] package lacks a dependency on [reflect],
|
||||
// so it cannot print arbitrary values. By implementing [fmt.Stringer],
|
||||
// we functionally teach a context how to print itself.
|
||||
type stringer[T any] struct{ v T } |
||||
|
||||
func (v stringer[T]) String() string { return fmt.Sprint(v.v) } |
||||
@ -0,0 +1,82 @@ |
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package ctxkey |
||||
|
||||
import ( |
||||
"context" |
||||
"fmt" |
||||
"regexp" |
||||
"testing" |
||||
"time" |
||||
|
||||
qt "github.com/frankban/quicktest" |
||||
) |
||||
|
||||
func TestKey(t *testing.T) { |
||||
c := qt.New(t) |
||||
ctx := context.Background() |
||||
|
||||
// Test keys with the same name as being distinct.
|
||||
k1 := New("same.Name", "") |
||||
c.Assert(k1.String(), qt.Equals, "same.Name") |
||||
k2 := New("same.Name", "") |
||||
c.Assert(k2.String(), qt.Equals, "same.Name") |
||||
c.Assert(k1 == k2, qt.Equals, false) |
||||
ctx = k1.WithValue(ctx, "hello") |
||||
c.Assert(k1.Has(ctx), qt.Equals, true) |
||||
c.Assert(k1.Value(ctx), qt.Equals, "hello") |
||||
c.Assert(k2.Has(ctx), qt.Equals, false) |
||||
c.Assert(k2.Value(ctx), qt.Equals, "") |
||||
ctx = k2.WithValue(ctx, "goodbye") |
||||
c.Assert(k1.Has(ctx), qt.Equals, true) |
||||
c.Assert(k1.Value(ctx), qt.Equals, "hello") |
||||
c.Assert(k2.Has(ctx), qt.Equals, true) |
||||
c.Assert(k2.Value(ctx), qt.Equals, "goodbye") |
||||
|
||||
// Test default value.
|
||||
k3 := New("mapreduce.Timeout", time.Hour) |
||||
c.Assert(k3.Has(ctx), qt.Equals, false) |
||||
c.Assert(k3.Value(ctx), qt.Equals, time.Hour) |
||||
ctx = k3.WithValue(ctx, time.Minute) |
||||
c.Assert(k3.Has(ctx), qt.Equals, true) |
||||
c.Assert(k3.Value(ctx), qt.Equals, time.Minute) |
||||
|
||||
// Test incomparable value.
|
||||
k4 := New("slice", []int(nil)) |
||||
c.Assert(k4.Has(ctx), qt.Equals, false) |
||||
c.Assert(k4.Value(ctx), qt.DeepEquals, []int(nil)) |
||||
ctx = k4.WithValue(ctx, []int{1, 2, 3}) |
||||
c.Assert(k4.Has(ctx), qt.Equals, true) |
||||
c.Assert(k4.Value(ctx), qt.DeepEquals, []int{1, 2, 3}) |
||||
|
||||
// Accessors should be allocation free.
|
||||
c.Assert(testing.AllocsPerRun(100, func() { |
||||
k1.Value(ctx) |
||||
k1.Has(ctx) |
||||
k1.ValueOk(ctx) |
||||
}), qt.Equals, 0.0) |
||||
|
||||
// Test keys that are created without New.
|
||||
var k5 Key[string] |
||||
c.Assert(k5.String(), qt.Equals, "string") |
||||
c.Assert(k1 == k5, qt.Equals, false) // should be different from key created by New
|
||||
c.Assert(k5.Has(ctx), qt.Equals, false) |
||||
ctx = k5.WithValue(ctx, "fizz") |
||||
c.Assert(k5.Value(ctx), qt.Equals, "fizz") |
||||
var k6 Key[string] |
||||
c.Assert(k6.String(), qt.Equals, "string") |
||||
c.Assert(k5 == k6, qt.Equals, true) |
||||
c.Assert(k6.Has(ctx), qt.Equals, true) |
||||
ctx = k6.WithValue(ctx, "fizz") |
||||
} |
||||
|
||||
func TestStringer(t *testing.T) { |
||||
t.SkipNow() // TODO(https://go.dev/cl/555697): Enable this after fix is merged upstream.
|
||||
c := qt.New(t) |
||||
ctx := context.Background() |
||||
c.Assert(fmt.Sprint(New("foo.Bar", "").WithValue(ctx, "baz")), qt.Matches, regexp.MustCompile("foo.Bar.*baz")) |
||||
c.Assert(fmt.Sprint(New("", []int{}).WithValue(ctx, []int{1, 2, 3})), qt.Matches, regexp.MustCompile(fmt.Sprintf("%[1]T.*%[1]v", []int{1, 2, 3}))) |
||||
c.Assert(fmt.Sprint(New("", 0).WithValue(ctx, 5)), qt.Matches, regexp.MustCompile("int.*5")) |
||||
c.Assert(fmt.Sprint(Key[time.Duration]{}.WithValue(ctx, time.Hour)), qt.Matches, regexp.MustCompile(fmt.Sprintf("%[1]T.*%[1]v", time.Hour))) |
||||
} |
||||
Loading…
Reference in new issue