control/tsp, cmd/tsp: add low-level Tailscale protocol client and tool

Add a new control/tsp package providing a client for speaking the
Tailscale protocol to a coordination server over Noise, along with a
cmd/tsp binary exposing it as a low-level composable tool for
generating keys, registering nodes, and issuing map requests.

Previously developed out-of-tree at github.com/bradfitz/tsp; imported
here without git history.

Updates #12542

Change-Id: I6ad21143c4aefe8939d4a46ae65b2184173bf69f
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:
Brad Fitzpatrick
2026-04-16 21:15:51 +00:00
committed by Brad Fitzpatrick
parent 69572c7435
commit 50d7176333
7 changed files with 1710 additions and 0 deletions
+339
View File
@@ -0,0 +1,339 @@
// Copyright (c) Tailscale Inc & contributors
// SPDX-License-Identifier: BSD-3-Clause
package tsp
import (
"bytes"
"cmp"
"context"
"encoding/binary"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strings"
"sync"
"sync/atomic"
"github.com/klauspost/compress/zstd"
"tailscale.com/control/ts2021"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
)
// errSessionClosed is returned by [MapSession.Next] and
// [MapSession.NextInto] when called after [MapSession.Close].
var errSessionClosed = errors.New("tsp: map session closed")
// DefaultMaxMessageSize is the default cap, in bytes, on the size of a
// single compressed map response frame. See [MapOpts.MaxMessageSize].
const DefaultMaxMessageSize = 4 << 20
// zstdDecoderPool is a pool of *zstd.Decoder reused across MapSessions to
// amortize the cost of setting up zstd state. Decoders are returned via
// [MapSession.Close]; entries are reclaimed by the runtime under memory
// pressure via sync.Pool semantics.
var zstdDecoderPool sync.Pool // of *zstd.Decoder
// MapOpts contains options for sending a map request.
type MapOpts struct {
// NodeKey is the node's private key. Required.
NodeKey key.NodePrivate
// Hostinfo is the host information to send. Optional;
// if nil, a minimal default is used.
Hostinfo *tailcfg.Hostinfo
// Stream is whether to receive multiple MapResponses over
// the same HTTP connection.
Stream bool
// OmitPeers is whether the client is okay with the Peers list
// being omitted in the response.
OmitPeers bool
// MaxMessageSize is the maximum size in bytes of any single
// compressed map response frame on the wire. If zero,
// [DefaultMaxMessageSize] is used.
MaxMessageSize int64
}
// framedReader is an io.Reader that consumes a stream of length-prefixed
// frames (each a little-endian uint32 length followed by that many bytes)
// from r and yields only the frame payloads back-to-back.
//
// This lets us feed the concatenated zstd frames from our wire protocol
// into a single streaming zstd decoder. Zstd's file format permits
// concatenation (RFC 8478 §2), and klauspost's decoder handles it
// transparently.
//
// If onNewFrame is non-nil, it is called after each new 4-byte length
// header is successfully read. Used to reset the per-message decoded-size
// budget downstream.
type framedReader struct {
r io.Reader
maxSize int64 // per-frame compressed-size cap
remain int // bytes remaining in the current frame
onNewFrame func()
}
func (f *framedReader) Read(p []byte) (int, error) {
if f.remain == 0 {
var hdr [4]byte
if _, err := io.ReadFull(f.r, hdr[:]); err != nil {
return 0, err
}
sz := int64(binary.LittleEndian.Uint32(hdr[:]))
if sz == 0 {
return 0, fmt.Errorf("map response: zero-length frame")
}
if sz > f.maxSize {
return 0, fmt.Errorf("map response frame size %d exceeds max %d", sz, f.maxSize)
}
f.remain = int(sz)
if f.onNewFrame != nil {
f.onNewFrame()
}
}
if len(p) > f.remain {
p = p[:f.remain]
}
n, err := f.r.Read(p)
f.remain -= n
return n, err
}
// boundedReader is an io.Reader that yields at most remain bytes from r
// before returning an error. Call reset to raise the budget back to max,
// typically at a new message boundary.
//
// Used to cap the decoded size of a single map response so a malicious
// server can't send a small zstd frame that explodes into gigabytes of
// junk for the json.Decoder to consume.
type boundedReader struct {
r io.Reader
max int64
remain int64
}
func (b *boundedReader) Read(p []byte) (int, error) {
if b.remain <= 0 {
return 0, fmt.Errorf("map response decoded size exceeds max %d", b.max)
}
if int64(len(p)) > b.remain {
p = p[:b.remain]
}
n, err := b.r.Read(p)
b.remain -= int64(n)
return n, err
}
func (b *boundedReader) reset() { b.remain = b.max }
// MapSession wraps an in-progress map response stream. Call Next to read
// each MapResponse. Call Close when done.
type MapSession struct {
res *http.Response
stream bool
noiseDoer func(*http.Request) (*http.Response, error)
// inNext detects concurrent NextInto callers. It CAS-flips
// false→true on entry and back to false on exit; a failed CAS
// panics, akin to how the Go runtime detects concurrent map
// access. It does not serialize Close vs. NextInto; that's
// nextMu's job.
inNext atomic.Bool
// nextMu is held while [MapSession.NextInto] is running jdec.Decode,
// so that Close can wait for an in-flight Decode to unwind before it
// touches zdec (Reset, pool-Put) and avoids racing with the running
// Read chain that Decode drives.
nextMu sync.Mutex
read int // guarded by nextMu
closed bool // guarded by nextMu
zdec *zstd.Decoder // reads from a framedReader wrapping res.Body
jdec *json.Decoder // reads decompressed JSON from zdec
closeOnce sync.Once
closeErr error
}
// NoiseRoundTrip sends an HTTP request over the Noise channel used by this map session.
func (s *MapSession) NoiseRoundTrip(req *http.Request) (*http.Response, error) {
return s.noiseDoer(req)
}
// Next reads and returns the next MapResponse from the stream.
// For non-streaming sessions, the first call returns the single response
// and subsequent calls return io.EOF.
// For streaming sessions, Next blocks until the next response arrives
// or the server closes the connection.
//
// Each call allocates a fresh MapResponse. Callers that want to amortize
// the allocation across calls can use [MapSession.NextInto].
//
// Next and NextInto are not safe to call concurrently from multiple
// goroutines on the same [MapSession]; a concurrent call panics, akin
// to the Go runtime's concurrent map access detection. [MapSession.Close]
// may be called concurrently to abort an in-flight Next.
func (s *MapSession) Next() (*tailcfg.MapResponse, error) {
var resp tailcfg.MapResponse
if err := s.NextInto(&resp); err != nil {
return nil, err
}
return &resp, nil
}
// NextInto is like [MapSession.Next] but decodes the next MapResponse into
// the caller-supplied *resp rather than allocating a new one. The pointer's
// pointee is zeroed before decoding so fields from a prior response do not
// persist.
//
// For non-streaming sessions, the first call decodes the single response
// and subsequent calls return io.EOF.
// For streaming sessions, NextInto blocks until the next response arrives
// or the server closes the connection.
//
// See [MapSession.Next] for concurrency rules; those apply to NextInto too.
func (s *MapSession) NextInto(resp *tailcfg.MapResponse) error {
if !s.inNext.CompareAndSwap(false, true) {
panic("tsp: invalid concurrent call to MapSession.Next/NextInto")
}
defer s.inNext.Store(false)
s.nextMu.Lock()
defer s.nextMu.Unlock()
if s.closed {
return errSessionClosed
}
if !s.stream && s.read > 0 {
return io.EOF
}
*resp = tailcfg.MapResponse{}
if err := s.jdec.Decode(resp); err != nil {
return err
}
s.read++
return nil
}
// Close returns the session's zstd decoder to the pool and closes the
// underlying HTTP response body. It is safe to call Close multiple times
// and from multiple goroutines, including while a [MapSession.Next] or
// [MapSession.NextInto] call is in flight on another goroutine (which
// will return an error once the body close propagates).
func (s *MapSession) Close() error {
// Callers are likely to race a deferred Close with a time.AfterFunc
// timeout (or similar) Close that aborts a hung Next. Without the
// Once, both Closes would Put the same *zstd.Decoder into the pool,
// corrupting it, and the Reset/Put in one would race with the
// zdec.Read that the hung Next is driving.
//
// Ordering inside the Once: close the body first to unblock any
// in-flight NextInto (its Read chain ends at res.Body and will
// return an error once it's closed). That lets NextInto unwind and
// release nextMu. Only then do we take nextMu ourselves and touch
// zdec, which is safe because no goroutine is still reading from
// it. Acquiring nextMu before closing the body would deadlock
// against a hung NextInto.
s.closeOnce.Do(func() {
s.closeErr = s.res.Body.Close()
s.nextMu.Lock()
defer s.nextMu.Unlock()
s.closed = true
s.zdec.Reset(nil)
zstdDecoderPool.Put(s.zdec)
})
return s.closeErr
}
// Map sends a map request to the coordination server and returns a MapSession
// for reading the framed, zstd-compressed response(s).
func (c *Client) Map(ctx context.Context, opts MapOpts) (*MapSession, error) {
if opts.NodeKey.IsZero() {
return nil, fmt.Errorf("NodeKey is required")
}
hi := opts.Hostinfo
if hi == nil {
hi = defaultHostinfo()
}
mapReq := tailcfg.MapRequest{
Version: tailcfg.CurrentCapabilityVersion,
NodeKey: opts.NodeKey.Public(),
Hostinfo: hi,
Stream: opts.Stream,
Compress: "zstd",
OmitPeers: opts.OmitPeers,
// Streaming requires the server to track us as "connected",
// which in turn requires ReadOnly=false. Non-streaming polls
// stay ReadOnly to minimize side effects.
ReadOnly: !opts.Stream,
}
body, err := json.Marshal(mapReq)
if err != nil {
return nil, fmt.Errorf("encoding map request: %w", err)
}
nc, err := c.noiseClient(ctx)
if err != nil {
return nil, fmt.Errorf("establishing noise connection: %w", err)
}
url := c.serverURL + "/machine/map"
url = strings.Replace(url, "http:", "https:", 1)
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body))
if err != nil {
return nil, fmt.Errorf("creating map request: %w", err)
}
ts2021.AddLBHeader(req, opts.NodeKey.Public())
res, err := nc.Do(req)
if err != nil {
return nil, fmt.Errorf("map request: %w", err)
}
if res.StatusCode != 200 {
msg, _ := io.ReadAll(res.Body)
res.Body.Close()
return nil, fmt.Errorf("map request: http %d: %.200s",
res.StatusCode, strings.TrimSpace(string(msg)))
}
maxMessageSize := cmp.Or(opts.MaxMessageSize, DefaultMaxMessageSize)
bounded := &boundedReader{max: maxMessageSize, remain: maxMessageSize}
fr := &framedReader{
r: res.Body,
maxSize: maxMessageSize,
onNewFrame: bounded.reset,
}
zdec, _ := zstdDecoderPool.Get().(*zstd.Decoder)
if zdec != nil {
if err := zdec.Reset(fr); err != nil {
// Reset can fail if the previous stream is in a bad state; drop
// the decoder and create a fresh one.
zdec = nil
}
}
if zdec == nil {
zdec, err = zstd.NewReader(fr, zstd.WithDecoderConcurrency(1))
if err != nil {
res.Body.Close()
return nil, fmt.Errorf("creating zstd decoder: %w", err)
}
}
bounded.r = zdec
return &MapSession{
res: res,
stream: opts.Stream,
noiseDoer: nc.Do,
zdec: zdec,
jdec: json.NewDecoder(bounded),
}, nil
}
+270
View File
@@ -0,0 +1,270 @@
// Copyright (c) Tailscale Inc & contributors
// SPDX-License-Identifier: BSD-3-Clause
package tsp
import (
"bytes"
"context"
"encoding/binary"
"encoding/json"
"fmt"
"io"
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/klauspost/compress/zstd"
"tailscale.com/tailcfg"
"tailscale.com/tstest/integration/testcontrol"
"tailscale.com/types/key"
)
func TestMapAgainstTestControl(t *testing.T) {
ctrl := &testcontrol.Server{}
ctrl.HTTPTestServer = httptest.NewUnstartedServer(ctrl)
ctrl.HTTPTestServer.Start()
t.Cleanup(ctrl.HTTPTestServer.Close)
baseURL := ctrl.HTTPTestServer.URL
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
serverKey, err := DiscoverServerKey(ctx, baseURL)
if err != nil {
t.Fatalf("DiscoverServerKey: %v", err)
}
register := func(hostname string) (nodeKey key.NodePrivate, machineKey key.MachinePrivate) {
t.Helper()
nodeKey = key.NewNode()
machineKey = key.NewMachine()
c, err := NewClient(ClientOpts{
ServerURL: baseURL,
MachineKey: machineKey,
})
if err != nil {
t.Fatalf("NewClient %s: %v", hostname, err)
}
defer c.Close()
c.SetControlPublicKey(serverKey)
if _, err := c.Register(ctx, RegisterOpts{
NodeKey: nodeKey,
Hostinfo: &tailcfg.Hostinfo{Hostname: hostname},
}); err != nil {
t.Fatalf("Register %s: %v", hostname, err)
}
return nodeKey, machineKey
}
nodeKeyA, machineKeyA := register("a")
nodeKeyB, _ := register("b")
clientA, err := NewClient(ClientOpts{
ServerURL: baseURL,
MachineKey: machineKeyA,
})
if err != nil {
t.Fatalf("NewClient A: %v", err)
}
defer clientA.Close()
clientA.SetControlPublicKey(serverKey)
session, err := clientA.Map(ctx, MapOpts{
NodeKey: nodeKeyA,
Hostinfo: &tailcfg.Hostinfo{Hostname: "a"},
Stream: true,
})
if err != nil {
t.Fatalf("Map: %v", err)
}
defer session.Close()
// nextNonKeepalive returns the next non-keepalive MapResponse, to keep
// the test robust if a server-side keepalive arrives mid-test.
nextNonKeepalive := func() *tailcfg.MapResponse {
t.Helper()
for {
resp, err := session.Next()
if err != nil {
t.Fatalf("session.Next: %v", err)
}
if resp.KeepAlive {
continue
}
return resp
}
}
// First MapResponse: expect node A as self and node B in Peers.
first := nextNonKeepalive()
if first.Node == nil {
t.Fatal("first response has nil Node")
}
if got, want := first.Node.Key, nodeKeyA.Public(); got != want {
t.Errorf("first Node.Key = %v, want %v", got, want)
}
var foundB bool
for _, p := range first.Peers {
if p.Key == nodeKeyB.Public() {
foundB = true
break
}
}
if !foundB {
t.Errorf("peer B (%v) not in first response's Peers (%d peers)", nodeKeyB.Public(), len(first.Peers))
}
// Inject raw MapResponses and verify they come out the reader, in order.
// msgToSend is single-slot, so we must consume each before injecting the next.
for i := range 3 {
want := fmt.Sprintf("injected-%d.example.com", i)
inject := &tailcfg.MapResponse{Domain: want}
if !ctrl.AddRawMapResponse(nodeKeyA.Public(), inject) {
t.Fatalf("AddRawMapResponse %d: node not connected", i)
}
got := nextNonKeepalive()
if got.Domain != want {
t.Errorf("injected %d: got Domain=%q, want %q", i, got.Domain, want)
}
}
}
// newTestPipeline builds the same framedReader → zstd → boundedReader →
// json.Decoder pipeline that [Client.Map] builds for a live session, but
// feeds it from a raw byte slice. Returned jdec can be used with Decode to
// pull out MapResponses.
func newTestPipeline(t testing.TB, wire []byte, maxMessageSize int64) *json.Decoder {
t.Helper()
bounded := &boundedReader{max: maxMessageSize, remain: maxMessageSize}
fr := &framedReader{
r: bytes.NewReader(wire),
maxSize: maxMessageSize,
onNewFrame: bounded.reset,
}
zdec, err := zstd.NewReader(fr, zstd.WithDecoderConcurrency(1))
if err != nil {
t.Fatalf("zstd.NewReader: %v", err)
}
t.Cleanup(zdec.Close)
bounded.r = zdec
return json.NewDecoder(bounded)
}
// zstdFrame returns a zstd-compressed frame of b.
func zstdFrame(t testing.TB, b []byte) []byte {
t.Helper()
enc, err := zstd.NewWriter(io.Discard, zstd.WithEncoderConcurrency(1))
if err != nil {
t.Fatalf("zstd.NewWriter: %v", err)
}
defer enc.Close()
return enc.EncodeAll(b, nil)
}
// wireFrame writes a 4-byte little-endian length prefix plus payload to buf.
func wireFrame(buf *bytes.Buffer, payload []byte) {
var hdr [4]byte
binary.LittleEndian.PutUint32(hdr[:], uint32(len(payload)))
buf.Write(hdr[:])
buf.Write(payload)
}
// TestMapFrameSizeTooLarge verifies that a 4-byte length prefix claiming
// a frame larger than the configured cap is rejected before any payload
// bytes are read from the stream.
func TestMapFrameSizeTooLarge(t *testing.T) {
const max = 4 << 20
var wire bytes.Buffer
var hdr [4]byte
binary.LittleEndian.PutUint32(hdr[:], (max + 1))
wire.Write(hdr[:])
jdec := newTestPipeline(t, wire.Bytes(), max)
var resp tailcfg.MapResponse
err := jdec.Decode(&resp)
if err == nil {
t.Fatal("Decode: got nil error, want frame-too-large")
}
if !strings.Contains(err.Error(), "exceeds max") {
t.Errorf("Decode error = %q, want one containing %q", err, "exceeds max")
}
}
// TestMapDecodedSizeTooLarge verifies that a small on-wire frame (well
// under the cap) which decompresses into a huge JSON payload is rejected.
// This is the "zstd bomb" case: a tiny compressed frame that would
// explode into a huge decoded payload for json.Decoder to consume.
func TestMapDecodedSizeTooLarge(t *testing.T) {
const max = 4 << 20
big := strings.Repeat("a", 5<<20) // 5 MiB of 'a'
raw, err := json.Marshal(&tailcfg.MapResponse{Domain: big})
if err != nil {
t.Fatal(err)
}
if int64(len(raw)) <= max {
t.Fatalf("raw JSON unexpectedly small: %d", len(raw))
}
compressed := zstdFrame(t, raw)
if int64(len(compressed)) >= max {
t.Fatalf("compressed too large (%d); test needs a more compressible payload", len(compressed))
}
var wire bytes.Buffer
wireFrame(&wire, compressed)
jdec := newTestPipeline(t, wire.Bytes(), max)
var resp tailcfg.MapResponse
err = jdec.Decode(&resp)
if err == nil {
t.Fatal("Decode: got nil error, want decoded-size-exceeded")
}
if !strings.Contains(err.Error(), "exceeds max") {
t.Errorf("Decode error = %q, want one containing %q", err, "exceeds max")
}
}
// TestMapBudgetResetsBetweenFrames verifies that the per-message decoded
// budget is reset at each new frame boundary. Two consecutive 3-MiB frames
// should both decode successfully under a 4-MiB per-frame cap. Without the
// reset, the second frame would fail (remaining budget after frame 1 =
// 4MiB - 3MiB = 1MiB, and we'd try to read 3MiB more).
func TestMapBudgetResetsBetweenFrames(t *testing.T) {
const max = 4 << 20
payload := strings.Repeat("a", 3<<20)
r1 := &tailcfg.MapResponse{Domain: payload + "-one"}
r2 := &tailcfg.MapResponse{Domain: payload + "-two"}
var wire bytes.Buffer
for _, r := range []*tailcfg.MapResponse{r1, r2} {
raw, err := json.Marshal(r)
if err != nil {
t.Fatal(err)
}
if int64(len(raw)) >= max {
t.Fatalf("raw JSON size %d >= max %d; would fail budget check by itself", len(raw), max)
}
compressed := zstdFrame(t, raw)
if int64(len(compressed)) >= max {
t.Fatalf("compressed size %d >= max %d", len(compressed), max)
}
wireFrame(&wire, compressed)
}
jdec := newTestPipeline(t, wire.Bytes(), max)
var got1, got2 tailcfg.MapResponse
if err := jdec.Decode(&got1); err != nil {
t.Fatalf("first Decode: %v", err)
}
if got1.Domain != r1.Domain {
t.Errorf("first Domain mismatch (len %d vs %d)", len(got1.Domain), len(r1.Domain))
}
if err := jdec.Decode(&got2); err != nil {
t.Fatalf("second Decode: %v", err)
}
if got2.Domain != r2.Domain {
t.Errorf("second Domain mismatch (len %d vs %d)", len(got2.Domain), len(r2.Domain))
}
}
+105
View File
@@ -0,0 +1,105 @@
// Copyright (c) Tailscale Inc & contributors
// SPDX-License-Identifier: BSD-3-Clause
package tsp
import (
"encoding/json"
"fmt"
"os"
"tailscale.com/types/key"
)
// ServerInfo identifies a coordination server by its URL and Noise public key.
type ServerInfo struct {
// URL is the base URL of the coordination server, without any path
// (e.g. "https://controlplane.tailscale.com").
//
// There is no default value; a URL must always be supplied.
URL string `json:"server_url"`
// Key is the server's Noise public key, used to establish an encrypted
// channel between the client and the coordination server.
Key key.MachinePublic `json:"server_key"`
}
// NodeFile is the JSON structure for a node credentials file. It contains
// the private keys that authenticate a node to a coordination server.
//
// Example:
//
// {
// "node_key": "privkey:...",
// "machine_key": "privkey:...",
// "server_url": "https://controlplane.tailscale.com",
// "server_key": "mkey:..."
// }
//
// Note that node and machine private keys share the same "privkey:"
// textual form; they are disambiguated by the surrounding JSON field
// names rather than by any prefix in the key itself.
type NodeFile struct {
// NodeKey is the node's WireGuard private key. The corresponding
// public key identifies this node to other peers.
NodeKey key.NodePrivate `json:"node_key"`
// MachineKey is the machine's private key. It authenticates this
// machine to the coordination server over Noise.
MachineKey key.MachinePrivate `json:"machine_key"`
ServerInfo // server_url and server_key
}
// ReadNodeFile reads and parses a node JSON file.
func ReadNodeFile(path string) (NodeFile, error) {
data, err := os.ReadFile(path)
if err != nil {
return NodeFile{}, err
}
var nf NodeFile
if err := json.Unmarshal(data, &nf); err != nil {
return NodeFile{}, fmt.Errorf("parsing node file %q: %w", path, err)
}
return nf, nil
}
// WriteNodeFile writes a node JSON file. The file is created with mode 0600.
func WriteNodeFile(path string, nf NodeFile) error {
if err := nf.Check(); err != nil {
return fmt.Errorf("invalid NodeFile: %w", err)
}
return os.WriteFile(path, nf.AsJSON(), 0600)
}
// AsJSON returns nf as a pretty-printed JSON object, terminated by a newline.
//
// It always succeeds and always returns a valid JSON object. It does not
// validate that the fields of nf are non-zero; it is the caller's
// responsibility to call [NodeFile.Check] first if they want to reject
// incomplete NodeFiles.
func (nf NodeFile) AsJSON() []byte {
out, err := json.MarshalIndent(nf, "", " ")
if err != nil {
panic(fmt.Sprintf("NodeFile.AsJSON: %v", err)) // unreachable: all fields marshal successfully
}
return append(out, '\n')
}
// Check reports whether nf has all required fields set.
// It returns an error describing the first zero-valued field, if any.
func (nf NodeFile) Check() error {
if nf.NodeKey.IsZero() {
return fmt.Errorf("node_key is missing")
}
if nf.MachineKey.IsZero() {
return fmt.Errorf("machine_key is missing")
}
if nf.URL == "" {
return fmt.Errorf("server_url is missing")
}
if nf.ServerInfo.Key.IsZero() {
return fmt.Errorf("server_key is missing")
}
return nil
}
+116
View File
@@ -0,0 +1,116 @@
// Copyright (c) Tailscale Inc & contributors
// SPDX-License-Identifier: BSD-3-Clause
package tsp
import (
"encoding/json"
"os"
"path/filepath"
"testing"
"tailscale.com/types/key"
)
func TestNodeFileRoundTrip(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "node.json")
nf := NodeFile{
NodeKey: key.NewNode(),
MachineKey: key.NewMachine(),
ServerInfo: ServerInfo{
URL: "https://controlplane.tailscale.com",
Key: key.NewMachine().Public(),
},
}
if err := WriteNodeFile(path, nf); err != nil {
t.Fatalf("WriteNodeFile: %v", err)
}
got, err := ReadNodeFile(path)
if err != nil {
t.Fatalf("ReadNodeFile: %v", err)
}
if !got.NodeKey.Equal(nf.NodeKey) {
t.Errorf("node key mismatch")
}
if !got.MachineKey.Equal(nf.MachineKey) {
t.Errorf("machine key mismatch")
}
if got.URL != nf.URL {
t.Errorf("server URL = %q, want %q", got.URL, nf.URL)
}
if got.ServerInfo.Key != nf.ServerInfo.Key {
t.Errorf("server key mismatch")
}
}
// TestNodeFileFormat verifies that ReadNodeFile can parse a fixed JSON literal,
// ensuring we don't accidentally change the on-disk format.
func TestNodeFileFormat(t *testing.T) {
const fileContents = `{
"node_key": "privkey:0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef",
"machine_key": "privkey:fedcba9876543210fedcba9876543210fedcba9876543210fedcba9876543210",
"server_url": "https://controlplane.tailscale.com",
"server_key": "mkey:1111111111111111111111111111111111111111111111111111111111111111"
}`
dir := t.TempDir()
path := filepath.Join(dir, "node.json")
if err := os.WriteFile(path, []byte(fileContents), 0600); err != nil {
t.Fatal(err)
}
nf, err := ReadNodeFile(path)
if err != nil {
t.Fatalf("ReadNodeFile: %v", err)
}
if nf.NodeKey.IsZero() {
t.Error("node key is zero")
}
if nf.MachineKey.IsZero() {
t.Error("machine key is zero")
}
if nf.URL != "https://controlplane.tailscale.com" {
t.Errorf("server URL = %q", nf.URL)
}
if nf.ServerInfo.Key.IsZero() {
t.Error("server key is zero")
}
}
// TestNodeFileWriteFormat verifies that WriteNodeFile produces the expected
// JSON field names.
func TestNodeFileWriteFormat(t *testing.T) {
dir := t.TempDir()
path := filepath.Join(dir, "node.json")
nf := NodeFile{
NodeKey: key.NewNode(),
MachineKey: key.NewMachine(),
ServerInfo: ServerInfo{
URL: "https://example.com",
Key: key.NewMachine().Public(),
},
}
if err := WriteNodeFile(path, nf); err != nil {
t.Fatalf("WriteNodeFile: %v", err)
}
data, err := os.ReadFile(path)
if err != nil {
t.Fatal(err)
}
var raw map[string]json.RawMessage
if err := json.Unmarshal(data, &raw); err != nil {
t.Fatalf("parsing written JSON: %v", err)
}
for _, field := range []string{"node_key", "machine_key", "server_url", "server_key"} {
if _, ok := raw[field]; !ok {
t.Errorf("missing JSON field %q in written file", field)
}
}
}
+116
View File
@@ -0,0 +1,116 @@
// Copyright (c) Tailscale Inc & contributors
// SPDX-License-Identifier: BSD-3-Clause
package tsp
import (
"bytes"
"cmp"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"tailscale.com/control/ts2021"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
)
// RegisterOpts contains options for registering a node.
type RegisterOpts struct {
// NodeKey is the node's private key. Required.
NodeKey key.NodePrivate
// Hostinfo is the host information to send. Optional;
// if nil, a minimal default is used.
Hostinfo *tailcfg.Hostinfo
// Ephemeral marks the node as ephemeral.
Ephemeral bool
// AuthKey is a pre-authorized auth key.
AuthKey string
// Tags is a list of ACL tags to request.
Tags []string
// MaxResponseSize is the maximum size in bytes of the register
// response body. If zero, [DefaultMaxMessageSize] is used.
MaxResponseSize int64
}
// Register sends a registration request to the coordination server
// and returns the response.
func (c *Client) Register(ctx context.Context, opts RegisterOpts) (*tailcfg.RegisterResponse, error) {
hi := opts.Hostinfo
if hi == nil {
hi = defaultHostinfo()
}
if len(opts.Tags) > 0 {
hi.RequestTags = opts.Tags
}
regReq := tailcfg.RegisterRequest{
Version: tailcfg.CurrentCapabilityVersion,
NodeKey: opts.NodeKey.Public(),
Hostinfo: hi,
Ephemeral: opts.Ephemeral,
}
if opts.AuthKey != "" {
regReq.Auth = &tailcfg.RegisterResponseAuth{
AuthKey: opts.AuthKey,
}
}
body, err := json.Marshal(regReq)
if err != nil {
return nil, fmt.Errorf("encoding register request: %w", err)
}
nc, err := c.noiseClient(ctx)
if err != nil {
return nil, fmt.Errorf("establishing noise connection: %w", err)
}
url := c.serverURL + "/machine/register"
url = strings.Replace(url, "http:", "https:", 1)
req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body))
if err != nil {
return nil, fmt.Errorf("creating register request: %w", err)
}
ts2021.AddLBHeader(req, opts.NodeKey.Public())
res, err := nc.Do(req)
if err != nil {
return nil, fmt.Errorf("register request: %w", err)
}
defer res.Body.Close()
maxResponseSize := cmp.Or(opts.MaxResponseSize, DefaultMaxMessageSize)
if res.StatusCode != 200 {
msg, _ := io.ReadAll(io.LimitReader(res.Body, maxResponseSize))
return nil, fmt.Errorf("register request: http %d: %.200s",
res.StatusCode, strings.TrimSpace(string(msg)))
}
// Read up to maxResponseSize+1 so we can distinguish "exactly at cap" from
// "over the cap" rather than relying on a truncated json parse error.
data, err := io.ReadAll(io.LimitReader(res.Body, maxResponseSize+1))
if err != nil {
return nil, fmt.Errorf("reading register response: %w", err)
}
if int64(len(data)) > maxResponseSize {
return nil, fmt.Errorf("register response exceeds max %d", maxResponseSize)
}
var resp tailcfg.RegisterResponse
if err := json.Unmarshal(data, &resp); err != nil {
return nil, fmt.Errorf("decoding register response: %w", err)
}
if resp.Error != "" {
return nil, fmt.Errorf("register: %s", resp.Error)
}
return &resp, nil
}
+251
View File
@@ -0,0 +1,251 @@
// Copyright (c) Tailscale Inc & contributors
// SPDX-License-Identifier: BSD-3-Clause
// Package tsp provides a client for speaking the Tailscale protocol
// to a coordination server over Noise.
package tsp
import (
"bufio"
"bytes"
"cmp"
"context"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"strconv"
"sync"
"tailscale.com/control/ts2021"
"tailscale.com/ipn"
"tailscale.com/net/tsdial"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
"tailscale.com/types/logger"
"tailscale.com/version"
)
// DefaultServerURL is the default coordination server base URL,
// used when ClientOpts.ServerURL is empty.
const DefaultServerURL = ipn.DefaultControlURL
// ClientOpts contains options for creating a new Client.
type ClientOpts struct {
// ServerURL is the base URL of the coordination server
// (e.g. "https://controlplane.tailscale.com").
// If empty, DefaultServerURL is used.
ServerURL string
// MachineKey is this node's machine private key. Required.
MachineKey key.MachinePrivate
// Logf is the log function. If nil, logger.Discard is used.
Logf logger.Logf
}
// Client is a Tailscale protocol client that speaks to a coordination
// server over Noise.
type Client struct {
opts ClientOpts
serverURL string
logf logger.Logf
mu sync.Mutex
nc *ts2021.Client // nil until noiseClient called
serverPub key.MachinePublic // zero until set or discovered
}
// NewClient creates a new Client configured to talk to the coordination server
// specified in opts. It performs no I/O; the server's public key is discovered
// lazily on first use or can be set explicitly via SetControlPublicKey.
func NewClient(opts ClientOpts) (*Client, error) {
if opts.MachineKey.IsZero() {
return nil, fmt.Errorf("MachineKey is required")
}
logf := opts.Logf
if logf == nil {
logf = logger.Discard
}
return &Client{
opts: opts,
serverURL: cmp.Or(opts.ServerURL, DefaultServerURL),
logf: logf,
}, nil
}
// SetControlPublicKey sets the server's public key, bypassing lazy discovery.
// Any existing noise client is invalidated and will be re-created on next use.
func (c *Client) SetControlPublicKey(k key.MachinePublic) {
c.mu.Lock()
defer c.mu.Unlock()
c.serverPub = k
c.nc = nil
}
// DiscoverServerKey fetches the server's public key from the coordination
// server and stores it for subsequent use. Any existing noise client is
// invalidated.
func (c *Client) DiscoverServerKey(ctx context.Context) (key.MachinePublic, error) {
k, err := DiscoverServerKey(ctx, c.serverURL)
if err != nil {
return key.MachinePublic{}, err
}
c.mu.Lock()
defer c.mu.Unlock()
c.serverPub = k
c.nc = nil
return k, nil
}
// DiscoverServerKey fetches the coordination server's public key from the
// given server URL. It is a standalone function that requires no client state.
func DiscoverServerKey(ctx context.Context, serverURL string) (key.MachinePublic, error) {
serverURL = cmp.Or(serverURL, DefaultServerURL)
keysURL := serverURL + "/key?v=" + strconv.Itoa(int(tailcfg.CurrentCapabilityVersion))
req, err := http.NewRequestWithContext(ctx, "GET", keysURL, nil)
if err != nil {
return key.MachinePublic{}, fmt.Errorf("creating key request: %w", err)
}
res, err := http.DefaultClient.Do(req)
if err != nil {
return key.MachinePublic{}, fmt.Errorf("fetching server key: %w", err)
}
defer res.Body.Close()
if res.StatusCode != 200 {
return key.MachinePublic{}, fmt.Errorf("fetching server key: %s", res.Status)
}
var keys struct {
PublicKey key.MachinePublic
}
if err := json.NewDecoder(res.Body).Decode(&keys); err != nil {
return key.MachinePublic{}, fmt.Errorf("decoding server key: %w", err)
}
return keys.PublicKey, nil
}
// noiseClient returns the ts2021 noise client, creating it lazily if needed.
// If the server's public key is not yet known, it is discovered via HTTP.
func (c *Client) noiseClient(ctx context.Context) (*ts2021.Client, error) {
c.mu.Lock()
defer c.mu.Unlock()
if c.nc != nil {
return c.nc, nil
}
if c.serverPub.IsZero() {
// Discover server key without holding the lock, to avoid blocking
// other callers during the HTTP request.
c.mu.Unlock()
k, err := DiscoverServerKey(ctx, c.serverURL)
c.mu.Lock()
if err != nil {
return nil, err
}
// Re-check: another goroutine may have set it while we were unlocked.
if c.serverPub.IsZero() {
c.serverPub = k
}
// If nc was created by another goroutine while unlocked, use it.
if c.nc != nil {
return c.nc, nil
}
}
nc, err := ts2021.NewClient(ts2021.ClientOpts{
ServerURL: c.serverURL,
PrivKey: c.opts.MachineKey,
ServerPubKey: c.serverPub,
Dialer: tsdial.NewFromFuncForDebug(c.logf, (&net.Dialer{}).DialContext),
Logf: c.logf,
})
if err != nil {
return nil, fmt.Errorf("creating noise client: %w", err)
}
c.nc = nc
return nc, nil
}
// AnswerC2NPing handles a c2n PingRequest from the control plane by parsing the
// embedded HTTP request in the payload, routing it locally, and POSTing the HTTP
// response back to pr.URL using doNoiseRequest. The POST is done in a new
// goroutine so this method does not block.
//
// It reports whether the ping was handled. Unhandled pings (nil pr, non-c2n
// types, or unrecognized c2n paths) return false.
func (c *Client) AnswerC2NPing(ctx context.Context, pr *tailcfg.PingRequest, doNoiseRequest func(*http.Request) (*http.Response, error)) (handled bool) {
if pr == nil || pr.Types != "c2n" {
return false
}
// Parse the HTTP request from the payload.
httpReq, err := http.ReadRequest(bufio.NewReader(bytes.NewReader(pr.Payload)))
if err != nil {
c.logf("parsing c2n ping payload: %v", err)
return false
}
// Route the request locally.
var httpResp *http.Response
switch httpReq.URL.Path {
case "/echo":
body, _ := io.ReadAll(httpReq.Body)
httpResp = &http.Response{
StatusCode: 200,
Status: "200 OK",
Proto: "HTTP/1.1",
ProtoMajor: 1,
ProtoMinor: 1,
Header: http.Header{},
Body: io.NopCloser(bytes.NewReader(body)),
ContentLength: int64(len(body)),
}
default:
c.logf("ignoring c2n ping request for unhandled path %q", httpReq.URL.Path)
return false
}
// Serialize the HTTP response.
var buf bytes.Buffer
if err := httpResp.Write(&buf); err != nil {
c.logf("serializing c2n ping response: %v", err)
return false
}
// Send the response back to the control plane over the Noise channel.
go func() {
req, err := http.NewRequestWithContext(ctx, "POST", pr.URL, &buf)
if err != nil {
c.logf("creating c2n ping reply request: %v", err)
return
}
resp, err := doNoiseRequest(req)
if err != nil {
c.logf("sending c2n ping reply: %v", err)
return
}
resp.Body.Close()
}()
return true
}
// Close closes the client and releases resources.
func (c *Client) Close() error {
c.mu.Lock()
nc := c.nc
c.nc = nil
c.mu.Unlock()
if nc != nil {
nc.Close()
}
return nil
}
func defaultHostinfo() *tailcfg.Hostinfo {
return &tailcfg.Hostinfo{
OS: version.OS(),
IPNVersion: version.Long(),
}
}