diff --git a/cmd/tsp/tsp.go b/cmd/tsp/tsp.go new file mode 100644 index 000000000..a59b352d5 --- /dev/null +++ b/cmd/tsp/tsp.go @@ -0,0 +1,513 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +// Program tsp is a low-level Tailscale protocol tool for performing +// composable building block operations like generating keys and +// registering nodes. +package main + +import ( + "bytes" + "cmp" + "context" + "encoding/json" + "errors" + "flag" + "fmt" + "io" + "os" + "reflect" + "strings" + + "github.com/peterbourgon/ff/v3/ffcli" + "tailscale.com/control/tsp" + "tailscale.com/hostinfo" + "tailscale.com/tailcfg" + "tailscale.com/types/key" +) + +var globalArgs struct { + // serverURL is the base URL of the coordination server (-s flag). + // If empty, tsp.DefaultServerURL is used. + serverURL string + + // controlKeyFile is a path to a file containing the server's + // MachinePublic key in MarshalText form (--control-key flag). + // When set, server key discovery is skipped. + controlKeyFile string +} + +func main() { + args := os.Args[1:] + if err := rootCmd.Parse(args); err != nil { + fmt.Fprintln(os.Stderr, err.Error()) + os.Exit(1) + } + err := rootCmd.Run(context.Background()) + if errors.Is(err, flag.ErrHelp) { + os.Exit(0) + } + if err != nil { + fmt.Fprintln(os.Stderr, err.Error()) + os.Exit(1) + } +} + +var rootCmd = &ffcli.Command{ + Name: "tsp", + ShortUsage: "tsp [-s url] [flags]", + ShortHelp: "Low-level Tailscale protocol tool.", + FlagSet: (func() *flag.FlagSet { + fs := flag.NewFlagSet("tsp", flag.ExitOnError) + fs.StringVar(&globalArgs.serverURL, "s", "", "base URL of coordination server (default: "+tsp.DefaultServerURL+")") + fs.StringVar(&globalArgs.controlKeyFile, "control-key", "", "file containing the server's public key (skips discovery)") + return fs + })(), + Subcommands: []*ffcli.Command{ + newMachineKeyCmd, + newNodeKeyCmd, + newNodeCmd, + registerCmd, + mapCmd, + discoverServerKeyCmd, + }, + Exec: func(ctx context.Context, args []string) error { + return flag.ErrHelp + }, +} + +var newMachineKeyArgs struct { + output string +} + +var newMachineKeyCmd = &ffcli.Command{ + Name: "new-machine-key", + ShortUsage: "tsp new-machine-key [-o file]", + ShortHelp: "Generate a new machine key.", + FlagSet: (func() *flag.FlagSet { + fs := flag.NewFlagSet("new-machine-key", flag.ExitOnError) + fs.StringVar(&newMachineKeyArgs.output, "o", "", "output file (default: stdout)") + return fs + })(), + Exec: runNewMachineKey, +} + +func runNewMachineKey(ctx context.Context, args []string) error { + k := key.NewMachine() + text, err := k.MarshalText() + if err != nil { + return err + } + text = append(text, '\n') + return writeOutput(newMachineKeyArgs.output, text) +} + +var newNodeKeyArgs struct { + output string +} + +var newNodeKeyCmd = &ffcli.Command{ + Name: "new-node-key", + ShortUsage: "tsp new-node-key [-o file]", + ShortHelp: "Generate a new node key.", + FlagSet: (func() *flag.FlagSet { + fs := flag.NewFlagSet("new-node-key", flag.ExitOnError) + fs.StringVar(&newNodeKeyArgs.output, "o", "", "output file (default: stdout)") + return fs + })(), + Exec: runNewNodeKey, +} + +func runNewNodeKey(ctx context.Context, args []string) error { + k := key.NewNode() + text, err := k.MarshalText() + if err != nil { + return err + } + text = append(text, '\n') + return writeOutput(newNodeKeyArgs.output, text) +} + +var discoverServerKeyArgs struct { + output string +} + +var discoverServerKeyCmd = &ffcli.Command{ + Name: "discover-server-key", + ShortUsage: "tsp [-s url] discover-server-key [-o file]", + ShortHelp: "Discover and print the coordination server's public key.", + FlagSet: (func() *flag.FlagSet { + fs := flag.NewFlagSet("discover-server-key", flag.ExitOnError) + fs.StringVar(&discoverServerKeyArgs.output, "o", "", "output file (default: stdout)") + return fs + })(), + Exec: runDiscoverServerKey, +} + +func runDiscoverServerKey(ctx context.Context, args []string) error { + k, err := tsp.DiscoverServerKey(ctx, globalArgs.serverURL) + if err != nil { + return err + } + text, err := k.MarshalText() + if err != nil { + return fmt.Errorf("marshaling server key: %w", err) + } + text = append(text, '\n') + return writeOutput(discoverServerKeyArgs.output, text) +} + +var newNodeArgs struct { + nodeKeyFile string + machineKeyFile string + output string +} + +var newNodeCmd = &ffcli.Command{ + Name: "new-node", + ShortUsage: "tsp [-s url] [--control-key file] new-node [-n node-key-file] [-m machine-key-file] [-o output]", + ShortHelp: "Generate a new node JSON file with keys and server info.", + FlagSet: (func() *flag.FlagSet { + fs := flag.NewFlagSet("new-node", flag.ExitOnError) + fs.StringVar(&newNodeArgs.nodeKeyFile, "n", "", "existing node key file (default: generate new)") + fs.StringVar(&newNodeArgs.machineKeyFile, "m", "", "existing machine key file (default: generate new)") + fs.StringVar(&newNodeArgs.output, "o", "", "output file (default: stdout)") + return fs + })(), + Exec: runNewNode, +} + +func runNewNode(ctx context.Context, args []string) error { + var nodeKey key.NodePrivate + if newNodeArgs.nodeKeyFile != "" { + var err error + nodeKey, err = readNodeKeyFile(newNodeArgs.nodeKeyFile) + if err != nil { + return fmt.Errorf("reading node key: %w", err) + } + } else { + nodeKey = key.NewNode() + } + + var machineKey key.MachinePrivate + if newNodeArgs.machineKeyFile != "" { + var err error + machineKey, err = readMachineKeyFile(newNodeArgs.machineKeyFile) + if err != nil { + return fmt.Errorf("reading machine key: %w", err) + } + } else { + machineKey = key.NewMachine() + } + + serverURL := cmp.Or(globalArgs.serverURL, tsp.DefaultServerURL) + + var serverKey key.MachinePublic + if globalArgs.controlKeyFile != "" { + var err error + serverKey, err = readControlKeyFile(globalArgs.controlKeyFile) + if err != nil { + return fmt.Errorf("reading control key: %w", err) + } + } else { + var err error + serverKey, err = tsp.DiscoverServerKey(ctx, serverURL) + if err != nil { + return fmt.Errorf("discovering server key: %w", err) + } + } + + nf := tsp.NodeFile{ + NodeKey: nodeKey, + MachineKey: machineKey, + ServerInfo: tsp.ServerInfo{URL: serverURL, Key: serverKey}, + } + + out, err := json.MarshalIndent(nf, "", " ") + if err != nil { + return fmt.Errorf("encoding node file: %w", err) + } + out = append(out, '\n') + return writeOutput(newNodeArgs.output, out) +} + +var registerArgs struct { + nodeFile string + output string + hostname string + ephemeral bool + authKey string + tags string +} + +var registerCmd = &ffcli.Command{ + Name: "register", + ShortUsage: "tsp [-s url] register -n [flags]", + ShortHelp: "Register a node key with a coordination server.", + FlagSet: (func() *flag.FlagSet { + fs := flag.NewFlagSet("register", flag.ExitOnError) + fs.StringVar(®isterArgs.nodeFile, "n", "", "node JSON file (required)") + fs.StringVar(®isterArgs.output, "o", "", "output file (default: stdout)") + fs.StringVar(®isterArgs.hostname, "hostname", "", "hostname to register") + fs.BoolVar(®isterArgs.ephemeral, "ephemeral", false, "register as ephemeral node") + fs.StringVar(®isterArgs.authKey, "auth-key", "", "pre-authorized auth key or file containing one") + fs.StringVar(®isterArgs.tags, "tags", "", "comma-separated ACL tags") + return fs + })(), + Exec: runRegister, +} + +func runRegister(ctx context.Context, args []string) error { + if registerArgs.nodeFile == "" { + return fmt.Errorf("flag -n (node file) is required") + } + + nf, err := tsp.ReadNodeFile(registerArgs.nodeFile) + if err != nil { + return fmt.Errorf("reading node file: %w", err) + } + + hi := hostinfo.New() + if registerArgs.hostname != "" { + hi.Hostname = registerArgs.hostname + } + + var tags []string + if registerArgs.tags != "" { + tags = strings.Split(registerArgs.tags, ",") + } + + authKey, err := resolveAuthKey(registerArgs.authKey) + if err != nil { + return err + } + + client, err := tsp.NewClient(tsp.ClientOpts{ + ServerURL: cmp.Or(globalArgs.serverURL, nf.URL), + MachineKey: nf.MachineKey, + }) + if err != nil { + return fmt.Errorf("creating client: %w", err) + } + defer client.Close() + + if globalArgs.controlKeyFile != "" { + controlKey, err := readControlKeyFile(globalArgs.controlKeyFile) + if err != nil { + return fmt.Errorf("reading control key: %w", err) + } + client.SetControlPublicKey(controlKey) + } else { + client.SetControlPublicKey(nf.ServerInfo.Key) + } + + resp, err := client.Register(ctx, tsp.RegisterOpts{ + NodeKey: nf.NodeKey, + Hostinfo: hi, + Ephemeral: registerArgs.ephemeral, + AuthKey: authKey, + Tags: tags, + }) + if err != nil { + return err + } + + out, err := json.MarshalIndent(resp, "", " ") + if err != nil { + return fmt.Errorf("encoding response: %w", err) + } + out = append(out, '\n') + + if err := writeOutput(registerArgs.output, out); err != nil { + return err + } + + if resp.AuthURL != "" { + fmt.Fprintf(os.Stderr, "AuthURL: %s\n", resp.AuthURL) + } + return nil +} + +var mapArgs struct { + nodeFile string + stream bool + peers bool + quiet bool + output string +} + +var mapCmd = &ffcli.Command{ + Name: "map", + ShortUsage: "tsp [-s url] map -n [-stream]", + ShortHelp: "Send a map request to the coordination server.", + FlagSet: (func() *flag.FlagSet { + fs := flag.NewFlagSet("map", flag.ExitOnError) + fs.StringVar(&mapArgs.nodeFile, "n", "", "node JSON file (required)") + fs.BoolVar(&mapArgs.stream, "stream", false, "stream map responses") + fs.BoolVar(&mapArgs.peers, "peers", true, "include peers in map response") + fs.BoolVar(&mapArgs.quiet, "quiet", true, "suppress keepalives and handled c2n ping requests from output") + fs.StringVar(&mapArgs.output, "o", "", "output file (default: stdout)") + return fs + })(), + Exec: runMap, +} + +func runMap(ctx context.Context, args []string) error { + if mapArgs.nodeFile == "" { + return fmt.Errorf("flag -n (node file) is required") + } + + nf, err := tsp.ReadNodeFile(mapArgs.nodeFile) + if err != nil { + return fmt.Errorf("reading node file: %w", err) + } + + if globalArgs.serverURL != "" && globalArgs.serverURL != nf.URL { + return fmt.Errorf("server URL mismatch: -s flag is %q but node file is for %q", globalArgs.serverURL, nf.URL) + } + + hi := hostinfo.New() + + client, err := tsp.NewClient(tsp.ClientOpts{ + ServerURL: cmp.Or(globalArgs.serverURL, nf.URL), + MachineKey: nf.MachineKey, + }) + if err != nil { + return fmt.Errorf("creating client: %w", err) + } + defer client.Close() + + if globalArgs.controlKeyFile != "" { + controlKey, err := readControlKeyFile(globalArgs.controlKeyFile) + if err != nil { + return fmt.Errorf("reading control key: %w", err) + } + client.SetControlPublicKey(controlKey) + } else { + client.SetControlPublicKey(nf.ServerInfo.Key) + } + + session, err := client.Map(ctx, tsp.MapOpts{ + NodeKey: nf.NodeKey, + Hostinfo: hi, + Stream: mapArgs.stream, + OmitPeers: !mapArgs.peers, + }) + if err != nil { + return err + } + defer session.Close() + + gotResponse := false + for { + resp, err := session.Next() + if err == io.EOF { + if !gotResponse { + return fmt.Errorf("server returned no map response") + } + return nil + } + if err != nil { + return fmt.Errorf("reading map response: %w", err) + } + gotResponse = true + + if pr := resp.PingRequest; pr != nil && pr.Types == "c2n" { + if client.AnswerC2NPing(ctx, pr, session.NoiseRoundTrip) && mapArgs.quiet { + resp.PingRequest = nil + } + } + if mapArgs.quiet { + resp.KeepAlive = false + } + + if isZeroMapResponse(resp) { + continue + } + + out, err := json.MarshalIndent(resp, "", " ") + if err != nil { + return fmt.Errorf("encoding response: %w", err) + } + out = append(out, '\n') + if err := writeOutput(mapArgs.output, out); err != nil { + return err + } + } +} + +// readMachineKeyFile reads a machine private key from a file. +func readMachineKeyFile(path string) (key.MachinePrivate, error) { + data, err := os.ReadFile(path) + if err != nil { + return key.MachinePrivate{}, err + } + var k key.MachinePrivate + if err := k.UnmarshalText(bytes.TrimSpace(data)); err != nil { + return key.MachinePrivate{}, fmt.Errorf("parsing machine key from %q: %w", path, err) + } + return k, nil +} + +// readNodeKeyFile reads a node private key from a file. +func readNodeKeyFile(path string) (key.NodePrivate, error) { + data, err := os.ReadFile(path) + if err != nil { + return key.NodePrivate{}, err + } + var k key.NodePrivate + if err := k.UnmarshalText(bytes.TrimSpace(data)); err != nil { + return key.NodePrivate{}, fmt.Errorf("parsing node key from %q: %w", path, err) + } + return k, nil +} + +// readControlKeyFile reads a file containing a server's MachinePublic key +// in its MarshalText form (e.g. "mkey:..."). +func readControlKeyFile(path string) (key.MachinePublic, error) { + data, err := os.ReadFile(path) + if err != nil { + return key.MachinePublic{}, err + } + var k key.MachinePublic + if err := k.UnmarshalText(bytes.TrimSpace(data)); err != nil { + return key.MachinePublic{}, fmt.Errorf("parsing control key from %q: %w", path, err) + } + return k, nil +} + +// resolveAuthKey returns the auth key from v. If v is empty, it returns "". +// If v starts with "tskey-", it's used directly. Otherwise v is treated as a +// filename and its contents are read and trimmed. +func resolveAuthKey(v string) (string, error) { + if v == "" { + return "", nil + } + if strings.HasPrefix(strings.TrimSpace(v), "tskey-") { + return strings.TrimSpace(v), nil + } + data, err := os.ReadFile(v) + if err != nil { + return "", fmt.Errorf("reading auth key file: %w", err) + } + return strings.TrimSpace(string(data)), nil +} + +func writeOutput(path string, data []byte) error { + if path == "" { + _, err := os.Stdout.Write(data) + return err + } + return os.WriteFile(path, data, 0600) +} + +// isZeroMapResponse reports whether all fields of resp are zero values. +func isZeroMapResponse(resp *tailcfg.MapResponse) bool { + v := reflect.ValueOf(*resp) + for i := range v.NumField() { + if !v.Field(i).IsZero() { + return false + } + } + return true +} diff --git a/control/tsp/map.go b/control/tsp/map.go new file mode 100644 index 000000000..96531255b --- /dev/null +++ b/control/tsp/map.go @@ -0,0 +1,339 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package tsp + +import ( + "bytes" + "cmp" + "context" + "encoding/binary" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "strings" + "sync" + "sync/atomic" + + "github.com/klauspost/compress/zstd" + "tailscale.com/control/ts2021" + "tailscale.com/tailcfg" + "tailscale.com/types/key" +) + +// errSessionClosed is returned by [MapSession.Next] and +// [MapSession.NextInto] when called after [MapSession.Close]. +var errSessionClosed = errors.New("tsp: map session closed") + +// DefaultMaxMessageSize is the default cap, in bytes, on the size of a +// single compressed map response frame. See [MapOpts.MaxMessageSize]. +const DefaultMaxMessageSize = 4 << 20 + +// zstdDecoderPool is a pool of *zstd.Decoder reused across MapSessions to +// amortize the cost of setting up zstd state. Decoders are returned via +// [MapSession.Close]; entries are reclaimed by the runtime under memory +// pressure via sync.Pool semantics. +var zstdDecoderPool sync.Pool // of *zstd.Decoder + +// MapOpts contains options for sending a map request. +type MapOpts struct { + // NodeKey is the node's private key. Required. + NodeKey key.NodePrivate + + // Hostinfo is the host information to send. Optional; + // if nil, a minimal default is used. + Hostinfo *tailcfg.Hostinfo + + // Stream is whether to receive multiple MapResponses over + // the same HTTP connection. + Stream bool + + // OmitPeers is whether the client is okay with the Peers list + // being omitted in the response. + OmitPeers bool + + // MaxMessageSize is the maximum size in bytes of any single + // compressed map response frame on the wire. If zero, + // [DefaultMaxMessageSize] is used. + MaxMessageSize int64 +} + +// framedReader is an io.Reader that consumes a stream of length-prefixed +// frames (each a little-endian uint32 length followed by that many bytes) +// from r and yields only the frame payloads back-to-back. +// +// This lets us feed the concatenated zstd frames from our wire protocol +// into a single streaming zstd decoder. Zstd's file format permits +// concatenation (RFC 8478 §2), and klauspost's decoder handles it +// transparently. +// +// If onNewFrame is non-nil, it is called after each new 4-byte length +// header is successfully read. Used to reset the per-message decoded-size +// budget downstream. +type framedReader struct { + r io.Reader + maxSize int64 // per-frame compressed-size cap + remain int // bytes remaining in the current frame + onNewFrame func() +} + +func (f *framedReader) Read(p []byte) (int, error) { + if f.remain == 0 { + var hdr [4]byte + if _, err := io.ReadFull(f.r, hdr[:]); err != nil { + return 0, err + } + sz := int64(binary.LittleEndian.Uint32(hdr[:])) + if sz == 0 { + return 0, fmt.Errorf("map response: zero-length frame") + } + if sz > f.maxSize { + return 0, fmt.Errorf("map response frame size %d exceeds max %d", sz, f.maxSize) + } + f.remain = int(sz) + if f.onNewFrame != nil { + f.onNewFrame() + } + } + if len(p) > f.remain { + p = p[:f.remain] + } + n, err := f.r.Read(p) + f.remain -= n + return n, err +} + +// boundedReader is an io.Reader that yields at most remain bytes from r +// before returning an error. Call reset to raise the budget back to max, +// typically at a new message boundary. +// +// Used to cap the decoded size of a single map response so a malicious +// server can't send a small zstd frame that explodes into gigabytes of +// junk for the json.Decoder to consume. +type boundedReader struct { + r io.Reader + max int64 + remain int64 +} + +func (b *boundedReader) Read(p []byte) (int, error) { + if b.remain <= 0 { + return 0, fmt.Errorf("map response decoded size exceeds max %d", b.max) + } + if int64(len(p)) > b.remain { + p = p[:b.remain] + } + n, err := b.r.Read(p) + b.remain -= int64(n) + return n, err +} + +func (b *boundedReader) reset() { b.remain = b.max } + +// MapSession wraps an in-progress map response stream. Call Next to read +// each MapResponse. Call Close when done. +type MapSession struct { + res *http.Response + stream bool + noiseDoer func(*http.Request) (*http.Response, error) + + // inNext detects concurrent NextInto callers. It CAS-flips + // false→true on entry and back to false on exit; a failed CAS + // panics, akin to how the Go runtime detects concurrent map + // access. It does not serialize Close vs. NextInto; that's + // nextMu's job. + inNext atomic.Bool + + // nextMu is held while [MapSession.NextInto] is running jdec.Decode, + // so that Close can wait for an in-flight Decode to unwind before it + // touches zdec (Reset, pool-Put) and avoids racing with the running + // Read chain that Decode drives. + nextMu sync.Mutex + read int // guarded by nextMu + closed bool // guarded by nextMu + zdec *zstd.Decoder // reads from a framedReader wrapping res.Body + jdec *json.Decoder // reads decompressed JSON from zdec + + closeOnce sync.Once + closeErr error +} + +// NoiseRoundTrip sends an HTTP request over the Noise channel used by this map session. +func (s *MapSession) NoiseRoundTrip(req *http.Request) (*http.Response, error) { + return s.noiseDoer(req) +} + +// Next reads and returns the next MapResponse from the stream. +// For non-streaming sessions, the first call returns the single response +// and subsequent calls return io.EOF. +// For streaming sessions, Next blocks until the next response arrives +// or the server closes the connection. +// +// Each call allocates a fresh MapResponse. Callers that want to amortize +// the allocation across calls can use [MapSession.NextInto]. +// +// Next and NextInto are not safe to call concurrently from multiple +// goroutines on the same [MapSession]; a concurrent call panics, akin +// to the Go runtime's concurrent map access detection. [MapSession.Close] +// may be called concurrently to abort an in-flight Next. +func (s *MapSession) Next() (*tailcfg.MapResponse, error) { + var resp tailcfg.MapResponse + if err := s.NextInto(&resp); err != nil { + return nil, err + } + return &resp, nil +} + +// NextInto is like [MapSession.Next] but decodes the next MapResponse into +// the caller-supplied *resp rather than allocating a new one. The pointer's +// pointee is zeroed before decoding so fields from a prior response do not +// persist. +// +// For non-streaming sessions, the first call decodes the single response +// and subsequent calls return io.EOF. +// For streaming sessions, NextInto blocks until the next response arrives +// or the server closes the connection. +// +// See [MapSession.Next] for concurrency rules; those apply to NextInto too. +func (s *MapSession) NextInto(resp *tailcfg.MapResponse) error { + if !s.inNext.CompareAndSwap(false, true) { + panic("tsp: invalid concurrent call to MapSession.Next/NextInto") + } + defer s.inNext.Store(false) + + s.nextMu.Lock() + defer s.nextMu.Unlock() + if s.closed { + return errSessionClosed + } + if !s.stream && s.read > 0 { + return io.EOF + } + *resp = tailcfg.MapResponse{} + if err := s.jdec.Decode(resp); err != nil { + return err + } + s.read++ + return nil +} + +// Close returns the session's zstd decoder to the pool and closes the +// underlying HTTP response body. It is safe to call Close multiple times +// and from multiple goroutines, including while a [MapSession.Next] or +// [MapSession.NextInto] call is in flight on another goroutine (which +// will return an error once the body close propagates). +func (s *MapSession) Close() error { + // Callers are likely to race a deferred Close with a time.AfterFunc + // timeout (or similar) Close that aborts a hung Next. Without the + // Once, both Closes would Put the same *zstd.Decoder into the pool, + // corrupting it, and the Reset/Put in one would race with the + // zdec.Read that the hung Next is driving. + // + // Ordering inside the Once: close the body first to unblock any + // in-flight NextInto (its Read chain ends at res.Body and will + // return an error once it's closed). That lets NextInto unwind and + // release nextMu. Only then do we take nextMu ourselves and touch + // zdec, which is safe because no goroutine is still reading from + // it. Acquiring nextMu before closing the body would deadlock + // against a hung NextInto. + s.closeOnce.Do(func() { + s.closeErr = s.res.Body.Close() + s.nextMu.Lock() + defer s.nextMu.Unlock() + s.closed = true + s.zdec.Reset(nil) + zstdDecoderPool.Put(s.zdec) + }) + return s.closeErr +} + +// Map sends a map request to the coordination server and returns a MapSession +// for reading the framed, zstd-compressed response(s). +func (c *Client) Map(ctx context.Context, opts MapOpts) (*MapSession, error) { + if opts.NodeKey.IsZero() { + return nil, fmt.Errorf("NodeKey is required") + } + + hi := opts.Hostinfo + if hi == nil { + hi = defaultHostinfo() + } + + mapReq := tailcfg.MapRequest{ + Version: tailcfg.CurrentCapabilityVersion, + NodeKey: opts.NodeKey.Public(), + Hostinfo: hi, + Stream: opts.Stream, + Compress: "zstd", + OmitPeers: opts.OmitPeers, + // Streaming requires the server to track us as "connected", + // which in turn requires ReadOnly=false. Non-streaming polls + // stay ReadOnly to minimize side effects. + ReadOnly: !opts.Stream, + } + + body, err := json.Marshal(mapReq) + if err != nil { + return nil, fmt.Errorf("encoding map request: %w", err) + } + + nc, err := c.noiseClient(ctx) + if err != nil { + return nil, fmt.Errorf("establishing noise connection: %w", err) + } + + url := c.serverURL + "/machine/map" + url = strings.Replace(url, "http:", "https:", 1) + req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("creating map request: %w", err) + } + ts2021.AddLBHeader(req, opts.NodeKey.Public()) + + res, err := nc.Do(req) + if err != nil { + return nil, fmt.Errorf("map request: %w", err) + } + + if res.StatusCode != 200 { + msg, _ := io.ReadAll(res.Body) + res.Body.Close() + return nil, fmt.Errorf("map request: http %d: %.200s", + res.StatusCode, strings.TrimSpace(string(msg))) + } + + maxMessageSize := cmp.Or(opts.MaxMessageSize, DefaultMaxMessageSize) + bounded := &boundedReader{max: maxMessageSize, remain: maxMessageSize} + fr := &framedReader{ + r: res.Body, + maxSize: maxMessageSize, + onNewFrame: bounded.reset, + } + + zdec, _ := zstdDecoderPool.Get().(*zstd.Decoder) + if zdec != nil { + if err := zdec.Reset(fr); err != nil { + // Reset can fail if the previous stream is in a bad state; drop + // the decoder and create a fresh one. + zdec = nil + } + } + if zdec == nil { + zdec, err = zstd.NewReader(fr, zstd.WithDecoderConcurrency(1)) + if err != nil { + res.Body.Close() + return nil, fmt.Errorf("creating zstd decoder: %w", err) + } + } + bounded.r = zdec + + return &MapSession{ + res: res, + stream: opts.Stream, + noiseDoer: nc.Do, + zdec: zdec, + jdec: json.NewDecoder(bounded), + }, nil +} diff --git a/control/tsp/map_test.go b/control/tsp/map_test.go new file mode 100644 index 000000000..15b32dd36 --- /dev/null +++ b/control/tsp/map_test.go @@ -0,0 +1,270 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package tsp + +import ( + "bytes" + "context" + "encoding/binary" + "encoding/json" + "fmt" + "io" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/klauspost/compress/zstd" + "tailscale.com/tailcfg" + "tailscale.com/tstest/integration/testcontrol" + "tailscale.com/types/key" +) + +func TestMapAgainstTestControl(t *testing.T) { + ctrl := &testcontrol.Server{} + ctrl.HTTPTestServer = httptest.NewUnstartedServer(ctrl) + ctrl.HTTPTestServer.Start() + t.Cleanup(ctrl.HTTPTestServer.Close) + baseURL := ctrl.HTTPTestServer.URL + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + serverKey, err := DiscoverServerKey(ctx, baseURL) + if err != nil { + t.Fatalf("DiscoverServerKey: %v", err) + } + + register := func(hostname string) (nodeKey key.NodePrivate, machineKey key.MachinePrivate) { + t.Helper() + nodeKey = key.NewNode() + machineKey = key.NewMachine() + c, err := NewClient(ClientOpts{ + ServerURL: baseURL, + MachineKey: machineKey, + }) + if err != nil { + t.Fatalf("NewClient %s: %v", hostname, err) + } + defer c.Close() + c.SetControlPublicKey(serverKey) + if _, err := c.Register(ctx, RegisterOpts{ + NodeKey: nodeKey, + Hostinfo: &tailcfg.Hostinfo{Hostname: hostname}, + }); err != nil { + t.Fatalf("Register %s: %v", hostname, err) + } + return nodeKey, machineKey + } + + nodeKeyA, machineKeyA := register("a") + nodeKeyB, _ := register("b") + + clientA, err := NewClient(ClientOpts{ + ServerURL: baseURL, + MachineKey: machineKeyA, + }) + if err != nil { + t.Fatalf("NewClient A: %v", err) + } + defer clientA.Close() + clientA.SetControlPublicKey(serverKey) + + session, err := clientA.Map(ctx, MapOpts{ + NodeKey: nodeKeyA, + Hostinfo: &tailcfg.Hostinfo{Hostname: "a"}, + Stream: true, + }) + if err != nil { + t.Fatalf("Map: %v", err) + } + defer session.Close() + + // nextNonKeepalive returns the next non-keepalive MapResponse, to keep + // the test robust if a server-side keepalive arrives mid-test. + nextNonKeepalive := func() *tailcfg.MapResponse { + t.Helper() + for { + resp, err := session.Next() + if err != nil { + t.Fatalf("session.Next: %v", err) + } + if resp.KeepAlive { + continue + } + return resp + } + } + + // First MapResponse: expect node A as self and node B in Peers. + first := nextNonKeepalive() + if first.Node == nil { + t.Fatal("first response has nil Node") + } + if got, want := first.Node.Key, nodeKeyA.Public(); got != want { + t.Errorf("first Node.Key = %v, want %v", got, want) + } + var foundB bool + for _, p := range first.Peers { + if p.Key == nodeKeyB.Public() { + foundB = true + break + } + } + if !foundB { + t.Errorf("peer B (%v) not in first response's Peers (%d peers)", nodeKeyB.Public(), len(first.Peers)) + } + + // Inject raw MapResponses and verify they come out the reader, in order. + // msgToSend is single-slot, so we must consume each before injecting the next. + for i := range 3 { + want := fmt.Sprintf("injected-%d.example.com", i) + inject := &tailcfg.MapResponse{Domain: want} + if !ctrl.AddRawMapResponse(nodeKeyA.Public(), inject) { + t.Fatalf("AddRawMapResponse %d: node not connected", i) + } + got := nextNonKeepalive() + if got.Domain != want { + t.Errorf("injected %d: got Domain=%q, want %q", i, got.Domain, want) + } + } +} + +// newTestPipeline builds the same framedReader → zstd → boundedReader → +// json.Decoder pipeline that [Client.Map] builds for a live session, but +// feeds it from a raw byte slice. Returned jdec can be used with Decode to +// pull out MapResponses. +func newTestPipeline(t testing.TB, wire []byte, maxMessageSize int64) *json.Decoder { + t.Helper() + bounded := &boundedReader{max: maxMessageSize, remain: maxMessageSize} + fr := &framedReader{ + r: bytes.NewReader(wire), + maxSize: maxMessageSize, + onNewFrame: bounded.reset, + } + zdec, err := zstd.NewReader(fr, zstd.WithDecoderConcurrency(1)) + if err != nil { + t.Fatalf("zstd.NewReader: %v", err) + } + t.Cleanup(zdec.Close) + bounded.r = zdec + return json.NewDecoder(bounded) +} + +// zstdFrame returns a zstd-compressed frame of b. +func zstdFrame(t testing.TB, b []byte) []byte { + t.Helper() + enc, err := zstd.NewWriter(io.Discard, zstd.WithEncoderConcurrency(1)) + if err != nil { + t.Fatalf("zstd.NewWriter: %v", err) + } + defer enc.Close() + return enc.EncodeAll(b, nil) +} + +// wireFrame writes a 4-byte little-endian length prefix plus payload to buf. +func wireFrame(buf *bytes.Buffer, payload []byte) { + var hdr [4]byte + binary.LittleEndian.PutUint32(hdr[:], uint32(len(payload))) + buf.Write(hdr[:]) + buf.Write(payload) +} + +// TestMapFrameSizeTooLarge verifies that a 4-byte length prefix claiming +// a frame larger than the configured cap is rejected before any payload +// bytes are read from the stream. +func TestMapFrameSizeTooLarge(t *testing.T) { + const max = 4 << 20 + var wire bytes.Buffer + var hdr [4]byte + binary.LittleEndian.PutUint32(hdr[:], (max + 1)) + wire.Write(hdr[:]) + + jdec := newTestPipeline(t, wire.Bytes(), max) + var resp tailcfg.MapResponse + err := jdec.Decode(&resp) + if err == nil { + t.Fatal("Decode: got nil error, want frame-too-large") + } + if !strings.Contains(err.Error(), "exceeds max") { + t.Errorf("Decode error = %q, want one containing %q", err, "exceeds max") + } +} + +// TestMapDecodedSizeTooLarge verifies that a small on-wire frame (well +// under the cap) which decompresses into a huge JSON payload is rejected. +// This is the "zstd bomb" case: a tiny compressed frame that would +// explode into a huge decoded payload for json.Decoder to consume. +func TestMapDecodedSizeTooLarge(t *testing.T) { + const max = 4 << 20 + big := strings.Repeat("a", 5<<20) // 5 MiB of 'a' + raw, err := json.Marshal(&tailcfg.MapResponse{Domain: big}) + if err != nil { + t.Fatal(err) + } + if int64(len(raw)) <= max { + t.Fatalf("raw JSON unexpectedly small: %d", len(raw)) + } + compressed := zstdFrame(t, raw) + if int64(len(compressed)) >= max { + t.Fatalf("compressed too large (%d); test needs a more compressible payload", len(compressed)) + } + + var wire bytes.Buffer + wireFrame(&wire, compressed) + + jdec := newTestPipeline(t, wire.Bytes(), max) + var resp tailcfg.MapResponse + err = jdec.Decode(&resp) + if err == nil { + t.Fatal("Decode: got nil error, want decoded-size-exceeded") + } + if !strings.Contains(err.Error(), "exceeds max") { + t.Errorf("Decode error = %q, want one containing %q", err, "exceeds max") + } +} + +// TestMapBudgetResetsBetweenFrames verifies that the per-message decoded +// budget is reset at each new frame boundary. Two consecutive 3-MiB frames +// should both decode successfully under a 4-MiB per-frame cap. Without the +// reset, the second frame would fail (remaining budget after frame 1 = +// 4MiB - 3MiB = 1MiB, and we'd try to read 3MiB more). +func TestMapBudgetResetsBetweenFrames(t *testing.T) { + const max = 4 << 20 + payload := strings.Repeat("a", 3<<20) + r1 := &tailcfg.MapResponse{Domain: payload + "-one"} + r2 := &tailcfg.MapResponse{Domain: payload + "-two"} + + var wire bytes.Buffer + for _, r := range []*tailcfg.MapResponse{r1, r2} { + raw, err := json.Marshal(r) + if err != nil { + t.Fatal(err) + } + if int64(len(raw)) >= max { + t.Fatalf("raw JSON size %d >= max %d; would fail budget check by itself", len(raw), max) + } + compressed := zstdFrame(t, raw) + if int64(len(compressed)) >= max { + t.Fatalf("compressed size %d >= max %d", len(compressed), max) + } + wireFrame(&wire, compressed) + } + + jdec := newTestPipeline(t, wire.Bytes(), max) + + var got1, got2 tailcfg.MapResponse + if err := jdec.Decode(&got1); err != nil { + t.Fatalf("first Decode: %v", err) + } + if got1.Domain != r1.Domain { + t.Errorf("first Domain mismatch (len %d vs %d)", len(got1.Domain), len(r1.Domain)) + } + if err := jdec.Decode(&got2); err != nil { + t.Fatalf("second Decode: %v", err) + } + if got2.Domain != r2.Domain { + t.Errorf("second Domain mismatch (len %d vs %d)", len(got2.Domain), len(r2.Domain)) + } +} diff --git a/control/tsp/nodefile.go b/control/tsp/nodefile.go new file mode 100644 index 000000000..8cae11ba9 --- /dev/null +++ b/control/tsp/nodefile.go @@ -0,0 +1,105 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package tsp + +import ( + "encoding/json" + "fmt" + "os" + + "tailscale.com/types/key" +) + +// ServerInfo identifies a coordination server by its URL and Noise public key. +type ServerInfo struct { + // URL is the base URL of the coordination server, without any path + // (e.g. "https://controlplane.tailscale.com"). + // + // There is no default value; a URL must always be supplied. + URL string `json:"server_url"` + + // Key is the server's Noise public key, used to establish an encrypted + // channel between the client and the coordination server. + Key key.MachinePublic `json:"server_key"` +} + +// NodeFile is the JSON structure for a node credentials file. It contains +// the private keys that authenticate a node to a coordination server. +// +// Example: +// +// { +// "node_key": "privkey:...", +// "machine_key": "privkey:...", +// "server_url": "https://controlplane.tailscale.com", +// "server_key": "mkey:..." +// } +// +// Note that node and machine private keys share the same "privkey:" +// textual form; they are disambiguated by the surrounding JSON field +// names rather than by any prefix in the key itself. +type NodeFile struct { + // NodeKey is the node's WireGuard private key. The corresponding + // public key identifies this node to other peers. + NodeKey key.NodePrivate `json:"node_key"` + + // MachineKey is the machine's private key. It authenticates this + // machine to the coordination server over Noise. + MachineKey key.MachinePrivate `json:"machine_key"` + + ServerInfo // server_url and server_key +} + +// ReadNodeFile reads and parses a node JSON file. +func ReadNodeFile(path string) (NodeFile, error) { + data, err := os.ReadFile(path) + if err != nil { + return NodeFile{}, err + } + var nf NodeFile + if err := json.Unmarshal(data, &nf); err != nil { + return NodeFile{}, fmt.Errorf("parsing node file %q: %w", path, err) + } + return nf, nil +} + +// WriteNodeFile writes a node JSON file. The file is created with mode 0600. +func WriteNodeFile(path string, nf NodeFile) error { + if err := nf.Check(); err != nil { + return fmt.Errorf("invalid NodeFile: %w", err) + } + return os.WriteFile(path, nf.AsJSON(), 0600) +} + +// AsJSON returns nf as a pretty-printed JSON object, terminated by a newline. +// +// It always succeeds and always returns a valid JSON object. It does not +// validate that the fields of nf are non-zero; it is the caller's +// responsibility to call [NodeFile.Check] first if they want to reject +// incomplete NodeFiles. +func (nf NodeFile) AsJSON() []byte { + out, err := json.MarshalIndent(nf, "", " ") + if err != nil { + panic(fmt.Sprintf("NodeFile.AsJSON: %v", err)) // unreachable: all fields marshal successfully + } + return append(out, '\n') +} + +// Check reports whether nf has all required fields set. +// It returns an error describing the first zero-valued field, if any. +func (nf NodeFile) Check() error { + if nf.NodeKey.IsZero() { + return fmt.Errorf("node_key is missing") + } + if nf.MachineKey.IsZero() { + return fmt.Errorf("machine_key is missing") + } + if nf.URL == "" { + return fmt.Errorf("server_url is missing") + } + if nf.ServerInfo.Key.IsZero() { + return fmt.Errorf("server_key is missing") + } + return nil +} diff --git a/control/tsp/nodefile_test.go b/control/tsp/nodefile_test.go new file mode 100644 index 000000000..4a019f25f --- /dev/null +++ b/control/tsp/nodefile_test.go @@ -0,0 +1,116 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package tsp + +import ( + "encoding/json" + "os" + "path/filepath" + "testing" + + "tailscale.com/types/key" +) + +func TestNodeFileRoundTrip(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "node.json") + + nf := NodeFile{ + NodeKey: key.NewNode(), + MachineKey: key.NewMachine(), + ServerInfo: ServerInfo{ + URL: "https://controlplane.tailscale.com", + Key: key.NewMachine().Public(), + }, + } + + if err := WriteNodeFile(path, nf); err != nil { + t.Fatalf("WriteNodeFile: %v", err) + } + + got, err := ReadNodeFile(path) + if err != nil { + t.Fatalf("ReadNodeFile: %v", err) + } + if !got.NodeKey.Equal(nf.NodeKey) { + t.Errorf("node key mismatch") + } + if !got.MachineKey.Equal(nf.MachineKey) { + t.Errorf("machine key mismatch") + } + if got.URL != nf.URL { + t.Errorf("server URL = %q, want %q", got.URL, nf.URL) + } + if got.ServerInfo.Key != nf.ServerInfo.Key { + t.Errorf("server key mismatch") + } +} + +// TestNodeFileFormat verifies that ReadNodeFile can parse a fixed JSON literal, +// ensuring we don't accidentally change the on-disk format. +func TestNodeFileFormat(t *testing.T) { + const fileContents = `{ + "node_key": "privkey:0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", + "machine_key": "privkey:fedcba9876543210fedcba9876543210fedcba9876543210fedcba9876543210", + "server_url": "https://controlplane.tailscale.com", + "server_key": "mkey:1111111111111111111111111111111111111111111111111111111111111111" +}` + dir := t.TempDir() + path := filepath.Join(dir, "node.json") + if err := os.WriteFile(path, []byte(fileContents), 0600); err != nil { + t.Fatal(err) + } + + nf, err := ReadNodeFile(path) + if err != nil { + t.Fatalf("ReadNodeFile: %v", err) + } + if nf.NodeKey.IsZero() { + t.Error("node key is zero") + } + if nf.MachineKey.IsZero() { + t.Error("machine key is zero") + } + if nf.URL != "https://controlplane.tailscale.com" { + t.Errorf("server URL = %q", nf.URL) + } + if nf.ServerInfo.Key.IsZero() { + t.Error("server key is zero") + } +} + +// TestNodeFileWriteFormat verifies that WriteNodeFile produces the expected +// JSON field names. +func TestNodeFileWriteFormat(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "node.json") + + nf := NodeFile{ + NodeKey: key.NewNode(), + MachineKey: key.NewMachine(), + ServerInfo: ServerInfo{ + URL: "https://example.com", + Key: key.NewMachine().Public(), + }, + } + + if err := WriteNodeFile(path, nf); err != nil { + t.Fatalf("WriteNodeFile: %v", err) + } + + data, err := os.ReadFile(path) + if err != nil { + t.Fatal(err) + } + + var raw map[string]json.RawMessage + if err := json.Unmarshal(data, &raw); err != nil { + t.Fatalf("parsing written JSON: %v", err) + } + for _, field := range []string{"node_key", "machine_key", "server_url", "server_key"} { + if _, ok := raw[field]; !ok { + t.Errorf("missing JSON field %q in written file", field) + } + } +} diff --git a/control/tsp/register.go b/control/tsp/register.go new file mode 100644 index 000000000..0d2baf75f --- /dev/null +++ b/control/tsp/register.go @@ -0,0 +1,116 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package tsp + +import ( + "bytes" + "cmp" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + + "tailscale.com/control/ts2021" + "tailscale.com/tailcfg" + "tailscale.com/types/key" +) + +// RegisterOpts contains options for registering a node. +type RegisterOpts struct { + // NodeKey is the node's private key. Required. + NodeKey key.NodePrivate + + // Hostinfo is the host information to send. Optional; + // if nil, a minimal default is used. + Hostinfo *tailcfg.Hostinfo + + // Ephemeral marks the node as ephemeral. + Ephemeral bool + + // AuthKey is a pre-authorized auth key. + AuthKey string + + // Tags is a list of ACL tags to request. + Tags []string + + // MaxResponseSize is the maximum size in bytes of the register + // response body. If zero, [DefaultMaxMessageSize] is used. + MaxResponseSize int64 +} + +// Register sends a registration request to the coordination server +// and returns the response. +func (c *Client) Register(ctx context.Context, opts RegisterOpts) (*tailcfg.RegisterResponse, error) { + hi := opts.Hostinfo + if hi == nil { + hi = defaultHostinfo() + } + if len(opts.Tags) > 0 { + hi.RequestTags = opts.Tags + } + + regReq := tailcfg.RegisterRequest{ + Version: tailcfg.CurrentCapabilityVersion, + NodeKey: opts.NodeKey.Public(), + Hostinfo: hi, + Ephemeral: opts.Ephemeral, + } + if opts.AuthKey != "" { + regReq.Auth = &tailcfg.RegisterResponseAuth{ + AuthKey: opts.AuthKey, + } + } + + body, err := json.Marshal(regReq) + if err != nil { + return nil, fmt.Errorf("encoding register request: %w", err) + } + + nc, err := c.noiseClient(ctx) + if err != nil { + return nil, fmt.Errorf("establishing noise connection: %w", err) + } + + url := c.serverURL + "/machine/register" + url = strings.Replace(url, "http:", "https:", 1) + req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("creating register request: %w", err) + } + ts2021.AddLBHeader(req, opts.NodeKey.Public()) + + res, err := nc.Do(req) + if err != nil { + return nil, fmt.Errorf("register request: %w", err) + } + defer res.Body.Close() + + maxResponseSize := cmp.Or(opts.MaxResponseSize, DefaultMaxMessageSize) + + if res.StatusCode != 200 { + msg, _ := io.ReadAll(io.LimitReader(res.Body, maxResponseSize)) + return nil, fmt.Errorf("register request: http %d: %.200s", + res.StatusCode, strings.TrimSpace(string(msg))) + } + + // Read up to maxResponseSize+1 so we can distinguish "exactly at cap" from + // "over the cap" rather than relying on a truncated json parse error. + data, err := io.ReadAll(io.LimitReader(res.Body, maxResponseSize+1)) + if err != nil { + return nil, fmt.Errorf("reading register response: %w", err) + } + if int64(len(data)) > maxResponseSize { + return nil, fmt.Errorf("register response exceeds max %d", maxResponseSize) + } + var resp tailcfg.RegisterResponse + if err := json.Unmarshal(data, &resp); err != nil { + return nil, fmt.Errorf("decoding register response: %w", err) + } + if resp.Error != "" { + return nil, fmt.Errorf("register: %s", resp.Error) + } + return &resp, nil +} diff --git a/control/tsp/tsp.go b/control/tsp/tsp.go new file mode 100644 index 000000000..a75cc7d0e --- /dev/null +++ b/control/tsp/tsp.go @@ -0,0 +1,251 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +// Package tsp provides a client for speaking the Tailscale protocol +// to a coordination server over Noise. +package tsp + +import ( + "bufio" + "bytes" + "cmp" + "context" + "encoding/json" + "fmt" + "io" + "net" + "net/http" + "strconv" + "sync" + + "tailscale.com/control/ts2021" + "tailscale.com/ipn" + "tailscale.com/net/tsdial" + "tailscale.com/tailcfg" + "tailscale.com/types/key" + "tailscale.com/types/logger" + "tailscale.com/version" +) + +// DefaultServerURL is the default coordination server base URL, +// used when ClientOpts.ServerURL is empty. +const DefaultServerURL = ipn.DefaultControlURL + +// ClientOpts contains options for creating a new Client. +type ClientOpts struct { + // ServerURL is the base URL of the coordination server + // (e.g. "https://controlplane.tailscale.com"). + // If empty, DefaultServerURL is used. + ServerURL string + + // MachineKey is this node's machine private key. Required. + MachineKey key.MachinePrivate + + // Logf is the log function. If nil, logger.Discard is used. + Logf logger.Logf +} + +// Client is a Tailscale protocol client that speaks to a coordination +// server over Noise. +type Client struct { + opts ClientOpts + serverURL string + logf logger.Logf + + mu sync.Mutex + nc *ts2021.Client // nil until noiseClient called + serverPub key.MachinePublic // zero until set or discovered +} + +// NewClient creates a new Client configured to talk to the coordination server +// specified in opts. It performs no I/O; the server's public key is discovered +// lazily on first use or can be set explicitly via SetControlPublicKey. +func NewClient(opts ClientOpts) (*Client, error) { + if opts.MachineKey.IsZero() { + return nil, fmt.Errorf("MachineKey is required") + } + logf := opts.Logf + if logf == nil { + logf = logger.Discard + } + return &Client{ + opts: opts, + serverURL: cmp.Or(opts.ServerURL, DefaultServerURL), + logf: logf, + }, nil +} + +// SetControlPublicKey sets the server's public key, bypassing lazy discovery. +// Any existing noise client is invalidated and will be re-created on next use. +func (c *Client) SetControlPublicKey(k key.MachinePublic) { + c.mu.Lock() + defer c.mu.Unlock() + c.serverPub = k + c.nc = nil +} + +// DiscoverServerKey fetches the server's public key from the coordination +// server and stores it for subsequent use. Any existing noise client is +// invalidated. +func (c *Client) DiscoverServerKey(ctx context.Context) (key.MachinePublic, error) { + k, err := DiscoverServerKey(ctx, c.serverURL) + if err != nil { + return key.MachinePublic{}, err + } + c.mu.Lock() + defer c.mu.Unlock() + c.serverPub = k + c.nc = nil + return k, nil +} + +// DiscoverServerKey fetches the coordination server's public key from the +// given server URL. It is a standalone function that requires no client state. +func DiscoverServerKey(ctx context.Context, serverURL string) (key.MachinePublic, error) { + serverURL = cmp.Or(serverURL, DefaultServerURL) + keysURL := serverURL + "/key?v=" + strconv.Itoa(int(tailcfg.CurrentCapabilityVersion)) + req, err := http.NewRequestWithContext(ctx, "GET", keysURL, nil) + if err != nil { + return key.MachinePublic{}, fmt.Errorf("creating key request: %w", err) + } + res, err := http.DefaultClient.Do(req) + if err != nil { + return key.MachinePublic{}, fmt.Errorf("fetching server key: %w", err) + } + defer res.Body.Close() + if res.StatusCode != 200 { + return key.MachinePublic{}, fmt.Errorf("fetching server key: %s", res.Status) + } + var keys struct { + PublicKey key.MachinePublic + } + if err := json.NewDecoder(res.Body).Decode(&keys); err != nil { + return key.MachinePublic{}, fmt.Errorf("decoding server key: %w", err) + } + return keys.PublicKey, nil +} + +// noiseClient returns the ts2021 noise client, creating it lazily if needed. +// If the server's public key is not yet known, it is discovered via HTTP. +func (c *Client) noiseClient(ctx context.Context) (*ts2021.Client, error) { + c.mu.Lock() + defer c.mu.Unlock() + + if c.nc != nil { + return c.nc, nil + } + + if c.serverPub.IsZero() { + // Discover server key without holding the lock, to avoid blocking + // other callers during the HTTP request. + c.mu.Unlock() + k, err := DiscoverServerKey(ctx, c.serverURL) + c.mu.Lock() + if err != nil { + return nil, err + } + // Re-check: another goroutine may have set it while we were unlocked. + if c.serverPub.IsZero() { + c.serverPub = k + } + // If nc was created by another goroutine while unlocked, use it. + if c.nc != nil { + return c.nc, nil + } + } + + nc, err := ts2021.NewClient(ts2021.ClientOpts{ + ServerURL: c.serverURL, + PrivKey: c.opts.MachineKey, + ServerPubKey: c.serverPub, + Dialer: tsdial.NewFromFuncForDebug(c.logf, (&net.Dialer{}).DialContext), + Logf: c.logf, + }) + if err != nil { + return nil, fmt.Errorf("creating noise client: %w", err) + } + c.nc = nc + return nc, nil +} + +// AnswerC2NPing handles a c2n PingRequest from the control plane by parsing the +// embedded HTTP request in the payload, routing it locally, and POSTing the HTTP +// response back to pr.URL using doNoiseRequest. The POST is done in a new +// goroutine so this method does not block. +// +// It reports whether the ping was handled. Unhandled pings (nil pr, non-c2n +// types, or unrecognized c2n paths) return false. +func (c *Client) AnswerC2NPing(ctx context.Context, pr *tailcfg.PingRequest, doNoiseRequest func(*http.Request) (*http.Response, error)) (handled bool) { + if pr == nil || pr.Types != "c2n" { + return false + } + + // Parse the HTTP request from the payload. + httpReq, err := http.ReadRequest(bufio.NewReader(bytes.NewReader(pr.Payload))) + if err != nil { + c.logf("parsing c2n ping payload: %v", err) + return false + } + + // Route the request locally. + var httpResp *http.Response + switch httpReq.URL.Path { + case "/echo": + body, _ := io.ReadAll(httpReq.Body) + httpResp = &http.Response{ + StatusCode: 200, + Status: "200 OK", + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: http.Header{}, + Body: io.NopCloser(bytes.NewReader(body)), + ContentLength: int64(len(body)), + } + default: + c.logf("ignoring c2n ping request for unhandled path %q", httpReq.URL.Path) + return false + } + + // Serialize the HTTP response. + var buf bytes.Buffer + if err := httpResp.Write(&buf); err != nil { + c.logf("serializing c2n ping response: %v", err) + return false + } + + // Send the response back to the control plane over the Noise channel. + go func() { + req, err := http.NewRequestWithContext(ctx, "POST", pr.URL, &buf) + if err != nil { + c.logf("creating c2n ping reply request: %v", err) + return + } + resp, err := doNoiseRequest(req) + if err != nil { + c.logf("sending c2n ping reply: %v", err) + return + } + resp.Body.Close() + }() + return true +} + +// Close closes the client and releases resources. +func (c *Client) Close() error { + c.mu.Lock() + nc := c.nc + c.nc = nil + c.mu.Unlock() + if nc != nil { + nc.Close() + } + return nil +} + +func defaultHostinfo() *tailcfg.Hostinfo { + return &tailcfg.Hostinfo{ + OS: version.OS(), + IPNVersion: version.Long(), + } +}