net/art: allow non-pointers as values
Values are still turned into pointers internally to maintain the invariants of strideTable, but from the user's perspective it's now possible to tbl.Insert(pfx, true) rather than tbl.Insert(pfx, ptr.To(true)). Updates #7781 Signed-off-by: David Anderson <danderson@tailscale.com>
This commit is contained in:
committed by
Dave Anderson
parent
bc0eb6b914
commit
e92adfe5e4
@@ -8,12 +8,12 @@ import (
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"sort"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"tailscale.com/types/ptr"
|
||||
)
|
||||
|
||||
func TestInversePrefix(t *testing.T) {
|
||||
@@ -65,10 +65,10 @@ func TestStrideTableInsert(t *testing.T) {
|
||||
|
||||
for i := 0; i < 256; i++ {
|
||||
addr := uint8(i)
|
||||
slowVal := slow.get(addr)
|
||||
fastVal := fast.get(addr)
|
||||
if slowVal != fastVal {
|
||||
t.Fatalf("strideTable.get(%d) = %v, want %v", addr, *fastVal, *slowVal)
|
||||
slowVal, slowOK := slow.get(addr)
|
||||
fastVal, fastOK := fast.get(addr)
|
||||
if !getsEqual(fastVal, fastOK, slowVal, slowOK) {
|
||||
t.Fatalf("strideTable.get(%d) = (%v, %v), want (%v, %v)", addr, fastVal, fastOK, slowVal, slowOK)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -91,10 +91,14 @@ func TestStrideTableInsertShuffled(t *testing.T) {
|
||||
|
||||
zero := 0
|
||||
rt := strideTable[int]{}
|
||||
// strideTable has a value interface, but internally has to keep
|
||||
// track of distinct routes even if they all have the same
|
||||
// value. rtZero uses the same value for all routes, and expects
|
||||
// correct behavior.
|
||||
rtZero := strideTable[int]{}
|
||||
for _, route := range routes {
|
||||
rt.insert(route.addr, route.len, route.val)
|
||||
rtZero.insert(route.addr, route.len, &zero)
|
||||
rtZero.insert(route.addr, route.len, zero)
|
||||
}
|
||||
|
||||
// Order of insertion should not affect the final shape of the stride table.
|
||||
@@ -105,15 +109,15 @@ func TestStrideTableInsertShuffled(t *testing.T) {
|
||||
for _, route := range routes2 {
|
||||
rt2.insert(route.addr, route.len, route.val)
|
||||
}
|
||||
if diff := cmp.Diff(rt, rt2, cmpDiffOpts...); diff != "" {
|
||||
if diff := cmp.Diff(rt.tableDebugString(), rt2.tableDebugString()); diff != "" {
|
||||
t.Errorf("tables ended up different with different insertion order (-got+want):\n%s\n\nOrder 1: %v\nOrder 2: %v", diff, formatSlowEntriesShort(routes), formatSlowEntriesShort(routes2))
|
||||
}
|
||||
|
||||
rtZero2 := strideTable[int]{}
|
||||
for _, route := range routes2 {
|
||||
rtZero2.insert(route.addr, route.len, &zero)
|
||||
rtZero2.insert(route.addr, route.len, zero)
|
||||
}
|
||||
if diff := cmp.Diff(rtZero, rtZero2, cmpDiffOpts...); diff != "" {
|
||||
if diff := cmp.Diff(rtZero.tableDebugString(), rtZero2.tableDebugString(), cmpDiffOpts...); diff != "" {
|
||||
t.Errorf("tables with identical vals ended up different with different insertion order (-got+want):\n%s\n\nOrder 1: %v\nOrder 2: %v", diff, formatSlowEntriesShort(routes), formatSlowEntriesShort(routes2))
|
||||
}
|
||||
}
|
||||
@@ -150,10 +154,10 @@ func TestStrideTableDelete(t *testing.T) {
|
||||
|
||||
for i := 0; i < 256; i++ {
|
||||
addr := uint8(i)
|
||||
slowVal := slow.get(addr)
|
||||
fastVal := fast.get(addr)
|
||||
if slowVal != fastVal {
|
||||
t.Fatalf("strideTable.get(%d) = %v, want %v", addr, *fastVal, *slowVal)
|
||||
slowVal, slowOK := slow.get(addr)
|
||||
fastVal, fastOK := fast.get(addr)
|
||||
if !getsEqual(fastVal, fastOK, slowVal, slowOK) {
|
||||
t.Fatalf("strideTable.get(%d) = (%v, %v), want (%v, %v)", addr, fastVal, fastOK, slowVal, slowOK)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -168,10 +172,14 @@ func TestStrideTableDeleteShuffle(t *testing.T) {
|
||||
|
||||
zero := 0
|
||||
rt := strideTable[int]{}
|
||||
// strideTable has a value interface, but internally has to keep
|
||||
// track of distinct routes even if they all have the same
|
||||
// value. rtZero uses the same value for all routes, and expects
|
||||
// correct behavior.
|
||||
rtZero := strideTable[int]{}
|
||||
for _, route := range routes {
|
||||
rt.insert(route.addr, route.len, route.val)
|
||||
rtZero.insert(route.addr, route.len, &zero)
|
||||
rtZero.insert(route.addr, route.len, zero)
|
||||
}
|
||||
for _, route := range toDelete {
|
||||
rt.delete(route.addr, route.len)
|
||||
@@ -189,18 +197,18 @@ func TestStrideTableDeleteShuffle(t *testing.T) {
|
||||
for _, route := range toDelete2 {
|
||||
rt2.delete(route.addr, route.len)
|
||||
}
|
||||
if diff := cmp.Diff(rt, rt2, cmpDiffOpts...); diff != "" {
|
||||
if diff := cmp.Diff(rt.tableDebugString(), rt2.tableDebugString(), cmpDiffOpts...); diff != "" {
|
||||
t.Errorf("tables ended up different with different deletion order (-got+want):\n%s\n\nOrder 1: %v\nOrder 2: %v", diff, formatSlowEntriesShort(toDelete), formatSlowEntriesShort(toDelete2))
|
||||
}
|
||||
|
||||
rtZero2 := strideTable[int]{}
|
||||
for _, route := range routes {
|
||||
rtZero2.insert(route.addr, route.len, &zero)
|
||||
rtZero2.insert(route.addr, route.len, zero)
|
||||
}
|
||||
for _, route := range toDelete2 {
|
||||
rtZero2.delete(route.addr, route.len)
|
||||
}
|
||||
if diff := cmp.Diff(rtZero, rtZero2, cmpDiffOpts...); diff != "" {
|
||||
if diff := cmp.Diff(rtZero.tableDebugString(), rtZero2.tableDebugString(), cmpDiffOpts...); diff != "" {
|
||||
t.Errorf("tables with identical vals ended up different with different deletion order (-got+want):\n%s\n\nOrder 1: %v\nOrder 2: %v", diff, formatSlowEntriesShort(toDelete), formatSlowEntriesShort(toDelete2))
|
||||
}
|
||||
}
|
||||
@@ -218,31 +226,35 @@ func forStrideCountAndOrdering(b *testing.B, fn func(b *testing.B, routes []slow
|
||||
routes := shufflePrefixes(allPrefixes())
|
||||
for _, nroutes := range strideRouteCount {
|
||||
b.Run(fmt.Sprint(nroutes), func(b *testing.B) {
|
||||
routes := append([]slowEntry[int](nil), routes[:nroutes]...)
|
||||
b.Run("random_order", func(b *testing.B) {
|
||||
runAndRecord := func(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
var startMem, endMem runtime.MemStats
|
||||
runtime.ReadMemStats(&startMem)
|
||||
fn(b, routes)
|
||||
})
|
||||
runtime.ReadMemStats(&endMem)
|
||||
ops := float64(b.N) * float64(len(routes))
|
||||
allocs := float64(endMem.Mallocs - startMem.Mallocs)
|
||||
bytes := float64(endMem.TotalAlloc - startMem.TotalAlloc)
|
||||
b.ReportMetric(roundFloat64(allocs/ops), "allocs/op")
|
||||
b.ReportMetric(roundFloat64(bytes/ops), "B/op")
|
||||
}
|
||||
|
||||
routes := append([]slowEntry[int](nil), routes[:nroutes]...)
|
||||
b.Run("random_order", runAndRecord)
|
||||
sort.Slice(routes, func(i, j int) bool {
|
||||
if routes[i].len < routes[j].len {
|
||||
return true
|
||||
}
|
||||
return routes[i].addr < routes[j].addr
|
||||
})
|
||||
b.Run("largest_first", func(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
fn(b, routes)
|
||||
})
|
||||
b.Run("largest_first", runAndRecord)
|
||||
sort.Slice(routes, func(i, j int) bool {
|
||||
if routes[j].len < routes[i].len {
|
||||
return true
|
||||
}
|
||||
return routes[j].addr < routes[i].addr
|
||||
})
|
||||
b.Run("smallest_first", func(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
fn(b, routes)
|
||||
})
|
||||
b.Run("smallest_first", runAndRecord)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -253,7 +265,7 @@ func BenchmarkStrideTableInsertion(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
var rt strideTable[int]
|
||||
for _, route := range routes {
|
||||
rt.insert(route.addr, route.len, &val)
|
||||
rt.insert(route.addr, route.len, val)
|
||||
}
|
||||
}
|
||||
inserts := float64(b.N) * float64(len(routes))
|
||||
@@ -269,7 +281,7 @@ func BenchmarkStrideTableDeletion(b *testing.B) {
|
||||
val := 0
|
||||
var rt strideTable[int]
|
||||
for _, route := range routes {
|
||||
rt.insert(route.addr, route.len, &val)
|
||||
rt.insert(route.addr, route.len, val)
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
@@ -287,7 +299,7 @@ func BenchmarkStrideTableDeletion(b *testing.B) {
|
||||
})
|
||||
}
|
||||
|
||||
var writeSink *int
|
||||
var writeSink int
|
||||
|
||||
func BenchmarkStrideTableGet(b *testing.B) {
|
||||
// No need to forCountAndOrdering here, route lookup time is independent of
|
||||
@@ -300,7 +312,7 @@ func BenchmarkStrideTableGet(b *testing.B) {
|
||||
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
writeSink = rt.get(uint8(i))
|
||||
writeSink, _ = rt.get(uint8(i))
|
||||
}
|
||||
gets := float64(b.N)
|
||||
elapsedSec := b.Elapsed().Seconds()
|
||||
@@ -318,7 +330,7 @@ type slowTable[T any] struct {
|
||||
type slowEntry[T any] struct {
|
||||
addr uint8
|
||||
len int
|
||||
val *T
|
||||
val T
|
||||
}
|
||||
|
||||
func (t *slowTable[T]) String() string {
|
||||
@@ -331,13 +343,14 @@ func (t *slowTable[T]) String() string {
|
||||
})
|
||||
var ret bytes.Buffer
|
||||
for _, pfx := range pfxs {
|
||||
fmt.Fprintf(&ret, "%3d/%d (%08b/%08b) = %v\n", pfx.addr, pfx.len, pfx.addr, pfxMask(pfx.len), *pfx.val)
|
||||
fmt.Fprintf(&ret, "%3d/%d (%08b/%08b) = %v\n", pfx.addr, pfx.len, pfx.addr, pfxMask(pfx.len), pfx.val)
|
||||
}
|
||||
return ret.String()
|
||||
}
|
||||
|
||||
func (t *slowTable[T]) insert(addr uint8, prefixLen int, val *T) {
|
||||
func (t *slowTable[T]) insert(addr uint8, prefixLen int, val T) {
|
||||
t.delete(addr, prefixLen) // no-op if prefix doesn't exist
|
||||
|
||||
t.prefixes = append(t.prefixes, slowEntry[T]{addr, prefixLen, val})
|
||||
}
|
||||
|
||||
@@ -352,18 +365,15 @@ func (t *slowTable[T]) delete(addr uint8, prefixLen int) {
|
||||
t.prefixes = pfx
|
||||
}
|
||||
|
||||
func (t *slowTable[T]) get(addr uint8) *T {
|
||||
var (
|
||||
ret *T
|
||||
curLen = -1
|
||||
)
|
||||
func (t *slowTable[T]) get(addr uint8) (ret T, ok bool) {
|
||||
var curLen = -1
|
||||
for _, e := range t.prefixes {
|
||||
if addr&pfxMask(e.len) == e.addr && e.len >= curLen {
|
||||
ret = e.val
|
||||
curLen = e.len
|
||||
}
|
||||
}
|
||||
return ret
|
||||
return ret, curLen != -1
|
||||
}
|
||||
|
||||
func pfxMask(pfxLen int) uint8 {
|
||||
@@ -374,7 +384,7 @@ func allPrefixes() []slowEntry[int] {
|
||||
ret := make([]slowEntry[int], 0, lastHostIndex)
|
||||
for i := 1; i < lastHostIndex+1; i++ {
|
||||
a, l := inversePrefixIndex(i)
|
||||
ret = append(ret, slowEntry[int]{a, l, ptr.To(i)})
|
||||
ret = append(ret, slowEntry[int]{a, l, i})
|
||||
}
|
||||
return ret
|
||||
}
|
||||
@@ -393,6 +403,15 @@ func formatSlowEntriesShort[T any](ents []slowEntry[T]) string {
|
||||
}
|
||||
|
||||
var cmpDiffOpts = []cmp.Option{
|
||||
cmp.AllowUnexported(strideTable[int]{}, strideEntry[int]{}),
|
||||
cmp.Comparer(func(a, b netip.Prefix) bool { return a == b }),
|
||||
}
|
||||
|
||||
func getsEqual[T comparable](a T, aOK bool, b T, bOK bool) bool {
|
||||
if !aOK && !bOK {
|
||||
return true
|
||||
}
|
||||
if aOK != bOK {
|
||||
return false
|
||||
}
|
||||
return a == b
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user