tsweb: JSONHandler using reflect (#437)
Updates #395 #437 Signed-off-by: Zijie Lu <zijie@tailscale.com>main
parent
059b1d10bb
commit
1d2e497d47
@ -0,0 +1,125 @@ |
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package tsweb |
||||
|
||||
import ( |
||||
"encoding/json" |
||||
"net/http" |
||||
"reflect" |
||||
) |
||||
|
||||
type response struct { |
||||
Status string `json:"status"` |
||||
Error string `json:"error,omitempty"` |
||||
Data interface{} `json:"data,omitempty"` |
||||
} |
||||
|
||||
func responseSuccess(data interface{}) *response { |
||||
return &response{ |
||||
Status: "success", |
||||
Data: data, |
||||
} |
||||
} |
||||
|
||||
func responseError(e string) *response { |
||||
return &response{ |
||||
Status: "error", |
||||
Error: e, |
||||
} |
||||
} |
||||
|
||||
func writeResponse(w http.ResponseWriter, s int, resp *response) { |
||||
b, _ := json.Marshal(resp) |
||||
w.WriteHeader(s) |
||||
w.Header().Set("Content-Type", "application/json") |
||||
w.Write(b) |
||||
} |
||||
|
||||
func checkFn(t reflect.Type) { |
||||
h := reflect.TypeOf(http.HandlerFunc(nil)) |
||||
switch t.NumIn() { |
||||
case 2, 3: |
||||
if !t.In(0).AssignableTo(h.In(0)) { |
||||
panic("first argument must be http.ResponseWriter") |
||||
} |
||||
if !t.In(1).AssignableTo(h.In(1)) { |
||||
panic("second argument must be *http.Request") |
||||
} |
||||
default: |
||||
panic("JSONHandler: number of input parameter should be 2 or 3") |
||||
} |
||||
|
||||
switch t.NumOut() { |
||||
case 1: |
||||
if !t.Out(0).Implements(reflect.TypeOf((*error)(nil)).Elem()) { |
||||
panic("return value must be error") |
||||
} |
||||
case 2: |
||||
if !t.Out(1).Implements(reflect.TypeOf((*error)(nil)).Elem()) { |
||||
panic("second return value must be error") |
||||
} |
||||
default: |
||||
panic("JSONHandler: number of return values should be 1 or 2") |
||||
} |
||||
} |
||||
|
||||
// JSONHandler wraps an HTTP handler function with a version that automatically
|
||||
// unmarshals and marshals requests and responses respectively into fn's arguments
|
||||
// and results.
|
||||
//
|
||||
// The fn parameter is a function. It must take two or three input arguments.
|
||||
// The first two arguments must be http.ResponseWriter and *http.Request.
|
||||
// The optional third argument can be of any type representing the JSON input.
|
||||
// The function's results can be either (error) or (T, error), where T is the
|
||||
// JSON-marshalled result type.
|
||||
//
|
||||
// For example:
|
||||
// fn := func(w http.ResponseWriter, r *http.Request, in *Req) (*Res, error) { ... }
|
||||
func JSONHandler(fn interface{}) http.Handler { |
||||
v := reflect.ValueOf(fn) |
||||
t := v.Type() |
||||
checkFn(t) |
||||
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
||||
wv := reflect.ValueOf(w) |
||||
rv := reflect.ValueOf(r) |
||||
var vs []reflect.Value |
||||
|
||||
switch t.NumIn() { |
||||
case 2: |
||||
vs = v.Call([]reflect.Value{wv, rv}) |
||||
case 3: |
||||
dv := reflect.New(t.In(2)) |
||||
err := json.NewDecoder(r.Body).Decode(dv.Interface()) |
||||
if err != nil { |
||||
writeResponse(w, http.StatusBadRequest, responseError("bad json")) |
||||
return |
||||
} |
||||
vs = v.Call([]reflect.Value{wv, rv, dv.Elem()}) |
||||
default: |
||||
panic("JSONHandler: number of input parameter should be 2 or 3") |
||||
} |
||||
|
||||
switch len(vs) { |
||||
case 1: |
||||
// todo support other error types
|
||||
if vs[0].IsNil() { |
||||
writeResponse(w, http.StatusOK, responseSuccess(nil)) |
||||
} else { |
||||
err := vs[0].Interface().(error) |
||||
writeResponse(w, http.StatusBadRequest, responseError(err.Error())) |
||||
} |
||||
case 2: |
||||
if vs[1].IsNil() { |
||||
writeResponse(w, http.StatusOK, responseSuccess(vs[0].Interface())) |
||||
} else { |
||||
err := vs[1].Interface().(error) |
||||
writeResponse(w, http.StatusBadRequest, responseError(err.Error())) |
||||
} |
||||
default: |
||||
panic("JSONHandler: number of return values should be 1 or 2") |
||||
} |
||||
}) |
||||
} |
||||
@ -0,0 +1,175 @@ |
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package tsweb |
||||
|
||||
import ( |
||||
"encoding/json" |
||||
"errors" |
||||
"net/http" |
||||
"net/http/httptest" |
||||
"strings" |
||||
"testing" |
||||
) |
||||
|
||||
type Data struct { |
||||
Name string |
||||
Price int |
||||
} |
||||
|
||||
type Response struct { |
||||
Status string |
||||
Error string |
||||
Data *Data |
||||
} |
||||
|
||||
func TestNewJSONHandler(t *testing.T) { |
||||
checkStatus := func(w *httptest.ResponseRecorder, status string) *Response { |
||||
d := &Response{ |
||||
Data: &Data{}, |
||||
} |
||||
|
||||
t.Logf("%s", w.Body.Bytes()) |
||||
err := json.Unmarshal(w.Body.Bytes(), d) |
||||
if err != nil { |
||||
t.Logf(err.Error()) |
||||
return nil |
||||
} |
||||
|
||||
if d.Status == status { |
||||
t.Logf("ok: %s", d.Status) |
||||
} else { |
||||
t.Fatalf("wrong status: %s %s", d.Status, status) |
||||
} |
||||
|
||||
return d |
||||
} |
||||
|
||||
// 2 1
|
||||
h21 := JSONHandler(func(w http.ResponseWriter, r *http.Request) error { |
||||
return nil |
||||
}) |
||||
|
||||
t.Run("2 1 simple", func(t *testing.T) { |
||||
w := httptest.NewRecorder() |
||||
r := httptest.NewRequest("GET", "/", nil) |
||||
h21.ServeHTTP(w, r) |
||||
checkStatus(w, "success") |
||||
}) |
||||
|
||||
// 2 2
|
||||
h22 := JSONHandler(func(w http.ResponseWriter, r *http.Request) (*Data, error) { |
||||
return &Data{Name: "tailscale"}, nil |
||||
}) |
||||
t.Run("2 2 get data", func(t *testing.T) { |
||||
w := httptest.NewRecorder() |
||||
r := httptest.NewRequest("GET", "/", nil) |
||||
h22.ServeHTTP(w, r) |
||||
checkStatus(w, "success") |
||||
}) |
||||
|
||||
// 3 1
|
||||
h31 := JSONHandler(func(w http.ResponseWriter, r *http.Request, d *Data) error { |
||||
if d.Name == "" { |
||||
return errors.New("name is empty") |
||||
} |
||||
|
||||
return nil |
||||
}) |
||||
t.Run("3 1 post data", func(t *testing.T) { |
||||
w := httptest.NewRecorder() |
||||
r := httptest.NewRequest("POST", "/", strings.NewReader(`{"Name": "tailscale"}`)) |
||||
h31.ServeHTTP(w, r) |
||||
checkStatus(w, "success") |
||||
}) |
||||
|
||||
t.Run("3 1 bad json", func(t *testing.T) { |
||||
w := httptest.NewRecorder() |
||||
r := httptest.NewRequest("POST", "/", strings.NewReader(`{`)) |
||||
h31.ServeHTTP(w, r) |
||||
checkStatus(w, "error") |
||||
}) |
||||
|
||||
t.Run("3 1 post data error", func(t *testing.T) { |
||||
w := httptest.NewRecorder() |
||||
r := httptest.NewRequest("POST", "/", strings.NewReader(`{}`)) |
||||
h31.ServeHTTP(w, r) |
||||
resp := checkStatus(w, "error") |
||||
if resp.Error != "name is empty" { |
||||
t.Fatalf("wrong error") |
||||
} |
||||
}) |
||||
|
||||
// 3 2
|
||||
h32 := JSONHandler(func(w http.ResponseWriter, r *http.Request, d *Data) (*Data, error) { |
||||
if d.Price == 0 { |
||||
return nil, errors.New("price is empty") |
||||
} |
||||
|
||||
return &Data{Price: d.Price * 2}, nil |
||||
}) |
||||
t.Run("3 2 post data", func(t *testing.T) { |
||||
w := httptest.NewRecorder() |
||||
r := httptest.NewRequest("POST", "/", strings.NewReader(`{"Price": 10}`)) |
||||
h32.ServeHTTP(w, r) |
||||
resp := checkStatus(w, "success") |
||||
t.Log(resp.Data) |
||||
if resp.Data.Price != 20 { |
||||
t.Fatalf("wrong price: %d %d", resp.Data.Price, 10) |
||||
} |
||||
}) |
||||
|
||||
t.Run("3 2 post data error", func(t *testing.T) { |
||||
w := httptest.NewRecorder() |
||||
r := httptest.NewRequest("POST", "/", strings.NewReader(`{}`)) |
||||
h32.ServeHTTP(w, r) |
||||
resp := checkStatus(w, "error") |
||||
if resp.Error != "price is empty" { |
||||
t.Fatalf("wrong error") |
||||
} |
||||
}) |
||||
|
||||
// fn check
|
||||
shouldPanic := func() { |
||||
r := recover() |
||||
if r == nil { |
||||
t.Fatalf("should panic") |
||||
} |
||||
t.Log(r) |
||||
} |
||||
|
||||
t.Run("2 0 panic", func(t *testing.T) { |
||||
defer shouldPanic() |
||||
JSONHandler(func(w http.ResponseWriter, r *http.Request) {}) |
||||
}) |
||||
|
||||
t.Run("2 1 panic return value", func(t *testing.T) { |
||||
defer shouldPanic() |
||||
JSONHandler(func(w http.ResponseWriter, r *http.Request) string { |
||||
return "" |
||||
}) |
||||
}) |
||||
|
||||
t.Run("2 1 panic arguments", func(t *testing.T) { |
||||
defer shouldPanic() |
||||
JSONHandler(func(r *http.Request, w http.ResponseWriter) error { |
||||
return nil |
||||
}) |
||||
}) |
||||
|
||||
t.Run("3 1 panic arguments", func(t *testing.T) { |
||||
defer shouldPanic() |
||||
JSONHandler(func(name string, r *http.Request, w http.ResponseWriter) error { |
||||
return nil |
||||
}) |
||||
}) |
||||
|
||||
t.Run("3 2 panic return value", func(t *testing.T) { |
||||
defer shouldPanic() |
||||
//lint:ignore ST1008 intentional
|
||||
JSONHandler(func(name string, r *http.Request, w http.ResponseWriter) (error, string) { |
||||
return nil, "panic" |
||||
}) |
||||
}) |
||||
} |
||||
Loading…
Reference in new issue