This is mostly code movement from the wireguard-go repo.
Most of the new wgcfg package corresponds to the wireguard-go wgcfg package.
wgengine/wgcfg/device{_test}.go was device/config{_test}.go.
There were substantive but simple changes to device_test.go to remove
internal package device references.
The API of device.Config (now wgcfg.DeviceConfig) grew an error return;
we previously logged the error and threw it away.
Signed-off-by: Josh Bleecher Snyder <josh@tailscale.com>
main
parent
0bc73f8e4f
commit
fe7c3e9c17
@ -0,0 +1,67 @@ |
||||
// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
// Package wgcfg has types and a parser for representing WireGuard config.
|
||||
package wgcfg |
||||
|
||||
import ( |
||||
"inet.af/netaddr" |
||||
) |
||||
|
||||
// Config is a WireGuard configuration.
|
||||
// It only supports the set of things Tailscale uses.
|
||||
type Config struct { |
||||
Name string |
||||
PrivateKey PrivateKey |
||||
Addresses []netaddr.IPPrefix |
||||
ListenPort uint16 |
||||
MTU uint16 |
||||
DNS []netaddr.IP |
||||
Peers []Peer |
||||
} |
||||
|
||||
type Peer struct { |
||||
PublicKey Key |
||||
AllowedIPs []netaddr.IPPrefix |
||||
Endpoints string // comma-separated host/port pairs: "1.2.3.4:56,[::]:80"
|
||||
PersistentKeepalive uint16 |
||||
} |
||||
|
||||
// Copy makes a deep copy of Config.
|
||||
// The result aliases no memory with the original.
|
||||
func (cfg Config) Copy() Config { |
||||
res := cfg |
||||
if res.Addresses != nil { |
||||
res.Addresses = append([]netaddr.IPPrefix{}, res.Addresses...) |
||||
} |
||||
if res.DNS != nil { |
||||
res.DNS = append([]netaddr.IP{}, res.DNS...) |
||||
} |
||||
peers := make([]Peer, 0, len(res.Peers)) |
||||
for _, peer := range res.Peers { |
||||
peers = append(peers, peer.Copy()) |
||||
} |
||||
res.Peers = peers |
||||
return res |
||||
} |
||||
|
||||
// Copy makes a deep copy of Peer.
|
||||
// The result aliases no memory with the original.
|
||||
func (peer Peer) Copy() Peer { |
||||
res := peer |
||||
if res.AllowedIPs != nil { |
||||
res.AllowedIPs = append([]netaddr.IPPrefix{}, res.AllowedIPs...) |
||||
} |
||||
return res |
||||
} |
||||
|
||||
// PeerWithKey returns the Peer with key k and reports whether it was found.
|
||||
func (config Config) PeerWithKey(k Key) (Peer, bool) { |
||||
for _, p := range config.Peers { |
||||
if p.PublicKey == k { |
||||
return p, true |
||||
} |
||||
} |
||||
return Peer{}, false |
||||
} |
||||
@ -0,0 +1,61 @@ |
||||
// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package wgcfg |
||||
|
||||
import ( |
||||
"io" |
||||
"sort" |
||||
|
||||
"github.com/tailscale/wireguard-go/device" |
||||
"tailscale.com/types/logger" |
||||
) |
||||
|
||||
func DeviceConfig(d *device.Device) (*Config, error) { |
||||
r, w := io.Pipe() |
||||
errc := make(chan error, 1) |
||||
go func() { |
||||
errc <- d.IpcGetOperation(w) |
||||
w.Close() |
||||
}() |
||||
cfg, err := FromUAPI(r) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
if err := <-errc; err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
sort.Slice(cfg.Peers, func(i, j int) bool { |
||||
return cfg.Peers[i].PublicKey.LessThan(&cfg.Peers[j].PublicKey) |
||||
}) |
||||
return cfg, nil |
||||
} |
||||
|
||||
// ReconfigDevice replaces the existing device configuration with cfg.
|
||||
func ReconfigDevice(d *device.Device, cfg *Config, logf logger.Logf) (err error) { |
||||
defer func() { |
||||
if err != nil { |
||||
logf("wgcfg.Reconfig failed: %v", err) |
||||
} |
||||
}() |
||||
|
||||
prev, err := DeviceConfig(d) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
|
||||
r, w := io.Pipe() |
||||
errc := make(chan error) |
||||
go func() { |
||||
errc <- d.IpcSetOperation(r) |
||||
}() |
||||
|
||||
err = cfg.ToUAPI(w, prev) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
w.Close() |
||||
return <-errc |
||||
} |
||||
@ -0,0 +1,242 @@ |
||||
// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package wgcfg |
||||
|
||||
import ( |
||||
"bufio" |
||||
"bytes" |
||||
"io" |
||||
"os" |
||||
"sort" |
||||
"strings" |
||||
"sync" |
||||
"testing" |
||||
|
||||
"github.com/tailscale/wireguard-go/device" |
||||
"github.com/tailscale/wireguard-go/tun" |
||||
"inet.af/netaddr" |
||||
"tailscale.com/types/wgkey" |
||||
) |
||||
|
||||
func TestDeviceConfig(t *testing.T) { |
||||
newPrivateKey := func() (Key, PrivateKey) { |
||||
t.Helper() |
||||
pk, err := wgkey.NewPrivate() |
||||
if err != nil { |
||||
t.Fatal(err) |
||||
} |
||||
return Key(pk.Public()), PrivateKey(pk) |
||||
} |
||||
k1, pk1 := newPrivateKey() |
||||
ip1 := netaddr.MustParseIPPrefix("10.0.0.1/32") |
||||
|
||||
k2, pk2 := newPrivateKey() |
||||
ip2 := netaddr.MustParseIPPrefix("10.0.0.2/32") |
||||
|
||||
k3, _ := newPrivateKey() |
||||
ip3 := netaddr.MustParseIPPrefix("10.0.0.3/32") |
||||
|
||||
cfg1 := &Config{ |
||||
PrivateKey: PrivateKey(pk1), |
||||
Peers: []Peer{{ |
||||
PublicKey: k2, |
||||
AllowedIPs: []netaddr.IPPrefix{ip2}, |
||||
}}, |
||||
} |
||||
|
||||
cfg2 := &Config{ |
||||
PrivateKey: PrivateKey(pk2), |
||||
Peers: []Peer{{ |
||||
PublicKey: k1, |
||||
AllowedIPs: []netaddr.IPPrefix{ip1}, |
||||
PersistentKeepalive: 5, |
||||
}}, |
||||
} |
||||
|
||||
device1 := device.NewDevice(newNilTun(), &device.DeviceOptions{ |
||||
Logger: device.NewLogger(device.LogLevelError, "device1"), |
||||
}) |
||||
device2 := device.NewDevice(newNilTun(), &device.DeviceOptions{ |
||||
Logger: device.NewLogger(device.LogLevelError, "device2"), |
||||
}) |
||||
defer device1.Close() |
||||
defer device2.Close() |
||||
|
||||
cmp := func(t *testing.T, d *device.Device, want *Config) { |
||||
t.Helper() |
||||
got, err := DeviceConfig(d) |
||||
if err != nil { |
||||
t.Fatal(err) |
||||
} |
||||
prev := new(Config) |
||||
gotbuf := new(strings.Builder) |
||||
err = got.ToUAPI(gotbuf, prev) |
||||
gotStr := gotbuf.String() |
||||
if err != nil { |
||||
t.Errorf("got.ToUAPI(): error: %v", err) |
||||
return |
||||
} |
||||
wantbuf := new(strings.Builder) |
||||
err = want.ToUAPI(wantbuf, prev) |
||||
wantStr := wantbuf.String() |
||||
if err != nil { |
||||
t.Errorf("want.ToUAPI(): error: %v", err) |
||||
return |
||||
} |
||||
if gotStr != wantStr { |
||||
buf := new(bytes.Buffer) |
||||
w := bufio.NewWriter(buf) |
||||
if err := d.IpcGetOperation(w); err != nil { |
||||
t.Errorf("on error, could not IpcGetOperation: %v", err) |
||||
} |
||||
w.Flush() |
||||
t.Errorf("cfg:\n%s\n---- want:\n%s\n---- uapi:\n%s", gotStr, wantStr, buf.String()) |
||||
} |
||||
} |
||||
|
||||
t.Run("device1 config", func(t *testing.T) { |
||||
if err := ReconfigDevice(device1, cfg1, t.Logf); err != nil { |
||||
t.Fatal(err) |
||||
} |
||||
cmp(t, device1, cfg1) |
||||
}) |
||||
|
||||
t.Run("device2 config", func(t *testing.T) { |
||||
if err := ReconfigDevice(device2, cfg2, t.Logf); err != nil { |
||||
t.Fatal(err) |
||||
} |
||||
cmp(t, device2, cfg2) |
||||
}) |
||||
|
||||
// This is only to test that Config and Reconfig are properly synchronized.
|
||||
t.Run("device2 config/reconfig", func(t *testing.T) { |
||||
var wg sync.WaitGroup |
||||
wg.Add(2) |
||||
|
||||
go func() { |
||||
ReconfigDevice(device2, cfg2, t.Logf) |
||||
wg.Done() |
||||
}() |
||||
|
||||
go func() { |
||||
DeviceConfig(device2) |
||||
wg.Done() |
||||
}() |
||||
|
||||
wg.Wait() |
||||
}) |
||||
|
||||
t.Run("device1 modify peer", func(t *testing.T) { |
||||
cfg1.Peers[0].Endpoints = "1.2.3.4:12345" |
||||
if err := ReconfigDevice(device1, cfg1, t.Logf); err != nil { |
||||
t.Fatal(err) |
||||
} |
||||
cmp(t, device1, cfg1) |
||||
}) |
||||
|
||||
t.Run("device1 replace endpoint", func(t *testing.T) { |
||||
cfg1.Peers[0].Endpoints = "1.1.1.1:123" |
||||
if err := ReconfigDevice(device1, cfg1, t.Logf); err != nil { |
||||
t.Fatal(err) |
||||
} |
||||
cmp(t, device1, cfg1) |
||||
}) |
||||
|
||||
t.Run("device1 add new peer", func(t *testing.T) { |
||||
cfg1.Peers = append(cfg1.Peers, Peer{ |
||||
PublicKey: k3, |
||||
AllowedIPs: []netaddr.IPPrefix{ip3}, |
||||
}) |
||||
sort.Slice(cfg1.Peers, func(i, j int) bool { |
||||
return cfg1.Peers[i].PublicKey.LessThan(&cfg1.Peers[j].PublicKey) |
||||
}) |
||||
|
||||
origCfg, err := DeviceConfig(device1) |
||||
if err != nil { |
||||
t.Fatal(err) |
||||
} |
||||
|
||||
if err := ReconfigDevice(device1, cfg1, t.Logf); err != nil { |
||||
t.Fatal(err) |
||||
} |
||||
cmp(t, device1, cfg1) |
||||
|
||||
newCfg, err := DeviceConfig(device1) |
||||
if err != nil { |
||||
t.Fatal(err) |
||||
} |
||||
|
||||
peer0 := func(cfg *Config) Peer { |
||||
p, ok := cfg.PeerWithKey(k2) |
||||
if !ok { |
||||
t.Helper() |
||||
t.Fatal("failed to look up peer 2") |
||||
} |
||||
return p |
||||
} |
||||
peersEqual := func(p, q Peer) bool { |
||||
return p.PublicKey == q.PublicKey && p.PersistentKeepalive == q.PersistentKeepalive && |
||||
p.Endpoints == q.Endpoints && cidrsEqual(p.AllowedIPs, q.AllowedIPs) |
||||
} |
||||
if !peersEqual(peer0(origCfg), peer0(newCfg)) { |
||||
t.Error("reconfig modified old peer") |
||||
} |
||||
}) |
||||
|
||||
t.Run("device1 remove peer", func(t *testing.T) { |
||||
removeKey := cfg1.Peers[len(cfg1.Peers)-1].PublicKey |
||||
cfg1.Peers = cfg1.Peers[:len(cfg1.Peers)-1] |
||||
|
||||
if err := ReconfigDevice(device1, cfg1, t.Logf); err != nil { |
||||
t.Fatal(err) |
||||
} |
||||
cmp(t, device1, cfg1) |
||||
|
||||
newCfg, err := DeviceConfig(device1) |
||||
if err != nil { |
||||
t.Fatal(err) |
||||
} |
||||
|
||||
_, ok := newCfg.PeerWithKey(removeKey) |
||||
if ok { |
||||
t.Error("reconfig failed to remove peer") |
||||
} |
||||
}) |
||||
} |
||||
|
||||
// TODO: replace with a loopback tunnel
|
||||
type nilTun struct { |
||||
events chan tun.Event |
||||
closed chan struct{} |
||||
} |
||||
|
||||
func newNilTun() tun.Device { |
||||
return &nilTun{ |
||||
events: make(chan tun.Event), |
||||
closed: make(chan struct{}), |
||||
} |
||||
} |
||||
|
||||
func (t *nilTun) File() *os.File { return nil } |
||||
func (t *nilTun) Flush() error { return nil } |
||||
func (t *nilTun) MTU() (int, error) { return 1420, nil } |
||||
func (t *nilTun) Name() (string, error) { return "niltun", nil } |
||||
func (t *nilTun) Events() chan tun.Event { return t.events } |
||||
|
||||
func (t *nilTun) Read(data []byte, offset int) (int, error) { |
||||
<-t.closed |
||||
return 0, io.EOF |
||||
} |
||||
|
||||
func (t *nilTun) Write(data []byte, offset int) (int, error) { |
||||
<-t.closed |
||||
return 0, io.EOF |
||||
} |
||||
|
||||
func (t *nilTun) Close() error { |
||||
close(t.events) |
||||
close(t.closed) |
||||
return nil |
||||
} |
||||
@ -0,0 +1,240 @@ |
||||
// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package wgcfg |
||||
|
||||
import ( |
||||
"bytes" |
||||
"crypto/rand" |
||||
"crypto/subtle" |
||||
"encoding/base64" |
||||
"encoding/hex" |
||||
"errors" |
||||
"fmt" |
||||
"strings" |
||||
|
||||
"golang.org/x/crypto/chacha20poly1305" |
||||
"golang.org/x/crypto/curve25519" |
||||
) |
||||
|
||||
const KeySize = 32 |
||||
|
||||
// Key is curve25519 key.
|
||||
// It is used by WireGuard to represent public and preshared keys.
|
||||
type Key [KeySize]byte |
||||
|
||||
// NewPresharedKey generates a new random key.
|
||||
func NewPresharedKey() (*Key, error) { |
||||
var k [KeySize]byte |
||||
_, err := rand.Read(k[:]) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
return (*Key)(&k), nil |
||||
} |
||||
|
||||
func ParseKey(b64 string) (*Key, error) { return parseKeyBase64(base64.StdEncoding, b64) } |
||||
|
||||
func ParseHexKey(s string) (Key, error) { |
||||
b, err := hex.DecodeString(s) |
||||
if err != nil { |
||||
return Key{}, &ParseError{"invalid hex key: " + err.Error(), s} |
||||
} |
||||
if len(b) != KeySize { |
||||
return Key{}, &ParseError{fmt.Sprintf("invalid hex key length: %d", len(b)), s} |
||||
} |
||||
|
||||
var key Key |
||||
copy(key[:], b) |
||||
return key, nil |
||||
} |
||||
|
||||
func ParsePrivateHexKey(v string) (PrivateKey, error) { |
||||
k, err := ParseHexKey(v) |
||||
if err != nil { |
||||
return PrivateKey{}, err |
||||
} |
||||
pk := PrivateKey(k) |
||||
if pk.IsZero() { |
||||
// Do not clamp a zero key, pass the zero through
|
||||
// (much like NaN propagation) so that IsZero reports
|
||||
// a useful result.
|
||||
return pk, nil |
||||
} |
||||
pk.clamp() |
||||
return pk, nil |
||||
} |
||||
|
||||
func (k Key) Base64() string { return base64.StdEncoding.EncodeToString(k[:]) } |
||||
func (k Key) String() string { return k.ShortString() } |
||||
func (k Key) HexString() string { return hex.EncodeToString(k[:]) } |
||||
func (k Key) Equal(k2 Key) bool { return subtle.ConstantTimeCompare(k[:], k2[:]) == 1 } |
||||
|
||||
func (k *Key) ShortString() string { |
||||
long := k.Base64() |
||||
return "[" + long[0:5] + "]" |
||||
} |
||||
|
||||
func (k *Key) IsZero() bool { |
||||
if k == nil { |
||||
return true |
||||
} |
||||
var zeros Key |
||||
return subtle.ConstantTimeCompare(zeros[:], k[:]) == 1 |
||||
} |
||||
|
||||
func (k *Key) MarshalJSON() ([]byte, error) { |
||||
if k == nil { |
||||
return []byte("null"), nil |
||||
} |
||||
buf := new(bytes.Buffer) |
||||
fmt.Fprintf(buf, `"%x"`, k[:]) |
||||
return buf.Bytes(), nil |
||||
} |
||||
|
||||
func (k *Key) UnmarshalJSON(b []byte) error { |
||||
if k == nil { |
||||
return errors.New("wgcfg.Key: UnmarshalJSON on nil pointer") |
||||
} |
||||
if len(b) < 3 || b[0] != '"' || b[len(b)-1] != '"' { |
||||
return errors.New("wgcfg.Key: UnmarshalJSON not given a string") |
||||
} |
||||
b = b[1 : len(b)-1] |
||||
key, err := ParseHexKey(string(b)) |
||||
if err != nil { |
||||
return fmt.Errorf("wgcfg.Key: UnmarshalJSON: %v", err) |
||||
} |
||||
copy(k[:], key[:]) |
||||
return nil |
||||
} |
||||
|
||||
func (a *Key) LessThan(b *Key) bool { |
||||
for i := range a { |
||||
if a[i] < b[i] { |
||||
return true |
||||
} else if a[i] > b[i] { |
||||
return false |
||||
} |
||||
} |
||||
return false |
||||
} |
||||
|
||||
// PrivateKey is curve25519 key.
|
||||
// It is used by WireGuard to represent private keys.
|
||||
type PrivateKey [KeySize]byte |
||||
|
||||
// NewPrivateKey generates a new curve25519 secret key.
|
||||
// It conforms to the format described on https://cr.yp.to/ecdh.html.
|
||||
func NewPrivateKey() (PrivateKey, error) { |
||||
k, err := NewPresharedKey() |
||||
if err != nil { |
||||
return PrivateKey{}, err |
||||
} |
||||
k[0] &= 248 |
||||
k[31] = (k[31] & 127) | 64 |
||||
return (PrivateKey)(*k), nil |
||||
} |
||||
|
||||
func ParsePrivateKey(b64 string) (*PrivateKey, error) { |
||||
k, err := parseKeyBase64(base64.StdEncoding, b64) |
||||
return (*PrivateKey)(k), err |
||||
} |
||||
|
||||
func (k *PrivateKey) String() string { return base64.StdEncoding.EncodeToString(k[:]) } |
||||
func (k *PrivateKey) HexString() string { return hex.EncodeToString(k[:]) } |
||||
func (k *PrivateKey) Equal(k2 PrivateKey) bool { return subtle.ConstantTimeCompare(k[:], k2[:]) == 1 } |
||||
|
||||
func (k *PrivateKey) IsZero() bool { |
||||
pk := Key(*k) |
||||
return pk.IsZero() |
||||
} |
||||
|
||||
func (k *PrivateKey) clamp() { |
||||
k[0] &= 248 |
||||
k[31] = (k[31] & 127) | 64 |
||||
} |
||||
|
||||
// Public computes the public key matching this curve25519 secret key.
|
||||
func (k *PrivateKey) Public() Key { |
||||
pk := Key(*k) |
||||
if pk.IsZero() { |
||||
panic("Tried to generate emptyPrivateKey.Public()") |
||||
} |
||||
var p [KeySize]byte |
||||
curve25519.ScalarBaseMult(&p, (*[KeySize]byte)(k)) |
||||
return (Key)(p) |
||||
} |
||||
|
||||
func (k PrivateKey) MarshalText() ([]byte, error) { |
||||
buf := new(bytes.Buffer) |
||||
fmt.Fprintf(buf, `privkey:%x`, k[:]) |
||||
return buf.Bytes(), nil |
||||
} |
||||
|
||||
func (k *PrivateKey) UnmarshalText(b []byte) error { |
||||
s := string(b) |
||||
if !strings.HasPrefix(s, `privkey:`) { |
||||
return errors.New("wgcfg.PrivateKey: UnmarshalText not given a private-key string") |
||||
} |
||||
s = strings.TrimPrefix(s, `privkey:`) |
||||
key, err := ParseHexKey(s) |
||||
if err != nil { |
||||
return fmt.Errorf("wgcfg.PrivateKey: UnmarshalText: %v", err) |
||||
} |
||||
copy(k[:], key[:]) |
||||
return nil |
||||
} |
||||
|
||||
func (k PrivateKey) SharedSecret(pub Key) (ss [KeySize]byte) { |
||||
apk := (*[KeySize]byte)(&pub) |
||||
ask := (*[KeySize]byte)(&k) |
||||
curve25519.ScalarMult(&ss, ask, apk) //lint:ignore SA1019 Jason says this is OK; match wireguard-go exactyl
|
||||
return ss |
||||
} |
||||
|
||||
func parseKeyBase64(enc *base64.Encoding, s string) (*Key, error) { |
||||
k, err := enc.DecodeString(s) |
||||
if err != nil { |
||||
return nil, &ParseError{"Invalid key: " + err.Error(), s} |
||||
} |
||||
if len(k) != KeySize { |
||||
return nil, &ParseError{"Keys must decode to exactly 32 bytes", s} |
||||
} |
||||
var key Key |
||||
copy(key[:], k) |
||||
return &key, nil |
||||
} |
||||
|
||||
func ParseSymmetricKey(b64 string) (SymmetricKey, error) { |
||||
k, err := parseKeyBase64(base64.StdEncoding, b64) |
||||
if err != nil { |
||||
return SymmetricKey{}, err |
||||
} |
||||
return SymmetricKey(*k), nil |
||||
} |
||||
|
||||
func ParseSymmetricHexKey(s string) (SymmetricKey, error) { |
||||
b, err := hex.DecodeString(s) |
||||
if err != nil { |
||||
return SymmetricKey{}, &ParseError{"invalid symmetric hex key: " + err.Error(), s} |
||||
} |
||||
if len(b) != chacha20poly1305.KeySize { |
||||
return SymmetricKey{}, &ParseError{fmt.Sprintf("invalid symmetric hex key length: %d", len(b)), s} |
||||
} |
||||
var key SymmetricKey |
||||
copy(key[:], b) |
||||
return key, nil |
||||
} |
||||
|
||||
// SymmetricKey is a chacha20poly1305 key.
|
||||
// It is used by WireGuard to represent pre-shared symmetric keys.
|
||||
type SymmetricKey [chacha20poly1305.KeySize]byte |
||||
|
||||
func (k SymmetricKey) Base64() string { return base64.StdEncoding.EncodeToString(k[:]) } |
||||
func (k SymmetricKey) String() string { return "sym:" + k.Base64()[:8] } |
||||
func (k SymmetricKey) HexString() string { return hex.EncodeToString(k[:]) } |
||||
func (k SymmetricKey) IsZero() bool { return k.Equal(SymmetricKey{}) } |
||||
func (k SymmetricKey) Equal(k2 SymmetricKey) bool { |
||||
return subtle.ConstantTimeCompare(k[:], k2[:]) == 1 |
||||
} |
||||
@ -0,0 +1,111 @@ |
||||
// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package wgcfg |
||||
|
||||
import ( |
||||
"bytes" |
||||
"testing" |
||||
) |
||||
|
||||
func TestKeyBasics(t *testing.T) { |
||||
k1, err := NewPresharedKey() |
||||
if err != nil { |
||||
t.Fatal(err) |
||||
} |
||||
|
||||
b, err := k1.MarshalJSON() |
||||
if err != nil { |
||||
t.Fatal(err) |
||||
} |
||||
|
||||
t.Run("JSON round-trip", func(t *testing.T) { |
||||
// should preserve the keys
|
||||
k2 := new(Key) |
||||
if err := k2.UnmarshalJSON(b); err != nil { |
||||
t.Fatal(err) |
||||
} |
||||
if !bytes.Equal(k1[:], k2[:]) { |
||||
t.Fatalf("k1 %v != k2 %v", k1[:], k2[:]) |
||||
} |
||||
if b1, b2 := k1.String(), k2.String(); b1 != b2 { |
||||
t.Fatalf("base64-encoded keys do not match: %s, %s", b1, b2) |
||||
} |
||||
}) |
||||
|
||||
t.Run("JSON incompatible with PrivateKey", func(t *testing.T) { |
||||
k2 := new(PrivateKey) |
||||
if err := k2.UnmarshalText(b); err == nil { |
||||
t.Fatalf("successfully decoded key as private key") |
||||
} |
||||
}) |
||||
|
||||
t.Run("second key", func(t *testing.T) { |
||||
// A second call to NewPresharedKey should make a new key.
|
||||
k3, err := NewPresharedKey() |
||||
if err != nil { |
||||
t.Fatal(err) |
||||
} |
||||
if bytes.Equal(k1[:], k3[:]) { |
||||
t.Fatalf("k1 %v == k3 %v", k1[:], k3[:]) |
||||
} |
||||
// Check for obvious comparables to make sure we are not generating bad strings somewhere.
|
||||
if b1, b2 := k1.String(), k3.String(); b1 == b2 { |
||||
t.Fatalf("base64-encoded keys match: %s, %s", b1, b2) |
||||
} |
||||
}) |
||||
} |
||||
func TestPrivateKeyBasics(t *testing.T) { |
||||
pri, err := NewPrivateKey() |
||||
if err != nil { |
||||
t.Fatal(err) |
||||
} |
||||
|
||||
b, err := pri.MarshalText() |
||||
if err != nil { |
||||
t.Fatal(err) |
||||
} |
||||
|
||||
t.Run("JSON round-trip", func(t *testing.T) { |
||||
// should preserve the keys
|
||||
pri2 := new(PrivateKey) |
||||
if err := pri2.UnmarshalText(b); err != nil { |
||||
t.Fatal(err) |
||||
} |
||||
if !bytes.Equal(pri[:], pri2[:]) { |
||||
t.Fatalf("pri %v != pri2 %v", pri[:], pri2[:]) |
||||
} |
||||
if b1, b2 := pri.String(), pri2.String(); b1 != b2 { |
||||
t.Fatalf("base64-encoded keys do not match: %s, %s", b1, b2) |
||||
} |
||||
if pub1, pub2 := pri.Public().String(), pri2.Public().String(); pub1 != pub2 { |
||||
t.Fatalf("base64-encoded public keys do not match: %s, %s", pub1, pub2) |
||||
} |
||||
}) |
||||
|
||||
t.Run("JSON incompatible with Key", func(t *testing.T) { |
||||
k2 := new(Key) |
||||
if err := k2.UnmarshalJSON(b); err == nil { |
||||
t.Fatalf("successfully decoded private key as key") |
||||
} |
||||
}) |
||||
|
||||
t.Run("second key", func(t *testing.T) { |
||||
// A second call to New should make a new key.
|
||||
pri3, err := NewPrivateKey() |
||||
if err != nil { |
||||
t.Fatal(err) |
||||
} |
||||
if bytes.Equal(pri[:], pri3[:]) { |
||||
t.Fatalf("pri %v == pri3 %v", pri[:], pri3[:]) |
||||
} |
||||
// Check for obvious comparables to make sure we are not generating bad strings somewhere.
|
||||
if b1, b2 := pri.String(), pri3.String(); b1 == b2 { |
||||
t.Fatalf("base64-encoded keys match: %s, %s", b1, b2) |
||||
} |
||||
if pub1, pub2 := pri.Public().String(), pri3.Public().String(); pub1 == pub2 { |
||||
t.Fatalf("base64-encoded public keys match: %s, %s", pub1, pub2) |
||||
} |
||||
}) |
||||
} |
||||
@ -0,0 +1,197 @@ |
||||
// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package wgcfg |
||||
|
||||
import ( |
||||
"bufio" |
||||
"encoding/hex" |
||||
"fmt" |
||||
"io" |
||||
"net" |
||||
"strconv" |
||||
"strings" |
||||
|
||||
"inet.af/netaddr" |
||||
) |
||||
|
||||
type ParseError struct { |
||||
why string |
||||
offender string |
||||
} |
||||
|
||||
func (e *ParseError) Error() string { |
||||
return fmt.Sprintf("%s: ‘%s’", e.why, e.offender) |
||||
} |
||||
|
||||
func validateEndpoints(s string) error { |
||||
vals := strings.Split(s, ",") |
||||
for _, val := range vals { |
||||
_, _, err := parseEndpoint(val) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
func parseEndpoint(s string) (host string, port uint16, err error) { |
||||
i := strings.LastIndexByte(s, ':') |
||||
if i < 0 { |
||||
return "", 0, &ParseError{"Missing port from endpoint", s} |
||||
} |
||||
host, portStr := s[:i], s[i+1:] |
||||
if len(host) < 1 { |
||||
return "", 0, &ParseError{"Invalid endpoint host", host} |
||||
} |
||||
uport, err := strconv.ParseUint(portStr, 10, 16) |
||||
if err != nil { |
||||
return "", 0, err |
||||
} |
||||
hostColon := strings.IndexByte(host, ':') |
||||
if host[0] == '[' || host[len(host)-1] == ']' || hostColon > 0 { |
||||
err := &ParseError{"Brackets must contain an IPv6 address", host} |
||||
if len(host) > 3 && host[0] == '[' && host[len(host)-1] == ']' && hostColon > 0 { |
||||
maybeV6 := net.ParseIP(host[1 : len(host)-1]) |
||||
if maybeV6 == nil || len(maybeV6) != net.IPv6len { |
||||
return "", 0, err |
||||
} |
||||
} else { |
||||
return "", 0, err |
||||
} |
||||
host = host[1 : len(host)-1] |
||||
} |
||||
return host, uint16(uport), nil |
||||
} |
||||
|
||||
func parseKeyHex(s string) (*Key, error) { |
||||
k, err := hex.DecodeString(s) |
||||
if err != nil { |
||||
return nil, &ParseError{"Invalid key: " + err.Error(), s} |
||||
} |
||||
if len(k) != KeySize { |
||||
return nil, &ParseError{"Keys must decode to exactly 32 bytes", s} |
||||
} |
||||
var key Key |
||||
copy(key[:], k) |
||||
return &key, nil |
||||
} |
||||
|
||||
// FromUAPI generates a Config from r.
|
||||
// r should be generated by calling device.IpcGetOperation;
|
||||
// it is not compatible with other uapi streams.
|
||||
func FromUAPI(r io.Reader) (*Config, error) { |
||||
cfg := new(Config) |
||||
var peer *Peer // current peer being operated on
|
||||
deviceConfig := true |
||||
|
||||
scanner := bufio.NewScanner(r) |
||||
for scanner.Scan() { |
||||
line := scanner.Text() |
||||
if line == "" { |
||||
continue |
||||
} |
||||
parts := strings.Split(line, "=") |
||||
if len(parts) != 2 { |
||||
return nil, fmt.Errorf("failed to parse line %q, found %d =-separated parts, want 2", line, len(parts)) |
||||
} |
||||
key := parts[0] |
||||
value := parts[1] |
||||
|
||||
if key == "public_key" { |
||||
if deviceConfig { |
||||
deviceConfig = false |
||||
} |
||||
// Load/create the peer we are now configuring.
|
||||
var err error |
||||
peer, err = cfg.handlePublicKeyLine(value) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
continue |
||||
} |
||||
|
||||
var err error |
||||
if deviceConfig { |
||||
err = cfg.handleDeviceLine(key, value) |
||||
} else { |
||||
err = cfg.handlePeerLine(peer, key, value) |
||||
} |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
} |
||||
|
||||
if err := scanner.Err(); err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
return cfg, nil |
||||
} |
||||
|
||||
func (cfg *Config) handleDeviceLine(key, value string) error { |
||||
switch key { |
||||
case "private_key": |
||||
k, err := parseKeyHex(value) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
// wireguard-go guarantees not to send zero value; private keys are already clamped.
|
||||
cfg.PrivateKey = PrivateKey(*k) |
||||
case "listen_port": |
||||
port, err := strconv.ParseUint(value, 10, 16) |
||||
if err != nil { |
||||
return fmt.Errorf("failed to parse listen_port: %w", err) |
||||
} |
||||
cfg.ListenPort = uint16(port) |
||||
case "fwmark": |
||||
// ignore
|
||||
default: |
||||
return fmt.Errorf("unexpected IpcGetOperation key: %v", key) |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
func (cfg *Config) handlePublicKeyLine(value string) (*Peer, error) { |
||||
k, err := parseKeyHex(value) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
cfg.Peers = append(cfg.Peers, Peer{}) |
||||
peer := &cfg.Peers[len(cfg.Peers)-1] |
||||
peer.PublicKey = *k |
||||
return peer, nil |
||||
} |
||||
|
||||
func (cfg *Config) handlePeerLine(peer *Peer, key, value string) error { |
||||
switch key { |
||||
case "endpoint": |
||||
err := validateEndpoints(value) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
peer.Endpoints = value |
||||
case "persistent_keepalive_interval": |
||||
n, err := strconv.ParseUint(value, 10, 16) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
peer.PersistentKeepalive = uint16(n) |
||||
case "allowed_ip": |
||||
ipp, err := netaddr.ParseIPPrefix(value) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
peer.AllowedIPs = append(peer.AllowedIPs, ipp) |
||||
case "protocol_version": |
||||
if value != "1" { |
||||
return fmt.Errorf("invalid protocol version: %v", value) |
||||
} |
||||
case "preshared_key", "last_handshake_time_sec", "last_handshake_time_nsec", "tx_bytes", "rx_bytes": |
||||
// ignore
|
||||
default: |
||||
return fmt.Errorf("unexpected IpcGetOperation key: %v", key) |
||||
} |
||||
return nil |
||||
} |
||||
@ -0,0 +1,55 @@ |
||||
// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package wgcfg |
||||
|
||||
import ( |
||||
"reflect" |
||||
"runtime" |
||||
"testing" |
||||
) |
||||
|
||||
func noError(t *testing.T, err error) bool { |
||||
if err == nil { |
||||
return true |
||||
} |
||||
_, fn, line, _ := runtime.Caller(1) |
||||
t.Errorf("Error at %s:%d: %#v", fn, line, err) |
||||
return false |
||||
} |
||||
|
||||
func equal(t *testing.T, expected, actual interface{}) bool { |
||||
if reflect.DeepEqual(expected, actual) { |
||||
return true |
||||
} |
||||
_, fn, line, _ := runtime.Caller(1) |
||||
t.Errorf("Failed equals at %s:%d\nactual %#v\nexpected %#v", fn, line, actual, expected) |
||||
return false |
||||
} |
||||
|
||||
func TestParseEndpoint(t *testing.T) { |
||||
_, _, err := parseEndpoint("[192.168.42.0:]:51880") |
||||
if err == nil { |
||||
t.Error("Error was expected") |
||||
} |
||||
host, port, err := parseEndpoint("192.168.42.0:51880") |
||||
if noError(t, err) { |
||||
equal(t, "192.168.42.0", host) |
||||
equal(t, uint16(51880), port) |
||||
} |
||||
host, port, err = parseEndpoint("test.wireguard.com:18981") |
||||
if noError(t, err) { |
||||
equal(t, "test.wireguard.com", host) |
||||
equal(t, uint16(18981), port) |
||||
} |
||||
host, port, err = parseEndpoint("[2607:5300:60:6b0::c05f:543]:2468") |
||||
if noError(t, err) { |
||||
equal(t, "2607:5300:60:6b0::c05f:543", host) |
||||
equal(t, uint16(2468), port) |
||||
} |
||||
_, _, err = parseEndpoint("[::::::invalid:18981") |
||||
if err == nil { |
||||
t.Error("Error was expected") |
||||
} |
||||
} |
||||
@ -0,0 +1,141 @@ |
||||
// Copyright (c) 2021 Tailscale Inc & AUTHORS All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSE file.
|
||||
|
||||
package wgcfg |
||||
|
||||
import ( |
||||
"fmt" |
||||
"io" |
||||
"sort" |
||||
"strconv" |
||||
"strings" |
||||
|
||||
"inet.af/netaddr" |
||||
) |
||||
|
||||
// ToUAPI writes cfg in UAPI format to w.
|
||||
// Prev is the previous device Config.
|
||||
// Prev is required so that we can remove now-defunct peers
|
||||
// without having to remove and re-add all peers.
|
||||
func (cfg *Config) ToUAPI(w io.Writer, prev *Config) error { |
||||
var stickyErr error |
||||
set := func(key, value string) { |
||||
if stickyErr != nil { |
||||
return |
||||
} |
||||
_, err := fmt.Fprintf(w, "%s=%s\n", key, value) |
||||
if err != nil { |
||||
stickyErr = err |
||||
} |
||||
} |
||||
setUint16 := func(key string, value uint16) { |
||||
set(key, strconv.FormatUint(uint64(value), 10)) |
||||
} |
||||
setPeer := func(peer Peer) { |
||||
set("public_key", peer.PublicKey.HexString()) |
||||
} |
||||
|
||||
// Device config.
|
||||
if prev.PrivateKey != cfg.PrivateKey { |
||||
set("private_key", cfg.PrivateKey.HexString()) |
||||
} |
||||
if prev.ListenPort != cfg.ListenPort { |
||||
setUint16("listen_port", cfg.ListenPort) |
||||
} |
||||
|
||||
old := make(map[Key]Peer) |
||||
for _, p := range prev.Peers { |
||||
old[p.PublicKey] = p |
||||
} |
||||
|
||||
// Add/configure all new peers.
|
||||
for _, p := range cfg.Peers { |
||||
oldPeer := old[p.PublicKey] |
||||
setPeer(p) |
||||
set("protocol_version", "1") |
||||
|
||||
if !endpointsEqual(oldPeer.Endpoints, p.Endpoints) { |
||||
set("endpoint", p.Endpoints) |
||||
} |
||||
|
||||
// TODO: replace_allowed_ips is expensive.
|
||||
// If p.AllowedIPs is a strict superset of oldPeer.AllowedIPs,
|
||||
// then skip replace_allowed_ips and instead add only
|
||||
// the new ipps with allowed_ip.
|
||||
if !cidrsEqual(oldPeer.AllowedIPs, p.AllowedIPs) { |
||||
set("replace_allowed_ips", "true") |
||||
for _, ipp := range p.AllowedIPs { |
||||
set("allowed_ip", ipp.String()) |
||||
} |
||||
} |
||||
|
||||
// Set PersistentKeepalive after the peer is otherwise configured,
|
||||
// because it can trigger handshake packets.
|
||||
if oldPeer.PersistentKeepalive != p.PersistentKeepalive { |
||||
setUint16("persistent_keepalive_interval", p.PersistentKeepalive) |
||||
} |
||||
} |
||||
|
||||
// Remove peers that were present but should no longer be.
|
||||
for _, p := range cfg.Peers { |
||||
delete(old, p.PublicKey) |
||||
} |
||||
for _, p := range old { |
||||
setPeer(p) |
||||
set("remove", "true") |
||||
} |
||||
|
||||
if stickyErr != nil { |
||||
stickyErr = fmt.Errorf("ToUAPI: %w", stickyErr) |
||||
} |
||||
return stickyErr |
||||
} |
||||
|
||||
func endpointsEqual(x, y string) bool { |
||||
// Cheap comparisons.
|
||||
if x == y { |
||||
return true |
||||
} |
||||
xs := strings.Split(x, ",") |
||||
ys := strings.Split(y, ",") |
||||
if len(xs) != len(ys) { |
||||
return false |
||||
} |
||||
// Otherwise, see if they're the same, but out of order.
|
||||
sort.Strings(xs) |
||||
sort.Strings(ys) |
||||
x = strings.Join(xs, ",") |
||||
y = strings.Join(ys, ",") |
||||
return x == y |
||||
} |
||||
|
||||
func cidrsEqual(x, y []netaddr.IPPrefix) bool { |
||||
// TODO: re-implement using netaddr.IPSet.Equal.
|
||||
if len(x) != len(y) { |
||||
return false |
||||
} |
||||
// First see if they're equal in order, without allocating.
|
||||
exact := true |
||||
for i := range x { |
||||
if x[i] != y[i] { |
||||
exact = false |
||||
break |
||||
} |
||||
} |
||||
if exact { |
||||
return true |
||||
} |
||||
|
||||
// Otherwise, see if they're the same, but out of order.
|
||||
m := make(map[netaddr.IPPrefix]bool) |
||||
for _, v := range x { |
||||
m[v] = true |
||||
} |
||||
for _, v := range y { |
||||
if !m[v] { |
||||
return false |
||||
} |
||||
} |
||||
return true |
||||
} |
||||
Loading…
Reference in new issue