Replace our ratelimiter with standard rate package (#359)
* Replace our ratelimiter with standard rate package Signed-off-by: Wendi Yu <wendi.yu@yahoo.ca>main
parent
b01db109f5
commit
499c8fcbb3
@ -1,81 +0,0 @@ |
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ratelimit |
||||
|
||||
import ( |
||||
"sync" |
||||
"time" |
||||
|
||||
"tailscale.com/types/structs" |
||||
) |
||||
|
||||
type Bucket struct { |
||||
_ structs.Incomparable |
||||
mu sync.Mutex |
||||
FillInterval time.Duration |
||||
Burst int |
||||
v int |
||||
quitCh chan struct{} |
||||
started bool |
||||
closed bool |
||||
} |
||||
|
||||
func (b *Bucket) startLocked() { |
||||
b.v = b.Burst |
||||
b.quitCh = make(chan struct{}) |
||||
b.started = true |
||||
|
||||
t := time.NewTicker(b.FillInterval) |
||||
go func() { |
||||
for { |
||||
select { |
||||
case <-b.quitCh: |
||||
return |
||||
case <-t.C: |
||||
b.tick() |
||||
} |
||||
} |
||||
}() |
||||
} |
||||
|
||||
func (b *Bucket) tick() { |
||||
b.mu.Lock() |
||||
defer b.mu.Unlock() |
||||
|
||||
if b.v < b.Burst { |
||||
b.v++ |
||||
} |
||||
} |
||||
|
||||
func (b *Bucket) Close() { |
||||
b.mu.Lock() |
||||
if !b.started { |
||||
b.closed = true |
||||
b.mu.Unlock() |
||||
return |
||||
} |
||||
if b.closed { |
||||
b.mu.Unlock() |
||||
return |
||||
} |
||||
b.closed = true |
||||
b.mu.Unlock() |
||||
|
||||
b.quitCh <- struct{}{} |
||||
} |
||||
|
||||
func (b *Bucket) TryGet() int { |
||||
b.mu.Lock() |
||||
defer b.mu.Unlock() |
||||
|
||||
if !b.started { |
||||
b.startLocked() |
||||
} |
||||
if b.v > 0 { |
||||
b.v-- |
||||
return b.v + 1 |
||||
} |
||||
return 0 |
||||
} |
||||
@ -1,28 +0,0 @@ |
||||
// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package ratelimit |
||||
|
||||
import ( |
||||
"testing" |
||||
"time" |
||||
) |
||||
|
||||
func TestBucket(t *testing.T) { |
||||
b := Bucket{ |
||||
FillInterval: time.Second, |
||||
Burst: 3, |
||||
} |
||||
expect := []int{3, 2, 1, 0, 0} |
||||
for i, want := range expect { |
||||
got := b.TryGet() |
||||
if want != got { |
||||
t.Errorf("#%d want=%d got=%d\n", i, want, got) |
||||
} |
||||
} |
||||
b.tick() |
||||
if want, got := 1, b.TryGet(); want != got { |
||||
t.Errorf("after tick: want=%d got=%d\n", want, got) |
||||
} |
||||
} |
||||
Loading…
Reference in new issue