|
|
|
|
@ -5,6 +5,7 @@ |
|
|
|
|
package magicsock |
|
|
|
|
|
|
|
|
|
import ( |
|
|
|
|
"bytes" |
|
|
|
|
"fmt" |
|
|
|
|
"log" |
|
|
|
|
"net" |
|
|
|
|
@ -13,6 +14,9 @@ import ( |
|
|
|
|
"testing" |
|
|
|
|
"time" |
|
|
|
|
|
|
|
|
|
"github.com/tailscale/wireguard-go/device" |
|
|
|
|
"github.com/tailscale/wireguard-go/tun/tuntest" |
|
|
|
|
"github.com/tailscale/wireguard-go/wgcfg" |
|
|
|
|
"tailscale.com/stun" |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
@ -179,3 +183,174 @@ func runSTUN(pc net.PacketConn, stats *stunStats) { |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func makeConfigs(t *testing.T, ports []uint16) []wgcfg.Config { |
|
|
|
|
t.Helper() |
|
|
|
|
|
|
|
|
|
var privKeys []wgcfg.PrivateKey |
|
|
|
|
var addresses [][]wgcfg.CIDR |
|
|
|
|
|
|
|
|
|
for i := range ports { |
|
|
|
|
privKey, err := wgcfg.NewPrivateKey() |
|
|
|
|
if err != nil { |
|
|
|
|
t.Fatal(err) |
|
|
|
|
} |
|
|
|
|
privKeys = append(privKeys, privKey) |
|
|
|
|
|
|
|
|
|
addresses = append(addresses, []wgcfg.CIDR{ |
|
|
|
|
parseCIDR(t, fmt.Sprintf("1.0.0.%d/32", i+1)), |
|
|
|
|
}) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
var cfgs []wgcfg.Config |
|
|
|
|
for i, port := range ports { |
|
|
|
|
cfg := wgcfg.Config{ |
|
|
|
|
Name: fmt.Sprintf("peer%d", i+1), |
|
|
|
|
PrivateKey: privKeys[i], |
|
|
|
|
Addresses: addresses[i], |
|
|
|
|
ListenPort: port, |
|
|
|
|
} |
|
|
|
|
for peerNum, port := range ports { |
|
|
|
|
if peerNum == i { |
|
|
|
|
continue |
|
|
|
|
} |
|
|
|
|
peer := wgcfg.Peer{ |
|
|
|
|
PublicKey: privKeys[peerNum].Public(), |
|
|
|
|
AllowedIPs: addresses[peerNum], |
|
|
|
|
Endpoints: []wgcfg.Endpoint{{ |
|
|
|
|
Host: "127.0.0.1", |
|
|
|
|
Port: port, |
|
|
|
|
}}, |
|
|
|
|
} |
|
|
|
|
cfg.Peers = append(cfg.Peers, peer) |
|
|
|
|
} |
|
|
|
|
cfgs = append(cfgs, cfg) |
|
|
|
|
} |
|
|
|
|
return cfgs |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func parseCIDR(t *testing.T, addr string) wgcfg.CIDR { |
|
|
|
|
t.Helper() |
|
|
|
|
cidr, err := wgcfg.ParseCIDR(addr) |
|
|
|
|
if err != nil { |
|
|
|
|
t.Fatal(err) |
|
|
|
|
} |
|
|
|
|
return *cidr |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func TestTwoDevicePing(t *testing.T) { |
|
|
|
|
stunAddr, stunCleanupFn := serveSTUN(t) |
|
|
|
|
defer stunCleanupFn() |
|
|
|
|
|
|
|
|
|
epCh1 := make(chan []string, 16) |
|
|
|
|
conn1, err := Listen(Options{ |
|
|
|
|
STUN: []string{stunAddr.String()}, |
|
|
|
|
EndpointsFunc: func(eps []string) { |
|
|
|
|
epCh1 <- eps |
|
|
|
|
}, |
|
|
|
|
}) |
|
|
|
|
if err != nil { |
|
|
|
|
t.Fatal(err) |
|
|
|
|
} |
|
|
|
|
defer conn1.Close() |
|
|
|
|
|
|
|
|
|
epCh2 := make(chan []string, 16) |
|
|
|
|
conn2, err := Listen(Options{ |
|
|
|
|
STUN: []string{stunAddr.String()}, |
|
|
|
|
EndpointsFunc: func(eps []string) { |
|
|
|
|
epCh2 <- eps |
|
|
|
|
}, |
|
|
|
|
}) |
|
|
|
|
if err != nil { |
|
|
|
|
t.Fatal(err) |
|
|
|
|
} |
|
|
|
|
defer conn2.Close() |
|
|
|
|
|
|
|
|
|
ports := []uint16{conn1.LocalPort(), conn2.LocalPort()} |
|
|
|
|
cfgs := makeConfigs(t, ports) |
|
|
|
|
|
|
|
|
|
uapi1, _ := cfgs[0].ToUAPI() |
|
|
|
|
t.Logf("cfg0: %v", uapi1) |
|
|
|
|
uapi2, _ := cfgs[1].ToUAPI() |
|
|
|
|
t.Logf("cfg1: %v", uapi2) |
|
|
|
|
|
|
|
|
|
tun1 := tuntest.NewChannelTUN() |
|
|
|
|
dev1 := device.NewDevice(tun1.TUN(), &device.DeviceOptions{ |
|
|
|
|
Logger: device.NewLogger(device.LogLevelDebug, "dev1: "), |
|
|
|
|
CreateEndpoint: conn1.CreateEndpoint, |
|
|
|
|
CreateBind: conn1.CreateBind, |
|
|
|
|
SkipBindUpdate: true, |
|
|
|
|
}) |
|
|
|
|
dev1.Up() |
|
|
|
|
//defer dev1.Close() TODO(crawshaw): this hangs
|
|
|
|
|
if err := dev1.Reconfig(&cfgs[0]); err != nil { |
|
|
|
|
t.Fatal(err) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
tun2 := tuntest.NewChannelTUN() |
|
|
|
|
dev2 := device.NewDevice(tun2.TUN(), &device.DeviceOptions{ |
|
|
|
|
Logger: device.NewLogger(device.LogLevelDebug, "dev2: "), |
|
|
|
|
CreateEndpoint: conn2.CreateEndpoint, |
|
|
|
|
CreateBind: conn2.CreateBind, |
|
|
|
|
SkipBindUpdate: true, |
|
|
|
|
}) |
|
|
|
|
dev2.Up() |
|
|
|
|
//defer dev2.Close() TODO(crawshaw): this hangs
|
|
|
|
|
if err := dev2.Reconfig(&cfgs[1]); err != nil { |
|
|
|
|
t.Fatal(err) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
ping1 := func(t *testing.T) { |
|
|
|
|
t.Helper() |
|
|
|
|
|
|
|
|
|
msg2to1 := tuntest.Ping(net.ParseIP("1.0.0.1"), net.ParseIP("1.0.0.2")) |
|
|
|
|
tun2.Outbound <- msg2to1 |
|
|
|
|
select { |
|
|
|
|
case msgRecv := <-tun1.Inbound: |
|
|
|
|
if !bytes.Equal(msg2to1, msgRecv) { |
|
|
|
|
t.Error("ping did not transit correctly") |
|
|
|
|
} |
|
|
|
|
case <-time.After(1 * time.Second): |
|
|
|
|
t.Error("ping did not transit") |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
ping2 := func(t *testing.T) { |
|
|
|
|
t.Helper() |
|
|
|
|
|
|
|
|
|
msg1to2 := tuntest.Ping(net.ParseIP("1.0.0.2"), net.ParseIP("1.0.0.1")) |
|
|
|
|
tun1.Outbound <- msg1to2 |
|
|
|
|
select { |
|
|
|
|
case msgRecv := <-tun2.Inbound: |
|
|
|
|
if !bytes.Equal(msg1to2, msgRecv) { |
|
|
|
|
t.Error("return ping did not transit correctly") |
|
|
|
|
} |
|
|
|
|
case <-time.After(1 * time.Second): |
|
|
|
|
t.Error("return ping did not transit") |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
t.Run("ping 1.0.0.1", func(t *testing.T) { ping1(t) }) |
|
|
|
|
t.Run("ping 1.0.0.2", func(t *testing.T) { ping2(t) }) |
|
|
|
|
t.Run("ping 1.0.0.2 via SendPacket", func(t *testing.T) { |
|
|
|
|
msg1to2 := tuntest.Ping(net.ParseIP("1.0.0.2"), net.ParseIP("1.0.0.1")) |
|
|
|
|
if err := dev1.SendPacket(msg1to2); err != nil { |
|
|
|
|
t.Fatal(err) |
|
|
|
|
} |
|
|
|
|
select { |
|
|
|
|
case msgRecv := <-tun2.Inbound: |
|
|
|
|
if !bytes.Equal(msg1to2, msgRecv) { |
|
|
|
|
t.Error("return ping did not transit correctly") |
|
|
|
|
} |
|
|
|
|
case <-time.After(1 * time.Second): |
|
|
|
|
t.Error("return ping did not transit") |
|
|
|
|
} |
|
|
|
|
}) |
|
|
|
|
|
|
|
|
|
t.Run("no-op dev1 reconfig", func(t *testing.T) { |
|
|
|
|
if err := dev1.Reconfig(&cfgs[0]); err != nil { |
|
|
|
|
t.Fatal(err) |
|
|
|
|
} |
|
|
|
|
ping1(t) |
|
|
|
|
ping2(t) |
|
|
|
|
}) |
|
|
|
|
} |
|
|
|
|
|