derp/derpserver: implement hierarchical token bucket rate limiting

By adding a server-global parent bucket. Per-client rate limiting is
subject to the parent bucket if global rate limiting is enabled.

This implementation is experimental, and all related APIs should be
considered unstable.

Updates tailscale/corp#40291

Signed-off-by: Jordan Whited <jordan@tailscale.com>
This commit is contained in:
Jordan Whited
2026-04-13 19:34:34 -07:00
committed by Jordan Whited
parent 5eb0b4be31
commit d8190e0de5
2 changed files with 243 additions and 152 deletions
+155 -90
View File
@@ -965,11 +965,28 @@ func TestPerClientRateLimit(t *testing.T) {
c := &sclient{
ctx: ctx,
}
c.recvLim.Store(rate.NewLimiter(rate.Limit(minRateLimitTokenBucketSize), minRateLimitTokenBucketSize))
lim := &parentChildTokenBuckets{
// Set parent limit to half of child to enable verification of
// rate limiting across both layers with a single sclient.
parent: rate.NewLimiter(rate.Limit(minRateLimitTokenBucketSize)/2, minRateLimitTokenBucketSize),
child: rate.NewLimiter(rate.Limit(minRateLimitTokenBucketSize), minRateLimitTokenBucketSize),
}
c.recvLim.Store(lim)
wantTokens := func(t *testing.T, wantParentTokens, wantChildTokens float64) {
t.Helper()
if lim.parent.Tokens() != wantParentTokens {
t.Fatalf("want parent tokens: %v got: %v", wantParentTokens, lim.parent.Tokens())
}
if lim.child.Tokens() != wantChildTokens {
t.Fatalf("want child tokens: %v got: %v", wantChildTokens, lim.child.Tokens())
}
}
// First call within burst should not block.
c.rateLimit(minRateLimitTokenBucketSize)
wantTokens(t, 0, 0)
// Next call exceeds burst, should block until tokens replenish.
done := make(chan error, 1)
go func() {
@@ -984,7 +1001,21 @@ func TestPerClientRateLimit(t *testing.T) {
default:
}
// Advance time by 1 second
// Advance time by 1 second, the goroutine should still be blocked
// on the parent bucket (negative tokens).
time.Sleep(1 * time.Second)
synctest.Wait()
select {
case err := <-done:
t.Fatalf("rateLimit should have blocked, but returned: %v", err)
default:
}
// Verify the parent bucket fills at half the rate of the child.
wantTokens(t, -(minRateLimitTokenBucketSize / 2), 0)
// Advance time by another second, parent should have enough tokens
// to unblock.
time.Sleep(1 * time.Second)
synctest.Wait()
@@ -996,6 +1027,8 @@ func TestPerClientRateLimit(t *testing.T) {
default:
t.Fatal("rateLimit should have unblocked after 1s")
}
wantTokens(t, 0, minRateLimitTokenBucketSize)
})
})
@@ -1006,7 +1039,11 @@ func TestPerClientRateLimit(t *testing.T) {
c := &sclient{
ctx: ctx,
}
c.recvLim.Store(rate.NewLimiter(rate.Limit(minRateLimitTokenBucketSize), minRateLimitTokenBucketSize))
lim := &parentChildTokenBuckets{
child: rate.NewLimiter(rate.Limit(minRateLimitTokenBucketSize), minRateLimitTokenBucketSize),
parent: rate.NewLimiter(rate.Limit(minRateLimitTokenBucketSize), minRateLimitTokenBucketSize),
}
c.recvLim.Store(lim)
// Exhaust burst.
if err := c.rateLimit(minRateLimitTokenBucketSize); err != nil {
@@ -1052,18 +1089,38 @@ func TestPerClientRateLimit(t *testing.T) {
t.Run("zero_config_no_limiter", func(t *testing.T) {
s := New(key.NewNode(), logger.Discard)
defer s.Close()
if s.perClientRecvBytesPerSec != 0 {
t.Errorf("expected zero rate limit, got %d", s.perClientRecvBytesPerSec)
if !reflect.DeepEqual(s.rateConfig, RateConfig{}) {
t.Errorf("expected zero rate limit, got %+v", s.rateConfig)
}
})
}
func TestUpdatePerClientRateLimit(t *testing.T) {
func verifyLimiter(t *testing.T, lim *parentChildTokenBuckets, wantRateConfig RateConfig) {
t.Helper()
if got := lim.child.Limit(); got != rate.Limit(wantRateConfig.PerClientRateLimitBytesPerSec) {
t.Errorf("client rate limit = %v; want %d", got, wantRateConfig.PerClientRateLimitBytesPerSec)
}
if got := lim.child.Burst(); got != int(wantRateConfig.PerClientRateBurstBytes) {
t.Errorf("client burst = %v; want %d", got, wantRateConfig.PerClientRateBurstBytes)
}
if got := lim.parent.Limit(); got != rate.Limit(wantRateConfig.GlobalRateLimitBytesPerSec) {
t.Errorf("global rate limit = %v, want %d", got, wantRateConfig.GlobalRateLimitBytesPerSec)
}
if got := lim.parent.Burst(); got != int(wantRateConfig.GlobalRateBurstBytes) {
t.Errorf("global burst = %v, want %d", got, wantRateConfig.GlobalRateBurstBytes)
}
}
func TestUpdateRateLimits(t *testing.T) {
const (
testBurst1 = derp.MaxPacketSize * 2
testRate1 = 1000
testBurst2 = derp.MaxPacketSize * 4
testRate2 = 5000
testClientBurst1 = minRateLimitTokenBucketSize + 1
testClientRate1 = minRateLimitTokenBucketSize + 2
testClientBurst2 = minRateLimitTokenBucketSize + 3
testClientRate2 = minRateLimitTokenBucketSize + 4
testGlobalBurst1 = minRateLimitTokenBucketSize + 5
testGlobalRate1 = minRateLimitTokenBucketSize + 6
testGlobalBurst2 = minRateLimitTokenBucketSize + 7
testGlobalRate2 = minRateLimitTokenBucketSize + 8
)
s := New(key.NewNode(), t.Logf)
@@ -1084,52 +1141,46 @@ func TestUpdatePerClientRateLimit(t *testing.T) {
s.clients[clientKey] = cs
s.mu.Unlock()
s.UpdatePerClientRateLimit(testRate1, testBurst1)
rc := RateConfig{
PerClientRateLimitBytesPerSec: testClientRate1,
PerClientRateBurstBytes: testClientBurst1,
GlobalRateLimitBytesPerSec: testGlobalRate1,
GlobalRateBurstBytes: testGlobalBurst1,
}
s.UpdateRateLimits(rc)
lim := c.recvLim.Load()
if lim == nil {
t.Fatal("expected non-nil limiter after update")
}
if got := lim.Limit(); got != rate.Limit(testRate1) {
t.Errorf("rate limit = %v; want %d", got, testRate1)
}
if got := lim.Burst(); got != int(testBurst1) {
t.Errorf("burst = %v; want %d", got, testBurst1)
}
verifyLimiter(t, lim, rc)
// Verify server fields updated.
s.mu.Lock()
if s.perClientRecvBytesPerSec != testRate1 {
t.Errorf("server rate = %d; want %d", s.perClientRecvBytesPerSec, testRate1)
}
if s.perClientRecvBurst != testBurst1 {
t.Errorf("server burst = %d; want %d", s.perClientRecvBurst, testBurst1)
if !reflect.DeepEqual(s.rateConfig, rc) {
t.Errorf("s.rateConfig = %+v; want %+v", s.rateConfig, rc)
}
s.mu.Unlock()
// Update again with different nonzero values. This exercises the
// in-place update path (existing limiter is reused, not recreated).
prevLim := c.recvLim.Load()
s.UpdatePerClientRateLimit(testRate2, testBurst2)
// Update again with different nonzero values.
rc = RateConfig{
PerClientRateLimitBytesPerSec: testClientRate2,
PerClientRateBurstBytes: testClientBurst2,
GlobalRateLimitBytesPerSec: testGlobalRate2,
GlobalRateBurstBytes: testGlobalBurst2,
}
s.UpdateRateLimits(rc)
lim = c.recvLim.Load()
if lim == nil {
t.Fatal("expected non-nil limiter after in-place update")
}
if lim != prevLim {
t.Error("expected same limiter pointer after in-place update")
}
if got := lim.Limit(); got != rate.Limit(testRate2) {
t.Errorf("rate limit after in-place update = %v; want %d", got, testRate2)
}
if got := lim.Burst(); got != int(testBurst2) {
t.Errorf("burst after in-place update = %v; want %d", got, testBurst2)
t.Fatal("expected non-nil limiter")
}
verifyLimiter(t, lim, rc)
// Disable rate limiting (set to 0).
s.UpdatePerClientRateLimit(0, 0)
s.UpdateRateLimits(RateConfig{})
if got := c.recvLim.Load(); got != nil {
t.Errorf("expected nil limiter after disable, got limit=%v", got.Limit())
t.Errorf("expected nil limiter after disable, got limit=%v", got.child.Limit())
}
// Mesh peer should always have nil limiter regardless of update.
@@ -1147,22 +1198,23 @@ func TestUpdatePerClientRateLimit(t *testing.T) {
s.clients[meshKey] = meshCS
s.mu.Unlock()
s.UpdatePerClientRateLimit(testRate2, testBurst2)
rc = RateConfig{
PerClientRateLimitBytesPerSec: testClientRate2,
PerClientRateBurstBytes: testClientBurst2,
GlobalRateLimitBytesPerSec: testGlobalRate2,
GlobalRateBurstBytes: testGlobalBurst2,
}
s.UpdateRateLimits(rc)
if got := meshClient.recvLim.Load(); got != nil {
t.Errorf("mesh peer should have nil limiter, got limit=%v", got.Limit())
t.Errorf("mesh peer should have nil limiter, got limit=%v", got.child.Limit())
}
// Non-mesh client should be updated.
lim = c.recvLim.Load()
if lim == nil {
t.Fatal("expected non-nil limiter for non-mesh client")
}
if got := lim.Limit(); got != rate.Limit(testRate2) {
t.Errorf("rate limit = %v; want %d", got, testRate2)
}
if got := lim.Burst(); got != int(testBurst2) {
t.Errorf("burst = %v; want %d", got, testBurst2)
}
verifyLimiter(t, lim, rc)
// Verify dup clients are also updated.
dupKey := key.NewNode().Public()
@@ -1175,32 +1227,40 @@ func TestUpdatePerClientRateLimit(t *testing.T) {
s.clients[dupKey] = dupCS
s.mu.Unlock()
s.UpdatePerClientRateLimit(testRate1, testBurst1)
rc = RateConfig{
GlobalRateLimitBytesPerSec: testGlobalRate1,
GlobalRateBurstBytes: testGlobalBurst1,
PerClientRateLimitBytesPerSec: testClientRate1,
PerClientRateBurstBytes: testClientBurst1,
}
s.UpdateRateLimits(rc)
for i, d := range []*sclient{d1, d2} {
dl := d.recvLim.Load()
if dl == nil {
t.Fatalf("dup client %d: expected non-nil limiter", i)
}
if got := dl.Limit(); got != rate.Limit(testRate1) {
t.Errorf("dup client %d: rate = %v; want %d", i, got, testRate1)
}
if got := dl.Burst(); got != int(testBurst1) {
t.Errorf("dup client %d: burst = %v; want %d", i, got, testBurst1)
}
verifyLimiter(t, dl, rc)
}
}
func TestLoadRateConfig(t *testing.T) {
for _, tt := range []struct {
name string
json string
wantRate uint64
wantBurst uint64
name string
json string
wantRateConfig RateConfig
}{
{"both_set", `{"PerClientRateLimitBytesPerSec": 1250000, "PerClientRateBurstBytes": 2500000}`, 1250000, 2500000},
{"rate_only", `{"PerClientRateLimitBytesPerSec": 500000}`, 500000, 0},
{"zeros", `{"PerClientRateLimitBytesPerSec": 0, "PerClientRateBurstBytes": 0}`, 0, 0},
{"empty_json", `{}`, 0, 0},
{"all_set", `{"PerClientRateLimitBytesPerSec": 1, "PerClientRateBurstBytes": 2, "GlobalRateLimitBytesPerSec": 3, "GlobalRateBurstBytes": 4}`, RateConfig{
PerClientRateLimitBytesPerSec: 1,
PerClientRateBurstBytes: 2,
GlobalRateLimitBytesPerSec: 3,
GlobalRateBurstBytes: 4,
}},
{"rate_only", `{"PerClientRateLimitBytesPerSec": 1, "GlobalRateLimitBytesPerSec": 3}`, RateConfig{
PerClientRateLimitBytesPerSec: 1,
GlobalRateLimitBytesPerSec: 3,
}},
{"zeros", `{"PerClientRateLimitBytesPerSec": 0, "PerClientRateBurstBytes": 0, "GlobalRateLimitBytesPerSec": 0, "GlobalRateBurstBytes": 0}`, RateConfig{}},
{"empty_json", `{}`, RateConfig{}},
} {
t.Run(tt.name, func(t *testing.T) {
f := filepath.Join(t.TempDir(), "rate.json")
@@ -1209,13 +1269,10 @@ func TestLoadRateConfig(t *testing.T) {
}
rc, err := LoadRateConfig(f)
if err != nil {
t.Fatalf("unexpected error: %v", err)
t.Fatal(err)
}
if rc.PerClientRateLimitBytesPerSec != tt.wantRate {
t.Errorf("rate = %d; want %d", rc.PerClientRateLimitBytesPerSec, tt.wantRate)
}
if rc.PerClientRateBurstBytes != tt.wantBurst {
t.Errorf("burst = %d; want %d", rc.PerClientRateBurstBytes, tt.wantBurst)
if !reflect.DeepEqual(rc, tt.wantRateConfig) {
t.Errorf("rate config = %v want %v", rc, tt.wantRateConfig)
}
})
}
@@ -1223,7 +1280,7 @@ func TestLoadRateConfig(t *testing.T) {
for _, tt := range []struct {
name string
path string
content string // written to path if non-empty; path used as-is if empty
content string // written to loaded path if non-empty; path used as-is if empty
}{
{"empty_path", "", ""},
{"missing_file", filepath.Join(t.TempDir(), "nonexistent.json"), ""},
@@ -1267,47 +1324,51 @@ func TestLoadAndApplyRateConfig(t *testing.T) {
s.clients[clientKey] = cs
s.mu.Unlock()
f := writeConfig(t, `{"PerClientRateLimitBytesPerSec": 1250000, "PerClientRateBurstBytes": 2500000}`)
f := writeConfig(t, fmt.Sprintf(`{"PerClientRateLimitBytesPerSec": %d, "PerClientRateBurstBytes": %d, "GlobalRateLimitBytesPerSec": %d, "GlobalRateBurstBytes": %d}`,
minRateLimitTokenBucketSize, minRateLimitTokenBucketSize+1, minRateLimitTokenBucketSize+2, minRateLimitTokenBucketSize+3))
if err := s.LoadAndApplyRateConfig(f); err != nil {
t.Fatalf("LoadAndApplyRateConfig: %v", err)
}
// Verify server fields.
wantRateConfig := RateConfig{
PerClientRateLimitBytesPerSec: minRateLimitTokenBucketSize,
PerClientRateBurstBytes: minRateLimitTokenBucketSize + 1,
GlobalRateLimitBytesPerSec: minRateLimitTokenBucketSize + 2,
GlobalRateBurstBytes: minRateLimitTokenBucketSize + 3,
}
s.mu.Lock()
gotRate := s.perClientRecvBytesPerSec
gotBurst := s.perClientRecvBurst
if !reflect.DeepEqual(s.rateConfig, wantRateConfig) {
t.Errorf("s.rateConfig = %+v; want %+v", s.rateConfig, wantRateConfig)
}
s.mu.Unlock()
if gotRate != 1250000 {
t.Errorf("server rate = %d; want 1250000", gotRate)
}
if gotBurst != 2500000 {
t.Errorf("server burst = %d; want 2500000", gotBurst)
}
// Verify client limiter.
lim := c.recvLim.Load()
if lim == nil {
t.Fatal("expected non-nil limiter")
}
if got := lim.Limit(); got != rate.Limit(1250000) {
t.Errorf("client rate = %v; want 1250000", got)
}
verifyLimiter(t, lim, wantRateConfig)
})
t.Run("burst_is_at_least_max_packet_size", func(t *testing.T) {
t.Run("burst_is_at_least_minRateLimitTokenBucketSize", func(t *testing.T) {
s := New(key.NewNode(), t.Logf)
defer s.Close()
f := writeConfig(t, `{"PerClientRateLimitBytesPerSec": 1250000, "PerClientRateBurstBytes": 10}`)
f := writeConfig(t, `{"PerClientRateLimitBytesPerSec": 1250000, "PerClientRateBurstBytes": 10, "GlobalRateLimitBytesPerSec": 1250000, "GlobalRateBurstBytes": 10}`)
if err := s.LoadAndApplyRateConfig(f); err != nil {
t.Fatalf("LoadAndApplyRateConfig: %v", err)
}
s.mu.Lock()
gotBurst := s.perClientRecvBurst
gotClientBurst := s.rateConfig.PerClientRateBurstBytes
gotGlobalBurst := s.rateConfig.GlobalRateBurstBytes
s.mu.Unlock()
if gotBurst != minRateLimitTokenBucketSize {
t.Errorf("burst = %d; want at least %d", gotBurst, minRateLimitTokenBucketSize)
if gotClientBurst != minRateLimitTokenBucketSize {
t.Errorf("client burst = %d; want %d", gotClientBurst, minRateLimitTokenBucketSize)
}
if gotGlobalBurst != minRateLimitTokenBucketSize {
t.Errorf("global burst = %d; want %d", gotGlobalBurst, minRateLimitTokenBucketSize)
}
})
@@ -1315,10 +1376,15 @@ func TestLoadAndApplyRateConfig(t *testing.T) {
s := New(key.NewNode(), t.Logf)
defer s.Close()
f := writeConfig(t, `{"PerClientRateLimitBytesPerSec": 1250000, "PerClientRateBurstBytes": 2500000}`)
f := writeConfig(t, `{"PerClientRateLimitBytesPerSec": 1250000, "PerClientRateBurstBytes": 2500000, "GlobalRateLimitBytesPerSec": 12500000, "GlobalRateBurstBytes": 25000000}`)
if err := s.LoadAndApplyRateConfig(f); err != nil {
t.Fatal(err)
}
s.mu.Lock()
if reflect.DeepEqual(s.rateConfig, RateConfig{}) {
t.Error("s.rateConfig is zero val; want nonzero rates")
}
s.mu.Unlock()
if err := os.WriteFile(f, []byte(`{}`), 0644); err != nil {
t.Fatal(err)
@@ -1328,11 +1394,10 @@ func TestLoadAndApplyRateConfig(t *testing.T) {
}
s.mu.Lock()
gotRate := s.perClientRecvBytesPerSec
s.mu.Unlock()
if gotRate != 0 {
t.Errorf("rate = %d; want 0 (unlimited)", gotRate)
if !reflect.DeepEqual(s.rateConfig, RateConfig{}) {
t.Errorf("s.rateConfig = %+v; want %+v", s.rateConfig, RateConfig{})
}
s.mu.Unlock()
})
t.Run("propagates_errors", func(t *testing.T) {