This exports a number of things from the derp (generic + client) package to be used by the new derpserver package, as now used by cmd/derper. And then enough other misc changes to lock in that cmd/tailscaled can be configured to not bring in tailscale.com/client/local. (The webclient in particular, even when disabled, was bringing it in, so that's now fixed) Fixes #17257 Change-Id: I88b6c7958643fb54f386dd900bddf73d2d4d96d5 Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>main
parent
df747f1c1b
commit
21dc5f4e21
@ -0,0 +1,21 @@ |
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
//go:build !ts_omit_webclient
|
||||
|
||||
package main |
||||
|
||||
import ( |
||||
"tailscale.com/client/local" |
||||
"tailscale.com/ipn/ipnlocal" |
||||
"tailscale.com/paths" |
||||
) |
||||
|
||||
func init() { |
||||
hookConfigureWebClient.Set(func(lb *ipnlocal.LocalBackend) { |
||||
lb.ConfigureWebClient(&local.Client{ |
||||
Socket: args.socketpath, |
||||
UseSocketOnly: args.socketpath != paths.DefaultTailscaledSocket(), |
||||
}) |
||||
}) |
||||
} |
||||
@ -0,0 +1,235 @@ |
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package derp |
||||
|
||||
import ( |
||||
"bufio" |
||||
"bytes" |
||||
"io" |
||||
"net" |
||||
"reflect" |
||||
"sync" |
||||
"testing" |
||||
"time" |
||||
|
||||
"tailscale.com/tstest" |
||||
"tailscale.com/types/key" |
||||
) |
||||
|
||||
type dummyNetConn struct { |
||||
net.Conn |
||||
} |
||||
|
||||
func (dummyNetConn) SetReadDeadline(time.Time) error { return nil } |
||||
|
||||
func TestClientRecv(t *testing.T) { |
||||
tests := []struct { |
||||
name string |
||||
input []byte |
||||
want any |
||||
}{ |
||||
{ |
||||
name: "ping", |
||||
input: []byte{ |
||||
byte(FramePing), 0, 0, 0, 8, |
||||
1, 2, 3, 4, 5, 6, 7, 8, |
||||
}, |
||||
want: PingMessage{1, 2, 3, 4, 5, 6, 7, 8}, |
||||
}, |
||||
{ |
||||
name: "pong", |
||||
input: []byte{ |
||||
byte(FramePong), 0, 0, 0, 8, |
||||
1, 2, 3, 4, 5, 6, 7, 8, |
||||
}, |
||||
want: PongMessage{1, 2, 3, 4, 5, 6, 7, 8}, |
||||
}, |
||||
{ |
||||
name: "health_bad", |
||||
input: []byte{ |
||||
byte(FrameHealth), 0, 0, 0, 3, |
||||
byte('B'), byte('A'), byte('D'), |
||||
}, |
||||
want: HealthMessage{Problem: "BAD"}, |
||||
}, |
||||
{ |
||||
name: "health_ok", |
||||
input: []byte{ |
||||
byte(FrameHealth), 0, 0, 0, 0, |
||||
}, |
||||
want: HealthMessage{}, |
||||
}, |
||||
{ |
||||
name: "server_restarting", |
||||
input: []byte{ |
||||
byte(FrameRestarting), 0, 0, 0, 8, |
||||
0, 0, 0, 1, |
||||
0, 0, 0, 2, |
||||
}, |
||||
want: ServerRestartingMessage{ |
||||
ReconnectIn: 1 * time.Millisecond, |
||||
TryFor: 2 * time.Millisecond, |
||||
}, |
||||
}, |
||||
} |
||||
for _, tt := range tests { |
||||
t.Run(tt.name, func(t *testing.T) { |
||||
c := &Client{ |
||||
nc: dummyNetConn{}, |
||||
br: bufio.NewReader(bytes.NewReader(tt.input)), |
||||
logf: t.Logf, |
||||
clock: &tstest.Clock{}, |
||||
} |
||||
got, err := c.Recv() |
||||
if err != nil { |
||||
t.Fatal(err) |
||||
} |
||||
if !reflect.DeepEqual(got, tt.want) { |
||||
t.Errorf("got %#v; want %#v", got, tt.want) |
||||
} |
||||
}) |
||||
} |
||||
} |
||||
|
||||
func TestClientSendPing(t *testing.T) { |
||||
var buf bytes.Buffer |
||||
c := &Client{ |
||||
bw: bufio.NewWriter(&buf), |
||||
} |
||||
if err := c.SendPing([8]byte{1, 2, 3, 4, 5, 6, 7, 8}); err != nil { |
||||
t.Fatal(err) |
||||
} |
||||
want := []byte{ |
||||
byte(FramePing), 0, 0, 0, 8, |
||||
1, 2, 3, 4, 5, 6, 7, 8, |
||||
} |
||||
if !bytes.Equal(buf.Bytes(), want) { |
||||
t.Errorf("unexpected output\nwrote: % 02x\n want: % 02x", buf.Bytes(), want) |
||||
} |
||||
} |
||||
|
||||
func TestClientSendPong(t *testing.T) { |
||||
var buf bytes.Buffer |
||||
c := &Client{ |
||||
bw: bufio.NewWriter(&buf), |
||||
} |
||||
if err := c.SendPong([8]byte{1, 2, 3, 4, 5, 6, 7, 8}); err != nil { |
||||
t.Fatal(err) |
||||
} |
||||
want := []byte{ |
||||
byte(FramePong), 0, 0, 0, 8, |
||||
1, 2, 3, 4, 5, 6, 7, 8, |
||||
} |
||||
if !bytes.Equal(buf.Bytes(), want) { |
||||
t.Errorf("unexpected output\nwrote: % 02x\n want: % 02x", buf.Bytes(), want) |
||||
} |
||||
} |
||||
|
||||
func BenchmarkWriteUint32(b *testing.B) { |
||||
w := bufio.NewWriter(io.Discard) |
||||
b.ReportAllocs() |
||||
b.ResetTimer() |
||||
for range b.N { |
||||
writeUint32(w, 0x0ba3a) |
||||
} |
||||
} |
||||
|
||||
type nopRead struct{} |
||||
|
||||
func (r nopRead) Read(p []byte) (int, error) { |
||||
return len(p), nil |
||||
} |
||||
|
||||
var sinkU32 uint32 |
||||
|
||||
func BenchmarkReadUint32(b *testing.B) { |
||||
r := bufio.NewReader(nopRead{}) |
||||
var err error |
||||
b.ReportAllocs() |
||||
b.ResetTimer() |
||||
for range b.N { |
||||
sinkU32, err = readUint32(r) |
||||
if err != nil { |
||||
b.Fatal(err) |
||||
} |
||||
} |
||||
} |
||||
|
||||
type countWriter struct { |
||||
mu sync.Mutex |
||||
writes int |
||||
bytes int64 |
||||
} |
||||
|
||||
func (w *countWriter) Write(p []byte) (n int, err error) { |
||||
w.mu.Lock() |
||||
defer w.mu.Unlock() |
||||
w.writes++ |
||||
w.bytes += int64(len(p)) |
||||
return len(p), nil |
||||
} |
||||
|
||||
func (w *countWriter) Stats() (writes int, bytes int64) { |
||||
w.mu.Lock() |
||||
defer w.mu.Unlock() |
||||
return w.writes, w.bytes |
||||
} |
||||
|
||||
func (w *countWriter) ResetStats() { |
||||
w.mu.Lock() |
||||
defer w.mu.Unlock() |
||||
w.writes, w.bytes = 0, 0 |
||||
} |
||||
|
||||
func TestClientSendRateLimiting(t *testing.T) { |
||||
cw := new(countWriter) |
||||
c := &Client{ |
||||
bw: bufio.NewWriter(cw), |
||||
clock: &tstest.Clock{}, |
||||
} |
||||
c.setSendRateLimiter(ServerInfoMessage{}) |
||||
|
||||
pkt := make([]byte, 1000) |
||||
if err := c.send(key.NodePublic{}, pkt); err != nil { |
||||
t.Fatal(err) |
||||
} |
||||
writes1, bytes1 := cw.Stats() |
||||
if writes1 != 1 { |
||||
t.Errorf("writes = %v, want 1", writes1) |
||||
} |
||||
|
||||
// Flood should all succeed.
|
||||
cw.ResetStats() |
||||
for range 1000 { |
||||
if err := c.send(key.NodePublic{}, pkt); err != nil { |
||||
t.Fatal(err) |
||||
} |
||||
} |
||||
writes1K, bytes1K := cw.Stats() |
||||
if writes1K != 1000 { |
||||
t.Logf("writes = %v; want 1000", writes1K) |
||||
} |
||||
if got, want := bytes1K, bytes1*1000; got != want { |
||||
t.Logf("bytes = %v; want %v", got, want) |
||||
} |
||||
|
||||
// Set a rate limiter
|
||||
cw.ResetStats() |
||||
c.setSendRateLimiter(ServerInfoMessage{ |
||||
TokenBucketBytesPerSecond: 1, |
||||
TokenBucketBytesBurst: int(bytes1 * 2), |
||||
}) |
||||
for range 1000 { |
||||
if err := c.send(key.NodePublic{}, pkt); err != nil { |
||||
t.Fatal(err) |
||||
} |
||||
} |
||||
writesLimited, bytesLimited := cw.Stats() |
||||
if writesLimited == 0 || writesLimited == writes1K { |
||||
t.Errorf("limited conn's write count = %v; want non-zero, less than 1k", writesLimited) |
||||
} |
||||
if bytesLimited < bytes1*2 || bytesLimited >= bytes1K { |
||||
t.Errorf("limited conn's bytes count = %v; want >=%v, <%v", bytesLimited, bytes1K*2, bytes1K) |
||||
} |
||||
} |
||||
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,24 @@ |
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package derphttp |
||||
|
||||
func SetTestHookWatchLookConnectResult(f func(connectError error, wasSelfConnect bool) (keepRunning bool)) { |
||||
testHookWatchLookConnectResult = f |
||||
} |
||||
|
||||
// breakConnection breaks the connection, which should trigger a reconnect.
|
||||
func (c *Client) BreakConnection(brokenClient *Client) { |
||||
c.mu.Lock() |
||||
defer c.mu.Unlock() |
||||
if c.client != brokenClient.client { |
||||
return |
||||
} |
||||
if c.netConn != nil { |
||||
c.netConn.Close() |
||||
c.netConn = nil |
||||
} |
||||
c.client = nil |
||||
} |
||||
|
||||
var RetryInterval = &retryInterval |
||||
@ -0,0 +1,782 @@ |
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package derpserver |
||||
|
||||
import ( |
||||
"bufio" |
||||
"cmp" |
||||
"context" |
||||
"crypto/x509" |
||||
"encoding/asn1" |
||||
"expvar" |
||||
"fmt" |
||||
"log" |
||||
"net" |
||||
"os" |
||||
"reflect" |
||||
"strconv" |
||||
"sync" |
||||
"testing" |
||||
"time" |
||||
|
||||
qt "github.com/frankban/quicktest" |
||||
"go4.org/mem" |
||||
"golang.org/x/time/rate" |
||||
"tailscale.com/derp" |
||||
"tailscale.com/derp/derpconst" |
||||
"tailscale.com/types/key" |
||||
"tailscale.com/types/logger" |
||||
) |
||||
|
||||
const testMeshKey = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef" |
||||
|
||||
func TestSetMeshKey(t *testing.T) { |
||||
for name, tt := range map[string]struct { |
||||
key string |
||||
want key.DERPMesh |
||||
wantErr bool |
||||
}{ |
||||
"clobber": { |
||||
key: testMeshKey, |
||||
wantErr: false, |
||||
}, |
||||
"invalid": { |
||||
key: "badf00d", |
||||
wantErr: true, |
||||
}, |
||||
} { |
||||
t.Run(name, func(t *testing.T) { |
||||
s := &Server{} |
||||
|
||||
err := s.SetMeshKey(tt.key) |
||||
if tt.wantErr { |
||||
if err == nil { |
||||
t.Fatalf("expected err") |
||||
} |
||||
return |
||||
} |
||||
if err != nil { |
||||
t.Fatalf("unexpected err: %v", err) |
||||
} |
||||
|
||||
want, err := key.ParseDERPMesh(tt.key) |
||||
if err != nil { |
||||
t.Fatal(err) |
||||
} |
||||
if !s.meshKey.Equal(want) { |
||||
t.Fatalf("got %v, want %v", s.meshKey, want) |
||||
} |
||||
}) |
||||
} |
||||
} |
||||
|
||||
func TestIsMeshPeer(t *testing.T) { |
||||
s := &Server{} |
||||
err := s.SetMeshKey(testMeshKey) |
||||
if err != nil { |
||||
t.Fatal(err) |
||||
} |
||||
for name, tt := range map[string]struct { |
||||
want bool |
||||
meshKey string |
||||
wantAllocs float64 |
||||
}{ |
||||
"nil": { |
||||
want: false, |
||||
wantAllocs: 0, |
||||
}, |
||||
"mismatch": { |
||||
meshKey: "6d529e9d4ef632d22d4a4214cb49da8f1ba1b72697061fb24e312984c35ec8d8", |
||||
want: false, |
||||
wantAllocs: 1, |
||||
}, |
||||
"match": { |
||||
meshKey: testMeshKey, |
||||
want: true, |
||||
wantAllocs: 0, |
||||
}, |
||||
} { |
||||
t.Run(name, func(t *testing.T) { |
||||
var got bool |
||||
var mKey key.DERPMesh |
||||
if tt.meshKey != "" { |
||||
mKey, err = key.ParseDERPMesh(tt.meshKey) |
||||
if err != nil { |
||||
t.Fatalf("ParseDERPMesh(%q) failed: %v", tt.meshKey, err) |
||||
} |
||||
} |
||||
|
||||
info := derp.ClientInfo{ |
||||
MeshKey: mKey, |
||||
} |
||||
allocs := testing.AllocsPerRun(1, func() { |
||||
got = s.isMeshPeer(&info) |
||||
}) |
||||
if got != tt.want { |
||||
t.Fatalf("got %t, want %t: info = %#v", got, tt.want, info) |
||||
} |
||||
|
||||
if allocs != tt.wantAllocs && tt.want { |
||||
t.Errorf("%f allocations, want %f", allocs, tt.wantAllocs) |
||||
} |
||||
}) |
||||
} |
||||
} |
||||
|
||||
type testFwd int |
||||
|
||||
func (testFwd) ForwardPacket(key.NodePublic, key.NodePublic, []byte) error { |
||||
panic("not called in tests") |
||||
} |
||||
func (testFwd) String() string { |
||||
panic("not called in tests") |
||||
} |
||||
|
||||
func pubAll(b byte) (ret key.NodePublic) { |
||||
var bs [32]byte |
||||
for i := range bs { |
||||
bs[i] = b |
||||
} |
||||
return key.NodePublicFromRaw32(mem.B(bs[:])) |
||||
} |
||||
|
||||
func TestForwarderRegistration(t *testing.T) { |
||||
s := &Server{ |
||||
clients: make(map[key.NodePublic]*clientSet), |
||||
clientsMesh: map[key.NodePublic]PacketForwarder{}, |
||||
} |
||||
want := func(want map[key.NodePublic]PacketForwarder) { |
||||
t.Helper() |
||||
if got := s.clientsMesh; !reflect.DeepEqual(got, want) { |
||||
t.Fatalf("mismatch\n got: %v\nwant: %v\n", got, want) |
||||
} |
||||
} |
||||
wantCounter := func(c *expvar.Int, want int) { |
||||
t.Helper() |
||||
if got := c.Value(); got != int64(want) { |
||||
t.Errorf("counter = %v; want %v", got, want) |
||||
} |
||||
} |
||||
singleClient := func(c *sclient) *clientSet { |
||||
cs := &clientSet{} |
||||
cs.activeClient.Store(c) |
||||
return cs |
||||
} |
||||
|
||||
u1 := pubAll(1) |
||||
u2 := pubAll(2) |
||||
u3 := pubAll(3) |
||||
|
||||
s.AddPacketForwarder(u1, testFwd(1)) |
||||
s.AddPacketForwarder(u2, testFwd(2)) |
||||
want(map[key.NodePublic]PacketForwarder{ |
||||
u1: testFwd(1), |
||||
u2: testFwd(2), |
||||
}) |
||||
|
||||
// Verify a remove of non-registered forwarder is no-op.
|
||||
s.RemovePacketForwarder(u2, testFwd(999)) |
||||
want(map[key.NodePublic]PacketForwarder{ |
||||
u1: testFwd(1), |
||||
u2: testFwd(2), |
||||
}) |
||||
|
||||
// Verify a remove of non-registered user is no-op.
|
||||
s.RemovePacketForwarder(u3, testFwd(1)) |
||||
want(map[key.NodePublic]PacketForwarder{ |
||||
u1: testFwd(1), |
||||
u2: testFwd(2), |
||||
}) |
||||
|
||||
// Actual removal.
|
||||
s.RemovePacketForwarder(u2, testFwd(2)) |
||||
want(map[key.NodePublic]PacketForwarder{ |
||||
u1: testFwd(1), |
||||
}) |
||||
|
||||
// Adding a dup for a user.
|
||||
wantCounter(&s.multiForwarderCreated, 0) |
||||
s.AddPacketForwarder(u1, testFwd(100)) |
||||
s.AddPacketForwarder(u1, testFwd(100)) // dup to trigger dup path
|
||||
want(map[key.NodePublic]PacketForwarder{ |
||||
u1: newMultiForwarder(testFwd(1), testFwd(100)), |
||||
}) |
||||
wantCounter(&s.multiForwarderCreated, 1) |
||||
|
||||
// Removing a forwarder in a multi set that doesn't exist; does nothing.
|
||||
s.RemovePacketForwarder(u1, testFwd(55)) |
||||
want(map[key.NodePublic]PacketForwarder{ |
||||
u1: newMultiForwarder(testFwd(1), testFwd(100)), |
||||
}) |
||||
|
||||
// Removing a forwarder in a multi set that does exist should collapse it away
|
||||
// from being a multiForwarder.
|
||||
wantCounter(&s.multiForwarderDeleted, 0) |
||||
s.RemovePacketForwarder(u1, testFwd(1)) |
||||
want(map[key.NodePublic]PacketForwarder{ |
||||
u1: testFwd(100), |
||||
}) |
||||
wantCounter(&s.multiForwarderDeleted, 1) |
||||
|
||||
// Removing an entry for a client that's still connected locally should result
|
||||
// in a nil forwarder.
|
||||
u1c := &sclient{ |
||||
key: u1, |
||||
logf: logger.Discard, |
||||
} |
||||
s.clients[u1] = singleClient(u1c) |
||||
s.RemovePacketForwarder(u1, testFwd(100)) |
||||
want(map[key.NodePublic]PacketForwarder{ |
||||
u1: nil, |
||||
}) |
||||
|
||||
// But once that client disconnects, it should go away.
|
||||
s.unregisterClient(u1c) |
||||
want(map[key.NodePublic]PacketForwarder{}) |
||||
|
||||
// But if it already has a forwarder, it's not removed.
|
||||
s.AddPacketForwarder(u1, testFwd(2)) |
||||
s.unregisterClient(u1c) |
||||
want(map[key.NodePublic]PacketForwarder{ |
||||
u1: testFwd(2), |
||||
}) |
||||
|
||||
// Now pretend u1 was already connected locally (so clientsMesh[u1] is nil), and then we heard
|
||||
// that they're also connected to a peer of ours. That shouldn't transition the forwarder
|
||||
// from nil to the new one, not a multiForwarder.
|
||||
s.clients[u1] = singleClient(u1c) |
||||
s.clientsMesh[u1] = nil |
||||
want(map[key.NodePublic]PacketForwarder{ |
||||
u1: nil, |
||||
}) |
||||
s.AddPacketForwarder(u1, testFwd(3)) |
||||
want(map[key.NodePublic]PacketForwarder{ |
||||
u1: testFwd(3), |
||||
}) |
||||
} |
||||
|
||||
type channelFwd struct { |
||||
// id is to ensure that different instances that reference the
|
||||
// same channel are not equal, as they are used as keys in the
|
||||
// multiForwarder map.
|
||||
id int |
||||
c chan []byte |
||||
} |
||||
|
||||
func (f channelFwd) String() string { return "" } |
||||
func (f channelFwd) ForwardPacket(_ key.NodePublic, _ key.NodePublic, packet []byte) error { |
||||
f.c <- packet |
||||
return nil |
||||
} |
||||
|
||||
func TestMultiForwarder(t *testing.T) { |
||||
received := 0 |
||||
var wg sync.WaitGroup |
||||
ch := make(chan []byte) |
||||
ctx, cancel := context.WithCancel(context.Background()) |
||||
|
||||
s := &Server{ |
||||
clients: make(map[key.NodePublic]*clientSet), |
||||
clientsMesh: map[key.NodePublic]PacketForwarder{}, |
||||
} |
||||
u := pubAll(1) |
||||
s.AddPacketForwarder(u, channelFwd{1, ch}) |
||||
|
||||
wg.Add(2) |
||||
go func() { |
||||
defer wg.Done() |
||||
for { |
||||
select { |
||||
case <-ch: |
||||
received += 1 |
||||
case <-ctx.Done(): |
||||
return |
||||
} |
||||
} |
||||
}() |
||||
go func() { |
||||
defer wg.Done() |
||||
for { |
||||
s.AddPacketForwarder(u, channelFwd{2, ch}) |
||||
s.AddPacketForwarder(u, channelFwd{3, ch}) |
||||
s.RemovePacketForwarder(u, channelFwd{2, ch}) |
||||
s.RemovePacketForwarder(u, channelFwd{1, ch}) |
||||
s.AddPacketForwarder(u, channelFwd{1, ch}) |
||||
s.RemovePacketForwarder(u, channelFwd{3, ch}) |
||||
if ctx.Err() != nil { |
||||
return |
||||
} |
||||
} |
||||
}() |
||||
|
||||
// Number of messages is chosen arbitrarily, just for this loop to
|
||||
// run long enough concurrently with {Add,Remove}PacketForwarder loop above.
|
||||
numMsgs := 5000 |
||||
var fwd PacketForwarder |
||||
for i := range numMsgs { |
||||
s.mu.Lock() |
||||
fwd = s.clientsMesh[u] |
||||
s.mu.Unlock() |
||||
fwd.ForwardPacket(u, u, []byte(strconv.Itoa(i))) |
||||
} |
||||
|
||||
cancel() |
||||
wg.Wait() |
||||
if received != numMsgs { |
||||
t.Errorf("expected %d messages to be forwarded; got %d", numMsgs, received) |
||||
} |
||||
} |
||||
func TestMetaCert(t *testing.T) { |
||||
priv := key.NewNode() |
||||
pub := priv.Public() |
||||
s := NewServer(priv, t.Logf) |
||||
|
||||
certBytes := s.MetaCert() |
||||
cert, err := x509.ParseCertificate(certBytes) |
||||
if err != nil { |
||||
log.Fatal(err) |
||||
} |
||||
if fmt.Sprint(cert.SerialNumber) != fmt.Sprint(derp.ProtocolVersion) { |
||||
t.Errorf("serial = %v; want %v", cert.SerialNumber, derp.ProtocolVersion) |
||||
} |
||||
if g, w := cert.Subject.CommonName, derpconst.MetaCertCommonNamePrefix+pub.UntypedHexString(); g != w { |
||||
t.Errorf("CommonName = %q; want %q", g, w) |
||||
} |
||||
if n := len(cert.Extensions); n != 1 { |
||||
t.Fatalf("got %d extensions; want 1", n) |
||||
} |
||||
|
||||
// oidExtensionBasicConstraints is the Basic Constraints ID copied
|
||||
// from the x509 package.
|
||||
oidExtensionBasicConstraints := asn1.ObjectIdentifier{2, 5, 29, 19} |
||||
|
||||
if id := cert.Extensions[0].Id; !id.Equal(oidExtensionBasicConstraints) { |
||||
t.Errorf("extension ID = %v; want %v", id, oidExtensionBasicConstraints) |
||||
} |
||||
} |
||||
|
||||
func TestServerDupClients(t *testing.T) { |
||||
serverPriv := key.NewNode() |
||||
var s *Server |
||||
|
||||
clientPriv := key.NewNode() |
||||
clientPub := clientPriv.Public() |
||||
|
||||
var c1, c2, c3 *sclient |
||||
var clientName map[*sclient]string |
||||
|
||||
// run starts a new test case and resets clients back to their zero values.
|
||||
run := func(name string, dupPolicy dupPolicy, f func(t *testing.T)) { |
||||
s = NewServer(serverPriv, t.Logf) |
||||
s.dupPolicy = dupPolicy |
||||
c1 = &sclient{key: clientPub, logf: logger.WithPrefix(t.Logf, "c1: ")} |
||||
c2 = &sclient{key: clientPub, logf: logger.WithPrefix(t.Logf, "c2: ")} |
||||
c3 = &sclient{key: clientPub, logf: logger.WithPrefix(t.Logf, "c3: ")} |
||||
clientName = map[*sclient]string{ |
||||
c1: "c1", |
||||
c2: "c2", |
||||
c3: "c3", |
||||
} |
||||
t.Run(name, f) |
||||
} |
||||
runBothWays := func(name string, f func(t *testing.T)) { |
||||
run(name+"_disablefighters", disableFighters, f) |
||||
run(name+"_lastwriteractive", lastWriterIsActive, f) |
||||
} |
||||
wantSingleClient := func(t *testing.T, want *sclient) { |
||||
t.Helper() |
||||
got, ok := s.clients[want.key] |
||||
if !ok { |
||||
t.Error("no clients for key") |
||||
return |
||||
} |
||||
if got.dup != nil { |
||||
t.Errorf("unexpected dup set for single client") |
||||
} |
||||
cur := got.activeClient.Load() |
||||
if cur != want { |
||||
t.Errorf("active client = %q; want %q", clientName[cur], clientName[want]) |
||||
} |
||||
if cur != nil { |
||||
if cur.isDup.Load() { |
||||
t.Errorf("unexpected isDup on singleClient") |
||||
} |
||||
if cur.isDisabled.Load() { |
||||
t.Errorf("unexpected isDisabled on singleClient") |
||||
} |
||||
} |
||||
} |
||||
wantNoClient := func(t *testing.T) { |
||||
t.Helper() |
||||
_, ok := s.clients[clientPub] |
||||
if !ok { |
||||
// Good
|
||||
return |
||||
} |
||||
t.Errorf("got client; want empty") |
||||
} |
||||
wantDupSet := func(t *testing.T) *dupClientSet { |
||||
t.Helper() |
||||
cs, ok := s.clients[clientPub] |
||||
if !ok { |
||||
t.Fatal("no set for key; want dup set") |
||||
return nil |
||||
} |
||||
if cs.dup != nil { |
||||
return cs.dup |
||||
} |
||||
t.Fatalf("no dup set for key; want dup set") |
||||
return nil |
||||
} |
||||
wantActive := func(t *testing.T, want *sclient) { |
||||
t.Helper() |
||||
set, ok := s.clients[clientPub] |
||||
if !ok { |
||||
t.Error("no set for key") |
||||
return |
||||
} |
||||
got := set.activeClient.Load() |
||||
if got != want { |
||||
t.Errorf("active client = %q; want %q", clientName[got], clientName[want]) |
||||
} |
||||
} |
||||
checkDup := func(t *testing.T, c *sclient, want bool) { |
||||
t.Helper() |
||||
if got := c.isDup.Load(); got != want { |
||||
t.Errorf("client %q isDup = %v; want %v", clientName[c], got, want) |
||||
} |
||||
} |
||||
checkDisabled := func(t *testing.T, c *sclient, want bool) { |
||||
t.Helper() |
||||
if got := c.isDisabled.Load(); got != want { |
||||
t.Errorf("client %q isDisabled = %v; want %v", clientName[c], got, want) |
||||
} |
||||
} |
||||
wantDupConns := func(t *testing.T, want int) { |
||||
t.Helper() |
||||
if got := s.dupClientConns.Value(); got != int64(want) { |
||||
t.Errorf("dupClientConns = %v; want %v", got, want) |
||||
} |
||||
} |
||||
wantDupKeys := func(t *testing.T, want int) { |
||||
t.Helper() |
||||
if got := s.dupClientKeys.Value(); got != int64(want) { |
||||
t.Errorf("dupClientKeys = %v; want %v", got, want) |
||||
} |
||||
} |
||||
|
||||
// Common case: a single client comes and goes, with no dups.
|
||||
runBothWays("one_comes_and_goes", func(t *testing.T) { |
||||
wantNoClient(t) |
||||
s.registerClient(c1) |
||||
wantSingleClient(t, c1) |
||||
s.unregisterClient(c1) |
||||
wantNoClient(t) |
||||
}) |
||||
|
||||
// A still somewhat common case: a single client was
|
||||
// connected and then their wifi dies or laptop closes
|
||||
// or they switch networks and connect from a
|
||||
// different network. They have two connections but
|
||||
// it's not very bad. Only their new one is
|
||||
// active. The last one, being dead, doesn't send and
|
||||
// thus the new one doesn't get disabled.
|
||||
runBothWays("small_overlap_replacement", func(t *testing.T) { |
||||
wantNoClient(t) |
||||
s.registerClient(c1) |
||||
wantSingleClient(t, c1) |
||||
wantActive(t, c1) |
||||
wantDupKeys(t, 0) |
||||
wantDupKeys(t, 0) |
||||
|
||||
s.registerClient(c2) // wifi dies; c2 replacement connects
|
||||
wantDupSet(t) |
||||
wantDupConns(t, 2) |
||||
wantDupKeys(t, 1) |
||||
checkDup(t, c1, true) |
||||
checkDup(t, c2, true) |
||||
checkDisabled(t, c1, false) |
||||
checkDisabled(t, c2, false) |
||||
wantActive(t, c2) // sends go to the replacement
|
||||
|
||||
s.unregisterClient(c1) // c1 finally times out
|
||||
wantSingleClient(t, c2) |
||||
checkDup(t, c2, false) // c2 is longer a dup
|
||||
wantActive(t, c2) |
||||
wantDupConns(t, 0) |
||||
wantDupKeys(t, 0) |
||||
}) |
||||
|
||||
// Key cloning situation with concurrent clients, both trying
|
||||
// to write.
|
||||
run("concurrent_dups_get_disabled", disableFighters, func(t *testing.T) { |
||||
wantNoClient(t) |
||||
s.registerClient(c1) |
||||
wantSingleClient(t, c1) |
||||
wantActive(t, c1) |
||||
s.registerClient(c2) |
||||
wantDupSet(t) |
||||
wantDupKeys(t, 1) |
||||
wantDupConns(t, 2) |
||||
wantActive(t, c2) |
||||
checkDup(t, c1, true) |
||||
checkDup(t, c2, true) |
||||
checkDisabled(t, c1, false) |
||||
checkDisabled(t, c2, false) |
||||
|
||||
s.noteClientActivity(c2) |
||||
checkDisabled(t, c1, false) |
||||
checkDisabled(t, c2, false) |
||||
s.noteClientActivity(c1) |
||||
checkDisabled(t, c1, true) |
||||
checkDisabled(t, c2, true) |
||||
wantActive(t, nil) |
||||
|
||||
s.registerClient(c3) |
||||
wantActive(t, c3) |
||||
checkDisabled(t, c3, false) |
||||
wantDupKeys(t, 1) |
||||
wantDupConns(t, 3) |
||||
|
||||
s.unregisterClient(c3) |
||||
wantActive(t, nil) |
||||
wantDupKeys(t, 1) |
||||
wantDupConns(t, 2) |
||||
|
||||
s.unregisterClient(c2) |
||||
wantSingleClient(t, c1) |
||||
wantDupKeys(t, 0) |
||||
wantDupConns(t, 0) |
||||
}) |
||||
|
||||
// Key cloning with an A->B->C->A series instead.
|
||||
run("concurrent_dups_three_parties", disableFighters, func(t *testing.T) { |
||||
wantNoClient(t) |
||||
s.registerClient(c1) |
||||
s.registerClient(c2) |
||||
s.registerClient(c3) |
||||
s.noteClientActivity(c1) |
||||
checkDisabled(t, c1, true) |
||||
checkDisabled(t, c2, true) |
||||
checkDisabled(t, c3, true) |
||||
wantActive(t, nil) |
||||
}) |
||||
|
||||
run("activity_promotes_primary_when_nil", disableFighters, func(t *testing.T) { |
||||
wantNoClient(t) |
||||
|
||||
// Last registered client is the active one...
|
||||
s.registerClient(c1) |
||||
wantActive(t, c1) |
||||
s.registerClient(c2) |
||||
wantActive(t, c2) |
||||
s.registerClient(c3) |
||||
s.noteClientActivity(c2) |
||||
wantActive(t, c3) |
||||
|
||||
// But if the last one goes away, the one with the
|
||||
// most recent activity wins.
|
||||
s.unregisterClient(c3) |
||||
wantActive(t, c2) |
||||
}) |
||||
|
||||
run("concurrent_dups_three_parties_last_writer", lastWriterIsActive, func(t *testing.T) { |
||||
wantNoClient(t) |
||||
|
||||
s.registerClient(c1) |
||||
wantActive(t, c1) |
||||
s.registerClient(c2) |
||||
wantActive(t, c2) |
||||
|
||||
s.noteClientActivity(c1) |
||||
checkDisabled(t, c1, false) |
||||
checkDisabled(t, c2, false) |
||||
wantActive(t, c1) |
||||
|
||||
s.noteClientActivity(c2) |
||||
checkDisabled(t, c1, false) |
||||
checkDisabled(t, c2, false) |
||||
wantActive(t, c2) |
||||
|
||||
s.unregisterClient(c2) |
||||
checkDisabled(t, c1, false) |
||||
wantActive(t, c1) |
||||
}) |
||||
} |
||||
|
||||
func TestLimiter(t *testing.T) { |
||||
rl := rate.NewLimiter(rate.Every(time.Minute), 100) |
||||
for i := range 200 { |
||||
r := rl.Reserve() |
||||
d := r.Delay() |
||||
t.Logf("i=%d, allow=%v, d=%v", i, r.OK(), d) |
||||
} |
||||
} |
||||
|
||||
// BenchmarkConcurrentStreams exercises mutex contention on a
|
||||
// single Server instance with multiple concurrent client flows.
|
||||
func BenchmarkConcurrentStreams(b *testing.B) { |
||||
serverPrivateKey := key.NewNode() |
||||
s := NewServer(serverPrivateKey, logger.Discard) |
||||
defer s.Close() |
||||
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0") |
||||
if err != nil { |
||||
b.Fatal(err) |
||||
} |
||||
defer ln.Close() |
||||
|
||||
ctx, cancel := context.WithCancel(context.Background()) |
||||
defer cancel() |
||||
|
||||
go func() { |
||||
for ctx.Err() == nil { |
||||
connIn, err := ln.Accept() |
||||
if err != nil { |
||||
if ctx.Err() != nil { |
||||
return |
||||
} |
||||
b.Error(err) |
||||
return |
||||
} |
||||
|
||||
brwServer := bufio.NewReadWriter(bufio.NewReader(connIn), bufio.NewWriter(connIn)) |
||||
go s.Accept(ctx, connIn, brwServer, "test-client") |
||||
} |
||||
}() |
||||
|
||||
newClient := func(t testing.TB) *derp.Client { |
||||
t.Helper() |
||||
connOut, err := net.Dial("tcp", ln.Addr().String()) |
||||
if err != nil { |
||||
b.Fatal(err) |
||||
} |
||||
t.Cleanup(func() { connOut.Close() }) |
||||
|
||||
k := key.NewNode() |
||||
|
||||
brw := bufio.NewReadWriter(bufio.NewReader(connOut), bufio.NewWriter(connOut)) |
||||
client, err := derp.NewClient(k, connOut, brw, logger.Discard) |
||||
if err != nil { |
||||
b.Fatalf("client: %v", err) |
||||
} |
||||
return client |
||||
} |
||||
|
||||
b.RunParallel(func(pb *testing.PB) { |
||||
c1, c2 := newClient(b), newClient(b) |
||||
const packetSize = 100 |
||||
msg := make([]byte, packetSize) |
||||
for pb.Next() { |
||||
if err := c1.Send(c2.PublicKey(), msg); err != nil { |
||||
b.Fatal(err) |
||||
} |
||||
_, err := c2.Recv() |
||||
if err != nil { |
||||
return |
||||
} |
||||
} |
||||
}) |
||||
} |
||||
|
||||
func BenchmarkSendRecv(b *testing.B) { |
||||
for _, size := range []int{10, 100, 1000, 10000} { |
||||
b.Run(fmt.Sprintf("msgsize=%d", size), func(b *testing.B) { benchmarkSendRecvSize(b, size) }) |
||||
} |
||||
} |
||||
|
||||
func benchmarkSendRecvSize(b *testing.B, packetSize int) { |
||||
serverPrivateKey := key.NewNode() |
||||
s := NewServer(serverPrivateKey, logger.Discard) |
||||
defer s.Close() |
||||
|
||||
k := key.NewNode() |
||||
clientKey := k.Public() |
||||
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0") |
||||
if err != nil { |
||||
b.Fatal(err) |
||||
} |
||||
defer ln.Close() |
||||
|
||||
connOut, err := net.Dial("tcp", ln.Addr().String()) |
||||
if err != nil { |
||||
b.Fatal(err) |
||||
} |
||||
defer connOut.Close() |
||||
|
||||
connIn, err := ln.Accept() |
||||
if err != nil { |
||||
b.Fatal(err) |
||||
} |
||||
defer connIn.Close() |
||||
|
||||
brwServer := bufio.NewReadWriter(bufio.NewReader(connIn), bufio.NewWriter(connIn)) |
||||
|
||||
ctx, cancel := context.WithCancel(context.Background()) |
||||
defer cancel() |
||||
go s.Accept(ctx, connIn, brwServer, "test-client") |
||||
|
||||
brw := bufio.NewReadWriter(bufio.NewReader(connOut), bufio.NewWriter(connOut)) |
||||
client, err := derp.NewClient(k, connOut, brw, logger.Discard) |
||||
if err != nil { |
||||
b.Fatalf("client: %v", err) |
||||
} |
||||
|
||||
go func() { |
||||
for { |
||||
_, err := client.Recv() |
||||
if err != nil { |
||||
return |
||||
} |
||||
} |
||||
}() |
||||
|
||||
msg := make([]byte, packetSize) |
||||
b.SetBytes(int64(len(msg))) |
||||
b.ReportAllocs() |
||||
b.ResetTimer() |
||||
for range b.N { |
||||
if err := client.Send(clientKey, msg); err != nil { |
||||
b.Fatal(err) |
||||
} |
||||
} |
||||
} |
||||
|
||||
func TestParseSSOutput(t *testing.T) { |
||||
contents, err := os.ReadFile("testdata/example_ss.txt") |
||||
if err != nil { |
||||
t.Errorf("os.ReadFile(example_ss.txt) failed: %v", err) |
||||
} |
||||
seen := parseSSOutput(string(contents)) |
||||
if len(seen) == 0 { |
||||
t.Errorf("parseSSOutput expected non-empty map") |
||||
} |
||||
} |
||||
|
||||
func TestGetPerClientSendQueueDepth(t *testing.T) { |
||||
c := qt.New(t) |
||||
envKey := "TS_DEBUG_DERP_PER_CLIENT_SEND_QUEUE_DEPTH" |
||||
|
||||
testCases := []struct { |
||||
envVal string |
||||
want int |
||||
}{ |
||||
// Empty case, envknob treats empty as missing also.
|
||||
{ |
||||
"", defaultPerClientSendQueueDepth, |
||||
}, |
||||
{ |
||||
"64", 64, |
||||
}, |
||||
} |
||||
|
||||
for _, tc := range testCases { |
||||
t.Run(cmp.Or(tc.envVal, "empty"), func(t *testing.T) { |
||||
t.Setenv(envKey, tc.envVal) |
||||
val := getPerClientSendQueueDepth() |
||||
c.Assert(val, qt.Equals, tc.want) |
||||
}) |
||||
} |
||||
} |
||||
@ -0,0 +1,10 @@ |
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package derp |
||||
|
||||
import "time" |
||||
|
||||
func (c *Client) RecvTimeoutForTest(timeout time.Duration) (m ReceivedMessage, err error) { |
||||
return c.recvTimeout(timeout) |
||||
} |
||||
Loading…
Reference in new issue