wgengine/netlog: merge connstats into package (#17557)

Merge the connstats package into the netlog package
and unexport all of its declarations.

Remove the buildfeatures.HasConnStats and use HasNetLog instead.

Updates tailscale/corp#33352

Signed-off-by: Joe Tsai <joetsai@digital-static.net>
This commit is contained in:
Joe Tsai
2025-10-16 00:07:29 -07:00
committed by GitHub
parent e75f13bd93
commit e804b64358
14 changed files with 43 additions and 104 deletions
-224
View File
@@ -1,224 +0,0 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
//go:build !ts_omit_connstats
// Package connstats maintains statistics about connections
// flowing through a TUN device (which operate at the IP layer).
package connstats
import (
"context"
"net/netip"
"sync"
"time"
"golang.org/x/sync/errgroup"
"tailscale.com/net/packet"
"tailscale.com/net/tsaddr"
"tailscale.com/types/ipproto"
"tailscale.com/types/netlogtype"
)
// Statistics maintains counters for every connection.
// All methods are safe for concurrent use.
// The zero value is ready for use.
type Statistics struct {
maxConns int // immutable once set
mu sync.Mutex
connCnts
connCntsCh chan connCnts
shutdownCtx context.Context
shutdown context.CancelFunc
group errgroup.Group
}
type connCnts struct {
start time.Time
end time.Time
virtual map[netlogtype.Connection]netlogtype.Counts
physical map[netlogtype.Connection]netlogtype.Counts
}
// NewStatistics creates a data structure for tracking connection statistics
// that periodically dumps the virtual and physical connection counts
// depending on whether the maxPeriod or maxConns is exceeded.
// The dump function is called from a single goroutine.
// Shutdown must be called to cleanup resources.
func NewStatistics(maxPeriod time.Duration, maxConns int, dump func(start, end time.Time, virtual, physical map[netlogtype.Connection]netlogtype.Counts)) *Statistics {
s := &Statistics{maxConns: maxConns}
s.connCntsCh = make(chan connCnts, 256)
s.shutdownCtx, s.shutdown = context.WithCancel(context.Background())
s.group.Go(func() error {
// TODO(joetsai): Using a ticker is problematic on mobile platforms
// where waking up a process every maxPeriod when there is no activity
// is a drain on battery life. Switch this instead to instead use
// a time.Timer that is triggered upon network activity.
ticker := new(time.Ticker)
if maxPeriod > 0 {
ticker = time.NewTicker(maxPeriod)
defer ticker.Stop()
}
for {
var cc connCnts
select {
case cc = <-s.connCntsCh:
case <-ticker.C:
cc = s.extract()
case <-s.shutdownCtx.Done():
cc = s.extract()
}
if len(cc.virtual)+len(cc.physical) > 0 && dump != nil {
dump(cc.start, cc.end, cc.virtual, cc.physical)
}
if s.shutdownCtx.Err() != nil {
return nil
}
}
})
return s
}
// UpdateTxVirtual updates the counters for a transmitted IP packet
// The source and destination of the packet directly correspond with
// the source and destination in netlogtype.Connection.
func (s *Statistics) UpdateTxVirtual(b []byte) {
var p packet.Parsed
p.Decode(b)
s.UpdateVirtual(p.IPProto, p.Src, p.Dst, 1, len(b), false)
}
// UpdateRxVirtual updates the counters for a received IP packet.
// The source and destination of the packet are inverted with respect to
// the source and destination in netlogtype.Connection.
func (s *Statistics) UpdateRxVirtual(b []byte) {
var p packet.Parsed
p.Decode(b)
s.UpdateVirtual(p.IPProto, p.Dst, p.Src, 1, len(b), true)
}
var (
tailscaleServiceIPv4 = tsaddr.TailscaleServiceIP()
tailscaleServiceIPv6 = tsaddr.TailscaleServiceIPv6()
)
func (s *Statistics) UpdateVirtual(proto ipproto.Proto, src, dst netip.AddrPort, packets, bytes int, receive bool) {
// Network logging is defined as traffic between two Tailscale nodes.
// Traffic with the internal Tailscale service is not with another node
// and should not be logged. It also happens to be a high volume
// amount of discrete traffic flows (e.g., DNS lookups).
switch dst.Addr() {
case tailscaleServiceIPv4, tailscaleServiceIPv6:
return
}
conn := netlogtype.Connection{Proto: proto, Src: src, Dst: dst}
s.mu.Lock()
defer s.mu.Unlock()
cnts, found := s.virtual[conn]
if !found && !s.preInsertConn() {
return
}
if receive {
cnts.RxPackets += uint64(packets)
cnts.RxBytes += uint64(bytes)
} else {
cnts.TxPackets += uint64(packets)
cnts.TxBytes += uint64(bytes)
}
s.virtual[conn] = cnts
}
// UpdateTxPhysical updates the counters for zero or more transmitted wireguard packets.
// The src is always a Tailscale IP address, representing some remote peer.
// The dst is a remote IP address and port that corresponds
// with some physical peer backing the Tailscale IP address.
func (s *Statistics) UpdateTxPhysical(src netip.Addr, dst netip.AddrPort, packets, bytes int) {
s.UpdatePhysical(0, netip.AddrPortFrom(src, 0), dst, packets, bytes, false)
}
// UpdateRxPhysical updates the counters for zero or more received wireguard packets.
// The src is always a Tailscale IP address, representing some remote peer.
// The dst is a remote IP address and port that corresponds
// with some physical peer backing the Tailscale IP address.
func (s *Statistics) UpdateRxPhysical(src netip.Addr, dst netip.AddrPort, packets, bytes int) {
s.UpdatePhysical(0, netip.AddrPortFrom(src, 0), dst, packets, bytes, true)
}
func (s *Statistics) UpdatePhysical(proto ipproto.Proto, src, dst netip.AddrPort, packets, bytes int, receive bool) {
conn := netlogtype.Connection{Proto: proto, Src: src, Dst: dst}
s.mu.Lock()
defer s.mu.Unlock()
cnts, found := s.physical[conn]
if !found && !s.preInsertConn() {
return
}
if receive {
cnts.RxPackets += uint64(packets)
cnts.RxBytes += uint64(bytes)
} else {
cnts.TxPackets += uint64(packets)
cnts.TxBytes += uint64(bytes)
}
s.physical[conn] = cnts
}
// preInsertConn updates the maps to handle insertion of a new connection.
// It reports false if insertion is not allowed (i.e., after shutdown).
func (s *Statistics) preInsertConn() bool {
// Check whether insertion of a new connection will exceed maxConns.
if len(s.virtual)+len(s.physical) == s.maxConns && s.maxConns > 0 {
// Extract the current statistics and send it to the serializer.
// Avoid blocking the network packet handling path.
select {
case s.connCntsCh <- s.extractLocked():
default:
// TODO(joetsai): Log that we are dropping an entire connCounts.
}
}
// Initialize the maps if nil.
if s.virtual == nil && s.physical == nil {
s.start = time.Now().UTC()
s.virtual = make(map[netlogtype.Connection]netlogtype.Counts)
s.physical = make(map[netlogtype.Connection]netlogtype.Counts)
}
return s.shutdownCtx.Err() == nil
}
func (s *Statistics) extract() connCnts {
s.mu.Lock()
defer s.mu.Unlock()
return s.extractLocked()
}
func (s *Statistics) extractLocked() connCnts {
if len(s.virtual)+len(s.physical) == 0 {
return connCnts{}
}
s.end = time.Now().UTC()
cc := s.connCnts
s.connCnts = connCnts{}
return cc
}
// TestExtract synchronously extracts the current network statistics map
// and resets the counters. This should only be used for testing purposes.
func (s *Statistics) TestExtract() (virtual, physical map[netlogtype.Connection]netlogtype.Counts) {
cc := s.extract()
return cc.virtual, cc.physical
}
// Shutdown performs a final flush of statistics.
// Statistics for any subsequent calls to Update will be dropped.
// It is safe to call Shutdown concurrently and repeatedly.
func (s *Statistics) Shutdown(context.Context) error {
s.shutdown()
return s.group.Wait()
}
-24
View File
@@ -1,24 +0,0 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
//go:build ts_omit_connstats
package connstats
import (
"context"
"net/netip"
"time"
)
type Statistics struct{}
func NewStatistics(maxPeriod time.Duration, maxConns int, dump func(start, end time.Time, virtual, physical any)) *Statistics {
return &Statistics{}
}
func (s *Statistics) UpdateTxVirtual(b []byte) {}
func (s *Statistics) UpdateRxVirtual(b []byte) {}
func (s *Statistics) UpdateTxPhysical(src netip.Addr, dst netip.AddrPort, packets, bytes int) {}
func (s *Statistics) UpdateRxPhysical(src netip.Addr, dst netip.AddrPort, packets, bytes int) {}
func (s *Statistics) Shutdown(context.Context) error { return nil }
-235
View File
@@ -1,235 +0,0 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package connstats
import (
"context"
"encoding/binary"
"fmt"
"math/rand"
"net/netip"
"runtime"
"sync"
"testing"
"time"
qt "github.com/frankban/quicktest"
"tailscale.com/cmd/testwrapper/flakytest"
"tailscale.com/types/ipproto"
"tailscale.com/types/netlogtype"
)
func testPacketV4(proto ipproto.Proto, srcAddr, dstAddr [4]byte, srcPort, dstPort, size uint16) (out []byte) {
var ipHdr [20]byte
ipHdr[0] = 4<<4 | 5
binary.BigEndian.PutUint16(ipHdr[2:], size)
ipHdr[9] = byte(proto)
*(*[4]byte)(ipHdr[12:]) = srcAddr
*(*[4]byte)(ipHdr[16:]) = dstAddr
out = append(out, ipHdr[:]...)
switch proto {
case ipproto.TCP:
var tcpHdr [20]byte
binary.BigEndian.PutUint16(tcpHdr[0:], srcPort)
binary.BigEndian.PutUint16(tcpHdr[2:], dstPort)
out = append(out, tcpHdr[:]...)
case ipproto.UDP:
var udpHdr [8]byte
binary.BigEndian.PutUint16(udpHdr[0:], srcPort)
binary.BigEndian.PutUint16(udpHdr[2:], dstPort)
out = append(out, udpHdr[:]...)
default:
panic(fmt.Sprintf("unknown proto: %d", proto))
}
return append(out, make([]byte, int(size)-len(out))...)
}
// TestInterval ensures that we receive at least one call to `dump` using only
// maxPeriod.
func TestInterval(t *testing.T) {
c := qt.New(t)
const maxPeriod = 10 * time.Millisecond
const maxConns = 2048
gotDump := make(chan struct{}, 1)
stats := NewStatistics(maxPeriod, maxConns, func(_, _ time.Time, _, _ map[netlogtype.Connection]netlogtype.Counts) {
select {
case gotDump <- struct{}{}:
default:
}
})
defer stats.Shutdown(context.Background())
srcAddr := netip.AddrFrom4([4]byte{192, 168, 0, byte(rand.Intn(16))})
dstAddr := netip.AddrFrom4([4]byte{192, 168, 0, byte(rand.Intn(16))})
srcPort := uint16(rand.Intn(16))
dstPort := uint16(rand.Intn(16))
size := uint16(64 + rand.Intn(1024))
p := testPacketV4(ipproto.TCP, srcAddr.As4(), dstAddr.As4(), srcPort, dstPort, size)
stats.UpdateRxVirtual(p)
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
select {
case <-ctx.Done():
c.Fatal("didn't receive dump within context deadline")
case <-gotDump:
}
}
func TestConcurrent(t *testing.T) {
flakytest.Mark(t, "https://github.com/tailscale/tailscale/issues/7030")
c := qt.New(t)
const maxPeriod = 10 * time.Millisecond
const maxConns = 10
virtualAggregate := make(map[netlogtype.Connection]netlogtype.Counts)
stats := NewStatistics(maxPeriod, maxConns, func(start, end time.Time, virtual, physical map[netlogtype.Connection]netlogtype.Counts) {
c.Assert(start.IsZero(), qt.IsFalse)
c.Assert(end.IsZero(), qt.IsFalse)
c.Assert(end.Before(start), qt.IsFalse)
c.Assert(len(virtual) > 0 && len(virtual) <= maxConns, qt.IsTrue)
c.Assert(len(physical) == 0, qt.IsTrue)
for conn, cnts := range virtual {
virtualAggregate[conn] = virtualAggregate[conn].Add(cnts)
}
})
defer stats.Shutdown(context.Background())
var wants []map[netlogtype.Connection]netlogtype.Counts
gots := make([]map[netlogtype.Connection]netlogtype.Counts, runtime.NumCPU())
var group sync.WaitGroup
for i := range gots {
group.Add(1)
go func(i int) {
defer group.Done()
gots[i] = make(map[netlogtype.Connection]netlogtype.Counts)
rn := rand.New(rand.NewSource(time.Now().UnixNano()))
var p []byte
var t netlogtype.Connection
for j := 0; j < 1000; j++ {
delay := rn.Intn(10000)
if p == nil || rn.Intn(64) == 0 {
proto := ipproto.TCP
if rn.Intn(2) == 0 {
proto = ipproto.UDP
}
srcAddr := netip.AddrFrom4([4]byte{192, 168, 0, byte(rand.Intn(16))})
dstAddr := netip.AddrFrom4([4]byte{192, 168, 0, byte(rand.Intn(16))})
srcPort := uint16(rand.Intn(16))
dstPort := uint16(rand.Intn(16))
size := uint16(64 + rand.Intn(1024))
p = testPacketV4(proto, srcAddr.As4(), dstAddr.As4(), srcPort, dstPort, size)
t = netlogtype.Connection{Proto: proto, Src: netip.AddrPortFrom(srcAddr, srcPort), Dst: netip.AddrPortFrom(dstAddr, dstPort)}
}
t2 := t
receive := rn.Intn(2) == 0
if receive {
t2.Src, t2.Dst = t2.Dst, t2.Src
}
cnts := gots[i][t2]
if receive {
stats.UpdateRxVirtual(p)
cnts.RxPackets++
cnts.RxBytes += uint64(len(p))
} else {
cnts.TxPackets++
cnts.TxBytes += uint64(len(p))
stats.UpdateTxVirtual(p)
}
gots[i][t2] = cnts
time.Sleep(time.Duration(rn.Intn(1 + delay)))
}
}(i)
}
group.Wait()
c.Assert(stats.Shutdown(context.Background()), qt.IsNil)
wants = append(wants, virtualAggregate)
got := make(map[netlogtype.Connection]netlogtype.Counts)
want := make(map[netlogtype.Connection]netlogtype.Counts)
mergeMaps(got, gots...)
mergeMaps(want, wants...)
c.Assert(got, qt.DeepEquals, want)
}
func mergeMaps(dst map[netlogtype.Connection]netlogtype.Counts, srcs ...map[netlogtype.Connection]netlogtype.Counts) {
for _, src := range srcs {
for conn, cnts := range src {
dst[conn] = dst[conn].Add(cnts)
}
}
}
func Benchmark(b *testing.B) {
// TODO: Test IPv6 packets?
b.Run("SingleRoutine/SameConn", func(b *testing.B) {
p := testPacketV4(ipproto.UDP, [4]byte{192, 168, 0, 1}, [4]byte{192, 168, 0, 2}, 123, 456, 789)
b.ResetTimer()
b.ReportAllocs()
for range b.N {
s := NewStatistics(0, 0, nil)
for j := 0; j < 1e3; j++ {
s.UpdateTxVirtual(p)
}
}
})
b.Run("SingleRoutine/UniqueConns", func(b *testing.B) {
p := testPacketV4(ipproto.UDP, [4]byte{}, [4]byte{}, 0, 0, 789)
b.ResetTimer()
b.ReportAllocs()
for range b.N {
s := NewStatistics(0, 0, nil)
for j := 0; j < 1e3; j++ {
binary.BigEndian.PutUint32(p[20:], uint32(j)) // unique port combination
s.UpdateTxVirtual(p)
}
}
})
b.Run("MultiRoutine/SameConn", func(b *testing.B) {
p := testPacketV4(ipproto.UDP, [4]byte{192, 168, 0, 1}, [4]byte{192, 168, 0, 2}, 123, 456, 789)
b.ResetTimer()
b.ReportAllocs()
for range b.N {
s := NewStatistics(0, 0, nil)
var group sync.WaitGroup
for j := 0; j < runtime.NumCPU(); j++ {
group.Add(1)
go func() {
defer group.Done()
for k := 0; k < 1e3; k++ {
s.UpdateTxVirtual(p)
}
}()
}
group.Wait()
}
})
b.Run("MultiRoutine/UniqueConns", func(b *testing.B) {
ps := make([][]byte, runtime.NumCPU())
for i := range ps {
ps[i] = testPacketV4(ipproto.UDP, [4]byte{192, 168, 0, 1}, [4]byte{192, 168, 0, 2}, 0, 0, 789)
}
b.ResetTimer()
b.ReportAllocs()
for range b.N {
s := NewStatistics(0, 0, nil)
var group sync.WaitGroup
for j := 0; j < runtime.NumCPU(); j++ {
group.Add(1)
go func(j int) {
defer group.Done()
p := ps[j]
j *= 1e3
for k := 0; k < 1e3; k++ {
binary.BigEndian.PutUint32(p[20:], uint32(j+k)) // unique port combination
s.UpdateTxVirtual(p)
}
}(j)
}
group.Wait()
}
})
}
+4 -4
View File
@@ -976,7 +976,7 @@ func (t *Wrapper) Read(buffs [][]byte, sizes []int, offset int) (int, error) {
panic(fmt.Sprintf("short copy: %d != %d", n, len(data)-res.dataOffset))
}
sizes[buffsPos] = n
if buildfeatures.HasConnStats {
if buildfeatures.HasNetLog {
if update := t.connCounter.Load(); update != nil {
updateConnCounter(update, p.Buffer(), false)
}
@@ -1105,7 +1105,7 @@ func (t *Wrapper) injectedRead(res tunInjectedRead, outBuffs [][]byte, sizes []i
n, err = tun.GSOSplit(pkt, gsoOptions, outBuffs, sizes, offset)
}
if buildfeatures.HasConnStats {
if buildfeatures.HasNetLog {
if update := t.connCounter.Load(); update != nil {
for i := 0; i < n; i++ {
updateConnCounter(update, outBuffs[i][offset:offset+sizes[i]], false)
@@ -1275,7 +1275,7 @@ func (t *Wrapper) Write(buffs [][]byte, offset int) (int, error) {
}
func (t *Wrapper) tdevWrite(buffs [][]byte, offset int) (int, error) {
if buildfeatures.HasConnStats {
if buildfeatures.HasNetLog {
if update := t.connCounter.Load(); update != nil {
for i := range buffs {
updateConnCounter(update, buffs[i][offset:], true)
@@ -1501,7 +1501,7 @@ func (t *Wrapper) Unwrap() tun.Device {
// SetConnectionCounter specifies a per-connection statistics aggregator.
// Nil may be specified to disable statistics gathering.
func (t *Wrapper) SetConnectionCounter(fn netlogfunc.ConnectionCounter) {
if buildfeatures.HasConnStats {
if buildfeatures.HasNetLog {
t.connCounter.Store(fn)
}
}
+1 -1
View File
@@ -380,7 +380,7 @@ func TestFilter(t *testing.T) {
tunStats := stats.Clone()
stats.Reset()
if len(tunStats) > 0 {
t.Errorf("connstats.Statistics.Extract = %v, want {}", tunStats)
t.Errorf("netlogtype.CountsByConnection = %v, want {}", tunStats)
}
if tt.dir == in {