wgengine/magicsock: add address selection for wireguard only endpoints (#7979)

This change introduces address selection for wireguard only endpoints.
If a endpoint has not been used before, an address is randomly selected
to be used based on information we know about, such as if they are able
to use IPv4 or IPv6. When an address is initially selected, we also
initiate a new ICMP ping to the endpoints addresses to determine which
endpoint offers the best latency. This information is then used to
update which endpoint we should be using based on the best possible
route. If the latency is the same for a IPv4 and an IPv6 address, IPv6
will be used.

Updates #7826

Signed-off-by: Charlotte Brandhorst-Satzkorn <charlotte@tailscale.com>
This commit is contained in:
Charlotte Brandhorst-Satzkorn
2023-05-02 17:49:56 -07:00
committed by GitHub
parent c1e6888fc7
commit ddb4040aa0
3 changed files with 714 additions and 74 deletions
+474 -14
View File
@@ -23,6 +23,7 @@ import (
"strconv"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
"unsafe"
@@ -33,6 +34,8 @@ import (
"go4.org/mem"
"golang.org/x/exp/maps"
"golang.org/x/exp/slices"
"golang.org/x/net/icmp"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
"tailscale.com/cmd/testwrapper/flakytest"
"tailscale.com/derp"
@@ -42,11 +45,13 @@ import (
"tailscale.com/net/connstats"
"tailscale.com/net/netaddr"
"tailscale.com/net/packet"
"tailscale.com/net/ping"
"tailscale.com/net/stun/stuntest"
"tailscale.com/net/tstun"
"tailscale.com/tailcfg"
"tailscale.com/tstest"
"tailscale.com/tstest/natlab"
"tailscale.com/tstime/mono"
"tailscale.com/types/key"
"tailscale.com/types/logger"
"tailscale.com/types/netlogtype"
@@ -2117,9 +2122,8 @@ func Test_batchingUDPConn_coalesceMessages(t *testing.T) {
}
// newWireguard starts up a new wireguard-go device attached to a test tun, and
// returns the device, tun and netpoint address. To add peers call device.IpcSet
// with UAPI instructions.
func newWireguard(t *testing.T, uapi string, aips []netip.Prefix) (*device.Device, *tuntest.ChannelTUN, netip.AddrPort) {
// returns the device, tun and endpoint port. To add peers call device.IpcSet with UAPI instructions.
func newWireguard(t *testing.T, uapi string, aips []netip.Prefix) (*device.Device, *tuntest.ChannelTUN, uint16) {
wgtun := tuntest.NewChannelTUN()
wglogf := func(f string, args ...any) {
t.Logf("wg-go: "+f, args...)
@@ -2138,8 +2142,7 @@ func newWireguard(t *testing.T, uapi string, aips []netip.Prefix) (*device.Devic
t.Fatal(err)
}
var wgEp netip.AddrPort
var port uint16
s, err := wgdev.IpcGet()
if err != nil {
t.Fatal(err)
@@ -2151,17 +2154,16 @@ func newWireguard(t *testing.T, uapi string, aips []netip.Prefix) (*device.Devic
}
k, v, _ := strings.Cut(line, "=")
if k == "listen_port" {
wgEp = netip.MustParseAddrPort("127.0.0.1:" + v)
p, err := strconv.ParseUint(v, 10, 16)
if err != nil {
panic(err)
}
port = uint16(p)
break
}
}
if !wgEp.IsValid() {
t.Fatalf("failed to get endpoint out of wg-go")
}
t.Logf("wg-go endpoint: %s", wgEp)
return wgdev, wgtun, wgEp
return wgdev, wgtun, port
}
func TestIsWireGuardOnlyPeer(t *testing.T) {
@@ -2176,8 +2178,9 @@ func TestIsWireGuardOnlyPeer(t *testing.T) {
uapi := fmt.Sprintf("private_key=%s\npublic_key=%s\nallowed_ip=%s\n\n",
wgkey.UntypedHexString(), tskey.Public().UntypedHexString(), tsaip.String())
wgdev, wgtun, wgEp := newWireguard(t, uapi, []netip.Prefix{wgaip})
wgdev, wgtun, port := newWireguard(t, uapi, []netip.Prefix{wgaip})
defer wgdev.Close()
wgEp := netip.AddrPortFrom(netip.MustParseAddr("127.0.0.1"), port)
m := newMagicStackWithKey(t, t.Logf, localhostListener{}, derpMap, tskey)
defer m.Close()
@@ -2233,8 +2236,9 @@ func TestIsWireGuardOnlyPeerWithMasquerade(t *testing.T) {
uapi := fmt.Sprintf("private_key=%s\npublic_key=%s\nallowed_ip=%s\n\n",
wgkey.UntypedHexString(), tskey.Public().UntypedHexString(), masqip.String())
wgdev, wgtun, wgEp := newWireguard(t, uapi, []netip.Prefix{wgaip})
wgdev, wgtun, port := newWireguard(t, uapi, []netip.Prefix{wgaip})
defer wgdev.Close()
wgEp := netip.AddrPortFrom(netip.MustParseAddr("127.0.0.1"), port)
m := newMagicStackWithKey(t, t.Logf, localhostListener{}, derpMap, tskey)
defer m.Close()
@@ -2397,3 +2401,459 @@ func TestEndpointTracker(t *testing.T) {
}
}
}
// applyNetworkMap is a test helper that sets the network map and
// configures WG.
func applyNetworkMap(t *testing.T, m *magicStack, nm *netmap.NetworkMap) {
t.Helper()
m.conn.SetNetworkMap(nm)
// Make sure we can't use v6 to avoid test failures.
m.conn.noV6.Store(true)
// Turn the network map into a wireguard config (for the tailscale internal wireguard device).
cfg, err := nmcfg.WGCfg(nm, t.Logf, netmap.AllowSingleHosts|netmap.AllowSubnetRoutes, "")
if err != nil {
t.Fatal(err)
}
// Apply the wireguard config to the tailscale internal wireguard device.
if err := m.Reconfig(cfg); err != nil {
t.Fatal(err)
}
}
func TestIsWireGuardOnlyPickEndpointByPing(t *testing.T) {
clock := &tstest.Clock{}
derpMap, cleanup := runDERPAndStun(t, t.Logf, localhostListener{}, netaddr.IPv4(127, 0, 0, 1))
defer cleanup()
// Create a TS client.
tskey := key.NewNode()
tsaip := netip.MustParsePrefix("100.111.222.111/32")
// Create a WireGuard only client.
wgkey := key.NewNode()
wgaip := netip.MustParsePrefix("100.222.111.222/32")
uapi := fmt.Sprintf("private_key=%s\npublic_key=%s\nallowed_ip=%s\n\n",
wgkey.UntypedHexString(), tskey.Public().UntypedHexString(), tsaip.String())
wgdev, wgtun, port := newWireguard(t, uapi, []netip.Prefix{wgaip})
defer wgdev.Close()
wgEp := netip.AddrPortFrom(netip.MustParseAddr("127.0.0.1"), port)
wgEp2 := netip.AddrPortFrom(netip.MustParseAddr("127.0.0.2"), port)
m := newMagicStackWithKey(t, t.Logf, localhostListener{}, derpMap, tskey)
defer m.Close()
pr := newPingResponder(t)
// Get a destination address which includes a port, so that UDP packets flow
// to the correct place, the mockPinger will use this to direct port-less
// pings to this place.
pingDest := pr.LocalAddr()
// Create and start the pinger that is used for the
// wireguard only endpoint pings
p, closeP := mockPinger(t, clock, pingDest)
defer closeP()
m.conn.wgPinger.Set(p)
// Create an IPv6 endpoint which should not receive any traffic.
v6, err := net.ListenUDP("udp6", &net.UDPAddr{IP: net.ParseIP("::"), Port: 0})
if err != nil {
t.Fatal(err)
}
badEpRecv := make(chan []byte)
go func() {
defer v6.Close()
for {
b := make([]byte, 1500)
n, _, err := v6.ReadFrom(b)
if err != nil {
close(badEpRecv)
return
}
badEpRecv <- b[:n]
}
}()
wgEpV6 := netip.MustParseAddrPort(v6.LocalAddr().String())
nm := &netmap.NetworkMap{
Name: "ts",
PrivateKey: m.privateKey,
NodeKey: m.privateKey.Public(),
Addresses: []netip.Prefix{tsaip},
Peers: []*tailcfg.Node{
{
Key: wgkey.Public(),
Endpoints: []string{wgEp.String(), wgEp2.String(), wgEpV6.String()},
IsWireGuardOnly: true,
Addresses: []netip.Prefix{wgaip},
AllowedIPs: []netip.Prefix{wgaip},
},
},
}
applyNetworkMap(t, m, nm)
buf := tuntest.Ping(wgaip.Addr(), tsaip.Addr())
m.tun.Outbound <- buf
select {
case p := <-wgtun.Inbound:
if !bytes.Equal(p, buf) {
t.Errorf("got unexpected packet: %x", p)
}
case <-badEpRecv:
t.Fatal("got packet on bad endpoint")
case <-time.After(5 * time.Second):
t.Fatal("no packet after 1s")
}
pi, ok := m.conn.peerMap.byNodeKey[wgkey.Public()]
if !ok {
t.Fatal("wgkey doesn't exist in peer map")
}
// Check that we got a valid address set on the first send - this
// will be randomly selected, but because we have noV6 set to true,
// it will be the IPv4 address.
if !pi.ep.bestAddr.Addr().IsValid() {
t.Fatal("bestaddr was nil")
}
if pi.ep.trustBestAddrUntil.Before(mono.Now().Add(14 * time.Second)) {
t.Errorf("trustBestAddrUntil time wasn't set to 15 seconds in the future: got %v", pi.ep.trustBestAddrUntil)
}
for ipp, state := range pi.ep.endpointState {
if ipp == wgEp {
if len(state.recentPongs) != 1 {
t.Errorf("IPv4 address did not have a recentPong entry: got %v, want %v", len(state.recentPongs), 1)
}
// Set the latency extremely low so we choose this endpoint during the next
// addrForSendLocked call.
state.recentPongs[state.recentPong].latency = time.Nanosecond
}
if ipp == wgEp2 {
if len(state.recentPongs) != 1 {
t.Errorf("IPv4 address did not have a recentPong entry: got %v, want %v", len(state.recentPongs), 1)
}
// Set the latency extremely high so we dont choose endpoint during the next
// addrForSendLocked call.
state.recentPongs[state.recentPong].latency = time.Second
}
if ipp == wgEpV6 && len(state.recentPongs) != 0 {
t.Fatal("IPv6 should not have recentPong: IPv6 is not useable")
}
}
// Set trustBestAddrUnitl to now, so addrForSendLocked goes through the
// latency selection flow.
pi.ep.trustBestAddrUntil = mono.Now().Add(-time.Second)
buf = tuntest.Ping(wgaip.Addr(), tsaip.Addr())
m.tun.Outbound <- buf
select {
case p := <-wgtun.Inbound:
if !bytes.Equal(p, buf) {
t.Errorf("got unexpected packet: %x", p)
}
case <-badEpRecv:
t.Fatal("got packet on bad endpoint")
case <-time.After(5 * time.Second):
t.Fatal("no packet after 1s")
}
// Check that we have responded to a WireGuard only ping twice.
if pr.responseCount != 2 {
t.Fatal("pingresponder response count was not 2", pr.responseCount)
}
pi, ok = m.conn.peerMap.byNodeKey[wgkey.Public()]
if !ok {
t.Fatal("wgkey doesn't exist in peer map")
}
if !pi.ep.bestAddr.Addr().IsValid() {
t.Error("no bestAddr address was set")
}
if pi.ep.bestAddr.Addr() != wgEp.Addr() {
t.Errorf("bestAddr was not set to the expected IPv4 address: got %v, want %v", pi.ep.bestAddr.Addr().String(), wgEp.Addr())
}
if pi.ep.trustBestAddrUntil.IsZero() {
t.Fatal("trustBestAddrUntil was not set")
}
if pi.ep.trustBestAddrUntil.Before(mono.Now().Add(55 * time.Minute)) {
// Set to 55 minutes incase of sloooow tests.
t.Errorf("trustBestAddrUntil time wasn't set to an hour in the future: got %v", pi.ep.trustBestAddrUntil)
}
}
// udpingPacketConn will convert potentially ICMP destination addrs to UDP
// destination addrs in WriteTo so that a test that is intending to send ICMP
// traffic will instead send UDP traffic, without the higher level Pinger being
// aware of this difference.
type udpingPacketConn struct {
net.PacketConn
// destPort will be configured by the test to be the peer expected to respond to a ping.
destPort uint16
}
func (u *udpingPacketConn) WriteTo(body []byte, dest net.Addr) (int, error) {
switch d := dest.(type) {
case *net.IPAddr:
udpAddr := &net.UDPAddr{
IP: d.IP,
Port: int(u.destPort),
Zone: d.Zone,
}
return u.PacketConn.WriteTo(body, udpAddr)
}
return 0, fmt.Errorf("unimplemented udpingPacketConn for %T", dest)
}
type mockListenPacketer struct {
conn4 net.PacketConn
conn6 net.PacketConn
}
func (mlp *mockListenPacketer) ListenPacket(ctx context.Context, typ string, addr string) (net.PacketConn, error) {
switch typ {
case "ip4:icmp":
return mlp.conn4, nil
case "ip6:icmp":
return mlp.conn6, nil
}
return nil, fmt.Errorf("unimplemented ListenPacketForTesting for %s", typ)
}
func mockPinger(t *testing.T, clock *tstest.Clock, dest net.Addr) (*ping.Pinger, func()) {
ctx := context.Background()
dIPP := netip.MustParseAddrPort(dest.String())
// In tests, we use UDP so that we can test without being root; this
// doesn't matter because we mock out the ICMP reply below to be a real
// ICMP echo reply packet.
conn4, err := net.ListenPacket("udp4", "127.0.0.1:0")
if err != nil {
t.Fatalf("net.ListenPacket: %v", err)
}
conn6, err := net.ListenPacket("udp6", "[::]:0")
if err != nil {
t.Fatalf("net.ListenPacket: %v", err)
}
conn4 = &udpingPacketConn{
PacketConn: conn4,
destPort: dIPP.Port(),
}
conn6 = &udpingPacketConn{
PacketConn: conn6,
destPort: dIPP.Port(),
}
p := ping.New(ctx, t.Logf, &mockListenPacketer{conn4: conn4, conn6: conn6})
done := func() {
if err := p.Close(); err != nil {
t.Errorf("error on close: %v", err)
}
}
return p, done
}
type pingResponder struct {
net.PacketConn
running atomic.Bool
responseCount int
}
func (p *pingResponder) start() {
buf := make([]byte, 1500)
for p.running.Load() {
n, addr, err := p.PacketConn.ReadFrom(buf)
if err != nil {
return
}
m, err := icmp.ParseMessage(1, buf[:n])
if err != nil {
panic("got a non-ICMP message:" + fmt.Sprintf("%x", m))
}
r := icmp.Message{
Type: ipv4.ICMPTypeEchoReply,
Code: m.Code,
Body: m.Body,
}
b, err := r.Marshal(nil)
if err != nil {
panic(err)
}
if _, err := p.PacketConn.WriteTo(b, addr); err != nil {
panic(err)
}
p.responseCount++
}
}
func (p *pingResponder) stop() {
p.running.Store(false)
p.Close()
}
func newPingResponder(t *testing.T) *pingResponder {
t.Helper()
// global binds should be both IPv4 and IPv6 (if our test platforms don't,
// we might need to bind two sockets instead)
conn, err := net.ListenPacket("udp", ":")
if err != nil {
t.Fatal(err)
}
pr := &pingResponder{PacketConn: conn}
pr.running.Store(true)
go pr.start()
t.Cleanup(pr.stop)
return pr
}
func TestAddrForSendLockedForWireGuardOnly(t *testing.T) {
testTime := mono.Now()
type endpointDetails struct {
addrPort netip.AddrPort
latency time.Duration
}
wgTests := []struct {
name string
noV4 bool
noV6 bool
sendWGPing bool
ep []endpointDetails
want netip.AddrPort
}{
{
name: "choose lowest latency for useable IPv4 and IPv6",
sendWGPing: true,
ep: []endpointDetails{
{
addrPort: netip.MustParseAddrPort("1.1.1.1:111"),
latency: 100 * time.Millisecond,
},
{
addrPort: netip.MustParseAddrPort("[2345:0425:2CA1:0000:0000:0567:5673:23b5]:222"),
latency: 10 * time.Millisecond,
},
},
want: netip.MustParseAddrPort("[2345:0425:2CA1:0000:0000:0567:5673:23b5]:222"),
},
{
name: "choose IPv4 when IPv6 is not useable",
sendWGPing: false,
noV6: true,
ep: []endpointDetails{
{
addrPort: netip.MustParseAddrPort("1.1.1.1:111"),
latency: 100 * time.Millisecond,
},
{
addrPort: netip.MustParseAddrPort("[1::1]:567"),
},
},
want: netip.MustParseAddrPort("1.1.1.1:111"),
},
{
name: "choose IPv6 when IPv4 is not useable",
sendWGPing: false,
noV4: true,
ep: []endpointDetails{
{
addrPort: netip.MustParseAddrPort("1.1.1.1:111"),
},
{
addrPort: netip.MustParseAddrPort("[1::1]:567"),
latency: 100 * time.Millisecond,
},
},
want: netip.MustParseAddrPort("[1::1]:567"),
},
{
name: "choose IPv6 address when latency is the same for v4 and v6",
sendWGPing: true,
ep: []endpointDetails{
{
addrPort: netip.MustParseAddrPort("1.1.1.1:111"),
latency: 100 * time.Millisecond,
},
{
addrPort: netip.MustParseAddrPort("[1::1]:567"),
latency: 100 * time.Millisecond,
},
},
want: netip.MustParseAddrPort("[1::1]:567"),
},
}
for _, test := range wgTests {
endpoint := &endpoint{
isWireguardOnly: true,
endpointState: map[netip.AddrPort]*endpointState{},
c: &Conn{
noV4: atomic.Bool{},
noV6: atomic.Bool{},
},
}
endpoint.c.noV4.Store(test.noV4)
endpoint.c.noV6.Store(test.noV6)
for _, epd := range test.ep {
endpoint.endpointState[epd.addrPort] = &endpointState{}
}
udpAddr, _, shouldPing := endpoint.addrForSendLocked(testTime)
if !udpAddr.IsValid() {
t.Error("udpAddr returned is not valid")
}
if shouldPing != test.sendWGPing {
t.Errorf("addrForSendLocked did not indiciate correct ping state; got %v, want %v", shouldPing, test.sendWGPing)
}
for _, epd := range test.ep {
state, ok := endpoint.endpointState[epd.addrPort]
if !ok {
t.Errorf("addr does not exist in endpoint state map")
}
latency, ok := state.latencyLocked()
if ok {
t.Errorf("latency was set for %v: %v", epd.addrPort, latency)
}
state.recentPongs = append(state.recentPongs, pongReply{
latency: epd.latency,
})
state.recentPong = 0
}
udpAddr, _, shouldPing = endpoint.addrForSendLocked(testTime.Add(2 * time.Minute))
if udpAddr != test.want {
t.Errorf("udpAddr returned is not expected: got %v, want %v", udpAddr, test.want)
}
if shouldPing {
t.Error("addrForSendLocked should not indicate ping is required")
}
if endpoint.bestAddr.AddrPort != test.want {
t.Errorf("bestAddr.AddrPort is not as expected: got %v, want %v", endpoint.bestAddr.AddrPort, test.want)
}
}
}