derp: prevent concurrent access to multiForwarder map

Instead of iterating over the map to determine the preferred forwarder
on every packet (which could happen concurrently with map mutations),
store it separately in an atomic variable.

Fixes #6445

Signed-off-by: Anton Tolchanov <anton@tailscale.com>
This commit is contained in:
Anton Tolchanov
2022-11-22 16:13:53 +00:00
committed by Anton Tolchanov
parent 6e33d2da2b
commit 6cc6c70d70
2 changed files with 136 additions and 39 deletions
+73 -8
View File
@@ -19,6 +19,7 @@ import (
"net"
"os"
"reflect"
"strconv"
"sync"
"testing"
"time"
@@ -723,20 +724,14 @@ func TestForwarderRegistration(t *testing.T) {
s.AddPacketForwarder(u1, testFwd(100))
s.AddPacketForwarder(u1, testFwd(100)) // dup to trigger dup path
want(map[key.NodePublic]PacketForwarder{
u1: multiForwarder{
testFwd(1): 1,
testFwd(100): 2,
},
u1: newMultiForwarder(testFwd(1), testFwd(100)),
})
wantCounter(&s.multiForwarderCreated, 1)
// Removing a forwarder in a multi set that doesn't exist; does nothing.
s.RemovePacketForwarder(u1, testFwd(55))
want(map[key.NodePublic]PacketForwarder{
u1: multiForwarder{
testFwd(1): 1,
testFwd(100): 2,
},
u1: newMultiForwarder(testFwd(1), testFwd(100)),
})
// Removing a forwarder in a multi set that does exist should collapse it away
@@ -785,6 +780,76 @@ func TestForwarderRegistration(t *testing.T) {
})
}
type channelFwd struct {
// id is to ensure that different instances that reference the
// same channel are not equal, as they are used as keys in the
// multiForwarder map.
id int
c chan []byte
}
func (f channelFwd) ForwardPacket(_ key.NodePublic, _ key.NodePublic, packet []byte) error {
f.c <- packet
return nil
}
func TestMultiForwarder(t *testing.T) {
received := 0
var wg sync.WaitGroup
ch := make(chan []byte)
ctx, cancel := context.WithCancel(context.Background())
s := &Server{
clients: make(map[key.NodePublic]clientSet),
clientsMesh: map[key.NodePublic]PacketForwarder{},
}
u := pubAll(1)
s.AddPacketForwarder(u, channelFwd{1, ch})
wg.Add(2)
go func() {
defer wg.Done()
for {
select {
case <-ch:
received += 1
case <-ctx.Done():
return
}
}
}()
go func() {
defer wg.Done()
for {
s.AddPacketForwarder(u, channelFwd{2, ch})
s.AddPacketForwarder(u, channelFwd{3, ch})
s.RemovePacketForwarder(u, channelFwd{2, ch})
s.RemovePacketForwarder(u, channelFwd{1, ch})
s.AddPacketForwarder(u, channelFwd{1, ch})
s.RemovePacketForwarder(u, channelFwd{3, ch})
if ctx.Err() != nil {
return
}
}
}()
// Number of messages is chosen arbitrarily, just for this loop to
// run long enough concurrently with {Add,Remove}PacketForwarder loop above.
numMsgs := 5000
var fwd PacketForwarder
for i := 0; i < numMsgs; i++ {
s.mu.Lock()
fwd = s.clientsMesh[u]
s.mu.Unlock()
fwd.ForwardPacket(u, u, []byte(strconv.Itoa(i)))
}
cancel()
wg.Wait()
if received != numMsgs {
t.Errorf("expected %d messages to be forwarded; got %d", numMsgs, received)
}
}
func TestMetaCert(t *testing.T) {
priv := key.NewNode()
pub := priv.Public()