Also adds tests, because the logging handler is acquiring a fair number of branches. Signed-off-by: David Anderson <dave@natulte.net>main
parent
2e43cd3f95
commit
12a6626a94
@ -0,0 +1,254 @@ |
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package tsweb |
||||
|
||||
import ( |
||||
"bufio" |
||||
"context" |
||||
"errors" |
||||
"net" |
||||
"net/http" |
||||
"net/http/httptest" |
||||
"testing" |
||||
"time" |
||||
|
||||
"github.com/google/go-cmp/cmp" |
||||
"tailscale.com/testy" |
||||
) |
||||
|
||||
type noopHijacker struct { |
||||
*httptest.ResponseRecorder |
||||
hijacked bool |
||||
} |
||||
|
||||
func (h *noopHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) { |
||||
// Hijack "successfully" but don't bother returning a conn.
|
||||
h.hijacked = true |
||||
return nil, nil, nil |
||||
} |
||||
|
||||
func TestStdHandler(t *testing.T) { |
||||
var ( |
||||
handlerCode = func(code int) Handler { |
||||
return HandlerFunc(func(w http.ResponseWriter, r *http.Request) error { |
||||
w.WriteHeader(code) |
||||
return nil |
||||
}) |
||||
} |
||||
handlerErr = func(code int, err error) Handler { |
||||
return HandlerFunc(func(w http.ResponseWriter, r *http.Request) error { |
||||
if code != 0 { |
||||
w.WriteHeader(code) |
||||
} |
||||
return err |
||||
}) |
||||
} |
||||
|
||||
req = func(ctx context.Context, url string) *http.Request { |
||||
ret, err := http.NewRequestWithContext(ctx, "GET", url, nil) |
||||
if err != nil { |
||||
panic(err) |
||||
} |
||||
return ret |
||||
} |
||||
|
||||
testErr = errors.New("test error") |
||||
bgCtx = context.Background() |
||||
// canceledCtx, cancel = context.WithCancel(bgCtx)
|
||||
clock = testy.Clock{ |
||||
Start: time.Now(), |
||||
Step: time.Second, |
||||
} |
||||
) |
||||
// cancel()
|
||||
|
||||
tests := []struct { |
||||
name string |
||||
h Handler |
||||
r *http.Request |
||||
wantCode int |
||||
wantLog AccessLogRecord |
||||
}{ |
||||
{ |
||||
name: "handler returns 200", |
||||
h: handlerCode(200), |
||||
r: req(bgCtx, "http://example.com/"), |
||||
wantCode: 200, |
||||
wantLog: AccessLogRecord{ |
||||
When: clock.Start, |
||||
Seconds: 1.0, |
||||
Proto: "HTTP/1.1", |
||||
TLS: false, |
||||
Host: "example.com", |
||||
Method: "GET", |
||||
Code: 200, |
||||
RequestURI: "/", |
||||
}, |
||||
}, |
||||
|
||||
{ |
||||
name: "handler returns 404", |
||||
h: handlerCode(404), |
||||
r: req(bgCtx, "http://example.com/foo"), |
||||
wantCode: 404, |
||||
wantLog: AccessLogRecord{ |
||||
When: clock.Start, |
||||
Seconds: 1.0, |
||||
Proto: "HTTP/1.1", |
||||
Host: "example.com", |
||||
Method: "GET", |
||||
RequestURI: "/foo", |
||||
Code: 404, |
||||
}, |
||||
}, |
||||
|
||||
{ |
||||
name: "handler returns 404 via HTTPError", |
||||
h: handlerErr(0, Error(404, "not found", testErr)), |
||||
r: req(bgCtx, "http://example.com/foo"), |
||||
wantCode: 404, |
||||
wantLog: AccessLogRecord{ |
||||
When: clock.Start, |
||||
Seconds: 1.0, |
||||
Proto: "HTTP/1.1", |
||||
Host: "example.com", |
||||
Method: "GET", |
||||
RequestURI: "/foo", |
||||
Err: testErr.Error(), |
||||
Code: 404, |
||||
}, |
||||
}, |
||||
|
||||
{ |
||||
name: "handler returns generic error", |
||||
h: handlerErr(0, testErr), |
||||
r: req(bgCtx, "http://example.com/foo"), |
||||
wantCode: 500, |
||||
wantLog: AccessLogRecord{ |
||||
When: clock.Start, |
||||
Seconds: 1.0, |
||||
Proto: "HTTP/1.1", |
||||
Host: "example.com", |
||||
Method: "GET", |
||||
RequestURI: "/foo", |
||||
Err: testErr.Error(), |
||||
Code: 500, |
||||
}, |
||||
}, |
||||
|
||||
{ |
||||
name: "handler returns error after writing response", |
||||
h: handlerErr(200, testErr), |
||||
r: req(bgCtx, "http://example.com/foo"), |
||||
wantCode: 200, |
||||
wantLog: AccessLogRecord{ |
||||
When: clock.Start, |
||||
Seconds: 1.0, |
||||
Proto: "HTTP/1.1", |
||||
Host: "example.com", |
||||
Method: "GET", |
||||
RequestURI: "/foo", |
||||
Err: testErr.Error(), |
||||
Code: 200, |
||||
}, |
||||
}, |
||||
|
||||
{ |
||||
name: "handler returns HTTPError after writing response", |
||||
h: handlerErr(200, Error(404, "not found", testErr)), |
||||
r: req(bgCtx, "http://example.com/foo"), |
||||
wantCode: 200, |
||||
wantLog: AccessLogRecord{ |
||||
When: clock.Start, |
||||
Seconds: 1.0, |
||||
Proto: "HTTP/1.1", |
||||
Host: "example.com", |
||||
Method: "GET", |
||||
RequestURI: "/foo", |
||||
Err: testErr.Error(), |
||||
Code: 200, |
||||
}, |
||||
}, |
||||
|
||||
{ |
||||
name: "handler does nothing", |
||||
h: HandlerFunc(func(http.ResponseWriter, *http.Request) error { return nil }), |
||||
r: req(bgCtx, "http://example.com/foo"), |
||||
wantCode: 500, |
||||
wantLog: AccessLogRecord{ |
||||
When: clock.Start, |
||||
Seconds: 1.0, |
||||
Proto: "HTTP/1.1", |
||||
Host: "example.com", |
||||
Method: "GET", |
||||
RequestURI: "/foo", |
||||
Code: 500, |
||||
Err: "[unexpected] handler did not respond to the client", |
||||
}, |
||||
}, |
||||
|
||||
{ |
||||
name: "handler hijacks conn", |
||||
h: HandlerFunc(func(w http.ResponseWriter, r *http.Request) error { |
||||
_, _, err := w.(http.Hijacker).Hijack() |
||||
if err != nil { |
||||
t.Errorf("couldn't hijack: %v", err) |
||||
} |
||||
return err |
||||
}), |
||||
r: req(bgCtx, "http://example.com/foo"), |
||||
wantCode: 0, |
||||
wantLog: AccessLogRecord{ |
||||
When: clock.Start, |
||||
Seconds: 1.0, |
||||
|
||||
Proto: "HTTP/1.1", |
||||
Host: "example.com", |
||||
Method: "GET", |
||||
RequestURI: "/foo", |
||||
Code: 101, |
||||
}, |
||||
}, |
||||
} |
||||
|
||||
for _, test := range tests { |
||||
t.Run(test.name, func(t *testing.T) { |
||||
var logs []AccessLogRecord |
||||
logf := func(fmt string, args ...interface{}) { |
||||
if fmt == "%s" { |
||||
logs = append(logs, args[0].(AccessLogRecord)) |
||||
} |
||||
t.Logf(fmt, args...) |
||||
} |
||||
|
||||
clock.Reset() |
||||
|
||||
rec := noopHijacker{httptest.NewRecorder(), false} |
||||
// ResponseRecorder defaults Code to 200, grump.
|
||||
rec.Code = 0 |
||||
h := stdHandler(test.h, logf, clock.Now) |
||||
h.ServeHTTP(&rec, test.r) |
||||
if rec.Code != test.wantCode { |
||||
t.Errorf("HTTP code = %v, want %v", rec.Code, test.wantCode) |
||||
} |
||||
if !rec.hijacked && !rec.Flushed { |
||||
t.Errorf("handler didn't flush") |
||||
} |
||||
if len(logs) != 1 { |
||||
t.Errorf("handler didn't write a request log") |
||||
return |
||||
} |
||||
errTransform := cmp.Transformer("err", func(e error) string { |
||||
if e == nil { |
||||
return "" |
||||
} |
||||
return e.Error() |
||||
}) |
||||
if diff := cmp.Diff(logs[0], test.wantLog, errTransform); diff != "" { |
||||
t.Errorf("handler wrote incorrect request log (-got+want):\n%s", diff) |
||||
} |
||||
}) |
||||
} |
||||
} |
||||
Loading…
Reference in new issue