tka: truncate long rotation signature chains

When a rotation signature chain reaches a certain size, remove the
oldest rotation signature from the chain before wrapping it in a new
rotation signature.

Since all previous rotation signatures are signed by the same wrapping
pubkey (node's own tailnet lock key), the node can re-construct the
chain, re-signing previous rotation signatures. This will satisfy the
existing certificate validation logic.

Updates #13185

Signed-off-by: Anton Tolchanov <anton@tailscale.com>
This commit is contained in:
Anton Tolchanov
2024-08-19 19:32:14 +01:00
committed by Anton Tolchanov
parent bcc47d91ca
commit fd6686d81a
4 changed files with 221 additions and 11 deletions
+134
View File
@@ -9,7 +9,9 @@ import (
"testing"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"tailscale.com/types/key"
"tailscale.com/types/tkatype"
)
func TestSigDirect(t *testing.T) {
@@ -74,6 +76,9 @@ func TestSigNested(t *testing.T) {
if err := nestedSig.verifySignature(oldNode.Public(), k); err != nil {
t.Fatalf("verifySignature(oldNode) failed: %v", err)
}
if l := sigChainLength(nestedSig); l != 1 {
t.Errorf("nestedSig chain length = %v, want 1", l)
}
// The signature authorizing the rotation, signed by the
// rotation key & embedding the original signature.
@@ -88,6 +93,9 @@ func TestSigNested(t *testing.T) {
if err := sig.verifySignature(node.Public(), k); err != nil {
t.Fatalf("verifySignature(node) failed: %v", err)
}
if l := sigChainLength(sig); l != 2 {
t.Errorf("sig chain length = %v, want 2", l)
}
// Test verification fails if the wrong verification key is provided
kBad := Key{Kind: Key25519, Public: []byte{1, 2, 3, 4}, Votes: 2}
@@ -497,3 +505,129 @@ func TestDecodeWrappedAuthkey(t *testing.T) {
}
}
func TestResignNKS(t *testing.T) {
// Tailnet lock keypair of a signing node.
authPub, authPriv := testingKey25519(t, 1)
authKey := Key{Kind: Key25519, Public: authPub, Votes: 2}
// Node's own tailnet lock key used to sign rotation signatures.
tlPriv := key.NewNLPrivate()
// The original (oldest) node key, signed by a signing node.
origNode := key.NewNode()
origPub, _ := origNode.Public().MarshalBinary()
// The original signature for the old node key, signed by
// the network-lock key.
directSig := NodeKeySignature{
SigKind: SigDirect,
KeyID: authKey.MustID(),
Pubkey: origPub,
WrappingPubkey: tlPriv.Public().Verifier(),
}
sigHash := directSig.SigHash()
directSig.Signature = ed25519.Sign(authPriv, sigHash[:])
if err := directSig.verifySignature(origNode.Public(), authKey); err != nil {
t.Fatalf("verifySignature(origNode) failed: %v", err)
}
// Generate a bunch of node keys to be used by tests.
var nodeKeys []key.NodePublic
for range 20 {
n := key.NewNode()
nodeKeys = append(nodeKeys, n.Public())
}
// mkSig creates a signature chain starting with a direct signature
// with rotation signatures matching provided keys (from the nodeKeys slice).
mkSig := func(prevKeyIDs ...int) tkatype.MarshaledSignature {
sig := &directSig
for _, i := range prevKeyIDs {
pk, _ := nodeKeys[i].MarshalBinary()
sig = &NodeKeySignature{
SigKind: SigRotation,
Pubkey: pk,
Nested: sig,
}
var err error
sig.Signature, err = tlPriv.SignNKS(sig.SigHash())
if err != nil {
t.Error(err)
}
}
return sig.Serialize()
}
tests := []struct {
name string
oldSig tkatype.MarshaledSignature
wantPrevNodeKeys []key.NodePublic
}{
{
name: "first-rotation",
oldSig: directSig.Serialize(),
wantPrevNodeKeys: []key.NodePublic{origNode.Public()},
},
{
name: "second-rotation",
oldSig: mkSig(0),
wantPrevNodeKeys: []key.NodePublic{nodeKeys[0], origNode.Public()},
},
{
name: "truncate-chain",
oldSig: mkSig(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14),
wantPrevNodeKeys: []key.NodePublic{
nodeKeys[14],
nodeKeys[13],
nodeKeys[12],
nodeKeys[11],
nodeKeys[10],
nodeKeys[9],
nodeKeys[8],
nodeKeys[7],
nodeKeys[6],
nodeKeys[5],
nodeKeys[4],
nodeKeys[3],
nodeKeys[2],
nodeKeys[1],
origNode.Public(),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
newNode := key.NewNode()
got, err := ResignNKS(tlPriv, newNode.Public(), tt.oldSig)
if err != nil {
t.Fatalf("ResignNKS() error = %v", err)
}
var gotSig NodeKeySignature
if err := gotSig.Unserialize(got); err != nil {
t.Fatalf("Unserialize() failed: %v", err)
}
if err := gotSig.verifySignature(newNode.Public(), authKey); err != nil {
t.Errorf("verifySignature(newNode) error: %v", err)
}
rd, err := gotSig.rotationDetails()
if err != nil {
t.Fatalf("rotationDetails() error = %v", err)
}
if sigChainLength(gotSig) != len(tt.wantPrevNodeKeys)+1 {
t.Errorf("sigChainLength() = %v, want %v", sigChainLength(gotSig), len(tt.wantPrevNodeKeys)+1)
}
if diff := cmp.Diff(tt.wantPrevNodeKeys, rd.PrevNodeKeys, cmpopts.EquateComparable(key.NodePublic{})); diff != "" {
t.Errorf("PrevNodeKeys mismatch (-want +got):\n%s", diff)
}
})
}
}
func sigChainLength(s NodeKeySignature) int {
if s.Nested != nil {
return 1 + sigChainLength(*s.Nested)
}
return 1
}