control/controlclient: remove x/net/http2, use net/http
Saves 352 KB, removing one of our two HTTP/2 implementations linked into the binary. Fixes #17305 Updates #15015 Change-Id: I53a04b1f2687dca73c8541949465038b69aa6ade Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:
committed by
Brad Fitzpatrick
parent
c45f8813b4
commit
1d93bdce20
@@ -0,0 +1,289 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package ts2021
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"cmp"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"math"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"tailscale.com/control/controlhttp"
|
||||
"tailscale.com/health"
|
||||
"tailscale.com/net/dnscache"
|
||||
"tailscale.com/net/netmon"
|
||||
"tailscale.com/net/tsdial"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/tstime"
|
||||
"tailscale.com/types/key"
|
||||
"tailscale.com/types/logger"
|
||||
)
|
||||
|
||||
// Client provides a http.Client to connect to tailcontrol over
|
||||
// the ts2021 protocol.
|
||||
type Client struct {
|
||||
// Client is an HTTP client to talk to the coordination server.
|
||||
// It automatically makes a new Noise connection as needed.
|
||||
*http.Client
|
||||
|
||||
logf logger.Logf // non-nil
|
||||
opts ClientOpts
|
||||
host string // the host part of serverURL
|
||||
httpPort string // the default port to dial
|
||||
httpsPort string // the fallback Noise-over-https port or empty if none
|
||||
|
||||
// mu protects the following
|
||||
mu sync.Mutex
|
||||
closed bool
|
||||
}
|
||||
|
||||
// ClientOpts contains options for the [NewClient] function. All fields are
|
||||
// required unless otherwise specified.
|
||||
type ClientOpts struct {
|
||||
// ServerURL is the URL of the server to connect to.
|
||||
ServerURL string
|
||||
|
||||
// PrivKey is this node's private key.
|
||||
PrivKey key.MachinePrivate
|
||||
|
||||
// ServerPubKey is the public key of the server.
|
||||
// It is of the form https://<host>:<port> (no trailing slash).
|
||||
ServerPubKey key.MachinePublic
|
||||
|
||||
// Dialer's SystemDial function is used to connect to the server.
|
||||
Dialer *tsdial.Dialer
|
||||
|
||||
// Optional fields follow
|
||||
|
||||
// Logf is the log function to use.
|
||||
// If nil, log.Printf is used.
|
||||
Logf logger.Logf
|
||||
|
||||
// NetMon is the network monitor that will be used to get the
|
||||
// network interface state. This field can be nil; if so, the current
|
||||
// state will be looked up dynamically.
|
||||
NetMon *netmon.Monitor
|
||||
|
||||
// DNSCache is the caching Resolver to use to connect to the server.
|
||||
//
|
||||
// This field can be nil.
|
||||
DNSCache *dnscache.Resolver
|
||||
|
||||
// HealthTracker, if non-nil, is the health tracker to use.
|
||||
HealthTracker *health.Tracker
|
||||
|
||||
// DialPlan, if set, is a function that should return an explicit plan
|
||||
// on how to connect to the server.
|
||||
DialPlan func() *tailcfg.ControlDialPlan
|
||||
|
||||
// ProtocolVersion, if non-zero, specifies an alternate
|
||||
// protocol version to use instead of the default,
|
||||
// of [tailcfg.CurrentCapabilityVersion].
|
||||
ProtocolVersion uint16
|
||||
}
|
||||
|
||||
// NewClient returns a new noiseClient for the provided server and machine key.
|
||||
//
|
||||
// netMon may be nil, if non-nil it's used to do faster interface lookups.
|
||||
// dialPlan may be nil
|
||||
func NewClient(opts ClientOpts) (*Client, error) {
|
||||
logf := opts.Logf
|
||||
if logf == nil {
|
||||
logf = log.Printf
|
||||
}
|
||||
if opts.ServerURL == "" {
|
||||
return nil, errors.New("ServerURL is required")
|
||||
}
|
||||
if opts.PrivKey.IsZero() {
|
||||
return nil, errors.New("PrivKey is required")
|
||||
}
|
||||
if opts.ServerPubKey.IsZero() {
|
||||
return nil, errors.New("ServerPubKey is required")
|
||||
}
|
||||
if opts.Dialer == nil {
|
||||
return nil, errors.New("Dialer is required")
|
||||
}
|
||||
|
||||
u, err := url.Parse(opts.ServerURL)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid ClientOpts.ServerURL: %w", err)
|
||||
}
|
||||
if u.Scheme != "http" && u.Scheme != "https" {
|
||||
return nil, errors.New("invalid ServerURL scheme, must be http or https")
|
||||
}
|
||||
|
||||
httpPort, httpsPort := "80", "443"
|
||||
addr, _ := netip.ParseAddr(u.Hostname())
|
||||
isPrivateHost := addr.IsPrivate() || addr.IsLoopback() || u.Hostname() == "localhost"
|
||||
if port := u.Port(); port != "" {
|
||||
// If there is an explicit port specified, entirely rely on the scheme,
|
||||
// unless it's http with a private host in which case we never try using HTTPS.
|
||||
if u.Scheme == "https" {
|
||||
httpPort = ""
|
||||
httpsPort = port
|
||||
} else if u.Scheme == "http" {
|
||||
httpPort = port
|
||||
httpsPort = "443"
|
||||
if isPrivateHost {
|
||||
logf("setting empty HTTPS port with http scheme and private host %s", u.Hostname())
|
||||
httpsPort = ""
|
||||
}
|
||||
}
|
||||
} else if u.Scheme == "http" && isPrivateHost {
|
||||
// Whenever the scheme is http and the hostname is an IP address, do not set the HTTPS port,
|
||||
// as there cannot be a TLS certificate issued for an IP, unless it's a public IP.
|
||||
httpPort = "80"
|
||||
httpsPort = ""
|
||||
}
|
||||
|
||||
np := &Client{
|
||||
opts: opts,
|
||||
host: u.Hostname(),
|
||||
httpPort: httpPort,
|
||||
httpsPort: httpsPort,
|
||||
logf: logf,
|
||||
}
|
||||
|
||||
tr := &http.Transport{
|
||||
Protocols: new(http.Protocols),
|
||||
MaxConnsPerHost: 1,
|
||||
}
|
||||
// We force only HTTP/2 for this transport, which is what the control server
|
||||
// speaks inside the ts2021 Noise encryption. But Go doesn't know about that,
|
||||
// so we use "SetUnencryptedHTTP2" even though it's actually encrypted.
|
||||
tr.Protocols.SetUnencryptedHTTP2(true)
|
||||
tr.DialTLSContext = func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return np.dial(ctx)
|
||||
}
|
||||
|
||||
np.Client = &http.Client{Transport: tr}
|
||||
return np, nil
|
||||
}
|
||||
|
||||
// Close closes all the underlying noise connections.
|
||||
// It is a no-op and returns nil if the connection is already closed.
|
||||
func (nc *Client) Close() error {
|
||||
nc.mu.Lock()
|
||||
defer nc.mu.Unlock()
|
||||
nc.closed = true
|
||||
nc.Client.CloseIdleConnections()
|
||||
return nil
|
||||
}
|
||||
|
||||
// dial opens a new connection to tailcontrol, fetching the server noise key
|
||||
// if not cached.
|
||||
func (nc *Client) dial(ctx context.Context) (*Conn, error) {
|
||||
if tailcfg.CurrentCapabilityVersion > math.MaxUint16 {
|
||||
// Panic, because a test should have started failing several
|
||||
// thousand version numbers before getting to this point.
|
||||
panic("capability version is too high to fit in the wire protocol")
|
||||
}
|
||||
|
||||
var dialPlan *tailcfg.ControlDialPlan
|
||||
if nc.opts.DialPlan != nil {
|
||||
dialPlan = nc.opts.DialPlan()
|
||||
}
|
||||
|
||||
// If we have a dial plan, then set our timeout as slightly longer than
|
||||
// the maximum amount of time contained therein; we assume that
|
||||
// explicit instructions on timeouts are more useful than a single
|
||||
// hard-coded timeout.
|
||||
//
|
||||
// The default value of 5 is chosen so that, when there's no dial plan,
|
||||
// we retain the previous behaviour of 10 seconds end-to-end timeout.
|
||||
timeoutSec := 5.0
|
||||
if dialPlan != nil {
|
||||
for _, c := range dialPlan.Candidates {
|
||||
if v := c.DialStartDelaySec + c.DialTimeoutSec; v > timeoutSec {
|
||||
timeoutSec = v
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// After we establish a connection, we need some time to actually
|
||||
// upgrade it into a Noise connection. With a ballpark worst-case RTT
|
||||
// of 1000ms, give ourselves an extra 5 seconds to complete the
|
||||
// handshake.
|
||||
timeoutSec += 5
|
||||
|
||||
// Be extremely defensive and ensure that the timeout is in the range
|
||||
// [5, 60] seconds (e.g. if we accidentally get a negative number).
|
||||
if timeoutSec > 60 {
|
||||
timeoutSec = 60
|
||||
} else if timeoutSec < 5 {
|
||||
timeoutSec = 5
|
||||
}
|
||||
|
||||
timeout := time.Duration(timeoutSec * float64(time.Second))
|
||||
ctx, cancel := context.WithTimeout(ctx, timeout)
|
||||
defer cancel()
|
||||
|
||||
chd := &controlhttp.Dialer{
|
||||
Hostname: nc.host,
|
||||
HTTPPort: nc.httpPort,
|
||||
HTTPSPort: cmp.Or(nc.httpsPort, controlhttp.NoPort),
|
||||
MachineKey: nc.opts.PrivKey,
|
||||
ControlKey: nc.opts.ServerPubKey,
|
||||
ProtocolVersion: cmp.Or(nc.opts.ProtocolVersion, uint16(tailcfg.CurrentCapabilityVersion)),
|
||||
Dialer: nc.opts.Dialer.SystemDial,
|
||||
DNSCache: nc.opts.DNSCache,
|
||||
DialPlan: dialPlan,
|
||||
Logf: nc.logf,
|
||||
NetMon: nc.opts.NetMon,
|
||||
HealthTracker: nc.opts.HealthTracker,
|
||||
Clock: tstime.StdClock{},
|
||||
}
|
||||
clientConn, err := chd.Dial(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ncc := NewConn(clientConn.Conn)
|
||||
|
||||
nc.mu.Lock()
|
||||
if nc.closed {
|
||||
nc.mu.Unlock()
|
||||
ncc.Close() // Needs to be called without holding the lock.
|
||||
return nil, errors.New("noise client closed")
|
||||
}
|
||||
defer nc.mu.Unlock()
|
||||
return ncc, nil
|
||||
}
|
||||
|
||||
// post does a POST to the control server at the given path, JSON-encoding body.
|
||||
// The provided nodeKey is an optional load balancing hint.
|
||||
func (nc *Client) Post(ctx context.Context, path string, nodeKey key.NodePublic, body any) (*http.Response, error) {
|
||||
return nc.DoWithBody(ctx, "POST", path, nodeKey, body)
|
||||
}
|
||||
|
||||
func (nc *Client) DoWithBody(ctx context.Context, method, path string, nodeKey key.NodePublic, body any) (*http.Response, error) {
|
||||
jbody, err := json.Marshal(body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, method, "https://"+nc.host+path, bytes.NewReader(jbody))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
AddLBHeader(req, nodeKey)
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
return nc.Do(req)
|
||||
}
|
||||
|
||||
// AddLBHeader adds the load balancer header to req if nodeKey is non-zero.
|
||||
func AddLBHeader(req *http.Request, nodeKey key.NodePublic) {
|
||||
if !nodeKey.IsZero() {
|
||||
req.Header.Add(tailcfg.LBHeader, nodeKey.String())
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,363 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package ts2021
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"math"
|
||||
"net/http"
|
||||
"net/http/httptrace"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/http2"
|
||||
"tailscale.com/control/controlhttp/controlhttpserver"
|
||||
"tailscale.com/net/netmon"
|
||||
"tailscale.com/net/tsdial"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/tstest/nettest"
|
||||
"tailscale.com/types/key"
|
||||
"tailscale.com/types/logger"
|
||||
"tailscale.com/util/must"
|
||||
)
|
||||
|
||||
// maxAllowedNoiseVersion is the highest we expect the Tailscale
|
||||
// capability version to ever get. It's a value close to 2^16, but
|
||||
// with enough leeway that we get a very early warning that it's time
|
||||
// to rework the wire protocol to allow larger versions, while still
|
||||
// giving us headroom to bump this test and fix the build.
|
||||
//
|
||||
// Code elsewhere in the client will panic() if the tailcfg capability
|
||||
// version exceeds 16 bits, so take a failure of this test seriously.
|
||||
const maxAllowedNoiseVersion = math.MaxUint16 - 5000
|
||||
|
||||
func TestNoiseVersion(t *testing.T) {
|
||||
if tailcfg.CurrentCapabilityVersion > maxAllowedNoiseVersion {
|
||||
t.Fatalf("tailcfg.CurrentCapabilityVersion is %d, want <=%d", tailcfg.CurrentCapabilityVersion, maxAllowedNoiseVersion)
|
||||
}
|
||||
}
|
||||
|
||||
type noiseClientTest struct {
|
||||
sendEarlyPayload bool
|
||||
}
|
||||
|
||||
func TestNoiseClientHTTP2Upgrade(t *testing.T) {
|
||||
noiseClientTest{}.run(t)
|
||||
}
|
||||
|
||||
func TestNoiseClientHTTP2Upgrade_earlyPayload(t *testing.T) {
|
||||
noiseClientTest{
|
||||
sendEarlyPayload: true,
|
||||
}.run(t)
|
||||
}
|
||||
|
||||
var (
|
||||
testPrivKey = key.NewMachine()
|
||||
testServerPub = key.NewMachine().Public()
|
||||
)
|
||||
|
||||
func makeClientWithURL(t *testing.T, url string) *Client {
|
||||
nc, err := NewClient(ClientOpts{
|
||||
Logf: t.Logf,
|
||||
PrivKey: testPrivKey,
|
||||
ServerPubKey: testServerPub,
|
||||
ServerURL: url,
|
||||
Dialer: tsdial.NewDialer(netmon.NewStatic()),
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Cleanup(func() { nc.Close() })
|
||||
return nc
|
||||
}
|
||||
|
||||
func TestNoiseClientPortsAreSet(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
url string
|
||||
wantHTTPS string
|
||||
wantHTTP string
|
||||
}{
|
||||
{
|
||||
name: "https-url",
|
||||
url: "https://example.com",
|
||||
wantHTTPS: "443",
|
||||
wantHTTP: "80",
|
||||
},
|
||||
{
|
||||
name: "http-url",
|
||||
url: "http://example.com",
|
||||
wantHTTPS: "443", // TODO(bradfitz): questionable; change?
|
||||
wantHTTP: "80",
|
||||
},
|
||||
{
|
||||
name: "https-url-custom-port",
|
||||
url: "https://example.com:123",
|
||||
wantHTTPS: "123",
|
||||
wantHTTP: "",
|
||||
},
|
||||
{
|
||||
name: "http-url-custom-port",
|
||||
url: "http://example.com:123",
|
||||
wantHTTPS: "443", // TODO(bradfitz): questionable; change?
|
||||
wantHTTP: "123",
|
||||
},
|
||||
{
|
||||
name: "http-loopback-no-port",
|
||||
url: "http://127.0.0.1",
|
||||
wantHTTPS: "",
|
||||
wantHTTP: "80",
|
||||
},
|
||||
{
|
||||
name: "http-loopback-custom-port",
|
||||
url: "http://127.0.0.1:8080",
|
||||
wantHTTPS: "",
|
||||
wantHTTP: "8080",
|
||||
},
|
||||
{
|
||||
name: "http-localhost-no-port",
|
||||
url: "http://localhost",
|
||||
wantHTTPS: "",
|
||||
wantHTTP: "80",
|
||||
},
|
||||
{
|
||||
name: "http-localhost-custom-port",
|
||||
url: "http://localhost:8080",
|
||||
wantHTTPS: "",
|
||||
wantHTTP: "8080",
|
||||
},
|
||||
{
|
||||
name: "http-private-ip-no-port",
|
||||
url: "http://192.168.2.3",
|
||||
wantHTTPS: "",
|
||||
wantHTTP: "80",
|
||||
},
|
||||
{
|
||||
name: "http-private-ip-custom-port",
|
||||
url: "http://192.168.2.3:8080",
|
||||
wantHTTPS: "",
|
||||
wantHTTP: "8080",
|
||||
},
|
||||
{
|
||||
name: "http-public-ip",
|
||||
url: "http://1.2.3.4",
|
||||
wantHTTPS: "443", // TODO(bradfitz): questionable; change?
|
||||
wantHTTP: "80",
|
||||
},
|
||||
{
|
||||
name: "http-public-ip-custom-port",
|
||||
url: "http://1.2.3.4:8080",
|
||||
wantHTTPS: "443", // TODO(bradfitz): questionable; change?
|
||||
wantHTTP: "8080",
|
||||
},
|
||||
{
|
||||
name: "https-public-ip",
|
||||
url: "https://1.2.3.4",
|
||||
wantHTTPS: "443",
|
||||
wantHTTP: "80",
|
||||
},
|
||||
{
|
||||
name: "https-public-ip-custom-port",
|
||||
url: "https://1.2.3.4:8080",
|
||||
wantHTTPS: "8080",
|
||||
wantHTTP: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
nc := makeClientWithURL(t, tt.url)
|
||||
if nc.httpsPort != tt.wantHTTPS {
|
||||
t.Errorf("nc.httpsPort = %q; want %q", nc.httpsPort, tt.wantHTTPS)
|
||||
}
|
||||
if nc.httpPort != tt.wantHTTP {
|
||||
t.Errorf("nc.httpPort = %q; want %q", nc.httpPort, tt.wantHTTP)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (tt noiseClientTest) run(t *testing.T) {
|
||||
serverPrivate := key.NewMachine()
|
||||
clientPrivate := key.NewMachine()
|
||||
chalPrivate := key.NewChallenge()
|
||||
|
||||
const msg = "Hello, client"
|
||||
h2 := &http2.Server{}
|
||||
nw := nettest.GetNetwork(t)
|
||||
hs := nettest.NewHTTPServer(nw, &Upgrader{
|
||||
h2srv: h2,
|
||||
noiseKeyPriv: serverPrivate,
|
||||
sendEarlyPayload: tt.sendEarlyPayload,
|
||||
challenge: chalPrivate,
|
||||
httpBaseConfig: &http.Server{
|
||||
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
io.WriteString(w, msg)
|
||||
}),
|
||||
},
|
||||
})
|
||||
defer hs.Close()
|
||||
|
||||
dialer := tsdial.NewDialer(netmon.NewStatic())
|
||||
if nettest.PreferMemNetwork() {
|
||||
dialer.SetSystemDialerForTest(nw.Dial)
|
||||
}
|
||||
|
||||
nc, err := NewClient(ClientOpts{
|
||||
PrivKey: clientPrivate,
|
||||
ServerPubKey: serverPrivate.Public(),
|
||||
ServerURL: hs.URL,
|
||||
Dialer: dialer,
|
||||
Logf: t.Logf,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var sawConn atomic.Bool
|
||||
trace := httptrace.WithClientTrace(t.Context(), &httptrace.ClientTrace{
|
||||
GotConn: func(ci httptrace.GotConnInfo) {
|
||||
ncc, ok := ci.Conn.(*Conn)
|
||||
if !ok {
|
||||
// This trace hook sees two dials: the lower-level controlhttp upgrade's
|
||||
// dial (a tsdial.sysConn), and then the *ts2021.Conn we want.
|
||||
// Ignore the first one.
|
||||
return
|
||||
}
|
||||
sawConn.Store(true)
|
||||
|
||||
ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
payload, err := ncc.GetEarlyPayload(ctx)
|
||||
if err != nil {
|
||||
t.Errorf("GetEarlyPayload: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
gotNonNil := payload != nil
|
||||
if gotNonNil != tt.sendEarlyPayload {
|
||||
t.Errorf("sendEarlyPayload = %v but got earlyPayload = %T", tt.sendEarlyPayload, payload)
|
||||
}
|
||||
if payload != nil {
|
||||
if payload.NodeKeyChallenge != chalPrivate.Public() {
|
||||
t.Errorf("earlyPayload.NodeKeyChallenge = %v; want %v", payload.NodeKeyChallenge, chalPrivate.Public())
|
||||
}
|
||||
}
|
||||
},
|
||||
})
|
||||
req := must.Get(http.NewRequestWithContext(trace, "GET", "https://unused.example/", nil))
|
||||
|
||||
checkRes := func(t *testing.T, res *http.Response) {
|
||||
t.Helper()
|
||||
defer res.Body.Close()
|
||||
all, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if string(all) != msg {
|
||||
t.Errorf("got response %q; want %q", all, msg)
|
||||
}
|
||||
}
|
||||
|
||||
// Verify we can do HTTP/2 against that conn.
|
||||
res, err := nc.Do(req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
checkRes(t, res)
|
||||
|
||||
if !sawConn.Load() {
|
||||
t.Error("ClientTrace.GotConn never saw the *ts2021.Conn")
|
||||
}
|
||||
|
||||
// And try using the high-level nc.post API as well.
|
||||
res, err = nc.Post(context.Background(), "/", key.NodePublic{}, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
checkRes(t, res)
|
||||
}
|
||||
|
||||
// Upgrader is an http.Handler that hijacks and upgrades POST-with-Upgrade
|
||||
// request to a Tailscale 2021 connection, then hands the resulting
|
||||
// controlbase.Conn off to h2srv.
|
||||
type Upgrader struct {
|
||||
// h2srv is that will handle requests after the
|
||||
// connection has been upgraded to HTTP/2-over-noise.
|
||||
h2srv *http2.Server
|
||||
|
||||
// httpBaseConfig is the http1 server config that h2srv is
|
||||
// associated with.
|
||||
httpBaseConfig *http.Server
|
||||
|
||||
logf logger.Logf
|
||||
|
||||
noiseKeyPriv key.MachinePrivate
|
||||
challenge key.ChallengePrivate
|
||||
|
||||
sendEarlyPayload bool
|
||||
}
|
||||
|
||||
func (up *Upgrader) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
if up == nil || up.h2srv == nil {
|
||||
http.Error(w, "invalid server config", http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
if r.URL.Path != "/ts2021" {
|
||||
http.Error(w, "ts2021 upgrader installed at wrong path", http.StatusBadGateway)
|
||||
return
|
||||
}
|
||||
if up.noiseKeyPriv.IsZero() {
|
||||
http.Error(w, "keys not available", http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
earlyWriteFn := func(protocolVersion int, w io.Writer) error {
|
||||
if !up.sendEarlyPayload {
|
||||
return nil
|
||||
}
|
||||
earlyJSON, err := json.Marshal(&tailcfg.EarlyNoise{
|
||||
NodeKeyChallenge: up.challenge.Public(),
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// 5 bytes that won't be mistaken for an HTTP/2 frame:
|
||||
// https://httpwg.org/specs/rfc7540.html#rfc.section.4.1 (Especially not
|
||||
// an HTTP/2 settings frame, which isn't of type 'T')
|
||||
var notH2Frame [5]byte
|
||||
copy(notH2Frame[:], EarlyPayloadMagic)
|
||||
var lenBuf [4]byte
|
||||
binary.BigEndian.PutUint32(lenBuf[:], uint32(len(earlyJSON)))
|
||||
// These writes are all buffered by caller, so fine to do them
|
||||
// separately:
|
||||
if _, err := w.Write(notH2Frame[:]); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := w.Write(lenBuf[:]); err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := w.Write(earlyJSON[:]); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
cbConn, err := controlhttpserver.AcceptHTTP(r.Context(), w, r, up.noiseKeyPriv, earlyWriteFn)
|
||||
if err != nil {
|
||||
up.logf("controlhttp: Accept: %v", err)
|
||||
return
|
||||
}
|
||||
defer cbConn.Close()
|
||||
|
||||
up.h2srv.ServeConn(cbConn, &http2.ServeConnOpts{
|
||||
BaseConfig: up.httpBaseConfig,
|
||||
})
|
||||
}
|
||||
+18
-39
@@ -13,10 +13,8 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"sync"
|
||||
|
||||
"golang.org/x/net/http2"
|
||||
"tailscale.com/control/controlbase"
|
||||
"tailscale.com/tailcfg"
|
||||
)
|
||||
@@ -27,11 +25,11 @@ import (
|
||||
// the pool when the connection is closed, properly handles an optional "early
|
||||
// payload" that's sent prior to beginning the HTTP/2 session, and provides a
|
||||
// way to return a connection to a pool when the connection is closed.
|
||||
//
|
||||
// Use [NewConn] to build a new Conn if you want [Conn.GetEarlyPayload] to work.
|
||||
// Otherwise making a Conn directly, only setting Conn, is fine.
|
||||
type Conn struct {
|
||||
*controlbase.Conn
|
||||
id int
|
||||
onClose func(int)
|
||||
h2cc *http2.ClientConn
|
||||
|
||||
readHeaderOnce sync.Once // guards init of reader field
|
||||
reader io.Reader // (effectively Conn.Reader after header)
|
||||
@@ -40,31 +38,18 @@ type Conn struct {
|
||||
earlyPayloadErr error
|
||||
}
|
||||
|
||||
// New creates a new Conn that wraps the given controlbase.Conn.
|
||||
// NewConn creates a new Conn that wraps the given controlbase.Conn.
|
||||
//
|
||||
// h2t is the HTTP/2 transport to use for the connection; a new
|
||||
// http2.ClientConn will be created that reads from the returned Conn.
|
||||
//
|
||||
// connID should be a unique ID for this connection. When the Conn is closed,
|
||||
// the onClose function will be called with the connID if it is non-nil.
|
||||
func New(conn *controlbase.Conn, h2t *http2.Transport, connID int, onClose func(int)) (*Conn, error) {
|
||||
ncc := &Conn{
|
||||
func NewConn(conn *controlbase.Conn) *Conn {
|
||||
return &Conn{
|
||||
Conn: conn,
|
||||
id: connID,
|
||||
onClose: onClose,
|
||||
earlyPayloadReady: make(chan struct{}),
|
||||
}
|
||||
h2cc, err := h2t.NewClientConn(ncc)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ncc.h2cc = h2cc
|
||||
return ncc, nil
|
||||
}
|
||||
|
||||
// RoundTrip implements the http.RoundTripper interface.
|
||||
func (c *Conn) RoundTrip(r *http.Request) (*http.Response, error) {
|
||||
return c.h2cc.RoundTrip(r)
|
||||
}
|
||||
|
||||
// GetEarlyPayload waits for the early Noise payload to arrive.
|
||||
@@ -74,6 +59,15 @@ func (c *Conn) RoundTrip(r *http.Request) (*http.Response, error) {
|
||||
// early Noise payload is ready (if any) and will return the same result for
|
||||
// the lifetime of the Conn.
|
||||
func (c *Conn) GetEarlyPayload(ctx context.Context) (*tailcfg.EarlyNoise, error) {
|
||||
if c.earlyPayloadReady == nil {
|
||||
return nil, errors.New("Conn was not created with NewConn; early payload not supported")
|
||||
}
|
||||
select {
|
||||
case <-c.earlyPayloadReady:
|
||||
return c.earlyPayload, c.earlyPayloadErr
|
||||
default:
|
||||
go c.readHeaderOnce.Do(c.readHeader)
|
||||
}
|
||||
select {
|
||||
case <-c.earlyPayloadReady:
|
||||
return c.earlyPayload, c.earlyPayloadErr
|
||||
@@ -82,12 +76,6 @@ func (c *Conn) GetEarlyPayload(ctx context.Context) (*tailcfg.EarlyNoise, error)
|
||||
}
|
||||
}
|
||||
|
||||
// CanTakeNewRequest reports whether the underlying HTTP/2 connection can take
|
||||
// a new request, meaning it has not been closed or received or sent a GOAWAY.
|
||||
func (c *Conn) CanTakeNewRequest() bool {
|
||||
return c.h2cc.CanTakeNewRequest()
|
||||
}
|
||||
|
||||
// The first 9 bytes from the server to client over Noise are either an HTTP/2
|
||||
// settings frame (a normal HTTP/2 setup) or, as we added later, an "early payload"
|
||||
// header that's also 9 bytes long: 5 bytes (EarlyPayloadMagic) followed by 4 bytes
|
||||
@@ -122,7 +110,9 @@ func (c *Conn) Read(p []byte) (n int, err error) {
|
||||
// c.earlyPayload, closing c.earlyPayloadReady, and initializing c.reader for
|
||||
// future reads.
|
||||
func (c *Conn) readHeader() {
|
||||
defer close(c.earlyPayloadReady)
|
||||
if c.earlyPayloadReady != nil {
|
||||
defer close(c.earlyPayloadReady)
|
||||
}
|
||||
|
||||
setErr := func(err error) {
|
||||
c.reader = returnErrReader{err}
|
||||
@@ -156,14 +146,3 @@ func (c *Conn) readHeader() {
|
||||
}
|
||||
c.reader = c.Conn
|
||||
}
|
||||
|
||||
// Close closes the connection.
|
||||
func (c *Conn) Close() error {
|
||||
if err := c.Conn.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
if c.onClose != nil {
|
||||
c.onClose(c.id)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user