tka: add attack-scenario unit tests, defensive checks, resolve TODOs

Signed-off-by: Tom DNetto <tom@tailscale.com>
This commit is contained in:
Tom DNetto
2022-08-03 11:00:24 -07:00
committed by Tom
parent 4001d0bf25
commit c13fab2a67
3 changed files with 189 additions and 57 deletions
+58
View File
@@ -147,6 +147,9 @@ func (s State) applyVerifiedAUM(update AUM) (State, error) {
return update.State.cloneForUpdate(&update), nil
case AUMAddKey:
if update.Key == nil {
return State{}, errors.New("no key to add provided")
}
if _, err := s.GetKey(update.Key.ID()); err == nil {
return State{}, errors.New("key already exists")
}
@@ -202,3 +205,58 @@ func (s State) applyVerifiedAUM(update AUM) (State, error) {
return State{}, fmt.Errorf("unhandled message: %v", update.MessageKind)
}
}
// Upper bound on checkpoint elements, chosen arbitrarily. Intended to
// cap out insanely large AUMs.
const (
maxDisablementSecrets = 32
maxKeys = 512
)
// staticValidateCheckpoint validates that the state is well-formed for
// inclusion in a checkpoint AUM.
func (s *State) staticValidateCheckpoint() error {
if s.LastAUMHash != nil {
return errors.New("cannot specify a parent AUM")
}
if len(s.DisablementSecrets) == 0 {
return errors.New("at least one disablement secret required")
}
if numDS := len(s.DisablementSecrets); numDS > maxDisablementSecrets {
return fmt.Errorf("too many disablement secrets (%d, max %d)", numDS, maxDisablementSecrets)
}
for i, ds := range s.DisablementSecrets {
if len(ds) != disablementLength {
return fmt.Errorf("disablement[%d]: invalid length (got %d, want %d)", i, len(ds), disablementLength)
}
for j, ds2 := range s.DisablementSecrets {
if i == j {
continue
}
if bytes.Equal(ds, ds2) {
return fmt.Errorf("disablement[%d]: duplicates disablement[%d]", i, j)
}
}
}
if len(s.Keys) == 0 {
return errors.New("at least one key is required")
}
if numKeys := len(s.Keys); numKeys > maxKeys {
return fmt.Errorf("too many keys (%d, max %d)", numKeys, maxKeys)
}
for i, k := range s.Keys {
if err := k.StaticValidate(); err != nil {
return fmt.Errorf("key[%d]: %v", i, err)
}
for j, k2 := range s.Keys {
if i == j {
continue
}
if bytes.Equal(k.ID(), k2.ID()) {
return fmt.Errorf("key[%d]: duplicates key[%d]", i, j)
}
}
}
return nil
}