If multiple certificates match when selecting a certificate, use the one issued the most recently (as determined by the NotBefore timestamp). This also adds some tests for the function that performs that comparison. Updates tailscale/coral#6 Signed-off-by: Adrian Dewhurst <adrian@tailscale.com>main
parent
a80cef0c13
commit
adda2d2a51
@ -0,0 +1,238 @@ |
||||
// Copyright (c) 2022 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
//go:build windows && cgo
|
||||
// +build windows,cgo
|
||||
|
||||
package controlclient |
||||
|
||||
import ( |
||||
"crypto" |
||||
"crypto/x509" |
||||
"crypto/x509/pkix" |
||||
"errors" |
||||
"reflect" |
||||
"testing" |
||||
"time" |
||||
|
||||
"github.com/tailscale/certstore" |
||||
) |
||||
|
||||
const ( |
||||
testRootCommonName = "testroot" |
||||
testRootSubject = "CN=testroot" |
||||
) |
||||
|
||||
type testIdentity struct { |
||||
chain []*x509.Certificate |
||||
} |
||||
|
||||
func makeChain(rootCommonName string, notBefore, notAfter time.Time) []*x509.Certificate { |
||||
return []*x509.Certificate{ |
||||
{ |
||||
NotBefore: notBefore, |
||||
NotAfter: notAfter, |
||||
PublicKeyAlgorithm: x509.RSA, |
||||
}, |
||||
{ |
||||
Subject: pkix.Name{ |
||||
CommonName: rootCommonName, |
||||
}, |
||||
PublicKeyAlgorithm: x509.RSA, |
||||
}, |
||||
} |
||||
} |
||||
|
||||
func (t *testIdentity) Certificate() (*x509.Certificate, error) { |
||||
return t.chain[0], nil |
||||
} |
||||
|
||||
func (t *testIdentity) CertificateChain() ([]*x509.Certificate, error) { |
||||
return t.chain, nil |
||||
} |
||||
|
||||
func (t *testIdentity) Signer() (crypto.Signer, error) { |
||||
return nil, errors.New("not implemented") |
||||
} |
||||
|
||||
func (t *testIdentity) Delete() error { |
||||
return errors.New("not implemented") |
||||
} |
||||
|
||||
func (t *testIdentity) Close() {} |
||||
|
||||
func TestSelectIdentityFromSlice(t *testing.T) { |
||||
var times []time.Time |
||||
for _, ts := range []string{ |
||||
"2000-01-01T00:00:00Z", |
||||
"2001-01-01T00:00:00Z", |
||||
"2002-01-01T00:00:00Z", |
||||
"2003-01-01T00:00:00Z", |
||||
} { |
||||
tm, err := time.Parse(time.RFC3339, ts) |
||||
if err != nil { |
||||
t.Fatal(err) |
||||
} |
||||
times = append(times, tm) |
||||
} |
||||
|
||||
tests := []struct { |
||||
name string |
||||
subject string |
||||
ids []certstore.Identity |
||||
now time.Time |
||||
// wantIndex is an index into ids, or -1 for nil.
|
||||
wantIndex int |
||||
}{ |
||||
{ |
||||
name: "single unexpired identity", |
||||
subject: testRootSubject, |
||||
ids: []certstore.Identity{ |
||||
&testIdentity{ |
||||
chain: makeChain(testRootCommonName, times[0], times[2]), |
||||
}, |
||||
}, |
||||
now: times[1], |
||||
wantIndex: 0, |
||||
}, |
||||
{ |
||||
name: "single expired identity", |
||||
subject: testRootSubject, |
||||
ids: []certstore.Identity{ |
||||
&testIdentity{ |
||||
chain: makeChain(testRootCommonName, times[0], times[1]), |
||||
}, |
||||
}, |
||||
now: times[2], |
||||
wantIndex: -1, |
||||
}, |
||||
{ |
||||
name: "unrelated ids", |
||||
subject: testRootSubject, |
||||
ids: []certstore.Identity{ |
||||
&testIdentity{ |
||||
chain: makeChain("something", times[0], times[2]), |
||||
}, |
||||
&testIdentity{ |
||||
chain: makeChain(testRootCommonName, times[0], times[2]), |
||||
}, |
||||
&testIdentity{ |
||||
chain: makeChain("else", times[0], times[2]), |
||||
}, |
||||
}, |
||||
now: times[1], |
||||
wantIndex: 1, |
||||
}, |
||||
{ |
||||
name: "expired with unrelated ids", |
||||
subject: testRootSubject, |
||||
ids: []certstore.Identity{ |
||||
&testIdentity{ |
||||
chain: makeChain("something", times[0], times[3]), |
||||
}, |
||||
&testIdentity{ |
||||
chain: makeChain(testRootCommonName, times[0], times[1]), |
||||
}, |
||||
&testIdentity{ |
||||
chain: makeChain("else", times[0], times[3]), |
||||
}, |
||||
}, |
||||
now: times[2], |
||||
wantIndex: -1, |
||||
}, |
||||
{ |
||||
name: "one expired", |
||||
subject: testRootSubject, |
||||
ids: []certstore.Identity{ |
||||
&testIdentity{ |
||||
chain: makeChain(testRootCommonName, times[0], times[1]), |
||||
}, |
||||
&testIdentity{ |
||||
chain: makeChain(testRootCommonName, times[1], times[3]), |
||||
}, |
||||
}, |
||||
now: times[2], |
||||
wantIndex: 1, |
||||
}, |
||||
{ |
||||
name: "two certs both unexpired", |
||||
subject: testRootSubject, |
||||
ids: []certstore.Identity{ |
||||
&testIdentity{ |
||||
chain: makeChain(testRootCommonName, times[0], times[3]), |
||||
}, |
||||
&testIdentity{ |
||||
chain: makeChain(testRootCommonName, times[1], times[3]), |
||||
}, |
||||
}, |
||||
now: times[2], |
||||
wantIndex: 1, |
||||
}, |
||||
{ |
||||
name: "two unexpired one expired", |
||||
subject: testRootSubject, |
||||
ids: []certstore.Identity{ |
||||
&testIdentity{ |
||||
chain: makeChain(testRootCommonName, times[0], times[3]), |
||||
}, |
||||
&testIdentity{ |
||||
chain: makeChain(testRootCommonName, times[1], times[3]), |
||||
}, |
||||
&testIdentity{ |
||||
chain: makeChain(testRootCommonName, times[0], times[1]), |
||||
}, |
||||
}, |
||||
now: times[2], |
||||
wantIndex: 1, |
||||
}, |
||||
} |
||||
for _, tt := range tests { |
||||
t.Run(tt.name, func(t *testing.T) { |
||||
gotId, gotChain := selectIdentityFromSlice(tt.subject, tt.ids, tt.now) |
||||
|
||||
if gotId == nil && gotChain != nil { |
||||
t.Error("id is nil: got non-nil chain, want nil chain") |
||||
return |
||||
} |
||||
if gotId != nil && gotChain == nil { |
||||
t.Error("id is not nil: got nil chain, want non-nil chain") |
||||
return |
||||
} |
||||
if tt.wantIndex == -1 { |
||||
if gotId != nil { |
||||
t.Error("got non-nil id, want nil id") |
||||
} |
||||
return |
||||
} |
||||
if gotId == nil { |
||||
t.Error("got nil id, want non-nil id") |
||||
return |
||||
} |
||||
if gotId != tt.ids[tt.wantIndex] { |
||||
found := -1 |
||||
for i := range tt.ids { |
||||
if tt.ids[i] == gotId { |
||||
found = i |
||||
break |
||||
} |
||||
} |
||||
if found == -1 { |
||||
t.Errorf("got unknown id, want id at index %v", tt.wantIndex) |
||||
} else { |
||||
t.Errorf("got id at index %v, want id at index %v", found, tt.wantIndex) |
||||
} |
||||
} |
||||
|
||||
tid, ok := tt.ids[tt.wantIndex].(*testIdentity) |
||||
if !ok { |
||||
t.Error("got non-testIdentity, want testIdentity") |
||||
return |
||||
} |
||||
|
||||
if !reflect.DeepEqual(tid.chain, gotChain) { |
||||
t.Errorf("got unknown chain, want chain from id at index %v", tt.wantIndex) |
||||
} |
||||
}) |
||||
} |
||||
} |
||||
Loading…
Reference in new issue