net/art: allow non-pointers as values

Values are still turned into pointers internally to maintain the
invariants of strideTable, but from the user's perspective it's
now possible to tbl.Insert(pfx, true) rather than
tbl.Insert(pfx, ptr.To(true)).

Updates #7781

Signed-off-by: David Anderson <danderson@tailscale.com>
This commit is contained in:
David Anderson
2023-08-16 15:51:57 -07:00
committed by Dave Anderson
parent bc0eb6b914
commit e92adfe5e4
4 changed files with 246 additions and 220 deletions
+62 -43
View File
@@ -8,12 +8,12 @@ import (
"fmt"
"math/rand"
"net/netip"
"runtime"
"sort"
"strings"
"testing"
"github.com/google/go-cmp/cmp"
"tailscale.com/types/ptr"
)
func TestInversePrefix(t *testing.T) {
@@ -65,10 +65,10 @@ func TestStrideTableInsert(t *testing.T) {
for i := 0; i < 256; i++ {
addr := uint8(i)
slowVal := slow.get(addr)
fastVal := fast.get(addr)
if slowVal != fastVal {
t.Fatalf("strideTable.get(%d) = %v, want %v", addr, *fastVal, *slowVal)
slowVal, slowOK := slow.get(addr)
fastVal, fastOK := fast.get(addr)
if !getsEqual(fastVal, fastOK, slowVal, slowOK) {
t.Fatalf("strideTable.get(%d) = (%v, %v), want (%v, %v)", addr, fastVal, fastOK, slowVal, slowOK)
}
}
}
@@ -91,10 +91,14 @@ func TestStrideTableInsertShuffled(t *testing.T) {
zero := 0
rt := strideTable[int]{}
// strideTable has a value interface, but internally has to keep
// track of distinct routes even if they all have the same
// value. rtZero uses the same value for all routes, and expects
// correct behavior.
rtZero := strideTable[int]{}
for _, route := range routes {
rt.insert(route.addr, route.len, route.val)
rtZero.insert(route.addr, route.len, &zero)
rtZero.insert(route.addr, route.len, zero)
}
// Order of insertion should not affect the final shape of the stride table.
@@ -105,15 +109,15 @@ func TestStrideTableInsertShuffled(t *testing.T) {
for _, route := range routes2 {
rt2.insert(route.addr, route.len, route.val)
}
if diff := cmp.Diff(rt, rt2, cmpDiffOpts...); diff != "" {
if diff := cmp.Diff(rt.tableDebugString(), rt2.tableDebugString()); diff != "" {
t.Errorf("tables ended up different with different insertion order (-got+want):\n%s\n\nOrder 1: %v\nOrder 2: %v", diff, formatSlowEntriesShort(routes), formatSlowEntriesShort(routes2))
}
rtZero2 := strideTable[int]{}
for _, route := range routes2 {
rtZero2.insert(route.addr, route.len, &zero)
rtZero2.insert(route.addr, route.len, zero)
}
if diff := cmp.Diff(rtZero, rtZero2, cmpDiffOpts...); diff != "" {
if diff := cmp.Diff(rtZero.tableDebugString(), rtZero2.tableDebugString(), cmpDiffOpts...); diff != "" {
t.Errorf("tables with identical vals ended up different with different insertion order (-got+want):\n%s\n\nOrder 1: %v\nOrder 2: %v", diff, formatSlowEntriesShort(routes), formatSlowEntriesShort(routes2))
}
}
@@ -150,10 +154,10 @@ func TestStrideTableDelete(t *testing.T) {
for i := 0; i < 256; i++ {
addr := uint8(i)
slowVal := slow.get(addr)
fastVal := fast.get(addr)
if slowVal != fastVal {
t.Fatalf("strideTable.get(%d) = %v, want %v", addr, *fastVal, *slowVal)
slowVal, slowOK := slow.get(addr)
fastVal, fastOK := fast.get(addr)
if !getsEqual(fastVal, fastOK, slowVal, slowOK) {
t.Fatalf("strideTable.get(%d) = (%v, %v), want (%v, %v)", addr, fastVal, fastOK, slowVal, slowOK)
}
}
}
@@ -168,10 +172,14 @@ func TestStrideTableDeleteShuffle(t *testing.T) {
zero := 0
rt := strideTable[int]{}
// strideTable has a value interface, but internally has to keep
// track of distinct routes even if they all have the same
// value. rtZero uses the same value for all routes, and expects
// correct behavior.
rtZero := strideTable[int]{}
for _, route := range routes {
rt.insert(route.addr, route.len, route.val)
rtZero.insert(route.addr, route.len, &zero)
rtZero.insert(route.addr, route.len, zero)
}
for _, route := range toDelete {
rt.delete(route.addr, route.len)
@@ -189,18 +197,18 @@ func TestStrideTableDeleteShuffle(t *testing.T) {
for _, route := range toDelete2 {
rt2.delete(route.addr, route.len)
}
if diff := cmp.Diff(rt, rt2, cmpDiffOpts...); diff != "" {
if diff := cmp.Diff(rt.tableDebugString(), rt2.tableDebugString(), cmpDiffOpts...); diff != "" {
t.Errorf("tables ended up different with different deletion order (-got+want):\n%s\n\nOrder 1: %v\nOrder 2: %v", diff, formatSlowEntriesShort(toDelete), formatSlowEntriesShort(toDelete2))
}
rtZero2 := strideTable[int]{}
for _, route := range routes {
rtZero2.insert(route.addr, route.len, &zero)
rtZero2.insert(route.addr, route.len, zero)
}
for _, route := range toDelete2 {
rtZero2.delete(route.addr, route.len)
}
if diff := cmp.Diff(rtZero, rtZero2, cmpDiffOpts...); diff != "" {
if diff := cmp.Diff(rtZero.tableDebugString(), rtZero2.tableDebugString(), cmpDiffOpts...); diff != "" {
t.Errorf("tables with identical vals ended up different with different deletion order (-got+want):\n%s\n\nOrder 1: %v\nOrder 2: %v", diff, formatSlowEntriesShort(toDelete), formatSlowEntriesShort(toDelete2))
}
}
@@ -218,31 +226,35 @@ func forStrideCountAndOrdering(b *testing.B, fn func(b *testing.B, routes []slow
routes := shufflePrefixes(allPrefixes())
for _, nroutes := range strideRouteCount {
b.Run(fmt.Sprint(nroutes), func(b *testing.B) {
routes := append([]slowEntry[int](nil), routes[:nroutes]...)
b.Run("random_order", func(b *testing.B) {
runAndRecord := func(b *testing.B) {
b.ReportAllocs()
var startMem, endMem runtime.MemStats
runtime.ReadMemStats(&startMem)
fn(b, routes)
})
runtime.ReadMemStats(&endMem)
ops := float64(b.N) * float64(len(routes))
allocs := float64(endMem.Mallocs - startMem.Mallocs)
bytes := float64(endMem.TotalAlloc - startMem.TotalAlloc)
b.ReportMetric(roundFloat64(allocs/ops), "allocs/op")
b.ReportMetric(roundFloat64(bytes/ops), "B/op")
}
routes := append([]slowEntry[int](nil), routes[:nroutes]...)
b.Run("random_order", runAndRecord)
sort.Slice(routes, func(i, j int) bool {
if routes[i].len < routes[j].len {
return true
}
return routes[i].addr < routes[j].addr
})
b.Run("largest_first", func(b *testing.B) {
b.ReportAllocs()
fn(b, routes)
})
b.Run("largest_first", runAndRecord)
sort.Slice(routes, func(i, j int) bool {
if routes[j].len < routes[i].len {
return true
}
return routes[j].addr < routes[i].addr
})
b.Run("smallest_first", func(b *testing.B) {
b.ReportAllocs()
fn(b, routes)
})
b.Run("smallest_first", runAndRecord)
})
}
}
@@ -253,7 +265,7 @@ func BenchmarkStrideTableInsertion(b *testing.B) {
for i := 0; i < b.N; i++ {
var rt strideTable[int]
for _, route := range routes {
rt.insert(route.addr, route.len, &val)
rt.insert(route.addr, route.len, val)
}
}
inserts := float64(b.N) * float64(len(routes))
@@ -269,7 +281,7 @@ func BenchmarkStrideTableDeletion(b *testing.B) {
val := 0
var rt strideTable[int]
for _, route := range routes {
rt.insert(route.addr, route.len, &val)
rt.insert(route.addr, route.len, val)
}
b.ResetTimer()
@@ -287,7 +299,7 @@ func BenchmarkStrideTableDeletion(b *testing.B) {
})
}
var writeSink *int
var writeSink int
func BenchmarkStrideTableGet(b *testing.B) {
// No need to forCountAndOrdering here, route lookup time is independent of
@@ -300,7 +312,7 @@ func BenchmarkStrideTableGet(b *testing.B) {
b.ResetTimer()
for i := 0; i < b.N; i++ {
writeSink = rt.get(uint8(i))
writeSink, _ = rt.get(uint8(i))
}
gets := float64(b.N)
elapsedSec := b.Elapsed().Seconds()
@@ -318,7 +330,7 @@ type slowTable[T any] struct {
type slowEntry[T any] struct {
addr uint8
len int
val *T
val T
}
func (t *slowTable[T]) String() string {
@@ -331,13 +343,14 @@ func (t *slowTable[T]) String() string {
})
var ret bytes.Buffer
for _, pfx := range pfxs {
fmt.Fprintf(&ret, "%3d/%d (%08b/%08b) = %v\n", pfx.addr, pfx.len, pfx.addr, pfxMask(pfx.len), *pfx.val)
fmt.Fprintf(&ret, "%3d/%d (%08b/%08b) = %v\n", pfx.addr, pfx.len, pfx.addr, pfxMask(pfx.len), pfx.val)
}
return ret.String()
}
func (t *slowTable[T]) insert(addr uint8, prefixLen int, val *T) {
func (t *slowTable[T]) insert(addr uint8, prefixLen int, val T) {
t.delete(addr, prefixLen) // no-op if prefix doesn't exist
t.prefixes = append(t.prefixes, slowEntry[T]{addr, prefixLen, val})
}
@@ -352,18 +365,15 @@ func (t *slowTable[T]) delete(addr uint8, prefixLen int) {
t.prefixes = pfx
}
func (t *slowTable[T]) get(addr uint8) *T {
var (
ret *T
curLen = -1
)
func (t *slowTable[T]) get(addr uint8) (ret T, ok bool) {
var curLen = -1
for _, e := range t.prefixes {
if addr&pfxMask(e.len) == e.addr && e.len >= curLen {
ret = e.val
curLen = e.len
}
}
return ret
return ret, curLen != -1
}
func pfxMask(pfxLen int) uint8 {
@@ -374,7 +384,7 @@ func allPrefixes() []slowEntry[int] {
ret := make([]slowEntry[int], 0, lastHostIndex)
for i := 1; i < lastHostIndex+1; i++ {
a, l := inversePrefixIndex(i)
ret = append(ret, slowEntry[int]{a, l, ptr.To(i)})
ret = append(ret, slowEntry[int]{a, l, i})
}
return ret
}
@@ -393,6 +403,15 @@ func formatSlowEntriesShort[T any](ents []slowEntry[T]) string {
}
var cmpDiffOpts = []cmp.Option{
cmp.AllowUnexported(strideTable[int]{}, strideEntry[int]{}),
cmp.Comparer(func(a, b netip.Prefix) bool { return a == b }),
}
func getsEqual[T comparable](a T, aOK bool, b T, bOK bool) bool {
if !aOK && !bOK {
return true
}
if aOK != bOK {
return false
}
return a == b
}