The c2n handling code was using the Go httptest package's ResponseRecorder code but that's in a test package which brings in Go's test certs, etc. This forks the httptest recorder type into its own package that only has the recorder and adds a test that we don't re-introduce a dependency on httptest. Updates #12614 Change-Id: I3546f49972981e21813ece9064cc2be0b74f4b16 Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>main
parent
8c925899e1
commit
bce05ec6c3
@ -0,0 +1,258 @@ |
||||
// Copyright 2011 The Go Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package httprec is a copy of the Go standard library's httptest.ResponseRecorder
|
||||
// type, which we want to use in non-test code without pulling in the rest of
|
||||
// the httptest package and its test certs, etc.
|
||||
package httprec |
||||
|
||||
import ( |
||||
"bytes" |
||||
"fmt" |
||||
"io" |
||||
"net/http" |
||||
"net/textproto" |
||||
"strconv" |
||||
"strings" |
||||
|
||||
"golang.org/x/net/http/httpguts" |
||||
) |
||||
|
||||
// ResponseRecorder is an implementation of [http.ResponseWriter] that
|
||||
// records its mutations for later inspection in tests.
|
||||
type ResponseRecorder struct { |
||||
// Code is the HTTP response code set by WriteHeader.
|
||||
//
|
||||
// Note that if a Handler never calls WriteHeader or Write,
|
||||
// this might end up being 0, rather than the implicit
|
||||
// http.StatusOK. To get the implicit value, use the Result
|
||||
// method.
|
||||
Code int |
||||
|
||||
// HeaderMap contains the headers explicitly set by the Handler.
|
||||
// It is an internal detail.
|
||||
//
|
||||
// Deprecated: HeaderMap exists for historical compatibility
|
||||
// and should not be used. To access the headers returned by a handler,
|
||||
// use the Response.Header map as returned by the Result method.
|
||||
HeaderMap http.Header |
||||
|
||||
// Body is the buffer to which the Handler's Write calls are sent.
|
||||
// If nil, the Writes are silently discarded.
|
||||
Body *bytes.Buffer |
||||
|
||||
// Flushed is whether the Handler called Flush.
|
||||
Flushed bool |
||||
|
||||
result *http.Response // cache of Result's return value
|
||||
snapHeader http.Header // snapshot of HeaderMap at first Write
|
||||
wroteHeader bool |
||||
} |
||||
|
||||
// NewRecorder returns an initialized [ResponseRecorder].
|
||||
func NewRecorder() *ResponseRecorder { |
||||
return &ResponseRecorder{ |
||||
HeaderMap: make(http.Header), |
||||
Body: new(bytes.Buffer), |
||||
Code: 200, |
||||
} |
||||
} |
||||
|
||||
// DefaultRemoteAddr is the default remote address to return in RemoteAddr if
|
||||
// an explicit DefaultRemoteAddr isn't set on [ResponseRecorder].
|
||||
const DefaultRemoteAddr = "1.2.3.4" |
||||
|
||||
// Header implements [http.ResponseWriter]. It returns the response
|
||||
// headers to mutate within a handler. To test the headers that were
|
||||
// written after a handler completes, use the [ResponseRecorder.Result] method and see
|
||||
// the returned Response value's Header.
|
||||
func (rw *ResponseRecorder) Header() http.Header { |
||||
m := rw.HeaderMap |
||||
if m == nil { |
||||
m = make(http.Header) |
||||
rw.HeaderMap = m |
||||
} |
||||
return m |
||||
} |
||||
|
||||
// writeHeader writes a header if it was not written yet and
|
||||
// detects Content-Type if needed.
|
||||
//
|
||||
// bytes or str are the beginning of the response body.
|
||||
// We pass both to avoid unnecessarily generate garbage
|
||||
// in rw.WriteString which was created for performance reasons.
|
||||
// Non-nil bytes win.
|
||||
func (rw *ResponseRecorder) writeHeader(b []byte, str string) { |
||||
if rw.wroteHeader { |
||||
return |
||||
} |
||||
if len(str) > 512 { |
||||
str = str[:512] |
||||
} |
||||
|
||||
m := rw.Header() |
||||
|
||||
_, hasType := m["Content-Type"] |
||||
hasTE := m.Get("Transfer-Encoding") != "" |
||||
if !hasType && !hasTE { |
||||
if b == nil { |
||||
b = []byte(str) |
||||
} |
||||
m.Set("Content-Type", http.DetectContentType(b)) |
||||
} |
||||
|
||||
rw.WriteHeader(200) |
||||
} |
||||
|
||||
// Write implements http.ResponseWriter. The data in buf is written to
|
||||
// rw.Body, if not nil.
|
||||
func (rw *ResponseRecorder) Write(buf []byte) (int, error) { |
||||
rw.writeHeader(buf, "") |
||||
if rw.Body != nil { |
||||
rw.Body.Write(buf) |
||||
} |
||||
return len(buf), nil |
||||
} |
||||
|
||||
// WriteString implements [io.StringWriter]. The data in str is written
|
||||
// to rw.Body, if not nil.
|
||||
func (rw *ResponseRecorder) WriteString(str string) (int, error) { |
||||
rw.writeHeader(nil, str) |
||||
if rw.Body != nil { |
||||
rw.Body.WriteString(str) |
||||
} |
||||
return len(str), nil |
||||
} |
||||
|
||||
func checkWriteHeaderCode(code int) { |
||||
// Issue 22880: require valid WriteHeader status codes.
|
||||
// For now we only enforce that it's three digits.
|
||||
// In the future we might block things over 599 (600 and above aren't defined
|
||||
// at https://httpwg.org/specs/rfc7231.html#status.codes)
|
||||
// and we might block under 200 (once we have more mature 1xx support).
|
||||
// But for now any three digits.
|
||||
//
|
||||
// We used to send "HTTP/1.1 000 0" on the wire in responses but there's
|
||||
// no equivalent bogus thing we can realistically send in HTTP/2,
|
||||
// so we'll consistently panic instead and help people find their bugs
|
||||
// early. (We can't return an error from WriteHeader even if we wanted to.)
|
||||
if code < 100 || code > 999 { |
||||
panic(fmt.Sprintf("invalid WriteHeader code %v", code)) |
||||
} |
||||
} |
||||
|
||||
// WriteHeader implements [http.ResponseWriter].
|
||||
func (rw *ResponseRecorder) WriteHeader(code int) { |
||||
if rw.wroteHeader { |
||||
return |
||||
} |
||||
|
||||
checkWriteHeaderCode(code) |
||||
rw.Code = code |
||||
rw.wroteHeader = true |
||||
if rw.HeaderMap == nil { |
||||
rw.HeaderMap = make(http.Header) |
||||
} |
||||
rw.snapHeader = rw.HeaderMap.Clone() |
||||
} |
||||
|
||||
// Flush implements [http.Flusher]. To test whether Flush was
|
||||
// called, see rw.Flushed.
|
||||
func (rw *ResponseRecorder) Flush() { |
||||
if !rw.wroteHeader { |
||||
rw.WriteHeader(200) |
||||
} |
||||
rw.Flushed = true |
||||
} |
||||
|
||||
// Result returns the response generated by the handler.
|
||||
//
|
||||
// The returned Response will have at least its StatusCode,
|
||||
// Header, Body, and optionally Trailer populated.
|
||||
// More fields may be populated in the future, so callers should
|
||||
// not DeepEqual the result in tests.
|
||||
//
|
||||
// The Response.Header is a snapshot of the headers at the time of the
|
||||
// first write call, or at the time of this call, if the handler never
|
||||
// did a write.
|
||||
//
|
||||
// The Response.Body is guaranteed to be non-nil and Body.Read call is
|
||||
// guaranteed to not return any error other than [io.EOF].
|
||||
//
|
||||
// Result must only be called after the handler has finished running.
|
||||
func (rw *ResponseRecorder) Result() *http.Response { |
||||
if rw.result != nil { |
||||
return rw.result |
||||
} |
||||
if rw.snapHeader == nil { |
||||
rw.snapHeader = rw.HeaderMap.Clone() |
||||
} |
||||
res := &http.Response{ |
||||
Proto: "HTTP/1.1", |
||||
ProtoMajor: 1, |
||||
ProtoMinor: 1, |
||||
StatusCode: rw.Code, |
||||
Header: rw.snapHeader, |
||||
} |
||||
rw.result = res |
||||
if res.StatusCode == 0 { |
||||
res.StatusCode = 200 |
||||
} |
||||
res.Status = fmt.Sprintf("%03d %s", res.StatusCode, http.StatusText(res.StatusCode)) |
||||
if rw.Body != nil { |
||||
res.Body = io.NopCloser(bytes.NewReader(rw.Body.Bytes())) |
||||
} else { |
||||
res.Body = http.NoBody |
||||
} |
||||
res.ContentLength = parseContentLength(res.Header.Get("Content-Length")) |
||||
|
||||
if trailers, ok := rw.snapHeader["Trailer"]; ok { |
||||
res.Trailer = make(http.Header, len(trailers)) |
||||
for _, k := range trailers { |
||||
for _, k := range strings.Split(k, ",") { |
||||
k = http.CanonicalHeaderKey(textproto.TrimString(k)) |
||||
if !httpguts.ValidTrailerHeader(k) { |
||||
// Ignore since forbidden by RFC 7230, section 4.1.2.
|
||||
continue |
||||
} |
||||
vv, ok := rw.HeaderMap[k] |
||||
if !ok { |
||||
continue |
||||
} |
||||
vv2 := make([]string, len(vv)) |
||||
copy(vv2, vv) |
||||
res.Trailer[k] = vv2 |
||||
} |
||||
} |
||||
} |
||||
for k, vv := range rw.HeaderMap { |
||||
if !strings.HasPrefix(k, http.TrailerPrefix) { |
||||
continue |
||||
} |
||||
if res.Trailer == nil { |
||||
res.Trailer = make(http.Header) |
||||
} |
||||
for _, v := range vv { |
||||
res.Trailer.Add(strings.TrimPrefix(k, http.TrailerPrefix), v) |
||||
} |
||||
} |
||||
return res |
||||
} |
||||
|
||||
// parseContentLength trims whitespace from s and returns -1 if no value
|
||||
// is set, or the value if it's >= 0.
|
||||
//
|
||||
// This a modified version of same function found in net/http/transfer.go. This
|
||||
// one just ignores an invalid header.
|
||||
func parseContentLength(cl string) int64 { |
||||
cl = textproto.TrimString(cl) |
||||
if cl == "" { |
||||
return -1 |
||||
} |
||||
n, err := strconv.ParseUint(cl, 10, 63) |
||||
if err != nil { |
||||
return -1 |
||||
} |
||||
return int64(n) |
||||
} |
||||
Loading…
Reference in new issue