derp: prevent concurrent access to multiForwarder map
Instead of iterating over the map to determine the preferred forwarder on every packet (which could happen concurrently with map mutations), store it separately in an atomic variable. Fixes #6445 Signed-off-by: Anton Tolchanov <anton@tailscale.com>
This commit is contained in:
committed by
Anton Tolchanov
parent
6e33d2da2b
commit
6cc6c70d70
+73
-8
@@ -19,6 +19,7 @@ import (
|
||||
"net"
|
||||
"os"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -723,20 +724,14 @@ func TestForwarderRegistration(t *testing.T) {
|
||||
s.AddPacketForwarder(u1, testFwd(100))
|
||||
s.AddPacketForwarder(u1, testFwd(100)) // dup to trigger dup path
|
||||
want(map[key.NodePublic]PacketForwarder{
|
||||
u1: multiForwarder{
|
||||
testFwd(1): 1,
|
||||
testFwd(100): 2,
|
||||
},
|
||||
u1: newMultiForwarder(testFwd(1), testFwd(100)),
|
||||
})
|
||||
wantCounter(&s.multiForwarderCreated, 1)
|
||||
|
||||
// Removing a forwarder in a multi set that doesn't exist; does nothing.
|
||||
s.RemovePacketForwarder(u1, testFwd(55))
|
||||
want(map[key.NodePublic]PacketForwarder{
|
||||
u1: multiForwarder{
|
||||
testFwd(1): 1,
|
||||
testFwd(100): 2,
|
||||
},
|
||||
u1: newMultiForwarder(testFwd(1), testFwd(100)),
|
||||
})
|
||||
|
||||
// Removing a forwarder in a multi set that does exist should collapse it away
|
||||
@@ -785,6 +780,76 @@ func TestForwarderRegistration(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
type channelFwd struct {
|
||||
// id is to ensure that different instances that reference the
|
||||
// same channel are not equal, as they are used as keys in the
|
||||
// multiForwarder map.
|
||||
id int
|
||||
c chan []byte
|
||||
}
|
||||
|
||||
func (f channelFwd) ForwardPacket(_ key.NodePublic, _ key.NodePublic, packet []byte) error {
|
||||
f.c <- packet
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestMultiForwarder(t *testing.T) {
|
||||
received := 0
|
||||
var wg sync.WaitGroup
|
||||
ch := make(chan []byte)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
s := &Server{
|
||||
clients: make(map[key.NodePublic]clientSet),
|
||||
clientsMesh: map[key.NodePublic]PacketForwarder{},
|
||||
}
|
||||
u := pubAll(1)
|
||||
s.AddPacketForwarder(u, channelFwd{1, ch})
|
||||
|
||||
wg.Add(2)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for {
|
||||
select {
|
||||
case <-ch:
|
||||
received += 1
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
for {
|
||||
s.AddPacketForwarder(u, channelFwd{2, ch})
|
||||
s.AddPacketForwarder(u, channelFwd{3, ch})
|
||||
s.RemovePacketForwarder(u, channelFwd{2, ch})
|
||||
s.RemovePacketForwarder(u, channelFwd{1, ch})
|
||||
s.AddPacketForwarder(u, channelFwd{1, ch})
|
||||
s.RemovePacketForwarder(u, channelFwd{3, ch})
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Number of messages is chosen arbitrarily, just for this loop to
|
||||
// run long enough concurrently with {Add,Remove}PacketForwarder loop above.
|
||||
numMsgs := 5000
|
||||
var fwd PacketForwarder
|
||||
for i := 0; i < numMsgs; i++ {
|
||||
s.mu.Lock()
|
||||
fwd = s.clientsMesh[u]
|
||||
s.mu.Unlock()
|
||||
fwd.ForwardPacket(u, u, []byte(strconv.Itoa(i)))
|
||||
}
|
||||
|
||||
cancel()
|
||||
wg.Wait()
|
||||
if received != numMsgs {
|
||||
t.Errorf("expected %d messages to be forwarded; got %d", numMsgs, received)
|
||||
}
|
||||
}
|
||||
func TestMetaCert(t *testing.T) {
|
||||
priv := key.NewNode()
|
||||
pub := priv.Public()
|
||||
|
||||
Reference in New Issue
Block a user