all: rename variables with lowercase-l/uppercase-I
See http://go/no-ell Signed-off-by: Alex Chan <alexc@tailscale.com> Updates #cleanup Change-Id: I8c976b51ce7a60f06315048b1920516129cc1d5d
This commit is contained in:
+34
-34
@@ -94,59 +94,59 @@ type bucket struct {
|
||||
|
||||
// Allow charges the key one token (up to the overdraft limit), and
|
||||
// reports whether the key can perform an action.
|
||||
func (l *Limiter[K]) Allow(key K) bool {
|
||||
return l.allow(key, time.Now())
|
||||
func (lm *Limiter[K]) Allow(key K) bool {
|
||||
return lm.allow(key, time.Now())
|
||||
}
|
||||
|
||||
func (l *Limiter[K]) allow(key K, now time.Time) bool {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
return l.allowBucketLocked(l.getBucketLocked(key, now), now)
|
||||
func (lm *Limiter[K]) allow(key K, now time.Time) bool {
|
||||
lm.mu.Lock()
|
||||
defer lm.mu.Unlock()
|
||||
return lm.allowBucketLocked(lm.getBucketLocked(key, now), now)
|
||||
}
|
||||
|
||||
func (l *Limiter[K]) getBucketLocked(key K, now time.Time) *bucket {
|
||||
if l.cache == nil {
|
||||
l.cache = &lru.Cache[K, *bucket]{MaxEntries: l.Size}
|
||||
} else if b := l.cache.Get(key); b != nil {
|
||||
func (lm *Limiter[K]) getBucketLocked(key K, now time.Time) *bucket {
|
||||
if lm.cache == nil {
|
||||
lm.cache = &lru.Cache[K, *bucket]{MaxEntries: lm.Size}
|
||||
} else if b := lm.cache.Get(key); b != nil {
|
||||
return b
|
||||
}
|
||||
b := &bucket{
|
||||
cur: l.Max,
|
||||
lastUpdate: now.Truncate(l.RefillInterval),
|
||||
cur: lm.Max,
|
||||
lastUpdate: now.Truncate(lm.RefillInterval),
|
||||
}
|
||||
l.cache.Set(key, b)
|
||||
lm.cache.Set(key, b)
|
||||
return b
|
||||
}
|
||||
|
||||
func (l *Limiter[K]) allowBucketLocked(b *bucket, now time.Time) bool {
|
||||
func (lm *Limiter[K]) allowBucketLocked(b *bucket, now time.Time) bool {
|
||||
// Only update the bucket quota if needed to process request.
|
||||
if b.cur <= 0 {
|
||||
l.updateBucketLocked(b, now)
|
||||
lm.updateBucketLocked(b, now)
|
||||
}
|
||||
ret := b.cur > 0
|
||||
if b.cur > -l.Overdraft {
|
||||
if b.cur > -lm.Overdraft {
|
||||
b.cur--
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func (l *Limiter[K]) updateBucketLocked(b *bucket, now time.Time) {
|
||||
now = now.Truncate(l.RefillInterval)
|
||||
func (lm *Limiter[K]) updateBucketLocked(b *bucket, now time.Time) {
|
||||
now = now.Truncate(lm.RefillInterval)
|
||||
if now.Before(b.lastUpdate) {
|
||||
return
|
||||
}
|
||||
timeDelta := max(now.Sub(b.lastUpdate), 0)
|
||||
tokenDelta := int64(timeDelta / l.RefillInterval)
|
||||
b.cur = min(b.cur+tokenDelta, l.Max)
|
||||
tokenDelta := int64(timeDelta / lm.RefillInterval)
|
||||
b.cur = min(b.cur+tokenDelta, lm.Max)
|
||||
b.lastUpdate = now
|
||||
}
|
||||
|
||||
// peekForTest returns the number of tokens for key, also reporting
|
||||
// whether key was present.
|
||||
func (l *Limiter[K]) tokensForTest(key K) (int64, bool) {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
if b, ok := l.cache.PeekOk(key); ok {
|
||||
func (lm *Limiter[K]) tokensForTest(key K) (int64, bool) {
|
||||
lm.mu.Lock()
|
||||
defer lm.mu.Unlock()
|
||||
if b, ok := lm.cache.PeekOk(key); ok {
|
||||
return b.cur, true
|
||||
}
|
||||
return 0, false
|
||||
@@ -159,12 +159,12 @@ func (l *Limiter[K]) tokensForTest(key K) (int64, bool) {
|
||||
// DumpHTML blocks other callers of the limiter while it collects the
|
||||
// state for dumping. It should not be called on large limiters
|
||||
// involved in hot codepaths.
|
||||
func (l *Limiter[K]) DumpHTML(w io.Writer, onlyLimited bool) {
|
||||
l.dumpHTML(w, onlyLimited, time.Now())
|
||||
func (lm *Limiter[K]) DumpHTML(w io.Writer, onlyLimited bool) {
|
||||
lm.dumpHTML(w, onlyLimited, time.Now())
|
||||
}
|
||||
|
||||
func (l *Limiter[K]) dumpHTML(w io.Writer, onlyLimited bool, now time.Time) {
|
||||
dump := l.collectDump(now)
|
||||
func (lm *Limiter[K]) dumpHTML(w io.Writer, onlyLimited bool, now time.Time) {
|
||||
dump := lm.collectDump(now)
|
||||
io.WriteString(w, "<table><tr><th>Key</th><th>Tokens</th></tr>")
|
||||
for _, line := range dump {
|
||||
if onlyLimited && line.Tokens > 0 {
|
||||
@@ -183,13 +183,13 @@ func (l *Limiter[K]) dumpHTML(w io.Writer, onlyLimited bool, now time.Time) {
|
||||
}
|
||||
|
||||
// collectDump grabs a copy of the limiter state needed by DumpHTML.
|
||||
func (l *Limiter[K]) collectDump(now time.Time) []dumpEntry[K] {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
func (lm *Limiter[K]) collectDump(now time.Time) []dumpEntry[K] {
|
||||
lm.mu.Lock()
|
||||
defer lm.mu.Unlock()
|
||||
|
||||
ret := make([]dumpEntry[K], 0, l.cache.Len())
|
||||
l.cache.ForEach(func(k K, v *bucket) {
|
||||
l.updateBucketLocked(v, now) // so stats are accurate
|
||||
ret := make([]dumpEntry[K], 0, lm.cache.Len())
|
||||
lm.cache.ForEach(func(k K, v *bucket) {
|
||||
lm.updateBucketLocked(v, now) // so stats are accurate
|
||||
ret = append(ret, dumpEntry[K]{k, v.cur})
|
||||
})
|
||||
return ret
|
||||
|
||||
@@ -16,7 +16,7 @@ const testRefillInterval = time.Second
|
||||
|
||||
func TestLimiter(t *testing.T) {
|
||||
// 1qps, burst of 10, 2 keys tracked
|
||||
l := &Limiter[string]{
|
||||
limiter := &Limiter[string]{
|
||||
Size: 2,
|
||||
Max: 10,
|
||||
RefillInterval: testRefillInterval,
|
||||
@@ -24,48 +24,48 @@ func TestLimiter(t *testing.T) {
|
||||
|
||||
// Consume entire burst
|
||||
now := time.Now().Truncate(testRefillInterval)
|
||||
allowed(t, l, "foo", 10, now)
|
||||
denied(t, l, "foo", 1, now)
|
||||
hasTokens(t, l, "foo", 0)
|
||||
allowed(t, limiter, "foo", 10, now)
|
||||
denied(t, limiter, "foo", 1, now)
|
||||
hasTokens(t, limiter, "foo", 0)
|
||||
|
||||
allowed(t, l, "bar", 10, now)
|
||||
denied(t, l, "bar", 1, now)
|
||||
hasTokens(t, l, "bar", 0)
|
||||
allowed(t, limiter, "bar", 10, now)
|
||||
denied(t, limiter, "bar", 1, now)
|
||||
hasTokens(t, limiter, "bar", 0)
|
||||
|
||||
// Refill 1 token for both foo and bar
|
||||
now = now.Add(time.Second + time.Millisecond)
|
||||
allowed(t, l, "foo", 1, now)
|
||||
denied(t, l, "foo", 1, now)
|
||||
hasTokens(t, l, "foo", 0)
|
||||
allowed(t, limiter, "foo", 1, now)
|
||||
denied(t, limiter, "foo", 1, now)
|
||||
hasTokens(t, limiter, "foo", 0)
|
||||
|
||||
allowed(t, l, "bar", 1, now)
|
||||
denied(t, l, "bar", 1, now)
|
||||
hasTokens(t, l, "bar", 0)
|
||||
allowed(t, limiter, "bar", 1, now)
|
||||
denied(t, limiter, "bar", 1, now)
|
||||
hasTokens(t, limiter, "bar", 0)
|
||||
|
||||
// Refill 2 tokens for foo and bar
|
||||
now = now.Add(2*time.Second + time.Millisecond)
|
||||
allowed(t, l, "foo", 2, now)
|
||||
denied(t, l, "foo", 1, now)
|
||||
hasTokens(t, l, "foo", 0)
|
||||
allowed(t, limiter, "foo", 2, now)
|
||||
denied(t, limiter, "foo", 1, now)
|
||||
hasTokens(t, limiter, "foo", 0)
|
||||
|
||||
allowed(t, l, "bar", 2, now)
|
||||
denied(t, l, "bar", 1, now)
|
||||
hasTokens(t, l, "bar", 0)
|
||||
allowed(t, limiter, "bar", 2, now)
|
||||
denied(t, limiter, "bar", 1, now)
|
||||
hasTokens(t, limiter, "bar", 0)
|
||||
|
||||
// qux can burst 10, evicts foo so it can immediately burst 10 again too
|
||||
allowed(t, l, "qux", 10, now)
|
||||
denied(t, l, "qux", 1, now)
|
||||
notInLimiter(t, l, "foo")
|
||||
denied(t, l, "bar", 1, now) // refresh bar so foo lookup doesn't evict it - still throttled
|
||||
allowed(t, limiter, "qux", 10, now)
|
||||
denied(t, limiter, "qux", 1, now)
|
||||
notInLimiter(t, limiter, "foo")
|
||||
denied(t, limiter, "bar", 1, now) // refresh bar so foo lookup doesn't evict it - still throttled
|
||||
|
||||
allowed(t, l, "foo", 10, now)
|
||||
denied(t, l, "foo", 1, now)
|
||||
hasTokens(t, l, "foo", 0)
|
||||
allowed(t, limiter, "foo", 10, now)
|
||||
denied(t, limiter, "foo", 1, now)
|
||||
hasTokens(t, limiter, "foo", 0)
|
||||
}
|
||||
|
||||
func TestLimiterOverdraft(t *testing.T) {
|
||||
// 1qps, burst of 10, overdraft of 2, 2 keys tracked
|
||||
l := &Limiter[string]{
|
||||
limiter := &Limiter[string]{
|
||||
Size: 2,
|
||||
Max: 10,
|
||||
Overdraft: 2,
|
||||
@@ -74,51 +74,51 @@ func TestLimiterOverdraft(t *testing.T) {
|
||||
|
||||
// Consume entire burst, go 1 into debt
|
||||
now := time.Now().Truncate(testRefillInterval).Add(time.Millisecond)
|
||||
allowed(t, l, "foo", 10, now)
|
||||
denied(t, l, "foo", 1, now)
|
||||
hasTokens(t, l, "foo", -1)
|
||||
allowed(t, limiter, "foo", 10, now)
|
||||
denied(t, limiter, "foo", 1, now)
|
||||
hasTokens(t, limiter, "foo", -1)
|
||||
|
||||
allowed(t, l, "bar", 10, now)
|
||||
denied(t, l, "bar", 1, now)
|
||||
hasTokens(t, l, "bar", -1)
|
||||
allowed(t, limiter, "bar", 10, now)
|
||||
denied(t, limiter, "bar", 1, now)
|
||||
hasTokens(t, limiter, "bar", -1)
|
||||
|
||||
// Refill 1 token for both foo and bar.
|
||||
// Still denied, still in debt.
|
||||
now = now.Add(time.Second)
|
||||
denied(t, l, "foo", 1, now)
|
||||
hasTokens(t, l, "foo", -1)
|
||||
denied(t, l, "bar", 1, now)
|
||||
hasTokens(t, l, "bar", -1)
|
||||
denied(t, limiter, "foo", 1, now)
|
||||
hasTokens(t, limiter, "foo", -1)
|
||||
denied(t, limiter, "bar", 1, now)
|
||||
hasTokens(t, limiter, "bar", -1)
|
||||
|
||||
// Refill 2 tokens for foo and bar (1 available after debt), try
|
||||
// to consume 4. Overdraft is capped to 2.
|
||||
now = now.Add(2 * time.Second)
|
||||
allowed(t, l, "foo", 1, now)
|
||||
denied(t, l, "foo", 3, now)
|
||||
hasTokens(t, l, "foo", -2)
|
||||
allowed(t, limiter, "foo", 1, now)
|
||||
denied(t, limiter, "foo", 3, now)
|
||||
hasTokens(t, limiter, "foo", -2)
|
||||
|
||||
allowed(t, l, "bar", 1, now)
|
||||
denied(t, l, "bar", 3, now)
|
||||
hasTokens(t, l, "bar", -2)
|
||||
allowed(t, limiter, "bar", 1, now)
|
||||
denied(t, limiter, "bar", 3, now)
|
||||
hasTokens(t, limiter, "bar", -2)
|
||||
|
||||
// Refill 1, not enough to allow.
|
||||
now = now.Add(time.Second)
|
||||
denied(t, l, "foo", 1, now)
|
||||
hasTokens(t, l, "foo", -2)
|
||||
denied(t, l, "bar", 1, now)
|
||||
hasTokens(t, l, "bar", -2)
|
||||
denied(t, limiter, "foo", 1, now)
|
||||
hasTokens(t, limiter, "foo", -2)
|
||||
denied(t, limiter, "bar", 1, now)
|
||||
hasTokens(t, limiter, "bar", -2)
|
||||
|
||||
// qux evicts foo, foo can immediately burst 10 again.
|
||||
allowed(t, l, "qux", 1, now)
|
||||
hasTokens(t, l, "qux", 9)
|
||||
notInLimiter(t, l, "foo")
|
||||
allowed(t, l, "foo", 10, now)
|
||||
denied(t, l, "foo", 1, now)
|
||||
hasTokens(t, l, "foo", -1)
|
||||
allowed(t, limiter, "qux", 1, now)
|
||||
hasTokens(t, limiter, "qux", 9)
|
||||
notInLimiter(t, limiter, "foo")
|
||||
allowed(t, limiter, "foo", 10, now)
|
||||
denied(t, limiter, "foo", 1, now)
|
||||
hasTokens(t, limiter, "foo", -1)
|
||||
}
|
||||
|
||||
func TestDumpHTML(t *testing.T) {
|
||||
l := &Limiter[string]{
|
||||
limiter := &Limiter[string]{
|
||||
Size: 3,
|
||||
Max: 10,
|
||||
Overdraft: 10,
|
||||
@@ -126,13 +126,13 @@ func TestDumpHTML(t *testing.T) {
|
||||
}
|
||||
|
||||
now := time.Now().Truncate(testRefillInterval).Add(time.Millisecond)
|
||||
allowed(t, l, "foo", 10, now)
|
||||
denied(t, l, "foo", 2, now)
|
||||
allowed(t, l, "bar", 4, now)
|
||||
allowed(t, l, "qux", 1, now)
|
||||
allowed(t, limiter, "foo", 10, now)
|
||||
denied(t, limiter, "foo", 2, now)
|
||||
allowed(t, limiter, "bar", 4, now)
|
||||
allowed(t, limiter, "qux", 1, now)
|
||||
|
||||
var out bytes.Buffer
|
||||
l.DumpHTML(&out, false)
|
||||
limiter.DumpHTML(&out, false)
|
||||
want := strings.Join([]string{
|
||||
"<table>",
|
||||
"<tr><th>Key</th><th>Tokens</th></tr>",
|
||||
@@ -146,7 +146,7 @@ func TestDumpHTML(t *testing.T) {
|
||||
}
|
||||
|
||||
out.Reset()
|
||||
l.DumpHTML(&out, true)
|
||||
limiter.DumpHTML(&out, true)
|
||||
want = strings.Join([]string{
|
||||
"<table>",
|
||||
"<tr><th>Key</th><th>Tokens</th></tr>",
|
||||
@@ -161,7 +161,7 @@ func TestDumpHTML(t *testing.T) {
|
||||
// organically.
|
||||
now = now.Add(3 * time.Second)
|
||||
out.Reset()
|
||||
l.dumpHTML(&out, false, now)
|
||||
limiter.dumpHTML(&out, false, now)
|
||||
want = strings.Join([]string{
|
||||
"<table>",
|
||||
"<tr><th>Key</th><th>Tokens</th></tr>",
|
||||
@@ -175,29 +175,29 @@ func TestDumpHTML(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func allowed(t *testing.T, l *Limiter[string], key string, count int, now time.Time) {
|
||||
func allowed(t *testing.T, limiter *Limiter[string], key string, count int, now time.Time) {
|
||||
t.Helper()
|
||||
for i := range count {
|
||||
if !l.allow(key, now) {
|
||||
toks, ok := l.tokensForTest(key)
|
||||
if !limiter.allow(key, now) {
|
||||
toks, ok := limiter.tokensForTest(key)
|
||||
t.Errorf("after %d times: allow(%q, %q) = false, want true (%d tokens available, in cache = %v)", i, key, now, toks, ok)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func denied(t *testing.T, l *Limiter[string], key string, count int, now time.Time) {
|
||||
func denied(t *testing.T, limiter *Limiter[string], key string, count int, now time.Time) {
|
||||
t.Helper()
|
||||
for i := range count {
|
||||
if l.allow(key, now) {
|
||||
toks, ok := l.tokensForTest(key)
|
||||
if limiter.allow(key, now) {
|
||||
toks, ok := limiter.tokensForTest(key)
|
||||
t.Errorf("after %d times: allow(%q, %q) = true, want false (%d tokens available, in cache = %v)", i, key, now, toks, ok)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func hasTokens(t *testing.T, l *Limiter[string], key string, want int64) {
|
||||
func hasTokens(t *testing.T, limiter *Limiter[string], key string, want int64) {
|
||||
t.Helper()
|
||||
got, ok := l.tokensForTest(key)
|
||||
got, ok := limiter.tokensForTest(key)
|
||||
if !ok {
|
||||
t.Errorf("key %q missing from limiter", key)
|
||||
} else if got != want {
|
||||
@@ -205,9 +205,9 @@ func hasTokens(t *testing.T, l *Limiter[string], key string, want int64) {
|
||||
}
|
||||
}
|
||||
|
||||
func notInLimiter(t *testing.T, l *Limiter[string], key string) {
|
||||
func notInLimiter(t *testing.T, limiter *Limiter[string], key string) {
|
||||
t.Helper()
|
||||
if tokens, ok := l.tokensForTest(key); ok {
|
||||
if tokens, ok := limiter.tokensForTest(key); ok {
|
||||
t.Errorf("key %q unexpectedly tracked by limiter, with %d tokens", key, tokens)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -85,7 +85,7 @@ type tableDetector interface {
|
||||
type linuxFWDetector struct{}
|
||||
|
||||
// iptDetect returns the number of iptables rules in the current namespace.
|
||||
func (l linuxFWDetector) iptDetect() (int, error) {
|
||||
func (ld linuxFWDetector) iptDetect() (int, error) {
|
||||
return detectIptables()
|
||||
}
|
||||
|
||||
@@ -96,7 +96,7 @@ var hookDetectNetfilter feature.Hook[func() (int, error)]
|
||||
var ErrUnsupported = errors.New("linuxfw:unsupported")
|
||||
|
||||
// nftDetect returns the number of nftables rules in the current namespace.
|
||||
func (l linuxFWDetector) nftDetect() (int, error) {
|
||||
func (ld linuxFWDetector) nftDetect() (int, error) {
|
||||
if f, ok := hookDetectNetfilter.GetOk(); ok {
|
||||
return f()
|
||||
}
|
||||
|
||||
@@ -84,8 +84,8 @@ func TestStressEvictions(t *testing.T) {
|
||||
for range numProbes {
|
||||
v := vals[rand.Intn(len(vals))]
|
||||
c.Set(v, true)
|
||||
if l := c.Len(); l > cacheSize {
|
||||
t.Fatalf("Cache size now %d, want max %d", l, cacheSize)
|
||||
if ln := c.Len(); ln > cacheSize {
|
||||
t.Fatalf("Cache size now %d, want max %d", ln, cacheSize)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -119,8 +119,8 @@ func TestStressBatchedEvictions(t *testing.T) {
|
||||
c.DeleteOldest()
|
||||
}
|
||||
}
|
||||
if l := c.Len(); l > cacheSizeMax {
|
||||
t.Fatalf("Cache size now %d, want max %d", l, cacheSizeMax)
|
||||
if ln := c.Len(); ln > cacheSizeMax {
|
||||
t.Fatalf("Cache size now %d, want max %d", ln, cacheSizeMax)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -322,33 +322,33 @@ func Definitions() ([]*Definition, error) {
|
||||
type PlatformList []string
|
||||
|
||||
// Has reports whether l contains the target platform.
|
||||
func (l PlatformList) Has(target string) bool {
|
||||
if len(l) == 0 {
|
||||
func (ls PlatformList) Has(target string) bool {
|
||||
if len(ls) == 0 {
|
||||
return true
|
||||
}
|
||||
return slices.ContainsFunc(l, func(os string) bool {
|
||||
return slices.ContainsFunc(ls, func(os string) bool {
|
||||
return strings.EqualFold(os, target)
|
||||
})
|
||||
}
|
||||
|
||||
// HasCurrent is like Has, but for the current platform.
|
||||
func (l PlatformList) HasCurrent() bool {
|
||||
return l.Has(internal.OS())
|
||||
func (ls PlatformList) HasCurrent() bool {
|
||||
return ls.Has(internal.OS())
|
||||
}
|
||||
|
||||
// mergeFrom merges l2 into l. Since an empty list indicates no platform restrictions,
|
||||
// if either l or l2 is empty, the merged result in l will also be empty.
|
||||
func (l *PlatformList) mergeFrom(l2 PlatformList) {
|
||||
func (ls *PlatformList) mergeFrom(l2 PlatformList) {
|
||||
switch {
|
||||
case len(*l) == 0:
|
||||
case len(*ls) == 0:
|
||||
// No-op. An empty list indicates no platform restrictions.
|
||||
case len(l2) == 0:
|
||||
// Merging with an empty list results in an empty list.
|
||||
*l = l2
|
||||
*ls = l2
|
||||
default:
|
||||
// Append, sort and dedup.
|
||||
*l = append(*l, l2...)
|
||||
slices.Sort(*l)
|
||||
*l = slices.Compact(*l)
|
||||
*ls = append(*ls, l2...)
|
||||
slices.Sort(*ls)
|
||||
*ls = slices.Compact(*ls)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -311,8 +311,8 @@ func TestListSettingDefinitions(t *testing.T) {
|
||||
t.Fatalf("SetDefinitionsForTest failed: %v", err)
|
||||
}
|
||||
|
||||
cmp := func(l, r *Definition) int {
|
||||
return strings.Compare(string(l.Key()), string(r.Key()))
|
||||
cmp := func(a, b *Definition) int {
|
||||
return strings.Compare(string(a.Key()), string(b.Key()))
|
||||
}
|
||||
want := append([]*Definition{}, definitions...)
|
||||
slices.SortFunc(want, cmp)
|
||||
|
||||
@@ -182,16 +182,16 @@ func doWithMachinePolicyLocked(t *testing.T, f func()) {
|
||||
f()
|
||||
}
|
||||
|
||||
func doWithCustomEnterLeaveFuncs(t *testing.T, f func(l *PolicyLock), enter func(bool) (policyLockHandle, error), leave func(policyLockHandle) error) {
|
||||
func doWithCustomEnterLeaveFuncs(t *testing.T, f func(*PolicyLock), enter func(bool) (policyLockHandle, error), leave func(policyLockHandle) error) {
|
||||
t.Helper()
|
||||
|
||||
l := NewMachinePolicyLock()
|
||||
l.enterFn, l.leaveFn = enter, leave
|
||||
lock := NewMachinePolicyLock()
|
||||
lock.enterFn, lock.leaveFn = enter, leave
|
||||
t.Cleanup(func() {
|
||||
if err := l.Close(); err != nil {
|
||||
if err := lock.Close(); err != nil {
|
||||
t.Fatalf("(*PolicyLock).Close failed: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
f(l)
|
||||
f(lock)
|
||||
}
|
||||
|
||||
@@ -127,32 +127,32 @@ func NewUserPolicyLock(token windows.Token) (*PolicyLock, error) {
|
||||
return lock, nil
|
||||
}
|
||||
|
||||
// Lock locks l.
|
||||
// It returns [ErrInvalidLockState] if l has a zero value or has already been closed,
|
||||
// Lock locks lk.
|
||||
// It returns [ErrInvalidLockState] if lk has a zero value or has already been closed,
|
||||
// [ErrLockRestricted] if the lock cannot be acquired due to a restriction in place,
|
||||
// or a [syscall.Errno] if the underlying Group Policy lock cannot be acquired.
|
||||
//
|
||||
// As a special case, it fails with [windows.ERROR_ACCESS_DENIED]
|
||||
// if l is a user policy lock, and the corresponding user is not logged in
|
||||
// if lk is a user policy lock, and the corresponding user is not logged in
|
||||
// interactively at the time of the call.
|
||||
func (l *PolicyLock) Lock() error {
|
||||
func (lk *PolicyLock) Lock() error {
|
||||
if policyLockRestricted.Load() > 0 {
|
||||
return ErrLockRestricted
|
||||
}
|
||||
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
if l.lockCnt.Add(2)&1 == 0 {
|
||||
lk.mu.Lock()
|
||||
defer lk.mu.Unlock()
|
||||
if lk.lockCnt.Add(2)&1 == 0 {
|
||||
// The lock cannot be acquired because it has either never been properly
|
||||
// created or its Close method has already been called. However, we need
|
||||
// to call Unlock to both decrement lockCnt and leave the underlying
|
||||
// CriticalPolicySection if we won the race with another goroutine and
|
||||
// now own the lock.
|
||||
l.Unlock()
|
||||
lk.Unlock()
|
||||
return ErrInvalidLockState
|
||||
}
|
||||
|
||||
if l.handle != 0 {
|
||||
if lk.handle != 0 {
|
||||
// The underlying CriticalPolicySection is already acquired.
|
||||
// It is an R-Lock (with the W-counterpart owned by the Group Policy service),
|
||||
// meaning that it can be acquired by multiple readers simultaneously.
|
||||
@@ -160,20 +160,20 @@ func (l *PolicyLock) Lock() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
return l.lockSlow()
|
||||
return lk.lockSlow()
|
||||
}
|
||||
|
||||
// lockSlow calls enterCriticalPolicySection to acquire the underlying GP read lock.
|
||||
// It waits for either the lock to be acquired, or for the Close method to be called.
|
||||
//
|
||||
// l.mu must be held.
|
||||
func (l *PolicyLock) lockSlow() (err error) {
|
||||
func (lk *PolicyLock) lockSlow() (err error) {
|
||||
defer func() {
|
||||
if err != nil {
|
||||
// Decrement the counter if the lock cannot be acquired,
|
||||
// and complete the pending close request if we're the last owner.
|
||||
if l.lockCnt.Add(-2) == 0 {
|
||||
l.closeInternal()
|
||||
if lk.lockCnt.Add(-2) == 0 {
|
||||
lk.closeInternal()
|
||||
}
|
||||
}
|
||||
}()
|
||||
@@ -190,12 +190,12 @@ func (l *PolicyLock) lockSlow() (err error) {
|
||||
resultCh := make(chan policyLockResult)
|
||||
|
||||
go func() {
|
||||
closing := l.closing
|
||||
if l.scope == UserPolicy && l.token != 0 {
|
||||
closing := lk.closing
|
||||
if lk.scope == UserPolicy && lk.token != 0 {
|
||||
// Impersonate the user whose critical policy section we want to acquire.
|
||||
runtime.LockOSThread()
|
||||
defer runtime.UnlockOSThread()
|
||||
if err := impersonateLoggedOnUser(l.token); err != nil {
|
||||
if err := impersonateLoggedOnUser(lk.token); err != nil {
|
||||
initCh <- err
|
||||
return
|
||||
}
|
||||
@@ -209,10 +209,10 @@ func (l *PolicyLock) lockSlow() (err error) {
|
||||
close(initCh)
|
||||
|
||||
var machine bool
|
||||
if l.scope == MachinePolicy {
|
||||
if lk.scope == MachinePolicy {
|
||||
machine = true
|
||||
}
|
||||
handle, err := l.enterFn(machine)
|
||||
handle, err := lk.enterFn(machine)
|
||||
|
||||
send_result:
|
||||
for {
|
||||
@@ -226,7 +226,7 @@ func (l *PolicyLock) lockSlow() (err error) {
|
||||
// The lock is being closed, and we lost the race to l.closing
|
||||
// it the calling goroutine.
|
||||
if err == nil {
|
||||
l.leaveFn(handle)
|
||||
lk.leaveFn(handle)
|
||||
}
|
||||
break send_result
|
||||
default:
|
||||
@@ -247,21 +247,21 @@ func (l *PolicyLock) lockSlow() (err error) {
|
||||
select {
|
||||
case result := <-resultCh:
|
||||
if result.err == nil {
|
||||
l.handle = result.handle
|
||||
lk.handle = result.handle
|
||||
}
|
||||
return result.err
|
||||
case <-l.closing:
|
||||
case <-lk.closing:
|
||||
return ErrInvalidLockState
|
||||
}
|
||||
}
|
||||
|
||||
// Unlock unlocks l.
|
||||
// It panics if l is not locked on entry to Unlock.
|
||||
func (l *PolicyLock) Unlock() {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
func (lk *PolicyLock) Unlock() {
|
||||
lk.mu.Lock()
|
||||
defer lk.mu.Unlock()
|
||||
|
||||
lockCnt := l.lockCnt.Add(-2)
|
||||
lockCnt := lk.lockCnt.Add(-2)
|
||||
if lockCnt < 0 {
|
||||
panic("negative lockCnt")
|
||||
}
|
||||
@@ -273,33 +273,33 @@ func (l *PolicyLock) Unlock() {
|
||||
return
|
||||
}
|
||||
|
||||
if l.handle != 0 {
|
||||
if lk.handle != 0 {
|
||||
// Impersonation is not required to unlock a critical policy section.
|
||||
// The handle we pass determines which mutex will be unlocked.
|
||||
leaveCriticalPolicySection(l.handle)
|
||||
l.handle = 0
|
||||
leaveCriticalPolicySection(lk.handle)
|
||||
lk.handle = 0
|
||||
}
|
||||
|
||||
if lockCnt == 0 {
|
||||
// Complete the pending close request if there's no more readers.
|
||||
l.closeInternal()
|
||||
lk.closeInternal()
|
||||
}
|
||||
}
|
||||
|
||||
// Close releases resources associated with l.
|
||||
// It is a no-op for the machine policy lock.
|
||||
func (l *PolicyLock) Close() error {
|
||||
lockCnt := l.lockCnt.Load()
|
||||
func (lk *PolicyLock) Close() error {
|
||||
lockCnt := lk.lockCnt.Load()
|
||||
if lockCnt&1 == 0 {
|
||||
// The lock has never been initialized, or close has already been called.
|
||||
return nil
|
||||
}
|
||||
|
||||
close(l.closing)
|
||||
close(lk.closing)
|
||||
|
||||
// Unset the LSB to indicate a pending close request.
|
||||
for !l.lockCnt.CompareAndSwap(lockCnt, lockCnt&^int32(1)) {
|
||||
lockCnt = l.lockCnt.Load()
|
||||
for !lk.lockCnt.CompareAndSwap(lockCnt, lockCnt&^int32(1)) {
|
||||
lockCnt = lk.lockCnt.Load()
|
||||
}
|
||||
|
||||
if lockCnt != 0 {
|
||||
@@ -307,16 +307,16 @@ func (l *PolicyLock) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
return l.closeInternal()
|
||||
return lk.closeInternal()
|
||||
}
|
||||
|
||||
func (l *PolicyLock) closeInternal() error {
|
||||
if l.token != 0 {
|
||||
if err := l.token.Close(); err != nil {
|
||||
func (lk *PolicyLock) closeInternal() error {
|
||||
if lk.token != 0 {
|
||||
if err := lk.token.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
l.token = 0
|
||||
lk.token = 0
|
||||
}
|
||||
l.closing = nil
|
||||
lk.closing = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -256,8 +256,8 @@ func checkDomainAccount(username string) (sanitizedUserName string, isDomainAcco
|
||||
// errors.Is to check for it. When capLevel == CapCreateProcess, the logon
|
||||
// enforces the user's logon hours policy (when present).
|
||||
func (ls *lsaSession) logonAs(srcName string, u *user.User, capLevel CapabilityLevel) (token windows.Token, err error) {
|
||||
if l := len(srcName); l == 0 || l > _TOKEN_SOURCE_LENGTH {
|
||||
return 0, fmt.Errorf("%w, actual length is %d", ErrBadSrcName, l)
|
||||
if ln := len(srcName); ln == 0 || ln > _TOKEN_SOURCE_LENGTH {
|
||||
return 0, fmt.Errorf("%w, actual length is %d", ErrBadSrcName, ln)
|
||||
}
|
||||
if err := checkASCII(srcName); err != nil {
|
||||
return 0, fmt.Errorf("%w: %v", ErrBadSrcName, err)
|
||||
|
||||
@@ -938,10 +938,10 @@ func mergeEnv(existingEnv []string, extraEnv map[string]string) []string {
|
||||
result = append(result, strings.Join([]string{k, v}, "="))
|
||||
}
|
||||
|
||||
slices.SortFunc(result, func(l, r string) int {
|
||||
kl, _, _ := strings.Cut(l, "=")
|
||||
kr, _, _ := strings.Cut(r, "=")
|
||||
return strings.Compare(kl, kr)
|
||||
slices.SortFunc(result, func(a, b string) int {
|
||||
ka, _, _ := strings.Cut(a, "=")
|
||||
kb, _, _ := strings.Cut(b, "=")
|
||||
return strings.Compare(ka, kb)
|
||||
})
|
||||
return result
|
||||
}
|
||||
|
||||
@@ -83,8 +83,8 @@ func (sib *StartupInfoBuilder) Resolve() (startupInfo *windows.StartupInfo, inhe
|
||||
// Always create a Unicode environment.
|
||||
createProcessFlags = windows.CREATE_UNICODE_ENVIRONMENT
|
||||
|
||||
if l := uint32(len(sib.attrs)); l > 0 {
|
||||
attrCont, err := windows.NewProcThreadAttributeList(l)
|
||||
if ln := uint32(len(sib.attrs)); ln > 0 {
|
||||
attrCont, err := windows.NewProcThreadAttributeList(ln)
|
||||
if err != nil {
|
||||
return nil, false, 0, err
|
||||
}
|
||||
|
||||
@@ -68,8 +68,8 @@ func checkContiguousBuffer[T any, BU BufUnit](t *testing.T, extra []BU, pt *T, p
|
||||
if gotLen := int(ptLen); gotLen != expectedLen {
|
||||
t.Errorf("allocation length got %d, want %d", gotLen, expectedLen)
|
||||
}
|
||||
if l := len(slcs); l != 1 {
|
||||
t.Errorf("len(slcs) got %d, want 1", l)
|
||||
if ln := len(slcs); ln != 1 {
|
||||
t.Errorf("len(slcs) got %d, want 1", ln)
|
||||
}
|
||||
if len(extra) == 0 && slcs[0] != nil {
|
||||
t.Error("slcs[0] got non-nil, want nil")
|
||||
|
||||
Reference in New Issue
Block a user