Files
tailscale/tsnet/tsnet_test.go
T
Mario Minardi 02af7c963c tsnet: allow for automatic ID token generation
Allow for optionally specifiying an audience for tsnet. This is passed
to the underlying identity federation logic to allow for tsnet auth to
use automatic ID token generation for authentication.

Updates https://github.com/tailscale/corp/issues/33316

Signed-off-by: Mario Minardi <mario@tailscale.com>
2026-01-14 09:02:43 -07:00

1736 lines
47 KiB
Go

// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
package tsnet
import (
"bufio"
"bytes"
"context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/json"
"errors"
"flag"
"fmt"
"io"
"log"
"math/big"
"net"
"net/http"
"net/http/httptest"
"net/netip"
"os"
"path/filepath"
"reflect"
"runtime"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/google/go-cmp/cmp"
dto "github.com/prometheus/client_model/go"
"github.com/prometheus/common/expfmt"
"golang.org/x/net/proxy"
"tailscale.com/client/local"
"tailscale.com/cmd/testwrapper/flakytest"
"tailscale.com/internal/client/tailscale"
"tailscale.com/ipn"
"tailscale.com/ipn/store/mem"
"tailscale.com/net/netns"
"tailscale.com/tailcfg"
"tailscale.com/tstest"
"tailscale.com/tstest/deptest"
"tailscale.com/tstest/integration"
"tailscale.com/tstest/integration/testcontrol"
"tailscale.com/types/key"
"tailscale.com/types/logger"
"tailscale.com/util/must"
)
// TestListener_Server ensures that the listener type always keeps the Server
// method, which is used by some external applications to identify a tsnet.Listener
// from other net.Listeners, as well as access the underlying Server.
func TestListener_Server(t *testing.T) {
s := &Server{}
ln := listener{s: s}
if ln.Server() != s {
t.Errorf("listener.Server() returned %v, want %v", ln.Server(), s)
}
}
func TestListenerPort(t *testing.T) {
errNone := errors.New("sentinel start error")
tests := []struct {
network string
addr string
wantErr bool
}{
{"tcp", ":80", false},
{"foo", ":80", true},
{"tcp", ":http", false}, // built-in name to Go; doesn't require cgo, /etc/services
{"tcp", ":https", false}, // built-in name to Go; doesn't require cgo, /etc/services
{"tcp", ":gibberishsdlkfj", true},
{"tcp", ":%!d(string=80)", true}, // issue 6201
{"udp", ":80", false},
{"udp", "100.102.104.108:80", false},
{"udp", "not-an-ip:80", true},
{"udp4", ":80", false},
{"udp4", "100.102.104.108:80", false},
{"udp4", "not-an-ip:80", true},
// Verify network type matches IP
{"tcp4", "1.2.3.4:80", false},
{"tcp6", "1.2.3.4:80", true},
{"tcp4", "[12::34]:80", true},
{"tcp6", "[12::34]:80", false},
}
for _, tt := range tests {
s := &Server{}
s.initOnce.Do(func() { s.initErr = errNone })
_, err := s.Listen(tt.network, tt.addr)
gotErr := err != nil && err != errNone
if gotErr != tt.wantErr {
t.Errorf("Listen(%q, %q) error = %v, want %v", tt.network, tt.addr, gotErr, tt.wantErr)
}
}
}
var verboseDERP = flag.Bool("verbose-derp", false, "if set, print DERP and STUN logs")
var verboseNodes = flag.Bool("verbose-nodes", false, "if set, print tsnet.Server logs")
func startControl(t *testing.T) (controlURL string, control *testcontrol.Server) {
// Corp#4520: don't use netns for tests.
netns.SetEnabled(false)
t.Cleanup(func() {
netns.SetEnabled(true)
})
derpLogf := logger.Discard
if *verboseDERP {
derpLogf = t.Logf
}
derpMap := integration.RunDERPAndSTUN(t, derpLogf, "127.0.0.1")
control = &testcontrol.Server{
DERPMap: derpMap,
DNSConfig: &tailcfg.DNSConfig{
Proxied: true,
},
MagicDNSDomain: "tail-scale.ts.net",
Logf: t.Logf,
}
control.HTTPTestServer = httptest.NewUnstartedServer(control)
control.HTTPTestServer.Start()
t.Cleanup(control.HTTPTestServer.Close)
controlURL = control.HTTPTestServer.URL
t.Logf("testcontrol listening on %s", controlURL)
return controlURL, control
}
type testCertIssuer struct {
mu sync.Mutex
certs map[string]*tls.Certificate
root *x509.Certificate
rootKey *ecdsa.PrivateKey
}
func newCertIssuer() *testCertIssuer {
rootKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
panic(err)
}
t := &x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{
CommonName: "root",
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(time.Hour),
IsCA: true,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
KeyUsage: x509.KeyUsageCertSign,
BasicConstraintsValid: true,
}
rootDER, err := x509.CreateCertificate(rand.Reader, t, t, &rootKey.PublicKey, rootKey)
if err != nil {
panic(err)
}
rootCA, err := x509.ParseCertificate(rootDER)
if err != nil {
panic(err)
}
return &testCertIssuer{
certs: make(map[string]*tls.Certificate),
root: rootCA,
rootKey: rootKey,
}
}
func (tci *testCertIssuer) getCert(chi *tls.ClientHelloInfo) (*tls.Certificate, error) {
tci.mu.Lock()
defer tci.mu.Unlock()
cert, ok := tci.certs[chi.ServerName]
if ok {
return cert, nil
}
certPrivKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return nil, err
}
certTmpl := &x509.Certificate{
SerialNumber: big.NewInt(1),
DNSNames: []string{chi.ServerName},
NotBefore: time.Now(),
NotAfter: time.Now().Add(time.Hour),
}
certDER, err := x509.CreateCertificate(rand.Reader, certTmpl, tci.root, &certPrivKey.PublicKey, tci.rootKey)
if err != nil {
return nil, err
}
cert = &tls.Certificate{
Certificate: [][]byte{certDER, tci.root.Raw},
PrivateKey: certPrivKey,
}
tci.certs[chi.ServerName] = cert
return cert, nil
}
func (tci *testCertIssuer) Pool() *x509.CertPool {
p := x509.NewCertPool()
p.AddCert(tci.root)
return p
}
var testCertRoot = newCertIssuer()
func startServer(t *testing.T, ctx context.Context, controlURL, hostname string) (*Server, netip.Addr, key.NodePublic) {
t.Helper()
tmp := filepath.Join(t.TempDir(), hostname)
os.MkdirAll(tmp, 0755)
s := &Server{
Dir: tmp,
ControlURL: controlURL,
Hostname: hostname,
Store: new(mem.Store),
Ephemeral: true,
getCertForTesting: testCertRoot.getCert,
}
if *verboseNodes {
s.Logf = t.Logf
}
t.Cleanup(func() { s.Close() })
status, err := s.Up(ctx)
if err != nil {
t.Fatal(err)
}
return s, status.TailscaleIPs[0], status.Self.PublicKey
}
func TestDialBlocks(t *testing.T) {
tstest.Shard(t)
tstest.ResourceCheck(t)
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
controlURL, _ := startControl(t)
// Make one tsnet that blocks until it's up.
s1, _, _ := startServer(t, ctx, controlURL, "s1")
ln, err := s1.Listen("tcp", ":8080")
if err != nil {
t.Fatal(err)
}
defer ln.Close()
// Then make another tsnet node that will only be woken up
// upon the first dial.
tmp := filepath.Join(t.TempDir(), "s2")
os.MkdirAll(tmp, 0755)
s2 := &Server{
Dir: tmp,
ControlURL: controlURL,
Hostname: "s2",
Store: new(mem.Store),
Ephemeral: true,
getCertForTesting: testCertRoot.getCert,
}
if *verboseNodes {
s2.Logf = log.Printf
}
t.Cleanup(func() { s2.Close() })
c, err := s2.Dial(ctx, "tcp", "s1:8080")
if err != nil {
t.Fatal(err)
}
defer c.Close()
}
// TestConn tests basic TCP connections between two tsnet Servers, s1 and s2:
//
// - s1, a subnet router, first listens on its TCP :8081.
// - s2 can connect to s1:8081
// - s2 cannot connect to s1:8082 (no listener)
// - s2 can dial through the subnet router functionality (getting a synthetic RST
// that we verify we generated & saw)
func TestConn(t *testing.T) {
tstest.Shard(t)
tstest.ResourceCheck(t)
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
controlURL, c := startControl(t)
s1, s1ip, s1PubKey := startServer(t, ctx, controlURL, "s1")
// Track whether we saw an attempted dial to 192.0.2.1:8081.
var saw192DocNetDial atomic.Bool
s1.RegisterFallbackTCPHandler(func(src, dst netip.AddrPort) (handler func(net.Conn), intercept bool) {
t.Logf("s1: fallback TCP handler called for %v -> %v", src, dst)
if dst.String() == "192.0.2.1:8081" {
saw192DocNetDial.Store(true)
}
return nil, true // nil handler but intercept=true means to send RST
})
lc1 := must.Get(s1.LocalClient())
must.Get(lc1.EditPrefs(ctx, &ipn.MaskedPrefs{
Prefs: ipn.Prefs{
AdvertiseRoutes: []netip.Prefix{netip.MustParsePrefix("192.0.2.0/24")},
},
AdvertiseRoutesSet: true,
}))
c.SetSubnetRoutes(s1PubKey, []netip.Prefix{netip.MustParsePrefix("192.0.2.0/24")})
// Start s2 after s1 is fully set up, including advertising its routes,
// otherwise the test is flaky if the test starts dialing through s2 before
// our test control server has told s2 about s1's routes.
s2, _, _ := startServer(t, ctx, controlURL, "s2")
lc2 := must.Get(s2.LocalClient())
must.Get(lc2.EditPrefs(ctx, &ipn.MaskedPrefs{
Prefs: ipn.Prefs{
RouteAll: true,
},
RouteAllSet: true,
}))
// ping to make sure the connection is up.
res, err := lc2.Ping(ctx, s1ip, tailcfg.PingTSMP)
if err != nil {
t.Fatal(err)
}
t.Logf("ping success: %#+v", res)
// pass some data through TCP.
ln, err := s1.Listen("tcp", ":8081")
if err != nil {
t.Fatal(err)
}
defer ln.Close()
s1Conns := make(chan net.Conn)
go func() {
for {
c, err := ln.Accept()
if err != nil {
if ctx.Err() != nil {
return
}
t.Errorf("s1.Accept: %v", err)
return
}
select {
case s1Conns <- c:
case <-ctx.Done():
c.Close()
}
}
}()
w, err := s2.Dial(ctx, "tcp", fmt.Sprintf("%s:8081", s1ip))
if err != nil {
t.Fatal(err)
}
want := "hello"
if _, err := io.WriteString(w, want); err != nil {
t.Fatal(err)
}
select {
case <-time.After(5 * time.Second):
t.Fatal("timeout waiting for connection")
case r := <-s1Conns:
got := make([]byte, len(want))
_, err := io.ReadAtLeast(r, got, len(got))
r.Close()
if err != nil {
t.Fatal(err)
}
t.Logf("got: %q", got)
if string(got) != want {
t.Errorf("got %q, want %q", got, want)
}
}
// Dial a non-existent port on s1 and expect it to fail.
_, err = s2.Dial(ctx, "tcp", fmt.Sprintf("%s:8082", s1ip)) // some random port
if err == nil {
t.Fatalf("unexpected success; should have seen a connection refused error")
}
t.Logf("got expected failure: %v", err)
// s1 is a subnet router for TEST-NET-1 (192.0.2.0/24). Let's dial to that
// subnet from s2 to ensure a listener without an IP address (i.e. our
// ":8081" listen above) only matches destination IPs corresponding to the
// s1 node's IP addresses, and not to any random IP of a subnet it's routing.
//
// The RegisterFallbackTCPHandler on s1 above handles sending a RST when the
// TCP SYN arrives from s2. But we bound it to 5 seconds lest a regression
// like tailscale/tailscale#17805 recur.
s2dialer := s2.Sys().Dialer.Get()
s2dialer.SetSystemDialerForTest(func(ctx context.Context, netw, addr string) (net.Conn, error) {
t.Logf("s2: unexpected system dial called for %s %s", netw, addr)
return nil, fmt.Errorf("system dialer called unexpectedly for %s %s", netw, addr)
})
docCtx, docCancel := context.WithTimeout(ctx, 5*time.Second)
defer docCancel()
_, err = s2.Dial(docCtx, "tcp", "192.0.2.1:8081")
if err == nil {
t.Fatalf("unexpected success; should have seen a connection refused error")
}
if !saw192DocNetDial.Load() {
t.Errorf("expected s1's fallback TCP handler to have been called for 192.0.2.1:8081")
}
}
func TestLoopbackLocalAPI(t *testing.T) {
flakytest.Mark(t, "https://github.com/tailscale/tailscale/issues/8557")
tstest.Shard(t)
tstest.ResourceCheck(t)
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
controlURL, _ := startControl(t)
s1, _, _ := startServer(t, ctx, controlURL, "s1")
addr, proxyCred, localAPICred, err := s1.Loopback()
if err != nil {
t.Fatal(err)
}
if proxyCred == localAPICred {
t.Fatal("proxy password matches local API password, they should be different")
}
url := "http://" + addr + "/localapi/v0/status"
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
t.Fatal(err)
}
res, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatal(err)
}
res.Body.Close()
if res.StatusCode != 403 {
t.Errorf("GET %s returned %d, want 403 without Sec- header", url, res.StatusCode)
}
req, err = http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
t.Fatal(err)
}
req.Header.Set("Sec-Tailscale", "localapi")
res, err = http.DefaultClient.Do(req)
if err != nil {
t.Fatal(err)
}
res.Body.Close()
if res.StatusCode != 401 {
t.Errorf("GET %s returned %d, want 401 without basic auth", url, res.StatusCode)
}
req, err = http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
t.Fatal(err)
}
req.SetBasicAuth("", localAPICred)
res, err = http.DefaultClient.Do(req)
if err != nil {
t.Fatal(err)
}
res.Body.Close()
if res.StatusCode != 403 {
t.Errorf("GET %s returned %d, want 403 without Sec- header", url, res.StatusCode)
}
req, err = http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
t.Fatal(err)
}
req.Header.Set("Sec-Tailscale", "localapi")
req.SetBasicAuth("", localAPICred)
res, err = http.DefaultClient.Do(req)
if err != nil {
t.Fatal(err)
}
res.Body.Close()
if res.StatusCode != 200 {
t.Errorf("GET /status returned %d, want 200", res.StatusCode)
}
}
func TestLoopbackSOCKS5(t *testing.T) {
flakytest.Mark(t, "https://github.com/tailscale/tailscale/issues/8198")
tstest.Shard(t)
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
controlURL, _ := startControl(t)
s1, s1ip, _ := startServer(t, ctx, controlURL, "s1")
s2, _, _ := startServer(t, ctx, controlURL, "s2")
addr, proxyCred, _, err := s2.Loopback()
if err != nil {
t.Fatal(err)
}
ln, err := s1.Listen("tcp", ":8081")
if err != nil {
t.Fatal(err)
}
defer ln.Close()
auth := &proxy.Auth{User: "tsnet", Password: proxyCred}
socksDialer, err := proxy.SOCKS5("tcp", addr, auth, proxy.Direct)
if err != nil {
t.Fatal(err)
}
w, err := socksDialer.Dial("tcp", fmt.Sprintf("%s:8081", s1ip))
if err != nil {
t.Fatal(err)
}
r, err := ln.Accept()
if err != nil {
t.Fatal(err)
}
want := "hello"
if _, err := io.WriteString(w, want); err != nil {
t.Fatal(err)
}
got := make([]byte, len(want))
if _, err := io.ReadAtLeast(r, got, len(got)); err != nil {
t.Fatal(err)
}
t.Logf("got: %q", got)
if string(got) != want {
t.Errorf("got %q, want %q", got, want)
}
}
func TestTailscaleIPs(t *testing.T) {
tstest.Shard(t)
controlURL, _ := startControl(t)
tmp := t.TempDir()
tmps1 := filepath.Join(tmp, "s1")
os.MkdirAll(tmps1, 0755)
s1 := &Server{
Dir: tmps1,
ControlURL: controlURL,
Hostname: "s1",
Store: new(mem.Store),
Ephemeral: true,
}
defer s1.Close()
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
s1status, err := s1.Up(ctx)
if err != nil {
t.Fatal(err)
}
var upIp4, upIp6 netip.Addr
for _, ip := range s1status.TailscaleIPs {
if ip.Is6() {
upIp6 = ip
}
if ip.Is4() {
upIp4 = ip
}
}
sIp4, sIp6 := s1.TailscaleIPs()
if !(upIp4 == sIp4 && upIp6 == sIp6) {
t.Errorf("s1.TailscaleIPs returned a different result than S1.Up, (%s, %s) != (%s, %s)",
sIp4, upIp4, sIp6, upIp6)
}
}
// TestListenerCleanup is a regression test to verify that s.Close doesn't
// deadlock if a listener is still open.
func TestListenerCleanup(t *testing.T) {
tstest.Shard(t)
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
controlURL, _ := startControl(t)
s1, _, _ := startServer(t, ctx, controlURL, "s1")
ln, err := s1.Listen("tcp", ":8081")
if err != nil {
t.Fatal(err)
}
if err := s1.Close(); err != nil {
t.Fatal(err)
}
if err := ln.Close(); !errors.Is(err, net.ErrClosed) {
t.Fatalf("second ln.Close error: %v, want net.ErrClosed", err)
}
// Verify that handling a connection from gVisor (from a packet arriving)
// after a listener closed doesn't panic (previously: sending on a closed
// channel) or hang.
c := &closeTrackConn{}
ln.(*listener).handle(c)
if !c.closed {
t.Errorf("c.closed = false, want true")
}
}
type closeTrackConn struct {
net.Conn
closed bool
}
func (wc *closeTrackConn) Close() error {
wc.closed = true
return nil
}
// tests https://github.com/tailscale/tailscale/issues/6973 -- that we can start a tsnet server,
// stop it, and restart it, even on Windows.
func TestStartStopStartGetsSameIP(t *testing.T) {
tstest.Shard(t)
controlURL, _ := startControl(t)
tmp := t.TempDir()
tmps1 := filepath.Join(tmp, "s1")
os.MkdirAll(tmps1, 0755)
newServer := func() *Server {
return &Server{
Dir: tmps1,
ControlURL: controlURL,
Hostname: "s1",
Logf: tstest.WhileTestRunningLogger(t),
}
}
s1 := newServer()
defer s1.Close()
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
s1status, err := s1.Up(ctx)
if err != nil {
t.Fatal(err)
}
firstIPs := s1status.TailscaleIPs
t.Logf("IPs: %v", firstIPs)
if err := s1.Close(); err != nil {
t.Fatalf("Close: %v", err)
}
s2 := newServer()
defer s2.Close()
s2status, err := s2.Up(ctx)
if err != nil {
t.Fatalf("second Up: %v", err)
}
secondIPs := s2status.TailscaleIPs
t.Logf("IPs: %v", secondIPs)
if !reflect.DeepEqual(firstIPs, secondIPs) {
t.Fatalf("got %v but later %v", firstIPs, secondIPs)
}
}
func TestFunnel(t *testing.T) {
tstest.Shard(t)
ctx, dialCancel := context.WithTimeout(context.Background(), 30*time.Second)
defer dialCancel()
controlURL, _ := startControl(t)
s1, _, _ := startServer(t, ctx, controlURL, "s1")
s2, _, _ := startServer(t, ctx, controlURL, "s2")
ln := must.Get(s1.ListenFunnel("tcp", ":443"))
defer ln.Close()
wantSrcAddrPort := netip.MustParseAddrPort("127.0.0.1:1234")
wantTarget := ipn.HostPort("s1.tail-scale.ts.net:443")
srv := &http.Server{
ConnContext: func(ctx context.Context, c net.Conn) context.Context {
tc, ok := c.(*tls.Conn)
if !ok {
t.Errorf("ConnContext called with non-TLS conn: %T", c)
}
if fc, ok := tc.NetConn().(*ipn.FunnelConn); !ok {
t.Errorf("ConnContext called with non-FunnelConn: %T", c)
} else if fc.Src != wantSrcAddrPort {
t.Errorf("ConnContext called with wrong SrcAddrPort; got %v, want %v", fc.Src, wantSrcAddrPort)
} else if fc.Target != wantTarget {
t.Errorf("ConnContext called with wrong Target; got %q, want %q", fc.Target, wantTarget)
}
return ctx
},
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "hello")
}),
}
go srv.Serve(ln)
c := &http.Client{
Transport: &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return dialIngressConn(s2, s1, addr)
},
TLSClientConfig: &tls.Config{
RootCAs: testCertRoot.Pool(),
},
},
}
resp, err := c.Get("https://s1.tail-scale.ts.net:443")
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
t.Errorf("unexpected status code: %v", resp.StatusCode)
return
}
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
}
if string(body) != "hello" {
t.Errorf("unexpected body: %q", body)
}
}
// TestFunnelClose ensures that the listener returned by ListenFunnel cleans up
// after itself when closed. Specifically, changes made to the serve config
// should be cleared.
func TestFunnelClose(t *testing.T) {
marshalServeConfig := func(t *testing.T, sc ipn.ServeConfigView) string {
t.Helper()
return string(must.Get(json.MarshalIndent(sc, "", "\t")))
}
t.Run("simple", func(t *testing.T) {
controlURL, _ := startControl(t)
s, _, _ := startServer(t, t.Context(), controlURL, "s")
before := s.lb.ServeConfig()
ln := must.Get(s.ListenFunnel("tcp", ":443"))
ln.Close()
after := s.lb.ServeConfig()
if diff := cmp.Diff(marshalServeConfig(t, after), marshalServeConfig(t, before)); diff != "" {
t.Fatalf("expected serve config to be unchanged after close (-got, +want):\n%s", diff)
}
})
// Closing the listener shouldn't clear out config that predates it.
t.Run("no_clobbering", func(t *testing.T) {
controlURL, _ := startControl(t)
s, _, _ := startServer(t, t.Context(), controlURL, "s")
// To obtain config the listener might want to clobber, we:
// - run a listener
// - grab the config
// - close the listener (clearing config)
ln := must.Get(s.ListenFunnel("tcp", ":443"))
before := s.lb.ServeConfig()
ln.Close()
// Now we manually write the config to the local backend (it should have
// been cleared), run the listener again, and close it again.
must.Do(s.lb.SetServeConfig(before.AsStruct(), ""))
ln = must.Get(s.ListenFunnel("tcp", ":443"))
ln.Close()
// The config should not have been cleared this time since it predated
// the most recent run.
after := s.lb.ServeConfig()
if diff := cmp.Diff(marshalServeConfig(t, after), marshalServeConfig(t, before)); diff != "" {
t.Fatalf("expected existing config to remain intact (-got, +want):\n%s", diff)
}
})
// Closing one listener shouldn't affect config for another listener.
t.Run("two_listeners", func(t *testing.T) {
controlURL, _ := startControl(t)
s, _, _ := startServer(t, t.Context(), controlURL, "s1")
// Start a listener on 443.
ln1 := must.Get(s.ListenFunnel("tcp", ":443"))
defer ln1.Close()
// Save the serve config for this original listener.
before := s.lb.ServeConfig()
// Now start and close a new listener on a different port.
ln2 := must.Get(s.ListenFunnel("tcp", ":8080"))
ln2.Close()
// The serve config for the original listener should be intact.
after := s.lb.ServeConfig()
if diff := cmp.Diff(marshalServeConfig(t, after), marshalServeConfig(t, before)); diff != "" {
t.Fatalf("expected existing config to remain intact (-got, +want):\n%s", diff)
}
})
// It should be possible to close a listener and free system resources even
// when the Server has been closed (or the listener should be automatically
// closed).
t.Run("after_server_close", func(t *testing.T) {
controlURL, _ := startControl(t)
s, _, _ := startServer(t, t.Context(), controlURL, "s")
ln := must.Get(s.ListenFunnel("tcp", ":443"))
// Close the server, then close the listener.
must.Do(s.Close())
// We don't care whether we get an error from the listener closing.
ln.Close()
// The listener should immediately return an error indicating closure.
_, err := ln.Accept()
// Looking for a string in the error sucks, but it's supposed to stay
// consistent:
// https://github.com/golang/go/blob/108b333d510c1f60877ac917375d7931791acfe6/src/internal/poll/fd.go#L20-L24
if err == nil || !strings.Contains(err.Error(), "use of closed network connection") {
t.Fatal("expected listener to be closed, got:", err)
}
})
}
func TestListenerClose(t *testing.T) {
tstest.Shard(t)
ctx := context.Background()
controlURL, _ := startControl(t)
s1, _, _ := startServer(t, ctx, controlURL, "s1")
ln, err := s1.Listen("tcp", ":8080")
if err != nil {
t.Fatal(err)
}
errc := make(chan error, 1)
go func() {
c, err := ln.Accept()
if c != nil {
c.Close()
}
errc <- err
}()
ln.Close()
select {
case err := <-errc:
if !errors.Is(err, net.ErrClosed) {
t.Errorf("unexpected error: %v", err)
}
case <-time.After(10 * time.Second):
t.Fatal("timeout waiting for Accept to return")
}
}
func dialIngressConn(from, to *Server, target string) (net.Conn, error) {
toLC := must.Get(to.LocalClient())
toStatus := must.Get(toLC.StatusWithoutPeers(context.Background()))
peer6 := toStatus.Self.PeerAPIURL[1] // IPv6
toPeerAPI, ok := strings.CutPrefix(peer6, "http://")
if !ok {
return nil, fmt.Errorf("unexpected PeerAPIURL %q", peer6)
}
dialCtx, dialCancel := context.WithTimeout(context.Background(), 30*time.Second)
outConn, err := from.Dial(dialCtx, "tcp", toPeerAPI)
dialCancel()
if err != nil {
return nil, err
}
req, err := http.NewRequest("POST", "/v0/ingress", nil)
if err != nil {
return nil, err
}
req.Host = toPeerAPI
req.Header.Set("Tailscale-Ingress-Src", "127.0.0.1:1234")
req.Header.Set("Tailscale-Ingress-Target", target)
if err := req.Write(outConn); err != nil {
return nil, err
}
br := bufio.NewReader(outConn)
res, err := http.ReadResponse(br, req)
if err != nil {
return nil, err
}
defer res.Body.Close() // just to appease vet
if res.StatusCode != 101 {
return nil, fmt.Errorf("unexpected status code: %v", res.StatusCode)
}
return &bufferedConn{outConn, br}, nil
}
type bufferedConn struct {
net.Conn
reader *bufio.Reader
}
func (c *bufferedConn) Read(b []byte) (int, error) {
return c.reader.Read(b)
}
func TestFallbackTCPHandler(t *testing.T) {
tstest.Shard(t)
tstest.ResourceCheck(t)
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
controlURL, _ := startControl(t)
s1, s1ip, _ := startServer(t, ctx, controlURL, "s1")
s2, _, _ := startServer(t, ctx, controlURL, "s2")
lc2, err := s2.LocalClient()
if err != nil {
t.Fatal(err)
}
// ping to make sure the connection is up.
res, err := lc2.Ping(ctx, s1ip, tailcfg.PingICMP)
if err != nil {
t.Fatal(err)
}
t.Logf("ping success: %#+v", res)
var s1TcpConnCount atomic.Int32
deregister := s1.RegisterFallbackTCPHandler(func(src, dst netip.AddrPort) (handler func(net.Conn), intercept bool) {
s1TcpConnCount.Add(1)
return nil, false
})
if _, err := s2.Dial(ctx, "tcp", fmt.Sprintf("%s:8081", s1ip)); err == nil {
t.Fatal("Expected dial error because fallback handler did not intercept")
}
if got := s1TcpConnCount.Load(); got != 1 {
t.Errorf("s1TcpConnCount = %d, want %d", got, 1)
}
deregister()
if _, err := s2.Dial(ctx, "tcp", fmt.Sprintf("%s:8081", s1ip)); err == nil {
t.Fatal("Expected dial error because nothing would intercept")
}
if got := s1TcpConnCount.Load(); got != 1 {
t.Errorf("s1TcpConnCount = %d, want %d", got, 1)
}
}
func TestCapturePcap(t *testing.T) {
tstest.Shard(t)
const timeLimit = 120
ctx, cancel := context.WithTimeout(context.Background(), timeLimit*time.Second)
defer cancel()
dir := t.TempDir()
s1Pcap := filepath.Join(dir, "s1.pcap")
s2Pcap := filepath.Join(dir, "s2.pcap")
controlURL, _ := startControl(t)
s1, s1ip, _ := startServer(t, ctx, controlURL, "s1")
s2, _, _ := startServer(t, ctx, controlURL, "s2")
s1.CapturePcap(ctx, s1Pcap)
s2.CapturePcap(ctx, s2Pcap)
lc2, err := s2.LocalClient()
if err != nil {
t.Fatal(err)
}
// send a packet which both nodes will capture
res, err := lc2.Ping(ctx, s1ip, tailcfg.PingICMP)
if err != nil {
t.Fatal(err)
}
t.Logf("ping success: %#+v", res)
fileSize := func(name string) int64 {
fi, err := os.Stat(name)
if err != nil {
return 0
}
return fi.Size()
}
const pcapHeaderSize = 24
// there is a lag before the io.Copy writes a packet to the pcap files
for range timeLimit * 10 {
time.Sleep(100 * time.Millisecond)
if (fileSize(s1Pcap) > pcapHeaderSize) && (fileSize(s2Pcap) > pcapHeaderSize) {
break
}
}
if got := fileSize(s1Pcap); got <= pcapHeaderSize {
t.Errorf("s1 pcap file size = %d, want > pcapHeaderSize(%d)", got, pcapHeaderSize)
}
if got := fileSize(s2Pcap); got <= pcapHeaderSize {
t.Errorf("s2 pcap file size = %d, want > pcapHeaderSize(%d)", got, pcapHeaderSize)
}
}
func TestUDPConn(t *testing.T) {
tstest.Shard(t)
tstest.ResourceCheck(t)
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
controlURL, _ := startControl(t)
s1, s1ip, _ := startServer(t, ctx, controlURL, "s1")
s2, s2ip, _ := startServer(t, ctx, controlURL, "s2")
lc2, err := s2.LocalClient()
if err != nil {
t.Fatal(err)
}
// ping to make sure the connection is up.
res, err := lc2.Ping(ctx, s1ip, tailcfg.PingICMP)
if err != nil {
t.Fatal(err)
}
t.Logf("ping success: %#+v", res)
pc := must.Get(s1.ListenPacket("udp", fmt.Sprintf("%s:8081", s1ip)))
defer pc.Close()
// Dial to s1 from s2
w, err := s2.Dial(ctx, "udp", fmt.Sprintf("%s:8081", s1ip))
if err != nil {
t.Fatal(err)
}
defer w.Close()
// Send a packet from s2 to s1
want := "hello"
if _, err := io.WriteString(w, want); err != nil {
t.Fatal(err)
}
// Receive the packet on s1
got := make([]byte, 1024)
n, from, err := pc.ReadFrom(got)
if err != nil {
t.Fatal(err)
}
got = got[:n]
t.Logf("got: %q", got)
if string(got) != want {
t.Errorf("got %q, want %q", got, want)
}
if from.(*net.UDPAddr).AddrPort().Addr() != s2ip {
t.Errorf("got from %v, want %v", from, s2ip)
}
// Write a response back to s2
if _, err := pc.WriteTo([]byte("world"), from); err != nil {
t.Fatal(err)
}
// Receive the response on s2
got = make([]byte, 1024)
n, err = w.Read(got)
if err != nil {
t.Fatal(err)
}
got = got[:n]
t.Logf("got: %q", got)
if string(got) != "world" {
t.Errorf("got %q, want world", got)
}
}
func parseMetrics(m []byte) (map[string]float64, error) {
metrics := make(map[string]float64)
var parser expfmt.TextParser
mf, err := parser.TextToMetricFamilies(bytes.NewReader(m))
if err != nil {
return nil, err
}
for _, f := range mf {
for _, ff := range f.Metric {
val := float64(0)
switch f.GetType() {
case dto.MetricType_COUNTER:
val = ff.GetCounter().GetValue()
case dto.MetricType_GAUGE:
val = ff.GetGauge().GetValue()
}
metrics[f.GetName()+promMetricLabelsStr(ff.GetLabel())] = val
}
}
return metrics, nil
}
func promMetricLabelsStr(labels []*dto.LabelPair) string {
if len(labels) == 0 {
return ""
}
var b strings.Builder
b.WriteString("{")
for i, lb := range labels {
if i > 0 {
b.WriteString(",")
}
b.WriteString(fmt.Sprintf("%s=%q", lb.GetName(), lb.GetValue()))
}
b.WriteString("}")
return b.String()
}
// sendData sends a given amount of bytes from s1 to s2.
func sendData(logf func(format string, args ...any), ctx context.Context, bytesCount int, s1, s2 *Server, s1ip, s2ip netip.Addr) error {
lb := must.Get(s1.Listen("tcp", fmt.Sprintf("%s:8081", s1ip)))
defer lb.Close()
// Dial to s1 from s2
w, err := s2.Dial(ctx, "tcp", fmt.Sprintf("%s:8081", s1ip))
if err != nil {
return err
}
defer w.Close()
stopReceive := make(chan struct{})
defer close(stopReceive)
allReceived := make(chan error)
defer close(allReceived)
go func() {
conn, err := lb.Accept()
if err != nil {
allReceived <- err
return
}
conn.SetWriteDeadline(time.Now().Add(30 * time.Second))
total := 0
recvStart := time.Now()
for {
got := make([]byte, bytesCount)
n, err := conn.Read(got)
if err != nil {
allReceived <- fmt.Errorf("failed reading packet, %s", err)
return
}
got = got[:n]
select {
case <-stopReceive:
return
default:
}
total += n
logf("received %d/%d bytes, %.2f %%", total, bytesCount, (float64(total) / (float64(bytesCount)) * 100))
// Validate the received bytes to be the same as the sent bytes.
for _, b := range string(got) {
if b != 'A' {
allReceived <- fmt.Errorf("received unexpected byte: %c", b)
return
}
}
if total == bytesCount {
break
}
}
logf("all received, took: %s", time.Since(recvStart).String())
allReceived <- nil
}()
sendStart := time.Now()
w.SetWriteDeadline(time.Now().Add(30 * time.Second))
if _, err := w.Write(bytes.Repeat([]byte("A"), bytesCount)); err != nil {
stopReceive <- struct{}{}
return err
}
logf("all sent (%s), waiting for all packets (%d) to be received", time.Since(sendStart).String(), bytesCount)
err, _ = <-allReceived
if err != nil {
return err
}
return nil
}
func TestUserMetricsByteCounters(t *testing.T) {
tstest.Shard(t)
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
controlURL, _ := startControl(t)
s1, s1ip, _ := startServer(t, ctx, controlURL, "s1")
defer s1.Close()
s2, s2ip, _ := startServer(t, ctx, controlURL, "s2")
defer s2.Close()
lc1, err := s1.LocalClient()
if err != nil {
t.Fatal(err)
}
lc2, err := s2.LocalClient()
if err != nil {
t.Fatal(err)
}
// Force an update to the netmap to ensure that the metrics are up-to-date.
s1.lb.DebugForceNetmapUpdate()
s2.lb.DebugForceNetmapUpdate()
// Wait for both nodes to have a peer in their netmap.
waitForCondition(t, "waiting for netmaps to contain peer", 90*time.Second, func() bool {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
status1, err := lc1.Status(ctx)
if err != nil {
t.Logf("getting status: %s", err)
return false
}
status2, err := lc2.Status(ctx)
if err != nil {
t.Logf("getting status: %s", err)
return false
}
return len(status1.Peers()) > 0 && len(status2.Peers()) > 0
})
// ping to make sure the connection is up.
res, err := lc2.Ping(ctx, s1ip, tailcfg.PingICMP)
if err != nil {
t.Fatalf("pinging: %s", err)
}
t.Logf("ping success: %#+v", res)
mustDirect(t, t.Logf, lc1, lc2)
// 1 megabytes
bytesToSend := 1 * 1024 * 1024
// This asserts generates some traffic, it is factored out
// of TestUDPConn.
start := time.Now()
err = sendData(t.Logf, ctx, bytesToSend, s1, s2, s1ip, s2ip)
if err != nil {
t.Fatalf("Failed to send packets: %v", err)
}
t.Logf("Sent %d bytes from s1 to s2 in %s", bytesToSend, time.Since(start).String())
ctxLc, cancelLc := context.WithTimeout(context.Background(), 5*time.Second)
defer cancelLc()
metrics1, err := lc1.UserMetrics(ctxLc)
if err != nil {
t.Fatal(err)
}
parsedMetrics1, err := parseMetrics(metrics1)
if err != nil {
t.Fatal(err)
}
// Allow the metrics for the bytes sent to be off by 15%.
bytesSentTolerance := 1.15
t.Logf("Metrics1:\n%s\n", metrics1)
// Verify that the amount of data recorded in bytes is higher or equal to the data sent
inboundBytes1 := parsedMetrics1[`tailscaled_inbound_bytes_total{path="direct_ipv4"}`]
if inboundBytes1 < float64(bytesToSend) {
t.Errorf(`metrics1, tailscaled_inbound_bytes_total{path="direct_ipv4"}: expected higher (or equal) than %d, got: %f`, bytesToSend, inboundBytes1)
}
// But ensure that it is not too much higher than the data sent.
if inboundBytes1 > float64(bytesToSend)*bytesSentTolerance {
t.Errorf(`metrics1, tailscaled_inbound_bytes_total{path="direct_ipv4"}: expected lower than %f, got: %f`, float64(bytesToSend)*bytesSentTolerance, inboundBytes1)
}
metrics2, err := lc2.UserMetrics(ctx)
if err != nil {
t.Fatal(err)
}
parsedMetrics2, err := parseMetrics(metrics2)
if err != nil {
t.Fatal(err)
}
t.Logf("Metrics2:\n%s\n", metrics2)
// Verify that the amount of data recorded in bytes is higher or equal than the data sent.
outboundBytes2 := parsedMetrics2[`tailscaled_outbound_bytes_total{path="direct_ipv4"}`]
if outboundBytes2 < float64(bytesToSend) {
t.Errorf(`metrics2, tailscaled_outbound_bytes_total{path="direct_ipv4"}: expected higher (or equal) than %d, got: %f`, bytesToSend, outboundBytes2)
}
// But ensure that it is not too much higher than the data sent.
if outboundBytes2 > float64(bytesToSend)*bytesSentTolerance {
t.Errorf(`metrics2, tailscaled_outbound_bytes_total{path="direct_ipv4"}: expected lower than %f, got: %f`, float64(bytesToSend)*bytesSentTolerance, outboundBytes2)
}
}
func TestUserMetricsRouteGauges(t *testing.T) {
tstest.Shard(t)
// Windows does not seem to support or report back routes when running in
// userspace via tsnet. So, we skip this check on Windows.
// TODO(kradalby): Figure out if this is correct.
if runtime.GOOS == "windows" {
t.Skipf("skipping on windows")
}
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
defer cancel()
controlURL, c := startControl(t)
s1, _, s1PubKey := startServer(t, ctx, controlURL, "s1")
defer s1.Close()
s2, _, _ := startServer(t, ctx, controlURL, "s2")
defer s2.Close()
s1.lb.EditPrefs(&ipn.MaskedPrefs{
Prefs: ipn.Prefs{
AdvertiseRoutes: []netip.Prefix{
netip.MustParsePrefix("192.0.2.0/24"),
netip.MustParsePrefix("192.0.3.0/24"),
netip.MustParsePrefix("192.0.5.1/32"),
netip.MustParsePrefix("0.0.0.0/0"),
},
},
AdvertiseRoutesSet: true,
})
c.SetSubnetRoutes(s1PubKey, []netip.Prefix{
netip.MustParsePrefix("192.0.2.0/24"),
netip.MustParsePrefix("192.0.5.1/32"),
netip.MustParsePrefix("0.0.0.0/0"),
})
lc1, err := s1.LocalClient()
if err != nil {
t.Fatal(err)
}
lc2, err := s2.LocalClient()
if err != nil {
t.Fatal(err)
}
// Force an update to the netmap to ensure that the metrics are up-to-date.
s1.lb.DebugForceNetmapUpdate()
s2.lb.DebugForceNetmapUpdate()
wantRoutes := float64(2)
// Wait for the routes to be propagated to node 1 to ensure
// that the metrics are up-to-date.
waitForCondition(t, "primary routes available for node1", 90*time.Second, func() bool {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
status1, err := lc1.Status(ctx)
if err != nil {
t.Logf("getting status: %s", err)
return false
}
// Wait for the primary routes to reach our desired routes, which is wantRoutes + 1, because
// the PrimaryRoutes list will contain a exit node route, which the metric does not count.
return status1.Self.PrimaryRoutes != nil && status1.Self.PrimaryRoutes.Len() == int(wantRoutes)+1
})
ctxLc, cancelLc := context.WithTimeout(context.Background(), 5*time.Second)
defer cancelLc()
metrics1, err := lc1.UserMetrics(ctxLc)
if err != nil {
t.Fatal(err)
}
parsedMetrics1, err := parseMetrics(metrics1)
if err != nil {
t.Fatal(err)
}
t.Logf("Metrics1:\n%s\n", metrics1)
// The node is advertising 4 routes:
// - 192.0.2.0/24
// - 192.0.3.0/24
// - 192.0.5.1/32
if got, want := parsedMetrics1["tailscaled_advertised_routes"], 3.0; got != want {
t.Errorf("metrics1, tailscaled_advertised_routes: got %v, want %v", got, want)
}
// The control has approved 2 routes:
// - 192.0.2.0/24
// - 192.0.5.1/32
if got, want := parsedMetrics1["tailscaled_approved_routes"], wantRoutes; got != want {
t.Errorf("metrics1, tailscaled_approved_routes: got %v, want %v", got, want)
}
metrics2, err := lc2.UserMetrics(ctx)
if err != nil {
t.Fatal(err)
}
parsedMetrics2, err := parseMetrics(metrics2)
if err != nil {
t.Fatal(err)
}
t.Logf("Metrics2:\n%s\n", metrics2)
// The node is advertising 0 routes
if got, want := parsedMetrics2["tailscaled_advertised_routes"], 0.0; got != want {
t.Errorf("metrics2, tailscaled_advertised_routes: got %v, want %v", got, want)
}
// The control has approved 0 routes
if got, want := parsedMetrics2["tailscaled_approved_routes"], 0.0; got != want {
t.Errorf("metrics2, tailscaled_approved_routes: got %v, want %v", got, want)
}
}
func waitForCondition(t *testing.T, msg string, waitTime time.Duration, f func() bool) {
t.Helper()
for deadline := time.Now().Add(waitTime); time.Now().Before(deadline); time.Sleep(1 * time.Second) {
if f() {
return
}
}
t.Fatalf("waiting for condition: %s", msg)
}
// mustDirect ensures there is a direct connection between LocalClient 1 and 2
func mustDirect(t *testing.T, logf logger.Logf, lc1, lc2 *local.Client) {
t.Helper()
lastLog := time.Now().Add(-time.Minute)
// See https://github.com/tailscale/tailscale/issues/654
// and https://github.com/tailscale/tailscale/issues/3247 for discussions of this deadline.
for deadline := time.Now().Add(30 * time.Second); time.Now().Before(deadline); time.Sleep(10 * time.Millisecond) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
status1, err := lc1.Status(ctx)
if err != nil {
continue
}
status2, err := lc2.Status(ctx)
if err != nil {
continue
}
pst := status1.Peer[status2.Self.PublicKey]
if pst.CurAddr != "" {
logf("direct link %s->%s found with addr %s", status1.Self.HostName, status2.Self.HostName, pst.CurAddr)
return
}
if now := time.Now(); now.Sub(lastLog) > time.Second {
logf("no direct path %s->%s yet, addrs %v", status1.Self.HostName, status2.Self.HostName, pst.Addrs)
lastLog = now
}
}
t.Error("magicsock did not find a direct path from lc1 to lc2")
}
func TestDeps(t *testing.T) {
tstest.Shard(t)
deptest.DepChecker{
GOOS: "linux",
GOARCH: "amd64",
OnDep: func(dep string) {
if strings.Contains(dep, "portlist") {
t.Errorf("unexpected dep: %q", dep)
}
},
}.Check(t)
}
func TestResolveAuthKey(t *testing.T) {
tests := []struct {
name string
authKey string
clientSecret string
clientID string
idToken string
audience string
oauthAvailable bool
wifAvailable bool
resolveViaOAuth func(ctx context.Context, clientSecret string, tags []string) (string, error)
resolveViaWIF func(ctx context.Context, baseURL, clientID, idToken, audience string, tags []string) (string, error)
wantAuthKey string
wantErr bool
wantErrContains string
}{
{
name: "successful resolution via OAuth client secret",
clientSecret: "tskey-client-secret-123",
oauthAvailable: true,
resolveViaOAuth: func(ctx context.Context, clientSecret string, tags []string) (string, error) {
if clientSecret != "tskey-client-secret-123" {
return "", fmt.Errorf("unexpected client secret: %s", clientSecret)
}
return "tskey-auth-via-oauth", nil
},
wantAuthKey: "tskey-auth-via-oauth",
wantErrContains: "",
},
{
name: "failing resolution via OAuth client secret",
clientSecret: "tskey-client-secret-123",
oauthAvailable: true,
resolveViaOAuth: func(ctx context.Context, clientSecret string, tags []string) (string, error) {
return "", fmt.Errorf("resolution failed")
},
wantErrContains: "resolution failed",
},
{
name: "successful resolution via federated ID token",
clientID: "client-id-123",
idToken: "id-token-456",
wifAvailable: true,
resolveViaWIF: func(ctx context.Context, baseURL, clientID, idToken, audience string, tags []string) (string, error) {
if clientID != "client-id-123" {
return "", fmt.Errorf("unexpected client ID: %s", clientID)
}
if idToken != "id-token-456" {
return "", fmt.Errorf("unexpected ID token: %s", idToken)
}
return "tskey-auth-via-wif", nil
},
wantAuthKey: "tskey-auth-via-wif",
wantErrContains: "",
},
{
name: "successful resolution via federated audience",
clientID: "client-id-123",
audience: "api.tailscale.com",
wifAvailable: true,
resolveViaWIF: func(ctx context.Context, baseURL, clientID, idToken, audience string, tags []string) (string, error) {
if clientID != "client-id-123" {
return "", fmt.Errorf("unexpected client ID: %s", clientID)
}
if audience != "api.tailscale.com" {
return "", fmt.Errorf("unexpected ID token: %s", idToken)
}
return "tskey-auth-via-wif", nil
},
wantAuthKey: "tskey-auth-via-wif",
wantErrContains: "",
},
{
name: "failing resolution via federated ID token",
clientID: "client-id-123",
idToken: "id-token-456",
wifAvailable: true,
resolveViaWIF: func(ctx context.Context, baseURL, clientID, idToken, audience string, tags []string) (string, error) {
return "", fmt.Errorf("resolution failed")
},
wantErrContains: "resolution failed",
},
{
name: "empty client ID with ID token",
clientID: "",
idToken: "id-token-456",
wifAvailable: true,
resolveViaWIF: func(ctx context.Context, baseURL, clientID, idToken, audience string, tags []string) (string, error) {
return "", fmt.Errorf("should not be called")
},
wantErrContains: "empty",
},
{
name: "empty client ID with audience",
clientID: "",
audience: "api.tailscale.com",
wifAvailable: true,
resolveViaWIF: func(ctx context.Context, baseURL, clientID, idToken, audience string, tags []string) (string, error) {
return "", fmt.Errorf("should not be called")
},
wantErrContains: "empty",
},
{
name: "empty ID token",
clientID: "client-id-123",
idToken: "",
wifAvailable: true,
resolveViaWIF: func(ctx context.Context, baseURL, clientID, idToken, audience string, tags []string) (string, error) {
return "", fmt.Errorf("should not be called")
},
wantErrContains: "empty",
},
{
name: "audience with ID token",
clientID: "client-id-123",
idToken: "id-token-456",
audience: "api.tailscale.com",
wifAvailable: true,
resolveViaWIF: func(ctx context.Context, baseURL, clientID, idToken, audience string, tags []string) (string, error) {
return "", fmt.Errorf("should not be called")
},
wantErrContains: "only one of ID token and audience",
},
{
name: "workload identity resolution skipped if resolution via OAuth token succeeds",
clientSecret: "tskey-client-secret-123",
oauthAvailable: true,
resolveViaOAuth: func(ctx context.Context, clientSecret string, tags []string) (string, error) {
if clientSecret != "tskey-client-secret-123" {
return "", fmt.Errorf("unexpected client secret: %s", clientSecret)
}
return "tskey-auth-via-oauth", nil
},
wifAvailable: true,
resolveViaWIF: func(ctx context.Context, baseURL, clientID, idToken, audience string, tags []string) (string, error) {
return "", fmt.Errorf("should not be called")
},
wantAuthKey: "tskey-auth-via-oauth",
wantErrContains: "",
},
{
name: "workload identity resolution skipped if resolution via OAuth token fails",
clientID: "tskey-client-id-123",
idToken: "",
oauthAvailable: true,
resolveViaOAuth: func(ctx context.Context, clientSecret string, tags []string) (string, error) {
return "", fmt.Errorf("resolution failed")
},
wifAvailable: true,
resolveViaWIF: func(ctx context.Context, baseURL, clientID, idToken, audience string, tags []string) (string, error) {
return "", fmt.Errorf("should not be called")
},
wantErrContains: "failed",
},
{
name: "authkey set and no resolution available",
authKey: "tskey-auth-123",
oauthAvailable: false,
wifAvailable: false,
wantAuthKey: "tskey-auth-123",
wantErrContains: "",
},
{
name: "no authkey set and no resolution available",
oauthAvailable: false,
wifAvailable: false,
wantAuthKey: "",
wantErrContains: "",
},
{
name: "authkey is client secret and resolution via OAuth client secret succeeds",
authKey: "tskey-client-secret-123",
oauthAvailable: true,
resolveViaOAuth: func(ctx context.Context, clientSecret string, tags []string) (string, error) {
if clientSecret != "tskey-client-secret-123" {
return "", fmt.Errorf("unexpected client secret: %s", clientSecret)
}
return "tskey-auth-via-oauth", nil
},
wantAuthKey: "tskey-auth-via-oauth",
wantErrContains: "",
},
{
name: "authkey is client secret but resolution via OAuth client secret fails",
authKey: "tskey-client-secret-123",
oauthAvailable: true,
resolveViaOAuth: func(ctx context.Context, clientSecret string, tags []string) (string, error) {
return "", fmt.Errorf("resolution failed")
},
wantErrContains: "resolution failed",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.oauthAvailable {
t.Cleanup(tailscale.HookResolveAuthKey.SetForTest(tt.resolveViaOAuth))
}
if tt.wifAvailable {
t.Cleanup(tailscale.HookResolveAuthKeyViaWIF.SetForTest(tt.resolveViaWIF))
}
s := &Server{
AuthKey: tt.authKey,
ClientSecret: tt.clientSecret,
ClientID: tt.clientID,
IDToken: tt.idToken,
Audience: tt.audience,
ControlURL: "https://control.example.com",
}
s.shutdownCtx = context.Background()
gotAuthKey, err := s.resolveAuthKey()
if tt.wantErrContains != "" {
if err == nil {
t.Errorf("expected error but got none")
return
}
if !strings.Contains(err.Error(), tt.wantErrContains) {
t.Errorf("expected error containing %q but got error: %v", tt.wantErrContains, err)
}
return
}
if err != nil {
t.Errorf("resolveAuthKey expected no error but got error: %v", err)
return
}
if gotAuthKey != tt.wantAuthKey {
t.Errorf("resolveAuthKey() = %q, want %q", gotAuthKey, tt.wantAuthKey)
}
})
}
}