ipn/ipnlocal: proxy h2c grpc using net/http.Transport instead of x/net/http2
(Kinda related: #17351) Updates #17305 Change-Id: I47df2612732a5713577164e74652bc9fa3cd14b3 Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:
committed by
Brad Fitzpatrick
parent
3f5c560fd4
commit
2c956e30be
+11
-11
@@ -34,7 +34,6 @@ import (
|
|||||||
"unicode/utf8"
|
"unicode/utf8"
|
||||||
|
|
||||||
"go4.org/mem"
|
"go4.org/mem"
|
||||||
"golang.org/x/net/http2"
|
|
||||||
"tailscale.com/ipn"
|
"tailscale.com/ipn"
|
||||||
"tailscale.com/net/netutil"
|
"tailscale.com/net/netutil"
|
||||||
"tailscale.com/syncs"
|
"tailscale.com/syncs"
|
||||||
@@ -761,8 +760,8 @@ type reverseProxy struct {
|
|||||||
insecure bool
|
insecure bool
|
||||||
backend string
|
backend string
|
||||||
lb *LocalBackend
|
lb *LocalBackend
|
||||||
httpTransport lazy.SyncValue[*http.Transport] // transport for non-h2c backends
|
httpTransport lazy.SyncValue[*http.Transport] // transport for non-h2c backends
|
||||||
h2cTransport lazy.SyncValue[*http2.Transport] // transport for h2c backends
|
h2cTransport lazy.SyncValue[*http.Transport] // transport for h2c backends
|
||||||
// closed tracks whether proxy is closed/currently closing.
|
// closed tracks whether proxy is closed/currently closing.
|
||||||
closed atomic.Bool
|
closed atomic.Bool
|
||||||
}
|
}
|
||||||
@@ -770,9 +769,7 @@ type reverseProxy struct {
|
|||||||
// close ensures that any open backend connections get closed.
|
// close ensures that any open backend connections get closed.
|
||||||
func (rp *reverseProxy) close() {
|
func (rp *reverseProxy) close() {
|
||||||
rp.closed.Store(true)
|
rp.closed.Store(true)
|
||||||
if h2cT := rp.h2cTransport.Get(func() *http2.Transport {
|
if h2cT := rp.h2cTransport.Get(func() *http.Transport { return nil }); h2cT != nil {
|
||||||
return nil
|
|
||||||
}); h2cT != nil {
|
|
||||||
h2cT.CloseIdleConnections()
|
h2cT.CloseIdleConnections()
|
||||||
}
|
}
|
||||||
if httpTransport := rp.httpTransport.Get(func() *http.Transport {
|
if httpTransport := rp.httpTransport.Get(func() *http.Transport {
|
||||||
@@ -843,14 +840,17 @@ func (rp *reverseProxy) getTransport() *http.Transport {
|
|||||||
|
|
||||||
// getH2CTransport returns the Transport used for GRPC requests to the backend.
|
// getH2CTransport returns the Transport used for GRPC requests to the backend.
|
||||||
// The Transport gets created lazily, at most once.
|
// The Transport gets created lazily, at most once.
|
||||||
func (rp *reverseProxy) getH2CTransport() *http2.Transport {
|
func (rp *reverseProxy) getH2CTransport() http.RoundTripper {
|
||||||
return rp.h2cTransport.Get(func() *http2.Transport {
|
return rp.h2cTransport.Get(func() *http.Transport {
|
||||||
return &http2.Transport{
|
var p http.Protocols
|
||||||
AllowHTTP: true,
|
p.SetUnencryptedHTTP2(true)
|
||||||
DialTLSContext: func(ctx context.Context, network string, addr string, _ *tls.Config) (net.Conn, error) {
|
tr := &http.Transport{
|
||||||
|
Protocols: &p,
|
||||||
|
DialTLSContext: func(ctx context.Context, network string, addr string) (net.Conn, error) {
|
||||||
return rp.lb.dialer.SystemDial(ctx, "tcp", rp.url.Host)
|
return rp.lb.dialer.SystemDial(ctx, "tcp", rp.url.Host)
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
return tr
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
@@ -881,7 +882,7 @@ func mustCreateURL(t *testing.T, u string) url.URL {
|
|||||||
|
|
||||||
func newTestBackend(t *testing.T, opts ...any) *LocalBackend {
|
func newTestBackend(t *testing.T, opts ...any) *LocalBackend {
|
||||||
var logf logger.Logf = logger.Discard
|
var logf logger.Logf = logger.Discard
|
||||||
const debug = true
|
const debug = false
|
||||||
if debug {
|
if debug {
|
||||||
logf = logger.WithPrefix(tstest.WhileTestRunningLogger(t), "... ")
|
logf = logger.WithPrefix(tstest.WhileTestRunningLogger(t), "... ")
|
||||||
}
|
}
|
||||||
@@ -1085,3 +1086,88 @@ func TestEncTailscaleHeaderValue(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestServeGRPCProxy(t *testing.T) {
|
||||||
|
const msg = "some-response\n"
|
||||||
|
backend := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Path-Was", r.RequestURI)
|
||||||
|
w.Header().Set("Proto-Was", r.Proto)
|
||||||
|
io.WriteString(w, msg)
|
||||||
|
}))
|
||||||
|
backend.EnableHTTP2 = true
|
||||||
|
backend.Config.Protocols = new(http.Protocols)
|
||||||
|
backend.Config.Protocols.SetHTTP1(true)
|
||||||
|
backend.Config.Protocols.SetUnencryptedHTTP2(true)
|
||||||
|
backend.Start()
|
||||||
|
defer backend.Close()
|
||||||
|
|
||||||
|
backendURL := must.Get(url.Parse(backend.URL))
|
||||||
|
|
||||||
|
lb := newTestBackend(t)
|
||||||
|
rp := &reverseProxy{
|
||||||
|
logf: t.Logf,
|
||||||
|
url: backendURL,
|
||||||
|
backend: backend.URL,
|
||||||
|
lb: lb,
|
||||||
|
}
|
||||||
|
|
||||||
|
req := func(method, urlStr string, opt ...any) *http.Request {
|
||||||
|
req := httptest.NewRequest(method, urlStr, nil)
|
||||||
|
for _, o := range opt {
|
||||||
|
switch v := o.(type) {
|
||||||
|
case int:
|
||||||
|
req.ProtoMajor = v
|
||||||
|
case string:
|
||||||
|
req.Header.Set("Content-Type", v)
|
||||||
|
default:
|
||||||
|
panic(fmt.Sprintf("unsupported option type %T", v))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return req
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
req *http.Request
|
||||||
|
wantPath string
|
||||||
|
wantProto string
|
||||||
|
wantBody string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "non-gRPC",
|
||||||
|
req: req("GET", "http://foo/bar"),
|
||||||
|
wantPath: "/bar",
|
||||||
|
wantProto: "HTTP/1.1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "gRPC-but-not-http2",
|
||||||
|
req: req("GET", "http://foo/bar", "application/grpc"),
|
||||||
|
wantPath: "/bar",
|
||||||
|
wantProto: "HTTP/1.1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "gRPC--http2",
|
||||||
|
req: req("GET", "http://foo/bar", 2, "application/grpc"),
|
||||||
|
wantPath: "/bar",
|
||||||
|
wantProto: "HTTP/2.0",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
rec := httptest.NewRecorder()
|
||||||
|
rp.ServeHTTP(rec, tt.req)
|
||||||
|
|
||||||
|
res := rec.Result()
|
||||||
|
got := must.Get(io.ReadAll(res.Body))
|
||||||
|
if got, want := res.Header.Get("Path-Was"), tt.wantPath; want != got {
|
||||||
|
t.Errorf("Path-Was %q, want %q", got, want)
|
||||||
|
}
|
||||||
|
if got, want := res.Header.Get("Proto-Was"), tt.wantProto; want != got {
|
||||||
|
t.Errorf("Proto-Was %q, want %q", got, want)
|
||||||
|
}
|
||||||
|
if string(got) != msg {
|
||||||
|
t.Errorf("got body %q, want %q", got, msg)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user