control/tsp, cmd/tsp: add low-level Tailscale protocol client and tool
Add a new control/tsp package providing a client for speaking the Tailscale protocol to a coordination server over Noise, along with a cmd/tsp binary exposing it as a low-level composable tool for generating keys, registering nodes, and issuing map requests. Previously developed out-of-tree at github.com/bradfitz/tsp; imported here without git history. Updates #12542 Change-Id: I6ad21143c4aefe8939d4a46ae65b2184173bf69f Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:
committed by
Brad Fitzpatrick
parent
69572c7435
commit
50d7176333
@@ -0,0 +1,339 @@
|
||||
// Copyright (c) Tailscale Inc & contributors
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package tsp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"cmp"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/klauspost/compress/zstd"
|
||||
"tailscale.com/control/ts2021"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/key"
|
||||
)
|
||||
|
||||
// errSessionClosed is returned by [MapSession.Next] and
|
||||
// [MapSession.NextInto] when called after [MapSession.Close].
|
||||
var errSessionClosed = errors.New("tsp: map session closed")
|
||||
|
||||
// DefaultMaxMessageSize is the default cap, in bytes, on the size of a
|
||||
// single compressed map response frame. See [MapOpts.MaxMessageSize].
|
||||
const DefaultMaxMessageSize = 4 << 20
|
||||
|
||||
// zstdDecoderPool is a pool of *zstd.Decoder reused across MapSessions to
|
||||
// amortize the cost of setting up zstd state. Decoders are returned via
|
||||
// [MapSession.Close]; entries are reclaimed by the runtime under memory
|
||||
// pressure via sync.Pool semantics.
|
||||
var zstdDecoderPool sync.Pool // of *zstd.Decoder
|
||||
|
||||
// MapOpts contains options for sending a map request.
|
||||
type MapOpts struct {
|
||||
// NodeKey is the node's private key. Required.
|
||||
NodeKey key.NodePrivate
|
||||
|
||||
// Hostinfo is the host information to send. Optional;
|
||||
// if nil, a minimal default is used.
|
||||
Hostinfo *tailcfg.Hostinfo
|
||||
|
||||
// Stream is whether to receive multiple MapResponses over
|
||||
// the same HTTP connection.
|
||||
Stream bool
|
||||
|
||||
// OmitPeers is whether the client is okay with the Peers list
|
||||
// being omitted in the response.
|
||||
OmitPeers bool
|
||||
|
||||
// MaxMessageSize is the maximum size in bytes of any single
|
||||
// compressed map response frame on the wire. If zero,
|
||||
// [DefaultMaxMessageSize] is used.
|
||||
MaxMessageSize int64
|
||||
}
|
||||
|
||||
// framedReader is an io.Reader that consumes a stream of length-prefixed
|
||||
// frames (each a little-endian uint32 length followed by that many bytes)
|
||||
// from r and yields only the frame payloads back-to-back.
|
||||
//
|
||||
// This lets us feed the concatenated zstd frames from our wire protocol
|
||||
// into a single streaming zstd decoder. Zstd's file format permits
|
||||
// concatenation (RFC 8478 §2), and klauspost's decoder handles it
|
||||
// transparently.
|
||||
//
|
||||
// If onNewFrame is non-nil, it is called after each new 4-byte length
|
||||
// header is successfully read. Used to reset the per-message decoded-size
|
||||
// budget downstream.
|
||||
type framedReader struct {
|
||||
r io.Reader
|
||||
maxSize int64 // per-frame compressed-size cap
|
||||
remain int // bytes remaining in the current frame
|
||||
onNewFrame func()
|
||||
}
|
||||
|
||||
func (f *framedReader) Read(p []byte) (int, error) {
|
||||
if f.remain == 0 {
|
||||
var hdr [4]byte
|
||||
if _, err := io.ReadFull(f.r, hdr[:]); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
sz := int64(binary.LittleEndian.Uint32(hdr[:]))
|
||||
if sz == 0 {
|
||||
return 0, fmt.Errorf("map response: zero-length frame")
|
||||
}
|
||||
if sz > f.maxSize {
|
||||
return 0, fmt.Errorf("map response frame size %d exceeds max %d", sz, f.maxSize)
|
||||
}
|
||||
f.remain = int(sz)
|
||||
if f.onNewFrame != nil {
|
||||
f.onNewFrame()
|
||||
}
|
||||
}
|
||||
if len(p) > f.remain {
|
||||
p = p[:f.remain]
|
||||
}
|
||||
n, err := f.r.Read(p)
|
||||
f.remain -= n
|
||||
return n, err
|
||||
}
|
||||
|
||||
// boundedReader is an io.Reader that yields at most remain bytes from r
|
||||
// before returning an error. Call reset to raise the budget back to max,
|
||||
// typically at a new message boundary.
|
||||
//
|
||||
// Used to cap the decoded size of a single map response so a malicious
|
||||
// server can't send a small zstd frame that explodes into gigabytes of
|
||||
// junk for the json.Decoder to consume.
|
||||
type boundedReader struct {
|
||||
r io.Reader
|
||||
max int64
|
||||
remain int64
|
||||
}
|
||||
|
||||
func (b *boundedReader) Read(p []byte) (int, error) {
|
||||
if b.remain <= 0 {
|
||||
return 0, fmt.Errorf("map response decoded size exceeds max %d", b.max)
|
||||
}
|
||||
if int64(len(p)) > b.remain {
|
||||
p = p[:b.remain]
|
||||
}
|
||||
n, err := b.r.Read(p)
|
||||
b.remain -= int64(n)
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (b *boundedReader) reset() { b.remain = b.max }
|
||||
|
||||
// MapSession wraps an in-progress map response stream. Call Next to read
|
||||
// each MapResponse. Call Close when done.
|
||||
type MapSession struct {
|
||||
res *http.Response
|
||||
stream bool
|
||||
noiseDoer func(*http.Request) (*http.Response, error)
|
||||
|
||||
// inNext detects concurrent NextInto callers. It CAS-flips
|
||||
// false→true on entry and back to false on exit; a failed CAS
|
||||
// panics, akin to how the Go runtime detects concurrent map
|
||||
// access. It does not serialize Close vs. NextInto; that's
|
||||
// nextMu's job.
|
||||
inNext atomic.Bool
|
||||
|
||||
// nextMu is held while [MapSession.NextInto] is running jdec.Decode,
|
||||
// so that Close can wait for an in-flight Decode to unwind before it
|
||||
// touches zdec (Reset, pool-Put) and avoids racing with the running
|
||||
// Read chain that Decode drives.
|
||||
nextMu sync.Mutex
|
||||
read int // guarded by nextMu
|
||||
closed bool // guarded by nextMu
|
||||
zdec *zstd.Decoder // reads from a framedReader wrapping res.Body
|
||||
jdec *json.Decoder // reads decompressed JSON from zdec
|
||||
|
||||
closeOnce sync.Once
|
||||
closeErr error
|
||||
}
|
||||
|
||||
// NoiseRoundTrip sends an HTTP request over the Noise channel used by this map session.
|
||||
func (s *MapSession) NoiseRoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return s.noiseDoer(req)
|
||||
}
|
||||
|
||||
// Next reads and returns the next MapResponse from the stream.
|
||||
// For non-streaming sessions, the first call returns the single response
|
||||
// and subsequent calls return io.EOF.
|
||||
// For streaming sessions, Next blocks until the next response arrives
|
||||
// or the server closes the connection.
|
||||
//
|
||||
// Each call allocates a fresh MapResponse. Callers that want to amortize
|
||||
// the allocation across calls can use [MapSession.NextInto].
|
||||
//
|
||||
// Next and NextInto are not safe to call concurrently from multiple
|
||||
// goroutines on the same [MapSession]; a concurrent call panics, akin
|
||||
// to the Go runtime's concurrent map access detection. [MapSession.Close]
|
||||
// may be called concurrently to abort an in-flight Next.
|
||||
func (s *MapSession) Next() (*tailcfg.MapResponse, error) {
|
||||
var resp tailcfg.MapResponse
|
||||
if err := s.NextInto(&resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// NextInto is like [MapSession.Next] but decodes the next MapResponse into
|
||||
// the caller-supplied *resp rather than allocating a new one. The pointer's
|
||||
// pointee is zeroed before decoding so fields from a prior response do not
|
||||
// persist.
|
||||
//
|
||||
// For non-streaming sessions, the first call decodes the single response
|
||||
// and subsequent calls return io.EOF.
|
||||
// For streaming sessions, NextInto blocks until the next response arrives
|
||||
// or the server closes the connection.
|
||||
//
|
||||
// See [MapSession.Next] for concurrency rules; those apply to NextInto too.
|
||||
func (s *MapSession) NextInto(resp *tailcfg.MapResponse) error {
|
||||
if !s.inNext.CompareAndSwap(false, true) {
|
||||
panic("tsp: invalid concurrent call to MapSession.Next/NextInto")
|
||||
}
|
||||
defer s.inNext.Store(false)
|
||||
|
||||
s.nextMu.Lock()
|
||||
defer s.nextMu.Unlock()
|
||||
if s.closed {
|
||||
return errSessionClosed
|
||||
}
|
||||
if !s.stream && s.read > 0 {
|
||||
return io.EOF
|
||||
}
|
||||
*resp = tailcfg.MapResponse{}
|
||||
if err := s.jdec.Decode(resp); err != nil {
|
||||
return err
|
||||
}
|
||||
s.read++
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close returns the session's zstd decoder to the pool and closes the
|
||||
// underlying HTTP response body. It is safe to call Close multiple times
|
||||
// and from multiple goroutines, including while a [MapSession.Next] or
|
||||
// [MapSession.NextInto] call is in flight on another goroutine (which
|
||||
// will return an error once the body close propagates).
|
||||
func (s *MapSession) Close() error {
|
||||
// Callers are likely to race a deferred Close with a time.AfterFunc
|
||||
// timeout (or similar) Close that aborts a hung Next. Without the
|
||||
// Once, both Closes would Put the same *zstd.Decoder into the pool,
|
||||
// corrupting it, and the Reset/Put in one would race with the
|
||||
// zdec.Read that the hung Next is driving.
|
||||
//
|
||||
// Ordering inside the Once: close the body first to unblock any
|
||||
// in-flight NextInto (its Read chain ends at res.Body and will
|
||||
// return an error once it's closed). That lets NextInto unwind and
|
||||
// release nextMu. Only then do we take nextMu ourselves and touch
|
||||
// zdec, which is safe because no goroutine is still reading from
|
||||
// it. Acquiring nextMu before closing the body would deadlock
|
||||
// against a hung NextInto.
|
||||
s.closeOnce.Do(func() {
|
||||
s.closeErr = s.res.Body.Close()
|
||||
s.nextMu.Lock()
|
||||
defer s.nextMu.Unlock()
|
||||
s.closed = true
|
||||
s.zdec.Reset(nil)
|
||||
zstdDecoderPool.Put(s.zdec)
|
||||
})
|
||||
return s.closeErr
|
||||
}
|
||||
|
||||
// Map sends a map request to the coordination server and returns a MapSession
|
||||
// for reading the framed, zstd-compressed response(s).
|
||||
func (c *Client) Map(ctx context.Context, opts MapOpts) (*MapSession, error) {
|
||||
if opts.NodeKey.IsZero() {
|
||||
return nil, fmt.Errorf("NodeKey is required")
|
||||
}
|
||||
|
||||
hi := opts.Hostinfo
|
||||
if hi == nil {
|
||||
hi = defaultHostinfo()
|
||||
}
|
||||
|
||||
mapReq := tailcfg.MapRequest{
|
||||
Version: tailcfg.CurrentCapabilityVersion,
|
||||
NodeKey: opts.NodeKey.Public(),
|
||||
Hostinfo: hi,
|
||||
Stream: opts.Stream,
|
||||
Compress: "zstd",
|
||||
OmitPeers: opts.OmitPeers,
|
||||
// Streaming requires the server to track us as "connected",
|
||||
// which in turn requires ReadOnly=false. Non-streaming polls
|
||||
// stay ReadOnly to minimize side effects.
|
||||
ReadOnly: !opts.Stream,
|
||||
}
|
||||
|
||||
body, err := json.Marshal(mapReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("encoding map request: %w", err)
|
||||
}
|
||||
|
||||
nc, err := c.noiseClient(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("establishing noise connection: %w", err)
|
||||
}
|
||||
|
||||
url := c.serverURL + "/machine/map"
|
||||
url = strings.Replace(url, "http:", "https:", 1)
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating map request: %w", err)
|
||||
}
|
||||
ts2021.AddLBHeader(req, opts.NodeKey.Public())
|
||||
|
||||
res, err := nc.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("map request: %w", err)
|
||||
}
|
||||
|
||||
if res.StatusCode != 200 {
|
||||
msg, _ := io.ReadAll(res.Body)
|
||||
res.Body.Close()
|
||||
return nil, fmt.Errorf("map request: http %d: %.200s",
|
||||
res.StatusCode, strings.TrimSpace(string(msg)))
|
||||
}
|
||||
|
||||
maxMessageSize := cmp.Or(opts.MaxMessageSize, DefaultMaxMessageSize)
|
||||
bounded := &boundedReader{max: maxMessageSize, remain: maxMessageSize}
|
||||
fr := &framedReader{
|
||||
r: res.Body,
|
||||
maxSize: maxMessageSize,
|
||||
onNewFrame: bounded.reset,
|
||||
}
|
||||
|
||||
zdec, _ := zstdDecoderPool.Get().(*zstd.Decoder)
|
||||
if zdec != nil {
|
||||
if err := zdec.Reset(fr); err != nil {
|
||||
// Reset can fail if the previous stream is in a bad state; drop
|
||||
// the decoder and create a fresh one.
|
||||
zdec = nil
|
||||
}
|
||||
}
|
||||
if zdec == nil {
|
||||
zdec, err = zstd.NewReader(fr, zstd.WithDecoderConcurrency(1))
|
||||
if err != nil {
|
||||
res.Body.Close()
|
||||
return nil, fmt.Errorf("creating zstd decoder: %w", err)
|
||||
}
|
||||
}
|
||||
bounded.r = zdec
|
||||
|
||||
return &MapSession{
|
||||
res: res,
|
||||
stream: opts.Stream,
|
||||
noiseDoer: nc.Do,
|
||||
zdec: zdec,
|
||||
jdec: json.NewDecoder(bounded),
|
||||
}, nil
|
||||
}
|
||||
@@ -0,0 +1,270 @@
|
||||
// Copyright (c) Tailscale Inc & contributors
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package tsp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/klauspost/compress/zstd"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/tstest/integration/testcontrol"
|
||||
"tailscale.com/types/key"
|
||||
)
|
||||
|
||||
func TestMapAgainstTestControl(t *testing.T) {
|
||||
ctrl := &testcontrol.Server{}
|
||||
ctrl.HTTPTestServer = httptest.NewUnstartedServer(ctrl)
|
||||
ctrl.HTTPTestServer.Start()
|
||||
t.Cleanup(ctrl.HTTPTestServer.Close)
|
||||
baseURL := ctrl.HTTPTestServer.URL
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
serverKey, err := DiscoverServerKey(ctx, baseURL)
|
||||
if err != nil {
|
||||
t.Fatalf("DiscoverServerKey: %v", err)
|
||||
}
|
||||
|
||||
register := func(hostname string) (nodeKey key.NodePrivate, machineKey key.MachinePrivate) {
|
||||
t.Helper()
|
||||
nodeKey = key.NewNode()
|
||||
machineKey = key.NewMachine()
|
||||
c, err := NewClient(ClientOpts{
|
||||
ServerURL: baseURL,
|
||||
MachineKey: machineKey,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewClient %s: %v", hostname, err)
|
||||
}
|
||||
defer c.Close()
|
||||
c.SetControlPublicKey(serverKey)
|
||||
if _, err := c.Register(ctx, RegisterOpts{
|
||||
NodeKey: nodeKey,
|
||||
Hostinfo: &tailcfg.Hostinfo{Hostname: hostname},
|
||||
}); err != nil {
|
||||
t.Fatalf("Register %s: %v", hostname, err)
|
||||
}
|
||||
return nodeKey, machineKey
|
||||
}
|
||||
|
||||
nodeKeyA, machineKeyA := register("a")
|
||||
nodeKeyB, _ := register("b")
|
||||
|
||||
clientA, err := NewClient(ClientOpts{
|
||||
ServerURL: baseURL,
|
||||
MachineKey: machineKeyA,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("NewClient A: %v", err)
|
||||
}
|
||||
defer clientA.Close()
|
||||
clientA.SetControlPublicKey(serverKey)
|
||||
|
||||
session, err := clientA.Map(ctx, MapOpts{
|
||||
NodeKey: nodeKeyA,
|
||||
Hostinfo: &tailcfg.Hostinfo{Hostname: "a"},
|
||||
Stream: true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("Map: %v", err)
|
||||
}
|
||||
defer session.Close()
|
||||
|
||||
// nextNonKeepalive returns the next non-keepalive MapResponse, to keep
|
||||
// the test robust if a server-side keepalive arrives mid-test.
|
||||
nextNonKeepalive := func() *tailcfg.MapResponse {
|
||||
t.Helper()
|
||||
for {
|
||||
resp, err := session.Next()
|
||||
if err != nil {
|
||||
t.Fatalf("session.Next: %v", err)
|
||||
}
|
||||
if resp.KeepAlive {
|
||||
continue
|
||||
}
|
||||
return resp
|
||||
}
|
||||
}
|
||||
|
||||
// First MapResponse: expect node A as self and node B in Peers.
|
||||
first := nextNonKeepalive()
|
||||
if first.Node == nil {
|
||||
t.Fatal("first response has nil Node")
|
||||
}
|
||||
if got, want := first.Node.Key, nodeKeyA.Public(); got != want {
|
||||
t.Errorf("first Node.Key = %v, want %v", got, want)
|
||||
}
|
||||
var foundB bool
|
||||
for _, p := range first.Peers {
|
||||
if p.Key == nodeKeyB.Public() {
|
||||
foundB = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !foundB {
|
||||
t.Errorf("peer B (%v) not in first response's Peers (%d peers)", nodeKeyB.Public(), len(first.Peers))
|
||||
}
|
||||
|
||||
// Inject raw MapResponses and verify they come out the reader, in order.
|
||||
// msgToSend is single-slot, so we must consume each before injecting the next.
|
||||
for i := range 3 {
|
||||
want := fmt.Sprintf("injected-%d.example.com", i)
|
||||
inject := &tailcfg.MapResponse{Domain: want}
|
||||
if !ctrl.AddRawMapResponse(nodeKeyA.Public(), inject) {
|
||||
t.Fatalf("AddRawMapResponse %d: node not connected", i)
|
||||
}
|
||||
got := nextNonKeepalive()
|
||||
if got.Domain != want {
|
||||
t.Errorf("injected %d: got Domain=%q, want %q", i, got.Domain, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// newTestPipeline builds the same framedReader → zstd → boundedReader →
|
||||
// json.Decoder pipeline that [Client.Map] builds for a live session, but
|
||||
// feeds it from a raw byte slice. Returned jdec can be used with Decode to
|
||||
// pull out MapResponses.
|
||||
func newTestPipeline(t testing.TB, wire []byte, maxMessageSize int64) *json.Decoder {
|
||||
t.Helper()
|
||||
bounded := &boundedReader{max: maxMessageSize, remain: maxMessageSize}
|
||||
fr := &framedReader{
|
||||
r: bytes.NewReader(wire),
|
||||
maxSize: maxMessageSize,
|
||||
onNewFrame: bounded.reset,
|
||||
}
|
||||
zdec, err := zstd.NewReader(fr, zstd.WithDecoderConcurrency(1))
|
||||
if err != nil {
|
||||
t.Fatalf("zstd.NewReader: %v", err)
|
||||
}
|
||||
t.Cleanup(zdec.Close)
|
||||
bounded.r = zdec
|
||||
return json.NewDecoder(bounded)
|
||||
}
|
||||
|
||||
// zstdFrame returns a zstd-compressed frame of b.
|
||||
func zstdFrame(t testing.TB, b []byte) []byte {
|
||||
t.Helper()
|
||||
enc, err := zstd.NewWriter(io.Discard, zstd.WithEncoderConcurrency(1))
|
||||
if err != nil {
|
||||
t.Fatalf("zstd.NewWriter: %v", err)
|
||||
}
|
||||
defer enc.Close()
|
||||
return enc.EncodeAll(b, nil)
|
||||
}
|
||||
|
||||
// wireFrame writes a 4-byte little-endian length prefix plus payload to buf.
|
||||
func wireFrame(buf *bytes.Buffer, payload []byte) {
|
||||
var hdr [4]byte
|
||||
binary.LittleEndian.PutUint32(hdr[:], uint32(len(payload)))
|
||||
buf.Write(hdr[:])
|
||||
buf.Write(payload)
|
||||
}
|
||||
|
||||
// TestMapFrameSizeTooLarge verifies that a 4-byte length prefix claiming
|
||||
// a frame larger than the configured cap is rejected before any payload
|
||||
// bytes are read from the stream.
|
||||
func TestMapFrameSizeTooLarge(t *testing.T) {
|
||||
const max = 4 << 20
|
||||
var wire bytes.Buffer
|
||||
var hdr [4]byte
|
||||
binary.LittleEndian.PutUint32(hdr[:], (max + 1))
|
||||
wire.Write(hdr[:])
|
||||
|
||||
jdec := newTestPipeline(t, wire.Bytes(), max)
|
||||
var resp tailcfg.MapResponse
|
||||
err := jdec.Decode(&resp)
|
||||
if err == nil {
|
||||
t.Fatal("Decode: got nil error, want frame-too-large")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "exceeds max") {
|
||||
t.Errorf("Decode error = %q, want one containing %q", err, "exceeds max")
|
||||
}
|
||||
}
|
||||
|
||||
// TestMapDecodedSizeTooLarge verifies that a small on-wire frame (well
|
||||
// under the cap) which decompresses into a huge JSON payload is rejected.
|
||||
// This is the "zstd bomb" case: a tiny compressed frame that would
|
||||
// explode into a huge decoded payload for json.Decoder to consume.
|
||||
func TestMapDecodedSizeTooLarge(t *testing.T) {
|
||||
const max = 4 << 20
|
||||
big := strings.Repeat("a", 5<<20) // 5 MiB of 'a'
|
||||
raw, err := json.Marshal(&tailcfg.MapResponse{Domain: big})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if int64(len(raw)) <= max {
|
||||
t.Fatalf("raw JSON unexpectedly small: %d", len(raw))
|
||||
}
|
||||
compressed := zstdFrame(t, raw)
|
||||
if int64(len(compressed)) >= max {
|
||||
t.Fatalf("compressed too large (%d); test needs a more compressible payload", len(compressed))
|
||||
}
|
||||
|
||||
var wire bytes.Buffer
|
||||
wireFrame(&wire, compressed)
|
||||
|
||||
jdec := newTestPipeline(t, wire.Bytes(), max)
|
||||
var resp tailcfg.MapResponse
|
||||
err = jdec.Decode(&resp)
|
||||
if err == nil {
|
||||
t.Fatal("Decode: got nil error, want decoded-size-exceeded")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "exceeds max") {
|
||||
t.Errorf("Decode error = %q, want one containing %q", err, "exceeds max")
|
||||
}
|
||||
}
|
||||
|
||||
// TestMapBudgetResetsBetweenFrames verifies that the per-message decoded
|
||||
// budget is reset at each new frame boundary. Two consecutive 3-MiB frames
|
||||
// should both decode successfully under a 4-MiB per-frame cap. Without the
|
||||
// reset, the second frame would fail (remaining budget after frame 1 =
|
||||
// 4MiB - 3MiB = 1MiB, and we'd try to read 3MiB more).
|
||||
func TestMapBudgetResetsBetweenFrames(t *testing.T) {
|
||||
const max = 4 << 20
|
||||
payload := strings.Repeat("a", 3<<20)
|
||||
r1 := &tailcfg.MapResponse{Domain: payload + "-one"}
|
||||
r2 := &tailcfg.MapResponse{Domain: payload + "-two"}
|
||||
|
||||
var wire bytes.Buffer
|
||||
for _, r := range []*tailcfg.MapResponse{r1, r2} {
|
||||
raw, err := json.Marshal(r)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if int64(len(raw)) >= max {
|
||||
t.Fatalf("raw JSON size %d >= max %d; would fail budget check by itself", len(raw), max)
|
||||
}
|
||||
compressed := zstdFrame(t, raw)
|
||||
if int64(len(compressed)) >= max {
|
||||
t.Fatalf("compressed size %d >= max %d", len(compressed), max)
|
||||
}
|
||||
wireFrame(&wire, compressed)
|
||||
}
|
||||
|
||||
jdec := newTestPipeline(t, wire.Bytes(), max)
|
||||
|
||||
var got1, got2 tailcfg.MapResponse
|
||||
if err := jdec.Decode(&got1); err != nil {
|
||||
t.Fatalf("first Decode: %v", err)
|
||||
}
|
||||
if got1.Domain != r1.Domain {
|
||||
t.Errorf("first Domain mismatch (len %d vs %d)", len(got1.Domain), len(r1.Domain))
|
||||
}
|
||||
if err := jdec.Decode(&got2); err != nil {
|
||||
t.Fatalf("second Decode: %v", err)
|
||||
}
|
||||
if got2.Domain != r2.Domain {
|
||||
t.Errorf("second Domain mismatch (len %d vs %d)", len(got2.Domain), len(r2.Domain))
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,105 @@
|
||||
// Copyright (c) Tailscale Inc & contributors
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package tsp
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"tailscale.com/types/key"
|
||||
)
|
||||
|
||||
// ServerInfo identifies a coordination server by its URL and Noise public key.
|
||||
type ServerInfo struct {
|
||||
// URL is the base URL of the coordination server, without any path
|
||||
// (e.g. "https://controlplane.tailscale.com").
|
||||
//
|
||||
// There is no default value; a URL must always be supplied.
|
||||
URL string `json:"server_url"`
|
||||
|
||||
// Key is the server's Noise public key, used to establish an encrypted
|
||||
// channel between the client and the coordination server.
|
||||
Key key.MachinePublic `json:"server_key"`
|
||||
}
|
||||
|
||||
// NodeFile is the JSON structure for a node credentials file. It contains
|
||||
// the private keys that authenticate a node to a coordination server.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// {
|
||||
// "node_key": "privkey:...",
|
||||
// "machine_key": "privkey:...",
|
||||
// "server_url": "https://controlplane.tailscale.com",
|
||||
// "server_key": "mkey:..."
|
||||
// }
|
||||
//
|
||||
// Note that node and machine private keys share the same "privkey:"
|
||||
// textual form; they are disambiguated by the surrounding JSON field
|
||||
// names rather than by any prefix in the key itself.
|
||||
type NodeFile struct {
|
||||
// NodeKey is the node's WireGuard private key. The corresponding
|
||||
// public key identifies this node to other peers.
|
||||
NodeKey key.NodePrivate `json:"node_key"`
|
||||
|
||||
// MachineKey is the machine's private key. It authenticates this
|
||||
// machine to the coordination server over Noise.
|
||||
MachineKey key.MachinePrivate `json:"machine_key"`
|
||||
|
||||
ServerInfo // server_url and server_key
|
||||
}
|
||||
|
||||
// ReadNodeFile reads and parses a node JSON file.
|
||||
func ReadNodeFile(path string) (NodeFile, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return NodeFile{}, err
|
||||
}
|
||||
var nf NodeFile
|
||||
if err := json.Unmarshal(data, &nf); err != nil {
|
||||
return NodeFile{}, fmt.Errorf("parsing node file %q: %w", path, err)
|
||||
}
|
||||
return nf, nil
|
||||
}
|
||||
|
||||
// WriteNodeFile writes a node JSON file. The file is created with mode 0600.
|
||||
func WriteNodeFile(path string, nf NodeFile) error {
|
||||
if err := nf.Check(); err != nil {
|
||||
return fmt.Errorf("invalid NodeFile: %w", err)
|
||||
}
|
||||
return os.WriteFile(path, nf.AsJSON(), 0600)
|
||||
}
|
||||
|
||||
// AsJSON returns nf as a pretty-printed JSON object, terminated by a newline.
|
||||
//
|
||||
// It always succeeds and always returns a valid JSON object. It does not
|
||||
// validate that the fields of nf are non-zero; it is the caller's
|
||||
// responsibility to call [NodeFile.Check] first if they want to reject
|
||||
// incomplete NodeFiles.
|
||||
func (nf NodeFile) AsJSON() []byte {
|
||||
out, err := json.MarshalIndent(nf, "", " ")
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("NodeFile.AsJSON: %v", err)) // unreachable: all fields marshal successfully
|
||||
}
|
||||
return append(out, '\n')
|
||||
}
|
||||
|
||||
// Check reports whether nf has all required fields set.
|
||||
// It returns an error describing the first zero-valued field, if any.
|
||||
func (nf NodeFile) Check() error {
|
||||
if nf.NodeKey.IsZero() {
|
||||
return fmt.Errorf("node_key is missing")
|
||||
}
|
||||
if nf.MachineKey.IsZero() {
|
||||
return fmt.Errorf("machine_key is missing")
|
||||
}
|
||||
if nf.URL == "" {
|
||||
return fmt.Errorf("server_url is missing")
|
||||
}
|
||||
if nf.ServerInfo.Key.IsZero() {
|
||||
return fmt.Errorf("server_key is missing")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,116 @@
|
||||
// Copyright (c) Tailscale Inc & contributors
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package tsp
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"tailscale.com/types/key"
|
||||
)
|
||||
|
||||
func TestNodeFileRoundTrip(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "node.json")
|
||||
|
||||
nf := NodeFile{
|
||||
NodeKey: key.NewNode(),
|
||||
MachineKey: key.NewMachine(),
|
||||
ServerInfo: ServerInfo{
|
||||
URL: "https://controlplane.tailscale.com",
|
||||
Key: key.NewMachine().Public(),
|
||||
},
|
||||
}
|
||||
|
||||
if err := WriteNodeFile(path, nf); err != nil {
|
||||
t.Fatalf("WriteNodeFile: %v", err)
|
||||
}
|
||||
|
||||
got, err := ReadNodeFile(path)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadNodeFile: %v", err)
|
||||
}
|
||||
if !got.NodeKey.Equal(nf.NodeKey) {
|
||||
t.Errorf("node key mismatch")
|
||||
}
|
||||
if !got.MachineKey.Equal(nf.MachineKey) {
|
||||
t.Errorf("machine key mismatch")
|
||||
}
|
||||
if got.URL != nf.URL {
|
||||
t.Errorf("server URL = %q, want %q", got.URL, nf.URL)
|
||||
}
|
||||
if got.ServerInfo.Key != nf.ServerInfo.Key {
|
||||
t.Errorf("server key mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
// TestNodeFileFormat verifies that ReadNodeFile can parse a fixed JSON literal,
|
||||
// ensuring we don't accidentally change the on-disk format.
|
||||
func TestNodeFileFormat(t *testing.T) {
|
||||
const fileContents = `{
|
||||
"node_key": "privkey:0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef",
|
||||
"machine_key": "privkey:fedcba9876543210fedcba9876543210fedcba9876543210fedcba9876543210",
|
||||
"server_url": "https://controlplane.tailscale.com",
|
||||
"server_key": "mkey:1111111111111111111111111111111111111111111111111111111111111111"
|
||||
}`
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "node.json")
|
||||
if err := os.WriteFile(path, []byte(fileContents), 0600); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
nf, err := ReadNodeFile(path)
|
||||
if err != nil {
|
||||
t.Fatalf("ReadNodeFile: %v", err)
|
||||
}
|
||||
if nf.NodeKey.IsZero() {
|
||||
t.Error("node key is zero")
|
||||
}
|
||||
if nf.MachineKey.IsZero() {
|
||||
t.Error("machine key is zero")
|
||||
}
|
||||
if nf.URL != "https://controlplane.tailscale.com" {
|
||||
t.Errorf("server URL = %q", nf.URL)
|
||||
}
|
||||
if nf.ServerInfo.Key.IsZero() {
|
||||
t.Error("server key is zero")
|
||||
}
|
||||
}
|
||||
|
||||
// TestNodeFileWriteFormat verifies that WriteNodeFile produces the expected
|
||||
// JSON field names.
|
||||
func TestNodeFileWriteFormat(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
path := filepath.Join(dir, "node.json")
|
||||
|
||||
nf := NodeFile{
|
||||
NodeKey: key.NewNode(),
|
||||
MachineKey: key.NewMachine(),
|
||||
ServerInfo: ServerInfo{
|
||||
URL: "https://example.com",
|
||||
Key: key.NewMachine().Public(),
|
||||
},
|
||||
}
|
||||
|
||||
if err := WriteNodeFile(path, nf); err != nil {
|
||||
t.Fatalf("WriteNodeFile: %v", err)
|
||||
}
|
||||
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var raw map[string]json.RawMessage
|
||||
if err := json.Unmarshal(data, &raw); err != nil {
|
||||
t.Fatalf("parsing written JSON: %v", err)
|
||||
}
|
||||
for _, field := range []string{"node_key", "machine_key", "server_url", "server_key"} {
|
||||
if _, ok := raw[field]; !ok {
|
||||
t.Errorf("missing JSON field %q in written file", field)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,116 @@
|
||||
// Copyright (c) Tailscale Inc & contributors
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package tsp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"cmp"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"tailscale.com/control/ts2021"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/key"
|
||||
)
|
||||
|
||||
// RegisterOpts contains options for registering a node.
|
||||
type RegisterOpts struct {
|
||||
// NodeKey is the node's private key. Required.
|
||||
NodeKey key.NodePrivate
|
||||
|
||||
// Hostinfo is the host information to send. Optional;
|
||||
// if nil, a minimal default is used.
|
||||
Hostinfo *tailcfg.Hostinfo
|
||||
|
||||
// Ephemeral marks the node as ephemeral.
|
||||
Ephemeral bool
|
||||
|
||||
// AuthKey is a pre-authorized auth key.
|
||||
AuthKey string
|
||||
|
||||
// Tags is a list of ACL tags to request.
|
||||
Tags []string
|
||||
|
||||
// MaxResponseSize is the maximum size in bytes of the register
|
||||
// response body. If zero, [DefaultMaxMessageSize] is used.
|
||||
MaxResponseSize int64
|
||||
}
|
||||
|
||||
// Register sends a registration request to the coordination server
|
||||
// and returns the response.
|
||||
func (c *Client) Register(ctx context.Context, opts RegisterOpts) (*tailcfg.RegisterResponse, error) {
|
||||
hi := opts.Hostinfo
|
||||
if hi == nil {
|
||||
hi = defaultHostinfo()
|
||||
}
|
||||
if len(opts.Tags) > 0 {
|
||||
hi.RequestTags = opts.Tags
|
||||
}
|
||||
|
||||
regReq := tailcfg.RegisterRequest{
|
||||
Version: tailcfg.CurrentCapabilityVersion,
|
||||
NodeKey: opts.NodeKey.Public(),
|
||||
Hostinfo: hi,
|
||||
Ephemeral: opts.Ephemeral,
|
||||
}
|
||||
if opts.AuthKey != "" {
|
||||
regReq.Auth = &tailcfg.RegisterResponseAuth{
|
||||
AuthKey: opts.AuthKey,
|
||||
}
|
||||
}
|
||||
|
||||
body, err := json.Marshal(regReq)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("encoding register request: %w", err)
|
||||
}
|
||||
|
||||
nc, err := c.noiseClient(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("establishing noise connection: %w", err)
|
||||
}
|
||||
|
||||
url := c.serverURL + "/machine/register"
|
||||
url = strings.Replace(url, "http:", "https:", 1)
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating register request: %w", err)
|
||||
}
|
||||
ts2021.AddLBHeader(req, opts.NodeKey.Public())
|
||||
|
||||
res, err := nc.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("register request: %w", err)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
maxResponseSize := cmp.Or(opts.MaxResponseSize, DefaultMaxMessageSize)
|
||||
|
||||
if res.StatusCode != 200 {
|
||||
msg, _ := io.ReadAll(io.LimitReader(res.Body, maxResponseSize))
|
||||
return nil, fmt.Errorf("register request: http %d: %.200s",
|
||||
res.StatusCode, strings.TrimSpace(string(msg)))
|
||||
}
|
||||
|
||||
// Read up to maxResponseSize+1 so we can distinguish "exactly at cap" from
|
||||
// "over the cap" rather than relying on a truncated json parse error.
|
||||
data, err := io.ReadAll(io.LimitReader(res.Body, maxResponseSize+1))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("reading register response: %w", err)
|
||||
}
|
||||
if int64(len(data)) > maxResponseSize {
|
||||
return nil, fmt.Errorf("register response exceeds max %d", maxResponseSize)
|
||||
}
|
||||
var resp tailcfg.RegisterResponse
|
||||
if err := json.Unmarshal(data, &resp); err != nil {
|
||||
return nil, fmt.Errorf("decoding register response: %w", err)
|
||||
}
|
||||
if resp.Error != "" {
|
||||
return nil, fmt.Errorf("register: %s", resp.Error)
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
@@ -0,0 +1,251 @@
|
||||
// Copyright (c) Tailscale Inc & contributors
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
// Package tsp provides a client for speaking the Tailscale protocol
|
||||
// to a coordination server over Noise.
|
||||
package tsp
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"cmp"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"sync"
|
||||
|
||||
"tailscale.com/control/ts2021"
|
||||
"tailscale.com/ipn"
|
||||
"tailscale.com/net/tsdial"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/key"
|
||||
"tailscale.com/types/logger"
|
||||
"tailscale.com/version"
|
||||
)
|
||||
|
||||
// DefaultServerURL is the default coordination server base URL,
|
||||
// used when ClientOpts.ServerURL is empty.
|
||||
const DefaultServerURL = ipn.DefaultControlURL
|
||||
|
||||
// ClientOpts contains options for creating a new Client.
|
||||
type ClientOpts struct {
|
||||
// ServerURL is the base URL of the coordination server
|
||||
// (e.g. "https://controlplane.tailscale.com").
|
||||
// If empty, DefaultServerURL is used.
|
||||
ServerURL string
|
||||
|
||||
// MachineKey is this node's machine private key. Required.
|
||||
MachineKey key.MachinePrivate
|
||||
|
||||
// Logf is the log function. If nil, logger.Discard is used.
|
||||
Logf logger.Logf
|
||||
}
|
||||
|
||||
// Client is a Tailscale protocol client that speaks to a coordination
|
||||
// server over Noise.
|
||||
type Client struct {
|
||||
opts ClientOpts
|
||||
serverURL string
|
||||
logf logger.Logf
|
||||
|
||||
mu sync.Mutex
|
||||
nc *ts2021.Client // nil until noiseClient called
|
||||
serverPub key.MachinePublic // zero until set or discovered
|
||||
}
|
||||
|
||||
// NewClient creates a new Client configured to talk to the coordination server
|
||||
// specified in opts. It performs no I/O; the server's public key is discovered
|
||||
// lazily on first use or can be set explicitly via SetControlPublicKey.
|
||||
func NewClient(opts ClientOpts) (*Client, error) {
|
||||
if opts.MachineKey.IsZero() {
|
||||
return nil, fmt.Errorf("MachineKey is required")
|
||||
}
|
||||
logf := opts.Logf
|
||||
if logf == nil {
|
||||
logf = logger.Discard
|
||||
}
|
||||
return &Client{
|
||||
opts: opts,
|
||||
serverURL: cmp.Or(opts.ServerURL, DefaultServerURL),
|
||||
logf: logf,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// SetControlPublicKey sets the server's public key, bypassing lazy discovery.
|
||||
// Any existing noise client is invalidated and will be re-created on next use.
|
||||
func (c *Client) SetControlPublicKey(k key.MachinePublic) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.serverPub = k
|
||||
c.nc = nil
|
||||
}
|
||||
|
||||
// DiscoverServerKey fetches the server's public key from the coordination
|
||||
// server and stores it for subsequent use. Any existing noise client is
|
||||
// invalidated.
|
||||
func (c *Client) DiscoverServerKey(ctx context.Context) (key.MachinePublic, error) {
|
||||
k, err := DiscoverServerKey(ctx, c.serverURL)
|
||||
if err != nil {
|
||||
return key.MachinePublic{}, err
|
||||
}
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.serverPub = k
|
||||
c.nc = nil
|
||||
return k, nil
|
||||
}
|
||||
|
||||
// DiscoverServerKey fetches the coordination server's public key from the
|
||||
// given server URL. It is a standalone function that requires no client state.
|
||||
func DiscoverServerKey(ctx context.Context, serverURL string) (key.MachinePublic, error) {
|
||||
serverURL = cmp.Or(serverURL, DefaultServerURL)
|
||||
keysURL := serverURL + "/key?v=" + strconv.Itoa(int(tailcfg.CurrentCapabilityVersion))
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", keysURL, nil)
|
||||
if err != nil {
|
||||
return key.MachinePublic{}, fmt.Errorf("creating key request: %w", err)
|
||||
}
|
||||
res, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return key.MachinePublic{}, fmt.Errorf("fetching server key: %w", err)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != 200 {
|
||||
return key.MachinePublic{}, fmt.Errorf("fetching server key: %s", res.Status)
|
||||
}
|
||||
var keys struct {
|
||||
PublicKey key.MachinePublic
|
||||
}
|
||||
if err := json.NewDecoder(res.Body).Decode(&keys); err != nil {
|
||||
return key.MachinePublic{}, fmt.Errorf("decoding server key: %w", err)
|
||||
}
|
||||
return keys.PublicKey, nil
|
||||
}
|
||||
|
||||
// noiseClient returns the ts2021 noise client, creating it lazily if needed.
|
||||
// If the server's public key is not yet known, it is discovered via HTTP.
|
||||
func (c *Client) noiseClient(ctx context.Context) (*ts2021.Client, error) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
if c.nc != nil {
|
||||
return c.nc, nil
|
||||
}
|
||||
|
||||
if c.serverPub.IsZero() {
|
||||
// Discover server key without holding the lock, to avoid blocking
|
||||
// other callers during the HTTP request.
|
||||
c.mu.Unlock()
|
||||
k, err := DiscoverServerKey(ctx, c.serverURL)
|
||||
c.mu.Lock()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Re-check: another goroutine may have set it while we were unlocked.
|
||||
if c.serverPub.IsZero() {
|
||||
c.serverPub = k
|
||||
}
|
||||
// If nc was created by another goroutine while unlocked, use it.
|
||||
if c.nc != nil {
|
||||
return c.nc, nil
|
||||
}
|
||||
}
|
||||
|
||||
nc, err := ts2021.NewClient(ts2021.ClientOpts{
|
||||
ServerURL: c.serverURL,
|
||||
PrivKey: c.opts.MachineKey,
|
||||
ServerPubKey: c.serverPub,
|
||||
Dialer: tsdial.NewFromFuncForDebug(c.logf, (&net.Dialer{}).DialContext),
|
||||
Logf: c.logf,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating noise client: %w", err)
|
||||
}
|
||||
c.nc = nc
|
||||
return nc, nil
|
||||
}
|
||||
|
||||
// AnswerC2NPing handles a c2n PingRequest from the control plane by parsing the
|
||||
// embedded HTTP request in the payload, routing it locally, and POSTing the HTTP
|
||||
// response back to pr.URL using doNoiseRequest. The POST is done in a new
|
||||
// goroutine so this method does not block.
|
||||
//
|
||||
// It reports whether the ping was handled. Unhandled pings (nil pr, non-c2n
|
||||
// types, or unrecognized c2n paths) return false.
|
||||
func (c *Client) AnswerC2NPing(ctx context.Context, pr *tailcfg.PingRequest, doNoiseRequest func(*http.Request) (*http.Response, error)) (handled bool) {
|
||||
if pr == nil || pr.Types != "c2n" {
|
||||
return false
|
||||
}
|
||||
|
||||
// Parse the HTTP request from the payload.
|
||||
httpReq, err := http.ReadRequest(bufio.NewReader(bytes.NewReader(pr.Payload)))
|
||||
if err != nil {
|
||||
c.logf("parsing c2n ping payload: %v", err)
|
||||
return false
|
||||
}
|
||||
|
||||
// Route the request locally.
|
||||
var httpResp *http.Response
|
||||
switch httpReq.URL.Path {
|
||||
case "/echo":
|
||||
body, _ := io.ReadAll(httpReq.Body)
|
||||
httpResp = &http.Response{
|
||||
StatusCode: 200,
|
||||
Status: "200 OK",
|
||||
Proto: "HTTP/1.1",
|
||||
ProtoMajor: 1,
|
||||
ProtoMinor: 1,
|
||||
Header: http.Header{},
|
||||
Body: io.NopCloser(bytes.NewReader(body)),
|
||||
ContentLength: int64(len(body)),
|
||||
}
|
||||
default:
|
||||
c.logf("ignoring c2n ping request for unhandled path %q", httpReq.URL.Path)
|
||||
return false
|
||||
}
|
||||
|
||||
// Serialize the HTTP response.
|
||||
var buf bytes.Buffer
|
||||
if err := httpResp.Write(&buf); err != nil {
|
||||
c.logf("serializing c2n ping response: %v", err)
|
||||
return false
|
||||
}
|
||||
|
||||
// Send the response back to the control plane over the Noise channel.
|
||||
go func() {
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", pr.URL, &buf)
|
||||
if err != nil {
|
||||
c.logf("creating c2n ping reply request: %v", err)
|
||||
return
|
||||
}
|
||||
resp, err := doNoiseRequest(req)
|
||||
if err != nil {
|
||||
c.logf("sending c2n ping reply: %v", err)
|
||||
return
|
||||
}
|
||||
resp.Body.Close()
|
||||
}()
|
||||
return true
|
||||
}
|
||||
|
||||
// Close closes the client and releases resources.
|
||||
func (c *Client) Close() error {
|
||||
c.mu.Lock()
|
||||
nc := c.nc
|
||||
c.nc = nil
|
||||
c.mu.Unlock()
|
||||
if nc != nil {
|
||||
nc.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func defaultHostinfo() *tailcfg.Hostinfo {
|
||||
return &tailcfg.Hostinfo{
|
||||
OS: version.OS(),
|
||||
IPNVersion: version.Long(),
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user