tstest/integration: add integration test for Tailnet Lock

This patch adds an integration test for Tailnet Lock, checking that a node can't
talk to peers in the tailnet until it becomes signed.

This patch also introduces a new package `tstest/tkatest`, which has some helpers
for constructing a mock control server that responds to TKA requests. This allows
us to reduce boilerplate in the IPN tests.

Updates tailscale/corp#33599

Signed-off-by: Alex Chan <alexc@tailscale.com>
This commit is contained in:
Alex Chan
2025-11-19 09:41:43 +00:00
committed by Alex Chan
parent 824027305a
commit b7658a4ad2
7 changed files with 574 additions and 287 deletions
+41 -1
View File
@@ -918,7 +918,7 @@ func (n *TestNode) Ping(otherNode *TestNode) error {
t := n.env.t
ip := otherNode.AwaitIP4().String()
t.Logf("Running ping %v (from %v)...", ip, n.AwaitIP4())
return n.Tailscale("ping", ip).Run()
return n.Tailscale("ping", "--timeout=1s", ip).Run()
}
// AwaitListening waits for the tailscaled to be serving local clients
@@ -1077,6 +1077,46 @@ func (n *TestNode) MustStatus() *ipnstate.Status {
return st
}
// PublicKey returns the hex-encoded public key of this node,
// e.g. `nodekey:123456abc`
func (n *TestNode) PublicKey() string {
tb := n.env.t
tb.Helper()
cmd := n.Tailscale("status", "--json")
out, err := cmd.CombinedOutput()
if err != nil {
tb.Fatalf("running `tailscale status`: %v, %s", err, out)
}
type Self struct{ PublicKey string }
type StatusOutput struct{ Self Self }
var st StatusOutput
if err := json.Unmarshal(out, &st); err != nil {
tb.Fatalf("decoding `tailscale status` JSON: %v\njson:\n%s", err, out)
}
return st.Self.PublicKey
}
// NLPublicKey returns the hex-encoded network lock public key of
// this node, e.g. `tlpub:123456abc`
func (n *TestNode) NLPublicKey() string {
tb := n.env.t
tb.Helper()
cmd := n.Tailscale("lock", "status", "--json")
out, err := cmd.CombinedOutput()
if err != nil {
tb.Fatalf("running `tailscale lock status`: %v, %s", err, out)
}
st := struct {
PublicKey string `json:"PublicKey"`
}{}
if err := json.Unmarshal(out, &st); err != nil {
tb.Fatalf("decoding `tailscale lock status` JSON: %v\njson:\n%s", err, out)
}
return st.PublicKey
}
// trafficTrap is an HTTP proxy handler to note whether any
// HTTP traffic tries to leave localhost from tailscaled. We don't
// expect any, so any request triggers a failure.
+74 -1
View File
@@ -2253,7 +2253,7 @@ func TestC2NDebugNetmap(t *testing.T) {
}
}
func TestNetworkLock(t *testing.T) {
func TestTailnetLock(t *testing.T) {
// If you run `tailscale lock log` on a node where Tailnet Lock isn't
// enabled, you get an error explaining that.
@@ -2291,6 +2291,79 @@ func TestNetworkLock(t *testing.T) {
t.Fatalf("stderr: want %q, got %q", wantErr, errBuf.String())
}
})
// If you create a tailnet with two signed nodes and one unsigned,
// the signed nodes can talk to each other but the unsigned node cannot
// talk to anybody.
t.Run("node-connectivity", func(t *testing.T) {
tstest.Shard(t)
t.Parallel()
env := NewTestEnv(t)
env.Control.DefaultNodeCapabilities = &tailcfg.NodeCapMap{
tailcfg.CapabilityTailnetLock: []tailcfg.RawMessage{},
}
// Start two nodes which will be our signing nodes.
signing1 := NewTestNode(t, env)
signing2 := NewTestNode(t, env)
nodes := []*TestNode{signing1, signing2}
for _, n := range nodes {
d := n.StartDaemon()
defer d.MustCleanShutdown(t)
n.MustUp()
n.AwaitRunning()
}
// Initiate Tailnet Lock with the two signing nodes.
initCmd := signing1.Tailscale("lock", "init",
"--gen-disablements", "10",
"--confirm",
signing1.NLPublicKey(), signing2.NLPublicKey(),
)
out, err := initCmd.CombinedOutput()
if err != nil {
t.Fatalf("init command failed: %q\noutput=%v", err, string(out))
}
// Check that the two signing nodes can ping each other
if err := signing1.Ping(signing2); err != nil {
t.Fatalf("ping signing1 -> signing2: %v", err)
}
if err := signing2.Ping(signing1); err != nil {
t.Fatalf("ping signing2 -> signing1: %v", err)
}
// Create and start a third node
node3 := NewTestNode(t, env)
d3 := node3.StartDaemon()
defer d3.MustCleanShutdown(t)
node3.MustUp()
node3.AwaitRunning()
if err := signing1.Ping(node3); err == nil {
t.Fatal("ping signing1 -> node3: expected err, but succeeded")
}
if err := node3.Ping(signing1); err == nil {
t.Fatal("ping node3 -> signing1: expected err, but succeeded")
}
// Sign node3, and check the nodes can now talk to each other
signCmd := signing1.Tailscale("lock", "sign", node3.PublicKey())
out, err = signCmd.CombinedOutput()
if err != nil {
t.Fatalf("sign command failed: %q\noutput = %v", err, string(out))
}
if err := signing1.Ping(node3); err != nil {
t.Fatalf("ping signing1 -> node3: expected success, got err: %v", err)
}
if err := node3.Ping(signing1); err != nil {
t.Fatalf("ping node3 -> signing1: expected success, got err: %v", err)
}
})
}
func TestNodeWithBadStateFile(t *testing.T) {
+149 -1
View File
@@ -33,6 +33,8 @@ import (
"tailscale.com/net/tsaddr"
"tailscale.com/syncs"
"tailscale.com/tailcfg"
"tailscale.com/tka"
"tailscale.com/tstest/tkatest"
"tailscale.com/types/key"
"tailscale.com/types/logger"
"tailscale.com/types/opt"
@@ -123,6 +125,10 @@ type Server struct {
nodeKeyAuthed set.Set[key.NodePublic]
msgToSend map[key.NodePublic]any // value is *tailcfg.PingRequest or entire *tailcfg.MapResponse
allExpired bool // All nodes will be told their node key is expired.
// tkaStorage records the Tailnet Lock state, if any.
// If nil, Tailnet Lock is not enabled in the Tailnet.
tkaStorage tka.CompactableChonk
}
// BaseURL returns the server's base URL, without trailing slash.
@@ -329,6 +335,7 @@ func (s *Server) initMux() {
w.WriteHeader(http.StatusNoContent)
})
s.mux.HandleFunc("/key", s.serveKey)
s.mux.HandleFunc("/machine/tka/", s.serveTKA)
s.mux.HandleFunc("/machine/", s.serveMachine)
s.mux.HandleFunc("/ts2021", s.serveNoiseUpgrade)
s.mux.HandleFunc("/c2n/", s.serveC2N)
@@ -439,7 +446,7 @@ func (s *Server) serveKey(w http.ResponseWriter, r *http.Request) {
func (s *Server) serveMachine(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
http.Error(w, "POST required", 400)
http.Error(w, "POST required for serveMachine", 400)
return
}
ctx := r.Context()
@@ -861,6 +868,132 @@ func (s *Server) serveRegister(w http.ResponseWriter, r *http.Request, mkey key.
w.Write(res)
}
func (s *Server) serveTKA(w http.ResponseWriter, r *http.Request) {
if r.Method != "GET" {
http.Error(w, "GET required for serveTKA", 400)
return
}
switch r.URL.Path {
case "/machine/tka/init/begin":
s.serveTKAInitBegin(w, r)
case "/machine/tka/init/finish":
s.serveTKAInitFinish(w, r)
case "/machine/tka/bootstrap":
s.serveTKABootstrap(w, r)
case "/machine/tka/sync/offer":
s.serveTKASyncOffer(w, r)
case "/machine/tka/sign":
s.serveTKASign(w, r)
default:
s.serveUnhandled(w, r)
}
}
func (s *Server) serveTKAInitBegin(w http.ResponseWriter, r *http.Request) {
s.mu.Lock()
defer s.mu.Unlock()
nodes := maps.Values(s.nodes)
genesisAUM, err := tkatest.HandleTKAInitBegin(w, r, nodes)
if err != nil {
go panic(fmt.Sprintf("HandleTKAInitBegin: %v", err))
}
s.tkaStorage = tka.ChonkMem()
s.tkaStorage.CommitVerifiedAUMs([]tka.AUM{*genesisAUM})
}
func (s *Server) serveTKAInitFinish(w http.ResponseWriter, r *http.Request) {
signatures, err := tkatest.HandleTKAInitFinish(w, r)
if err != nil {
go panic(fmt.Sprintf("HandleTKAInitFinish: %v", err))
}
s.mu.Lock()
defer s.mu.Unlock()
// Apply the signatures to each of the nodes. Because s.nodes is keyed
// by public key instead of node ID, we have to do this inefficiently.
//
// We only have small tailnets in the integration tests, so this isn't
// much of an issue.
for nodeID, sig := range signatures {
for _, n := range s.nodes {
if n.ID == nodeID {
n.KeySignature = sig
}
}
}
}
func (s *Server) serveTKABootstrap(w http.ResponseWriter, r *http.Request) {
s.mu.Lock()
defer s.mu.Unlock()
if s.tkaStorage == nil {
http.Error(w, "no TKA state when calling serveTKABootstrap", 400)
return
}
// Find the genesis AUM, which we need to include in the response.
var genesis *tka.AUM
allAUMs, err := s.tkaStorage.AllAUMs()
if err != nil {
http.Error(w, "unable to retrieve all AUMs from TKA state", 500)
return
}
for _, h := range allAUMs {
aum := must.Get(s.tkaStorage.AUM(h))
if _, hasParent := aum.Parent(); !hasParent {
genesis = &aum
break
}
}
if genesis == nil {
http.Error(w, "unable to find genesis AUM in TKA state", 500)
return
}
resp := tailcfg.TKABootstrapResponse{
GenesisAUM: genesis.Serialize(),
}
_, err = tkatest.HandleTKABootstrap(w, r, resp)
if err != nil {
go panic(fmt.Sprintf("HandleTKABootstrap: %v", err))
}
}
func (s *Server) serveTKASyncOffer(w http.ResponseWriter, r *http.Request) {
s.mu.Lock()
defer s.mu.Unlock()
authority, err := tka.Open(s.tkaStorage)
if err != nil {
go panic(fmt.Sprintf("serveTKASyncOffer: tka.Open: %v", err))
}
err = tkatest.HandleTKASyncOffer(w, r, authority, s.tkaStorage)
if err != nil {
go panic(fmt.Sprintf("HandleTKASyncOffer: %v", err))
}
}
func (s *Server) serveTKASign(w http.ResponseWriter, r *http.Request) {
s.mu.Lock()
defer s.mu.Unlock()
authority, err := tka.Open(s.tkaStorage)
if err != nil {
go panic(fmt.Sprintf("serveTKASign: tka.Open: %v", err))
}
sig, keyBeingSigned, err := tkatest.HandleTKASign(w, r, authority)
if err != nil {
go panic(fmt.Sprintf("HandleTKASign: %v", err))
}
s.nodes[*keyBeingSigned].KeySignature = *sig
s.updateLocked("TKASign", s.nodeIDsLocked(0))
}
// updateType indicates why a long-polling map request is being woken
// up for an update.
type updateType int
@@ -1197,6 +1330,21 @@ func (s *Server) MapResponse(req *tailcfg.MapRequest) (res *tailcfg.MapResponse,
v6Prefix,
}
// If the server is tracking TKA state, and there's a single TKA head,
// add it to the MapResponse.
if s.tkaStorage != nil {
heads, err := s.tkaStorage.Heads()
if err != nil {
log.Printf("unable to get TKA heads: %v", err)
} else if len(heads) != 1 {
log.Printf("unable to get single TKA head, got %v", heads)
} else {
res.TKAInfo = &tailcfg.TKAInfo{
Head: heads[0].Hash().String(),
}
}
}
s.mu.Lock()
defer s.mu.Unlock()
res.Node.PrimaryRoutes = s.nodeSubnetRoutes[nk]
+220
View File
@@ -0,0 +1,220 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
// tkatest has functions for creating a mock control server that responds
// to TKA endpoints.
package tkatest
import (
"encoding/json"
"errors"
"fmt"
"iter"
"log"
"net/http"
"tailscale.com/tailcfg"
"tailscale.com/tka"
"tailscale.com/types/key"
"tailscale.com/types/tkatype"
)
func serverError(w http.ResponseWriter, format string, a ...any) error {
err := fmt.Sprintf(format, a...)
http.Error(w, err, 500)
log.Printf("returning HTTP 500 error: %v", err)
return errors.New(err)
}
func userError(w http.ResponseWriter, format string, a ...any) error {
err := fmt.Sprintf(format, a...)
http.Error(w, err, 400)
return errors.New(err)
}
// HandleTKAInitBegin handles a request to /machine/tka/init/begin.
//
// If the request contains a valid genesis AUM, it sends a response to the
// client, and returns the AUM to the caller.
func HandleTKAInitBegin(w http.ResponseWriter, r *http.Request, nodes iter.Seq[*tailcfg.Node]) (*tka.AUM, error) {
var req *tailcfg.TKAInitBeginRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return nil, userError(w, "Decode: %v", err)
}
var aum tka.AUM
if err := aum.Unserialize(req.GenesisAUM); err != nil {
return nil, userError(w, "invalid genesis AUM: %v", err)
}
beginResp := tailcfg.TKAInitBeginResponse{}
for n := range nodes {
beginResp.NeedSignatures = append(
beginResp.NeedSignatures,
tailcfg.TKASignInfo{
NodeID: n.ID,
NodePublic: n.Key,
},
)
}
w.WriteHeader(200)
if err := json.NewEncoder(w).Encode(beginResp); err != nil {
return nil, serverError(w, "Encode: %v", err)
}
return &aum, nil
}
// HandleTKAInitFinish handles a request to /machine/tka/init/finish.
//
// It sends a response to the client, and gives the caller a list of node
// signatures to apply.
//
// This method assumes that the node signatures are valid, and does not
// verify them with the supplied public key.
func HandleTKAInitFinish(w http.ResponseWriter, r *http.Request) (map[tailcfg.NodeID]tkatype.MarshaledSignature, error) {
var req *tailcfg.TKAInitFinishRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
return nil, userError(w, "Decode: %v", err)
}
w.WriteHeader(200)
w.Write([]byte("{}"))
return req.Signatures, nil
}
// HandleTKABootstrap handles a request to /tka/bootstrap.
//
// If the request is valid, it sends a response to the client, and returns
// the parsed request to the caller.
func HandleTKABootstrap(w http.ResponseWriter, r *http.Request, resp tailcfg.TKABootstrapResponse) (*tailcfg.TKABootstrapRequest, error) {
req := new(tailcfg.TKABootstrapRequest)
if err := json.NewDecoder(r.Body).Decode(req); err != nil {
return nil, userError(w, "Decode: %v", err)
}
if req.Version != tailcfg.CurrentCapabilityVersion {
return nil, userError(w, "bootstrap CapVer = %v, want %v", req.Version, tailcfg.CurrentCapabilityVersion)
}
w.WriteHeader(200)
if err := json.NewEncoder(w).Encode(resp); err != nil {
return nil, serverError(w, "Encode: %v", err)
}
return req, nil
}
func HandleTKASyncOffer(w http.ResponseWriter, r *http.Request, authority *tka.Authority, chonk tka.Chonk) error {
body := new(tailcfg.TKASyncOfferRequest)
if err := json.NewDecoder(r.Body).Decode(body); err != nil {
return userError(w, "Decode: %v", err)
}
log.Printf("got sync offer:\n%+v", body)
nodeOffer, err := tka.ToSyncOffer(body.Head, body.Ancestors)
if err != nil {
return userError(w, "ToSyncOffer: %v", err)
}
controlOffer, err := authority.SyncOffer(chonk)
if err != nil {
return serverError(w, "authority.SyncOffer: %v", err)
}
sendAUMs, err := authority.MissingAUMs(chonk, nodeOffer)
if err != nil {
return serverError(w, "authority.MissingAUMs: %v", err)
}
head, ancestors, err := tka.FromSyncOffer(controlOffer)
if err != nil {
return serverError(w, "FromSyncOffer: %v", err)
}
resp := tailcfg.TKASyncOfferResponse{
Head: head,
Ancestors: ancestors,
MissingAUMs: make([]tkatype.MarshaledAUM, len(sendAUMs)),
}
for i, a := range sendAUMs {
resp.MissingAUMs[i] = a.Serialize()
}
log.Printf("responding to sync offer with:\n%+v", resp)
w.WriteHeader(200)
if err := json.NewEncoder(w).Encode(resp); err != nil {
return serverError(w, "Encode: %v", err)
}
return nil
}
// HandleTKASign handles a request to /machine/tka/sign.
//
// If the signature request is valid, it sends a response to the client, and
// gives the caller the signature and public key of the node being signed.
func HandleTKASign(w http.ResponseWriter, r *http.Request, authority *tka.Authority) (*tkatype.MarshaledSignature, *key.NodePublic, error) {
req := new(tailcfg.TKASubmitSignatureRequest)
if err := json.NewDecoder(r.Body).Decode(req); err != nil {
return nil, nil, userError(w, "Decode: %v", err)
}
if req.Version != tailcfg.CurrentCapabilityVersion {
return nil, nil, userError(w, "sign CapVer = %v, want %v", req.Version, tailcfg.CurrentCapabilityVersion)
}
var sig tka.NodeKeySignature
if err := sig.Unserialize(req.Signature); err != nil {
return nil, nil, userError(w, "malformed signature: %v", err)
}
var keyBeingSigned key.NodePublic
if err := keyBeingSigned.UnmarshalBinary(sig.Pubkey); err != nil {
return nil, nil, userError(w, "malformed signature pubkey: %v", err)
}
if err := authority.NodeKeyAuthorized(keyBeingSigned, req.Signature); err != nil {
return nil, nil, userError(w, "signature does not verify: %v", err)
}
w.WriteHeader(200)
if err := json.NewEncoder(w).Encode(tailcfg.TKASubmitSignatureResponse{}); err != nil {
return nil, nil, serverError(w, "Encode: %v", err)
}
return &req.Signature, &keyBeingSigned, nil
}
// HandleTKASyncSend handles a request to /machine/tka/send.
//
// If the request is valid, it adds the new AUMs to the authority, and sends
// a response to the client with the new head.
func HandleTKASyncSend(w http.ResponseWriter, r *http.Request, authority *tka.Authority, chonk tka.Chonk) error {
body := new(tailcfg.TKASyncSendRequest)
if err := json.NewDecoder(r.Body).Decode(body); err != nil {
return userError(w, "Decode: %v", err)
}
log.Printf("got sync send:\n%+v", body)
var remoteHead tka.AUMHash
if err := remoteHead.UnmarshalText([]byte(body.Head)); err != nil {
return userError(w, "head unmarshal: %v", err)
}
toApply := make([]tka.AUM, len(body.MissingAUMs))
for i, a := range body.MissingAUMs {
if err := toApply[i].Unserialize(a); err != nil {
return userError(w, "decoding missingAUM[%d]: %v", i, err)
}
}
if len(toApply) > 0 {
if err := authority.Inform(chonk, toApply); err != nil {
return serverError(w, "control.Inform(%+v) failed: %v", toApply, err)
}
}
head, err := authority.Head().MarshalText()
if err != nil {
return serverError(w, "head marshal: %v", err)
}
resp := tailcfg.TKASyncSendResponse{
Head: string(head),
}
w.WriteHeader(200)
if err := json.NewEncoder(w).Encode(resp); err != nil {
return serverError(w, "Encode: %v", err)
}
return nil
}