This package uses a count-min sketch and a heap to track the top K items in a stream of data. Tracking a new item and adding a count to an existing item both require no memory allocations and is at worst O(log(k)) complexity. Change-Id: I0553381be3fef2470897e2bd806d43396f2dbb36 Signed-off-by: Andrew Dunham <andrew@du.nham.ca>main
parent
7ad2bb87a6
commit
b7104cde4a
@ -0,0 +1,261 @@ |
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
// Package topk defines a count-min sketch and a cheap probabilistic top-K data
|
||||
// structure that uses the count-min sketch to track the top K items in
|
||||
// constant memory and O(log(k)) time.
|
||||
package topk |
||||
|
||||
import ( |
||||
"container/heap" |
||||
"hash/maphash" |
||||
"math" |
||||
"slices" |
||||
"sync" |
||||
) |
||||
|
||||
// TopK is a probabilistic counter of the top K items, using a count-min sketch
|
||||
// to keep track of item counts and a heap to track the top K of them.
|
||||
type TopK[T any] struct { |
||||
heap minHeap[T] |
||||
k int |
||||
sf SerializeFunc[T] |
||||
cms CountMinSketch |
||||
} |
||||
|
||||
// HashFunc is responsible for providing a []byte serialization of a value,
|
||||
// appended to the provided byte slice. This is used for hashing the value when
|
||||
// adding to a CountMinSketch.
|
||||
type SerializeFunc[T any] func([]byte, T) []byte |
||||
|
||||
// New creates a new TopK that stores k values. Parameters for the underlying
|
||||
// count-min sketch are chosen for a 0.1% error rate and a 0.1% probability of
|
||||
// error.
|
||||
func New[T any](k int, sf SerializeFunc[T]) *TopK[T] { |
||||
hashes, buckets := PickParams(0.001, 0.001) |
||||
return NewWithParams(k, sf, hashes, buckets) |
||||
} |
||||
|
||||
// NewWithParams creates a new TopK that stores k values, and additionally
|
||||
// allows customizing the parameters for the underlying count-min sketch.
|
||||
func NewWithParams[T any](k int, sf SerializeFunc[T], numHashes, numCols int) *TopK[T] { |
||||
ret := &TopK[T]{ |
||||
heap: make(minHeap[T], 0, k), |
||||
k: k, |
||||
sf: sf, |
||||
} |
||||
ret.cms.init(numHashes, numCols) |
||||
return ret |
||||
} |
||||
|
||||
// Add calls AddN(val, 1).
|
||||
func (tk *TopK[T]) Add(val T) uint64 { |
||||
return tk.AddN(val, 1) |
||||
} |
||||
|
||||
var hashPool = &sync.Pool{ |
||||
New: func() any { |
||||
buf := make([]byte, 0, 128) |
||||
return &buf |
||||
}, |
||||
} |
||||
|
||||
// AddN adds the given item to the set with the provided count, returning the
|
||||
// new estimated count.
|
||||
func (tk *TopK[T]) AddN(val T, count uint64) uint64 { |
||||
buf := hashPool.Get().(*[]byte) |
||||
defer hashPool.Put(buf) |
||||
ser := tk.sf((*buf)[:0], val) |
||||
|
||||
vcount := tk.cms.AddN(ser, count) |
||||
|
||||
// If we don't have a full heap, just push it.
|
||||
if len(tk.heap) < tk.k { |
||||
heap.Push(&tk.heap, mhValue[T]{ |
||||
count: vcount, |
||||
val: val, |
||||
}) |
||||
return vcount |
||||
} |
||||
|
||||
// If this item's count surpasses the heap's minimum, update the heap.
|
||||
if vcount > tk.heap[0].count { |
||||
tk.heap[0] = mhValue[T]{ |
||||
count: vcount, |
||||
val: val, |
||||
} |
||||
heap.Fix(&tk.heap, 0) |
||||
} |
||||
return vcount |
||||
} |
||||
|
||||
// Top returns the estimated top K items as stored by this TopK.
|
||||
func (tk *TopK[T]) Top() []T { |
||||
ret := make([]T, 0, tk.k) |
||||
for _, item := range tk.heap { |
||||
ret = append(ret, item.val) |
||||
} |
||||
return ret |
||||
} |
||||
|
||||
// AppendTop appends the estimated top K items as stored by this TopK to the
|
||||
// provided slice, allocating only if the slice does not have enough capacity
|
||||
// to store all items. The provided slice can be nil.
|
||||
func (tk *TopK[T]) AppendTop(sl []T) []T { |
||||
sl = slices.Grow(sl, tk.k) |
||||
for _, item := range tk.heap { |
||||
sl = append(sl, item.val) |
||||
} |
||||
return sl |
||||
} |
||||
|
||||
// CountMinSketch implements a count-min sketch, a probabilistic data structure
|
||||
// that tracks the frequency of events in a stream of data.
|
||||
//
|
||||
// See: https://en.wikipedia.org/wiki/Count%E2%80%93min_sketch
|
||||
type CountMinSketch struct { |
||||
hashes []maphash.Seed |
||||
nbuckets int |
||||
matrix []uint64 |
||||
} |
||||
|
||||
// NewCountMinSketch creates a new CountMinSketch with the provided number of
|
||||
// hashes and buckets. Hashes and buckets are often called "depth" and "width",
|
||||
// or "d" and "w", respectively.
|
||||
func NewCountMinSketch(hashes, buckets int) *CountMinSketch { |
||||
ret := &CountMinSketch{} |
||||
ret.init(hashes, buckets) |
||||
return ret |
||||
} |
||||
|
||||
// PickParams provides good parameters for 'hashes' and 'buckets' when
|
||||
// constructing a CountMinSketch, given an estimated total number of counts
|
||||
// (i.e. the sum of all counts ever stored), the error factor ϵ as a float
|
||||
// (e.g. 1% is 0.001), and the probability factor δ.
|
||||
//
|
||||
// Parameters are chosen such that with a probability of 1−δ, the error is at
|
||||
// most ϵ∗totalCount. Or, in other words: if N is the true count of an event,
|
||||
// E is the estimate given by a sketch and T the total count of items in the
|
||||
// sketch, E ≤ N + T*ϵ with probability (1 - δ).
|
||||
func PickParams(err, probability float64) (hashes, buckets int) { |
||||
d := math.Ceil(math.Log(1 / probability)) |
||||
w := math.Ceil(math.E / err) |
||||
|
||||
return int(d), int(w) |
||||
} |
||||
|
||||
func (cms *CountMinSketch) init(hashes, buckets int) { |
||||
for i := 0; i < hashes; i++ { |
||||
cms.hashes = append(cms.hashes, maphash.MakeSeed()) |
||||
} |
||||
|
||||
// Need a matrix of hashes * buckets to store counts
|
||||
cms.nbuckets = buckets |
||||
cms.matrix = make([]uint64, hashes*buckets) |
||||
} |
||||
|
||||
// Add calls AddN(val, 1).
|
||||
func (cms *CountMinSketch) Add(val []byte) uint64 { |
||||
return cms.AddN(val, 1) |
||||
} |
||||
|
||||
// AddN increments the count for the given value by the provided count,
|
||||
// returning the new count.
|
||||
func (cms *CountMinSketch) AddN(val []byte, count uint64) uint64 { |
||||
var ( |
||||
mh maphash.Hash |
||||
ret uint64 = math.MaxUint64 |
||||
) |
||||
for i, seed := range cms.hashes { |
||||
mh.SetSeed(seed) |
||||
|
||||
// Generate a hash for this value using Lemire's alternative to modular reduction:
|
||||
// https://lemire.me/blog/2016/06/27/a-fast-alternative-to-the-modulo-reduction/
|
||||
mh.Write(val) |
||||
hash := mh.Sum64() |
||||
hash = multiplyHigh64(hash, uint64(cms.nbuckets)) |
||||
|
||||
// The index in our matrix is (i * buckets) to move "down" i
|
||||
// rows in our matrix to the row for this hash, plus 'hash' to
|
||||
// move inside this row.
|
||||
idx := (i * cms.nbuckets) + int(hash) |
||||
|
||||
// Add to this row
|
||||
cms.matrix[idx] += count |
||||
ret = min(ret, cms.matrix[idx]) |
||||
} |
||||
return ret |
||||
} |
||||
|
||||
// Get returns the count for the provided value.
|
||||
func (cms *CountMinSketch) Get(val []byte) uint64 { |
||||
var ( |
||||
mh maphash.Hash |
||||
ret uint64 = math.MaxUint64 |
||||
) |
||||
for i, seed := range cms.hashes { |
||||
mh.SetSeed(seed) |
||||
|
||||
// Generate a hash for this value using Lemire's alternative to modular reduction:
|
||||
// https://lemire.me/blog/2016/06/27/a-fast-alternative-to-the-modulo-reduction/
|
||||
mh.Write(val) |
||||
hash := mh.Sum64() |
||||
hash = multiplyHigh64(hash, uint64(cms.nbuckets)) |
||||
|
||||
// The index in our matrix is (i * buckets) to move "down" i
|
||||
// rows in our matrix to the row for this hash, plus 'hash' to
|
||||
// move inside this row.
|
||||
idx := (i * cms.nbuckets) + int(hash) |
||||
|
||||
// Select the minimal value among all rows
|
||||
ret = min(ret, cms.matrix[idx]) |
||||
} |
||||
return ret |
||||
} |
||||
|
||||
// multiplyHigh64 implements (x * y) >> 64 "the long way" without access to a
|
||||
// 128-bit type. This function is adapted from something similar in Tensorflow:
|
||||
//
|
||||
// https://github.com/tensorflow/tensorflow/commit/a47a300185026fe7829990def9113bf3a5109fed
|
||||
//
|
||||
// TODO(andrew-d): this could be replaced with a single "MULX" instruction on
|
||||
// x86_64 platforms, which we can do if this ever turns out to be a performance
|
||||
// bottleneck.
|
||||
func multiplyHigh64(x, y uint64) uint64 { |
||||
x_lo := x & 0xffffffff |
||||
x_hi := x >> 32 |
||||
buckets_lo := y & 0xffffffff |
||||
buckets_hi := y >> 32 |
||||
prod_hi := x_hi * buckets_hi |
||||
prod_lo := x_lo * buckets_lo |
||||
prod_mid1 := x_hi * buckets_lo |
||||
prod_mid2 := x_lo * buckets_hi |
||||
carry := ((prod_mid1 & 0xffffffff) + (prod_mid2 & 0xffffffff) + (prod_lo >> 32)) >> 32 |
||||
return prod_hi + (prod_mid1 >> 32) + (prod_mid2 >> 32) + carry |
||||
} |
||||
|
||||
type mhValue[T any] struct { |
||||
count uint64 |
||||
val T |
||||
} |
||||
|
||||
// An minHeap is a min-heap of ints and associated values.
|
||||
type minHeap[T any] []mhValue[T] |
||||
|
||||
func (h minHeap[T]) Len() int { return len(h) } |
||||
func (h minHeap[T]) Less(i, j int) bool { return h[i].count < h[j].count } |
||||
func (h minHeap[T]) Swap(i, j int) { h[i], h[j] = h[j], h[i] } |
||||
|
||||
func (h *minHeap[T]) Push(x any) { |
||||
// Push and Pop use pointer receivers because they modify the slice's length,
|
||||
// not just its contents.
|
||||
*h = append(*h, x.(mhValue[T])) |
||||
} |
||||
|
||||
func (h *minHeap[T]) Pop() any { |
||||
old := *h |
||||
n := len(old) |
||||
x := old[n-1] |
||||
*h = old[0 : n-1] |
||||
return x |
||||
} |
||||
@ -0,0 +1,135 @@ |
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package topk |
||||
|
||||
import ( |
||||
"encoding/binary" |
||||
"fmt" |
||||
"slices" |
||||
"testing" |
||||
) |
||||
|
||||
func TestCountMinSketch(t *testing.T) { |
||||
cms := NewCountMinSketch(4, 10) |
||||
items := []string{"foo", "bar", "baz", "asdf", "quux"} |
||||
for _, item := range items { |
||||
cms.Add([]byte(item)) |
||||
} |
||||
for _, item := range items { |
||||
count := cms.Get([]byte(item)) |
||||
if count < 1 { |
||||
t.Errorf("item %q should have count >= 1", item) |
||||
} else if count > 1 { |
||||
t.Logf("item %q has count > 1: %d", item, count) |
||||
} |
||||
} |
||||
|
||||
// Test that an item that's *not* in the set has a value lower than the
|
||||
// total number of items we inserted (in the case that all items
|
||||
// collided).
|
||||
noItemCount := cms.Get([]byte("doesn't exist")) |
||||
if noItemCount > uint64(len(items)) { |
||||
t.Errorf("expected nonexistent item to have value < %d; got %d", len(items), noItemCount) |
||||
} |
||||
} |
||||
|
||||
func TestTopK(t *testing.T) { |
||||
// This is probabilistic, so we're going to try 10 times to get the
|
||||
// "right" value; the likelihood that we fail on all attempts is
|
||||
// vanishingly small since the number of hash buckets is drastically
|
||||
// larger than the number of items we're inserting.
|
||||
var ( |
||||
got []int |
||||
want = []int{5, 6, 7, 8, 9} |
||||
) |
||||
for try := 0; try < 10; try++ { |
||||
topk := NewWithParams[int](5, func(in []byte, val int) []byte { |
||||
return binary.LittleEndian.AppendUint64(in, uint64(val)) |
||||
}, 4, 1000) |
||||
|
||||
// Add the first 10 integers with counts equal to 2x their value
|
||||
for i := 0; i < 10; i++ { |
||||
topk.AddN(i, uint64(i*2)) |
||||
} |
||||
|
||||
got = topk.Top() |
||||
t.Logf("top K items: %+v", got) |
||||
slices.Sort(got) |
||||
|
||||
if slices.Equal(got, want) { |
||||
// All good!
|
||||
return |
||||
} |
||||
|
||||
// continue and retry or fail
|
||||
} |
||||
|
||||
t.Errorf("top K mismatch\ngot: %v\nwant: %v", got, want) |
||||
} |
||||
|
||||
func TestPickParams(t *testing.T) { |
||||
hashes, buckets := PickParams( |
||||
0.001, // 0.1% error rate
|
||||
0.001, // 0.1% chance of having an error, or 99.9% chance of not having an error
|
||||
) |
||||
t.Logf("hashes = %d, buckets = %d", hashes, buckets) |
||||
} |
||||
|
||||
func BenchmarkCountMinSketch(b *testing.B) { |
||||
cms := NewCountMinSketch(PickParams(0.001, 0.001)) |
||||
b.ResetTimer() |
||||
b.ReportAllocs() |
||||
|
||||
var enc [8]byte |
||||
for i := 0; i < b.N; i++ { |
||||
binary.LittleEndian.PutUint64(enc[:], uint64(i)) |
||||
cms.Add(enc[:]) |
||||
} |
||||
} |
||||
|
||||
func BenchmarkTopK(b *testing.B) { |
||||
for _, n := range []int{ |
||||
10, |
||||
128, |
||||
256, |
||||
1024, |
||||
8192, |
||||
} { |
||||
b.Run(fmt.Sprintf("Top%d", n), func(b *testing.B) { |
||||
out := make([]int, 0, n) |
||||
topk := New[int](n, func(in []byte, val int) []byte { |
||||
return binary.LittleEndian.AppendUint64(in, uint64(val)) |
||||
}) |
||||
b.ResetTimer() |
||||
b.ReportAllocs() |
||||
|
||||
for i := 0; i < b.N; i++ { |
||||
topk.Add(i) |
||||
} |
||||
out = topk.AppendTop(out[:0]) // should not allocate
|
||||
_ = out // appease linter
|
||||
}) |
||||
} |
||||
} |
||||
|
||||
func TestMultiplyHigh64(t *testing.T) { |
||||
testCases := []struct { |
||||
x, y uint64 |
||||
want uint64 |
||||
}{ |
||||
{0, 0, 0}, |
||||
{0xffffffff, 0xffffffff, 0}, |
||||
{0x2, 0xf000000000000000, 1}, |
||||
{0x3, 0xf000000000000000, 2}, |
||||
{0x3, 0xf000000000000001, 2}, |
||||
{0x3, 0xffffffffffffffff, 2}, |
||||
{0xffffffffffffffff, 0xffffffffffffffff, 0xfffffffffffffffe}, |
||||
} |
||||
for _, tc := range testCases { |
||||
got := multiplyHigh64(tc.x, tc.y) |
||||
if got != tc.want { |
||||
t.Errorf("got multiplyHigh64(%x, %x) = %x, want %x", tc.x, tc.y, got, tc.want) |
||||
} |
||||
} |
||||
} |
||||
Loading…
Reference in new issue