all: make use of ctxkey everywhere (#10846)

Also perform minor cleanups on the ctxkey package itself.
Provide guidance on when to use ctxkey.Key[T] over ctxkey.New.
Also, allow for interface kinds because the value wrapping trick
also happens to fix edge cases with interfaces in Go.

Updates #cleanup

Signed-off-by: Joe Tsai <joetsai@digital-static.net>
This commit is contained in:
Joe Tsai
2024-01-16 13:56:23 -08:00
committed by GitHub
parent 7732377cd7
commit c25968e1c5
13 changed files with 97 additions and 85 deletions
+8 -10
View File
@@ -8,6 +8,7 @@ import (
"net/http"
"github.com/google/uuid"
"tailscale.com/util/ctxkey"
)
// RequestID is an opaque identifier for a HTTP request, used to correlate
@@ -24,6 +25,9 @@ import (
// opaque string. The current implementation uses a UUID.
type RequestID string
// RequestIDKey stores and loads [RequestID] values within a [context.Context].
var RequestIDKey ctxkey.Key[RequestID]
// RequestIDHeader is a custom HTTP header that the WithRequestID middleware
// uses to determine whether to re-use a given request ID from the client
// or generate a new one.
@@ -42,22 +46,16 @@ func SetRequestID(h http.Handler) http.Handler {
// transitions if needed.
id = "REQ-1" + uuid.NewString()
}
ctx := withRequestID(r.Context(), RequestID(id))
ctx := RequestIDKey.WithValue(r.Context(), RequestID(id))
r = r.WithContext(ctx)
h.ServeHTTP(w, r)
})
}
type requestIDKey struct{}
// RequestIDFromContext retrieves the RequestID from context that can be set by
// the SetRequestID function.
//
// Deprecated: Use [RequestIDKey.Value] instead.
func RequestIDFromContext(ctx context.Context) RequestID {
val, _ := ctx.Value(requestIDKey{}).(RequestID)
return val
}
// withRequestID sets the given request id value in the given context.
func withRequestID(ctx context.Context, rid RequestID) context.Context {
return context.WithValue(ctx, requestIDKey{}, rid)
return RequestIDKey.Value(ctx)
}
+7 -7
View File
@@ -166,7 +166,7 @@ func TestStdHandler(t *testing.T) {
{
name: "handler returns 404 via HTTPError with request ID",
rh: handlerErr(0, Error(404, "not found", testErr)),
r: req(withRequestID(bgCtx, exampleRequestID), "http://example.com/foo"),
r: req(RequestIDKey.WithValue(bgCtx, exampleRequestID), "http://example.com/foo"),
wantCode: 404,
wantLog: AccessLogRecord{
When: startTime,
@@ -203,7 +203,7 @@ func TestStdHandler(t *testing.T) {
{
name: "handler returns 404 with request ID and nil child error",
rh: handlerErr(0, Error(404, "not found", nil)),
r: req(withRequestID(bgCtx, exampleRequestID), "http://example.com/foo"),
r: req(RequestIDKey.WithValue(bgCtx, exampleRequestID), "http://example.com/foo"),
wantCode: 404,
wantLog: AccessLogRecord{
When: startTime,
@@ -240,7 +240,7 @@ func TestStdHandler(t *testing.T) {
{
name: "handler returns user-visible error with request ID",
rh: handlerErr(0, vizerror.New("visible error")),
r: req(withRequestID(bgCtx, exampleRequestID), "http://example.com/foo"),
r: req(RequestIDKey.WithValue(bgCtx, exampleRequestID), "http://example.com/foo"),
wantCode: 500,
wantLog: AccessLogRecord{
When: startTime,
@@ -277,7 +277,7 @@ func TestStdHandler(t *testing.T) {
{
name: "handler returns user-visible error wrapped by private error with request ID",
rh: handlerErr(0, fmt.Errorf("private internal error: %w", vizerror.New("visible error"))),
r: req(withRequestID(bgCtx, exampleRequestID), "http://example.com/foo"),
r: req(RequestIDKey.WithValue(bgCtx, exampleRequestID), "http://example.com/foo"),
wantCode: 500,
wantLog: AccessLogRecord{
When: startTime,
@@ -314,7 +314,7 @@ func TestStdHandler(t *testing.T) {
{
name: "handler returns generic error with request ID",
rh: handlerErr(0, testErr),
r: req(withRequestID(bgCtx, exampleRequestID), "http://example.com/foo"),
r: req(RequestIDKey.WithValue(bgCtx, exampleRequestID), "http://example.com/foo"),
wantCode: 500,
wantLog: AccessLogRecord{
When: startTime,
@@ -350,7 +350,7 @@ func TestStdHandler(t *testing.T) {
{
name: "handler returns error after writing response with request ID",
rh: handlerErr(200, testErr),
r: req(withRequestID(bgCtx, exampleRequestID), "http://example.com/foo"),
r: req(RequestIDKey.WithValue(bgCtx, exampleRequestID), "http://example.com/foo"),
wantCode: 200,
wantLog: AccessLogRecord{
When: startTime,
@@ -446,7 +446,7 @@ func TestStdHandler(t *testing.T) {
{
name: "error handler gets run with request ID",
rh: handlerErr(0, Error(404, "not found", nil)), // status code changed in errHandler
r: req(withRequestID(bgCtx, exampleRequestID), "http://example.com/"),
r: req(RequestIDKey.WithValue(bgCtx, exampleRequestID), "http://example.com/"),
wantCode: 200,
errHandler: func(w http.ResponseWriter, r *http.Request, e HTTPError) {
requestID := RequestIDFromContext(r.Context())