From d8190e0de56fd580d3330c0e789d1eee9eee34d4 Mon Sep 17 00:00:00 2001 From: Jordan Whited Date: Mon, 13 Apr 2026 19:34:34 -0700 Subject: [PATCH] derp/derpserver: implement hierarchical token bucket rate limiting By adding a server-global parent bucket. Per-client rate limiting is subject to the parent bucket if global rate limiting is enabled. This implementation is experimental, and all related APIs should be considered unstable. Updates tailscale/corp#40291 Signed-off-by: Jordan Whited --- derp/derpserver/derpserver.go | 150 ++++++++++-------- derp/derpserver/derpserver_test.go | 245 ++++++++++++++++++----------- 2 files changed, 243 insertions(+), 152 deletions(-) diff --git a/derp/derpserver/derpserver.go b/derp/derpserver/derpserver.go index f05f94f09..40d8d3c51 100644 --- a/derp/derpserver/derpserver.go +++ b/derp/derpserver/derpserver.go @@ -181,7 +181,11 @@ type Server struct { verifyClientsURL string verifyClientsURLFailOpen bool - mu syncs.Mutex + perClientSendQueueDepth int // Sets the client send queue depth for the server. + tcpWriteTimeout time.Duration + clock tstime.Clock + + mu syncs.Mutex // guards the following fields closed bool netConns map[derp.Conn]chan struct{} // chan is closed when conn closes clients map[key.NodePublic]*clientSet @@ -197,25 +201,10 @@ type Server struct { // is gone from the region, we notify all of these watchers, // calling their funcs in a new goroutine. peerGoneWatchers map[key.NodePublic]set.HandleSet[func(key.NodePublic)] - // maps from netip.AddrPort to a client's public key - keyOfAddr map[netip.AddrPort]key.NodePublic - - // Sets the client send queue depth for the server. - perClientSendQueueDepth int - - tcpWriteTimeout time.Duration - - // perClientRecvBytesPerSec is the rate limit for receiving data from - // a single client connection, in bytes per second. 0 means unlimited. - // Mesh peers are exempt from this limit. - perClientRecvBytesPerSec uint64 - // perClientRecvBurst is the burst size in bytes for the per-client - // receive rate limiter. Always at least [derp.MaxPacketSize] when - // set via [Server.UpdatePerClientRateLimit]. - perClientRecvBurst uint64 - - clock tstime.Clock + keyOfAddr map[netip.AddrPort]key.NodePublic + rateConfig RateConfig // server-global and per-client DERP frame rate limiting config + recvLim *xrate.Limiter // server-global DERP frame receive limiter } // clientSet represents 1 or more *sclients. @@ -523,16 +512,29 @@ func (s *Server) SetTCPWriteTimeout(d time.Duration) { // our token bucket calls for. const minRateLimitTokenBucketSize = derp.MaxPacketSize + derp.KeyLen -// RateConfig is a JSON-serializable configuration for per-client rate limits. -// Values are in bytes. +// RateConfig is a JSON-serializable configuration for rate limits. Values are +// in bytes. type RateConfig struct { // PerClientRateLimitBytesPerSec represents the per-client - // rate limit in bytes per second. A zero value disables rate-limiting. + // rate limit in bytes per second. A zero value disables rate limiting. PerClientRateLimitBytesPerSec uint64 `json:",omitzero"` // PerClientRateBurstBytes represents the per-client token bucket depth, // or burst, in bytes. Any value lower than [minRateLimitTokenBucketSize] - // will be increased to [minRateLimitTokenBucketSize] before application. + // will be increased to [minRateLimitTokenBucketSize] before application. Only + // relevant if PerClientRateLimitBytesPerSec is nonzero. PerClientRateBurstBytes uint64 `json:",omitzero"` + // GlobalRateLimitBytesPerSec represents the global rate limit in bytes per + // second. A zero value disables global rate limiting, but per-client (PerClient...) + // configuration may still apply. Only relevant if PerClientRateLimitBytesPerSec + // is nonzero. If GlobalRateLimitBytesPerSec is nonzero and less than + // PerClientRateLimitBytesPerSec, then GlobalRateLimitBytesPerSec will be set + // equal to PerClientRateLimitBytesPerSec before application. + GlobalRateLimitBytesPerSec uint64 `json:",omitzero"` + // GlobalRateBurstBytes represents the global token bucket depth, or burst, + // in bytes. Any value lower than [minRateLimitTokenBucketSize] will be increased to + // [minRateLimitTokenBucketSize] before application. Only relevant if + // PerClientRateLimitBytesPerSec and GlobalRateLimitBytesPerSec are nonzero. + GlobalRateBurstBytes uint64 `json:",omitzero"` } // LoadRateConfig reads and JSON-unmarshals a [RateConfig] from the file at path. @@ -542,41 +544,56 @@ func LoadRateConfig(path string) (RateConfig, error) { } b, err := os.ReadFile(path) if err != nil { - return RateConfig{}, fmt.Errorf("reading rate config: %w", err) + return RateConfig{}, fmt.Errorf("error reading rate config: %w", err) } var rc RateConfig if err := json.Unmarshal(b, &rc); err != nil { - return RateConfig{}, fmt.Errorf("parsing rate config: %w", err) + return RateConfig{}, fmt.Errorf("error parsing rate config: %w", err) } return rc, nil } // LoadAndApplyRateConfig reads a [RateConfig] from the file at path and -// applies it to the server via [Server.UpdatePerClientRateLimit]. +// applies it to the server via [Server.UpdateRateLimits]. func (s *Server) LoadAndApplyRateConfig(path string) error { rc, err := LoadRateConfig(path) if err != nil { return err } - s.UpdatePerClientRateLimit(rc.PerClientRateLimitBytesPerSec, rc.PerClientRateBurstBytes) - s.logf("rate config applied: rate=%d bytes/sec, burst=%d bytes", rc.PerClientRateLimitBytesPerSec, rc.PerClientRateBurstBytes) + applied := s.UpdateRateLimits(rc) + s.logf("rate config applied: global-rate=%d bytes/sec global-burst=%d bytes client-rate=%d bytes/sec, client-burst=%d bytes", + applied.GlobalRateLimitBytesPerSec, applied.GlobalRateBurstBytes, applied.PerClientRateLimitBytesPerSec, applied.PerClientRateBurstBytes) return nil } -// UpdatePerClientRateLimit sets the per-client receive rate limit in bytes per -// second and the burst size in bytes, updating all existing client connections. -// The applied burst will be at least [minRateLimitTokenBucketSize]. If bytesPerSec is -// 0, rate limiting is disabled. Mesh peers are always exempt from rate limiting. -func (s *Server) UpdatePerClientRateLimit(bytesPerSec, burst uint64) { +// UpdateRateLimits sets the receive rate limits, updating all existing client +// connections. It returns the applied config, which may differ from rc. If the +// per-client rate limit is 0, rate limiting is disabled. Mesh peers are always +// exempt from rate limiting. +func (s *Server) UpdateRateLimits(rc RateConfig) (applied RateConfig) { s.mu.Lock() defer s.mu.Unlock() - s.perClientRecvBytesPerSec = bytesPerSec - s.perClientRecvBurst = max(burst, minRateLimitTokenBucketSize) + if rc.PerClientRateLimitBytesPerSec == 0 { + // if per-client is disabled, all rate limiting is disabled + rc = RateConfig{} + } + if rc.PerClientRateLimitBytesPerSec != 0 { + rc.PerClientRateBurstBytes = max(rc.PerClientRateBurstBytes, minRateLimitTokenBucketSize) + } + if rc.GlobalRateLimitBytesPerSec != 0 { + rc.GlobalRateLimitBytesPerSec = max(rc.GlobalRateLimitBytesPerSec, rc.PerClientRateLimitBytesPerSec) + rc.GlobalRateBurstBytes = max(rc.GlobalRateBurstBytes, minRateLimitTokenBucketSize) + s.recvLim = xrate.NewLimiter(xrate.Limit(rc.GlobalRateLimitBytesPerSec), int(rc.GlobalRateBurstBytes)) + } else { + s.recvLim = nil + } + s.rateConfig = rc for _, cs := range s.clients { cs.ForeachClient(func(c *sclient) { - c.setRateLimit(s.perClientRecvBytesPerSec, s.perClientRecvBurst) + c.setRateLimit(rc.PerClientRateLimitBytesPerSec, rc.PerClientRateBurstBytes, s.recvLim) }) } + return rc } // HasMeshKey reports whether the server is configured with a mesh key. @@ -741,7 +758,7 @@ func (s *Server) registerClient(c *sclient) { s.mu.Lock() defer s.mu.Unlock() - c.setRateLimit(s.perClientRecvBytesPerSec, s.perClientRecvBurst) + c.setRateLimit(s.rateConfig.PerClientRateLimitBytesPerSec, s.rateConfig.PerClientRateBurstBytes, s.recvLim) cs, ok := s.clients[c.key] if !ok { @@ -1100,7 +1117,7 @@ func (c *sclient) run(ctx context.Context) error { } return fmt.Errorf("client %s: readFrameHeader: %w", c.key.ShortString(), err) } - // Rate limit by DERP frame length (fl), which excludes TLS protocol and + // Rate-limit by DERP frame length (fl), which excludes TLS protocol and // DERP frame length field overheads. // Note: meshed clients are exempt from rate limits. if err := c.rateLimit(int(fl)); err != nil { @@ -1316,22 +1333,18 @@ func (c *sclient) handleFrameSendPacket(_ derp.FrameType, fl uint32) error { return c.sendPkt(dst, p) } -// setRateLimit updates the per-client receive rate limiter. -// When bytesPerSec is 0 or the client is a mesh peer, the limiter is -// set to nil so that [sclient.rateLimit] is a no-op. -func (c *sclient) setRateLimit(bytesPerSec uint64, burst uint64) { +// setRateLimit updates the receive rate limiter. When bytesPerSec is 0 or the +// client is a mesh peer, the limiter is set to nil so that [sclient.rateLimit] is a no-op. +func (c *sclient) setRateLimit(bytesPerSec, burst uint64, parent *xrate.Limiter) { if bytesPerSec == 0 || c.canMesh { c.recvLim.Store(nil) return } - if lim := c.recvLim.Load(); lim != nil { - // Update in place. SetBurst before SetLimit to avoid a transient - // state where a new higher rate exceeds the old lower burst. - lim.SetBurst(int(burst)) - lim.SetLimit(xrate.Limit(bytesPerSec)) - return + child := xrate.NewLimiter(xrate.Limit(bytesPerSec), int(burst)) + lim := &parentChildTokenBuckets{ + parent: parent, + child: child, } - lim := xrate.NewLimiter(xrate.Limit(bytesPerSec), int(burst)) c.recvLim.Store(lim) } @@ -1350,12 +1363,18 @@ func (c *sclient) rateLimit(n int) error { // [minRateLimitTokenBucketSize]. // // While we could call WaitN multiple times and/or more precisely for - // lim.Burst(), it's better to return early as a larger DERP frame is: - // 1. unexpected - // 2. only partially read off the socket (bufio) + // lim.Burst(), it's better to return early as a larger DERP frame: + // 1. is unexpected + // 2. is only partially read off the socket (bufio) // 3. would cause the connection to close shortly after rate limiting, anyway. clampedN := min(n, minRateLimitTokenBucketSize) - return lim.WaitN(c.ctx, clampedN) + err := lim.child.WaitN(c.ctx, clampedN) + if err != nil { + return err + } + if lim.parent != nil { + return lim.parent.WaitN(c.ctx, clampedN) + } } return nil } @@ -1792,15 +1811,22 @@ type sclient struct { // through us with a peer we have no record of. peerGoneLim *rate.Limiter - // recvLim is the per-connection receive rate limiter. When rate - // limiting is enabled for a non-mesh client, it points to an - // [xrate.Limiter]. When rate limiting is disabled or the client is a - // mesh peer, it is nil and [sclient.rateLimit] is a no-op. - // Updated atomically by [sclient.setRateLimitLocked] so that - // [sclient.rateLimit] can load it without holding Server.mu. - // TODO(mikeodr): update to use mono time, requires updates - // to tstime/rate.Limiter - recvLim atomic.Pointer[xrate.Limiter] + // recvLim is the receive rate limiter. When rate limiting is enabled for a + // non-mesh client, it points to a [parentChildTokenBuckets]. When rate limiting + // is disabled or the client is a mesh peer, it is nil and [sclient.rateLimit] + // is a no-op. Updated atomically by [sclient.setRateLimit] so that + // [sclient.rateLimit] can load it without holding [Server.mu]. + recvLim atomic.Pointer[parentChildTokenBuckets] +} + +// parentChildTokenBuckets contains a parent and child token bucket for the +// purpose of applying in a hierarchical topology. +// +// TODO: consider porting the required APIs from [xrate.Limiter] to [rate.Limiter], +// which is already optimized to use [mono.Time]. +type parentChildTokenBuckets struct { + parent *xrate.Limiter // parent may be nil + child *xrate.Limiter // child is always non-nil } func (c *sclient) presentFlags() derp.PeerPresentFlags { diff --git a/derp/derpserver/derpserver_test.go b/derp/derpserver/derpserver_test.go index d7281bee3..9bd631f3b 100644 --- a/derp/derpserver/derpserver_test.go +++ b/derp/derpserver/derpserver_test.go @@ -965,11 +965,28 @@ func TestPerClientRateLimit(t *testing.T) { c := &sclient{ ctx: ctx, } - c.recvLim.Store(rate.NewLimiter(rate.Limit(minRateLimitTokenBucketSize), minRateLimitTokenBucketSize)) + lim := &parentChildTokenBuckets{ + // Set parent limit to half of child to enable verification of + // rate limiting across both layers with a single sclient. + parent: rate.NewLimiter(rate.Limit(minRateLimitTokenBucketSize)/2, minRateLimitTokenBucketSize), + child: rate.NewLimiter(rate.Limit(minRateLimitTokenBucketSize), minRateLimitTokenBucketSize), + } + c.recvLim.Store(lim) + wantTokens := func(t *testing.T, wantParentTokens, wantChildTokens float64) { + t.Helper() + if lim.parent.Tokens() != wantParentTokens { + t.Fatalf("want parent tokens: %v got: %v", wantParentTokens, lim.parent.Tokens()) + } + if lim.child.Tokens() != wantChildTokens { + t.Fatalf("want child tokens: %v got: %v", wantChildTokens, lim.child.Tokens()) + } + } // First call within burst should not block. c.rateLimit(minRateLimitTokenBucketSize) + wantTokens(t, 0, 0) + // Next call exceeds burst, should block until tokens replenish. done := make(chan error, 1) go func() { @@ -984,7 +1001,21 @@ func TestPerClientRateLimit(t *testing.T) { default: } - // Advance time by 1 second + // Advance time by 1 second, the goroutine should still be blocked + // on the parent bucket (negative tokens). + time.Sleep(1 * time.Second) + synctest.Wait() + select { + case err := <-done: + t.Fatalf("rateLimit should have blocked, but returned: %v", err) + default: + } + + // Verify the parent bucket fills at half the rate of the child. + wantTokens(t, -(minRateLimitTokenBucketSize / 2), 0) + + // Advance time by another second, parent should have enough tokens + // to unblock. time.Sleep(1 * time.Second) synctest.Wait() @@ -996,6 +1027,8 @@ func TestPerClientRateLimit(t *testing.T) { default: t.Fatal("rateLimit should have unblocked after 1s") } + + wantTokens(t, 0, minRateLimitTokenBucketSize) }) }) @@ -1006,7 +1039,11 @@ func TestPerClientRateLimit(t *testing.T) { c := &sclient{ ctx: ctx, } - c.recvLim.Store(rate.NewLimiter(rate.Limit(minRateLimitTokenBucketSize), minRateLimitTokenBucketSize)) + lim := &parentChildTokenBuckets{ + child: rate.NewLimiter(rate.Limit(minRateLimitTokenBucketSize), minRateLimitTokenBucketSize), + parent: rate.NewLimiter(rate.Limit(minRateLimitTokenBucketSize), minRateLimitTokenBucketSize), + } + c.recvLim.Store(lim) // Exhaust burst. if err := c.rateLimit(minRateLimitTokenBucketSize); err != nil { @@ -1052,18 +1089,38 @@ func TestPerClientRateLimit(t *testing.T) { t.Run("zero_config_no_limiter", func(t *testing.T) { s := New(key.NewNode(), logger.Discard) defer s.Close() - if s.perClientRecvBytesPerSec != 0 { - t.Errorf("expected zero rate limit, got %d", s.perClientRecvBytesPerSec) + if !reflect.DeepEqual(s.rateConfig, RateConfig{}) { + t.Errorf("expected zero rate limit, got %+v", s.rateConfig) } }) } -func TestUpdatePerClientRateLimit(t *testing.T) { +func verifyLimiter(t *testing.T, lim *parentChildTokenBuckets, wantRateConfig RateConfig) { + t.Helper() + if got := lim.child.Limit(); got != rate.Limit(wantRateConfig.PerClientRateLimitBytesPerSec) { + t.Errorf("client rate limit = %v; want %d", got, wantRateConfig.PerClientRateLimitBytesPerSec) + } + if got := lim.child.Burst(); got != int(wantRateConfig.PerClientRateBurstBytes) { + t.Errorf("client burst = %v; want %d", got, wantRateConfig.PerClientRateBurstBytes) + } + if got := lim.parent.Limit(); got != rate.Limit(wantRateConfig.GlobalRateLimitBytesPerSec) { + t.Errorf("global rate limit = %v, want %d", got, wantRateConfig.GlobalRateLimitBytesPerSec) + } + if got := lim.parent.Burst(); got != int(wantRateConfig.GlobalRateBurstBytes) { + t.Errorf("global burst = %v, want %d", got, wantRateConfig.GlobalRateBurstBytes) + } +} + +func TestUpdateRateLimits(t *testing.T) { const ( - testBurst1 = derp.MaxPacketSize * 2 - testRate1 = 1000 - testBurst2 = derp.MaxPacketSize * 4 - testRate2 = 5000 + testClientBurst1 = minRateLimitTokenBucketSize + 1 + testClientRate1 = minRateLimitTokenBucketSize + 2 + testClientBurst2 = minRateLimitTokenBucketSize + 3 + testClientRate2 = minRateLimitTokenBucketSize + 4 + testGlobalBurst1 = minRateLimitTokenBucketSize + 5 + testGlobalRate1 = minRateLimitTokenBucketSize + 6 + testGlobalBurst2 = minRateLimitTokenBucketSize + 7 + testGlobalRate2 = minRateLimitTokenBucketSize + 8 ) s := New(key.NewNode(), t.Logf) @@ -1084,52 +1141,46 @@ func TestUpdatePerClientRateLimit(t *testing.T) { s.clients[clientKey] = cs s.mu.Unlock() - s.UpdatePerClientRateLimit(testRate1, testBurst1) + rc := RateConfig{ + PerClientRateLimitBytesPerSec: testClientRate1, + PerClientRateBurstBytes: testClientBurst1, + GlobalRateLimitBytesPerSec: testGlobalRate1, + GlobalRateBurstBytes: testGlobalBurst1, + } + s.UpdateRateLimits(rc) lim := c.recvLim.Load() if lim == nil { t.Fatal("expected non-nil limiter after update") } - if got := lim.Limit(); got != rate.Limit(testRate1) { - t.Errorf("rate limit = %v; want %d", got, testRate1) - } - if got := lim.Burst(); got != int(testBurst1) { - t.Errorf("burst = %v; want %d", got, testBurst1) - } + verifyLimiter(t, lim, rc) // Verify server fields updated. s.mu.Lock() - if s.perClientRecvBytesPerSec != testRate1 { - t.Errorf("server rate = %d; want %d", s.perClientRecvBytesPerSec, testRate1) - } - if s.perClientRecvBurst != testBurst1 { - t.Errorf("server burst = %d; want %d", s.perClientRecvBurst, testBurst1) + if !reflect.DeepEqual(s.rateConfig, rc) { + t.Errorf("s.rateConfig = %+v; want %+v", s.rateConfig, rc) } s.mu.Unlock() - // Update again with different nonzero values. This exercises the - // in-place update path (existing limiter is reused, not recreated). - prevLim := c.recvLim.Load() - s.UpdatePerClientRateLimit(testRate2, testBurst2) + // Update again with different nonzero values. + rc = RateConfig{ + PerClientRateLimitBytesPerSec: testClientRate2, + PerClientRateBurstBytes: testClientBurst2, + GlobalRateLimitBytesPerSec: testGlobalRate2, + GlobalRateBurstBytes: testGlobalBurst2, + } + s.UpdateRateLimits(rc) lim = c.recvLim.Load() if lim == nil { - t.Fatal("expected non-nil limiter after in-place update") - } - if lim != prevLim { - t.Error("expected same limiter pointer after in-place update") - } - if got := lim.Limit(); got != rate.Limit(testRate2) { - t.Errorf("rate limit after in-place update = %v; want %d", got, testRate2) - } - if got := lim.Burst(); got != int(testBurst2) { - t.Errorf("burst after in-place update = %v; want %d", got, testBurst2) + t.Fatal("expected non-nil limiter") } + verifyLimiter(t, lim, rc) // Disable rate limiting (set to 0). - s.UpdatePerClientRateLimit(0, 0) + s.UpdateRateLimits(RateConfig{}) if got := c.recvLim.Load(); got != nil { - t.Errorf("expected nil limiter after disable, got limit=%v", got.Limit()) + t.Errorf("expected nil limiter after disable, got limit=%v", got.child.Limit()) } // Mesh peer should always have nil limiter regardless of update. @@ -1147,22 +1198,23 @@ func TestUpdatePerClientRateLimit(t *testing.T) { s.clients[meshKey] = meshCS s.mu.Unlock() - s.UpdatePerClientRateLimit(testRate2, testBurst2) + rc = RateConfig{ + PerClientRateLimitBytesPerSec: testClientRate2, + PerClientRateBurstBytes: testClientBurst2, + GlobalRateLimitBytesPerSec: testGlobalRate2, + GlobalRateBurstBytes: testGlobalBurst2, + } + s.UpdateRateLimits(rc) if got := meshClient.recvLim.Load(); got != nil { - t.Errorf("mesh peer should have nil limiter, got limit=%v", got.Limit()) + t.Errorf("mesh peer should have nil limiter, got limit=%v", got.child.Limit()) } // Non-mesh client should be updated. lim = c.recvLim.Load() if lim == nil { t.Fatal("expected non-nil limiter for non-mesh client") } - if got := lim.Limit(); got != rate.Limit(testRate2) { - t.Errorf("rate limit = %v; want %d", got, testRate2) - } - if got := lim.Burst(); got != int(testBurst2) { - t.Errorf("burst = %v; want %d", got, testBurst2) - } + verifyLimiter(t, lim, rc) // Verify dup clients are also updated. dupKey := key.NewNode().Public() @@ -1175,32 +1227,40 @@ func TestUpdatePerClientRateLimit(t *testing.T) { s.clients[dupKey] = dupCS s.mu.Unlock() - s.UpdatePerClientRateLimit(testRate1, testBurst1) + rc = RateConfig{ + GlobalRateLimitBytesPerSec: testGlobalRate1, + GlobalRateBurstBytes: testGlobalBurst1, + PerClientRateLimitBytesPerSec: testClientRate1, + PerClientRateBurstBytes: testClientBurst1, + } + s.UpdateRateLimits(rc) for i, d := range []*sclient{d1, d2} { dl := d.recvLim.Load() if dl == nil { t.Fatalf("dup client %d: expected non-nil limiter", i) } - if got := dl.Limit(); got != rate.Limit(testRate1) { - t.Errorf("dup client %d: rate = %v; want %d", i, got, testRate1) - } - if got := dl.Burst(); got != int(testBurst1) { - t.Errorf("dup client %d: burst = %v; want %d", i, got, testBurst1) - } + verifyLimiter(t, dl, rc) } } func TestLoadRateConfig(t *testing.T) { for _, tt := range []struct { - name string - json string - wantRate uint64 - wantBurst uint64 + name string + json string + wantRateConfig RateConfig }{ - {"both_set", `{"PerClientRateLimitBytesPerSec": 1250000, "PerClientRateBurstBytes": 2500000}`, 1250000, 2500000}, - {"rate_only", `{"PerClientRateLimitBytesPerSec": 500000}`, 500000, 0}, - {"zeros", `{"PerClientRateLimitBytesPerSec": 0, "PerClientRateBurstBytes": 0}`, 0, 0}, - {"empty_json", `{}`, 0, 0}, + {"all_set", `{"PerClientRateLimitBytesPerSec": 1, "PerClientRateBurstBytes": 2, "GlobalRateLimitBytesPerSec": 3, "GlobalRateBurstBytes": 4}`, RateConfig{ + PerClientRateLimitBytesPerSec: 1, + PerClientRateBurstBytes: 2, + GlobalRateLimitBytesPerSec: 3, + GlobalRateBurstBytes: 4, + }}, + {"rate_only", `{"PerClientRateLimitBytesPerSec": 1, "GlobalRateLimitBytesPerSec": 3}`, RateConfig{ + PerClientRateLimitBytesPerSec: 1, + GlobalRateLimitBytesPerSec: 3, + }}, + {"zeros", `{"PerClientRateLimitBytesPerSec": 0, "PerClientRateBurstBytes": 0, "GlobalRateLimitBytesPerSec": 0, "GlobalRateBurstBytes": 0}`, RateConfig{}}, + {"empty_json", `{}`, RateConfig{}}, } { t.Run(tt.name, func(t *testing.T) { f := filepath.Join(t.TempDir(), "rate.json") @@ -1209,13 +1269,10 @@ func TestLoadRateConfig(t *testing.T) { } rc, err := LoadRateConfig(f) if err != nil { - t.Fatalf("unexpected error: %v", err) + t.Fatal(err) } - if rc.PerClientRateLimitBytesPerSec != tt.wantRate { - t.Errorf("rate = %d; want %d", rc.PerClientRateLimitBytesPerSec, tt.wantRate) - } - if rc.PerClientRateBurstBytes != tt.wantBurst { - t.Errorf("burst = %d; want %d", rc.PerClientRateBurstBytes, tt.wantBurst) + if !reflect.DeepEqual(rc, tt.wantRateConfig) { + t.Errorf("rate config = %v want %v", rc, tt.wantRateConfig) } }) } @@ -1223,7 +1280,7 @@ func TestLoadRateConfig(t *testing.T) { for _, tt := range []struct { name string path string - content string // written to path if non-empty; path used as-is if empty + content string // written to loaded path if non-empty; path used as-is if empty }{ {"empty_path", "", ""}, {"missing_file", filepath.Join(t.TempDir(), "nonexistent.json"), ""}, @@ -1267,47 +1324,51 @@ func TestLoadAndApplyRateConfig(t *testing.T) { s.clients[clientKey] = cs s.mu.Unlock() - f := writeConfig(t, `{"PerClientRateLimitBytesPerSec": 1250000, "PerClientRateBurstBytes": 2500000}`) + f := writeConfig(t, fmt.Sprintf(`{"PerClientRateLimitBytesPerSec": %d, "PerClientRateBurstBytes": %d, "GlobalRateLimitBytesPerSec": %d, "GlobalRateBurstBytes": %d}`, + minRateLimitTokenBucketSize, minRateLimitTokenBucketSize+1, minRateLimitTokenBucketSize+2, minRateLimitTokenBucketSize+3)) if err := s.LoadAndApplyRateConfig(f); err != nil { t.Fatalf("LoadAndApplyRateConfig: %v", err) } // Verify server fields. + wantRateConfig := RateConfig{ + PerClientRateLimitBytesPerSec: minRateLimitTokenBucketSize, + PerClientRateBurstBytes: minRateLimitTokenBucketSize + 1, + GlobalRateLimitBytesPerSec: minRateLimitTokenBucketSize + 2, + GlobalRateBurstBytes: minRateLimitTokenBucketSize + 3, + } s.mu.Lock() - gotRate := s.perClientRecvBytesPerSec - gotBurst := s.perClientRecvBurst + if !reflect.DeepEqual(s.rateConfig, wantRateConfig) { + t.Errorf("s.rateConfig = %+v; want %+v", s.rateConfig, wantRateConfig) + } s.mu.Unlock() - if gotRate != 1250000 { - t.Errorf("server rate = %d; want 1250000", gotRate) - } - if gotBurst != 2500000 { - t.Errorf("server burst = %d; want 2500000", gotBurst) - } // Verify client limiter. lim := c.recvLim.Load() if lim == nil { t.Fatal("expected non-nil limiter") } - if got := lim.Limit(); got != rate.Limit(1250000) { - t.Errorf("client rate = %v; want 1250000", got) - } + verifyLimiter(t, lim, wantRateConfig) }) - t.Run("burst_is_at_least_max_packet_size", func(t *testing.T) { + t.Run("burst_is_at_least_minRateLimitTokenBucketSize", func(t *testing.T) { s := New(key.NewNode(), t.Logf) defer s.Close() - f := writeConfig(t, `{"PerClientRateLimitBytesPerSec": 1250000, "PerClientRateBurstBytes": 10}`) + f := writeConfig(t, `{"PerClientRateLimitBytesPerSec": 1250000, "PerClientRateBurstBytes": 10, "GlobalRateLimitBytesPerSec": 1250000, "GlobalRateBurstBytes": 10}`) if err := s.LoadAndApplyRateConfig(f); err != nil { t.Fatalf("LoadAndApplyRateConfig: %v", err) } s.mu.Lock() - gotBurst := s.perClientRecvBurst + gotClientBurst := s.rateConfig.PerClientRateBurstBytes + gotGlobalBurst := s.rateConfig.GlobalRateBurstBytes s.mu.Unlock() - if gotBurst != minRateLimitTokenBucketSize { - t.Errorf("burst = %d; want at least %d", gotBurst, minRateLimitTokenBucketSize) + if gotClientBurst != minRateLimitTokenBucketSize { + t.Errorf("client burst = %d; want %d", gotClientBurst, minRateLimitTokenBucketSize) + } + if gotGlobalBurst != minRateLimitTokenBucketSize { + t.Errorf("global burst = %d; want %d", gotGlobalBurst, minRateLimitTokenBucketSize) } }) @@ -1315,10 +1376,15 @@ func TestLoadAndApplyRateConfig(t *testing.T) { s := New(key.NewNode(), t.Logf) defer s.Close() - f := writeConfig(t, `{"PerClientRateLimitBytesPerSec": 1250000, "PerClientRateBurstBytes": 2500000}`) + f := writeConfig(t, `{"PerClientRateLimitBytesPerSec": 1250000, "PerClientRateBurstBytes": 2500000, "GlobalRateLimitBytesPerSec": 12500000, "GlobalRateBurstBytes": 25000000}`) if err := s.LoadAndApplyRateConfig(f); err != nil { t.Fatal(err) } + s.mu.Lock() + if reflect.DeepEqual(s.rateConfig, RateConfig{}) { + t.Error("s.rateConfig is zero val; want nonzero rates") + } + s.mu.Unlock() if err := os.WriteFile(f, []byte(`{}`), 0644); err != nil { t.Fatal(err) @@ -1328,11 +1394,10 @@ func TestLoadAndApplyRateConfig(t *testing.T) { } s.mu.Lock() - gotRate := s.perClientRecvBytesPerSec - s.mu.Unlock() - if gotRate != 0 { - t.Errorf("rate = %d; want 0 (unlimited)", gotRate) + if !reflect.DeepEqual(s.rateConfig, RateConfig{}) { + t.Errorf("s.rateConfig = %+v; want %+v", s.rateConfig, RateConfig{}) } + s.mu.Unlock() }) t.Run("propagates_errors", func(t *testing.T) {