|
|
|
|
@ -15,6 +15,7 @@ import ( |
|
|
|
|
"encoding/json" |
|
|
|
|
"errors" |
|
|
|
|
"fmt" |
|
|
|
|
"io" |
|
|
|
|
"net/http" |
|
|
|
|
"net/http/httptest" |
|
|
|
|
"net/netip" |
|
|
|
|
@ -881,7 +882,7 @@ func mustCreateURL(t *testing.T, u string) url.URL { |
|
|
|
|
|
|
|
|
|
func newTestBackend(t *testing.T, opts ...any) *LocalBackend { |
|
|
|
|
var logf logger.Logf = logger.Discard |
|
|
|
|
const debug = true |
|
|
|
|
const debug = false |
|
|
|
|
if debug { |
|
|
|
|
logf = logger.WithPrefix(tstest.WhileTestRunningLogger(t), "... ") |
|
|
|
|
} |
|
|
|
|
@ -1085,3 +1086,88 @@ func TestEncTailscaleHeaderValue(t *testing.T) { |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func TestServeGRPCProxy(t *testing.T) { |
|
|
|
|
const msg = "some-response\n" |
|
|
|
|
backend := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
|
|
|
|
w.Header().Set("Path-Was", r.RequestURI) |
|
|
|
|
w.Header().Set("Proto-Was", r.Proto) |
|
|
|
|
io.WriteString(w, msg) |
|
|
|
|
})) |
|
|
|
|
backend.EnableHTTP2 = true |
|
|
|
|
backend.Config.Protocols = new(http.Protocols) |
|
|
|
|
backend.Config.Protocols.SetHTTP1(true) |
|
|
|
|
backend.Config.Protocols.SetUnencryptedHTTP2(true) |
|
|
|
|
backend.Start() |
|
|
|
|
defer backend.Close() |
|
|
|
|
|
|
|
|
|
backendURL := must.Get(url.Parse(backend.URL)) |
|
|
|
|
|
|
|
|
|
lb := newTestBackend(t) |
|
|
|
|
rp := &reverseProxy{ |
|
|
|
|
logf: t.Logf, |
|
|
|
|
url: backendURL, |
|
|
|
|
backend: backend.URL, |
|
|
|
|
lb: lb, |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
req := func(method, urlStr string, opt ...any) *http.Request { |
|
|
|
|
req := httptest.NewRequest(method, urlStr, nil) |
|
|
|
|
for _, o := range opt { |
|
|
|
|
switch v := o.(type) { |
|
|
|
|
case int: |
|
|
|
|
req.ProtoMajor = v |
|
|
|
|
case string: |
|
|
|
|
req.Header.Set("Content-Type", v) |
|
|
|
|
default: |
|
|
|
|
panic(fmt.Sprintf("unsupported option type %T", v)) |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
return req |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
tests := []struct { |
|
|
|
|
name string |
|
|
|
|
req *http.Request |
|
|
|
|
wantPath string |
|
|
|
|
wantProto string |
|
|
|
|
wantBody string |
|
|
|
|
}{ |
|
|
|
|
{ |
|
|
|
|
name: "non-gRPC", |
|
|
|
|
req: req("GET", "http://foo/bar"), |
|
|
|
|
wantPath: "/bar", |
|
|
|
|
wantProto: "HTTP/1.1", |
|
|
|
|
}, |
|
|
|
|
{ |
|
|
|
|
name: "gRPC-but-not-http2", |
|
|
|
|
req: req("GET", "http://foo/bar", "application/grpc"), |
|
|
|
|
wantPath: "/bar", |
|
|
|
|
wantProto: "HTTP/1.1", |
|
|
|
|
}, |
|
|
|
|
{ |
|
|
|
|
name: "gRPC--http2", |
|
|
|
|
req: req("GET", "http://foo/bar", 2, "application/grpc"), |
|
|
|
|
wantPath: "/bar", |
|
|
|
|
wantProto: "HTTP/2.0", |
|
|
|
|
}, |
|
|
|
|
} |
|
|
|
|
for _, tt := range tests { |
|
|
|
|
t.Run(tt.name, func(t *testing.T) { |
|
|
|
|
rec := httptest.NewRecorder() |
|
|
|
|
rp.ServeHTTP(rec, tt.req) |
|
|
|
|
|
|
|
|
|
res := rec.Result() |
|
|
|
|
got := must.Get(io.ReadAll(res.Body)) |
|
|
|
|
if got, want := res.Header.Get("Path-Was"), tt.wantPath; want != got { |
|
|
|
|
t.Errorf("Path-Was %q, want %q", got, want) |
|
|
|
|
} |
|
|
|
|
if got, want := res.Header.Get("Proto-Was"), tt.wantProto; want != got { |
|
|
|
|
t.Errorf("Proto-Was %q, want %q", got, want) |
|
|
|
|
} |
|
|
|
|
if string(got) != msg { |
|
|
|
|
t.Errorf("got body %q, want %q", got, msg) |
|
|
|
|
} |
|
|
|
|
}) |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|