diff --git a/.github/workflows/installer.yml b/.github/workflows/installer.yml index 6fc8913c4..9b49c4c07 100644 --- a/.github/workflows/installer.yml +++ b/.github/workflows/installer.yml @@ -37,8 +37,6 @@ jobs: - "elementary/docker:stable" - "elementary/docker:unstable" - "parrotsec/core:latest" - - "kalilinux/kali-rolling" - - "kalilinux/kali-dev" - "oraclelinux:9" - "oraclelinux:8" - "fedora:latest" @@ -61,6 +59,9 @@ jobs: - { image: "debian:stable-slim", deps: "curl" } - { image: "ubuntu:24.04", deps: "curl" } - { image: "fedora:latest", deps: "curl" } + # Kali doesn't have ca-certificates installed by default anymore + - { image: "kalilinux/kali-dev", "deps": "curl ca-certificates"} + - { image: "kalilinux/kali-rolling", "deps": "curl ca-certificates"} # Test TAILSCALE_VERSION pinning on a subset of distros. # Skip Alpine as community repos don't reliably keep old versions. - { image: "debian:stable-slim", deps: "curl", version: "1.80.0" } diff --git a/.github/workflows/natlab-integrationtest.yml b/.github/workflows/natlab-basic.yml similarity index 52% rename from .github/workflows/natlab-integrationtest.yml rename to .github/workflows/natlab-basic.yml index 162153cb2..1a19acfb8 100644 --- a/.github/workflows/natlab-integrationtest.yml +++ b/.github/workflows/natlab-basic.yml @@ -1,6 +1,7 @@ -# Run some natlab integration tests. +# Run a single natlab smoke test on every PR. The full natlab suite +# is opt-in and lives in .github/workflows/natlab-test.yml. # See https://github.com/tailscale/tailscale/issues/13038 -name: "natlab-integrationtest" +name: "natlab-basic" concurrency: group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} @@ -17,17 +18,28 @@ on: branches: - "main" jobs: - natlab-integrationtest: + EasyEasy: runs-on: ubuntu-latest steps: - name: Check out code uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + - name: Enable KVM + run: | + echo 'KERNEL=="kvm", GROUP="kvm", MODE="0666", OPTIONS+="static_node=kvm"' | sudo tee /etc/udev/rules.d/99-kvm4all.rules + sudo udevadm control --reload-rules + sudo udevadm trigger --name-match=kvm - name: Install qemu run: | sudo rm -f /var/lib/man-db/auto-update sudo apt-get -y update sudo apt-get -y remove man-db sudo apt-get install -y qemu-system-x86 qemu-utils + - name: Build VM image + # The test will build this if missing, but we do it explicitly + # to avoid cutting into the go test -timeout budget, and to + # fail earlier with a clearer error if the image build breaks. + run: | + make -C gokrazy natlab - name: Run natlab integration tests run: | - ./tool/go test -v -run=^TestEasyEasy$ -timeout=3m -count=1 ./tstest/integration/nat --run-vm-tests + ./tool/go test -v -run=^TestEasyEasy$ -timeout=3m -count=1 ./tstest/natlab/vmtest --run-vm-tests diff --git a/.github/workflows/natlab-test.yml b/.github/workflows/natlab-test.yml new file mode 100644 index 000000000..4f53c4ce4 --- /dev/null +++ b/.github/workflows/natlab-test.yml @@ -0,0 +1,182 @@ +# Run the full natlab/vmtest opt-in test suite. These tests boot QEMU VMs +# (gokrazy, Ubuntu, FreeBSD) and exercise vnet-driven networking scenarios. +# They are gated behind --run-vm-tests because they need KVM and are slow. +# +# This workflow runs: +# - on demand (workflow_dispatch) +# - on PRs that carry the "run-natlab-tests" label +# - on main, every 12 hours, via cron +# +# Layout: +# - "prepare" builds the gokrazy VM image, downloads the cloud images +# (Ubuntu, FreeBSD), and discovers every Test* function in the two +# opt-in packages. +# - "test" is a per-TestFoo matrix that depends on prepare. Each matrix +# job restores the shared caches and runs a single test. Adding a new +# TestFoo automatically gets its own job — no workflow edits needed. +# +# A separate workflow (.github/workflows/natlab-basic.yml) runs a single +# canary natlab test on every PR; this one runs the full suite. +name: "natlab-test" + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} + cancel-in-progress: true + +on: + workflow_dispatch: + pull_request: + types: [labeled, synchronize, reopened] + schedule: + # Every 12 hours, off-the-hour to avoid GitHub's :00 cron-stampede window. + - cron: "23 3,15 * * *" + +jobs: + # prepare warms the per-workflow-run caches (gokrazy image, cloud VM + # images) and emits the dynamic matrix of test names. By doing the work + # once here, the matrix test jobs never race to rebuild or re-download + # the same artifacts on a cold cache. + prepare: + if: | + github.event_name == 'workflow_dispatch' || + github.event_name == 'schedule' || + (github.event_name == 'pull_request' && contains(github.event.pull_request.labels.*.name, 'run-natlab-tests')) + runs-on: ubuntu-latest + timeout-minutes: 30 + outputs: + matrix: ${{ steps.list.outputs.matrix }} + steps: + - name: Check out code + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + + # The cloud VM image cache is keyed only on images.go (image URLs and + # SHAs), so it survives across workflow runs and is invalidated only + # when a new image source is added. + - name: Cache cloud VM images + uses: actions/cache@668228422ae6a00e4ad889ee87cd7109ec5666a7 # v5.0.4 + with: + path: ~/.cache/tailscale/vmtest/images + key: natlab-vmimages-${{ hashFiles('tstest/natlab/vmtest/images.go') }} + + # The gokrazy VM image is keyed by github.sha. That means we rebuild + # it once per commit but matrix test jobs in the same run all share + # the result. Per-PR re-runs of the same sha (e.g. a rerun-failed) + # also get the cache. + - name: Cache gokrazy VM image + id: gokrazy-cache + uses: actions/cache@668228422ae6a00e4ad889ee87cd7109ec5666a7 # v5.0.4 + with: + path: gokrazy/natlabapp.qcow2 + key: natlab-gokrazy-${{ github.sha }} + + # qemu-utils provides qemu-img, which the gokrazy Makefile uses to + # convert natlabapp.img to qcow2. Only install if we need it (cache + # miss); the test matrix jobs install qemu separately for the runtime. + - name: Install qemu-utils + if: steps.gokrazy-cache.outputs.cache-hit != 'true' + run: | + sudo rm -f /var/lib/man-db/auto-update + sudo apt-get -y update + sudo apt-get -y remove man-db + sudo apt-get install -y qemu-utils + + - name: Download cloud VM images + # natlabprep is idempotent: it checks the cache before downloading. + run: | + ./tool/go run ./tstest/natlab/vmtest/cmd/natlabprep + + - name: Build gokrazy VM image + if: steps.gokrazy-cache.outputs.cache-hit != 'true' + run: | + make -C gokrazy natlab + + - name: Discover tests + id: list + # Grep the test files directly rather than invoking `go test -list` + # so we don't pay the cost of compiling the test binaries here. The + # only test functions in these packages use the canonical + # `func TestFoo(t *testing.T)` signature. + # + # exclude is the set of tests that need special invocation + # (extra flags, a specific environment) and don't fit the + # single-test-per-matrix-job model. They stay runnable locally. + run: | + set -euo pipefail + exclude='^(TestGrid)$' + tmp=$(mktemp) + for pkg_dir in tstest/natlab/vmtest tstest/integration/nat; do + pkg="./${pkg_dir}/" + for f in "${pkg_dir}"/*_test.go; do + [ -e "$f" ] || continue + grep -hE '^func Test[A-Z][A-Za-z0-9_]*\(t \*testing\.T\)' "$f" \ + | sed -E 's/^func (Test[A-Za-z0-9_]+).*/\1/' \ + | grep -vE "$exclude" \ + | while read -r t; do + jq -nc --arg pkg "$pkg" --arg test "$t" \ + '{pkg: $pkg, test: $test}' >> "$tmp" + done + done + done + matrix=$(jq -s -c . "$tmp") + echo "matrix=${matrix}" >> "$GITHUB_OUTPUT" + echo "Discovered tests:" + jq . "$tmp" + + test: + needs: prepare + runs-on: ubuntu-latest + timeout-minutes: 20 + name: "${{ matrix.test }}" + strategy: + fail-fast: false + matrix: + include: ${{ fromJson(needs.prepare.outputs.matrix) }} + steps: + - name: Check out code + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + + - name: Enable KVM + run: | + echo 'KERNEL=="kvm", GROUP="kvm", MODE="0666", OPTIONS+="static_node=kvm"' | sudo tee /etc/udev/rules.d/99-kvm4all.rules + sudo udevadm control --reload-rules + sudo udevadm trigger --name-match=kvm + + - name: Install qemu + run: | + sudo rm -f /var/lib/man-db/auto-update + sudo apt-get -y update + sudo apt-get -y remove man-db + sudo apt-get install -y qemu-system-x86 qemu-utils + + # restore-only: prepare is the single writer of these caches, so + # matrix jobs don't write back. fail-on-cache-miss would be too + # strict for the gokrazy cache (e.g. a non-fatal cache eviction + # between prepare and us); we just rebuild on miss instead. + - name: Restore cloud VM images + uses: actions/cache/restore@668228422ae6a00e4ad889ee87cd7109ec5666a7 # v5.0.4 + with: + path: ~/.cache/tailscale/vmtest/images + key: natlab-vmimages-${{ hashFiles('tstest/natlab/vmtest/images.go') }} + + - name: Restore gokrazy VM image + uses: actions/cache/restore@668228422ae6a00e4ad889ee87cd7109ec5666a7 # v5.0.4 + with: + path: gokrazy/natlabapp.qcow2 + key: natlab-gokrazy-${{ github.sha }} + + # The gokrazy-based tests boot the kernel directly from + # vmlinuz that ships in the tailscale/gokrazy-kernel module. + # Tests look it up under GOMODCACHE via findKernelPath, so the + # module has to be present even though no Go source imports it + # in the test package itself. + - name: Download gokrazy-kernel module + run: | + ./tool/go mod download github.com/tailscale/gokrazy-kernel + + - name: Run ${{ matrix.test }} + # Per-test timeout is well above the few-minute typical runtime + # but small enough that a stuck test fails fast instead of holding + # the runner for the job's 20-minute budget. + run: | + ./tool/go test -v -timeout=15m -count=1 ${{ matrix.pkg }} \ + -run='^${{ matrix.test }}$' --run-vm-tests diff --git a/.github/workflows/ssh-integrationtest.yml b/.github/workflows/ssh-integrationtest.yml index afe2dd2f7..84432cd72 100644 --- a/.github/workflows/ssh-integrationtest.yml +++ b/.github/workflows/ssh-integrationtest.yml @@ -1,5 +1,5 @@ -# Run the ssh integration tests with `make sshintegrationtest`. -# These tests can also be running locally. +# Run the ssh integration tests in various Docker containers. +# These tests can also be run locally via `make sshintegrationtest`. name: "ssh-integrationtest" concurrency: @@ -15,9 +15,25 @@ on: jobs: ssh-integrationtest: runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + include: + - base: "ubuntu:focal" + tag: "ssh-ubuntu-focal" + - base: "ubuntu:jammy" + tag: "ssh-ubuntu-jammy" + - base: "ubuntu:noble" + tag: "ssh-ubuntu-noble" + - base: "alpine:latest" + tag: "ssh-alpine-latest" steps: - name: Check out code uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - - name: Run SSH integration tests + - name: Build test binaries run: | - make sshintegrationtest \ No newline at end of file + GOOS=linux GOARCH=amd64 CGO_ENABLED=0 ./tool/go test -tags integrationtest -c ./ssh/tailssh -o ssh/tailssh/testcontainers/tailssh.test + GOOS=linux GOARCH=amd64 CGO_ENABLED=0 ./tool/go build -o ssh/tailssh/testcontainers/tailscaled ./cmd/tailscaled + - name: Run SSH integration tests (${{ matrix.base }}) + run: | + docker build --build-arg="BASE=${{ matrix.base }}" -t "${{ matrix.tag }}" ssh/tailssh/testcontainers diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 38ebd1291..ded7873aa 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -361,7 +361,7 @@ jobs: run: chown -R $(id -u):$(id -g) $PWD - name: privileged tests working-directory: src - run: ./tool/go test ./util/linuxfw ./derp/xdp + run: ./tool/go test $(./tool/go run ./tool/listpkgs --has-root-tests) vm: needs: gomod-cache @@ -787,6 +787,14 @@ jobs: echo echo git diff --name-only --exit-code || (echo "The files above need updating. Please run 'go generate'."; exit 1) + - name: check that 'genreadme' is clean + working-directory: src + run: | + ./tool/go run ./misc/genreadme + git add -N . # ensure untracked files are noticed + echo + echo + git diff --name-only --exit-code || (echo "The files above need updating. Please run './tool/go run ./misc/genreadme'."; exit 1) make_tidy: runs-on: ubuntu-24.04 diff --git a/.github/workflows/update-flake.yml b/.github/workflows/update-flake.yml index 1304fb222..ce77cf651 100644 --- a/.github/workflows/update-flake.yml +++ b/.github/workflows/update-flake.yml @@ -23,8 +23,8 @@ jobs: - name: Check out code uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - - name: Run update-flakes - run: ./update-flake.sh + - name: Run updateflakes + run: ./tool/go run ./tool/updateflakes - name: Get access token uses: actions/create-github-app-token@f8d387b68d61c58ab83c6c016672934102569859 # v3.0.0 @@ -41,8 +41,8 @@ jobs: author: Flakes Updater committer: Flakes Updater branch: flakes - commit-message: "go.mod.sri: update SRI hash for go.mod changes" - title: "go.mod.sri: update SRI hash for go.mod changes" + commit-message: "flakehashes.json: update SRI hash for go.mod changes" + title: "flakehashes.json: update SRI hash for go.mod changes" body: Triggered by ${{ github.repository }}@${{ github.sha }} signoff: true delete-branch: true diff --git a/.gitignore b/.gitignore index 4bfabc80f..e1f6be02e 100644 --- a/.gitignore +++ b/.gitignore @@ -1,12 +1,15 @@ # Binaries for programs and plugins *~ *.tmp -*.exe *.dll *.so *.dylib *.spk +*.exe +# tool/go.exe is built specially and committed. +!/tool/go.exe + cmd/tailscale/tailscale cmd/tailscaled/tailscaled ssh/tailssh/testcontainers/tailscaled diff --git a/Makefile b/Makefile index b78ef0469..0efd57fb4 100644 --- a/Makefile +++ b/Makefile @@ -10,7 +10,7 @@ vet: ## Run go vet tidy: ## Run go mod tidy and update nix flake hashes ./tool/go mod tidy - ./update-flake.sh + ./tool/go run ./tool/updateflakes lint: ## Run golangci-lint ./tool/go run github.com/golangci/golangci-lint/cmd/golangci-lint run @@ -137,10 +137,12 @@ publishdevproxy: check-image-repo ## Build and publish k8s-proxy image to locati sshintegrationtest: ## Run the SSH integration tests in various Docker containers @GOOS=linux GOARCH=amd64 CGO_ENABLED=0 ./tool/go test -tags integrationtest -c ./ssh/tailssh -o ssh/tailssh/testcontainers/tailssh.test && \ GOOS=linux GOARCH=amd64 CGO_ENABLED=0 ./tool/go build -o ssh/tailssh/testcontainers/tailscaled ./cmd/tailscaled && \ - echo "Testing on ubuntu:focal" && docker build --build-arg="BASE=ubuntu:focal" -t ssh-ubuntu-focal ssh/tailssh/testcontainers && \ - echo "Testing on ubuntu:jammy" && docker build --build-arg="BASE=ubuntu:jammy" -t ssh-ubuntu-jammy ssh/tailssh/testcontainers && \ - echo "Testing on ubuntu:noble" && docker build --build-arg="BASE=ubuntu:noble" -t ssh-ubuntu-noble ssh/tailssh/testcontainers && \ - echo "Testing on alpine:latest" && docker build --build-arg="BASE=alpine:latest" -t ssh-alpine-latest ssh/tailssh/testcontainers + echo "Testing on ubuntu:focal, ubuntu:jammy, ubuntu:noble, alpine:latest (in parallel)" && \ + docker build --build-arg="BASE=ubuntu:focal" -t ssh-ubuntu-focal ssh/tailssh/testcontainers & \ + docker build --build-arg="BASE=ubuntu:jammy" -t ssh-ubuntu-jammy ssh/tailssh/testcontainers & \ + docker build --build-arg="BASE=ubuntu:noble" -t ssh-ubuntu-noble ssh/tailssh/testcontainers & \ + docker build --build-arg="BASE=alpine:latest" -t ssh-alpine-latest ssh/tailssh/testcontainers & \ + wait .PHONY: generate generate: ## Generate code diff --git a/VERSION.txt b/VERSION.txt index acbb747ac..9eb2e1ff9 100644 --- a/VERSION.txt +++ b/VERSION.txt @@ -1 +1 @@ -1.97.0 +1.99.0 diff --git a/appc/appconnector_test.go b/appc/appconnector_test.go index d14ef68fc..c58aa8041 100644 --- a/appc/appconnector_test.go +++ b/appc/appconnector_test.go @@ -736,6 +736,7 @@ func TestRateLogger(t *testing.T) { } func TestRouteStoreMetrics(t *testing.T) { + clientmetric.ResetForTest(t) metricStoreRoutes(1, 1) metricStoreRoutes(1, 1) // the 1 buckets value should be 2 metricStoreRoutes(5, 5) // the 5 buckets value should be 1 diff --git a/appc/conn25.go b/appc/conn25.go index fd1748fa6..62cb70017 100644 --- a/appc/conn25.go +++ b/appc/conn25.go @@ -6,6 +6,7 @@ package appc import ( "cmp" "slices" + "strings" "tailscale.com/ipn/ipnext" "tailscale.com/tailcfg" @@ -16,7 +17,7 @@ import ( const AppConnectorsExperimentalAttrName = "tailscale.com/app-connectors-experimental" -func isEligibleConnector(peer tailcfg.NodeView) bool { +func isPeerEligibleConnector(peer tailcfg.NodeView) bool { if !peer.Valid() || !peer.Hostinfo().Valid() { return false } @@ -39,7 +40,7 @@ func sortByPreference(ns []tailcfg.NodeView) { func PickConnector(nb ipnext.NodeBackend, app appctype.Conn25Attr) []tailcfg.NodeView { appTagsSet := set.SetOf(app.Connectors) matches := nb.AppendMatchingPeers(nil, func(n tailcfg.NodeView) bool { - if !isEligibleConnector(n) { + if !isPeerEligibleConnector(n) { return false } for _, t := range n.Tags().All() { @@ -55,7 +56,7 @@ func PickConnector(nb ipnext.NodeBackend, app appctype.Conn25Attr) []tailcfg.Nod // PickSplitDNSPeers looks at the netmap peers capabilities and finds which peers // want to be connectors for which domains. -func PickSplitDNSPeers(hasCap func(c tailcfg.NodeCapability) bool, self tailcfg.NodeView, peers map[tailcfg.NodeID]tailcfg.NodeView) map[string][]tailcfg.NodeView { +func PickSplitDNSPeers(hasCap func(c tailcfg.NodeCapability) bool, self tailcfg.NodeView, peers map[tailcfg.NodeID]tailcfg.NodeView, isSelfEligibleConnector bool) map[string][]tailcfg.NodeView { var m map[string][]tailcfg.NodeView if !hasCap(AppConnectorsExperimentalAttrName) { return m @@ -64,22 +65,43 @@ func PickSplitDNSPeers(hasCap func(c tailcfg.NodeCapability) bool, self tailcfg. if err != nil { return m } - tagToDomain := make(map[string][]string) + + // We strip the leading *. from any domains because the OS treats all domains + // that we pass to it as wildcard domains, and the OS would treat the * character + // as a literal domain component instead of treating it as a wildcard. + // We also use a Set to deduplicate the domains we pass to the OS in case removing + // the *. prefix resulted in duplicate entries. + tagToDomain := make(map[string]set.Set[string]) + selfTags := set.SetOf(self.Tags().AsSlice()) + selfRoutedDomains := set.Set[string]{} for _, app := range apps { + domains := make(set.Set[string]) + for _, domain := range app.Domains { + domains.Add(strings.ToLower(strings.TrimPrefix(domain, "*."))) + } for _, tag := range app.Connectors { - tagToDomain[tag] = append(tagToDomain[tag], app.Domains...) + if tagToDomain[tag] == nil { + tagToDomain[tag] = set.Set[string]{} + } + tagToDomain[tag].AddSet(domains) + if isSelfEligibleConnector && selfTags.Contains(tag) { + selfRoutedDomains.AddSet(domains) + } } } // NodeIDs are Comparable, and we have a map of NodeID to NodeView anyway, so // use a Set of NodeIDs to deduplicate, and populate into a []NodeView later. var work map[string]set.Set[tailcfg.NodeID] for _, peer := range peers { - if !isEligibleConnector(peer) { + if !isPeerEligibleConnector(peer) { continue } for _, t := range peer.Tags().All() { domains := tagToDomain[t] - for _, domain := range domains { + for domain := range domains { + if selfRoutedDomains.Contains(domain) { + continue + } if work[domain] == nil { mak.Set(&work, domain, set.Set[tailcfg.NodeID]{}) } diff --git a/appc/conn25_test.go b/appc/conn25_test.go index fc14caf36..dd98312ca 100644 --- a/appc/conn25_test.go +++ b/appc/conn25_test.go @@ -32,6 +32,8 @@ func TestPickSplitDNSPeers(t *testing.T) { appTwoBytes := getBytesForAttr("app2", []string{"a.example.com"}, []string{"tag:two"}) appThreeBytes := getBytesForAttr("app3", []string{"woo.b.example.com", "hoo.b.example.com"}, []string{"tag:three1", "tag:three2"}) appFourBytes := getBytesForAttr("app4", []string{"woo.b.example.com", "c.example.com"}, []string{"tag:four1", "tag:four2"}) + appFiveBytes := getBytesForAttr("app5", []string{"*.example.com", "example.com"}, []string{"tag:one"}) + appSixBytes := getBytesForAttr("app6", []string{"*.Example.com", "EXAMPLE.com", "EXAMPLE.COM"}, []string{"tag:one"}) makeNodeView := func(id tailcfg.NodeID, name string, tags []string) tailcfg.NodeView { return (&tailcfg.Node{ @@ -47,10 +49,12 @@ func TestPickSplitDNSPeers(t *testing.T) { nvp4 := makeNodeView(4, "p4", []string{"tag:two", "tag:three2", "tag:four2"}) for _, tt := range []struct { - name string - want map[string][]tailcfg.NodeView - peers []tailcfg.NodeView - config []tailcfg.RawMessage + name string + peers []tailcfg.NodeView + config []tailcfg.RawMessage + isEligibleConnector bool + selfTags []string + want map[string][]tailcfg.NodeView }{ { name: "empty", @@ -111,6 +115,128 @@ func TestPickSplitDNSPeers(t *testing.T) { "c.example.com": {nvp2, nvp4}, }, }, + { + name: "self-connector-exclude-self-domains", + config: []tailcfg.RawMessage{ + tailcfg.RawMessage(appOneBytes), + tailcfg.RawMessage(appTwoBytes), + tailcfg.RawMessage(appThreeBytes), + tailcfg.RawMessage(appFourBytes), + }, + peers: []tailcfg.NodeView{ + nvp1, + nvp2, + nvp3, + nvp4, + }, + isEligibleConnector: true, + selfTags: []string{"tag:three1"}, + want: map[string][]tailcfg.NodeView{ + // woo.b.example.com and hoo.b.example.com are covered + // by tag:three1, and so is this self-node. + // So those domains should not be routed to peers. + // woo.b.example.com is also covered by another tag, + // but still not included since this connector can route to it. + "example.com": {nvp1}, + "a.example.com": {nvp3, nvp4}, + "c.example.com": {nvp2, nvp4}, + }, + }, + { + name: "self-eligible-connector-no-matching-tag-include-all-domains", + config: []tailcfg.RawMessage{ + tailcfg.RawMessage(appOneBytes), + tailcfg.RawMessage(appTwoBytes), + tailcfg.RawMessage(appThreeBytes), + tailcfg.RawMessage(appFourBytes), + }, + peers: []tailcfg.NodeView{ + nvp1, + nvp2, + nvp3, + nvp4, + }, + isEligibleConnector: true, + selfTags: []string{"tag:unrelated"}, + want: map[string][]tailcfg.NodeView{ + // Self has prefs set but no tags matching any app, + // so no domains are self-routed and all appear. + "example.com": {nvp1}, + "a.example.com": {nvp3, nvp4}, + "woo.b.example.com": {nvp2, nvp3, nvp4}, + "hoo.b.example.com": {nvp3, nvp4}, + "c.example.com": {nvp2, nvp4}, + }, + }, + { + name: "self-not-eligible-connector-but-tagged-include-all-domains", + config: []tailcfg.RawMessage{ + tailcfg.RawMessage(appOneBytes), + tailcfg.RawMessage(appTwoBytes), + tailcfg.RawMessage(appThreeBytes), + tailcfg.RawMessage(appFourBytes), + }, + peers: []tailcfg.NodeView{ + nvp1, + nvp2, + nvp3, + nvp4, + }, + selfTags: []string{"tag:three1"}, + want: map[string][]tailcfg.NodeView{ + // Even though this self node has a tag for an app + // the prefs don't advertise as connector, so + // should still route through other connectors. + "example.com": {nvp1}, + "a.example.com": {nvp3, nvp4}, + "woo.b.example.com": {nvp2, nvp3, nvp4}, + "hoo.b.example.com": {nvp3, nvp4}, + "c.example.com": {nvp2, nvp4}, + }, + }, + { + name: "wildcards-are-stripped-and-deduped", + config: []tailcfg.RawMessage{ + tailcfg.RawMessage(appOneBytes), + tailcfg.RawMessage(appFiveBytes), + }, + peers: []tailcfg.NodeView{ + nvp1, + }, + want: map[string][]tailcfg.NodeView{ + // All the domains should be normalized to example.com + "example.com": {nvp1}, + }, + }, + { + name: "domains-are-normalized-and-deduped", + config: []tailcfg.RawMessage{ + tailcfg.RawMessage(appSixBytes), + }, + peers: []tailcfg.NodeView{ + nvp1, + }, + want: map[string][]tailcfg.NodeView{ + // All the domains should be normalized to example.com + "example.com": {nvp1}, + }, + }, + { + name: "sub-domains-and-top-domains-do-not-collide", + config: []tailcfg.RawMessage{ + tailcfg.RawMessage(appTwoBytes), + tailcfg.RawMessage(appFiveBytes), + }, + peers: []tailcfg.NodeView{ + nvp1, + nvp3, + }, + want: map[string][]tailcfg.NodeView{ + // The sub.example.com should remain distinct from example.com + "example.com": {nvp1}, + "a.example.com": {nvp3}, + }, + }, } { t.Run(tt.name, func(t *testing.T) { selfNode := &tailcfg.Node{} @@ -119,6 +245,7 @@ func TestPickSplitDNSPeers(t *testing.T) { tailcfg.NodeCapability(AppConnectorsExperimentalAttrName): tt.config, } } + selfNode.Tags = append(selfNode.Tags, tt.selfTags...) selfView := selfNode.View() peers := map[tailcfg.NodeID]tailcfg.NodeView{} for _, p := range tt.peers { @@ -126,7 +253,8 @@ func TestPickSplitDNSPeers(t *testing.T) { } got := PickSplitDNSPeers(func(_ tailcfg.NodeCapability) bool { return true - }, selfView, peers) + }, selfView, peers, tt.isEligibleConnector) + if !reflect.DeepEqual(got, tt.want) { t.Fatalf("got %v, want %v", got, tt.want) } diff --git a/cache_key_test.go b/cache_key_test.go new file mode 100644 index 000000000..43de02e13 --- /dev/null +++ b/cache_key_test.go @@ -0,0 +1,57 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package tailscaleroot + +import ( + "os" + "os/exec" + "strings" + "testing" + + "tailscale.com/util/cibuild" +) + +// TestTsgoRevInCacheKey verifies that the Tailscale Go toolchain's git +// revision (from go.toolchain.rev) is blended into Go build cache keys. +// Without this, bumping the toolchain to a new commit that doesn't change +// the Go version number would silently reuse stale cached build artifacts. +// +// See https://github.com/tailscale/tailscale/issues/36589. +func TestTsgoRevInCacheKey(t *testing.T) { + goRoot := goEnv(t, "GOROOT") + isTsgo := strings.Contains(goRoot, "/.cache/tsgo/") + if !cibuild.OnTailscaleCI() && !isTsgo { + t.Skip("skipping; not in Tailscale CI and not using the Tailscale Go toolchain") + } + + rev := strings.TrimSpace(GoToolchainRev) + if rev == "" { + t.Fatal("go.toolchain.rev is empty") + } + + // Build the small stdlib "errors" package with GODEBUG=gocachehash=1, + // which causes cmd/go to log its cache key computations to stderr. + cmd := exec.Command("go", "build", "errors") + cmd.Env = append(os.Environ(), "GODEBUG=gocachehash=1") + out, err := cmd.CombinedOutput() + if err != nil { + t.Fatalf("go build errors failed: %v\n%s", err, out) + } + + // The cache key output should contain the toolchain rev alongside the + // Go version, e.g.: + // HASH[moduleIndex]: "go1.26.2 dfe2a5fd8ee2e68b08ce5ff259269f50ecadf2f4" + if !strings.Contains(string(out), rev) { + t.Errorf("go.toolchain.rev %q not found in GODEBUG=gocachehash=1 output:\n%s", rev, out) + } +} + +func goEnv(t *testing.T, key string) string { + t.Helper() + out, err := exec.Command("go", "env", key).Output() + if err != nil { + t.Fatalf("go env %s: %v", key, err) + } + return strings.TrimSpace(string(out)) +} diff --git a/client/local/local.go b/client/local/local.go index e72589306..1a2d7342b 100644 --- a/client/local/local.go +++ b/client/local/local.go @@ -327,6 +327,35 @@ func (lc *Client) WhoIs(ctx context.Context, remoteAddr string) (*apitype.WhoIsR return decodeJSON[*apitype.WhoIsResponse](body) } +// WhoIsForService is like [Client.WhoIs] but scopes the returned CapMap to +// capabilities that apply to the named VIP service. This enables per-service +// capability resolution on hosts that advertise multiple VIP services. +func (lc *Client) WhoIsForService(ctx context.Context, remoteAddr string, svcName tailcfg.ServiceName) (*apitype.WhoIsResponse, error) { + body, err := lc.get200(ctx, "/localapi/v0/whois?addr="+url.QueryEscape(remoteAddr)+"&svc_name="+url.QueryEscape(string(svcName))) + if err != nil { + if hs, ok := err.(httpStatusError); ok && hs.HTTPStatus == http.StatusNotFound { + return nil, ErrPeerNotFound + } + return nil, err + } + return decodeJSON[*apitype.WhoIsResponse](body) +} + +// WhoIsForIP is like [Client.WhoIs] but scopes the returned CapMap to +// capabilities that apply to the given destination IP. The IP may be a +// VIP service address, the node's own tailnet address, or any other +// routable IP the node handles. +func (lc *Client) WhoIsForIP(ctx context.Context, remoteAddr string, dst netip.Addr) (*apitype.WhoIsResponse, error) { + body, err := lc.get200(ctx, "/localapi/v0/whois?addr="+url.QueryEscape(remoteAddr)+"&dst_ip="+url.QueryEscape(dst.String())) + if err != nil { + if hs, ok := err.(httpStatusError); ok && hs.HTTPStatus == http.StatusNotFound { + return nil, ErrPeerNotFound + } + return nil, err + } + return decodeJSON[*apitype.WhoIsResponse](body) +} + // ErrPeerNotFound is returned by [Client.WhoIs], [Client.WhoIsNodeKey] and // [Client.WhoIsProto] when a peer is not found. var ErrPeerNotFound = errors.New("peer not found") @@ -607,6 +636,24 @@ func (lc *Client) DebugResultJSON(ctx context.Context, action string) (any, erro return x, nil } +// GetDebugResultJSON invokes a debug action and decodes the JSON response +// into a value of type T. It avoids the marshal/unmarshal roundtrip that +// callers of [Client.DebugResultJSON] otherwise need to do to get a typed +// value. +// +// These are development tools and subject to change or removal over time. +func GetDebugResultJSON[T any](ctx context.Context, lc *Client, action string) (T, error) { + var v T + body, err := lc.send(ctx, "POST", "/localapi/v0/debug?action="+url.QueryEscape(action), 200, nil) + if err != nil { + return v, fmt.Errorf("error %w: %s", err, body) + } + if err := json.Unmarshal(body, &v); err != nil { + return v, err + } + return v, nil +} + // QueryOptionalFeatures queries the optional features supported by the Tailscale daemon. func (lc *Client) QueryOptionalFeatures(ctx context.Context) (*apitype.OptionalFeatures, error) { body, err := lc.send(ctx, "POST", "/localapi/v0/debug-optional-features", 200, nil) @@ -972,6 +1019,19 @@ func (lc *Client) UserDial(ctx context.Context, network, host string, port uint1 if res.StatusCode != http.StatusSwitchingProtocols { body, _ := io.ReadAll(res.Body) res.Body.Close() + if res.StatusCode == http.StatusOK && res.Header.Get("Dial-Self") == "true" { + // Server told us to dial the address ourselves rather than + // proxying through the daemon. This happens for non-Tailscale + // addresses where the daemon shouldn't dial as root on the + // client's behalf. The server provides the resolved address + // to avoid a TOCTOU race with DNS re-resolution. + addr := res.Header.Get("Dial-Addr") + if addr == "" { + return nil, errors.New("server returned Dial-Self without Dial-Addr") + } + var d net.Dialer + return d.DialContext(ctx, network, addr) + } return nil, fmt.Errorf("unexpected HTTP response: %s, %s", res.Status, body) } // From here on, the underlying net.Conn is ours to use, but there @@ -1009,6 +1069,44 @@ func (lc *Client) CurrentDERPMap(ctx context.Context) (*tailcfg.DERPMap, error) return &derpMap, nil } +// CertDomains returns the list of domains for which the local tailscaled can +// fetch TLS certificates, equivalent to the DNS.CertDomains field of the +// current netmap. The returned list is sorted in ascending order, and is +// empty if no netmap has been received yet. +func (lc *Client) CertDomains(ctx context.Context) ([]string, error) { + body, err := lc.get200(ctx, "/localapi/v0/cert-domains") + if err != nil { + return nil, err + } + return decodeJSON[[]string](body) +} + +// DNSConfig returns the [tailcfg.DNSConfig] from the current netmap. +// It returns an error if no netmap has been received yet. +// It is intended for callers that need fields like ExtraRecords or CertDomains +// without pulling the rest of the netmap. +func (lc *Client) DNSConfig(ctx context.Context) (*tailcfg.DNSConfig, error) { + body, err := lc.get200(ctx, "/localapi/v0/dns-config") + if err != nil { + return nil, err + } + return decodeJSON[*tailcfg.DNSConfig](body) +} + +// PeerByID returns a peer's current full [tailcfg.Node] looked up by its +// [tailcfg.NodeID], in O(1) time on the daemon side. It returns an error +// if no peer with that NodeID is in the current netmap. +// +// It is intended for callers that need the latest state of a single peer +// without fetching the entire netmap. +func (lc *Client) PeerByID(ctx context.Context, id tailcfg.NodeID) (*tailcfg.Node, error) { + body, err := lc.get200(ctx, "/localapi/v0/peer-by-id?id="+strconv.FormatInt(int64(id), 10)) + if err != nil { + return nil, err + } + return decodeJSON[*tailcfg.Node](body) +} + // PingOpts contains options for the ping request. // // The zero value is valid, which means to use defaults. @@ -1422,3 +1520,13 @@ func (lc *Client) GetAppConnectorRouteInfo(ctx context.Context) (appctype.RouteI } return decodeJSON[appctype.RouteInfo](body) } + +// GetServices returns the Services visible to this node, +// including their names, IP addresses, and ports, keyed by service name. +func (lc *Client) GetServices(ctx context.Context) (map[tailcfg.ServiceName]tailcfg.ServiceDetails, error) { + body, err := lc.get200(ctx, "/localapi/v0/services") + if err != nil { + return nil, err + } + return decodeJSON[map[tailcfg.ServiceName]tailcfg.ServiceDetails](body) +} diff --git a/client/local/local_test.go b/client/local/local_test.go index a5377fbd6..58a87b224 100644 --- a/client/local/local_test.go +++ b/client/local/local_test.go @@ -61,6 +61,57 @@ func TestWhoIsPeerNotFound(t *testing.T) { } } +func TestUserDialSelf(t *testing.T) { + // Start a real TCP listener that the client should dial directly + // when the server tells it to dial-self. + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer ln.Close() + go func() { + for { + c, err := ln.Accept() + if err != nil { + return + } + c.Write([]byte("hello")) + c.Close() + } + }() + targetAddr := ln.Addr().(*net.TCPAddr) + + // Mock LocalAPI server that returns Dial-Self response. + nw := nettest.GetNetwork(t) + ts := nettest.NewHTTPServer(nw, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Dial-Self", "true") + w.Header().Set("Dial-Addr", targetAddr.String()) + w.WriteHeader(http.StatusOK) + })) + defer ts.Close() + + lc := &Client{ + Dial: func(ctx context.Context, network, addr string) (net.Conn, error) { + return nw.Dial(ctx, network, ts.Listener.Addr().String()) + }, + } + + conn, err := lc.UserDial(context.Background(), "tcp", targetAddr.IP.String(), uint16(targetAddr.Port)) + if err != nil { + t.Fatalf("UserDial: %v", err) + } + defer conn.Close() + + buf := make([]byte, 5) + n, err := conn.Read(buf) + if err != nil { + t.Fatalf("Read: %v", err) + } + if got := string(buf[:n]); got != "hello" { + t.Errorf("got %q, want %q", got, "hello") + } +} + func TestDeps(t *testing.T) { deptest.DepChecker{ BadDeps: map[string]string{ diff --git a/client/local/tailnetlock.go b/client/local/tailnetlock.go index 5af90eb16..54e795833 100644 --- a/client/local/tailnetlock.go +++ b/client/local/tailnetlock.go @@ -117,7 +117,7 @@ func (lc *Client) NetworkLockAffectedSigs(ctx context.Context, keyID tkatype.Key return decodeJSON[[]tkatype.MarshaledSignature](body) } -// NetworkLockLog returns up to maxEntries number of changes to network-lock state. +// NetworkLockLog returns up to maxEntries number of changes to tailnet-lock state. func (lc *Client) NetworkLockLog(ctx context.Context, maxEntries int) ([]ipnstate.NetworkLockUpdate, error) { v := url.Values{} v.Set("limit", fmt.Sprint(maxEntries)) @@ -128,7 +128,7 @@ func (lc *Client) NetworkLockLog(ctx context.Context, maxEntries int) ([]ipnstat return decodeJSON[[]ipnstate.NetworkLockUpdate](body) } -// NetworkLockForceLocalDisable forcibly shuts down network lock on this node. +// NetworkLockForceLocalDisable forcibly shuts down tailnet lock on this node. func (lc *Client) NetworkLockForceLocalDisable(ctx context.Context) error { // This endpoint expects an empty JSON stanza as the payload. var b bytes.Buffer @@ -142,7 +142,7 @@ func (lc *Client) NetworkLockForceLocalDisable(ctx context.Context) error { return nil } -// NetworkLockVerifySigningDeeplink verifies the network lock deeplink contained +// NetworkLockVerifySigningDeeplink verifies the tailnet lock deeplink contained // in url and returns information extracted from it. func (lc *Client) NetworkLockVerifySigningDeeplink(ctx context.Context, url string) (*tka.DeeplinkValidationResult, error) { vr := struct { @@ -193,7 +193,7 @@ func (lc *Client) NetworkLockSubmitRecoveryAUM(ctx context.Context, aum tka.AUM) return nil } -// NetworkLockDisable shuts down network-lock across the tailnet. +// NetworkLockDisable shuts down tailnet-lock across the tailnet. func (lc *Client) NetworkLockDisable(ctx context.Context, secret []byte) error { if _, err := lc.send(ctx, "POST", "/localapi/v0/tka/disable", 200, bytes.NewReader(secret)); err != nil { return fmt.Errorf("error: %w", err) diff --git a/client/systray/logo.go b/client/systray/logo.go index a0f8bf7d0..334cd7917 100644 --- a/client/systray/logo.go +++ b/client/systray/logo.go @@ -11,6 +11,7 @@ import ( "image" "image/color" "image/png" + "log" "runtime" "sync" "time" @@ -204,12 +205,49 @@ var ( ) var ( - bg = color.NRGBA{0, 0, 0, 255} - fg = color.NRGBA{255, 255, 255, 255} - gray = color.NRGBA{255, 255, 255, 102} - red = color.NRGBA{229, 111, 74, 255} + black = color.NRGBA{0, 0, 0, 255} + white = color.NRGBA{255, 255, 255, 255} + darkGray = color.NRGBA{102, 102, 102, 255} + lightGray = color.NRGBA{153, 153, 153, 255} + red = color.NRGBA{229, 111, 74, 255} + transparent = color.NRGBA{} + + // default values to dark theme + bg = black + fg = white + gray = darkGray ) +// SetTheme sets the color theme of the systray icon. +// +// Supported themes are: +// - dark - white and gray dots over black background +// - dark:nobg - white and grey dots over transparent background +// - light - black and gray dots over white background +// - light:nobg - black and grey dots over transparent background +func SetTheme(theme string) { + switch theme { + case "dark": + bg = black + fg = white + gray = darkGray + case "dark:nobg": + bg = transparent + fg = white + gray = darkGray + case "light": + bg = white + fg = black + gray = lightGray + case "light:nobg": + bg = transparent + fg = black + gray = lightGray + default: + log.Printf("unknown theme: %q", theme) + } +} + // render returns a PNG image of the logo. func (logo tsLogo) render() *bytes.Buffer { const borderUnits = 1 diff --git a/client/systray/startup-creator.go b/client/systray/startup-creator.go index 369190012..02a018099 100644 --- a/client/systray/startup-creator.go +++ b/client/systray/startup-creator.go @@ -3,7 +3,6 @@ //go:build cgo || !darwin -// Package systray provides a minimal Tailscale systray application. package systray import ( diff --git a/client/systray/systray.go b/client/systray/systray.go index 65c1bec20..d0287e647 100644 --- a/client/systray/systray.go +++ b/client/systray/systray.go @@ -621,11 +621,9 @@ func (menu *Menu) rebuildExitNodeMenu(ctx context.Context) { title += strings.Split(sugg.Name, ".")[0] } menu.exitNodes.AddSeparator() - rm := menu.exitNodes.AddSubMenuItemCheckbox(title, "", false) + active := recommendedIsActive(status, sugg.ID, sugg.Location.CountryCode(), sugg.Location.City()) + rm := menu.exitNodes.AddSubMenuItemCheckbox(title, "", active) setExitNodeOnClick(rm, sugg.ID) - if status.ExitNodeStatus != nil && sugg.ID == status.ExitNodeStatus.ID { - rm.Check() - } } } @@ -647,13 +645,11 @@ func (menu *Menu) rebuildExitNodeMenu(ctx context.Context) { if !ps.Online { name += " (offline)" } - sm := menu.exitNodes.AddSubMenuItemCheckbox(name, "", false) + active := status.ExitNodeStatus != nil && ps.ID == status.ExitNodeStatus.ID + sm := menu.exitNodes.AddSubMenuItemCheckbox(name, "", active) if !ps.Online { sm.Disable() } - if status.ExitNodeStatus != nil && ps.ID == status.ExitNodeStatus.ID { - sm.Check() - } setExitNodeOnClick(sm, ps.ID) } } @@ -743,6 +739,30 @@ func (mc *mvCountry) sortedCities() []*mvCity { return cities } +// recommendedIsActive reports whether the suggested exit node corresponds to +// the currently active exit node in status. +func recommendedIsActive(status *ipnstate.Status, suggID tailcfg.StableNodeID, suggCountry, suggCity string) bool { + if status == nil || status.ExitNodeStatus == nil || status.ExitNodeStatus.ID.IsZero() { + return false + } + if suggID == status.ExitNodeStatus.ID { + return true + } + if suggCountry == "" || suggCity == "" { + return false + } + for _, p := range status.Peer { + if p.ID != status.ExitNodeStatus.ID { + continue + } + if loc := p.Location; loc != nil && loc.CountryCode == suggCountry && loc.City == suggCity { + return true + } + return false + } + return false +} + // countryFlag takes a 2-character ASCII string and returns the corresponding emoji flag. // It returns the empty string on error. func countryFlag(code string) string { diff --git a/client/systray/systray_test.go b/client/systray/systray_test.go new file mode 100644 index 000000000..6b8ce8b95 --- /dev/null +++ b/client/systray/systray_test.go @@ -0,0 +1,120 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +//go:build cgo || !darwin + +package systray + +import ( + "testing" + + "tailscale.com/ipn/ipnstate" + "tailscale.com/tailcfg" + "tailscale.com/types/key" +) + +func TestRecommendedIsActive(t *testing.T) { + t.Parallel() + + const ( + activeID = tailcfg.StableNodeID("active") + suggID = tailcfg.StableNodeID("suggestion") + ) + usNYC := &tailcfg.Location{CountryCode: "US", City: "New York"} + usCHI := &tailcfg.Location{CountryCode: "US", City: "Chicago"} + seSTO := &tailcfg.Location{CountryCode: "SE", City: "Stockholm"} + + statusWith := func(activePeer *ipnstate.PeerStatus) *ipnstate.Status { + s := &ipnstate.Status{ + ExitNodeStatus: &ipnstate.ExitNodeStatus{ID: activeID}, + } + if activePeer != nil { + s.Peer = map[key.NodePublic]*ipnstate.PeerStatus{{}: activePeer} + } + return s + } + + tests := []struct { + name string + status *ipnstate.Status + suggID tailcfg.StableNodeID + suggCountry string + suggCity string + isActive bool + }{ + { + name: "nil_status", + status: nil, + suggID: suggID, + }, + { + name: "no_exit_node", + status: &ipnstate.Status{}, + suggID: suggID, + }, + { + name: "exit_node_id_is_zero", + status: &ipnstate.Status{ExitNodeStatus: &ipnstate.ExitNodeStatus{}}, + suggID: suggID, + }, + { + name: "exact_id_match_short-circuits", + status: statusWith(&ipnstate.PeerStatus{ID: activeID, Location: usCHI}), + suggID: activeID, + suggCountry: "US", + suggCity: "New York", + isActive: true, + }, + { + name: "id_mismatch_but_same_city", + status: statusWith(&ipnstate.PeerStatus{ID: activeID, Location: usNYC}), + suggID: suggID, + suggCountry: "US", + suggCity: "New York", + isActive: true, + }, + { + name: "different_city", + status: statusWith(&ipnstate.PeerStatus{ID: activeID, Location: usCHI}), + suggID: suggID, + suggCountry: "US", + suggCity: "New York", + }, + { + name: "different_country", + status: statusWith(&ipnstate.PeerStatus{ID: activeID, Location: seSTO}), + suggID: suggID, + suggCountry: "US", + suggCity: "New York", + }, + { + name: "id_mismatch_suggestion_has_no_location", + status: statusWith(&ipnstate.PeerStatus{ID: activeID, Location: usNYC}), + suggID: suggID, + }, + { + name: "id_mismatch_active_peer_has_no_location", + status: statusWith(&ipnstate.PeerStatus{ID: activeID}), + suggID: suggID, + suggCountry: "US", + suggCity: "New York", + }, + { + name: "active_peer_not_in_status", + status: statusWith(nil), + suggID: suggID, + suggCountry: "US", + suggCity: "New York", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + isExitNodeActive := recommendedIsActive(tt.status, tt.suggID, tt.suggCountry, tt.suggCity) + if isExitNodeActive != tt.isActive { + t.Errorf("recommendedIsActive; got %v, want %v", isExitNodeActive, tt.isActive) + } + }) + } +} diff --git a/client/web/web.go b/client/web/web.go index 3e5fa4b54..95259ef1a 100644 --- a/client/web/web.go +++ b/client/web/web.go @@ -35,8 +35,10 @@ import ( "tailscale.com/net/netutil" "tailscale.com/net/tsaddr" "tailscale.com/tailcfg" + "tailscale.com/tsweb" "tailscale.com/types/logger" "tailscale.com/types/views" + "tailscale.com/util/ctxkey" "tailscale.com/util/httpm" "tailscale.com/util/syspolicy/policyclient" "tailscale.com/version" @@ -527,45 +529,40 @@ func (s *Server) serveLoginAPI(w http.ResponseWriter, r *http.Request) { } } -type apiHandler[data any] struct { - s *Server - w http.ResponseWriter - r *http.Request - - // permissionCheck allows for defining whether a requesting peer's - // capabilities grant them access to make the given data update. - // If permissionCheck reports false, the request fails as unauthorized. - permissionCheck func(data data, peer peerCapabilities) bool -} - -// newHandler constructs a new api handler which restricts the given request -// to the specified permission check. If the permission check fails for -// the peer associated with the request, an unauthorized error is returned -// to the client. -func newHandler[data any](s *Server, w http.ResponseWriter, r *http.Request, permissionCheck func(data data, peer peerCapabilities) bool) *apiHandler[data] { - return &apiHandler[data]{ - s: s, - w: w, - r: r, - permissionCheck: permissionCheck, +// handleJSON manages decoding the request's body JSON as data and passing it +// on to the provided handler function. +func handleJSON[data any](h func(ctx context.Context, data data) error) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + defer r.Body.Close() + var body data + if err := json.NewDecoder(r.Body).Decode(&body); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + if err := h(r.Context(), body); err != nil { + if httpErr, ok := errors.AsType[tsweb.HTTPError](err); ok { + tsweb.WriteHTTPError(w, r, httpErr) + } else { + http.Error(w, err.Error(), http.StatusInternalServerError) + } + return + } + w.WriteHeader(http.StatusOK) } } -// alwaysAllowed can be passed as the permissionCheck argument to newHandler -// for requests that are always allowed to complete regardless of a peer's -// capabilities. -func alwaysAllowed[data any](_ data, _ peerCapabilities) bool { return true } +var contextKeyPeer = ctxkey.New("peer-capabilities", peerCapabilities{}) -func (a *apiHandler[data]) getPeer() (peerCapabilities, error) { +func (s *Server) setPeer(r *http.Request) (*http.Request, error) { // TODO(tailscale/corp#16695,sonia): We also call StatusWithoutPeers and // WhoIs when originally checking for a session from authorizeRequest. // Would be nice if we could pipe those through to here so we don't end // up having to re-call them to grab the peer capabilities. - status, err := a.s.lc.StatusWithoutPeers(a.r.Context()) + status, err := s.lc.StatusWithoutPeers(r.Context()) if err != nil { return nil, err } - whois, err := a.s.lc.WhoIs(a.r.Context(), a.r.RemoteAddr) + whois, err := s.lc.WhoIs(r.Context(), r.RemoteAddr) if err != nil { return nil, err } @@ -573,56 +570,11 @@ func (a *apiHandler[data]) getPeer() (peerCapabilities, error) { if err != nil { return nil, err } - return peer, nil + return r.WithContext(contextKeyPeer.WithValue(r.Context(), peer)), nil } -type noBodyData any // empty type, for use from serveAPI for endpoints with empty body - -// handle runs the given handler if the source peer satisfies the -// constraints for running this request. -// -// handle is expected for use when `data` type is empty, or set to -// `noBodyData` in practice. For requests that expect JSON body data -// to be attached, use handleJSON instead. -func (a *apiHandler[data]) handle(h http.HandlerFunc) { - peer, err := a.getPeer() - if err != nil { - http.Error(a.w, err.Error(), http.StatusInternalServerError) - return - } - var body data // not used - if !a.permissionCheck(body, peer) { - http.Error(a.w, "not allowed", http.StatusUnauthorized) - return - } - h(a.w, a.r) -} - -// handleJSON manages decoding the request's body JSON and passing -// it on to the provided function if the source peer satisfies the -// constraints for running this request. -func (a *apiHandler[data]) handleJSON(h func(ctx context.Context, data data) error) { - defer a.r.Body.Close() - var body data - if err := json.NewDecoder(a.r.Body).Decode(&body); err != nil { - http.Error(a.w, err.Error(), http.StatusInternalServerError) - return - } - peer, err := a.getPeer() - if err != nil { - http.Error(a.w, err.Error(), http.StatusInternalServerError) - return - } - if !a.permissionCheck(body, peer) { - http.Error(a.w, "not allowed", http.StatusUnauthorized) - return - } - - if err := h(a.r.Context(), body); err != nil { - http.Error(a.w, err.Error(), http.StatusInternalServerError) - return - } - a.w.WriteHeader(http.StatusOK) +func (s *Server) getPeer(ctx context.Context) peerCapabilities { + return contextKeyPeer.Value(ctx) } // serveAPI serves requests for the web client api. @@ -637,67 +589,44 @@ func (s *Server) serveAPI(w http.ResponseWriter, r *http.Request) { } } + var err error + r, err = s.setPeer(r) + if err != nil { + http.Error(w, err.Error(), http.StatusUnauthorized) + return + } + path := strings.TrimPrefix(r.URL.Path, "/api") switch { case path == "/data" && r.Method == httpm.GET: - newHandler[noBodyData](s, w, r, alwaysAllowed). - handle(s.serveGetNodeData) + s.serveGetNodeData(w, r) return case path == "/exit-nodes" && r.Method == httpm.GET: - newHandler[noBodyData](s, w, r, alwaysAllowed). - handle(s.serveGetExitNodes) + s.serveGetExitNodes(w, r) return case path == "/routes" && r.Method == httpm.POST: - peerAllowed := func(d postRoutesRequest, p peerCapabilities) bool { - if d.SetExitNode && !p.canEdit(capFeatureExitNodes) { - return false - } else if d.SetRoutes && !p.canEdit(capFeatureSubnets) { - return false - } - return true - } - newHandler[postRoutesRequest](s, w, r, peerAllowed). - handleJSON(s.servePostRoutes) + handleJSON[postRoutesRequest](s.servePostRoutes)(w, r) return case path == "/device-details-click" && r.Method == httpm.POST: - newHandler[noBodyData](s, w, r, alwaysAllowed). - handle(s.serveDeviceDetailsClick) + s.serveDeviceDetailsClick(w, r) return case path == "/local/v0/logout" && r.Method == httpm.POST: - peerAllowed := func(_ noBodyData, peer peerCapabilities) bool { - return peer.canEdit(capFeatureAccount) - } - newHandler[noBodyData](s, w, r, peerAllowed). - handle(s.proxyRequestToLocalAPI) + s.proxyRequestToLocalAPI(w, r) return case path == "/local/v0/prefs" && r.Method == httpm.PATCH: - peerAllowed := func(data maskedPrefs, peer peerCapabilities) bool { - if data.RunSSHSet && !peer.canEdit(capFeatureSSH) { - return false - } - return true - } - newHandler[maskedPrefs](s, w, r, peerAllowed). - handleJSON(s.serveUpdatePrefs) + handleJSON[maskedPrefs](s.serveUpdatePrefs)(w, r) return case path == "/local/v0/update/check" && r.Method == httpm.GET: - newHandler[noBodyData](s, w, r, alwaysAllowed). - handle(s.proxyRequestToLocalAPI) + s.proxyRequestToLocalAPI(w, r) return case path == "/local/v0/update/check" && r.Method == httpm.POST: - peerAllowed := func(_ noBodyData, peer peerCapabilities) bool { - return peer.canEdit(capFeatureAccount) - } - newHandler[noBodyData](s, w, r, peerAllowed). - handle(s.proxyRequestToLocalAPI) + s.proxyRequestToLocalAPI(w, r) return case path == "/local/v0/update/progress" && r.Method == httpm.POST: - newHandler[noBodyData](s, w, r, alwaysAllowed). - handle(s.proxyRequestToLocalAPI) + s.proxyRequestToLocalAPI(w, r) return case path == "/local/v0/upload-client-metrics" && r.Method == httpm.POST: - newHandler[noBodyData](s, w, r, alwaysAllowed). - handle(s.proxyRequestToLocalAPI) + s.proxyRequestToLocalAPI(w, r) return } http.Error(w, "invalid endpoint", http.StatusNotFound) @@ -1122,6 +1051,11 @@ type maskedPrefs struct { } func (s *Server) serveUpdatePrefs(ctx context.Context, prefs maskedPrefs) error { + peer := s.getPeer(ctx) + if prefs.RunSSHSet && !peer.canEdit(capFeatureSSH) { + return tsweb.Error(http.StatusUnauthorized, "RunSSHSet not allowed", nil) + } + _, err := s.lc.EditPrefs(ctx, &ipn.MaskedPrefs{ RunSSHSet: prefs.RunSSHSet, Prefs: ipn.Prefs{ @@ -1140,6 +1074,17 @@ type postRoutesRequest struct { } func (s *Server) servePostRoutes(ctx context.Context, data postRoutesRequest) error { + if !data.SetExitNode && !data.SetRoutes { + return tsweb.Error(http.StatusBadRequest, "must specify SetExitNode or SetRoutes", nil) + } + peer := s.getPeer(ctx) + if data.SetExitNode && !peer.canEdit(capFeatureExitNodes) { + return tsweb.Error(http.StatusUnauthorized, "SetExitNode not allowed", nil) + } + if data.SetRoutes && !peer.canEdit(capFeatureSubnets) { + return tsweb.Error(http.StatusUnauthorized, "SetRoutes not allowed", nil) + } + prefs, err := s.lc.GetPrefs(ctx) if err != nil { return err @@ -1153,13 +1098,14 @@ func (s *Server) servePostRoutes(ctx context.Context, data postRoutesRequest) er } currNonExitRoutes = append(currNonExitRoutes, r.String()) } - // Set non-edited fields to their current values. - if data.SetExitNode { - data.AdvertiseRoutes = currNonExitRoutes - } else if data.SetRoutes { + // For each group of fields not being set, preserve the current prefs. + if !data.SetExitNode { data.AdvertiseExitNode = currAdvertisingExitNode data.UseExitNode = prefs.ExitNodeID } + if !data.SetRoutes { + data.AdvertiseRoutes = currNonExitRoutes + } // Calculate routes. routesStr := strings.Join(data.AdvertiseRoutes, ",") @@ -1336,6 +1282,19 @@ func (s *Server) proxyRequestToLocalAPI(w http.ResponseWriter, r *http.Request) return } + switch path { + case "/v0/logout": + if !s.getPeer(r.Context()).canEdit(capFeatureAccount) { + http.Error(w, "not allowed", http.StatusUnauthorized) + return + } + case "/v0/update/check": + if r.Method == httpm.POST && !s.getPeer(r.Context()).canEdit(capFeatureAccount) { + http.Error(w, "not allowed", http.StatusUnauthorized) + return + } + } + localAPIURL := "http://" + apitype.LocalAPIHost + "/localapi" + path req, err := http.NewRequestWithContext(r.Context(), r.Method, localAPIURL, r.Body) if err != nil { diff --git a/client/web/web_test.go b/client/web/web_test.go index 032cd5222..51b6a8ac5 100644 --- a/client/web/web_test.go +++ b/client/web/web_test.go @@ -191,7 +191,7 @@ func TestServeAPI(t *testing.T) { reqBody: "{\"setExitNode\":true}", tests: []requestTest{{ remoteIP: remoteIPWithNoCapabilities, - wantResponse: "not allowed", + wantResponse: "SetExitNode not allowed", wantStatus: http.StatusUnauthorized, }, { remoteIP: remoteIPWithAllCapabilities, @@ -204,7 +204,7 @@ func TestServeAPI(t *testing.T) { reqContentType: "application/json", tests: []requestTest{{ remoteIP: remoteIPWithNoCapabilities, - wantResponse: "not allowed", + wantResponse: "RunSSHSet not allowed", wantStatus: http.StatusUnauthorized, }, { remoteIP: remoteIPWithAllCapabilities, @@ -1604,3 +1604,149 @@ func TestCSRFProtect(t *testing.T) { }) } } + +func TestServePostRoutes(t *testing.T) { + existingExitNodeID := tailcfg.StableNodeID("existing-exit-node") + existingRoute := netip.MustParsePrefix("192.168.1.0/24") + + existingPrefs := &ipn.Prefs{ + ExitNodeID: existingExitNodeID, + AdvertiseRoutes: []netip.Prefix{existingRoute}, + } + + tests := []struct { + name string + data postRoutesRequest + peerCaps peerCapabilities + wantErr bool + wantEditPrefs bool // whether EditPrefs (PATCH /prefs) should be called + wantExitNodeID tailcfg.StableNodeID + wantRoutes []netip.Prefix + }{ + { + name: "empty-request", + data: postRoutesRequest{}, + peerCaps: peerCapabilities{capFeatureExitNodes: true, capFeatureSubnets: true}, + wantErr: true, + wantEditPrefs: false, + }, + { + name: "SetExitNode-only", + data: postRoutesRequest{ + SetExitNode: true, + UseExitNode: "new-exit-node", + }, + peerCaps: peerCapabilities{capFeatureExitNodes: true, capFeatureSubnets: true}, + wantEditPrefs: true, + wantExitNodeID: "new-exit-node", + wantRoutes: []netip.Prefix{existingRoute}, + }, + { + name: "SetExitNode-not-allowed", + data: postRoutesRequest{ + SetExitNode: true, + UseExitNode: "new-exit-node", + }, + peerCaps: peerCapabilities{capFeatureSubnets: true}, + wantErr: true, + }, + { + name: "SetRoutes-only", + data: postRoutesRequest{ + SetRoutes: true, + AdvertiseRoutes: []string{"10.0.0.0/8"}, + }, + peerCaps: peerCapabilities{capFeatureExitNodes: true, capFeatureSubnets: true}, + wantEditPrefs: true, + wantExitNodeID: existingExitNodeID, + wantRoutes: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")}, + }, + { + name: "SetRoutes-not-allowed", + data: postRoutesRequest{ + SetRoutes: true, + AdvertiseRoutes: []string{"10.0.0.0/8"}, + }, + peerCaps: peerCapabilities{capFeatureExitNodes: true}, + wantErr: true, + }, + { + name: "SetExitNode-and-SetRoutes", + data: postRoutesRequest{ + SetExitNode: true, + SetRoutes: true, + UseExitNode: "new-exit-node", + AdvertiseRoutes: []string{"10.0.0.0/8"}, + }, + peerCaps: peerCapabilities{capFeatureExitNodes: true, capFeatureSubnets: true}, + wantEditPrefs: true, + wantExitNodeID: "new-exit-node", + wantRoutes: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var gotPrefs *ipn.MaskedPrefs + + lal := memnet.Listen("local-tailscaled.sock:80") + defer lal.Close() + + localapi := &http.Server{Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/localapi/v0/prefs" { + t.Errorf("unexpected localapi call to %q", r.URL.Path) + http.Error(w, "unexpected localapi call", http.StatusInternalServerError) + return + } + switch r.Method { + case httpm.GET: + writeJSON(w, existingPrefs) + case httpm.PATCH: + var mp ipn.MaskedPrefs + if err := json.NewDecoder(r.Body).Decode(&mp); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + gotPrefs = &mp + writeJSON(w, gotPrefs.Prefs) + default: + t.Errorf("unexpected method %q on /prefs", r.Method) + http.Error(w, "unexpected method", http.StatusMethodNotAllowed) + } + })} + defer localapi.Close() + go localapi.Serve(lal) + + s := &Server{ + mode: ManageServerMode, + lc: &local.Client{Dial: lal.Dial}, + } + + ctx := contextKeyPeer.WithValue(t.Context(), tt.peerCaps) + err := s.servePostRoutes(ctx, tt.data) + + if tt.wantErr { + if err == nil { + t.Error("wanted error, got nil") + } + if gotPrefs != nil { + t.Error("EditPrefs should not have been called on error") + } + return + } + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if gotPrefs == nil { + t.Fatal("expected EditPrefs to be called") + } + if diff := cmp.Diff(tt.wantExitNodeID, gotPrefs.ExitNodeID); diff != "" { + t.Errorf("ExitNodeID mismatch (-want +got):\n%s", diff) + } + if diff := cmp.Diff(tt.wantRoutes, gotPrefs.AdvertiseRoutes, cmp.Comparer(func(a, b netip.Prefix) bool { return a.Compare(b) == 0 })); diff != "" { + t.Errorf("AdvertiseRoutes mismatch (-want +got):\n%s", diff) + } + }) + } +} diff --git a/clientupdate/clientupdate_windows.go b/clientupdate/clientupdate_windows.go index 70a3c5091..50b77c38b 100644 --- a/clientupdate/clientupdate_windows.go +++ b/clientupdate/clientupdate_windows.go @@ -38,12 +38,12 @@ const ( updaterPrefix = "tailscale-updater" ) -func makeSelfCopy() (origPathExe, tmpPathExe string, err error) { - selfExe, err := os.Executable() +func makeCmdTailscaleCopy() (origPathExe, tmpPathExe string, err error) { + srcExe, err := findCmdTailscale() if err != nil { return "", "", err } - f, err := os.Open(selfExe) + f, err := os.Open(srcExe) if err != nil { return "", "", err } @@ -59,7 +59,25 @@ func makeSelfCopy() (origPathExe, tmpPathExe string, err error) { f2.Close() return "", "", err } - return selfExe, f2.Name(), f2.Close() + return srcExe, f2.Name(), f2.Close() +} + +// findCmdTailscale returns the path to the binary that should be copied for the update +// re-execution. The copy is re-executed with "update" as a subcommand, so it must be +// a binary that handles "update" (ie tailscale.exe, not tailscaled.exe) +func findCmdTailscale() (string, error) { + selfExe, err := os.Executable() + if err != nil { + return "", err + } + if strings.EqualFold(filepath.Base(selfExe), "tailscale.exe") { + return selfExe, nil + } + ts := filepath.Join(filepath.Dir(selfExe), "tailscale.exe") + if _, err := os.Stat(ts); err != nil { + return "", fmt.Errorf("cannot find tailscale.exe alongside %s: %w", selfExe, err) + } + return ts, nil } func markTempFileWindows(name string) error { @@ -159,14 +177,14 @@ you can run the command prompt as Administrator one of these ways: up.Logf("making tailscale.exe copy to switch to...") up.cleanupOldDownloads(filepath.Join(os.TempDir(), updaterPrefix+"-*.exe")) - _, selfCopy, err := makeSelfCopy() + _, cmdTailscaleCopy, err := makeCmdTailscaleCopy() if err != nil { return err } - defer os.Remove(selfCopy) + defer os.Remove(cmdTailscaleCopy) up.Logf("running tailscale.exe copy for final install...") - cmd := exec.Command(selfCopy, "update") + cmd := exec.Command(cmdTailscaleCopy, "update") cmd.Env = append(os.Environ(), winMSIEnv+"="+msiTarget, winVersionEnv+"="+ver) cmd.Stdout = up.Stderr cmd.Stderr = up.Stderr diff --git a/cmd/cloner/cloner.go b/cmd/cloner/cloner.go index ab4a7b22f..8b4cacf7a 100644 --- a/cmd/cloner/cloner.go +++ b/cmd/cloner/cloner.go @@ -143,25 +143,9 @@ func gen(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named) { writef("if src.%s != nil {", fname) writef("dst.%s = make([]%s, len(src.%s))", fname, n, fname) writef("for i := range dst.%s {", fname) - if ptr, isPtr := ft.Elem().(*types.Pointer); isPtr { - writef("if src.%s[i] == nil { dst.%s[i] = nil } else {", fname, fname) - if codegen.ContainsPointers(ptr.Elem()) { - if _, isIface := ptr.Elem().Underlying().(*types.Interface); isIface { - writef("\tdst.%s[i] = new((*src.%s[i]).Clone())", fname, fname) - } else { - writef("\tdst.%s[i] = src.%s[i].Clone()", fname, fname) - } - } else { - writef("\tdst.%s[i] = new(*src.%s[i])", fname, fname) - } - writef("}") - } else if ft.Elem().String() == "encoding/json.RawMessage" { - writef("\tdst.%s[i] = append(src.%s[i][:0:0], src.%s[i]...)", fname, fname, fname) - } else if _, isIface := ft.Elem().Underlying().(*types.Interface); isIface { - writef("\tdst.%s[i] = src.%s[i].Clone()", fname, fname) - } else { - writef("\tdst.%s[i] = *src.%s[i].Clone()", fname, fname) - } + writeSliceElemClone(writef, ft.Elem(), + fmt.Sprintf("src.%s[i]", fname), + fmt.Sprintf("dst.%s[i]", fname)) writef("}") writef("}") } else { @@ -189,11 +173,28 @@ func gen(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named) { n := it.QualifiedName(sliceType.Elem()) writef("if dst.%s != nil {", fname) writef("\tdst.%s = map[%s]%s{}", fname, it.QualifiedName(ft.Key()), it.QualifiedName(elem)) - writef("\tfor k := range src.%s {", fname) - // use zero-length slice instead of nil to ensure - // the key is always copied. - writef("\t\tdst.%s[k] = append([]%s{}, src.%s[k]...)", fname, n, fname) - writef("\t}") + if codegen.ContainsPointers(sliceType.Elem()) { + writef("\tfor k, sv := range src.%s {", fname) + writef("\t\tif sv == nil {") + writef("\t\t\tdst.%s[k] = nil", fname) + writef("\t\t\tcontinue") + writef("\t\t}") + writef("\t\tdst.%s[k] = make([]%s, len(sv))", fname, n) + writef("\t\tfor i := range sv {") + innerWritef := func(format string, args ...any) { + writef("\t\t"+format, args...) + } + writeSliceElemClone(innerWritef, sliceType.Elem(), + "sv[i]", fmt.Sprintf("dst.%s[k][i]", fname)) + writef("\t\t}") + writef("\t}") + } else { + writef("\tfor k := range src.%s {", fname) + // use zero-length slice instead of nil to ensure + // the key is always copied. + writef("\t\tdst.%s[k] = append([]%s{}, src.%s[k]...)", fname, n, fname) + writef("\t}") + } writef("}") } else if codegen.IsViewType(elem) || !codegen.ContainsPointers(elem) { // If the map values are view types (which are @@ -242,6 +243,31 @@ func gen(buf *bytes.Buffer, it *codegen.ImportTracker, typ *types.Named) { buf.Write(codegen.AssertStructUnchanged(t, name, typeParams, "Clone", it)) } +// writeSliceElemClone generates code to deep-clone a single slice element +// from srcExpr to dstExpr. It handles pointer, json.RawMessage, interface, +// and named struct element types. +func writeSliceElemClone(writef func(string, ...any), elemType types.Type, srcExpr, dstExpr string) { + if ptr, isPtr := elemType.(*types.Pointer); isPtr { + writef("if %s == nil { %s = nil } else {", srcExpr, dstExpr) + if codegen.ContainsPointers(ptr.Elem()) { + if _, isIface := ptr.Elem().Underlying().(*types.Interface); isIface { + writef("\t%s = new((*%s).Clone())", dstExpr, srcExpr) + } else { + writef("\t%s = %s.Clone()", dstExpr, srcExpr) + } + } else { + writef("\t%s = new(*%s)", dstExpr, srcExpr) + } + writef("}") + } else if elemType.String() == "encoding/json.RawMessage" { + writef("%s = append(%s[:0:0], %s...)", dstExpr, srcExpr, srcExpr) + } else if _, isIface := elemType.Underlying().(*types.Interface); isIface { + writef("%s = %s.Clone()", dstExpr, srcExpr) + } else { + writef("%s = *%s.Clone()", dstExpr, srcExpr) + } +} + // hasBasicUnderlying reports true when typ.Underlying() is a slice or a map. func hasBasicUnderlying(typ types.Type) bool { switch typ.Underlying().(type) { diff --git a/cmd/cloner/cloner_test.go b/cmd/cloner/cloner_test.go index c0a946480..f8beb4a88 100644 --- a/cmd/cloner/cloner_test.go +++ b/cmd/cloner/cloner_test.go @@ -7,6 +7,7 @@ import ( "reflect" "testing" + "github.com/google/go-cmp/cmp" "tailscale.com/cmd/cloner/clonerex" ) @@ -182,6 +183,46 @@ func TestNamedMapContainer(t *testing.T) { } } +func TestMapSlicePointerContainer(t *testing.T) { + num := 42 + orig := &clonerex.MapSlicePointerContainer{ + Routes: map[string][]*clonerex.SliceContainer{ + "route1": { + {Slice: []*int{&num}}, + {Slice: []*int{&num, &num}}, + }, + "route2": { + {Slice: []*int{&num}}, + }, + }, + } + + cloned := orig.Clone() + if !reflect.DeepEqual(orig, cloned) { + t.Errorf("Clone() = %v, want %v", cloned, orig) + } + + // Mutate cloned.Routes pointer values + *cloned.Routes["route1"][0].Slice[0] = 999 + if *orig.Routes["route1"][0].Slice[0] == 999 { + t.Errorf("Clone() aliased memory in Routes: original was modified") + } +} + +func TestMapSlicePointerContainerNilValue(t *testing.T) { + num := 7 + orig := &clonerex.MapSlicePointerContainer{ + Routes: map[string][]*clonerex.SliceContainer{ + "nil-value": nil, + "non-nil": {{Slice: []*int{&num}}}, + }, + } + cloned := orig.Clone() + if diff := cmp.Diff(orig.Routes, cloned.Routes); diff != "" { + t.Errorf("Clone() Routes mismatch (-orig +cloned):\n%s", diff) + } +} + func TestDeeplyNestedMap(t *testing.T) { num := 123 orig := &clonerex.DeeplyNestedMap{ diff --git a/cmd/cloner/clonerex/clonerex.go b/cmd/cloner/clonerex/clonerex.go index d17dbefc5..41626d3ae 100644 --- a/cmd/cloner/clonerex/clonerex.go +++ b/cmd/cloner/clonerex/clonerex.go @@ -1,7 +1,7 @@ // Copyright (c) Tailscale Inc & contributors // SPDX-License-Identifier: BSD-3-Clause -//go:generate go run tailscale.com/cmd/cloner -clonefunc=true -type SliceContainer,InterfaceContainer,MapWithPointers,DeeplyNestedMap,NamedMapContainer +//go:generate go run tailscale.com/cmd/cloner -clonefunc=true -type SliceContainer,InterfaceContainer,MapWithPointers,DeeplyNestedMap,NamedMapContainer,MapSlicePointerContainer // Package clonerex is an example package for the cloner tool. package clonerex @@ -60,6 +60,13 @@ type NamedMapContainer struct { Attrs NamedMap } +// MapSlicePointerContainer has a map whose values are slices of pointers. +// This tests that the cloner deep-clones the pointer elements in the slice, +// not just the slice itself (which would leave aliased pointers). +type MapSlicePointerContainer struct { + Routes map[string][]*SliceContainer +} + // DeeplyNestedMap tests arbitrary depth of map nesting (3+ levels) type DeeplyNestedMap struct { ThreeLevels map[string]map[string]map[string]int diff --git a/cmd/cloner/clonerex/clonerex_clone.go b/cmd/cloner/clonerex/clonerex_clone.go index 7d94688a3..9a4413177 100644 --- a/cmd/cloner/clonerex/clonerex_clone.go +++ b/cmd/cloner/clonerex/clonerex_clone.go @@ -176,9 +176,42 @@ var _NamedMapContainerCloneNeedsRegeneration = NamedMapContainer(struct { Attrs NamedMap }{}) +// Clone makes a deep copy of MapSlicePointerContainer. +// The result aliases no memory with the original. +func (src *MapSlicePointerContainer) Clone() *MapSlicePointerContainer { + if src == nil { + return nil + } + dst := new(MapSlicePointerContainer) + *dst = *src + if dst.Routes != nil { + dst.Routes = map[string][]*SliceContainer{} + for k, sv := range src.Routes { + if sv == nil { + dst.Routes[k] = nil + continue + } + dst.Routes[k] = make([]*SliceContainer, len(sv)) + for i := range sv { + if sv[i] == nil { + dst.Routes[k][i] = nil + } else { + dst.Routes[k][i] = sv[i].Clone() + } + } + } + } + return dst +} + +// A compilation failure here means this code must be regenerated, with the command at the top of this file. +var _MapSlicePointerContainerCloneNeedsRegeneration = MapSlicePointerContainer(struct { + Routes map[string][]*SliceContainer +}{}) + // Clone duplicates src into dst and reports whether it succeeded. // To succeed, must be of types <*T, *T> or <*T, **T>, -// where T is one of SliceContainer,InterfaceContainer,MapWithPointers,DeeplyNestedMap,NamedMapContainer. +// where T is one of SliceContainer,InterfaceContainer,MapWithPointers,DeeplyNestedMap,NamedMapContainer,MapSlicePointerContainer. func Clone(dst, src any) bool { switch src := src.(type) { case *SliceContainer: @@ -226,6 +259,15 @@ func Clone(dst, src any) bool { *dst = src.Clone() return true } + case *MapSlicePointerContainer: + switch dst := dst.(type) { + case *MapSlicePointerContainer: + *dst = *src.Clone() + return true + case **MapSlicePointerContainer: + *dst = src.Clone() + return true + } } return false } diff --git a/cmd/containerboot/egressservices.go b/cmd/containerboot/egressservices.go index e60d65c04..abde12523 100644 --- a/cmd/containerboot/egressservices.go +++ b/cmd/containerboot/egressservices.go @@ -22,11 +22,12 @@ import ( "time" "github.com/fsnotify/fsnotify" + "tailscale.com/client/local" - "tailscale.com/ipn" "tailscale.com/kube/egressservices" "tailscale.com/kube/kubeclient" "tailscale.com/kube/kubetypes" + "tailscale.com/types/netmap" "tailscale.com/util/httpm" "tailscale.com/util/linuxfw" "tailscale.com/util/mak" @@ -54,7 +55,7 @@ type egressProxy struct { tsClient *local.Client // never nil - netmapChan chan ipn.Notify // chan to receive netmap updates on + netmapChan chan *netmap.NetworkMap // chan to receive netmap updates on podIPv4 string // never empty string, currently only IPv4 is supported @@ -86,7 +87,7 @@ type httpClient interface { // - the mounted egress config has changed // - the proxy's tailnet IP addresses have changed // - tailnet IPs have changed for any backend targets specified by tailnet FQDN -func (ep *egressProxy) run(ctx context.Context, n ipn.Notify, opts egressProxyRunOpts) error { +func (ep *egressProxy) run(ctx context.Context, nm *netmap.NetworkMap, opts egressProxyRunOpts) error { ep.configure(opts) var tickChan <-chan time.Time var eventChan <-chan fsnotify.Event @@ -105,7 +106,7 @@ func (ep *egressProxy) run(ctx context.Context, n ipn.Notify, opts egressProxyRu eventChan = w.Events } - if err := ep.sync(ctx, n); err != nil { + if err := ep.sync(ctx, nm); err != nil { return err } for { @@ -116,14 +117,14 @@ func (ep *egressProxy) run(ctx context.Context, n ipn.Notify, opts egressProxyRu log.Printf("periodic sync, ensuring firewall config is up to date...") case <-eventChan: log.Printf("config file change detected, ensuring firewall config is up to date...") - case n = <-ep.netmapChan: - shouldResync := ep.shouldResync(n) + case nm = <-ep.netmapChan: + shouldResync := ep.shouldResync(nm) if !shouldResync { continue } log.Printf("netmap change detected, ensuring firewall config is up to date...") } - if err := ep.sync(ctx, n); err != nil { + if err := ep.sync(ctx, nm); err != nil { return fmt.Errorf("error syncing egress service config: %w", err) } } @@ -135,7 +136,7 @@ type egressProxyRunOpts struct { kc kubeclient.Client tsClient *local.Client stateSecret string - netmapChan chan ipn.Notify + netmapChan chan *netmap.NetworkMap podIPv4 string tailnetAddrs []netip.Prefix } @@ -164,7 +165,7 @@ func (ep *egressProxy) configure(opts egressProxyRunOpts) { // any firewall rules need to be updated. Currently using status in state Secret as a reference for what is the current // firewall configuration is good enough because - the status is keyed by the Pod IP - we crash the Pod on errors such // as failed firewall update -func (ep *egressProxy) sync(ctx context.Context, n ipn.Notify) error { +func (ep *egressProxy) sync(ctx context.Context, nm *netmap.NetworkMap) error { cfgs, err := ep.getConfigs() if err != nil { return fmt.Errorf("error retrieving egress service configs: %w", err) @@ -173,12 +174,12 @@ func (ep *egressProxy) sync(ctx context.Context, n ipn.Notify) error { if err != nil { return fmt.Errorf("error retrieving current egress proxy status: %w", err) } - newStatus, err := ep.syncEgressConfigs(cfgs, status, n) + newStatus, err := ep.syncEgressConfigs(cfgs, status, nm) if err != nil { return fmt.Errorf("error syncing egress service configs: %w", err) } if !servicesStatusIsEqual(newStatus, status) { - if err := ep.setStatus(ctx, newStatus, n); err != nil { + if err := ep.setStatus(ctx, newStatus, nm); err != nil { return fmt.Errorf("error setting egress proxy status: %w", err) } } @@ -187,14 +188,14 @@ func (ep *egressProxy) sync(ctx context.Context, n ipn.Notify) error { // addrsHaveChanged returns true if the provided netmap update contains tailnet address change for this proxy node. // Netmap must not be nil. -func (ep *egressProxy) addrsHaveChanged(n ipn.Notify) bool { - return !reflect.DeepEqual(ep.tailnetAddrs, n.NetMap.SelfNode.Addresses()) +func (ep *egressProxy) addrsHaveChanged(nm *netmap.NetworkMap) bool { + return !reflect.DeepEqual(ep.tailnetAddrs, nm.SelfNode.Addresses()) } // syncEgressConfigs adds and deletes firewall rules to match the desired // configuration. It uses the provided status to determine what is currently // applied and updates the status after a successful sync. -func (ep *egressProxy) syncEgressConfigs(cfgs *egressservices.Configs, status *egressservices.Status, n ipn.Notify) (*egressservices.Status, error) { +func (ep *egressProxy) syncEgressConfigs(cfgs egressservices.Configs, status *egressservices.Status, nm *netmap.NetworkMap) (*egressservices.Status, error) { if !(wantsServicesConfigured(cfgs) || hasServicesConfigured(status)) { return nil, nil } @@ -212,8 +213,8 @@ func (ep *egressProxy) syncEgressConfigs(cfgs *egressservices.Configs, status *e // Add new services, update rules for any that have changed. rulesPerSvcToAdd := make(map[string][]rule, 0) rulesPerSvcToDelete := make(map[string][]rule, 0) - for svcName, cfg := range *cfgs { - tailnetTargetIPs, err := ep.tailnetTargetIPsForSvc(cfg, n) + for svcName, cfg := range cfgs { + tailnetTargetIPs, err := ep.tailnetTargetIPsForSvc(cfg, nm) if err != nil { return nil, fmt.Errorf("error determining tailnet target IPs: %w", err) } @@ -228,12 +229,12 @@ func (ep *egressProxy) syncEgressConfigs(cfgs *egressservices.Configs, status *e if len(rulesToDelete) != 0 { mak.Set(&rulesPerSvcToDelete, svcName, rulesToDelete) } - if len(rulesToAdd) != 0 || ep.addrsHaveChanged(n) { + if len(rulesToAdd) != 0 || ep.addrsHaveChanged(nm) { // For each tailnet target, set up SNAT from the local tailnet device address of the matching // family. for _, t := range tailnetTargetIPs { var local netip.Addr - for _, pfx := range n.NetMap.SelfNode.Addresses().All() { + for _, pfx := range nm.SelfNode.Addresses().All() { if !pfx.IsSingleIP() { continue } @@ -352,7 +353,7 @@ func updatesForCfg(svcName string, cfg egressservices.Config, status *egressserv // deleteUnneccessaryServices ensure that any services found on status, but not // present in config are deleted. -func (ep *egressProxy) deleteUnnecessaryServices(cfgs *egressservices.Configs, status *egressservices.Status) error { +func (ep *egressProxy) deleteUnnecessaryServices(cfgs egressservices.Configs, status *egressservices.Status) error { if !hasServicesConfigured(status) { return nil } @@ -367,7 +368,7 @@ func (ep *egressProxy) deleteUnnecessaryServices(cfgs *egressservices.Configs, s } for svcName, svc := range status.Services { - if _, ok := (*cfgs)[svcName]; !ok { + if _, ok := cfgs[svcName]; !ok { log.Printf("service %s is no longer required, deleting", svcName) if err := ensureServiceDeleted(svcName, svc, ep.nfr); err != nil { return fmt.Errorf("error deleting service %s: %w", svcName, err) @@ -379,7 +380,7 @@ func (ep *egressProxy) deleteUnnecessaryServices(cfgs *egressservices.Configs, s } // getConfigs gets the mounted egress service configuration. -func (ep *egressProxy) getConfigs() (*egressservices.Configs, error) { +func (ep *egressProxy) getConfigs() (egressservices.Configs, error) { svcsCfg := filepath.Join(ep.cfgPath, egressservices.KeyEgressServices) j, err := os.ReadFile(svcsCfg) if os.IsNotExist(err) { @@ -391,7 +392,7 @@ func (ep *egressProxy) getConfigs() (*egressservices.Configs, error) { if len(j) == 0 || string(j) == "" { return nil, nil } - cfg := &egressservices.Configs{} + cfg := egressservices.Configs{} if err := json.Unmarshal(j, &cfg); err != nil { return nil, err } @@ -423,7 +424,7 @@ func (ep *egressProxy) getStatus(ctx context.Context) (*egressservices.Status, e // setStatus writes egress proxy's currently configured firewall to the state // Secret and updates proxy's tailnet addresses. -func (ep *egressProxy) setStatus(ctx context.Context, status *egressservices.Status, n ipn.Notify) error { +func (ep *egressProxy) setStatus(ctx context.Context, status *egressservices.Status, nm *netmap.NetworkMap) error { // Pod IP is used to determine if a stored status applies to THIS proxy Pod. if status == nil { status = &egressservices.Status{} @@ -446,7 +447,7 @@ func (ep *egressProxy) setStatus(ctx context.Context, status *egressservices.Sta if err := ep.kc.JSONPatchResource(ctx, ep.stateSecret, kubeclient.TypeSecrets, []kubeclient.JSONPatch{patch}); err != nil { return fmt.Errorf("error patching state Secret: %w", err) } - ep.tailnetAddrs = n.NetMap.SelfNode.Addresses().AsSlice() + ep.tailnetAddrs = nm.SelfNode.Addresses().AsSlice() return nil } @@ -456,7 +457,7 @@ func (ep *egressProxy) setStatus(ctx context.Context, status *egressservices.Sta // FQDN, resolve the FQDN and return the resolved IPs. It checks if the // netfilter runner supports IPv6 NAT and skips any IPv6 addresses if it // doesn't. -func (ep *egressProxy) tailnetTargetIPsForSvc(svc egressservices.Config, n ipn.Notify) (addrs []netip.Addr, err error) { +func (ep *egressProxy) tailnetTargetIPsForSvc(svc egressservices.Config, nm *netmap.NetworkMap) (addrs []netip.Addr, err error) { if svc.TailnetTarget.IP != "" { addr, err := netip.ParseAddr(svc.TailnetTarget.IP) if err != nil { @@ -472,11 +473,11 @@ func (ep *egressProxy) tailnetTargetIPsForSvc(svc egressservices.Config, n ipn.N if svc.TailnetTarget.FQDN == "" { return nil, errors.New("unexpected egress service config- neither tailnet target IP nor FQDN is set") } - if n.NetMap == nil { + if nm == nil { log.Printf("netmap is not available, unable to determine backend addresses for %s", svc.TailnetTarget.FQDN) return addrs, nil } - egressAddrs, err := resolveTailnetFQDN(n.NetMap, svc.TailnetTarget.FQDN) + egressAddrs, err := resolveTailnetFQDN(nm, svc.TailnetTarget.FQDN) if err != nil { log.Printf("error fetching backend addresses for %q: %v", svc.TailnetTarget.FQDN, err) return addrs, nil @@ -502,22 +503,22 @@ func (ep *egressProxy) tailnetTargetIPsForSvc(svc egressservices.Config, n ipn.N // shouldResync parses netmap update and returns true if the update contains // changes for which the egress proxy's firewall should be reconfigured. -func (ep *egressProxy) shouldResync(n ipn.Notify) bool { - if n.NetMap == nil { +func (ep *egressProxy) shouldResync(nm *netmap.NetworkMap) bool { + if nm == nil { return false } // If proxy's tailnet addresses have changed, resync. - if !reflect.DeepEqual(n.NetMap.SelfNode.Addresses().AsSlice(), ep.tailnetAddrs) { + if !reflect.DeepEqual(nm.SelfNode.Addresses().AsSlice(), ep.tailnetAddrs) { log.Printf("node addresses have changed, trigger egress config resync") - ep.tailnetAddrs = n.NetMap.SelfNode.Addresses().AsSlice() + ep.tailnetAddrs = nm.SelfNode.Addresses().AsSlice() return true } // If the IPs for any of the egress services configured via FQDN have // changed, resync. for fqdn, ips := range ep.targetFQDNs { - for _, nn := range n.NetMap.Peers { + for _, nn := range nm.Peers { if equalFQDNs(nn.Name(), fqdn) { if !reflect.DeepEqual(ips, nn.Addresses().AsSlice()) { log.Printf("backend addresses for egress target %q have changed old IPs %v, new IPs %v trigger egress config resync", nn.Name(), ips, nn.Addresses().AsSlice()) @@ -602,8 +603,8 @@ type rule struct { protocol string } -func wantsServicesConfigured(cfgs *egressservices.Configs) bool { - return cfgs != nil && len(*cfgs) != 0 +func wantsServicesConfigured(cfgs egressservices.Configs) bool { + return cfgs != nil && len(cfgs) != 0 } func hasServicesConfigured(status *egressservices.Status) bool { @@ -657,13 +658,13 @@ func (ep *egressProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { // would normally be this Pod. When this Pod is being deleted, the operator should have removed it from the Service // backends and eventually kube proxy routing rules should be updated to no longer route traffic for the Service to this // Pod. -func (ep *egressProxy) waitTillSafeToShutdown(ctx context.Context, cfgs *egressservices.Configs, hp int) { - if cfgs == nil || len(*cfgs) == 0 { // avoid sleeping if no services are configured +func (ep *egressProxy) waitTillSafeToShutdown(ctx context.Context, cfgs egressservices.Configs, hp int) { + if cfgs == nil || len(cfgs) == 0 { // avoid sleeping if no services are configured return } log.Printf("Ensuring that cluster traffic for egress targets is no longer routed via this Pod...") var wg sync.WaitGroup - for s, cfg := range *cfgs { + for s, cfg := range cfgs { hep := cfg.HealthCheckEndpoint if hep == "" { log.Printf("Tailnet target %q does not have a cluster healthcheck specified, unable to verify if cluster traffic for the target is still routed via this Pod", s) diff --git a/cmd/containerboot/egressservices_test.go b/cmd/containerboot/egressservices_test.go index 0d8504bda..b30765f19 100644 --- a/cmd/containerboot/egressservices_test.go +++ b/cmd/containerboot/egressservices_test.go @@ -255,13 +255,13 @@ func TestWaitTillSafeToShutdown(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - cfgs := &egressservices.Configs{} + cfgs := egressservices.Configs{} switches := make(map[string]int) for svc, callsToSwitch := range tt.services { endpoint := fmt.Sprintf("http://%s.local", svc) if tt.healthCheckSet { - (*cfgs)[svc] = egressservices.Config{ + cfgs[svc] = egressservices.Config{ HealthCheckEndpoint: endpoint, } } diff --git a/cmd/containerboot/kube.go b/cmd/containerboot/kube.go index 73f5819b4..3e97710da 100644 --- a/cmd/containerboot/kube.go +++ b/cmd/containerboot/kube.go @@ -21,6 +21,7 @@ import ( "github.com/fsnotify/fsnotify" "tailscale.com/client/local" "tailscale.com/ipn" + "tailscale.com/kube/authkey" "tailscale.com/kube/egressservices" "tailscale.com/kube/ingressservices" "tailscale.com/kube/kubeapi" @@ -32,7 +33,6 @@ import ( ) const fieldManager = "tailscale-container" -const kubeletMountedConfigLn = "..data" // kubeClient is a wrapper around Tailscale's internal kube client that knows how to talk to the kube API server. We use // this rather than any of the upstream Kubernetes client libaries to avoid extra imports. @@ -127,6 +127,9 @@ func (kc *kubeClient) deleteAuthKey(ctx context.Context) error { // resetContainerbootState resets state from previous runs of containerboot to // ensure the operator doesn't use stale state when a Pod is first recreated. +// +// Device identity keys (device_id, device_fqdn, device_ips) are preserved so +// the operator can clean up the old device from the control plane. func (kc *kubeClient) resetContainerbootState(ctx context.Context, podUID string, tailscaledConfigAuthkey string) error { existingSecret, err := kc.GetSecret(ctx, kc.stateSecret) switch { @@ -139,12 +142,7 @@ func (kc *kubeClient) resetContainerbootState(ctx context.Context, podUID string s := &kubeapi.Secret{ Data: map[string][]byte{ - kubetypes.KeyCapVer: fmt.Appendf(nil, "%d", tailcfg.CurrentCapabilityVersion), - - // TODO(tomhjp): Perhaps shouldn't clear device ID and use a different signal, as this could leak tailnet devices. - kubetypes.KeyDeviceID: nil, - kubetypes.KeyDeviceFQDN: nil, - kubetypes.KeyDeviceIPs: nil, + kubetypes.KeyCapVer: fmt.Appendf(nil, "%d", tailcfg.CurrentCapabilityVersion), kubetypes.KeyHTTPSEndpoint: nil, egressservices.KeyEgressServices: nil, ingressservices.IngressConfigKey: nil, @@ -169,47 +167,18 @@ func (kc *kubeClient) setAndWaitForAuthKeyReissue(ctx context.Context, client *l return fmt.Errorf("error disconnecting from control: %w", err) } - err = kc.setReissueAuthKey(ctx, tailscaledConfigAuthKey) + err = authkey.SetReissueAuthKey(ctx, kc.Client, kc.stateSecret, tailscaledConfigAuthKey, authkey.TailscaleContainerFieldManager) if err != nil { return fmt.Errorf("failed to set reissue_authkey in Kubernetes Secret: %w", err) } - err = kc.waitForAuthKeyReissue(ctx, cfg.TailscaledConfigFilePath, tailscaledConfigAuthKey, 10*time.Minute) - if err != nil { - return fmt.Errorf("failed to receive new auth key: %w", err) + clearFn := func(ctx context.Context) error { + return authkey.ClearReissueAuthKey(ctx, kc.Client, kc.stateSecret, authkey.TailscaleContainerFieldManager) } - return nil -} - -func (kc *kubeClient) setReissueAuthKey(ctx context.Context, authKey string) error { - s := &kubeapi.Secret{ - Data: map[string][]byte{ - kubetypes.KeyReissueAuthkey: []byte(authKey), - }, - } - - log.Printf("Requesting a new auth key from operator") - return kc.StrategicMergePatchSecret(ctx, kc.stateSecret, s, fieldManager) -} - -func (kc *kubeClient) waitForAuthKeyReissue(ctx context.Context, configPath string, oldAuthKey string, maxWait time.Duration) error { - log.Printf("Waiting for operator to provide new auth key (max wait: %v)", maxWait) - - ctx, cancel := context.WithTimeout(ctx, maxWait) - defer cancel() - - tailscaledCfgDir := filepath.Dir(configPath) - toWatch := filepath.Join(tailscaledCfgDir, kubeletMountedConfigLn) - - var ( - pollTicker <-chan time.Time - eventChan <-chan fsnotify.Event - ) - - pollInterval := 5 * time.Second - - // Try to use fsnotify for faster notification + getAuthKey := func() string { return authkey.AuthKeyFromConfig(cfg.TailscaledConfigFilePath) } + tailscaledCfgDir := filepath.Dir(cfg.TailscaledConfigFilePath) + var notify <-chan struct{} if w, err := fsnotify.NewWatcher(); err != nil { log.Printf("auth key reissue: fsnotify unavailable, using polling: %v", err) } else if err := w.Add(tailscaledCfgDir); err != nil { @@ -217,54 +186,28 @@ func (kc *kubeClient) waitForAuthKeyReissue(ctx context.Context, configPath stri log.Printf("auth key reissue: fsnotify watch failed, using polling: %v", err) } else { defer w.Close() + ch := make(chan struct{}, 1) + toWatch := filepath.Join(tailscaledCfgDir, "..data") + go func() { + for ev := range w.Events { + if ev.Name == toWatch { + select { + case ch <- struct{}{}: + default: + } + } + } + }() + notify = ch log.Printf("auth key reissue: watching for config changes via fsnotify") - eventChan = w.Events } - // still keep polling if using fsnotify, for logging and in case fsnotify fails - pt := time.NewTicker(pollInterval) - defer pt.Stop() - pollTicker = pt.C - - start := time.Now() - - for { - select { - case <-ctx.Done(): - return fmt.Errorf("timeout waiting for auth key reissue after %v", maxWait) - case <-pollTicker: // Waits for polling tick, continues when received - case event := <-eventChan: - if event.Name != toWatch { - continue - } - } - - newAuthKey := authkeyFromTailscaledConfig(configPath) - if newAuthKey != "" && newAuthKey != oldAuthKey { - log.Printf("New auth key received from operator after %v", time.Since(start).Round(time.Second)) - - if err := kc.clearReissueAuthKeyRequest(ctx); err != nil { - log.Printf("Warning: failed to clear reissue request: %v", err) - } - - return nil - } - - if eventChan == nil && pollTicker != nil { - log.Printf("Waiting for new auth key from operator (%v elapsed)", time.Since(start).Round(time.Second)) - } + err = authkey.WaitForAuthKeyReissue(ctx, tailscaledConfigAuthKey, 10*time.Minute, getAuthKey, clearFn, notify) + if err != nil { + return fmt.Errorf("failed to receive new auth key: %w", err) } -} -// clearReissueAuthKeyRequest removes the reissue_authkey marker from the Secret -// to signal to the operator that we've successfully received the new key. -func (kc *kubeClient) clearReissueAuthKeyRequest(ctx context.Context) error { - s := &kubeapi.Secret{ - Data: map[string][]byte{ - kubetypes.KeyReissueAuthkey: nil, - }, - } - return kc.StrategicMergePatchSecret(ctx, kc.stateSecret, s, fieldManager) + return nil } // waitForConsistentState waits for tailscaled to finish writing state if it diff --git a/cmd/containerboot/kube_test.go b/cmd/containerboot/kube_test.go index b2e89a36c..fec0b74f7 100644 --- a/cmd/containerboot/kube_test.go +++ b/cmd/containerboot/kube_test.go @@ -257,12 +257,8 @@ func TestResetContainerbootState(t *testing.T) { authkey: "new-authkey", initial: map[string][]byte{}, expected: map[string][]byte{ - kubetypes.KeyCapVer: capver, - kubetypes.KeyPodUID: []byte("1234"), - // Cleared keys. - kubetypes.KeyDeviceID: nil, - kubetypes.KeyDeviceFQDN: nil, - kubetypes.KeyDeviceIPs: nil, + kubetypes.KeyCapVer: capver, + kubetypes.KeyPodUID: []byte("1234"), kubetypes.KeyHTTPSEndpoint: nil, egressservices.KeyEgressServices: nil, ingressservices.IngressConfigKey: nil, @@ -271,11 +267,7 @@ func TestResetContainerbootState(t *testing.T) { "empty_initial_no_pod_uid": { initial: map[string][]byte{}, expected: map[string][]byte{ - kubetypes.KeyCapVer: capver, - // Cleared keys. - kubetypes.KeyDeviceID: nil, - kubetypes.KeyDeviceFQDN: nil, - kubetypes.KeyDeviceIPs: nil, + kubetypes.KeyCapVer: capver, kubetypes.KeyHTTPSEndpoint: nil, egressservices.KeyEgressServices: nil, ingressservices.IngressConfigKey: nil, @@ -303,9 +295,6 @@ func TestResetContainerbootState(t *testing.T) { kubetypes.KeyCapVer: capver, kubetypes.KeyPodUID: []byte("1234"), // Cleared keys. - kubetypes.KeyDeviceID: nil, - kubetypes.KeyDeviceFQDN: nil, - kubetypes.KeyDeviceIPs: nil, kubetypes.KeyHTTPSEndpoint: nil, egressservices.KeyEgressServices: nil, ingressservices.IngressConfigKey: nil, @@ -321,9 +310,6 @@ func TestResetContainerbootState(t *testing.T) { kubetypes.KeyCapVer: capver, kubetypes.KeyReissueAuthkey: nil, // Cleared keys. - kubetypes.KeyDeviceID: nil, - kubetypes.KeyDeviceFQDN: nil, - kubetypes.KeyDeviceIPs: nil, kubetypes.KeyHTTPSEndpoint: nil, egressservices.KeyEgressServices: nil, ingressservices.IngressConfigKey: nil, @@ -338,9 +324,6 @@ func TestResetContainerbootState(t *testing.T) { kubetypes.KeyCapVer: capver, // reissue_authkey not cleared. // Cleared keys. - kubetypes.KeyDeviceID: nil, - kubetypes.KeyDeviceFQDN: nil, - kubetypes.KeyDeviceIPs: nil, kubetypes.KeyHTTPSEndpoint: nil, egressservices.KeyEgressServices: nil, ingressservices.IngressConfigKey: nil, @@ -355,9 +338,6 @@ func TestResetContainerbootState(t *testing.T) { kubetypes.KeyCapVer: capver, // reissue_authkey not cleared. // Cleared keys. - kubetypes.KeyDeviceID: nil, - kubetypes.KeyDeviceFQDN: nil, - kubetypes.KeyDeviceIPs: nil, kubetypes.KeyHTTPSEndpoint: nil, egressservices.KeyEgressServices: nil, ingressservices.IngressConfigKey: nil, diff --git a/cmd/containerboot/main.go b/cmd/containerboot/main.go index e80192a31..1a11c3150 100644 --- a/cmd/containerboot/main.go +++ b/cmd/containerboot/main.go @@ -137,10 +137,11 @@ import ( "golang.org/x/sys/unix" + "tailscale.com/client/local" "tailscale.com/health" "tailscale.com/ipn" - "tailscale.com/ipn/conffile" kubeutils "tailscale.com/k8s-operator" + "tailscale.com/kube/authkey" healthz "tailscale.com/kube/health" "tailscale.com/kube/kubetypes" klc "tailscale.com/kube/localclient" @@ -209,7 +210,7 @@ func run() error { var tailscaledConfigAuthkey string if isOneStepConfig(cfg) { - tailscaledConfigAuthkey = authkeyFromTailscaledConfig(cfg.TailscaledConfigFilePath) + tailscaledConfigAuthkey = authkey.AuthKeyFromConfig(cfg.TailscaledConfigFilePath) } var kc *kubeClient @@ -374,7 +375,7 @@ authLoop: if hasKubeStateStore(cfg) { log.Printf("Auth key missing or invalid (NeedsLogin state), disconnecting from control and requesting new key from operator") - err := kc.setAndWaitForAuthKeyReissue(bootCtx, client, cfg, tailscaledConfigAuthkey) + err := kc.setAndWaitForAuthKeyReissue(ctx, client, cfg, tailscaledConfigAuthkey) if err != nil { return fmt.Errorf("failed to get a reissued authkey: %w", err) } @@ -414,7 +415,7 @@ authLoop: if isOneStepConfig(cfg) && hasKubeStateStore(cfg) { log.Printf("Auth key failed to authenticate (may be expired or single-use), disconnecting from control and requesting new key from operator") - err := kc.setAndWaitForAuthKeyReissue(bootCtx, client, cfg, tailscaledConfigAuthkey) + err := kc.setAndWaitForAuthKeyReissue(ctx, client, cfg, tailscaledConfigAuthkey) if err != nil { return fmt.Errorf("failed to get a reissued authkey: %w", err) } @@ -536,7 +537,7 @@ authLoop: failedResolveAttempts++ } - var egressSvcsNotify chan ipn.Notify + var egressSvcsNotify chan *netmap.NetworkMap notifyChan := make(chan ipn.Notify) errChan := make(chan error) go func() { @@ -550,10 +551,17 @@ authLoop: } } }() + // Peer set changes (Add/Remove) no longer ride on the IPN bus; poll + // periodically so egress FQDN resolution and peer-aware work picks + // them up. SelfChange covers prompt self changes. + const peerPollInterval = 15 * time.Second + peerPoll := time.NewTicker(peerPollInterval) + defer peerPoll.Stop() var wg sync.WaitGroup runLoop: for { + var processNetmap bool select { case <-ctx.Done(): // Although killTailscaled() is deferred earlier, if we @@ -566,6 +574,8 @@ runLoop: return fmt.Errorf("failed to read from tailscaled: %w", err) case err := <-cfgWatchErrChan: return fmt.Errorf("failed to watch tailscaled config: %w", err) + case <-peerPoll.C: + processNetmap = true case n := <-notifyChan: // TODO: (ChaosInTheCRD) Add node removed check when supported by ipn if n.State != nil && *n.State != ipn.Running { @@ -576,235 +586,8 @@ runLoop: // whereupon we'll go through initial auth again. return fmt.Errorf("tailscaled left running state (now in state %q), exiting", *n.State) } - if n.NetMap != nil { - addrs = n.NetMap.SelfNode.Addresses().AsSlice() - newCurrentIPs := deephash.Hash(&addrs) - ipsHaveChanged := newCurrentIPs != currentIPs - - // Store device ID in a Kubernetes Secret before - // setting up any routing rules. This ensures - // that, for containerboot instances that are - // Kubernetes operator proxies, the operator is - // able to retrieve the device ID from the - // Kubernetes Secret to clean up tailnet nodes - // for proxies whose route setup continuously - // fails. - deviceID := n.NetMap.SelfNode.StableID() - if hasKubeStateStore(cfg) && deephash.Update(¤tDeviceID, &deviceID) { - if err := kc.storeDeviceID(ctx, n.NetMap.SelfNode.StableID()); err != nil { - return fmt.Errorf("storing device ID in Kubernetes Secret: %w", err) - } - } - if cfg.TailnetTargetFQDN != "" { - egressAddrs, err := resolveTailnetFQDN(n.NetMap, cfg.TailnetTargetFQDN) - if err != nil { - log.Print(err.Error()) - break - } - - newCurentEgressIPs := deephash.Hash(&egressAddrs) - egressIPsHaveChanged := newCurentEgressIPs != currentEgressIPs - // The firewall rules get (re-)installed: - // - on startup - // - when the tailnet IPs of the tailnet target have changed - // - when the tailnet IPs of this node have changed - if (egressIPsHaveChanged || ipsHaveChanged) && len(egressAddrs) != 0 { - var rulesInstalled bool - for _, egressAddr := range egressAddrs { - ea := egressAddr.Addr() - if ea.Is4() || (ea.Is6() && nfr.HasIPV6NAT()) { - rulesInstalled = true - log.Printf("Installing forwarding rules for destination %v", ea.String()) - if err := installEgressForwardingRule(ctx, ea.String(), addrs, nfr); err != nil { - return fmt.Errorf("installing egress proxy rules for destination %s: %v", ea.String(), err) - } - } - } - if !rulesInstalled { - return fmt.Errorf("no forwarding rules for egress addresses %v, host supports IPv6: %v", egressAddrs, nfr.HasIPV6NAT()) - } - } - currentEgressIPs = newCurentEgressIPs - } - if cfg.ProxyTargetIP != "" && len(addrs) != 0 && ipsHaveChanged { - log.Printf("Installing proxy rules") - if err := installIngressForwardingRule(ctx, cfg.ProxyTargetIP, addrs, nfr); err != nil { - return fmt.Errorf("installing ingress proxy rules: %w", err) - } - } - if cfg.ProxyTargetDNSName != "" && len(addrs) != 0 && ipsHaveChanged { - newBackendAddrs, err := resolveDNS(ctx, cfg.ProxyTargetDNSName) - if err != nil { - log.Printf("[unexpected] error resolving DNS name %s: %v", cfg.ProxyTargetDNSName, err) - resetTimer(true) - continue - } - backendsHaveChanged := !(slices.EqualFunc(backendAddrs, newBackendAddrs, func(ip1 net.IP, ip2 net.IP) bool { - return slices.ContainsFunc(newBackendAddrs, func(ip net.IP) bool { return ip.Equal(ip1) }) - })) - if backendsHaveChanged { - log.Printf("installing ingress proxy rules for backends %v", newBackendAddrs) - if err := installIngressForwardingRuleForDNSTarget(ctx, newBackendAddrs, addrs, nfr); err != nil { - return fmt.Errorf("error installing ingress proxy rules: %w", err) - } - } - resetTimer(false) - backendAddrs = newBackendAddrs - } - if cfg.ServeConfigPath != "" { - cd := certDomainFromNetmap(n.NetMap) - if cd == "" { - cd = kubetypes.ValueNoHTTPS - } - prev := certDomain.Swap(new(cd)) - if prev == nil || *prev != cd { - select { - case certDomainChanged <- true: - default: - } - } - } - if cfg.TailnetTargetIP != "" && ipsHaveChanged && len(addrs) != 0 { - log.Printf("Installing forwarding rules for destination %v", cfg.TailnetTargetIP) - if err := installEgressForwardingRule(ctx, cfg.TailnetTargetIP, addrs, nfr); err != nil { - return fmt.Errorf("installing egress proxy rules: %w", err) - } - } - // If this is a L7 cluster ingress proxy (set up - // by Kubernetes operator) and proxying of - // cluster traffic to the ingress target is - // enabled, set up proxy rule each time the - // tailnet IPs of this node change (including - // the first time they become available). - if cfg.AllowProxyingClusterTrafficViaIngress && cfg.ServeConfigPath != "" && ipsHaveChanged && len(addrs) != 0 { - log.Printf("installing rules to forward traffic for %s to node's tailnet IP", cfg.PodIP) - if err := installTSForwardingRuleForDestination(ctx, cfg.PodIP, addrs, nfr); err != nil { - return fmt.Errorf("installing rules to forward traffic to node's tailnet IP: %w", err) - } - } - currentIPs = newCurrentIPs - - // Only store device FQDN and IP addresses to - // Kubernetes Secret when any required proxy - // route setup has succeeded. IPs and FQDN are - // read from the Secret by the Tailscale - // Kubernetes operator and, for some proxy - // types, such as Tailscale Ingress, advertized - // on the Ingress status. Writing them to the - // Secret only after the proxy routing has been - // set up ensures that the operator does not - // advertize endpoints of broken proxies. - // TODO (irbekrm): instead of using the IP and FQDN, have some other mechanism for the proxy signal that it is 'Ready'. - deviceEndpoints := []any{n.NetMap.SelfNode.Name(), n.NetMap.SelfNode.Addresses()} - if hasKubeStateStore(cfg) && deephash.Update(¤tDeviceEndpoints, &deviceEndpoints) { - if err := kc.storeDeviceEndpoints(ctx, n.NetMap.SelfNode.Name(), n.NetMap.SelfNode.Addresses().AsSlice()); err != nil { - return fmt.Errorf("storing device IPs and FQDN in Kubernetes Secret: %w", err) - } - } - - if healthCheck != nil { - healthCheck.Update(len(addrs) != 0) - } - - var prevServeConfig *ipn.ServeConfig - if getAutoAdvertiseBool() { - prevServeConfig, err = client.GetServeConfig(ctx) - if err != nil { - return fmt.Errorf("autoadvertisement: failed to get serve config: %w", err) - } - - err = refreshAdvertiseServices(ctx, prevServeConfig, klc.New(client)) - if err != nil { - return fmt.Errorf("autoadvertisement: failed to refresh advertise services: %w", err) - } - } - - if cfg.ServeConfigPath != "" { - triggerWatchServeConfigChanges.Do(func() { - go watchServeConfigChanges(ctx, certDomainChanged, certDomain, client, kc, cfg, prevServeConfig) - }) - } - - if egressSvcsNotify != nil { - egressSvcsNotify <- n - } - } - if !startupTasksDone { - // For containerboot instances that act as TCP proxies (proxying traffic to an endpoint - // passed via one of the env vars that containerboot reads) and store state in a - // Kubernetes Secret, we consider startup tasks done at the point when device info has - // been successfully stored to state Secret. For all other containerboot instances, if - // we just get to this point the startup tasks can be considered done. - if !isL3Proxy(cfg) || !hasKubeStateStore(cfg) || (currentDeviceEndpoints != deephash.Sum{} && currentDeviceID != deephash.Sum{}) { - // This log message is used in tests to detect when all - // post-auth configuration is done. - log.Println("Startup complete, waiting for shutdown signal") - startupTasksDone = true - - // Configure egress proxy. Egress proxy will set up firewall rules to proxy - // traffic to tailnet targets configured in the provided configuration file. It - // will then continuously monitor the config file and netmap updates and - // reconfigure the firewall rules as needed. If any of its operations fail, it - // will crash this node. - if cfg.EgressProxiesCfgPath != "" { - log.Printf("configuring egress proxy using configuration file at %s", cfg.EgressProxiesCfgPath) - egressSvcsNotify = make(chan ipn.Notify) - opts := egressProxyRunOpts{ - cfgPath: cfg.EgressProxiesCfgPath, - nfr: nfr, - kc: kc, - tsClient: client, - stateSecret: cfg.KubeSecret, - netmapChan: egressSvcsNotify, - podIPv4: cfg.PodIPv4, - tailnetAddrs: addrs, - } - go func() { - if err := ep.run(ctx, n, opts); err != nil { - egressSvcsErrorChan <- err - } - }() - } - ip := ingressProxy{} - if cfg.IngressProxiesCfgPath != "" { - log.Printf("configuring ingress proxy using configuration file at %s", cfg.IngressProxiesCfgPath) - opts := ingressProxyOpts{ - cfgPath: cfg.IngressProxiesCfgPath, - nfr: nfr, - kc: kc, - stateSecret: cfg.KubeSecret, - podIPv4: cfg.PodIPv4, - podIPv6: cfg.PodIPv6, - } - go func() { - if err := ip.run(ctx, opts); err != nil { - ingressSvcsErrorChan <- err - } - }() - } - - // Wait on tailscaled process. It won't be cleaned up by default when the - // container exits as it is not PID1. TODO (irbekrm): perhaps we can replace the - // reaper by a running cmd.Wait in a goroutine immediately after starting - // tailscaled? - reaper := func() { - defer wg.Done() - for { - var status unix.WaitStatus - _, err := unix.Wait4(daemonProcess.Pid, &status, 0, nil) - if errors.Is(err, unix.EINTR) { - continue - } - if err != nil { - log.Fatalf("Waiting for tailscaled to exit: %v", err) - } - log.Print("tailscaled exited") - os.Exit(0) - } - } - wg.Add(1) - go reaper() - } + if n.SelfChange != nil { + processNetmap = true } case <-tc: newBackendAddrs, err := resolveDNS(ctx, cfg.ProxyTargetDNSName) @@ -824,11 +607,250 @@ runLoop: } backendAddrs = newBackendAddrs resetTimer(false) + continue case e := <-egressSvcsErrorChan: return fmt.Errorf("egress proxy failed: %v", e) case e := <-ingressSvcsErrorChan: return fmt.Errorf("ingress proxy failed: %v", e) } + if !processNetmap { + continue + } + nm, err := fetchNetMap(ctx, client) + if err != nil { + log.Printf("error fetching netmap: %v", err) + continue + } + if nm != nil { + addrs = nm.SelfNode.Addresses().AsSlice() + newCurrentIPs := deephash.Hash(&addrs) + ipsHaveChanged := newCurrentIPs != currentIPs + + // Store device ID in a Kubernetes Secret before + // setting up any routing rules. This ensures + // that, for containerboot instances that are + // Kubernetes operator proxies, the operator is + // able to retrieve the device ID from the + // Kubernetes Secret to clean up tailnet nodes + // for proxies whose route setup continuously + // fails. + deviceID := nm.SelfNode.StableID() + if hasKubeStateStore(cfg) && deephash.Update(¤tDeviceID, &deviceID) { + if err := kc.storeDeviceID(ctx, nm.SelfNode.StableID()); err != nil { + return fmt.Errorf("storing device ID in Kubernetes Secret: %w", err) + } + } + if cfg.TailnetTargetFQDN != "" { + egressAddrs, err := resolveTailnetFQDN(nm, cfg.TailnetTargetFQDN) + if err != nil { + log.Print(err.Error()) + break + } + + newCurentEgressIPs := deephash.Hash(&egressAddrs) + egressIPsHaveChanged := newCurentEgressIPs != currentEgressIPs + // The firewall rules get (re-)installed: + // - on startup + // - when the tailnet IPs of the tailnet target have changed + // - when the tailnet IPs of this node have changed + if (egressIPsHaveChanged || ipsHaveChanged) && len(egressAddrs) != 0 { + var rulesInstalled bool + for _, egressAddr := range egressAddrs { + ea := egressAddr.Addr() + if ea.Is4() || (ea.Is6() && nfr.HasIPV6NAT()) { + rulesInstalled = true + log.Printf("Installing forwarding rules for destination %v", ea.String()) + if err := installEgressForwardingRule(ctx, ea.String(), addrs, nfr); err != nil { + return fmt.Errorf("installing egress proxy rules for destination %s: %v", ea.String(), err) + } + } + } + if !rulesInstalled { + return fmt.Errorf("no forwarding rules for egress addresses %v, host supports IPv6: %v", egressAddrs, nfr.HasIPV6NAT()) + } + } + currentEgressIPs = newCurentEgressIPs + } + if cfg.ProxyTargetIP != "" && len(addrs) != 0 && ipsHaveChanged { + log.Printf("Installing proxy rules") + if err := installIngressForwardingRule(ctx, cfg.ProxyTargetIP, addrs, nfr); err != nil { + return fmt.Errorf("installing ingress proxy rules: %w", err) + } + } + if cfg.ProxyTargetDNSName != "" && len(addrs) != 0 && ipsHaveChanged { + newBackendAddrs, err := resolveDNS(ctx, cfg.ProxyTargetDNSName) + if err != nil { + log.Printf("[unexpected] error resolving DNS name %s: %v", cfg.ProxyTargetDNSName, err) + resetTimer(true) + continue + } + backendsHaveChanged := !(slices.EqualFunc(backendAddrs, newBackendAddrs, func(ip1 net.IP, ip2 net.IP) bool { + return slices.ContainsFunc(newBackendAddrs, func(ip net.IP) bool { return ip.Equal(ip1) }) + })) + if backendsHaveChanged { + log.Printf("installing ingress proxy rules for backends %v", newBackendAddrs) + if err := installIngressForwardingRuleForDNSTarget(ctx, newBackendAddrs, addrs, nfr); err != nil { + return fmt.Errorf("error installing ingress proxy rules: %w", err) + } + } + resetTimer(false) + backendAddrs = newBackendAddrs + } + if cfg.ServeConfigPath != "" { + cd := certDomainFromNetmap(nm) + if cd == "" { + cd = kubetypes.ValueNoHTTPS + } + prev := certDomain.Swap(new(cd)) + if prev == nil || *prev != cd { + select { + case certDomainChanged <- true: + default: + } + } + } + if cfg.TailnetTargetIP != "" && ipsHaveChanged && len(addrs) != 0 { + log.Printf("Installing forwarding rules for destination %v", cfg.TailnetTargetIP) + if err := installEgressForwardingRule(ctx, cfg.TailnetTargetIP, addrs, nfr); err != nil { + return fmt.Errorf("installing egress proxy rules: %w", err) + } + } + // If this is a L7 cluster ingress proxy (set up + // by Kubernetes operator) and proxying of + // cluster traffic to the ingress target is + // enabled, set up proxy rule each time the + // tailnet IPs of this node change (including + // the first time they become available). + if cfg.AllowProxyingClusterTrafficViaIngress && cfg.ServeConfigPath != "" && ipsHaveChanged && len(addrs) != 0 { + log.Printf("installing rules to forward traffic for %s to node's tailnet IP", cfg.PodIP) + if err := installTSForwardingRuleForDestination(ctx, cfg.PodIP, addrs, nfr); err != nil { + return fmt.Errorf("installing rules to forward traffic to node's tailnet IP: %w", err) + } + } + currentIPs = newCurrentIPs + + // Only store device FQDN and IP addresses to + // Kubernetes Secret when any required proxy + // route setup has succeeded. IPs and FQDN are + // read from the Secret by the Tailscale + // Kubernetes operator and, for some proxy + // types, such as Tailscale Ingress, advertized + // on the Ingress status. Writing them to the + // Secret only after the proxy routing has been + // set up ensures that the operator does not + // advertize endpoints of broken proxies. + // TODO (irbekrm): instead of using the IP and FQDN, have some other mechanism for the proxy signal that it is 'Ready'. + deviceEndpoints := []any{nm.SelfNode.Name(), nm.SelfNode.Addresses()} + if hasKubeStateStore(cfg) && deephash.Update(¤tDeviceEndpoints, &deviceEndpoints) { + if err := kc.storeDeviceEndpoints(ctx, nm.SelfNode.Name(), nm.SelfNode.Addresses().AsSlice()); err != nil { + return fmt.Errorf("storing device IPs and FQDN in Kubernetes Secret: %w", err) + } + } + + if healthCheck != nil { + healthCheck.Update(len(addrs) != 0) + } + + var prevServeConfig *ipn.ServeConfig + if getAutoAdvertiseBool() { + prevServeConfig, err = client.GetServeConfig(ctx) + if err != nil { + return fmt.Errorf("autoadvertisement: failed to get serve config: %w", err) + } + + err = refreshAdvertiseServices(ctx, prevServeConfig, klc.New(client)) + if err != nil { + return fmt.Errorf("autoadvertisement: failed to refresh advertise services: %w", err) + } + } + + if cfg.ServeConfigPath != "" { + triggerWatchServeConfigChanges.Do(func() { + go watchServeConfigChanges(ctx, certDomainChanged, certDomain, client, kc, cfg, prevServeConfig) + }) + } + + if egressSvcsNotify != nil { + egressSvcsNotify <- nm + } + } + if !startupTasksDone { + // For containerboot instances that act as TCP proxies (proxying traffic to an endpoint + // passed via one of the env vars that containerboot reads) and store state in a + // Kubernetes Secret, we consider startup tasks done at the point when device info has + // been successfully stored to state Secret. For all other containerboot instances, if + // we just get to this point the startup tasks can be considered done. + if !isL3Proxy(cfg) || !hasKubeStateStore(cfg) || (currentDeviceEndpoints != deephash.Sum{} && currentDeviceID != deephash.Sum{}) { + // This log message is used in tests to detect when all + // post-auth configuration is done. + log.Println("Startup complete, waiting for shutdown signal") + startupTasksDone = true + + // Configure egress proxy. Egress proxy will set up firewall rules to proxy + // traffic to tailnet targets configured in the provided configuration file. It + // will then continuously monitor the config file and netmap updates and + // reconfigure the firewall rules as needed. If any of its operations fail, it + // will crash this node. + if cfg.EgressProxiesCfgPath != "" { + log.Printf("configuring egress proxy using configuration file at %s", cfg.EgressProxiesCfgPath) + egressSvcsNotify = make(chan *netmap.NetworkMap) + opts := egressProxyRunOpts{ + cfgPath: cfg.EgressProxiesCfgPath, + nfr: nfr, + kc: kc, + tsClient: client, + stateSecret: cfg.KubeSecret, + netmapChan: egressSvcsNotify, + podIPv4: cfg.PodIPv4, + tailnetAddrs: addrs, + } + go func() { + if err := ep.run(ctx, nm, opts); err != nil { + egressSvcsErrorChan <- err + } + }() + } + ip := ingressProxy{} + if cfg.IngressProxiesCfgPath != "" { + log.Printf("configuring ingress proxy using configuration file at %s", cfg.IngressProxiesCfgPath) + opts := ingressProxyOpts{ + cfgPath: cfg.IngressProxiesCfgPath, + nfr: nfr, + kc: kc, + stateSecret: cfg.KubeSecret, + podIPv4: cfg.PodIPv4, + podIPv6: cfg.PodIPv6, + } + go func() { + if err := ip.run(ctx, opts); err != nil { + ingressSvcsErrorChan <- err + } + }() + } + + // Wait on tailscaled process. It won't be cleaned up by default when the + // container exits as it is not PID1. TODO (irbekrm): perhaps we can replace the + // reaper by a running cmd.Wait in a goroutine immediately after starting + // tailscaled? + reaper := func() { + defer wg.Done() + for { + var status unix.WaitStatus + _, err := unix.Wait4(daemonProcess.Pid, &status, 0, nil) + if errors.Is(err, unix.EINTR) { + continue + } + if err != nil { + log.Fatalf("Waiting for tailscaled to exit: %v", err) + } + log.Print("tailscaled exited") + os.Exit(0) + } + } + wg.Add(1) + go reaper() + } + } } wg.Wait() @@ -963,6 +985,15 @@ func runHTTPServer(mux *http.ServeMux, addr string) (close func() error) { } } +// fetchNetMap fetches the current netmap from tailscaled via the +// "current-netmap" localapi debug action. The debug action's payload +// shape is intentionally not part of any stable API; containerboot +// reads its own internal-package types out of it. New external consumers +// should not rely on this — see [local.Client.Status] and friends. +func fetchNetMap(ctx context.Context, lc *local.Client) (*netmap.NetworkMap, error) { + return local.GetDebugResultJSON[*netmap.NetworkMap](ctx, lc, "current-netmap") +} + // resolveTailnetFQDN resolves a tailnet FQDN to a list of IP prefixes, which // can be either a peer device or a Tailscale Service. func resolveTailnetFQDN(nm *netmap.NetworkMap, fqdn string) ([]netip.Prefix, error) { @@ -1024,11 +1055,3 @@ func serviceIPsFromNetMap(nm *netmap.NetworkMap, fqdn dnsname.FQDN) []netip.Pref return prefixes } - -func authkeyFromTailscaledConfig(path string) string { - if cfg, err := conffile.Load(path); err == nil && cfg.Parsed.AuthKey != nil { - return *cfg.Parsed.AuthKey - } - - return "" -} diff --git a/cmd/containerboot/main_test.go b/cmd/containerboot/main_test.go index 5ea402f66..40f575250 100644 --- a/cmd/containerboot/main_test.go +++ b/cmd/containerboot/main_test.go @@ -32,6 +32,7 @@ import ( "github.com/google/go-cmp/cmp" "golang.org/x/sys/unix" + "tailscale.com/cmd/testwrapper/flakytest" "tailscale.com/health" "tailscale.com/ipn" "tailscale.com/kube/egressservices" @@ -45,6 +46,7 @@ import ( const configFileAuthKey = "some-auth-key" func TestContainerBoot(t *testing.T) { + flakytest.Mark(t, "https://github.com/tailscale/tailscale/issues/19380") boot := filepath.Join(t.TempDir(), "containerboot") if err := exec.Command("go", "build", "-ldflags", "-X main.testSleepDuration=1ms", "-o", boot, "tailscale.com/cmd/containerboot").Run(); err != nil { t.Fatalf("Building containerboot: %v", err) @@ -69,6 +71,12 @@ func TestContainerBoot(t *testing.T) { // Waits below to be true before proceeding to the next phase. Notify *ipn.Notify + // If non-nil, install this NetMap on the fake LocalAPI before + // sending Notify. This is the replacement for the old + // Notify.NetMap field; reactive consumers fetch the current + // netmap via /localapi/v0/netmap on their own. + NetMap *netmap.NetworkMap + // WantCmds is the commands that containerboot should run in this phase. WantCmds []string @@ -103,12 +111,10 @@ func TestContainerBoot(t *testing.T) { } runningNotify := &ipn.Notify{ State: new(ipn.Running), - NetMap: &netmap.NetworkMap{ - SelfNode: (&tailcfg.Node{ - StableID: tailcfg.StableNodeID("myID"), - Name: "test-node.test.ts.net.", - Addresses: []netip.Prefix{netip.MustParsePrefix("100.64.0.1/32")}, - }).View(), + SelfChange: &tailcfg.Node{ + StableID: tailcfg.StableNodeID("myID"), + Name: "test-node.test.ts.net.", + Addresses: []netip.Prefix{netip.MustParsePrefix("100.64.0.1/32")}, }, } type testCase struct { @@ -381,19 +387,24 @@ func TestContainerBoot(t *testing.T) { { Notify: &ipn.Notify{ State: new(ipn.Running), - NetMap: &netmap.NetworkMap{ - SelfNode: (&tailcfg.Node{ - StableID: tailcfg.StableNodeID("myID"), - Name: "test-node.test.ts.net.", - Addresses: []netip.Prefix{netip.MustParsePrefix("100.64.0.1/32")}, + SelfChange: &tailcfg.Node{ + StableID: tailcfg.StableNodeID("myID"), + Name: "test-node.test.ts.net.", + Addresses: []netip.Prefix{netip.MustParsePrefix("100.64.0.1/32")}, + }, + }, + NetMap: &netmap.NetworkMap{ + SelfNode: (&tailcfg.Node{ + StableID: tailcfg.StableNodeID("myID"), + Name: "test-node.test.ts.net.", + Addresses: []netip.Prefix{netip.MustParsePrefix("100.64.0.1/32")}, + }).View(), + Peers: []tailcfg.NodeView{ + (&tailcfg.Node{ + StableID: tailcfg.StableNodeID("ipv6ID"), + Name: "ipv6-node.test.ts.net.", + Addresses: []netip.Prefix{netip.MustParsePrefix("::1/128")}, }).View(), - Peers: []tailcfg.NodeView{ - (&tailcfg.Node{ - StableID: tailcfg.StableNodeID("ipv6ID"), - Name: "ipv6-node.test.ts.net.", - Addresses: []netip.Prefix{netip.MustParsePrefix("::1/128")}, - }).View(), - }, }, }, WantLog: "no forwarding rules for egress addresses [::1/128], host supports IPv6: false", @@ -629,14 +640,19 @@ func TestContainerBoot(t *testing.T) { { Notify: &ipn.Notify{ State: new(ipn.Running), - NetMap: &netmap.NetworkMap{ - SelfNode: (&tailcfg.Node{ - StableID: tailcfg.StableNodeID("newID"), - Name: "new-name.test.ts.net.", - Addresses: []netip.Prefix{netip.MustParsePrefix("100.64.0.1/32")}, - }).View(), + SelfChange: &tailcfg.Node{ + StableID: tailcfg.StableNodeID("newID"), + Name: "new-name.test.ts.net.", + Addresses: []netip.Prefix{netip.MustParsePrefix("100.64.0.1/32")}, }, }, + NetMap: &netmap.NetworkMap{ + SelfNode: (&tailcfg.Node{ + StableID: tailcfg.StableNodeID("newID"), + Name: "new-name.test.ts.net.", + Addresses: []netip.Prefix{netip.MustParsePrefix("100.64.0.1/32")}, + }).View(), + }, WantKubeSecret: map[string]string{ "authkey": "tskey-key", "device_fqdn": "new-name.test.ts.net.", @@ -1093,19 +1109,24 @@ func TestContainerBoot(t *testing.T) { { Notify: &ipn.Notify{ State: new(ipn.Running), - NetMap: &netmap.NetworkMap{ - SelfNode: (&tailcfg.Node{ - StableID: tailcfg.StableNodeID("myID"), - Name: "test-node.test.ts.net.", - Addresses: []netip.Prefix{netip.MustParsePrefix("100.64.0.1/32")}, + SelfChange: &tailcfg.Node{ + StableID: tailcfg.StableNodeID("myID"), + Name: "test-node.test.ts.net.", + Addresses: []netip.Prefix{netip.MustParsePrefix("100.64.0.1/32")}, + }, + }, + NetMap: &netmap.NetworkMap{ + SelfNode: (&tailcfg.Node{ + StableID: tailcfg.StableNodeID("myID"), + Name: "test-node.test.ts.net.", + Addresses: []netip.Prefix{netip.MustParsePrefix("100.64.0.1/32")}, + }).View(), + Peers: []tailcfg.NodeView{ + (&tailcfg.Node{ + StableID: tailcfg.StableNodeID("fooID"), + Name: "foo.tailnetxyz.ts.net.", + Addresses: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")}, }).View(), - Peers: []tailcfg.NodeView{ - (&tailcfg.Node{ - StableID: tailcfg.StableNodeID("fooID"), - Name: "foo.tailnetxyz.ts.net.", - Addresses: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")}, - }).View(), - }, }, }, WantKubeSecret: map[string]string{ @@ -1274,6 +1295,18 @@ func TestContainerBoot(t *testing.T) { t.Fatalf("phase %d: updating mtime for %q: %v", i, path, err) } } + nmForFake := p.NetMap + if nmForFake == nil && p.Notify != nil && p.Notify.SelfChange != nil { + // Synthesize a minimal netmap from SelfChange so + // containerboot's NetMap() fetch returns + // something usable when the test only set Notify. + nmForFake = &netmap.NetworkMap{ + SelfNode: p.Notify.SelfChange.View(), + } + } + if nmForFake != nil { + env.lapi.SetNetMap(nmForFake) + } env.lapi.Notify(p.Notify) if p.Signal != nil { cmd.Process.Signal(*p.Signal) @@ -1466,6 +1499,7 @@ type localAPI struct { sync.Mutex cond *sync.Cond notify *ipn.Notify + netmap *netmap.NetworkMap // served by /localapi/v0/netmap } func (lc *localAPI) Start() error { @@ -1502,8 +1536,44 @@ func (lc *localAPI) Notify(n *ipn.Notify) { lc.cond.Broadcast() } +// SetNetMap installs the netmap that the fake /localapi/v0/netmap endpoint +// will return. +func (lc *localAPI) SetNetMap(nm *netmap.NetworkMap) { + lc.Lock() + defer lc.Unlock() + lc.netmap = nm +} + func (lc *localAPI) ServeHTTP(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { + case "/localapi/v0/netmap": + w.Header().Set("Content-Type", "application/json") + lc.Lock() + nm := lc.netmap + lc.Unlock() + if nm == nil { + http.Error(w, "no netmap", http.StatusServiceUnavailable) + return + } + json.NewEncoder(w).Encode(nm) + return + case "/localapi/v0/debug": + // containerboot fetches the netmap via the "current-netmap" + // debug action; serve it like /localapi/v0/netmap above. + if r.URL.Query().Get("action") != "current-netmap" { + http.Error(w, "unsupported debug action", http.StatusNotFound) + return + } + w.Header().Set("Content-Type", "application/json") + lc.Lock() + nm := lc.netmap + lc.Unlock() + if nm == nil { + http.Error(w, "no netmap", http.StatusServiceUnavailable) + return + } + json.NewEncoder(w).Encode(nm) + return case "/localapi/v0/serve-config": switch r.Method { case "GET": diff --git a/cmd/derper/bootstrap_dns_test.go b/cmd/derper/bootstrap_dns_test.go index 5b765f6d3..2055b9751 100644 --- a/cmd/derper/bootstrap_dns_test.go +++ b/cmd/derper/bootstrap_dns_test.go @@ -41,8 +41,28 @@ func (b *bitbucketResponseWriter) Write(p []byte) (int, error) { return len(p), func (b *bitbucketResponseWriter) WriteHeader(statusCode int) {} +// setDNSCache sets the published DNS cache for tests. +func setDNSCache(tb testing.TB, m *dnsEntryMap) { + tb.Helper() + j, err := json.Marshal(m.IPs) + if err != nil { + tb.Fatal(err) + } + tstest.AssertNotParallel(tb) + dnsCache.Store(m) + dnsCacheBytes.Store(j) + tb.Cleanup(func() { + dnsCache.Store(nil) + dnsCacheBytes.Store(nil) + }) +} + func getBootstrapDNS(t *testing.T, q string) map[string][]net.IP { t.Helper() + tstest.AssertNotParallel(t) + if dnsCache.Load() == nil { + t.Fatal("dnsCache not initialized; call setDNSCache before getBootstrapDNS") + } req, _ := http.NewRequest("GET", "https://localhost/bootstrap-dns?q="+url.QueryEscape(q), nil) w := httptest.NewRecorder() handleBootstrapDNS(w, req) @@ -100,7 +120,8 @@ func TestUnpublishedDNS(t *testing.T) { } } -func resetMetrics() { +func resetMetrics(tb testing.TB) { + tstest.AssertNotParallel(tb) publishedDNSHits.Set(0) publishedDNSMisses.Set(0) unpublishedDNSHits.Set(0) @@ -114,8 +135,7 @@ func TestUnpublishedDNSEmptyList(t *testing.T) { pub := &dnsEntryMap{ IPs: map[string][]net.IP{"tailscale.com": {net.IPv4(10, 10, 10, 10)}}, } - dnsCache.Store(pub) - dnsCacheBytes.Store([]byte(`{"tailscale.com":["10.10.10.10"]}`)) + setDNSCache(t, pub) unpublishedDNSCache.Store(&dnsEntryMap{ IPs: map[string][]net.IP{ @@ -131,7 +151,7 @@ func TestUnpublishedDNSEmptyList(t *testing.T) { t.Run("CacheMiss", func(t *testing.T) { // One domain in map but empty, one not in map at all for _, q := range []string{"log.tailscale.com", "login.tailscale.com"} { - resetMetrics() + resetMetrics(t) ips := getBootstrapDNS(t, q) // Expected our public map to be returned on a cache miss @@ -149,7 +169,7 @@ func TestUnpublishedDNSEmptyList(t *testing.T) { // Verify that we do get a valid response and metric. t.Run("CacheHit", func(t *testing.T) { - resetMetrics() + resetMetrics(t) ips := getBootstrapDNS(t, "controlplane.tailscale.com") want := map[string][]net.IP{"controlplane.tailscale.com": {net.IPv4(1, 2, 3, 4)}} if !reflect.DeepEqual(ips, want) { @@ -166,8 +186,10 @@ func TestUnpublishedDNSEmptyList(t *testing.T) { } func TestLookupMetric(t *testing.T) { + setDNSCache(t, &dnsEntryMap{}) + d := []string{"a.io", "b.io", "c.io", "d.io", "e.io", "e.io", "e.io", "a.io"} - resetMetrics() + resetMetrics(t) for _, q := range d { _ = getBootstrapDNS(t, q) } diff --git a/cmd/derper/depaware.txt b/cmd/derper/depaware.txt index ec59c7264..c927335e4 100644 --- a/cmd/derper/depaware.txt +++ b/cmd/derper/depaware.txt @@ -20,6 +20,7 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa github.com/go-json-experiment/json/internal/jsonopts from github.com/go-json-experiment/json+ github.com/go-json-experiment/json/internal/jsonwire from github.com/go-json-experiment/json+ github.com/go-json-experiment/json/jsontext from github.com/go-json-experiment/json+ + 💣 github.com/go4org/hashtriemap from tailscale.com/derp/derpserver github.com/golang/groupcache/lru from tailscale.com/net/dnscache github.com/hdevalence/ed25519consensus from tailscale.com/tka L 💣 github.com/jsimonetti/rtnetlink from tailscale.com/net/netmon @@ -310,7 +311,7 @@ tailscale.com/cmd/derper dependencies: (generated by github.com/tailscale/depawa hash from crypto+ hash/crc32 from compress/gzip+ hash/fnv from google.golang.org/protobuf/internal/detrand - hash/maphash from go4.org/mem + hash/maphash from go4.org/mem+ html from net/http/pprof+ html/template from tailscale.com/cmd/derper+ internal/abi from crypto/x509/internal/macos+ diff --git a/cmd/derper/derper.go b/cmd/derper/derper.go index 429aff361..745d887f8 100644 --- a/cmd/derper/derper.go +++ b/cmd/derper/derper.go @@ -87,8 +87,7 @@ var ( acceptConnLimit = flag.Float64("accept-connection-limit", math.Inf(+1), "rate limit for accepting new connection") acceptConnBurst = flag.Int("accept-connection-burst", math.MaxInt, "burst limit for accepting new connection") - perClientRateLimit = flag.Uint("per-client-rate-limit", 0, "per-client receive rate limit in bytes/sec; 0 means unlimited. Mesh peers are exempt.") - perClientRateBurst = flag.Uint("per-client-rate-burst", 0, "per-client receive rate burst in bytes; 0 defaults to 2x the rate limit (only relevant when using nonzero --per-client-rate-limit)") + rateConfigPath = flag.String("rate-config", "", "if non-empty, path to JSON rate limit config file. Rate limiting is experimental and subject to change. Configuration is reloaded on SIGHUP.") // tcpKeepAlive is intentionally long, to reduce battery cost. There is an L7 keepalive on a higher frequency schedule. tcpKeepAlive = flag.Duration("tcp-keepalive-time", 10*time.Minute, "TCP keepalive time") @@ -195,12 +194,11 @@ func main() { s.SetVerifyClientURL(*verifyClientURL) s.SetVerifyClientURLFailOpen(*verifyFailOpen) s.SetTCPWriteTimeout(*tcpWriteTimeout) - if *perClientRateLimit > 0 { - burst := *perClientRateBurst - if burst < 1 { - burst = *perClientRateLimit * 2 + if *rateConfigPath != "" { + if err := s.LoadAndApplyRateConfig(*rateConfigPath); err != nil { + log.Fatalf("derper: loading rate config: %v", err) } - s.SetPerClientRateLimit(*perClientRateLimit, burst) + go watchRateConfig(ctx, s, *rateConfigPath) } var meshKey string @@ -254,7 +252,7 @@ func main() { if err := startMesh(s); err != nil { log.Fatalf("startMesh: %v", err) } - expvar.Publish("derp", s.ExpVar()) + expvar.Publish("derp", s.ExpVar(*rateConfigPath != "")) handleHome, ok := getHomeHandler(*flagHome) if !ok { @@ -436,6 +434,27 @@ func main() { } } +// watchRateConfig listens for SIGHUP signals and reloads the rate config +// file on each signal, applying it to the server. It returns when ctx is done. +func watchRateConfig(ctx context.Context, s *derpserver.Server, path string) { + sighup := make(chan os.Signal, 1) + signal.Notify(sighup, syscall.SIGHUP) + defer signal.Stop(sighup) + for { + select { + case <-ctx.Done(): + return + case <-sighup: + log.Printf("derper: received SIGHUP, reloading rate config from %s", path) + if err := s.LoadAndApplyRateConfig(path); err != nil { + log.Printf("derper: rate config reload failed: %v", err) + continue + } + log.Printf("derper: rate config reloaded successfully") + } + } +} + var validProdHostname = regexp.MustCompile(`^derp([^.]*)\.tailscale\.com\.?$`) func prodAutocertHostPolicy(_ context.Context, host string) error { diff --git a/cmd/gitops-pusher/gitops-pusher.go b/cmd/gitops-pusher/gitops-pusher.go index 11448e30d..9ea115a15 100644 --- a/cmd/gitops-pusher/gitops-pusher.go +++ b/cmd/gitops-pusher/gitops-pusher.go @@ -26,7 +26,7 @@ import ( "github.com/tailscale/hujson" "golang.org/x/oauth2/clientcredentials" tsclient "tailscale.com/client/tailscale" - _ "tailscale.com/feature/condregister/identityfederation" + _ "tailscale.com/feature/identityfederation" "tailscale.com/internal/client/tailscale" "tailscale.com/util/httpm" ) diff --git a/cmd/hello/hello.go b/cmd/hello/hello.go index 710de49cd..45eb7751c 100644 --- a/cmd/hello/hello.go +++ b/cmd/hello/hello.go @@ -5,212 +5,16 @@ package main // import "tailscale.com/cmd/hello" import ( - "context" - "crypto/tls" - _ "embed" - "encoding/json" - "errors" - "flag" - "html/template" "log" - "net/http" - "os" - "strings" - "time" - "tailscale.com/client/local" - "tailscale.com/client/tailscale/apitype" - "tailscale.com/tailcfg" + "tailscale.com/cmd/hello/helloserver" ) -var ( - httpAddr = flag.String("http", ":80", "address to run an HTTP server on, or empty for none") - httpsAddr = flag.String("https", ":443", "address to run an HTTPS server on, or empty for none") - testIP = flag.String("test-ip", "", "if non-empty, look up IP and exit before running a server") -) - -//go:embed hello.tmpl.html -var embeddedTemplate string - -var localClient local.Client - func main() { - flag.Parse() - if *testIP != "" { - res, err := localClient.WhoIs(context.Background(), *testIP) - if err != nil { - log.Fatal(err) - } - e := json.NewEncoder(os.Stdout) - e.SetIndent("", "\t") - e.Encode(res) - return + s := &helloserver.Server{ + HTTPAddr: ":80", + HTTPSAddr: ":443", } - if devMode() { - // Parse it optimistically - var err error - tmpl, err = template.New("home").Parse(embeddedTemplate) - if err != nil { - log.Printf("ignoring template error in dev mode: %v", err) - } - } else { - if embeddedTemplate == "" { - log.Fatalf("embeddedTemplate is empty; must be build with Go 1.16+") - } - tmpl = template.Must(template.New("home").Parse(embeddedTemplate)) - } - - http.HandleFunc("/", root) log.Printf("Starting hello server.") - - errc := make(chan error, 1) - if *httpAddr != "" { - log.Printf("running HTTP server on %s", *httpAddr) - go func() { - errc <- http.ListenAndServe(*httpAddr, nil) - }() - } - if *httpsAddr != "" { - log.Printf("running HTTPS server on %s", *httpsAddr) - go func() { - hs := &http.Server{ - Addr: *httpsAddr, - TLSConfig: &tls.Config{ - GetCertificate: func(hi *tls.ClientHelloInfo) (*tls.Certificate, error) { - switch hi.ServerName { - case "hello.ts.net": - return localClient.GetCertificate(hi) - case "hello.ipn.dev": - c, err := tls.LoadX509KeyPair( - "/etc/hello/hello.ipn.dev.crt", - "/etc/hello/hello.ipn.dev.key", - ) - if err != nil { - return nil, err - } - return &c, nil - } - return nil, errors.New("invalid SNI name") - }, - }, - IdleTimeout: 30 * time.Second, - ReadHeaderTimeout: 20 * time.Second, - MaxHeaderBytes: 10 << 10, - } - errc <- hs.ListenAndServeTLS("", "") - }() - } - log.Fatal(<-errc) -} - -func devMode() bool { return *httpsAddr == "" && *httpAddr != "" } - -func getTmpl() (*template.Template, error) { - if devMode() { - tmplData, err := os.ReadFile("hello.tmpl.html") - if os.IsNotExist(err) { - log.Printf("using baked-in template in dev mode; can't find hello.tmpl.html in current directory") - return tmpl, nil - } - return template.New("home").Parse(string(tmplData)) - } - return tmpl, nil -} - -// tmpl is the template used in prod mode. -// In dev mode it's only used if the template file doesn't exist on disk. -// It's initialized by main after flag parsing. -var tmpl *template.Template - -type tmplData struct { - DisplayName string // "Foo Barberson" - LoginName string // "foo@bar.com" - ProfilePicURL string // "https://..." - MachineName string // "imac5k" - MachineOS string // "Linux" - IP string // "100.2.3.4" -} - -func tailscaleIP(who *apitype.WhoIsResponse) string { - if who == nil { - return "" - } - vals, err := tailcfg.UnmarshalNodeCapJSON[string](who.Node.CapMap, tailcfg.NodeAttrNativeIPV4) - if err == nil && len(vals) > 0 { - return vals[0] - } - for _, nodeIP := range who.Node.Addresses { - if nodeIP.Addr().Is4() && nodeIP.IsSingleIP() { - return nodeIP.Addr().String() - } - } - for _, nodeIP := range who.Node.Addresses { - if nodeIP.IsSingleIP() { - return nodeIP.Addr().String() - } - } - return "" -} - -func root(w http.ResponseWriter, r *http.Request) { - if r.TLS == nil && *httpsAddr != "" { - host := r.Host - if strings.Contains(r.Host, "100.101.102.103") || - strings.Contains(r.Host, "hello.ipn.dev") { - host = "hello.ts.net" - } - http.Redirect(w, r, "https://"+host, http.StatusFound) - return - } - if r.RequestURI != "/" { - http.Redirect(w, r, "/", http.StatusFound) - return - } - if r.TLS != nil && *httpsAddr != "" && strings.Contains(r.Host, "hello.ipn.dev") { - http.Redirect(w, r, "https://hello.ts.net", http.StatusFound) - return - } - tmpl, err := getTmpl() - if err != nil { - w.Header().Set("Content-Type", "text/plain") - http.Error(w, "template error: "+err.Error(), 500) - return - } - - who, err := localClient.WhoIs(r.Context(), r.RemoteAddr) - var data tmplData - if err != nil { - if devMode() { - log.Printf("warning: using fake data in dev mode due to whois lookup error: %v", err) - data = tmplData{ - DisplayName: "Taily Scalerson", - LoginName: "taily@scaler.son", - ProfilePicURL: "https://placekitten.com/200/200", - MachineName: "scaled", - MachineOS: "Linux", - IP: "100.1.2.3", - } - } else { - log.Printf("whois(%q) error: %v", r.RemoteAddr, err) - http.Error(w, "Your Tailscale works, but we failed to look you up.", 500) - return - } - } else { - data = tmplData{ - DisplayName: who.UserProfile.DisplayName, - LoginName: who.UserProfile.LoginName, - ProfilePicURL: who.UserProfile.ProfilePicURL, - MachineName: firstLabel(who.Node.ComputedName), - MachineOS: who.Node.Hostinfo.OS(), - IP: tailscaleIP(who), - } - } - w.Header().Set("Content-Type", "text/html; charset=utf-8") - tmpl.Execute(w, data) -} - -// firstLabel s up until the first period, if any. -func firstLabel(s string) string { - s, _, _ = strings.Cut(s, ".") - return s + log.Fatal(s.Run()) } diff --git a/cmd/hello/hello.tmpl.html b/cmd/hello/hello.tmpl.html deleted file mode 100644 index 3ecd1b58a..000000000 --- a/cmd/hello/hello.tmpl.html +++ /dev/null @@ -1,438 +0,0 @@ - - - - - - Hello from Tailscale - - - - -
- -
-

You're connected over Tailscale!

-

This device is signed in as…

-
-
-
- - - -
-
-
-
- {{ with .DisplayName }} -

{{.}}

- {{ end }} -
{{.LoginName}}
-
-
-
-
- - - - - - -

{{.MachineName}}

-
-
{{.IP}}
-
-
- -
- - diff --git a/cmd/hello/helloserver/hello.tmpl.html b/cmd/hello/helloserver/hello.tmpl.html new file mode 100644 index 000000000..0f74d116f --- /dev/null +++ b/cmd/hello/helloserver/hello.tmpl.html @@ -0,0 +1,71 @@ + + + + + + Hello from Tailscale + + + + +
+ +
+

You're connected over Tailscale!

+

This device is signed in as…

+
+
+
+ + + +
+
+
+ Profile picture +
+
+ {{ with .DisplayName }} +

{{.}}

+ {{ end }} +
{{.LoginName}}
+
+
+
+
+ + + + + + +

{{.MachineName}}

+
+
{{.IP}}
+
+
+ +
+ + diff --git a/cmd/hello/helloserver/helloserver.go b/cmd/hello/helloserver/helloserver.go new file mode 100644 index 000000000..41e7dbce2 --- /dev/null +++ b/cmd/hello/helloserver/helloserver.go @@ -0,0 +1,157 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +// Package helloserver implements the HTTP server behind hello.ts.net. +package helloserver + +import ( + "crypto/tls" + "embed" + "html/template" + "log" + "net/http" + "strings" + "time" + + "tailscale.com/client/local" + "tailscale.com/client/tailscale/apitype" + "tailscale.com/tailcfg" +) + +//go:embed hello.tmpl.html +var embeddedTemplate string + +//go:embed static/* +var staticFiles embed.FS + +var staticHandler = http.FileServerFS(staticFiles) + +var tmpl = template.Must(template.New("home").Parse(embeddedTemplate)) + +// Server is an HTTP server for hello.ts.net. +// +// The zero value is not valid; populate at least one of HTTPAddr or HTTPSAddr +// before calling Run. +type Server struct { + // HTTPAddr is the address to run an HTTP server on, or empty for none. + HTTPAddr string + + // HTTPSAddr is the address to run an HTTPS server on, or empty for none. + HTTPSAddr string + + // LocalClient is used to look up the identity of incoming requests and + // to obtain TLS certificates. If nil, the zero value of local.Client is + // used. + LocalClient *local.Client +} + +func (s *Server) localClient() *local.Client { + if s.LocalClient != nil { + return s.LocalClient + } + return &local.Client{} +} + +// Run starts the configured HTTP and HTTPS servers and blocks until one of +// them returns an error. +func (s *Server) Run() error { + errc := make(chan error, 1) + if s.HTTPAddr != "" { + log.Printf("running HTTP server on %s", s.HTTPAddr) + go func() { + errc <- http.ListenAndServe(s.HTTPAddr, s) + }() + } + if s.HTTPSAddr != "" { + log.Printf("running HTTPS server on %s", s.HTTPSAddr) + go func() { + hs := &http.Server{ + Addr: s.HTTPSAddr, + Handler: s, + TLSConfig: &tls.Config{ + GetCertificate: s.localClient().GetCertificate, + }, + IdleTimeout: 30 * time.Second, + ReadHeaderTimeout: 20 * time.Second, + MaxHeaderBytes: 10 << 10, + } + errc <- hs.ListenAndServeTLS("", "") + }() + } + return <-errc +} + +type tmplData struct { + DisplayName string // "Foo Barberson" + LoginName string // "foo@bar.com" + ProfilePicURL string // "https://..." + MachineName string // "imac5k" + MachineOS string // "Linux" + IP string // "100.2.3.4" +} + +func tailscaleIP(who *apitype.WhoIsResponse) string { + if who == nil { + return "" + } + vals, err := tailcfg.UnmarshalNodeCapJSON[string](who.Node.CapMap, tailcfg.NodeAttrNativeIPV4) + if err == nil && len(vals) > 0 { + return vals[0] + } + for _, nodeIP := range who.Node.Addresses { + if nodeIP.Addr().Is4() && nodeIP.IsSingleIP() { + return nodeIP.Addr().String() + } + } + for _, nodeIP := range who.Node.Addresses { + if nodeIP.IsSingleIP() { + return nodeIP.Addr().String() + } + } + return "" +} + +// ServeHTTP implements http.Handler. +func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if r.TLS == nil && s.HTTPSAddr != "" { + host := r.Host + if strings.Contains(r.Host, "100.101.102.103") { + host = "hello.ts.net" + } + http.Redirect(w, r, "https://"+host, http.StatusFound) + return + } + + if strings.HasPrefix(r.RequestURI, "/static/") { + staticHandler.ServeHTTP(w, r) + return + } + + if r.RequestURI != "/" { + http.Redirect(w, r, "/", http.StatusFound) + return + } + + who, err := s.localClient().WhoIs(r.Context(), r.RemoteAddr) + if err != nil { + log.Printf("whois(%q) error: %v", r.RemoteAddr, err) + http.Error(w, "Your Tailscale works, but we failed to look you up.", 500) + return + } + data := tmplData{ + DisplayName: who.UserProfile.DisplayName, + LoginName: who.UserProfile.LoginName, + ProfilePicURL: who.UserProfile.ProfilePicURL, + MachineName: firstLabel(who.Node.ComputedName), + MachineOS: who.Node.Hostinfo.OS(), + IP: tailscaleIP(who), + } + w.Header().Set("Content-Type", "text/html; charset=utf-8") + tmpl.Execute(w, data) +} + +// firstLabel returns s up until the first period, if any. +func firstLabel(s string) string { + s, _, _ = strings.Cut(s, ".") + return s +} diff --git a/cmd/hello/helloserver/static/script.js b/cmd/hello/helloserver/static/script.js new file mode 100644 index 000000000..db9bcd0f3 --- /dev/null +++ b/cmd/hello/helloserver/static/script.js @@ -0,0 +1,12 @@ +(function () { + var lastSeen = localStorage.getItem("lastSeen"); + if (!lastSeen) { + document.body.classList.add("animate"); + window.addEventListener("load", function () { + setTimeout(function () { + document.body.classList.add("animating"); + localStorage.setItem("lastSeen", Date.now()); + }, 100); + }); + } +})(); diff --git a/cmd/hello/helloserver/static/style.css b/cmd/hello/helloserver/static/style.css new file mode 100644 index 000000000..8ad55edc6 --- /dev/null +++ b/cmd/hello/helloserver/static/style.css @@ -0,0 +1,366 @@ +html, +body { + margin: 0; + padding: 0; +} + +body { + font-family: Inter, -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Helvetica, Arial, sans-serif; + font-size: 100%; + -webkit-font-smoothing: antialiased; + -moz-osx-font-smoothing: grayscale; +} + +html, +body, +main { + height: 100%; +} + +*, +::before, +::after { + box-sizing: border-box; + border-width: 0; + border-style: solid; + border-color: #dad6d5; +} + +h1, +h2, +h3, +h4, +h5, +h6 { + margin: 0; + font-size: 1rem; + font-weight: inherit; +} + +a { + color: inherit; +} + +p { + margin: 0; +} + +main { + display: flex; + flex-direction: column; + justify-content: center; + align-items: center; + max-width: 24rem; + width: 95%; + margin-left: auto; + margin-right: auto; +} + +.p-2 { + padding: 0.5rem; +} + +.p-4 { + padding: 1rem; +} + +.px-2 { + padding-left: 0.5rem; + padding-right: 0.5rem; +} + +.pl-3 { + padding-left: 0.75rem; +} + +.pr-3 { + padding-right: 0.75rem; +} + +.pt-4 { + padding-top: 1rem; +} + +.mr-2 { + margin-right: 0.5rem; +; +} + +.mb-1 { + margin-bottom: 0.25rem; +} + +.mb-2 { + margin-bottom: 0.5rem; +} + +.mb-4 { + margin-bottom: 1rem; +} + +.mb-6 { + margin-bottom: 1.5rem; +} + +.mb-8 { + margin-bottom: 2rem; +} + +.mb-12 { + margin-bottom: 3rem; +} + +.width-full { + width: 100%; +} + +.min-width-0 { + min-width: 0; +} + +.rounded-lg { + border-radius: 0.5rem; +} + +.relative { + position: relative; +} + +.flex { + display: flex; +} + +.justify-between { + justify-content: space-between; +} + +.items-center { + align-items: center; +} + +.border { + border-width: 1px; +} + +.border-t-1 { + border-top-width: 1px; +} + +.border-gray-100 { + border-color: #f7f5f4; +} + +.border-gray-200 { + border-color: #eeebea; +} + +.border-gray-300 { + border-color: #dad6d5; +} + +.bg-white { + background-color: white; +} + +.bg-gray-0 { + background-color: #faf9f8; +} + +.bg-gray-100 { + background-color: #f7f5f4; +} + +.text-green-600 { + color: #0d4b3b; +} + +.text-blue-600 { + color: #3f5db3; +} + +.hover\:text-blue-800:hover { + color: #253570; +} + +.text-gray-600 { + color: #444342; +} + +.text-gray-700 { + color: #2e2d2d; +} + +.text-gray-800 { + color: #232222; +} + +.text-center { + text-align: center; +} + +.text-sm { + font-size: 0.875rem; +} + +.font-title { + font-size: 1.25rem; + letter-spacing: -0.025em; +} + +.font-semibold { + font-weight: 600; +} + +.font-medium { + font-weight: 500; +} + +.font-regular { + font-weight: 400; +} + +.truncate { + overflow: hidden; + text-overflow: ellipsis; + white-space: nowrap; +} + +.overflow-hidden { + overflow: hidden; +} + +.profile-pic { + width: 2.5rem; + height: 2.5rem; + background-size: cover; + margin-right: 0.5rem; + flex-shrink: 0; +} + +.profile-pic-img { + width: 100%; + height: 100%; + object-fit: cover; + display: block; + border-radius: 9999px; +} + +.panel { + box-shadow: 0 20px 25px -5px rgba(0, 0, 0, 0.1), 0 10px 10px -5px rgba(0, 0, 0, 0.04); +} + +.animate .panel { + transform: translateY(10%); + box-shadow: 0 20px 25px -5px rgba(0, 0, 0, 0.0), 0 10px 10px -5px rgba(0, 0, 0, 0.0); + transition: transform 1200ms ease, opacity 1200ms ease, box-shadow 1200ms ease; +} + +.animate .panel-interior { + opacity: 0.0; + transition: opacity 1200ms ease; +} + +.animate .logo { + transform: translateY(2rem); + opacity: 0.0; + transition: transform 1200ms ease, opacity 1200ms ease; +} + +.animate .header-title { + transform: translateY(1.6rem); + opacity: 0.0; + transition: transform 1200ms ease, opacity 1200ms ease; +} + +.animate .header-text { + transform: translateY(1.2rem); + opacity: 0.0; + transition: transform 1200ms ease, opacity 1200ms ease; +} + +.animate .footer { + transform: translateY(-0.5rem); + opacity: 0.0; + transition: transform 1200ms ease, opacity 1200ms ease; +} + +.animating .panel { + transform: translateY(0); + opacity: 1.0; + box-shadow: 0 20px 25px -5px rgba(0, 0, 0, 0.1), 0 10px 10px -5px rgba(0, 0, 0, 0.04); +} + +.animating .panel-interior { + opacity: 1.0; +} + +.animating .spinner { + opacity: 0.0; +} + +.animating .logo, +.animating .header-title, +.animating .header-text, +.animating .footer { + transform: translateY(0); + opacity: 1.0; +} + +.spinner { + display: inline-flex; + position: absolute; + top: 50%; + left: 50%; + transform: translate(-50%, -50%); + align-items: center; + transition: opacity 200ms ease; +} + +.spinner span { + display: inline-block; + background-color: currentColor; + border-radius: 9999px; + animation-name: loading-dots-blink; + animation-duration: 1.4s; + animation-iteration-count: infinite; + animation-fill-mode: both; + width: 0.35em; + height: 0.35em; + margin: 0 0.15em; +} + +.spinner span:nth-child(2) { + animation-delay: 200ms; +} + +.spinner span:nth-child(3) { + animation-delay: 400ms; +} + +.spinner { + display: none; +} + +.animate .spinner { + display: inline-flex; +} + +@keyframes loading-dots-blink { + 0% { + opacity: 0.2; + } + 20% { + opacity: 1; + } + 100% { + opacity: 0.2; + } +} + +@media (prefers-reduced-motion) { + * { + animation-duration: 0ms !important; + transition-duration: 0ms !important; + transition-delay: 0ms !important; + } +} diff --git a/cmd/k8s-operator/depaware.txt b/cmd/k8s-operator/depaware.txt index 2b6884683..12073da0b 100644 --- a/cmd/k8s-operator/depaware.txt +++ b/cmd/k8s-operator/depaware.txt @@ -6,77 +6,6 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ W 💣 github.com/alexbrainman/sspi from github.com/alexbrainman/sspi/internal/common+ W github.com/alexbrainman/sspi/internal/common from github.com/alexbrainman/sspi/negotiate W 💣 github.com/alexbrainman/sspi/negotiate from tailscale.com/net/tshttpproxy - github.com/aws/aws-sdk-go-v2/aws from github.com/aws/aws-sdk-go-v2/aws/defaults+ - github.com/aws/aws-sdk-go-v2/aws/defaults from github.com/aws/aws-sdk-go-v2/service/sso+ - github.com/aws/aws-sdk-go-v2/aws/middleware from github.com/aws/aws-sdk-go-v2/aws/retry+ - github.com/aws/aws-sdk-go-v2/aws/protocol/query from github.com/aws/aws-sdk-go-v2/service/sts - github.com/aws/aws-sdk-go-v2/aws/protocol/restjson from github.com/aws/aws-sdk-go-v2/service/sso+ - github.com/aws/aws-sdk-go-v2/aws/protocol/xml from github.com/aws/aws-sdk-go-v2/service/sts - github.com/aws/aws-sdk-go-v2/aws/ratelimit from github.com/aws/aws-sdk-go-v2/aws/retry - github.com/aws/aws-sdk-go-v2/aws/retry from github.com/aws/aws-sdk-go-v2/credentials/endpointcreds/internal/client+ - github.com/aws/aws-sdk-go-v2/aws/signer/internal/v4 from github.com/aws/aws-sdk-go-v2/aws/signer/v4 - github.com/aws/aws-sdk-go-v2/aws/signer/v4 from github.com/aws/aws-sdk-go-v2/internal/auth/smithy+ - github.com/aws/aws-sdk-go-v2/aws/transport/http from github.com/aws/aws-sdk-go-v2/config+ - github.com/aws/aws-sdk-go-v2/config from tailscale.com/wif - github.com/aws/aws-sdk-go-v2/credentials from github.com/aws/aws-sdk-go-v2/config - github.com/aws/aws-sdk-go-v2/credentials/ec2rolecreds from github.com/aws/aws-sdk-go-v2/config - github.com/aws/aws-sdk-go-v2/credentials/endpointcreds from github.com/aws/aws-sdk-go-v2/config - github.com/aws/aws-sdk-go-v2/credentials/endpointcreds/internal/client from github.com/aws/aws-sdk-go-v2/credentials/endpointcreds - github.com/aws/aws-sdk-go-v2/credentials/processcreds from github.com/aws/aws-sdk-go-v2/config - github.com/aws/aws-sdk-go-v2/credentials/ssocreds from github.com/aws/aws-sdk-go-v2/config - github.com/aws/aws-sdk-go-v2/credentials/stscreds from github.com/aws/aws-sdk-go-v2/config - github.com/aws/aws-sdk-go-v2/feature/ec2/imds from github.com/aws/aws-sdk-go-v2/config+ - github.com/aws/aws-sdk-go-v2/feature/ec2/imds/internal/config from github.com/aws/aws-sdk-go-v2/feature/ec2/imds - github.com/aws/aws-sdk-go-v2/internal/auth from github.com/aws/aws-sdk-go-v2/aws/signer/v4+ - github.com/aws/aws-sdk-go-v2/internal/auth/smithy from github.com/aws/aws-sdk-go-v2/service/sso+ - github.com/aws/aws-sdk-go-v2/internal/configsources from github.com/aws/aws-sdk-go-v2/service/sso+ - github.com/aws/aws-sdk-go-v2/internal/context from github.com/aws/aws-sdk-go-v2/aws/retry+ - github.com/aws/aws-sdk-go-v2/internal/endpoints from github.com/aws/aws-sdk-go-v2/service/sso+ - github.com/aws/aws-sdk-go-v2/internal/endpoints/awsrulesfn from github.com/aws/aws-sdk-go-v2/service/sso+ - github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 from github.com/aws/aws-sdk-go-v2/service/sso/internal/endpoints+ - github.com/aws/aws-sdk-go-v2/internal/ini from github.com/aws/aws-sdk-go-v2/config - github.com/aws/aws-sdk-go-v2/internal/middleware from github.com/aws/aws-sdk-go-v2/service/sso+ - github.com/aws/aws-sdk-go-v2/internal/rand from github.com/aws/aws-sdk-go-v2/aws+ - github.com/aws/aws-sdk-go-v2/internal/sdk from github.com/aws/aws-sdk-go-v2/aws+ - github.com/aws/aws-sdk-go-v2/internal/sdkio from github.com/aws/aws-sdk-go-v2/credentials/processcreds - github.com/aws/aws-sdk-go-v2/internal/shareddefaults from github.com/aws/aws-sdk-go-v2/config+ - github.com/aws/aws-sdk-go-v2/internal/strings from github.com/aws/aws-sdk-go-v2/aws/signer/internal/v4 - github.com/aws/aws-sdk-go-v2/internal/sync/singleflight from github.com/aws/aws-sdk-go-v2/aws - github.com/aws/aws-sdk-go-v2/internal/timeconv from github.com/aws/aws-sdk-go-v2/aws/retry - github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding from github.com/aws/aws-sdk-go-v2/service/sts - github.com/aws/aws-sdk-go-v2/service/internal/presigned-url from github.com/aws/aws-sdk-go-v2/service/sts - github.com/aws/aws-sdk-go-v2/service/sso from github.com/aws/aws-sdk-go-v2/config+ - github.com/aws/aws-sdk-go-v2/service/sso/internal/endpoints from github.com/aws/aws-sdk-go-v2/service/sso - github.com/aws/aws-sdk-go-v2/service/sso/types from github.com/aws/aws-sdk-go-v2/service/sso - github.com/aws/aws-sdk-go-v2/service/ssooidc from github.com/aws/aws-sdk-go-v2/config+ - github.com/aws/aws-sdk-go-v2/service/ssooidc/internal/endpoints from github.com/aws/aws-sdk-go-v2/service/ssooidc - github.com/aws/aws-sdk-go-v2/service/ssooidc/types from github.com/aws/aws-sdk-go-v2/service/ssooidc - github.com/aws/aws-sdk-go-v2/service/sts from github.com/aws/aws-sdk-go-v2/config+ - github.com/aws/aws-sdk-go-v2/service/sts/internal/endpoints from github.com/aws/aws-sdk-go-v2/service/sts - github.com/aws/aws-sdk-go-v2/service/sts/types from github.com/aws/aws-sdk-go-v2/credentials/stscreds+ - github.com/aws/smithy-go from github.com/aws/aws-sdk-go-v2/aws/protocol/restjson+ - github.com/aws/smithy-go/auth from github.com/aws/aws-sdk-go-v2/internal/auth+ - github.com/aws/smithy-go/auth/bearer from github.com/aws/aws-sdk-go-v2/aws+ - github.com/aws/smithy-go/context from github.com/aws/smithy-go/auth/bearer - github.com/aws/smithy-go/document from github.com/aws/aws-sdk-go-v2/service/sso+ - github.com/aws/smithy-go/encoding from github.com/aws/smithy-go/encoding/json+ - github.com/aws/smithy-go/encoding/httpbinding from github.com/aws/aws-sdk-go-v2/aws/protocol/query+ - github.com/aws/smithy-go/encoding/json from github.com/aws/aws-sdk-go-v2/service/ssooidc - github.com/aws/smithy-go/encoding/xml from github.com/aws/aws-sdk-go-v2/service/sts - github.com/aws/smithy-go/endpoints from github.com/aws/aws-sdk-go-v2/service/sso+ - github.com/aws/smithy-go/endpoints/private/rulesfn from github.com/aws/aws-sdk-go-v2/service/sts - github.com/aws/smithy-go/internal/sync/singleflight from github.com/aws/smithy-go/auth/bearer - github.com/aws/smithy-go/io from github.com/aws/aws-sdk-go-v2/feature/ec2/imds+ - github.com/aws/smithy-go/logging from github.com/aws/aws-sdk-go-v2/aws+ - github.com/aws/smithy-go/metrics from github.com/aws/aws-sdk-go-v2/aws/retry+ - github.com/aws/smithy-go/middleware from github.com/aws/aws-sdk-go-v2/aws+ - github.com/aws/smithy-go/private/requestcompression from github.com/aws/aws-sdk-go-v2/config - github.com/aws/smithy-go/ptr from github.com/aws/aws-sdk-go-v2/aws+ - github.com/aws/smithy-go/rand from github.com/aws/aws-sdk-go-v2/aws/middleware - github.com/aws/smithy-go/time from github.com/aws/aws-sdk-go-v2/service/sso+ - github.com/aws/smithy-go/tracing from github.com/aws/aws-sdk-go-v2/aws/middleware+ - github.com/aws/smithy-go/transport/http from github.com/aws/aws-sdk-go-v2/aws+ - github.com/aws/smithy-go/transport/http/internal/io from github.com/aws/smithy-go/transport/http github.com/beorn7/perks/quantile from github.com/prometheus/client_golang/prometheus github.com/blang/semver/v4 from k8s.io/component-base/metrics 💣 github.com/cespare/xxhash/v2 from github.com/prometheus/client_golang/prometheus+ @@ -130,7 +59,7 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ github.com/google/gnostic-models/jsonschema from github.com/google/gnostic-models/compiler github.com/google/gnostic-models/openapiv2 from k8s.io/client-go/discovery+ github.com/google/gnostic-models/openapiv3 from k8s.io/kube-openapi/pkg/handler3+ - github.com/google/uuid from github.com/prometheus-community/pro-bing+ + github.com/google/uuid from k8s.io/apimachinery/pkg/util/uuid+ github.com/hdevalence/ed25519consensus from tailscale.com/tka github.com/huin/goupnp from github.com/huin/goupnp/dcps/internetgateway2+ github.com/huin/goupnp/dcps/internetgateway2 from tailscale.com/net/portmapper @@ -164,7 +93,6 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ github.com/pires/go-proxyproto from tailscale.com/ipn/ipnlocal+ github.com/pkg/errors from github.com/evanphx/json-patch/v5+ github.com/pmezard/go-difflib/difflib from k8s.io/apimachinery/pkg/util/diff - D github.com/prometheus-community/pro-bing from tailscale.com/wgengine/netstack github.com/prometheus/client_golang/internal/github.com/golang/gddo/httputil from github.com/prometheus/client_golang/prometheus/promhttp github.com/prometheus/client_golang/internal/github.com/golang/gddo/httputil/header from github.com/prometheus/client_golang/internal/github.com/golang/gddo/httputil 💣 github.com/prometheus/client_golang/prometheus from github.com/prometheus/client_golang/prometheus/collectors+ @@ -180,7 +108,7 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ LD github.com/prometheus/procfs/internal/util from github.com/prometheus/procfs L 💣 github.com/safchain/ethtool from tailscale.com/net/netkernelconf github.com/spf13/pflag from k8s.io/client-go/tools/clientcmd+ - W 💣 github.com/tailscale/certstore from tailscale.com/control/controlclient + DW 💣 github.com/tailscale/certstore from tailscale.com/control/controlclient W 💣 github.com/tailscale/go-winio from tailscale.com/safesocket W 💣 github.com/tailscale/go-winio/internal/fs from github.com/tailscale/go-winio W 💣 github.com/tailscale/go-winio/internal/socket from github.com/tailscale/go-winio @@ -805,11 +733,9 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ tailscale.com/feature/buildfeatures from tailscale.com/wgengine/magicsock+ tailscale.com/feature/c2n from tailscale.com/tsnet tailscale.com/feature/condlite/expvar from tailscale.com/wgengine/magicsock - tailscale.com/feature/condregister/identityfederation from tailscale.com/tsnet tailscale.com/feature/condregister/oauthkey from tailscale.com/tsnet tailscale.com/feature/condregister/portmapper from tailscale.com/tsnet tailscale.com/feature/condregister/useproxy from tailscale.com/tsnet - tailscale.com/feature/identityfederation from tailscale.com/feature/condregister/identityfederation tailscale.com/feature/oauthkey from tailscale.com/feature/condregister/oauthkey tailscale.com/feature/portmapper from tailscale.com/feature/condregister/portmapper tailscale.com/feature/syspolicy from tailscale.com/logpolicy @@ -817,7 +743,7 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ tailscale.com/health from tailscale.com/control/controlclient+ tailscale.com/health/healthmsg from tailscale.com/ipn/ipnlocal tailscale.com/hostinfo from tailscale.com/client/web+ - tailscale.com/internal/client/tailscale from tailscale.com/feature/identityfederation+ + tailscale.com/internal/client/tailscale from tailscale.com/feature/oauthkey+ tailscale.com/ipn from tailscale.com/client/local+ tailscale.com/ipn/conffile from tailscale.com/ipn/ipnlocal+ 💣 tailscale.com/ipn/ipnauth from tailscale.com/ipn/ipnlocal+ @@ -910,7 +836,7 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ tailscale.com/tstime from tailscale.com/cmd/k8s-operator+ tailscale.com/tstime/mono from tailscale.com/net/tstun+ tailscale.com/tstime/rate from tailscale.com/wgengine/filter - tailscale.com/tsweb from tailscale.com/util/eventbus + tailscale.com/tsweb from tailscale.com/util/eventbus+ tailscale.com/tsweb/varz from tailscale.com/util/usermetric+ tailscale.com/types/appctype from tailscale.com/ipn/ipnlocal+ tailscale.com/types/bools from tailscale.com/tsnet+ @@ -1000,7 +926,6 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ tailscale.com/wgengine/wgcfg/nmcfg from tailscale.com/ipn/ipnlocal 💣 tailscale.com/wgengine/wgint from tailscale.com/wgengine+ tailscale.com/wgengine/wglog from tailscale.com/wgengine - tailscale.com/wif from tailscale.com/feature/identityfederation golang.org/x/crypto/argon2 from tailscale.com/tka golang.org/x/crypto/blake2b from golang.org/x/crypto/argon2+ golang.org/x/crypto/blake2s from github.com/tailscale/wireguard-go/device+ @@ -1023,14 +948,15 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ golang.org/x/net/http/httpproxy from tailscale.com/net/tshttpproxy golang.org/x/net/http2 from k8s.io/apimachinery/pkg/util/net+ golang.org/x/net/http2/hpack from golang.org/x/net/http2+ - golang.org/x/net/icmp from github.com/prometheus-community/pro-bing+ + golang.org/x/net/icmp from tailscale.com/net/ping golang.org/x/net/idna from golang.org/x/net/http/httpguts+ golang.org/x/net/internal/httpcommon from golang.org/x/net/http2 + golang.org/x/net/internal/httpsfv from golang.org/x/net/http2 golang.org/x/net/internal/iana from golang.org/x/net/icmp+ - golang.org/x/net/internal/socket from golang.org/x/net/icmp+ + golang.org/x/net/internal/socket from golang.org/x/net/ipv4+ golang.org/x/net/internal/socks from golang.org/x/net/proxy - golang.org/x/net/ipv4 from github.com/prometheus-community/pro-bing+ - golang.org/x/net/ipv6 from github.com/prometheus-community/pro-bing+ + golang.org/x/net/ipv4 from github.com/tailscale/wireguard-go/conn+ + golang.org/x/net/ipv6 from github.com/tailscale/wireguard-go/conn+ golang.org/x/net/proxy from tailscale.com/net/netns D golang.org/x/net/route from tailscale.com/net/netmon+ golang.org/x/net/websocket from tailscale.com/k8s-operator/sessionrecording/ws @@ -1137,7 +1063,7 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ crypto/sha3 from crypto/internal/fips140hash+ crypto/sha512 from crypto/ecdsa+ crypto/subtle from crypto/cipher+ - crypto/tls from github.com/prometheus-community/pro-bing+ + crypto/tls from github.com/prometheus/client_golang/prometheus/promhttp+ crypto/tls/internal/fips140tls from crypto/tls crypto/x509 from crypto/tls+ D crypto/x509/internal/macos from crypto/x509 @@ -1246,7 +1172,7 @@ tailscale.com/cmd/k8s-operator dependencies: (generated by github.com/tailscale/ mime/quotedprintable from mime/multipart net from crypto/tls+ net/http from expvar+ - net/http/httptrace from github.com/prometheus-community/pro-bing+ + net/http/httptrace from github.com/prometheus/client_golang/prometheus/promhttp+ net/http/httputil from tailscale.com/client/web+ net/http/internal from net/http+ net/http/internal/ascii from net/http+ diff --git a/cmd/k8s-operator/deploy/chart/templates/deployment.yaml b/cmd/k8s-operator/deploy/chart/templates/deployment.yaml index 0c0cb64cb..feffd03a3 100644 --- a/cmd/k8s-operator/deploy/chart/templates/deployment.yaml +++ b/cmd/k8s-operator/deploy/chart/templates/deployment.yaml @@ -146,3 +146,6 @@ spec: tolerations: {{- toYaml . | nindent 8 }} {{- end }} + {{- with .Values.operatorConfig.priorityClassName }} + priorityClassName: {{ . }} + {{- end }} diff --git a/cmd/k8s-operator/deploy/chart/values.yaml b/cmd/k8s-operator/deploy/chart/values.yaml index 8517d77aa..7cebb2ef2 100644 --- a/cmd/k8s-operator/deploy/chart/values.yaml +++ b/cmd/k8s-operator/deploy/chart/values.yaml @@ -72,6 +72,8 @@ operatorConfig: affinity: {} + priorityClassName: "" + podSecurityContext: {} securityContext: {} diff --git a/cmd/k8s-operator/deploy/crds/tailscale.com_dnsconfigs.yaml b/cmd/k8s-operator/deploy/crds/tailscale.com_dnsconfigs.yaml index a819aa651..4d6422ede 100644 --- a/cmd/k8s-operator/deploy/crds/tailscale.com_dnsconfigs.yaml +++ b/cmd/k8s-operator/deploy/crds/tailscale.com_dnsconfigs.yaml @@ -104,6 +104,884 @@ spec: description: Pod configuration. type: object properties: + affinity: + description: If specified, applies affinity rules to the pods deployed by the DNSConfig resource. + type: object + properties: + nodeAffinity: + description: Describes node affinity scheduling rules for the pod. + type: object + properties: + preferredDuringSchedulingIgnoredDuringExecution: + description: |- + The scheduler will prefer to schedule pods to nodes that satisfy + the affinity expressions specified by this field, but it may choose + a node that violates one or more of the expressions. The node that is + most preferred is the one with the greatest sum of weights, i.e. + for each node that meets all of the scheduling requirements (resource + request, requiredDuringScheduling affinity expressions, etc.), + compute a sum by iterating through the elements of this field and adding + "weight" to the sum if the node matches the corresponding matchExpressions; the + node(s) with the highest sum are the most preferred. + type: array + items: + description: |- + An empty preferred scheduling term matches all objects with implicit weight 0 + (i.e. it's a no-op). A null preferred scheduling term matches no objects (i.e. is also a no-op). + type: object + required: + - preference + - weight + properties: + preference: + description: A node selector term, associated with the corresponding weight. + type: object + properties: + matchExpressions: + description: A list of node selector requirements by node's labels. + type: array + items: + description: |- + A node selector requirement is a selector that contains values, a key, and an operator + that relates the key and values. + type: object + required: + - key + - operator + properties: + key: + description: The label key that the selector applies to. + type: string + operator: + description: |- + Represents a key's relationship to a set of values. + Valid operators are In, NotIn, Exists, DoesNotExist. Gt, and Lt. + type: string + values: + description: |- + An array of string values. If the operator is In or NotIn, + the values array must be non-empty. If the operator is Exists or DoesNotExist, + the values array must be empty. If the operator is Gt or Lt, the values + array must have a single element, which will be interpreted as an integer. + This array is replaced during a strategic merge patch. + type: array + items: + type: string + x-kubernetes-list-type: atomic + x-kubernetes-list-type: atomic + matchFields: + description: A list of node selector requirements by node's fields. + type: array + items: + description: |- + A node selector requirement is a selector that contains values, a key, and an operator + that relates the key and values. + type: object + required: + - key + - operator + properties: + key: + description: The label key that the selector applies to. + type: string + operator: + description: |- + Represents a key's relationship to a set of values. + Valid operators are In, NotIn, Exists, DoesNotExist. Gt, and Lt. + type: string + values: + description: |- + An array of string values. If the operator is In or NotIn, + the values array must be non-empty. If the operator is Exists or DoesNotExist, + the values array must be empty. If the operator is Gt or Lt, the values + array must have a single element, which will be interpreted as an integer. + This array is replaced during a strategic merge patch. + type: array + items: + type: string + x-kubernetes-list-type: atomic + x-kubernetes-list-type: atomic + x-kubernetes-map-type: atomic + weight: + description: Weight associated with matching the corresponding nodeSelectorTerm, in the range 1-100. + type: integer + format: int32 + x-kubernetes-list-type: atomic + requiredDuringSchedulingIgnoredDuringExecution: + description: |- + If the affinity requirements specified by this field are not met at + scheduling time, the pod will not be scheduled onto the node. + If the affinity requirements specified by this field cease to be met + at some point during pod execution (e.g. due to an update), the system + may or may not try to eventually evict the pod from its node. + type: object + required: + - nodeSelectorTerms + properties: + nodeSelectorTerms: + description: Required. A list of node selector terms. The terms are ORed. + type: array + items: + description: |- + A null or empty node selector term matches no objects. The requirements of + them are ANDed. + The TopologySelectorTerm type implements a subset of the NodeSelectorTerm. + type: object + properties: + matchExpressions: + description: A list of node selector requirements by node's labels. + type: array + items: + description: |- + A node selector requirement is a selector that contains values, a key, and an operator + that relates the key and values. + type: object + required: + - key + - operator + properties: + key: + description: The label key that the selector applies to. + type: string + operator: + description: |- + Represents a key's relationship to a set of values. + Valid operators are In, NotIn, Exists, DoesNotExist. Gt, and Lt. + type: string + values: + description: |- + An array of string values. If the operator is In or NotIn, + the values array must be non-empty. If the operator is Exists or DoesNotExist, + the values array must be empty. If the operator is Gt or Lt, the values + array must have a single element, which will be interpreted as an integer. + This array is replaced during a strategic merge patch. + type: array + items: + type: string + x-kubernetes-list-type: atomic + x-kubernetes-list-type: atomic + matchFields: + description: A list of node selector requirements by node's fields. + type: array + items: + description: |- + A node selector requirement is a selector that contains values, a key, and an operator + that relates the key and values. + type: object + required: + - key + - operator + properties: + key: + description: The label key that the selector applies to. + type: string + operator: + description: |- + Represents a key's relationship to a set of values. + Valid operators are In, NotIn, Exists, DoesNotExist. Gt, and Lt. + type: string + values: + description: |- + An array of string values. If the operator is In or NotIn, + the values array must be non-empty. If the operator is Exists or DoesNotExist, + the values array must be empty. If the operator is Gt or Lt, the values + array must have a single element, which will be interpreted as an integer. + This array is replaced during a strategic merge patch. + type: array + items: + type: string + x-kubernetes-list-type: atomic + x-kubernetes-list-type: atomic + x-kubernetes-map-type: atomic + x-kubernetes-list-type: atomic + x-kubernetes-map-type: atomic + podAffinity: + description: Describes pod affinity scheduling rules (e.g. co-locate this pod in the same node, zone, etc. as some other pod(s)). + type: object + properties: + preferredDuringSchedulingIgnoredDuringExecution: + description: |- + The scheduler will prefer to schedule pods to nodes that satisfy + the affinity expressions specified by this field, but it may choose + a node that violates one or more of the expressions. The node that is + most preferred is the one with the greatest sum of weights, i.e. + for each node that meets all of the scheduling requirements (resource + request, requiredDuringScheduling affinity expressions, etc.), + compute a sum by iterating through the elements of this field and adding + "weight" to the sum if the node has pods which matches the corresponding podAffinityTerm; the + node(s) with the highest sum are the most preferred. + type: array + items: + description: The weights of all of the matched WeightedPodAffinityTerm fields are added per-node to find the most preferred node(s) + type: object + required: + - podAffinityTerm + - weight + properties: + podAffinityTerm: + description: Required. A pod affinity term, associated with the corresponding weight. + type: object + required: + - topologyKey + properties: + labelSelector: + description: |- + A label query over a set of resources, in this case pods. + If it's null, this PodAffinityTerm matches with no Pods. + type: object + properties: + matchExpressions: + description: matchExpressions is a list of label selector requirements. The requirements are ANDed. + type: array + items: + description: |- + A label selector requirement is a selector that contains values, a key, and an operator that + relates the key and values. + type: object + required: + - key + - operator + properties: + key: + description: key is the label key that the selector applies to. + type: string + operator: + description: |- + operator represents a key's relationship to a set of values. + Valid operators are In, NotIn, Exists and DoesNotExist. + type: string + values: + description: |- + values is an array of string values. If the operator is In or NotIn, + the values array must be non-empty. If the operator is Exists or DoesNotExist, + the values array must be empty. This array is replaced during a strategic + merge patch. + type: array + items: + type: string + x-kubernetes-list-type: atomic + x-kubernetes-list-type: atomic + matchLabels: + description: |- + matchLabels is a map of {key,value} pairs. A single {key,value} in the matchLabels + map is equivalent to an element of matchExpressions, whose key field is "key", the + operator is "In", and the values array contains only "value". The requirements are ANDed. + type: object + additionalProperties: + type: string + x-kubernetes-map-type: atomic + matchLabelKeys: + description: |- + MatchLabelKeys is a set of pod label keys to select which pods will + be taken into consideration. The keys are used to lookup values from the + incoming pod labels, those key-value labels are merged with `labelSelector` as `key in (value)` + to select the group of existing pods which pods will be taken into consideration + for the incoming pod's pod (anti) affinity. Keys that don't exist in the incoming + pod labels will be ignored. The default value is empty. + The same key is forbidden to exist in both matchLabelKeys and labelSelector. + Also, matchLabelKeys cannot be set when labelSelector isn't set. + type: array + items: + type: string + x-kubernetes-list-type: atomic + mismatchLabelKeys: + description: |- + MismatchLabelKeys is a set of pod label keys to select which pods will + be taken into consideration. The keys are used to lookup values from the + incoming pod labels, those key-value labels are merged with `labelSelector` as `key notin (value)` + to select the group of existing pods which pods will be taken into consideration + for the incoming pod's pod (anti) affinity. Keys that don't exist in the incoming + pod labels will be ignored. The default value is empty. + The same key is forbidden to exist in both mismatchLabelKeys and labelSelector. + Also, mismatchLabelKeys cannot be set when labelSelector isn't set. + type: array + items: + type: string + x-kubernetes-list-type: atomic + namespaceSelector: + description: |- + A label query over the set of namespaces that the term applies to. + The term is applied to the union of the namespaces selected by this field + and the ones listed in the namespaces field. + null selector and null or empty namespaces list means "this pod's namespace". + An empty selector ({}) matches all namespaces. + type: object + properties: + matchExpressions: + description: matchExpressions is a list of label selector requirements. The requirements are ANDed. + type: array + items: + description: |- + A label selector requirement is a selector that contains values, a key, and an operator that + relates the key and values. + type: object + required: + - key + - operator + properties: + key: + description: key is the label key that the selector applies to. + type: string + operator: + description: |- + operator represents a key's relationship to a set of values. + Valid operators are In, NotIn, Exists and DoesNotExist. + type: string + values: + description: |- + values is an array of string values. If the operator is In or NotIn, + the values array must be non-empty. If the operator is Exists or DoesNotExist, + the values array must be empty. This array is replaced during a strategic + merge patch. + type: array + items: + type: string + x-kubernetes-list-type: atomic + x-kubernetes-list-type: atomic + matchLabels: + description: |- + matchLabels is a map of {key,value} pairs. A single {key,value} in the matchLabels + map is equivalent to an element of matchExpressions, whose key field is "key", the + operator is "In", and the values array contains only "value". The requirements are ANDed. + type: object + additionalProperties: + type: string + x-kubernetes-map-type: atomic + namespaces: + description: |- + namespaces specifies a static list of namespace names that the term applies to. + The term is applied to the union of the namespaces listed in this field + and the ones selected by namespaceSelector. + null or empty namespaces list and null namespaceSelector means "this pod's namespace". + type: array + items: + type: string + x-kubernetes-list-type: atomic + topologyKey: + description: |- + This pod should be co-located (affinity) or not co-located (anti-affinity) with the pods matching + the labelSelector in the specified namespaces, where co-located is defined as running on a node + whose value of the label with key topologyKey matches that of any node on which any of the + selected pods is running. + Empty topologyKey is not allowed. + type: string + weight: + description: |- + weight associated with matching the corresponding podAffinityTerm, + in the range 1-100. + type: integer + format: int32 + x-kubernetes-list-type: atomic + requiredDuringSchedulingIgnoredDuringExecution: + description: |- + If the affinity requirements specified by this field are not met at + scheduling time, the pod will not be scheduled onto the node. + If the affinity requirements specified by this field cease to be met + at some point during pod execution (e.g. due to a pod label update), the + system may or may not try to eventually evict the pod from its node. + When there are multiple elements, the lists of nodes corresponding to each + podAffinityTerm are intersected, i.e. all terms must be satisfied. + type: array + items: + description: |- + Defines a set of pods (namely those matching the labelSelector + relative to the given namespace(s)) that this pod should be + co-located (affinity) or not co-located (anti-affinity) with, + where co-located is defined as running on a node whose value of + the label with key matches that of any node on which + a pod of the set of pods is running + type: object + required: + - topologyKey + properties: + labelSelector: + description: |- + A label query over a set of resources, in this case pods. + If it's null, this PodAffinityTerm matches with no Pods. + type: object + properties: + matchExpressions: + description: matchExpressions is a list of label selector requirements. The requirements are ANDed. + type: array + items: + description: |- + A label selector requirement is a selector that contains values, a key, and an operator that + relates the key and values. + type: object + required: + - key + - operator + properties: + key: + description: key is the label key that the selector applies to. + type: string + operator: + description: |- + operator represents a key's relationship to a set of values. + Valid operators are In, NotIn, Exists and DoesNotExist. + type: string + values: + description: |- + values is an array of string values. If the operator is In or NotIn, + the values array must be non-empty. If the operator is Exists or DoesNotExist, + the values array must be empty. This array is replaced during a strategic + merge patch. + type: array + items: + type: string + x-kubernetes-list-type: atomic + x-kubernetes-list-type: atomic + matchLabels: + description: |- + matchLabels is a map of {key,value} pairs. A single {key,value} in the matchLabels + map is equivalent to an element of matchExpressions, whose key field is "key", the + operator is "In", and the values array contains only "value". The requirements are ANDed. + type: object + additionalProperties: + type: string + x-kubernetes-map-type: atomic + matchLabelKeys: + description: |- + MatchLabelKeys is a set of pod label keys to select which pods will + be taken into consideration. The keys are used to lookup values from the + incoming pod labels, those key-value labels are merged with `labelSelector` as `key in (value)` + to select the group of existing pods which pods will be taken into consideration + for the incoming pod's pod (anti) affinity. Keys that don't exist in the incoming + pod labels will be ignored. The default value is empty. + The same key is forbidden to exist in both matchLabelKeys and labelSelector. + Also, matchLabelKeys cannot be set when labelSelector isn't set. + type: array + items: + type: string + x-kubernetes-list-type: atomic + mismatchLabelKeys: + description: |- + MismatchLabelKeys is a set of pod label keys to select which pods will + be taken into consideration. The keys are used to lookup values from the + incoming pod labels, those key-value labels are merged with `labelSelector` as `key notin (value)` + to select the group of existing pods which pods will be taken into consideration + for the incoming pod's pod (anti) affinity. Keys that don't exist in the incoming + pod labels will be ignored. The default value is empty. + The same key is forbidden to exist in both mismatchLabelKeys and labelSelector. + Also, mismatchLabelKeys cannot be set when labelSelector isn't set. + type: array + items: + type: string + x-kubernetes-list-type: atomic + namespaceSelector: + description: |- + A label query over the set of namespaces that the term applies to. + The term is applied to the union of the namespaces selected by this field + and the ones listed in the namespaces field. + null selector and null or empty namespaces list means "this pod's namespace". + An empty selector ({}) matches all namespaces. + type: object + properties: + matchExpressions: + description: matchExpressions is a list of label selector requirements. The requirements are ANDed. + type: array + items: + description: |- + A label selector requirement is a selector that contains values, a key, and an operator that + relates the key and values. + type: object + required: + - key + - operator + properties: + key: + description: key is the label key that the selector applies to. + type: string + operator: + description: |- + operator represents a key's relationship to a set of values. + Valid operators are In, NotIn, Exists and DoesNotExist. + type: string + values: + description: |- + values is an array of string values. If the operator is In or NotIn, + the values array must be non-empty. If the operator is Exists or DoesNotExist, + the values array must be empty. This array is replaced during a strategic + merge patch. + type: array + items: + type: string + x-kubernetes-list-type: atomic + x-kubernetes-list-type: atomic + matchLabels: + description: |- + matchLabels is a map of {key,value} pairs. A single {key,value} in the matchLabels + map is equivalent to an element of matchExpressions, whose key field is "key", the + operator is "In", and the values array contains only "value". The requirements are ANDed. + type: object + additionalProperties: + type: string + x-kubernetes-map-type: atomic + namespaces: + description: |- + namespaces specifies a static list of namespace names that the term applies to. + The term is applied to the union of the namespaces listed in this field + and the ones selected by namespaceSelector. + null or empty namespaces list and null namespaceSelector means "this pod's namespace". + type: array + items: + type: string + x-kubernetes-list-type: atomic + topologyKey: + description: |- + This pod should be co-located (affinity) or not co-located (anti-affinity) with the pods matching + the labelSelector in the specified namespaces, where co-located is defined as running on a node + whose value of the label with key topologyKey matches that of any node on which any of the + selected pods is running. + Empty topologyKey is not allowed. + type: string + x-kubernetes-list-type: atomic + podAntiAffinity: + description: Describes pod anti-affinity scheduling rules (e.g. avoid putting this pod in the same node, zone, etc. as some other pod(s)). + type: object + properties: + preferredDuringSchedulingIgnoredDuringExecution: + description: |- + The scheduler will prefer to schedule pods to nodes that satisfy + the anti-affinity expressions specified by this field, but it may choose + a node that violates one or more of the expressions. The node that is + most preferred is the one with the greatest sum of weights, i.e. + for each node that meets all of the scheduling requirements (resource + request, requiredDuringScheduling anti-affinity expressions, etc.), + compute a sum by iterating through the elements of this field and subtracting + "weight" from the sum if the node has pods which matches the corresponding podAffinityTerm; the + node(s) with the highest sum are the most preferred. + type: array + items: + description: The weights of all of the matched WeightedPodAffinityTerm fields are added per-node to find the most preferred node(s) + type: object + required: + - podAffinityTerm + - weight + properties: + podAffinityTerm: + description: Required. A pod affinity term, associated with the corresponding weight. + type: object + required: + - topologyKey + properties: + labelSelector: + description: |- + A label query over a set of resources, in this case pods. + If it's null, this PodAffinityTerm matches with no Pods. + type: object + properties: + matchExpressions: + description: matchExpressions is a list of label selector requirements. The requirements are ANDed. + type: array + items: + description: |- + A label selector requirement is a selector that contains values, a key, and an operator that + relates the key and values. + type: object + required: + - key + - operator + properties: + key: + description: key is the label key that the selector applies to. + type: string + operator: + description: |- + operator represents a key's relationship to a set of values. + Valid operators are In, NotIn, Exists and DoesNotExist. + type: string + values: + description: |- + values is an array of string values. If the operator is In or NotIn, + the values array must be non-empty. If the operator is Exists or DoesNotExist, + the values array must be empty. This array is replaced during a strategic + merge patch. + type: array + items: + type: string + x-kubernetes-list-type: atomic + x-kubernetes-list-type: atomic + matchLabels: + description: |- + matchLabels is a map of {key,value} pairs. A single {key,value} in the matchLabels + map is equivalent to an element of matchExpressions, whose key field is "key", the + operator is "In", and the values array contains only "value". The requirements are ANDed. + type: object + additionalProperties: + type: string + x-kubernetes-map-type: atomic + matchLabelKeys: + description: |- + MatchLabelKeys is a set of pod label keys to select which pods will + be taken into consideration. The keys are used to lookup values from the + incoming pod labels, those key-value labels are merged with `labelSelector` as `key in (value)` + to select the group of existing pods which pods will be taken into consideration + for the incoming pod's pod (anti) affinity. Keys that don't exist in the incoming + pod labels will be ignored. The default value is empty. + The same key is forbidden to exist in both matchLabelKeys and labelSelector. + Also, matchLabelKeys cannot be set when labelSelector isn't set. + type: array + items: + type: string + x-kubernetes-list-type: atomic + mismatchLabelKeys: + description: |- + MismatchLabelKeys is a set of pod label keys to select which pods will + be taken into consideration. The keys are used to lookup values from the + incoming pod labels, those key-value labels are merged with `labelSelector` as `key notin (value)` + to select the group of existing pods which pods will be taken into consideration + for the incoming pod's pod (anti) affinity. Keys that don't exist in the incoming + pod labels will be ignored. The default value is empty. + The same key is forbidden to exist in both mismatchLabelKeys and labelSelector. + Also, mismatchLabelKeys cannot be set when labelSelector isn't set. + type: array + items: + type: string + x-kubernetes-list-type: atomic + namespaceSelector: + description: |- + A label query over the set of namespaces that the term applies to. + The term is applied to the union of the namespaces selected by this field + and the ones listed in the namespaces field. + null selector and null or empty namespaces list means "this pod's namespace". + An empty selector ({}) matches all namespaces. + type: object + properties: + matchExpressions: + description: matchExpressions is a list of label selector requirements. The requirements are ANDed. + type: array + items: + description: |- + A label selector requirement is a selector that contains values, a key, and an operator that + relates the key and values. + type: object + required: + - key + - operator + properties: + key: + description: key is the label key that the selector applies to. + type: string + operator: + description: |- + operator represents a key's relationship to a set of values. + Valid operators are In, NotIn, Exists and DoesNotExist. + type: string + values: + description: |- + values is an array of string values. If the operator is In or NotIn, + the values array must be non-empty. If the operator is Exists or DoesNotExist, + the values array must be empty. This array is replaced during a strategic + merge patch. + type: array + items: + type: string + x-kubernetes-list-type: atomic + x-kubernetes-list-type: atomic + matchLabels: + description: |- + matchLabels is a map of {key,value} pairs. A single {key,value} in the matchLabels + map is equivalent to an element of matchExpressions, whose key field is "key", the + operator is "In", and the values array contains only "value". The requirements are ANDed. + type: object + additionalProperties: + type: string + x-kubernetes-map-type: atomic + namespaces: + description: |- + namespaces specifies a static list of namespace names that the term applies to. + The term is applied to the union of the namespaces listed in this field + and the ones selected by namespaceSelector. + null or empty namespaces list and null namespaceSelector means "this pod's namespace". + type: array + items: + type: string + x-kubernetes-list-type: atomic + topologyKey: + description: |- + This pod should be co-located (affinity) or not co-located (anti-affinity) with the pods matching + the labelSelector in the specified namespaces, where co-located is defined as running on a node + whose value of the label with key topologyKey matches that of any node on which any of the + selected pods is running. + Empty topologyKey is not allowed. + type: string + weight: + description: |- + weight associated with matching the corresponding podAffinityTerm, + in the range 1-100. + type: integer + format: int32 + x-kubernetes-list-type: atomic + requiredDuringSchedulingIgnoredDuringExecution: + description: |- + If the anti-affinity requirements specified by this field are not met at + scheduling time, the pod will not be scheduled onto the node. + If the anti-affinity requirements specified by this field cease to be met + at some point during pod execution (e.g. due to a pod label update), the + system may or may not try to eventually evict the pod from its node. + When there are multiple elements, the lists of nodes corresponding to each + podAffinityTerm are intersected, i.e. all terms must be satisfied. + type: array + items: + description: |- + Defines a set of pods (namely those matching the labelSelector + relative to the given namespace(s)) that this pod should be + co-located (affinity) or not co-located (anti-affinity) with, + where co-located is defined as running on a node whose value of + the label with key matches that of any node on which + a pod of the set of pods is running + type: object + required: + - topologyKey + properties: + labelSelector: + description: |- + A label query over a set of resources, in this case pods. + If it's null, this PodAffinityTerm matches with no Pods. + type: object + properties: + matchExpressions: + description: matchExpressions is a list of label selector requirements. The requirements are ANDed. + type: array + items: + description: |- + A label selector requirement is a selector that contains values, a key, and an operator that + relates the key and values. + type: object + required: + - key + - operator + properties: + key: + description: key is the label key that the selector applies to. + type: string + operator: + description: |- + operator represents a key's relationship to a set of values. + Valid operators are In, NotIn, Exists and DoesNotExist. + type: string + values: + description: |- + values is an array of string values. If the operator is In or NotIn, + the values array must be non-empty. If the operator is Exists or DoesNotExist, + the values array must be empty. This array is replaced during a strategic + merge patch. + type: array + items: + type: string + x-kubernetes-list-type: atomic + x-kubernetes-list-type: atomic + matchLabels: + description: |- + matchLabels is a map of {key,value} pairs. A single {key,value} in the matchLabels + map is equivalent to an element of matchExpressions, whose key field is "key", the + operator is "In", and the values array contains only "value". The requirements are ANDed. + type: object + additionalProperties: + type: string + x-kubernetes-map-type: atomic + matchLabelKeys: + description: |- + MatchLabelKeys is a set of pod label keys to select which pods will + be taken into consideration. The keys are used to lookup values from the + incoming pod labels, those key-value labels are merged with `labelSelector` as `key in (value)` + to select the group of existing pods which pods will be taken into consideration + for the incoming pod's pod (anti) affinity. Keys that don't exist in the incoming + pod labels will be ignored. The default value is empty. + The same key is forbidden to exist in both matchLabelKeys and labelSelector. + Also, matchLabelKeys cannot be set when labelSelector isn't set. + type: array + items: + type: string + x-kubernetes-list-type: atomic + mismatchLabelKeys: + description: |- + MismatchLabelKeys is a set of pod label keys to select which pods will + be taken into consideration. The keys are used to lookup values from the + incoming pod labels, those key-value labels are merged with `labelSelector` as `key notin (value)` + to select the group of existing pods which pods will be taken into consideration + for the incoming pod's pod (anti) affinity. Keys that don't exist in the incoming + pod labels will be ignored. The default value is empty. + The same key is forbidden to exist in both mismatchLabelKeys and labelSelector. + Also, mismatchLabelKeys cannot be set when labelSelector isn't set. + type: array + items: + type: string + x-kubernetes-list-type: atomic + namespaceSelector: + description: |- + A label query over the set of namespaces that the term applies to. + The term is applied to the union of the namespaces selected by this field + and the ones listed in the namespaces field. + null selector and null or empty namespaces list means "this pod's namespace". + An empty selector ({}) matches all namespaces. + type: object + properties: + matchExpressions: + description: matchExpressions is a list of label selector requirements. The requirements are ANDed. + type: array + items: + description: |- + A label selector requirement is a selector that contains values, a key, and an operator that + relates the key and values. + type: object + required: + - key + - operator + properties: + key: + description: key is the label key that the selector applies to. + type: string + operator: + description: |- + operator represents a key's relationship to a set of values. + Valid operators are In, NotIn, Exists and DoesNotExist. + type: string + values: + description: |- + values is an array of string values. If the operator is In or NotIn, + the values array must be non-empty. If the operator is Exists or DoesNotExist, + the values array must be empty. This array is replaced during a strategic + merge patch. + type: array + items: + type: string + x-kubernetes-list-type: atomic + x-kubernetes-list-type: atomic + matchLabels: + description: |- + matchLabels is a map of {key,value} pairs. A single {key,value} in the matchLabels + map is equivalent to an element of matchExpressions, whose key field is "key", the + operator is "In", and the values array contains only "value". The requirements are ANDed. + type: object + additionalProperties: + type: string + x-kubernetes-map-type: atomic + namespaces: + description: |- + namespaces specifies a static list of namespace names that the term applies to. + The term is applied to the union of the namespaces listed in this field + and the ones selected by namespaceSelector. + null or empty namespaces list and null namespaceSelector means "this pod's namespace". + type: array + items: + type: string + x-kubernetes-list-type: atomic + topologyKey: + description: |- + This pod should be co-located (affinity) or not co-located (anti-affinity) with the pods matching + the labelSelector in the specified namespaces, where co-located is defined as running on a node + whose value of the label with key topologyKey matches that of any node on which any of the + selected pods is running. + Empty topologyKey is not allowed. + type: string + x-kubernetes-list-type: atomic + nodeSelector: + description: If specified, applies node selector rules to the pods deployed by the DNSConfig resource. + type: object + additionalProperties: + type: string tolerations: description: If specified, applies tolerations to the pods deployed by the DNSConfig resource. type: array diff --git a/cmd/k8s-operator/deploy/manifests/operator.yaml b/cmd/k8s-operator/deploy/manifests/operator.yaml index 597641bde..07c9f3af3 100644 --- a/cmd/k8s-operator/deploy/manifests/operator.yaml +++ b/cmd/k8s-operator/deploy/manifests/operator.yaml @@ -442,6 +442,884 @@ spec: pod: description: Pod configuration. properties: + affinity: + description: If specified, applies affinity rules to the pods deployed by the DNSConfig resource. + properties: + nodeAffinity: + description: Describes node affinity scheduling rules for the pod. + properties: + preferredDuringSchedulingIgnoredDuringExecution: + description: |- + The scheduler will prefer to schedule pods to nodes that satisfy + the affinity expressions specified by this field, but it may choose + a node that violates one or more of the expressions. The node that is + most preferred is the one with the greatest sum of weights, i.e. + for each node that meets all of the scheduling requirements (resource + request, requiredDuringScheduling affinity expressions, etc.), + compute a sum by iterating through the elements of this field and adding + "weight" to the sum if the node matches the corresponding matchExpressions; the + node(s) with the highest sum are the most preferred. + items: + description: |- + An empty preferred scheduling term matches all objects with implicit weight 0 + (i.e. it's a no-op). A null preferred scheduling term matches no objects (i.e. is also a no-op). + properties: + preference: + description: A node selector term, associated with the corresponding weight. + properties: + matchExpressions: + description: A list of node selector requirements by node's labels. + items: + description: |- + A node selector requirement is a selector that contains values, a key, and an operator + that relates the key and values. + properties: + key: + description: The label key that the selector applies to. + type: string + operator: + description: |- + Represents a key's relationship to a set of values. + Valid operators are In, NotIn, Exists, DoesNotExist. Gt, and Lt. + type: string + values: + description: |- + An array of string values. If the operator is In or NotIn, + the values array must be non-empty. If the operator is Exists or DoesNotExist, + the values array must be empty. If the operator is Gt or Lt, the values + array must have a single element, which will be interpreted as an integer. + This array is replaced during a strategic merge patch. + items: + type: string + type: array + x-kubernetes-list-type: atomic + required: + - key + - operator + type: object + type: array + x-kubernetes-list-type: atomic + matchFields: + description: A list of node selector requirements by node's fields. + items: + description: |- + A node selector requirement is a selector that contains values, a key, and an operator + that relates the key and values. + properties: + key: + description: The label key that the selector applies to. + type: string + operator: + description: |- + Represents a key's relationship to a set of values. + Valid operators are In, NotIn, Exists, DoesNotExist. Gt, and Lt. + type: string + values: + description: |- + An array of string values. If the operator is In or NotIn, + the values array must be non-empty. If the operator is Exists or DoesNotExist, + the values array must be empty. If the operator is Gt or Lt, the values + array must have a single element, which will be interpreted as an integer. + This array is replaced during a strategic merge patch. + items: + type: string + type: array + x-kubernetes-list-type: atomic + required: + - key + - operator + type: object + type: array + x-kubernetes-list-type: atomic + type: object + x-kubernetes-map-type: atomic + weight: + description: Weight associated with matching the corresponding nodeSelectorTerm, in the range 1-100. + format: int32 + type: integer + required: + - preference + - weight + type: object + type: array + x-kubernetes-list-type: atomic + requiredDuringSchedulingIgnoredDuringExecution: + description: |- + If the affinity requirements specified by this field are not met at + scheduling time, the pod will not be scheduled onto the node. + If the affinity requirements specified by this field cease to be met + at some point during pod execution (e.g. due to an update), the system + may or may not try to eventually evict the pod from its node. + properties: + nodeSelectorTerms: + description: Required. A list of node selector terms. The terms are ORed. + items: + description: |- + A null or empty node selector term matches no objects. The requirements of + them are ANDed. + The TopologySelectorTerm type implements a subset of the NodeSelectorTerm. + properties: + matchExpressions: + description: A list of node selector requirements by node's labels. + items: + description: |- + A node selector requirement is a selector that contains values, a key, and an operator + that relates the key and values. + properties: + key: + description: The label key that the selector applies to. + type: string + operator: + description: |- + Represents a key's relationship to a set of values. + Valid operators are In, NotIn, Exists, DoesNotExist. Gt, and Lt. + type: string + values: + description: |- + An array of string values. If the operator is In or NotIn, + the values array must be non-empty. If the operator is Exists or DoesNotExist, + the values array must be empty. If the operator is Gt or Lt, the values + array must have a single element, which will be interpreted as an integer. + This array is replaced during a strategic merge patch. + items: + type: string + type: array + x-kubernetes-list-type: atomic + required: + - key + - operator + type: object + type: array + x-kubernetes-list-type: atomic + matchFields: + description: A list of node selector requirements by node's fields. + items: + description: |- + A node selector requirement is a selector that contains values, a key, and an operator + that relates the key and values. + properties: + key: + description: The label key that the selector applies to. + type: string + operator: + description: |- + Represents a key's relationship to a set of values. + Valid operators are In, NotIn, Exists, DoesNotExist. Gt, and Lt. + type: string + values: + description: |- + An array of string values. If the operator is In or NotIn, + the values array must be non-empty. If the operator is Exists or DoesNotExist, + the values array must be empty. If the operator is Gt or Lt, the values + array must have a single element, which will be interpreted as an integer. + This array is replaced during a strategic merge patch. + items: + type: string + type: array + x-kubernetes-list-type: atomic + required: + - key + - operator + type: object + type: array + x-kubernetes-list-type: atomic + type: object + x-kubernetes-map-type: atomic + type: array + x-kubernetes-list-type: atomic + required: + - nodeSelectorTerms + type: object + x-kubernetes-map-type: atomic + type: object + podAffinity: + description: Describes pod affinity scheduling rules (e.g. co-locate this pod in the same node, zone, etc. as some other pod(s)). + properties: + preferredDuringSchedulingIgnoredDuringExecution: + description: |- + The scheduler will prefer to schedule pods to nodes that satisfy + the affinity expressions specified by this field, but it may choose + a node that violates one or more of the expressions. The node that is + most preferred is the one with the greatest sum of weights, i.e. + for each node that meets all of the scheduling requirements (resource + request, requiredDuringScheduling affinity expressions, etc.), + compute a sum by iterating through the elements of this field and adding + "weight" to the sum if the node has pods which matches the corresponding podAffinityTerm; the + node(s) with the highest sum are the most preferred. + items: + description: The weights of all of the matched WeightedPodAffinityTerm fields are added per-node to find the most preferred node(s) + properties: + podAffinityTerm: + description: Required. A pod affinity term, associated with the corresponding weight. + properties: + labelSelector: + description: |- + A label query over a set of resources, in this case pods. + If it's null, this PodAffinityTerm matches with no Pods. + properties: + matchExpressions: + description: matchExpressions is a list of label selector requirements. The requirements are ANDed. + items: + description: |- + A label selector requirement is a selector that contains values, a key, and an operator that + relates the key and values. + properties: + key: + description: key is the label key that the selector applies to. + type: string + operator: + description: |- + operator represents a key's relationship to a set of values. + Valid operators are In, NotIn, Exists and DoesNotExist. + type: string + values: + description: |- + values is an array of string values. If the operator is In or NotIn, + the values array must be non-empty. If the operator is Exists or DoesNotExist, + the values array must be empty. This array is replaced during a strategic + merge patch. + items: + type: string + type: array + x-kubernetes-list-type: atomic + required: + - key + - operator + type: object + type: array + x-kubernetes-list-type: atomic + matchLabels: + additionalProperties: + type: string + description: |- + matchLabels is a map of {key,value} pairs. A single {key,value} in the matchLabels + map is equivalent to an element of matchExpressions, whose key field is "key", the + operator is "In", and the values array contains only "value". The requirements are ANDed. + type: object + type: object + x-kubernetes-map-type: atomic + matchLabelKeys: + description: |- + MatchLabelKeys is a set of pod label keys to select which pods will + be taken into consideration. The keys are used to lookup values from the + incoming pod labels, those key-value labels are merged with `labelSelector` as `key in (value)` + to select the group of existing pods which pods will be taken into consideration + for the incoming pod's pod (anti) affinity. Keys that don't exist in the incoming + pod labels will be ignored. The default value is empty. + The same key is forbidden to exist in both matchLabelKeys and labelSelector. + Also, matchLabelKeys cannot be set when labelSelector isn't set. + items: + type: string + type: array + x-kubernetes-list-type: atomic + mismatchLabelKeys: + description: |- + MismatchLabelKeys is a set of pod label keys to select which pods will + be taken into consideration. The keys are used to lookup values from the + incoming pod labels, those key-value labels are merged with `labelSelector` as `key notin (value)` + to select the group of existing pods which pods will be taken into consideration + for the incoming pod's pod (anti) affinity. Keys that don't exist in the incoming + pod labels will be ignored. The default value is empty. + The same key is forbidden to exist in both mismatchLabelKeys and labelSelector. + Also, mismatchLabelKeys cannot be set when labelSelector isn't set. + items: + type: string + type: array + x-kubernetes-list-type: atomic + namespaceSelector: + description: |- + A label query over the set of namespaces that the term applies to. + The term is applied to the union of the namespaces selected by this field + and the ones listed in the namespaces field. + null selector and null or empty namespaces list means "this pod's namespace". + An empty selector ({}) matches all namespaces. + properties: + matchExpressions: + description: matchExpressions is a list of label selector requirements. The requirements are ANDed. + items: + description: |- + A label selector requirement is a selector that contains values, a key, and an operator that + relates the key and values. + properties: + key: + description: key is the label key that the selector applies to. + type: string + operator: + description: |- + operator represents a key's relationship to a set of values. + Valid operators are In, NotIn, Exists and DoesNotExist. + type: string + values: + description: |- + values is an array of string values. If the operator is In or NotIn, + the values array must be non-empty. If the operator is Exists or DoesNotExist, + the values array must be empty. This array is replaced during a strategic + merge patch. + items: + type: string + type: array + x-kubernetes-list-type: atomic + required: + - key + - operator + type: object + type: array + x-kubernetes-list-type: atomic + matchLabels: + additionalProperties: + type: string + description: |- + matchLabels is a map of {key,value} pairs. A single {key,value} in the matchLabels + map is equivalent to an element of matchExpressions, whose key field is "key", the + operator is "In", and the values array contains only "value". The requirements are ANDed. + type: object + type: object + x-kubernetes-map-type: atomic + namespaces: + description: |- + namespaces specifies a static list of namespace names that the term applies to. + The term is applied to the union of the namespaces listed in this field + and the ones selected by namespaceSelector. + null or empty namespaces list and null namespaceSelector means "this pod's namespace". + items: + type: string + type: array + x-kubernetes-list-type: atomic + topologyKey: + description: |- + This pod should be co-located (affinity) or not co-located (anti-affinity) with the pods matching + the labelSelector in the specified namespaces, where co-located is defined as running on a node + whose value of the label with key topologyKey matches that of any node on which any of the + selected pods is running. + Empty topologyKey is not allowed. + type: string + required: + - topologyKey + type: object + weight: + description: |- + weight associated with matching the corresponding podAffinityTerm, + in the range 1-100. + format: int32 + type: integer + required: + - podAffinityTerm + - weight + type: object + type: array + x-kubernetes-list-type: atomic + requiredDuringSchedulingIgnoredDuringExecution: + description: |- + If the affinity requirements specified by this field are not met at + scheduling time, the pod will not be scheduled onto the node. + If the affinity requirements specified by this field cease to be met + at some point during pod execution (e.g. due to a pod label update), the + system may or may not try to eventually evict the pod from its node. + When there are multiple elements, the lists of nodes corresponding to each + podAffinityTerm are intersected, i.e. all terms must be satisfied. + items: + description: |- + Defines a set of pods (namely those matching the labelSelector + relative to the given namespace(s)) that this pod should be + co-located (affinity) or not co-located (anti-affinity) with, + where co-located is defined as running on a node whose value of + the label with key matches that of any node on which + a pod of the set of pods is running + properties: + labelSelector: + description: |- + A label query over a set of resources, in this case pods. + If it's null, this PodAffinityTerm matches with no Pods. + properties: + matchExpressions: + description: matchExpressions is a list of label selector requirements. The requirements are ANDed. + items: + description: |- + A label selector requirement is a selector that contains values, a key, and an operator that + relates the key and values. + properties: + key: + description: key is the label key that the selector applies to. + type: string + operator: + description: |- + operator represents a key's relationship to a set of values. + Valid operators are In, NotIn, Exists and DoesNotExist. + type: string + values: + description: |- + values is an array of string values. If the operator is In or NotIn, + the values array must be non-empty. If the operator is Exists or DoesNotExist, + the values array must be empty. This array is replaced during a strategic + merge patch. + items: + type: string + type: array + x-kubernetes-list-type: atomic + required: + - key + - operator + type: object + type: array + x-kubernetes-list-type: atomic + matchLabels: + additionalProperties: + type: string + description: |- + matchLabels is a map of {key,value} pairs. A single {key,value} in the matchLabels + map is equivalent to an element of matchExpressions, whose key field is "key", the + operator is "In", and the values array contains only "value". The requirements are ANDed. + type: object + type: object + x-kubernetes-map-type: atomic + matchLabelKeys: + description: |- + MatchLabelKeys is a set of pod label keys to select which pods will + be taken into consideration. The keys are used to lookup values from the + incoming pod labels, those key-value labels are merged with `labelSelector` as `key in (value)` + to select the group of existing pods which pods will be taken into consideration + for the incoming pod's pod (anti) affinity. Keys that don't exist in the incoming + pod labels will be ignored. The default value is empty. + The same key is forbidden to exist in both matchLabelKeys and labelSelector. + Also, matchLabelKeys cannot be set when labelSelector isn't set. + items: + type: string + type: array + x-kubernetes-list-type: atomic + mismatchLabelKeys: + description: |- + MismatchLabelKeys is a set of pod label keys to select which pods will + be taken into consideration. The keys are used to lookup values from the + incoming pod labels, those key-value labels are merged with `labelSelector` as `key notin (value)` + to select the group of existing pods which pods will be taken into consideration + for the incoming pod's pod (anti) affinity. Keys that don't exist in the incoming + pod labels will be ignored. The default value is empty. + The same key is forbidden to exist in both mismatchLabelKeys and labelSelector. + Also, mismatchLabelKeys cannot be set when labelSelector isn't set. + items: + type: string + type: array + x-kubernetes-list-type: atomic + namespaceSelector: + description: |- + A label query over the set of namespaces that the term applies to. + The term is applied to the union of the namespaces selected by this field + and the ones listed in the namespaces field. + null selector and null or empty namespaces list means "this pod's namespace". + An empty selector ({}) matches all namespaces. + properties: + matchExpressions: + description: matchExpressions is a list of label selector requirements. The requirements are ANDed. + items: + description: |- + A label selector requirement is a selector that contains values, a key, and an operator that + relates the key and values. + properties: + key: + description: key is the label key that the selector applies to. + type: string + operator: + description: |- + operator represents a key's relationship to a set of values. + Valid operators are In, NotIn, Exists and DoesNotExist. + type: string + values: + description: |- + values is an array of string values. If the operator is In or NotIn, + the values array must be non-empty. If the operator is Exists or DoesNotExist, + the values array must be empty. This array is replaced during a strategic + merge patch. + items: + type: string + type: array + x-kubernetes-list-type: atomic + required: + - key + - operator + type: object + type: array + x-kubernetes-list-type: atomic + matchLabels: + additionalProperties: + type: string + description: |- + matchLabels is a map of {key,value} pairs. A single {key,value} in the matchLabels + map is equivalent to an element of matchExpressions, whose key field is "key", the + operator is "In", and the values array contains only "value". The requirements are ANDed. + type: object + type: object + x-kubernetes-map-type: atomic + namespaces: + description: |- + namespaces specifies a static list of namespace names that the term applies to. + The term is applied to the union of the namespaces listed in this field + and the ones selected by namespaceSelector. + null or empty namespaces list and null namespaceSelector means "this pod's namespace". + items: + type: string + type: array + x-kubernetes-list-type: atomic + topologyKey: + description: |- + This pod should be co-located (affinity) or not co-located (anti-affinity) with the pods matching + the labelSelector in the specified namespaces, where co-located is defined as running on a node + whose value of the label with key topologyKey matches that of any node on which any of the + selected pods is running. + Empty topologyKey is not allowed. + type: string + required: + - topologyKey + type: object + type: array + x-kubernetes-list-type: atomic + type: object + podAntiAffinity: + description: Describes pod anti-affinity scheduling rules (e.g. avoid putting this pod in the same node, zone, etc. as some other pod(s)). + properties: + preferredDuringSchedulingIgnoredDuringExecution: + description: |- + The scheduler will prefer to schedule pods to nodes that satisfy + the anti-affinity expressions specified by this field, but it may choose + a node that violates one or more of the expressions. The node that is + most preferred is the one with the greatest sum of weights, i.e. + for each node that meets all of the scheduling requirements (resource + request, requiredDuringScheduling anti-affinity expressions, etc.), + compute a sum by iterating through the elements of this field and subtracting + "weight" from the sum if the node has pods which matches the corresponding podAffinityTerm; the + node(s) with the highest sum are the most preferred. + items: + description: The weights of all of the matched WeightedPodAffinityTerm fields are added per-node to find the most preferred node(s) + properties: + podAffinityTerm: + description: Required. A pod affinity term, associated with the corresponding weight. + properties: + labelSelector: + description: |- + A label query over a set of resources, in this case pods. + If it's null, this PodAffinityTerm matches with no Pods. + properties: + matchExpressions: + description: matchExpressions is a list of label selector requirements. The requirements are ANDed. + items: + description: |- + A label selector requirement is a selector that contains values, a key, and an operator that + relates the key and values. + properties: + key: + description: key is the label key that the selector applies to. + type: string + operator: + description: |- + operator represents a key's relationship to a set of values. + Valid operators are In, NotIn, Exists and DoesNotExist. + type: string + values: + description: |- + values is an array of string values. If the operator is In or NotIn, + the values array must be non-empty. If the operator is Exists or DoesNotExist, + the values array must be empty. This array is replaced during a strategic + merge patch. + items: + type: string + type: array + x-kubernetes-list-type: atomic + required: + - key + - operator + type: object + type: array + x-kubernetes-list-type: atomic + matchLabels: + additionalProperties: + type: string + description: |- + matchLabels is a map of {key,value} pairs. A single {key,value} in the matchLabels + map is equivalent to an element of matchExpressions, whose key field is "key", the + operator is "In", and the values array contains only "value". The requirements are ANDed. + type: object + type: object + x-kubernetes-map-type: atomic + matchLabelKeys: + description: |- + MatchLabelKeys is a set of pod label keys to select which pods will + be taken into consideration. The keys are used to lookup values from the + incoming pod labels, those key-value labels are merged with `labelSelector` as `key in (value)` + to select the group of existing pods which pods will be taken into consideration + for the incoming pod's pod (anti) affinity. Keys that don't exist in the incoming + pod labels will be ignored. The default value is empty. + The same key is forbidden to exist in both matchLabelKeys and labelSelector. + Also, matchLabelKeys cannot be set when labelSelector isn't set. + items: + type: string + type: array + x-kubernetes-list-type: atomic + mismatchLabelKeys: + description: |- + MismatchLabelKeys is a set of pod label keys to select which pods will + be taken into consideration. The keys are used to lookup values from the + incoming pod labels, those key-value labels are merged with `labelSelector` as `key notin (value)` + to select the group of existing pods which pods will be taken into consideration + for the incoming pod's pod (anti) affinity. Keys that don't exist in the incoming + pod labels will be ignored. The default value is empty. + The same key is forbidden to exist in both mismatchLabelKeys and labelSelector. + Also, mismatchLabelKeys cannot be set when labelSelector isn't set. + items: + type: string + type: array + x-kubernetes-list-type: atomic + namespaceSelector: + description: |- + A label query over the set of namespaces that the term applies to. + The term is applied to the union of the namespaces selected by this field + and the ones listed in the namespaces field. + null selector and null or empty namespaces list means "this pod's namespace". + An empty selector ({}) matches all namespaces. + properties: + matchExpressions: + description: matchExpressions is a list of label selector requirements. The requirements are ANDed. + items: + description: |- + A label selector requirement is a selector that contains values, a key, and an operator that + relates the key and values. + properties: + key: + description: key is the label key that the selector applies to. + type: string + operator: + description: |- + operator represents a key's relationship to a set of values. + Valid operators are In, NotIn, Exists and DoesNotExist. + type: string + values: + description: |- + values is an array of string values. If the operator is In or NotIn, + the values array must be non-empty. If the operator is Exists or DoesNotExist, + the values array must be empty. This array is replaced during a strategic + merge patch. + items: + type: string + type: array + x-kubernetes-list-type: atomic + required: + - key + - operator + type: object + type: array + x-kubernetes-list-type: atomic + matchLabels: + additionalProperties: + type: string + description: |- + matchLabels is a map of {key,value} pairs. A single {key,value} in the matchLabels + map is equivalent to an element of matchExpressions, whose key field is "key", the + operator is "In", and the values array contains only "value". The requirements are ANDed. + type: object + type: object + x-kubernetes-map-type: atomic + namespaces: + description: |- + namespaces specifies a static list of namespace names that the term applies to. + The term is applied to the union of the namespaces listed in this field + and the ones selected by namespaceSelector. + null or empty namespaces list and null namespaceSelector means "this pod's namespace". + items: + type: string + type: array + x-kubernetes-list-type: atomic + topologyKey: + description: |- + This pod should be co-located (affinity) or not co-located (anti-affinity) with the pods matching + the labelSelector in the specified namespaces, where co-located is defined as running on a node + whose value of the label with key topologyKey matches that of any node on which any of the + selected pods is running. + Empty topologyKey is not allowed. + type: string + required: + - topologyKey + type: object + weight: + description: |- + weight associated with matching the corresponding podAffinityTerm, + in the range 1-100. + format: int32 + type: integer + required: + - podAffinityTerm + - weight + type: object + type: array + x-kubernetes-list-type: atomic + requiredDuringSchedulingIgnoredDuringExecution: + description: |- + If the anti-affinity requirements specified by this field are not met at + scheduling time, the pod will not be scheduled onto the node. + If the anti-affinity requirements specified by this field cease to be met + at some point during pod execution (e.g. due to a pod label update), the + system may or may not try to eventually evict the pod from its node. + When there are multiple elements, the lists of nodes corresponding to each + podAffinityTerm are intersected, i.e. all terms must be satisfied. + items: + description: |- + Defines a set of pods (namely those matching the labelSelector + relative to the given namespace(s)) that this pod should be + co-located (affinity) or not co-located (anti-affinity) with, + where co-located is defined as running on a node whose value of + the label with key matches that of any node on which + a pod of the set of pods is running + properties: + labelSelector: + description: |- + A label query over a set of resources, in this case pods. + If it's null, this PodAffinityTerm matches with no Pods. + properties: + matchExpressions: + description: matchExpressions is a list of label selector requirements. The requirements are ANDed. + items: + description: |- + A label selector requirement is a selector that contains values, a key, and an operator that + relates the key and values. + properties: + key: + description: key is the label key that the selector applies to. + type: string + operator: + description: |- + operator represents a key's relationship to a set of values. + Valid operators are In, NotIn, Exists and DoesNotExist. + type: string + values: + description: |- + values is an array of string values. If the operator is In or NotIn, + the values array must be non-empty. If the operator is Exists or DoesNotExist, + the values array must be empty. This array is replaced during a strategic + merge patch. + items: + type: string + type: array + x-kubernetes-list-type: atomic + required: + - key + - operator + type: object + type: array + x-kubernetes-list-type: atomic + matchLabels: + additionalProperties: + type: string + description: |- + matchLabels is a map of {key,value} pairs. A single {key,value} in the matchLabels + map is equivalent to an element of matchExpressions, whose key field is "key", the + operator is "In", and the values array contains only "value". The requirements are ANDed. + type: object + type: object + x-kubernetes-map-type: atomic + matchLabelKeys: + description: |- + MatchLabelKeys is a set of pod label keys to select which pods will + be taken into consideration. The keys are used to lookup values from the + incoming pod labels, those key-value labels are merged with `labelSelector` as `key in (value)` + to select the group of existing pods which pods will be taken into consideration + for the incoming pod's pod (anti) affinity. Keys that don't exist in the incoming + pod labels will be ignored. The default value is empty. + The same key is forbidden to exist in both matchLabelKeys and labelSelector. + Also, matchLabelKeys cannot be set when labelSelector isn't set. + items: + type: string + type: array + x-kubernetes-list-type: atomic + mismatchLabelKeys: + description: |- + MismatchLabelKeys is a set of pod label keys to select which pods will + be taken into consideration. The keys are used to lookup values from the + incoming pod labels, those key-value labels are merged with `labelSelector` as `key notin (value)` + to select the group of existing pods which pods will be taken into consideration + for the incoming pod's pod (anti) affinity. Keys that don't exist in the incoming + pod labels will be ignored. The default value is empty. + The same key is forbidden to exist in both mismatchLabelKeys and labelSelector. + Also, mismatchLabelKeys cannot be set when labelSelector isn't set. + items: + type: string + type: array + x-kubernetes-list-type: atomic + namespaceSelector: + description: |- + A label query over the set of namespaces that the term applies to. + The term is applied to the union of the namespaces selected by this field + and the ones listed in the namespaces field. + null selector and null or empty namespaces list means "this pod's namespace". + An empty selector ({}) matches all namespaces. + properties: + matchExpressions: + description: matchExpressions is a list of label selector requirements. The requirements are ANDed. + items: + description: |- + A label selector requirement is a selector that contains values, a key, and an operator that + relates the key and values. + properties: + key: + description: key is the label key that the selector applies to. + type: string + operator: + description: |- + operator represents a key's relationship to a set of values. + Valid operators are In, NotIn, Exists and DoesNotExist. + type: string + values: + description: |- + values is an array of string values. If the operator is In or NotIn, + the values array must be non-empty. If the operator is Exists or DoesNotExist, + the values array must be empty. This array is replaced during a strategic + merge patch. + items: + type: string + type: array + x-kubernetes-list-type: atomic + required: + - key + - operator + type: object + type: array + x-kubernetes-list-type: atomic + matchLabels: + additionalProperties: + type: string + description: |- + matchLabels is a map of {key,value} pairs. A single {key,value} in the matchLabels + map is equivalent to an element of matchExpressions, whose key field is "key", the + operator is "In", and the values array contains only "value". The requirements are ANDed. + type: object + type: object + x-kubernetes-map-type: atomic + namespaces: + description: |- + namespaces specifies a static list of namespace names that the term applies to. + The term is applied to the union of the namespaces listed in this field + and the ones selected by namespaceSelector. + null or empty namespaces list and null namespaceSelector means "this pod's namespace". + items: + type: string + type: array + x-kubernetes-list-type: atomic + topologyKey: + description: |- + This pod should be co-located (affinity) or not co-located (anti-affinity) with the pods matching + the labelSelector in the specified namespaces, where co-located is defined as running on a node + whose value of the label with key topologyKey matches that of any node on which any of the + selected pods is running. + Empty topologyKey is not allowed. + type: string + required: + - topologyKey + type: object + type: array + x-kubernetes-list-type: atomic + type: object + type: object + nodeSelector: + additionalProperties: + type: string + description: If specified, applies node selector rules to the pods deployed by the DNSConfig resource. + type: object tolerations: description: If specified, applies tolerations to the pods deployed by the DNSConfig resource. items: diff --git a/cmd/k8s-operator/e2e/ingress_test.go b/cmd/k8s-operator/e2e/ingress_test.go index 4eb813a77..bef24ca5a 100644 --- a/cmd/k8s-operator/e2e/ingress_test.go +++ b/cmd/k8s-operator/e2e/ingress_test.go @@ -17,9 +17,11 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "sigs.k8s.io/controller-runtime/pkg/client" + "tailscale.com/client/tailscale/v2" kube "tailscale.com/k8s-operator" tsapi "tailscale.com/k8s-operator/apis/v1alpha1" "tailscale.com/kube/kubetypes" + "tailscale.com/tsnet" "tailscale.com/tstest" "tailscale.com/util/httpm" ) @@ -31,12 +33,12 @@ func TestL3Ingress(t *testing.T) { } // Apply nginx - createAndCleanup(t, kubeClient, nginxDeployment(ns, "nginx")) + nginx := nginxDeployment(ns) + createAndCleanup(t, kubeClient, nginx) // Apply service to expose it as ingress - name := generateName("test-ingress") svc := &corev1.Service{ ObjectMeta: metav1.ObjectMeta{ - Name: name, + Name: generateName("test-ingress"), Namespace: ns, Annotations: map[string]string{ "tailscale.com/expose": "true", @@ -44,7 +46,7 @@ func TestL3Ingress(t *testing.T) { }, Spec: corev1.ServiceSpec{ Selector: map[string]string{ - "app.kubernetes.io/name": "nginx", + "app.kubernetes.io/name": nginx.Name, }, Ports: []corev1.ServicePort{ { @@ -58,7 +60,7 @@ func TestL3Ingress(t *testing.T) { createAndCleanup(t, kubeClient, svc) if err := tstest.WaitFor(time.Minute, func() error { - maybeReadySvc := &corev1.Service{ObjectMeta: objectMeta(ns, name)} + maybeReadySvc := &corev1.Service{ObjectMeta: objectMeta(ns, svc.Name)} if err := get(t.Context(), kubeClient, maybeReadySvc); err != nil { return err } @@ -79,7 +81,7 @@ func TestL3Ingress(t *testing.T) { if err := kubeClient.List(t.Context(), &secrets, client.InNamespace("tailscale"), client.MatchingLabels{ - "tailscale.com/parent-resource": name, + "tailscale.com/parent-resource": svc.Name, "tailscale.com/parent-resource-ns": ns, }, ); err != nil { @@ -109,33 +111,34 @@ func TestL3HAIngress(t *testing.T) { } // Apply nginx. - createAndCleanup(t, kubeClient, nginxDeployment(ns, "nginx")) + nginx := nginxDeployment(ns) + createAndCleanup(t, kubeClient, nginx) // Create an ingress ProxyGroup. - createAndCleanup(t, kubeClient, &tsapi.ProxyGroup{ + pg := &tsapi.ProxyGroup{ ObjectMeta: metav1.ObjectMeta{ - Name: "ingress", + Name: generateName("ingress"), }, Spec: tsapi.ProxyGroupSpec{ Type: tsapi.ProxyGroupTypeIngress, }, - }) + } + createAndCleanup(t, kubeClient, pg) // Apply a Service to expose nginx via the ProxyGroup. - name := generateName("test-ingress") svc := &corev1.Service{ ObjectMeta: metav1.ObjectMeta{ - Name: name, + Name: generateName("test-ingress"), Namespace: ns, Annotations: map[string]string{ - "tailscale.com/proxy-group": "ingress", + "tailscale.com/proxy-group": pg.Name, }, }, Spec: corev1.ServiceSpec{ Type: corev1.ServiceTypeLoadBalancer, LoadBalancerClass: new("tailscale"), Selector: map[string]string{ - "app.kubernetes.io/name": "nginx", + "app.kubernetes.io/name": nginx.Name, }, Ports: []corev1.ServicePort{ { @@ -150,12 +153,12 @@ func TestL3HAIngress(t *testing.T) { var svcIPv4 string forceReconcile := triggerReconcile(t, - client.ObjectKey{Namespace: ns, Name: name}, + client.ObjectKey{Namespace: ns, Name: svc.Name}, &corev1.Service{}, 30*time.Second) // Wait for Service to be ready if err := tstest.WaitFor(5*time.Minute, func() error { - maybeReadySvc := &corev1.Service{ObjectMeta: objectMeta(ns, name)} + maybeReadySvc := &corev1.Service{ObjectMeta: objectMeta(ns, svc.Name)} forceReconcile() if err := get(t.Context(), kubeClient, maybeReadySvc); err != nil { return err @@ -186,15 +189,16 @@ func TestL7Ingress(t *testing.T) { } // Apply nginx Deployment and Service. - createAndCleanup(t, kubeClient, nginxDeployment(ns, "nginx")) + nginx := nginxDeployment(ns) + createAndCleanup(t, kubeClient, nginx) createAndCleanup(t, kubeClient, &corev1.Service{ ObjectMeta: metav1.ObjectMeta{ - Name: "nginx", + Name: nginx.Name, Namespace: ns, }, Spec: corev1.ServiceSpec{ Selector: map[string]string{ - "app.kubernetes.io/name": "nginx", + "app.kubernetes.io/name": nginx.Name, }, Ports: []corev1.ServicePort{ { @@ -206,13 +210,12 @@ func TestL7Ingress(t *testing.T) { }) // Apply Ingress to expose nginx. - name := generateName("test-ingress") - ingress := l7Ingress(ns, name, map[string]string{}) + ingress := l7Ingress(ns, nginx.Name, map[string]string{}) createAndCleanup(t, kubeClient, ingress) t.Log("Waiting for the Ingress to be ready...") - hostname, err := waitForIngressHostname(t, ns, name) + hostname, err := waitForIngressHostname(t, ns, ingress.Name) if err != nil { t.Fatalf("error waiting for Ingress hostname: %v", err) } @@ -228,15 +231,16 @@ func TestL7HAIngress(t *testing.T) { } // Apply nginx Deployment and Service. - createAndCleanup(t, kubeClient, nginxDeployment(ns, "nginx")) + nginx := nginxDeployment(ns) + createAndCleanup(t, kubeClient, nginx) createAndCleanup(t, kubeClient, &corev1.Service{ ObjectMeta: metav1.ObjectMeta{ - Name: "nginx", + Name: nginx.Name, Namespace: ns, }, Spec: corev1.ServiceSpec{ Selector: map[string]string{ - "app.kubernetes.io/name": "nginx", + "app.kubernetes.io/name": nginx.Name, }, Ports: []corev1.ServicePort{ { @@ -248,23 +252,23 @@ func TestL7HAIngress(t *testing.T) { }) // Create ProxyGroup that the Ingress will reference. - createAndCleanup(t, kubeClient, &tsapi.ProxyGroup{ + pg := &tsapi.ProxyGroup{ ObjectMeta: metav1.ObjectMeta{ - Name: "ingress", + Name: generateName("ingress"), }, Spec: tsapi.ProxyGroupSpec{ Type: tsapi.ProxyGroupTypeIngress, }, - }) + } + createAndCleanup(t, kubeClient, pg) // Apply Ingress to expose nginx. - name := generateName("test-ingress") - ingress := l7Ingress(ns, name, map[string]string{"tailscale.com/proxy-group": "ingress"}) + ingress := l7Ingress(ns, nginx.Name, map[string]string{"tailscale.com/proxy-group": pg.Name}) createAndCleanup(t, kubeClient, ingress) t.Log("Waiting for the Ingress to be ready...") - hostname, err := waitForIngressHostname(t, ns, name) + hostname, err := waitForIngressHostname(t, ns, ingress.Name) if err != nil { t.Fatalf("error waiting for Ingress hostname: %v", err) } @@ -274,7 +278,88 @@ func TestL7HAIngress(t *testing.T) { } } -func l7Ingress(namespace, name string, annotations map[string]string) *networkingv1.Ingress { +func TestL7HAIngressMultiTailnet(t *testing.T) { + if tnClient == nil || secondTNClient == nil { + t.Skip("TestL7HAIngressMultiTailnet requires a working tailnet client for a first and second tailnet") + } + + // Apply nginx Deployment and Service. + nginx := nginxDeployment(ns) + createAndCleanup(t, kubeClient, nginx) + createAndCleanup(t, kubeClient, &corev1.Service{ + ObjectMeta: metav1.ObjectMeta{ + Name: nginx.Name, + Namespace: ns, + }, + Spec: corev1.ServiceSpec{ + Selector: map[string]string{ + "app.kubernetes.io/name": nginx.Name, + }, + Ports: []corev1.ServicePort{ + { + Name: "http", + Port: 80, + }, + }, + }, + }) + + // Create Ingress ProxyGroup for each Tailnet. + firstTailnetPG := &tsapi.ProxyGroup{ + ObjectMeta: metav1.ObjectMeta{ + Name: generateName("first-tailnet"), + }, + Spec: tsapi.ProxyGroupSpec{ + Type: tsapi.ProxyGroupTypeIngress, + }, + } + createAndCleanup(t, kubeClient, firstTailnetPG) + secondTailnetPG := &tsapi.ProxyGroup{ + ObjectMeta: metav1.ObjectMeta{ + Name: generateName("second-tailnet"), + }, + Spec: tsapi.ProxyGroupSpec{ + Type: tsapi.ProxyGroupTypeIngress, + Tailnet: "second-tailnet", + }, + } + createAndCleanup(t, kubeClient, secondTailnetPG) + + if err := verifyProxyGroupTailnet(t, firstTailnetPG, tnClient); err != nil { + t.Fatalf("verifying ProxyGroup %s is registered to the correct tailnet: %v", firstTailnetPG.Name, err) + } + if err := verifyProxyGroupTailnet(t, secondTailnetPG, secondTNClient); err != nil { + t.Fatalf("verifying ProxyGroup %s is registered to the correct tailnet: %v", secondTailnetPG.Name, err) + } + + // Apply Ingress to expose nginx. + ingress := l7Ingress(ns, nginx.Name, map[string]string{ + "tailscale.com/proxy-group": secondTailnetPG.Name, + }) + createAndCleanup(t, kubeClient, ingress) + + // Check that the tailscale (VIP) Service has been created in the expected Tailnet. + svcName := "svc:" + ingress.Name + if err := tstest.WaitFor(3*time.Minute, func() error { + _, err := secondTSClient.VIPServices().Get(t.Context(), svcName) + if tailscale.IsNotFound(err) { + return fmt.Errorf("Tailscale service %q not yet in expected tailnet", svcName) + } + return err + }); err != nil { + t.Fatalf("Tailscale service %q never appeared in expected tailnet: %v", svcName, err) + } + hostname, err := waitForIngressHostname(t, ns, ingress.Name) + if err != nil { + t.Fatalf("error waiting for Ingress hostname: %v", err) + } + if err := testIngressIsReachable(t, newHTTPClient(secondTNClient), fmt.Sprintf("https://%s:443", hostname)); err != nil { + t.Fatal(err) + } +} + +func l7Ingress(namespace, svc string, annotations map[string]string) *networkingv1.Ingress { + name := generateName("test-ingress") ingress := &networkingv1.Ingress{ ObjectMeta: metav1.ObjectMeta{ Name: name, @@ -296,7 +381,7 @@ func l7Ingress(namespace, name string, annotations map[string]string) *networkin PathType: new(networkingv1.PathTypePrefix), Backend: networkingv1.IngressBackend{ Service: &networkingv1.IngressServiceBackend{ - Name: "nginx", + Name: svc, Port: networkingv1.ServiceBackendPort{ Number: 80, }, @@ -313,26 +398,27 @@ func l7Ingress(namespace, name string, annotations map[string]string) *networkin return ingress } -func nginxDeployment(namespace, name string) *appsv1.Deployment { +func nginxDeployment(namespace string) *appsv1.Deployment { + name := generateName("nginx") return &appsv1.Deployment{ ObjectMeta: metav1.ObjectMeta{ Name: name, Namespace: namespace, Labels: map[string]string{ - "app.kubernetes.io/name": "nginx", + "app.kubernetes.io/name": name, }, }, Spec: appsv1.DeploymentSpec{ Replicas: new(int32(1)), Selector: &metav1.LabelSelector{ MatchLabels: map[string]string{ - "app.kubernetes.io/name": "nginx", + "app.kubernetes.io/name": name, }, }, Template: corev1.PodTemplateSpec{ ObjectMeta: metav1.ObjectMeta{ Labels: map[string]string{ - "app.kubernetes.io/name": "nginx", + "app.kubernetes.io/name": name, }, }, Spec: corev1.PodSpec{ @@ -406,6 +492,56 @@ func testIngressIsReachable(t *testing.T, httpClient *http.Client, url string) e return nil } +// verifyProxyGroupTailnet verifies that a ProxyGroup is registered to the correct tailnet. +// This is done by getting the expected tailnet domain for the tailnet client, +// and comparing this with the actual device fqdn in the ProxyGroup state secret. +func verifyProxyGroupTailnet(t *testing.T, pg *tsapi.ProxyGroup, cl *tsnet.Server) error { + t.Helper() + // Determine the expected tailnet Magic DNS Name. + lc, err := cl.LocalClient() + if err != nil { + return err + } + status, err := lc.Status(t.Context()) + if err != nil { + return err + } + _, expectedTailnet, ok := strings.Cut(strings.TrimSuffix(status.Self.DNSName, "."), ".") + if !ok { + return fmt.Errorf("unexpected DNSName format %q", status.Self.DNSName) + } + // Read the device FQDN from the first state secret for the ProxyGroup, + // and verify that this matches the expected tailnet. + if err := tstest.WaitFor(3*time.Minute, func() error { + var secrets corev1.SecretList + if err := kubeClient.List(t.Context(), &secrets, + client.InNamespace("tailscale"), + client.MatchingLabels{ + kubetypes.LabelSecretType: kubetypes.LabelSecretTypeState, + "tailscale.com/parent-resource-type": "proxygroup", + "tailscale.com/parent-resource": pg.Name, + }, + ); err != nil { + return err + } + if len(secrets.Items) == 0 { + return fmt.Errorf("no state secrets found for ProxyGroup %q yet", pg.Name) + } + fqdn := strings.TrimSuffix(string(secrets.Items[0].Data[kubetypes.KeyDeviceFQDN]), ".") + _, tailnet, ok := strings.Cut(fqdn, ".") + if !ok { + return fmt.Errorf("ProxyGroup %q: device FQDN %q has no domain yet", pg.Name, fqdn) + } + if tailnet != expectedTailnet { + return fmt.Errorf("ProxyGroup %q on wrong tailnet: got domain %q, want %q", pg.Name, tailnet, expectedTailnet) + } + return nil + }); err != nil { + return fmt.Errorf("ProxyGroup %q not on expected tailnet: %v", pg.Name, err) + } + return nil +} + func waitForIngressHostname(t *testing.T, namespace, name string) (string, error) { t.Helper() var hostname string diff --git a/cmd/k8s-operator/e2e/main_test.go b/cmd/k8s-operator/e2e/main_test.go index 02f614014..9eab9e301 100644 --- a/cmd/k8s-operator/e2e/main_test.go +++ b/cmd/k8s-operator/e2e/main_test.go @@ -54,7 +54,7 @@ func createAndCleanup(t *testing.T, cl client.Client, obj client.Object) { t.Cleanup(func() { // Use context.Background() for cleanup, as t.Context() is cancelled // just before cleanup functions are called. - if err = cl.Delete(context.Background(), obj); err != nil { + if err := cl.Delete(context.Background(), obj); err != nil { t.Errorf("error cleaning up %s %s/%s: %s", obj.GetObjectKind().GroupVersionKind(), obj.GetNamespace(), obj.GetName(), err) } }) @@ -69,7 +69,7 @@ func createAndCleanupErr(t *testing.T, cl client.Client, obj client.Object) erro } t.Cleanup(func() { - if err = cl.Delete(context.Background(), obj); err != nil { + if err := cl.Delete(context.Background(), obj); err != nil { t.Errorf("error cleaning up %s %s/%s: %s", obj.GetObjectKind().GroupVersionKind(), obj.GetNamespace(), obj.GetName(), err) } }) diff --git a/cmd/k8s-operator/e2e/setup.go b/cmd/k8s-operator/e2e/setup.go index 55d4652d8..0d4ca80ad 100644 --- a/cmd/k8s-operator/e2e/setup.go +++ b/cmd/k8s-operator/e2e/setup.go @@ -4,6 +4,7 @@ package e2e import ( + "bytes" "context" "crypto/rand" "crypto/tls" @@ -39,6 +40,7 @@ import ( "helm.sh/helm/v3/pkg/release" "helm.sh/helm/v3/pkg/storage/driver" corev1 "k8s.io/api/core/v1" + apierrors "k8s.io/apimachinery/pkg/api/errors" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/apimachinery/pkg/watch" "k8s.io/client-go/rest" @@ -69,10 +71,13 @@ const ( ) var ( - tsClient *tailscale.Client // For API calls to control. - tnClient *tsnet.Server // For testing real tailnet traffic. - restCfg *rest.Config // For constructing a client-go client if necessary. - kubeClient client.WithWatch // For k8s API calls. + tsClient *tailscale.Client // For API calls to control. + tnClient *tsnet.Server // For testing real tailnet traffic on first tailnet. + secondTSClient *tailscale.Client // For API calls to the secondary tailnet (_second_tailnet). + secondTNClient *tsnet.Server // For testing real tailnet traffic on second tailnet. + restCfg *rest.Config // For constructing a client-go client if necessary. + kubeClient client.WithWatch // For k8s API calls. + clusterLoginServer string //go:embed certs/pebble.minica.crt pebbleMiniCACert []byte @@ -157,11 +162,11 @@ func runTests(m *testing.M) (int, error) { } var ( - clusterLoginServer string // Login server from cluster Pod point of view. - clientID, clientSecret string // OAuth client for the operator to use. + clientID, clientSecret string // OAuth client for the first tailnet (for the operator to use). caPaths []string // Extra CA cert file paths to add to images. - certsDir = filepath.Join(tmp, "certs") // Directory containing extra CA certs to add to images. + certsDir = filepath.Join(tmp, "certs") // Directory containing extra CA certs to add to images. + secondClientID, secondClientSecret string // OAuth client for the second tailnet (for the operator to use). ) if *fDevcontrol { // Deploy pebble and get its certs. @@ -279,7 +284,7 @@ func runTests(m *testing.M) (int, error) { return 0, fmt.Errorf("failed to set policy file: %w", err) } - logger.Infof("ACLs configured") + logger.Info("ACLs configured for first tailnet") key, err := tsClient.Keys().CreateOAuthClient(ctx, tailscale.CreateOAuthClientRequest{ Scopes: []string{"auth_keys", "devices:core", "services"}, @@ -287,36 +292,77 @@ func runTests(m *testing.M) (int, error) { Description: "k8s-operator client for e2e tests", }) if err != nil { - return 0, fmt.Errorf("failed to marshal OAuth client creation request: %w", err) + return 0, fmt.Errorf("failed to create OAuth client for first tailnet: %w", err) } - clientID = key.ID clientSecret = key.Key + + logger.Info("OAuth credentials set for first tailnet") + + // Create second tailnet. The bootstrap credentials returned have 'all' permissions- + // they are used for administrative actions and to create a separately scoped + // Oauth client for the k8s operator. + bootstrapClient, err := createTailnet(ctx, tsClient) + if err != nil { + return 0, fmt.Errorf("failed to create second tailnet: %w", err) + } + + // Set HTTPS on second tailnet. + err = bootstrapClient.TailnetSettings().Update(ctx, tailscale.UpdateTailnetSettingsRequest{HTTPSEnabled: new(true)}) + if err != nil { + return 0, fmt.Errorf("failed to configure https for second tailnet: %w", err) + } + logger.Info("HTTPS settings configured for second tailnet") + + // Set ACLs for second tailnet. + if err = bootstrapClient.PolicyFile().Set(ctx, string(requiredACLs), ""); err != nil { + return 0, fmt.Errorf("failed to set policy file: %w", err) + } + + logger.Info("ACLs configured for second tailnet") + + // Create an OAuth client for the second tailnet to be used + // by the k8s-operator. + secondKey, err := bootstrapClient.Keys().CreateOAuthClient(ctx, tailscale.CreateOAuthClientRequest{ + Scopes: []string{"auth_keys", "devices:core", "services"}, + Tags: []string{"tag:k8s-operator"}, + Description: "k8s-operator client for e2e tests", + }) + if err != nil { + return 0, fmt.Errorf("failed to create OAuth client for second tailnet: %w", err) + } + secondClientID = secondKey.ID + secondClientSecret = secondKey.Key + + secondTSClient, err = tailscaleClientFromSecret(ctx, "http://localhost:31544", secondClientID, secondClientSecret) + if err != nil { + return 0, fmt.Errorf("failed to set up second tailnet client: %w", err) + } + } else { clientSecret = os.Getenv("TS_API_CLIENT_SECRET") if clientSecret == "" { return 0, fmt.Errorf("must use --devcontrol or set TS_API_CLIENT_SECRET to an OAuth client suitable for the operator") } - // Format is "tskey-client--". - parts := strings.Split(clientSecret, "-") - if len(parts) != 4 { - return 0, fmt.Errorf("TS_API_CLIENT_SECRET is not valid") - } - clientID = parts[2] - credentials := clientcredentials.Config{ - ClientID: clientID, - ClientSecret: clientSecret, - TokenURL: fmt.Sprintf("%s/api/v2/oauth/token", ipn.DefaultControlURL), - Scopes: []string{"auth_keys"}, - } - tk, err := credentials.Token(ctx) + clientID, err = clientIDFromSecret(clientSecret) if err != nil { - return 0, fmt.Errorf("failed to get OAuth token: %w", err) + return 0, fmt.Errorf("failed to get client id from secret: %w", err) } - // An access token will last for an hour which is plenty of time for - // the tests to run. No need for token refresh logic. - tsClient = &tailscale.Client{ - APIKey: tk.AccessToken, + tsClient, err = tailscaleClientFromSecret(ctx, ipn.DefaultControlURL, clientID, clientSecret) + if err != nil { + return 0, fmt.Errorf("failed to set up first tailnet client: %w", err) + } + secondClientSecret = os.Getenv("SECOND_TS_API_CLIENT_SECRET") + if secondClientSecret == "" { + return 0, fmt.Errorf("must use --devcontrol or set SECOND_TS_API_CLIENT_SECRET to an OAuth client suitable for the operator") + } + secondClientID, err = clientIDFromSecret(secondClientSecret) + if err != nil { + return 0, fmt.Errorf("failed to get client id from secret: %w", err) + } + secondTSClient, err = tailscaleClientFromSecret(ctx, ipn.DefaultControlURL, secondClientID, secondClientSecret) + if err != nil { + return 0, fmt.Errorf("failed to set up second tailnet client: %w", err) } } @@ -446,10 +492,16 @@ func runTests(m *testing.M) (int, error) { authKey, err := tsClient.Keys().CreateAuthKey(ctx, tailscale.CreateKeyRequest{Capabilities: caps}) if err != nil { - return 0, err + return 0, fmt.Errorf("failed to create auth key for first tailnet: %w", err) } defer tsClient.Keys().Delete(context.Background(), authKey.ID) + secondAuthKey, err := secondTSClient.Keys().CreateAuthKey(ctx, tailscale.CreateKeyRequest{Capabilities: caps}) + if err != nil { + return 0, fmt.Errorf("failed to create auth key for second tailnet: %w", err) + } + defer secondTSClient.Keys().Delete(context.Background(), secondAuthKey.ID) + tnClient = &tsnet.Server{ ControlURL: tsClient.BaseURL.String(), Hostname: "test-proxy", @@ -463,9 +515,64 @@ func runTests(m *testing.M) (int, error) { } defer tnClient.Close() + secondTNClient = &tsnet.Server{ + ControlURL: secondTSClient.BaseURL.String(), + Hostname: "test-proxy", + Ephemeral: true, + Store: &mem.Store{}, + AuthKey: secondAuthKey.Key, + } + _, err = secondTNClient.Up(ctx) + if err != nil { + return 0, err + } + defer secondTNClient.Close() + + // Create the tailnet Secret in the tailscale namespace. + secret := &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: "second-tailnet-credentials", + Namespace: "tailscale", + }, + Data: map[string][]byte{ + "client_id": []byte(secondClientID), + "client_secret": []byte(secondClientSecret), + }, + } + if err := createOrUpdate(ctx, kubeClient, secret); err != nil { + return 0, fmt.Errorf("failed to create second-tailnet-credentials Secret: %w", err) + } + defer kubeClient.Delete(context.Background(), secret) + + // Create the Tailnet resource. + tn := &tsapi.Tailnet{ + ObjectMeta: metav1.ObjectMeta{ + Name: "second-tailnet", + }, + Spec: tsapi.TailnetSpec{ + LoginURL: clusterLoginServer, + Credentials: tsapi.TailnetCredentials{ + SecretName: "second-tailnet-credentials", + }, + }, + } + if err := createOrUpdate(ctx, kubeClient, tn); err != nil { + return 0, fmt.Errorf("failed to create second-tailnet Tailnet: %w", err) + } + defer kubeClient.Delete(context.Background(), tn) + return m.Run(), nil } +func clientIDFromSecret(clientSecret string) (string, error) { + // Format is "tskey-client--". + parts := strings.Split(clientSecret, "-") + if len(parts) != 4 { + return "", fmt.Errorf("secret is not valid") + } + return parts[2], nil +} + func upgraderOrInstaller(cfg *action.Configuration, releaseName string) helmInstallerFunc { hist := action.NewHistory(cfg) hist.Max = 1 @@ -724,3 +831,65 @@ func buildImage(ctx context.Context, dir, repo, target, tag string, extraCACerts return nil } + +func createOrUpdate(ctx context.Context, cl client.Client, obj client.Object) error { + if err := cl.Create(ctx, obj); err != nil { + if !apierrors.IsAlreadyExists(err) { + return err + } + return cl.Update(ctx, obj) + } + return nil +} + +// createTailnet creates a new tailnet and returns a tailscale.Client +// authenticated against it using the bootstrap credentials included in the +// creation response. +func createTailnet(ctx context.Context, tsClient *tailscale.Client) (*tailscale.Client, error) { + tailnetName := fmt.Sprintf("second-tailnet-%d", time.Now().Unix()) + body, err := json.Marshal(map[string]any{"displayName": tailnetName}) + if err != nil { + return nil, fmt.Errorf("failed to marshal tailnet creation request: %w", err) + } + // TODO(beckypauley): change to use a method on tailscale.Client once this is available. + req, _ := http.NewRequestWithContext(ctx, "POST", tsClient.BaseURL.String()+"/api/v2/organizations/-/tailnets", bytes.NewBuffer(body)) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", tsClient.APIKey)) + resp, err := tsClient.HTTP.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to create tailnet: %w", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + b, _ := io.ReadAll(resp.Body) + return nil, fmt.Errorf("HTTP %d creating tailnet: %s", resp.StatusCode, string(b)) + } + var result struct { + OauthClient struct { + ID string `json:"id"` + Secret string `json:"secret"` + } `json:"oauthClient"` + } + if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { + return nil, fmt.Errorf("failed to decode response: %w", err) + } + return tailscaleClientFromSecret(ctx, tsClient.BaseURL.String(), result.OauthClient.ID, result.OauthClient.Secret) +} + +// tailscaleClientFromSecret exchanges OAuth client credentials for an access token and +// returns a tailscale.Client configured to use it. The token is valid for +// one hour, which is sufficient for the tests to run. No need for refresh logic. +func tailscaleClientFromSecret(ctx context.Context, baseURL, clientID, clientSecret string) (*tailscale.Client, error) { + cfg := clientcredentials.Config{ + ClientID: clientID, + ClientSecret: clientSecret, + TokenURL: fmt.Sprintf("%s/api/v2/oauth/token", baseURL), + } + tk, err := cfg.Token(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get OAuth token for client %q: %w", clientID, err) + } + return &tailscale.Client{ + APIKey: tk.AccessToken, + BaseURL: must.Get(url.Parse(baseURL)), + }, nil +} diff --git a/cmd/k8s-operator/egress-eps.go b/cmd/k8s-operator/egress-eps.go index a248ed888..9f8510165 100644 --- a/cmd/k8s-operator/egress-eps.go +++ b/cmd/k8s-operator/egress-eps.go @@ -20,6 +20,7 @@ import ( metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/reconcile" + "tailscale.com/kube/egressservices" ) @@ -90,7 +91,7 @@ func (er *egressEpsReconciler) Reconcile(ctx context.Context, req reconcile.Requ lg.Debugf("No egress config found, likely because ProxyGroup has not been created") return res, nil } - cfg, ok := (*cfgs)[tailnetSvc] + cfg, ok := cfgs[tailnetSvc] if !ok { lg.Infof("[unexpected] configuration for tailnet service %s not found", tailnetSvc) return res, nil diff --git a/cmd/k8s-operator/egress-services.go b/cmd/k8s-operator/egress-services.go index 4949db80a..b9a3f8eab 100644 --- a/cmd/k8s-operator/egress-services.go +++ b/cmd/k8s-operator/egress-services.go @@ -30,6 +30,7 @@ import ( "k8s.io/client-go/tools/record" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/reconcile" + tsoperator "tailscale.com/k8s-operator" tsapi "tailscale.com/k8s-operator/apis/v1alpha1" "tailscale.com/kube/egressservices" @@ -347,11 +348,11 @@ func (esr *egressSvcsReconciler) provision(ctx context.Context, proxyGroupName s return nil, false, nil } tailnetSvc := tailnetSvcName(svc) - gotCfg := (*cfgs)[tailnetSvc] + gotCfg := cfgs[tailnetSvc] wantsCfg := egressSvcCfg(svc, clusterIPSvc, esr.tsNamespace, lg) if !reflect.DeepEqual(gotCfg, wantsCfg) { lg.Debugf("updating egress services ConfigMap %s", cm.Name) - mak.Set(cfgs, tailnetSvc, wantsCfg) + mak.Set(&cfgs, tailnetSvc, wantsCfg) bs, err := json.Marshal(cfgs) if err != nil { return nil, false, fmt.Errorf("error marshalling egress services configs: %w", err) @@ -485,19 +486,19 @@ func (esr *egressSvcsReconciler) ensureEgressSvcCfgDeleted(ctx context.Context, lggr.Debugf("ConfigMap does not contain egress service configs") return nil } - cfgs := &egressservices.Configs{} - if err := json.Unmarshal(bs, cfgs); err != nil { + cfgs := egressservices.Configs{} + if err := json.Unmarshal(bs, &cfgs); err != nil { return fmt.Errorf("error unmarshalling egress services configs") } tailnetSvc := tailnetSvcName(svc) - _, ok := (*cfgs)[tailnetSvc] + _, ok := cfgs[tailnetSvc] if !ok { lggr.Debugf("ConfigMap does not contain egress service config, likely because it was already deleted") return nil } - lggr.Infof("before deleting config %+#v", *cfgs) - delete(*cfgs, tailnetSvc) - lggr.Infof("after deleting config %+#v", *cfgs) + lggr.Infof("before deleting config %+#v", cfgs) + delete(cfgs, tailnetSvc) + lggr.Infof("after deleting config %+#v", cfgs) bs, err := json.Marshal(cfgs) if err != nil { return fmt.Errorf("error marshalling egress services configs: %w", err) @@ -649,7 +650,7 @@ func isEgressSvcForProxyGroup(obj client.Object) bool { // egressSvcConfig returns a ConfigMap that contains egress services configuration for the provided ProxyGroup as well // as unmarshalled configuration from the ConfigMap. -func egressSvcsConfigs(ctx context.Context, cl client.Client, proxyGroupName, tsNamespace string) (cm *corev1.ConfigMap, cfgs *egressservices.Configs, err error) { +func egressSvcsConfigs(ctx context.Context, cl client.Client, proxyGroupName, tsNamespace string) (cm *corev1.ConfigMap, cfgs egressservices.Configs, err error) { name := pgEgressCMName(proxyGroupName) cm = &corev1.ConfigMap{ ObjectMeta: metav1.ObjectMeta{ @@ -664,9 +665,9 @@ func egressSvcsConfigs(ctx context.Context, cl client.Client, proxyGroupName, ts if err != nil { return nil, nil, fmt.Errorf("error retrieving egress services ConfigMap %s: %v", name, err) } - cfgs = &egressservices.Configs{} + cfgs = egressservices.Configs{} if len(cm.BinaryData[egressservices.KeyEgressServices]) != 0 { - if err := json.Unmarshal(cm.BinaryData[egressservices.KeyEgressServices], cfgs); err != nil { + if err := json.Unmarshal(cm.BinaryData[egressservices.KeyEgressServices], &cfgs); err != nil { return nil, nil, fmt.Errorf("error unmarshaling egress services config %v: %w", cm.BinaryData[egressservices.KeyEgressServices], err) } } diff --git a/cmd/k8s-operator/egress-services_test.go b/cmd/k8s-operator/egress-services_test.go index 8443a1573..a7dd79f7f 100644 --- a/cmd/k8s-operator/egress-services_test.go +++ b/cmd/k8s-operator/egress-services_test.go @@ -21,6 +21,7 @@ import ( "k8s.io/apimachinery/pkg/util/intstr" "sigs.k8s.io/controller-runtime/pkg/client" "sigs.k8s.io/controller-runtime/pkg/client/fake" + tsapi "tailscale.com/k8s-operator/apis/v1alpha1" "tailscale.com/kube/egressservices" "tailscale.com/tstest" @@ -284,11 +285,11 @@ func configFromCM(t *testing.T, cm *corev1.ConfigMap, svcName string) *egressser if !ok { return nil } - cfgs := &egressservices.Configs{} - if err := json.Unmarshal(cfgBs, cfgs); err != nil { + cfgs := egressservices.Configs{} + if err := json.Unmarshal(cfgBs, &cfgs); err != nil { t.Fatalf("error unmarshalling config: %v", err) } - cfg, ok := (*cfgs)[svcName] + cfg, ok := cfgs[svcName] if ok { return &cfg } diff --git a/cmd/k8s-operator/ingress-for-pg.go b/cmd/k8s-operator/ingress-for-pg.go index 37d0ed014..d6872f680 100644 --- a/cmd/k8s-operator/ingress-for-pg.go +++ b/cmd/k8s-operator/ingress-for-pg.go @@ -1081,7 +1081,7 @@ func certResourceLabels(pgName, domain string) map[string]string { return map[string]string{ kubetypes.LabelManaged: "true", labelProxyGroup: pgName, - labelDomain: domain, + labelDomain: tsoperator.TruncateLabelValue(domain), } } diff --git a/cmd/k8s-operator/metrics_resources.go b/cmd/k8s-operator/metrics_resources.go index c7c329a7e..4384f4cba 100644 --- a/cmd/k8s-operator/metrics_resources.go +++ b/cmd/k8s-operator/metrics_resources.go @@ -19,6 +19,7 @@ import ( "k8s.io/apimachinery/pkg/runtime" "k8s.io/apimachinery/pkg/types" "sigs.k8s.io/controller-runtime/pkg/client" + kube "tailscale.com/k8s-operator" tsapi "tailscale.com/k8s-operator/apis/v1alpha1" "tailscale.com/kube/kubetypes" ) @@ -227,13 +228,13 @@ func metricsResourceLabels(opts *metricsOpts) map[string]string { kubetypes.LabelManaged: "true", labelMetricsTarget: opts.proxyStsName, labelPromProxyType: opts.proxyType, - labelPromProxyParentName: opts.proxyLabels[LabelParentName], + labelPromProxyParentName: kube.TruncateLabelValue(opts.proxyLabels[LabelParentName]), } // Include namespace label for proxies created for a namespaced type. if isNamespacedProxyType(opts.proxyType) { - lbls[labelPromProxyParentNamespace] = opts.proxyLabels[LabelParentNamespace] + lbls[labelPromProxyParentNamespace] = kube.TruncateLabelValue(opts.proxyLabels[LabelParentNamespace]) } - lbls[labelPromJob] = promJobName(opts) + lbls[labelPromJob] = kube.TruncateLabelValue(promJobName(opts)) return lbls } @@ -250,11 +251,11 @@ func promJobName(opts *metricsOpts) string { func metricsSvcSelector(proxyLabels map[string]string, proxyType string) map[string]string { sel := map[string]string{ labelPromProxyType: proxyType, - labelPromProxyParentName: proxyLabels[LabelParentName], + labelPromProxyParentName: kube.TruncateLabelValue(proxyLabels[LabelParentName]), } // Include namespace label for proxies created for a namespaced type. if isNamespacedProxyType(proxyType) { - sel[labelPromProxyParentNamespace] = proxyLabels[LabelParentNamespace] + sel[labelPromProxyParentNamespace] = kube.TruncateLabelValue(proxyLabels[LabelParentNamespace]) } return sel } diff --git a/cmd/k8s-operator/nameserver.go b/cmd/k8s-operator/nameserver.go index 869e5bb26..f5565e5d3 100644 --- a/cmd/k8s-operator/nameserver.go +++ b/cmd/k8s-operator/nameserver.go @@ -190,6 +190,8 @@ func (a *NameserverReconciler) maybeProvision(ctx context.Context, tsDNSCfg *tsa } if tsDNSCfg.Spec.Nameserver.Pod != nil { dCfg.tolerations = tsDNSCfg.Spec.Nameserver.Pod.Tolerations + dCfg.affinity = tsDNSCfg.Spec.Nameserver.Pod.Affinity + dCfg.nodeSelector = tsDNSCfg.Spec.Nameserver.Pod.NodeSelector } for _, deployable := range []deployable{saDeployable, deployDeployable, svcDeployable, cmDeployable} { @@ -217,14 +219,16 @@ type deployable struct { } type deployConfig struct { - replicas int32 - imageRepo string - imageTag string - labels map[string]string - ownerRefs []metav1.OwnerReference - namespace string - clusterIP string - tolerations []corev1.Toleration + replicas int32 + imageRepo string + imageTag string + labels map[string]string + ownerRefs []metav1.OwnerReference + namespace string + clusterIP string + tolerations []corev1.Toleration + affinity *corev1.Affinity + nodeSelector map[string]string } var ( @@ -250,6 +254,8 @@ var ( d.ObjectMeta.Labels = cfg.labels d.ObjectMeta.OwnerReferences = cfg.ownerRefs d.Spec.Template.Spec.Tolerations = cfg.tolerations + d.Spec.Template.Spec.Affinity = cfg.affinity + d.Spec.Template.Spec.NodeSelector = cfg.nodeSelector updateF := func(oldD *appsv1.Deployment) { oldD.Spec = d.Spec } diff --git a/cmd/k8s-operator/nameserver_test.go b/cmd/k8s-operator/nameserver_test.go index e35c72fc0..3ec00d5ed 100644 --- a/cmd/k8s-operator/nameserver_test.go +++ b/cmd/k8s-operator/nameserver_test.go @@ -43,6 +43,9 @@ func TestNameserverReconciler(t *testing.T) { ClusterIP: "5.4.3.2", }, Pod: &tsapi.NameserverPod{ + NodeSelector: map[string]string{ + "foo": "bar", + }, Tolerations: []corev1.Toleration{ { Key: "some-key", @@ -51,6 +54,23 @@ func TestNameserverReconciler(t *testing.T) { Effect: corev1.TaintEffectNoSchedule, }, }, + Affinity: &corev1.Affinity{ + NodeAffinity: &corev1.NodeAffinity{ + RequiredDuringSchedulingIgnoredDuringExecution: &corev1.NodeSelector{ + NodeSelectorTerms: []corev1.NodeSelectorTerm{ + { + MatchExpressions: []corev1.NodeSelectorRequirement{ + { + Key: "some-key", + Operator: corev1.NodeSelectorOpIn, + Values: []string{"some-value"}, + }, + }, + }, + }, + }, + }, + }, }, }, }, @@ -97,6 +117,26 @@ func TestNameserverReconciler(t *testing.T) { Effect: corev1.TaintEffectNoSchedule, }, } + wantsDeploy.Spec.Template.Spec.Affinity = &corev1.Affinity{ + NodeAffinity: &corev1.NodeAffinity{ + RequiredDuringSchedulingIgnoredDuringExecution: &corev1.NodeSelector{ + NodeSelectorTerms: []corev1.NodeSelectorTerm{ + { + MatchExpressions: []corev1.NodeSelectorRequirement{ + { + Key: "some-key", + Operator: corev1.NodeSelectorOpIn, + Values: []string{"some-value"}, + }, + }, + }, + }, + }, + }, + } + wantsDeploy.Spec.Template.Spec.NodeSelector = map[string]string{ + "foo": "bar", + } expectEqual(t, fc, wantsDeploy) }) diff --git a/cmd/k8s-operator/operator.go b/cmd/k8s-operator/operator.go index c0ef96a68..9f9c71997 100644 --- a/cmd/k8s-operator/operator.go +++ b/cmd/k8s-operator/operator.go @@ -692,12 +692,14 @@ func runReconcilers(opts reconcilerOpts) { Watches(&rbacv1.Role{}, recorderFilter). Watches(&rbacv1.RoleBinding{}, recorderFilter). Complete(&RecorderReconciler{ - recorder: eventRecorder, - tsNamespace: opts.tailscaleNamespace, - Client: mgr.GetClient(), - log: opts.log.Named("recorder-reconciler"), - clock: tstime.DefaultClock{}, - clients: clients, + recorder: eventRecorder, + tsNamespace: opts.tailscaleNamespace, + Client: mgr.GetClient(), + log: opts.log.Named("recorder-reconciler"), + clock: tstime.DefaultClock{}, + clients: clients, + authKeyRateLimits: make(map[string]*rate.Limiter), + authKeyReissuing: make(map[string]bool), }) if err != nil { startlog.Fatalf("could not create Recorder reconciler: %v", err) diff --git a/cmd/k8s-operator/proxygroup.go b/cmd/k8s-operator/proxygroup.go index 4bd015701..9df8460b7 100644 --- a/cmd/k8s-operator/proxygroup.go +++ b/cmd/k8s-operator/proxygroup.go @@ -1160,6 +1160,9 @@ func (r *ProxyGroupReconciler) ensureStateRemovedForProxyGroup(pg *tsapi.ProxyGr gaugeIngressProxyGroupResources.Set(int64(r.ingressProxyGroups.Len())) gaugeAPIServerProxyGroupResources.Set(int64(r.apiServerProxyGroups.Len())) delete(r.authKeyRateLimits, pg.Name) + for i := range pgReplicas(pg) { + delete(r.authKeyReissuing, pgStateSecretName(pg.Name, i)) + } } func pgTailscaledConfig(pg *tsapi.ProxyGroup, loginServer string, pc *tsapi.ProxyClass, idx int32, authKey *string, staticEndpoints []netip.AddrPort, oldAdvertiseServices []string) (tailscaledConfigs, error) { diff --git a/cmd/k8s-operator/tsrecorder.go b/cmd/k8s-operator/tsrecorder.go index 881d82354..86669d212 100644 --- a/cmd/k8s-operator/tsrecorder.go +++ b/cmd/k8s-operator/tsrecorder.go @@ -14,9 +14,11 @@ import ( "strconv" "strings" "sync" + "time" "go.uber.org/zap" xslices "golang.org/x/exp/slices" + "golang.org/x/time/rate" appsv1 "k8s.io/api/apps/v1" corev1 "k8s.io/api/core/v1" rbacv1 "k8s.io/api/rbac/v1" @@ -57,14 +59,15 @@ var gaugeRecorderResources = clientmetric.NewGauge(kubetypes.MetricRecorderCount // Recorder CRs. type RecorderReconciler struct { client.Client - log *zap.SugaredLogger - recorder record.EventRecorder - clock tstime.Clock - clients ClientProvider - tsNamespace string - - mu sync.Mutex // protects following - recorders set.Slice[types.UID] // for recorders gauge + log *zap.SugaredLogger + recorder record.EventRecorder + clock tstime.Clock + clients ClientProvider + tsNamespace string + authKeyRateLimits map[string]*rate.Limiter // per-Recorder rate limiters for auth key re-issuance. + authKeyReissuing map[string]bool + mu sync.Mutex // protects following + recorders set.Slice[types.UID] // for recorders gauge } func (r *RecorderReconciler) logger(name string) *zap.SugaredLogger { @@ -164,9 +167,23 @@ func (r *RecorderReconciler) Reconcile(ctx context.Context, req reconcile.Reques func (r *RecorderReconciler) maybeProvision(ctx context.Context, tsClient tsclient.Client, tsr *tsapi.Recorder) error { logger := r.logger(tsr.Name) + var replicas int32 = 1 + if tsr.Spec.Replicas != nil { + replicas = *tsr.Spec.Replicas + } + r.mu.Lock() r.recorders.Add(tsr.UID) gaugeRecorderResources.Set(int64(r.recorders.Len())) + if _, ok := r.authKeyRateLimits[tsr.Name]; !ok { + r.authKeyRateLimits[tsr.Name] = rate.NewLimiter(rate.Every(30*time.Second), int(replicas)) + } + for replica := range replicas { + name := fmt.Sprintf("%s-%d", tsr.Name, replica) + if _, ok := r.authKeyReissuing[name]; !ok { + r.authKeyReissuing[name] = false + } + } r.mu.Unlock() if err := r.ensureAuthSecretsCreated(ctx, tsClient, tsr); err != nil { @@ -174,11 +191,6 @@ func (r *RecorderReconciler) maybeProvision(ctx context.Context, tsClient tsclie } // State Secrets are pre-created so we can use the Recorder CR as its owner ref. - var replicas int32 = 1 - if tsr.Spec.Replicas != nil { - replicas = *tsr.Spec.Replicas - } - for replica := range replicas { sec := tsrStateSecret(tsr, r.tsNamespace, replica) _, err := createOrUpdate(ctx, r.Client, r.tsNamespace, sec, func(s *corev1.Secret) { @@ -423,6 +435,10 @@ func (r *RecorderReconciler) maybeCleanup(ctx context.Context, tsr *tsapi.Record r.mu.Lock() r.recorders.Remove(tsr.UID) gaugeRecorderResources.Set(int64(r.recorders.Len())) + delete(r.authKeyRateLimits, tsr.Name) + for replica := range replicas { + delete(r.authKeyReissuing, fmt.Sprintf("%s-%d", tsr.Name, replica)) + } r.mu.Unlock() return true, nil @@ -447,28 +463,122 @@ func (r *RecorderReconciler) ensureAuthSecretsCreated(ctx context.Context, tsCli Name: fmt.Sprintf("%s-auth-%d", tsr.Name, replica), } - err := r.Get(ctx, key, &corev1.Secret{}) + existingSecret := &corev1.Secret{} + err := r.Get(ctx, key, existingSecret) switch { case err == nil: - logger.Debugf("auth Secret %q already exists", key.Name) + reissue, err := r.shouldReissueAuthKey(ctx, tsClient, tsr, replica, existingSecret) + if err != nil { + return fmt.Errorf("error checking auth key reissue for replica %d: %w", replica, err) + } + if !reissue { + logger.Debugf("auth Secret %q already exists, no reissue needed", key.Name) + continue + } + authKey, err := newAuthKey(ctx, tsClient, tags.Stringify()) + if err != nil { + return err + } + existingSecret.Data["authkey"] = []byte(authKey) + if err = r.Update(ctx, existingSecret); err != nil { + return err + } continue - case !apierrors.IsNotFound(err): + case apierrors.IsNotFound(err): + authKey, err := newAuthKey(ctx, tsClient, tags.Stringify()) + if err != nil { + return err + } + if err := r.Create(ctx, tsrAuthSecret(tsr, r.tsNamespace, authKey, replica)); err != nil { + return err + } + default: return fmt.Errorf("failed to get Secret %q: %w", key.Name, err) } - - authKey, err := newAuthKey(ctx, tsClient, tags.Stringify()) - if err != nil { - return err - } - - if err = r.Create(ctx, tsrAuthSecret(tsr, r.tsNamespace, authKey, replica)); err != nil { - return err - } } return nil } +// shouldReissueAuthKey returns true if the proxy needs a new auth key. It +// tracks in-flight reissues via authKeyReissuing to avoid duplicate API calls +// across reconciles. +func (r *RecorderReconciler) shouldReissueAuthKey(ctx context.Context, tsClient tsclient.Client, tsr *tsapi.Recorder, replica int32, authSecret *corev1.Secret) (shouldReissue bool, err error) { + stateSecret, err := r.getStateSecret(ctx, tsr.Name, replica) + if err != nil || stateSecret == nil { + return false, err + } + + stateSecretName := fmt.Sprintf("%s-%d", tsr.Name, replica) + + r.mu.Lock() + reissuing := r.authKeyReissuing[stateSecretName] + r.mu.Unlock() + + if reissuing { + _, requestStillPresent := stateSecret.Data[kubetypes.KeyReissueAuthkey] + if !requestStillPresent { + r.mu.Lock() + r.authKeyReissuing[stateSecretName] = false + r.mu.Unlock() + r.log.Debugf("auth key reissue completed for %q", stateSecretName) + return false, nil + } + r.log.Debugf("auth key already in process of re-issuance for %q, waiting", stateSecretName) + return false, nil + } + + defer func() { + r.mu.Lock() + r.authKeyReissuing[stateSecretName] = shouldReissue + r.mu.Unlock() + }() + + brokenAuthkey, ok := stateSecret.Data[kubetypes.KeyReissueAuthkey] + if !ok { + return false, nil + } + + cfgAuthKey := string(authSecret.Data["authkey"]) + empty := cfgAuthKey == "" + broken := cfgAuthKey == string(brokenAuthkey) + + if !empty && !broken { + return false, nil + } + + lim := r.authKeyRateLimits[tsr.Name] + if !lim.Allow() { + r.log.Debugf("auth key re-issuance rate limit exceeded, limit: %.2f, burst: %d, tokens: %.2f", + lim.Limit(), lim.Burst(), lim.Tokens()) + return false, fmt.Errorf("auth key re-issuance rate limit exceeded for Recorder %q, will retry with backoff", tsr.Name) + } + + r.log.Infof("Recorder replica %s failing to auth; attempting cleanup and new key", stateSecretName) + if tsID := stateSecret.Data[kubetypes.KeyDeviceID]; len(tsID) > 0 { + id := tailcfg.StableNodeID(tsID) + if err := r.ensureDeviceDeleted(ctx, tsClient, id, r.log); err != nil { + return false, err + } + } + + return true, nil +} + +func (r *RecorderReconciler) ensureDeviceDeleted(ctx context.Context, tsClient tsclient.Client, id tailcfg.StableNodeID, logger *zap.SugaredLogger) error { + logger.Debugf("deleting device %s from control", string(id)) + err := tsClient.Devices().Delete(ctx, string(id)) + switch { + case tailscale.IsNotFound(err): + logger.Debugf("device %s not found, likely because it has already been deleted from control", string(id)) + case err != nil: + return fmt.Errorf("error deleting device: %w", err) + default: + logger.Debugf("device %s deleted from control", string(id)) + } + return nil +} + func (r *RecorderReconciler) validate(ctx context.Context, tsr *tsapi.Recorder) error { if !tsr.Spec.EnableUI && tsr.Spec.Storage.S3 == nil { return errors.New("must either enable UI or use S3 storage to ensure recordings are accessible") diff --git a/cmd/k8s-operator/tsrecorder_test.go b/cmd/k8s-operator/tsrecorder_test.go index 6bd47e07b..8f189728c 100644 --- a/cmd/k8s-operator/tsrecorder_test.go +++ b/cmd/k8s-operator/tsrecorder_test.go @@ -14,6 +14,7 @@ import ( "github.com/google/go-cmp/cmp" "go.uber.org/zap" + "golang.org/x/time/rate" appsv1 "k8s.io/api/apps/v1" corev1 "k8s.io/api/core/v1" rbacv1 "k8s.io/api/rbac/v1" @@ -55,12 +56,14 @@ func TestRecorder(t *testing.T) { fr := record.NewFakeRecorder(2) cl := tstest.NewClock(tstest.ClockOpts{}) reconciler := &RecorderReconciler{ - tsNamespace: tsNamespace, - Client: fc, - clients: tsclient.NewProvider(tsClient), - recorder: fr, - log: zl.Sugar(), - clock: cl, + tsNamespace: tsNamespace, + Client: fc, + clients: tsclient.NewProvider(tsClient), + recorder: fr, + log: zl.Sugar(), + clock: cl, + authKeyRateLimits: make(map[string]*rate.Limiter), + authKeyReissuing: make(map[string]bool), } t.Run("invalid_spec_gives_an_error_condition", func(t *testing.T) { diff --git a/cmd/k8s-proxy/k8s-proxy.go b/cmd/k8s-proxy/k8s-proxy.go index 38a86a5e0..673493f58 100644 --- a/cmd/k8s-proxy/k8s-proxy.go +++ b/cmd/k8s-proxy/k8s-proxy.go @@ -31,6 +31,7 @@ import ( "k8s.io/utils/strings/slices" "tailscale.com/client/local" "tailscale.com/cmd/k8s-proxy/internal/config" + "tailscale.com/health" "tailscale.com/hostinfo" "tailscale.com/ipn" "tailscale.com/ipn/store" @@ -41,6 +42,7 @@ import ( "tailscale.com/kube/certs" healthz "tailscale.com/kube/health" "tailscale.com/kube/k8s-proxy/conf" + "tailscale.com/kube/kubeclient" "tailscale.com/kube/kubetypes" klc "tailscale.com/kube/localclient" "tailscale.com/kube/metrics" @@ -171,10 +173,31 @@ func run(logger *zap.SugaredLogger) error { // If Pod UID unset, assume we're running outside of a cluster/not managed // by the operator, so no need to set additional state keys. + var kc kubeclient.Client + var stateSecretName string if podUID != "" { if err := state.SetInitialKeys(st, podUID); err != nil { return fmt.Errorf("error setting initial state: %w", err) } + + if cfg.Parsed.State != nil { + if name, ok := strings.CutPrefix(*cfg.Parsed.State, "kube:"); ok { + stateSecretName = name + + kc, err = kubeclient.New(k8sProxyFieldManager) + if err != nil { + return err + } + + var configAuthKey string + if cfg.Parsed.AuthKey != nil { + configAuthKey = *cfg.Parsed.AuthKey + } + if err := resetState(ctx, kc, stateSecretName, podUID, configAuthKey); err != nil { + return fmt.Errorf("error resetting state: %w", err) + } + } + } } var authKey string @@ -197,23 +220,69 @@ func run(logger *zap.SugaredLogger) error { ts.Hostname = *cfg.Parsed.Hostname } - // Make sure we crash loop if Up doesn't complete in reasonable time. - upCtx, upCancel := context.WithTimeout(ctx, time.Minute) - defer upCancel() - if _, err := ts.Up(upCtx); err != nil { - return fmt.Errorf("error starting tailscale server: %w", err) - } - defer ts.Close() lc, err := ts.LocalClient() if err != nil { return fmt.Errorf("error getting local client: %w", err) } - // Setup for updating state keys. + // Make sure we crash loop if Up doesn't complete in reasonable time. + upCtx, upCancel := context.WithTimeout(ctx, 30*time.Second) + defer upCancel() + + // ts.Up() deliberately ignores NeedsLogin because it fires transiently + // during normal auth-key login. We can watch for the login-state health + // warning here though, which only fires on terminal auth failure, and + // cancel early. + go func() { + w, err := lc.WatchIPNBus(upCtx, ipn.NotifyInitialHealthState) + if err != nil { + return + } + defer w.Close() + for { + n, err := w.Next() + if err != nil { + logger.Debugf("failed to process message from ipn bus: %s", err.Error()) + return + } + if n.Health != nil { + if _, ok := n.Health.Warnings[health.LoginStateWarnable.Code]; ok { + upCancel() + return + } + } + } + }() + + if _, err := ts.Up(upCtx); err != nil { + if kc != nil && stateSecretName != "" { + return handleAuthKeyReissue(ctx, lc, kc, stateSecretName, authKey, cfgChan, logger) + } + return err + } + + defer ts.Close() + + reissueCh := make(chan struct{}, 1) if podUID != "" { group.Go(func() error { return state.KeepKeysUpdated(ctx, st, klc.New(lc)) }) + + if kc != nil && stateSecretName != "" { + needsReissue, err := checkInitialAuthState(ctx, lc) + if err != nil { + return fmt.Errorf("error checking initial auth state: %w", err) + } + if needsReissue { + logger.Info("Auth key missing or invalid after startup, requesting new key from operator") + return handleAuthKeyReissue(ctx, lc, kc, stateSecretName, authKey, cfgChan, logger) + } + + group.Go(func() error { + return monitorAuthHealth(ctx, lc, reissueCh, logger) + }) + } } if cfg.Parsed.HealthCheckEnabled.EqualBool(true) || cfg.Parsed.MetricsEnabled.EqualBool(true) { @@ -362,6 +431,8 @@ func run(logger *zap.SugaredLogger) error { } cfgLogger.Infof("Config reloaded") + case <-reissueCh: + return handleAuthKeyReissue(ctx, lc, kc, stateSecretName, authKey, cfgChan, logger) } } } diff --git a/cmd/k8s-proxy/kube.go b/cmd/k8s-proxy/kube.go new file mode 100644 index 000000000..1d9348f1a --- /dev/null +++ b/cmd/k8s-proxy/kube.go @@ -0,0 +1,161 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !plan9 + +package main + +import ( + "context" + "fmt" + "strings" + "sync" + "time" + + "go.uber.org/zap" + "tailscale.com/client/local" + "tailscale.com/health" + "tailscale.com/ipn" + "tailscale.com/kube/authkey" + "tailscale.com/kube/k8s-proxy/conf" + "tailscale.com/kube/kubeapi" + "tailscale.com/kube/kubeclient" + "tailscale.com/kube/kubetypes" + "tailscale.com/tailcfg" +) + +const k8sProxyFieldManager = "tailscale-k8s-proxy" + +// resetState clears k8s-proxy state from previous runs and sets +// initial values. This ensures the operator doesn't use stale state when a Pod +// is first recreated. +// +// It also clears the reissue_authkey marker if the operator has actioned it +// (i.e., the config now has a different auth key than what was marked for +// reissue). +func resetState(ctx context.Context, kc kubeclient.Client, stateSecretName string, podUID string, configAuthKey string) error { + existingSecret, err := kc.GetSecret(ctx, stateSecretName) + switch { + case kubeclient.IsNotFoundErr(err): + return nil + case err != nil: + return fmt.Errorf("failed to read state Secret %q to reset state: %w", stateSecretName, err) + } + + s := &kubeapi.Secret{ + Data: map[string][]byte{ + kubetypes.KeyCapVer: fmt.Appendf(nil, "%d", tailcfg.CurrentCapabilityVersion), + }, + } + if podUID != "" { + s.Data[kubetypes.KeyPodUID] = []byte(podUID) + } + + // Only clear reissue_authkey if the operator has actioned it. + brokenAuthkey, ok := existingSecret.Data[kubetypes.KeyReissueAuthkey] + if ok && configAuthKey != "" && string(brokenAuthkey) != configAuthKey { + s.Data[kubetypes.KeyReissueAuthkey] = nil + } + + return kc.StrategicMergePatchSecret(ctx, stateSecretName, s, k8sProxyFieldManager) +} + +// needsAuthKeyReissue reports whether the given backend state and health +// warnings indicate a terminal auth failure requiring a new key from the +// operator. +func needsAuthKeyReissue(backendState string, healthWarnings []string) bool { + if backendState == ipn.NeedsLogin.String() { + return true + } + loginWarnableCode := string(health.LoginStateWarnable.Code) + for _, h := range healthWarnings { + if strings.Contains(h, loginWarnableCode) { + return true + } + } + return false +} + +// checkInitialAuthState checks if the tsnet server is in an auth failure state +// immediately after coming up. Returns true if auth key reissue is needed. +func checkInitialAuthState(ctx context.Context, lc *local.Client) (bool, error) { + status, err := lc.Status(ctx) + if err != nil { + return false, fmt.Errorf("error getting status: %w", err) + } + return needsAuthKeyReissue(status.BackendState, status.Health), nil +} + +// monitorAuthHealth watches the IPN bus for auth failures and triggers reissue +// when needed. Runs until context is cancelled or auth failure is detected. +func monitorAuthHealth(ctx context.Context, lc *local.Client, reissueCh chan<- struct{}, logger *zap.SugaredLogger) error { + w, err := lc.WatchIPNBus(ctx, ipn.NotifyInitialHealthState) + if err != nil { + return fmt.Errorf("failed to watch IPN bus for auth health: %w", err) + } + defer w.Close() + + for { + if ctx.Err() != nil { + return ctx.Err() + } + n, err := w.Next() + if err != nil { + return err + } + if n.Health != nil { + if _, ok := n.Health.Warnings[health.LoginStateWarnable.Code]; ok { + logger.Info("Auth key failed to authenticate (may be expired or single-use), requesting new key from operator") + select { + case reissueCh <- struct{}{}: + case <-ctx.Done(): + } + return nil + } + } + } +} + +// handleAuthKeyReissue orchestrates the auth key reissue flow: +// 1. Disconnect from control +// 2. Set reissue marker in state Secret +// 3. Wait for operator to provide new key +// 4. Exit cleanly (Kubernetes will restart the pod with the new key) +func handleAuthKeyReissue(ctx context.Context, lc *local.Client, kc kubeclient.Client, stateSecretName string, currentAuthKey string, cfgChan <-chan *conf.Config, logger *zap.SugaredLogger) error { + if err := lc.DisconnectControl(ctx); err != nil { + return fmt.Errorf("error disconnecting from control: %w", err) + } + if err := authkey.SetReissueAuthKey(ctx, kc, stateSecretName, currentAuthKey, k8sProxyFieldManager); err != nil { + return fmt.Errorf("failed to set reissue_authkey in Kubernetes Secret: %w", err) + } + + var mu sync.Mutex + var latestAuthKey string + notify := make(chan struct{}, 1) + + // we use this go func to abstract away conf.Config from the shared function + go func() { + for cfg := range cfgChan { + if cfg.Parsed.AuthKey != nil { + mu.Lock() + latestAuthKey = *cfg.Parsed.AuthKey + mu.Unlock() + select { + case notify <- struct{}{}: + default: + } + } + } + }() + + getAuthKey := func() string { + mu.Lock() + defer mu.Unlock() + return latestAuthKey + } + clearFn := func(ctx context.Context) error { + return authkey.ClearReissueAuthKey(ctx, kc, stateSecretName, k8sProxyFieldManager) + } + + return authkey.WaitForAuthKeyReissue(ctx, currentAuthKey, 10*time.Minute, getAuthKey, clearFn, notify) +} diff --git a/cmd/k8s-proxy/kube_test.go b/cmd/k8s-proxy/kube_test.go new file mode 100644 index 000000000..c7e0f33d0 --- /dev/null +++ b/cmd/k8s-proxy/kube_test.go @@ -0,0 +1,141 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !plan9 + +package main + +import ( + "context" + "fmt" + "testing" + + "github.com/google/go-cmp/cmp" + "tailscale.com/health" + "tailscale.com/kube/kubeapi" + "tailscale.com/kube/kubeclient" + "tailscale.com/kube/kubetypes" + "tailscale.com/tailcfg" +) + +func TestResetState(t *testing.T) { + tests := []struct { + name string + existingData map[string][]byte + podUID string + configAuthKey string + wantPatched map[string][]byte + }{ + { + name: "sets_capver_and_pod_uid", + existingData: map[string][]byte{ + kubetypes.KeyDeviceID: []byte("device-123"), + kubetypes.KeyDeviceFQDN: []byte("node.tailnet"), + kubetypes.KeyDeviceIPs: []byte(`["100.64.0.1"]`), + }, + podUID: "pod-123", + configAuthKey: "new-key", + wantPatched: map[string][]byte{ + kubetypes.KeyPodUID: []byte("pod-123"), + }, + }, + { + name: "clears_reissue_marker_when_actioned", + existingData: map[string][]byte{ + kubetypes.KeyReissueAuthkey: []byte("old-key"), + }, + podUID: "pod-123", + configAuthKey: "new-key", + wantPatched: map[string][]byte{ + kubetypes.KeyPodUID: []byte("pod-123"), + kubetypes.KeyReissueAuthkey: nil, + }, + }, + { + name: "keeps_reissue_marker_when_not_actioned", + existingData: map[string][]byte{ + kubetypes.KeyReissueAuthkey: []byte("old-key"), + }, + podUID: "pod-123", + configAuthKey: "old-key", + wantPatched: map[string][]byte{ + kubetypes.KeyPodUID: []byte("pod-123"), + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.wantPatched[kubetypes.KeyCapVer] = fmt.Appendf(nil, "%d", tailcfg.CurrentCapabilityVersion) + + var patched map[string][]byte + kc := &kubeclient.FakeClient{ + GetSecretImpl: func(ctx context.Context, name string) (*kubeapi.Secret, error) { + return &kubeapi.Secret{Data: tt.existingData}, nil + }, + StrategicMergePatchSecretImpl: func(ctx context.Context, name string, s *kubeapi.Secret, fm string) error { + patched = s.Data + return nil + }, + } + + err := resetState(context.Background(), kc, "test-secret", tt.podUID, tt.configAuthKey) + if err != nil { + t.Fatalf("resetState() error = %v", err) + } + + if diff := cmp.Diff(tt.wantPatched, patched); diff != "" { + t.Errorf("resetState() mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestNeedsAuthKeyReissue(t *testing.T) { + loginWarnableCode := string(health.LoginStateWarnable.Code) + + tests := []struct { + name string + backendState string + health []string + want bool + }{ + { + name: "running_healthy", + backendState: "Running", + want: false, + }, + { + name: "needs_login", + backendState: "NeedsLogin", + want: true, + }, + { + name: "running_with_login_warning", + backendState: "Running", + health: []string{"warning: " + loginWarnableCode + ": you are logged out"}, + want: true, + }, + { + name: "running_with_unrelated_warning", + backendState: "Running", + health: []string{"dns-not-working"}, + want: false, + }, + { + name: "running_no_warnings", + backendState: "Running", + health: nil, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := needsAuthKeyReissue(tt.backendState, tt.health) + if got != tt.want { + t.Errorf("needsAuthKeyReissue() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/cmd/nardump/nardump.go b/cmd/nardump/nardump.go index c8db24cb6..38a2a6731 100644 --- a/cmd/nardump/nardump.go +++ b/cmd/nardump/nardump.go @@ -9,22 +9,13 @@ // git-pull-oss.sh having Nix available. package main -// For the format, see: -// See https://gist.github.com/jbeda/5c79d2b1434f0018d693 - import ( - "bufio" - "crypto/sha256" - "encoding/base64" - "encoding/binary" "flag" "fmt" - "io" - "io/fs" "log" "os" - "path" - "sort" + + "tailscale.com/cmd/nardump/nardump" ) var sri = flag.Bool("sri", false, "print SRI") @@ -34,167 +25,16 @@ func main() { if flag.NArg() != 1 { log.Fatal("usage: nardump ") } - arg := flag.Arg(0) - if err := os.Chdir(arg); err != nil { - log.Fatal(err) - } + fsys := os.DirFS(flag.Arg(0)) if *sri { - hash := sha256.New() - if err := writeNAR(hash, os.DirFS(".")); err != nil { + s, err := nardump.SRI(fsys) + if err != nil { log.Fatal(err) } - fmt.Printf("sha256-%s\n", base64.StdEncoding.EncodeToString(hash.Sum(nil))) + fmt.Println(s) return } - bw := bufio.NewWriter(os.Stdout) - if err := writeNAR(bw, os.DirFS(".")); err != nil { + if err := nardump.WriteNAR(os.Stdout, fsys); err != nil { log.Fatal(err) } - bw.Flush() -} - -// writeNARError is a sentinel panic type that's recovered by writeNAR -// and converted into the wrapped error. -type writeNARError struct{ err error } - -// narWriter writes NAR files. -type narWriter struct { - w io.Writer - fs fs.FS -} - -// writeNAR writes a NAR file to w from the root of fs. -func writeNAR(w io.Writer, fs fs.FS) (err error) { - defer func() { - if e := recover(); e != nil { - if we, ok := e.(writeNARError); ok { - err = we.err - return - } - panic(e) - } - }() - nw := &narWriter{w: w, fs: fs} - nw.str("nix-archive-1") - return nw.writeDir(".") -} - -func (nw *narWriter) writeDir(dirPath string) error { - ents, err := fs.ReadDir(nw.fs, dirPath) - if err != nil { - return err - } - sort.Slice(ents, func(i, j int) bool { - return ents[i].Name() < ents[j].Name() - }) - nw.str("(") - nw.str("type") - nw.str("directory") - for _, ent := range ents { - nw.str("entry") - nw.str("(") - nw.str("name") - nw.str(ent.Name()) - nw.str("node") - mode := ent.Type() - sub := path.Join(dirPath, ent.Name()) - var err error - switch { - case mode.IsDir(): - err = nw.writeDir(sub) - case mode.IsRegular(): - err = nw.writeRegular(sub) - case mode&os.ModeSymlink != 0: - err = nw.writeSymlink(sub) - default: - return fmt.Errorf("unsupported file type %v at %q", sub, mode) - } - if err != nil { - return err - } - nw.str(")") - } - nw.str(")") - return nil -} - -func (nw *narWriter) writeRegular(path string) error { - nw.str("(") - nw.str("type") - nw.str("regular") - fi, err := fs.Stat(nw.fs, path) - if err != nil { - return err - } - if fi.Mode()&0111 != 0 { - nw.str("executable") - nw.str("") - } - contents, err := fs.ReadFile(nw.fs, path) - if err != nil { - return err - } - nw.str("contents") - if err := writeBytes(nw.w, contents); err != nil { - return err - } - nw.str(")") - return nil -} - -func (nw *narWriter) writeSymlink(path string) error { - nw.str("(") - nw.str("type") - nw.str("symlink") - nw.str("target") - // broken symlinks are valid in a nar - // given we do os.chdir(dir) and os.dirfs(".") above - // readlink now resolves relative links even if they are broken - link, err := os.Readlink(path) - if err != nil { - return err - } - nw.str(link) - nw.str(")") - return nil -} - -func (nw *narWriter) str(s string) { - if err := writeString(nw.w, s); err != nil { - panic(writeNARError{err}) - } -} - -func writeString(w io.Writer, s string) error { - var buf [8]byte - binary.LittleEndian.PutUint64(buf[:], uint64(len(s))) - if _, err := w.Write(buf[:]); err != nil { - return err - } - if _, err := io.WriteString(w, s); err != nil { - return err - } - return writePad(w, len(s)) -} - -func writeBytes(w io.Writer, b []byte) error { - var buf [8]byte - binary.LittleEndian.PutUint64(buf[:], uint64(len(b))) - if _, err := w.Write(buf[:]); err != nil { - return err - } - if _, err := w.Write(b); err != nil { - return err - } - return writePad(w, len(b)) -} - -func writePad(w io.Writer, n int) error { - pad := n % 8 - if pad == 0 { - return nil - } - var zeroes [8]byte - _, err := w.Write(zeroes[:8-pad]) - return err } diff --git a/cmd/nardump/nardump/nardump.go b/cmd/nardump/nardump/nardump.go new file mode 100644 index 000000000..ab9ff1f3c --- /dev/null +++ b/cmd/nardump/nardump/nardump.go @@ -0,0 +1,193 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +// Package nardump writes a NAR (Nix Archive) representation of an +// fs.FS to an io.Writer, or summarizes it as a Subresource Integrity +// hash, as used by Nix flake.nix vendor and toolchain hashes. +// +// For the format, see: +// https://gist.github.com/jbeda/5c79d2b1434f0018d693 +package nardump + +import ( + "bufio" + "crypto/sha256" + "encoding/base64" + "encoding/binary" + "fmt" + "io" + "io/fs" + "path" + "sort" +) + +// WriteNAR writes a NAR-encoded representation of fsys, rooted at +// the FS root, to w. +// +// The encoder issues many small writes; if w is not already a +// *bufio.Writer, WriteNAR wraps it in one and flushes on return so +// the caller doesn't have to. +// +// fsys must implement fs.ReadLinkFS to encode any symlinks it +// contains; os.DirFS satisfies this on Go 1.25+. +func WriteNAR(w io.Writer, fsys fs.FS) (err error) { + defer func() { + if e := recover(); e != nil { + if we, ok := e.(writeNARError); ok { + err = we.err + return + } + panic(e) + } + }() + bw, ok := w.(*bufio.Writer) + if !ok { + bw = bufio.NewWriter(w) + defer func() { + if flushErr := bw.Flush(); err == nil { + err = flushErr + } + }() + } + nw := &narWriter{w: bw, fs: fsys} + nw.str("nix-archive-1") + return nw.writeDir(".") +} + +// SRI returns the Subresource Integrity hash of the NAR encoding of +// fsys, in the form "sha256-". This is the format Nix +// expects for vendorHash and similar fields. +func SRI(fsys fs.FS) (string, error) { + h := sha256.New() + if err := WriteNAR(h, fsys); err != nil { + return "", err + } + return "sha256-" + base64.StdEncoding.EncodeToString(h.Sum(nil)), nil +} + +// writeNARError is a sentinel panic type that's recovered by +// WriteNAR and converted into the wrapped error. +type writeNARError struct{ err error } + +// narWriter writes NAR files. +type narWriter struct { + w io.Writer + fs fs.FS +} + +func (nw *narWriter) writeDir(dirPath string) error { + ents, err := fs.ReadDir(nw.fs, dirPath) + if err != nil { + return err + } + sort.Slice(ents, func(i, j int) bool { + return ents[i].Name() < ents[j].Name() + }) + nw.str("(") + nw.str("type") + nw.str("directory") + for _, ent := range ents { + nw.str("entry") + nw.str("(") + nw.str("name") + nw.str(ent.Name()) + nw.str("node") + mode := ent.Type() + sub := path.Join(dirPath, ent.Name()) + var err error + switch { + case mode.IsDir(): + err = nw.writeDir(sub) + case mode.IsRegular(): + err = nw.writeRegular(sub) + case mode&fs.ModeSymlink != 0: + err = nw.writeSymlink(sub) + default: + return fmt.Errorf("unsupported file type %v at %q", sub, mode) + } + if err != nil { + return err + } + nw.str(")") + } + nw.str(")") + return nil +} + +func (nw *narWriter) writeRegular(p string) error { + nw.str("(") + nw.str("type") + nw.str("regular") + fi, err := fs.Stat(nw.fs, p) + if err != nil { + return err + } + if fi.Mode()&0111 != 0 { + nw.str("executable") + nw.str("") + } + contents, err := fs.ReadFile(nw.fs, p) + if err != nil { + return err + } + nw.str("contents") + if err := writeBytes(nw.w, contents); err != nil { + return err + } + nw.str(")") + return nil +} + +func (nw *narWriter) writeSymlink(p string) error { + nw.str("(") + nw.str("type") + nw.str("symlink") + nw.str("target") + link, err := fs.ReadLink(nw.fs, p) + if err != nil { + return err + } + nw.str(link) + nw.str(")") + return nil +} + +func (nw *narWriter) str(s string) { + if err := writeString(nw.w, s); err != nil { + panic(writeNARError{err}) + } +} + +func writeString(w io.Writer, s string) error { + var buf [8]byte + binary.LittleEndian.PutUint64(buf[:], uint64(len(s))) + if _, err := w.Write(buf[:]); err != nil { + return err + } + if _, err := io.WriteString(w, s); err != nil { + return err + } + return writePad(w, len(s)) +} + +func writeBytes(w io.Writer, b []byte) error { + var buf [8]byte + binary.LittleEndian.PutUint64(buf[:], uint64(len(b))) + if _, err := w.Write(buf[:]); err != nil { + return err + } + if _, err := w.Write(b); err != nil { + return err + } + return writePad(w, len(b)) +} + +func writePad(w io.Writer, n int) error { + pad := n % 8 + if pad == 0 { + return nil + } + var zeroes [8]byte + _, err := w.Write(zeroes[:8-pad]) + return err +} diff --git a/cmd/nardump/nardump/nardump_test.go b/cmd/nardump/nardump/nardump_test.go new file mode 100644 index 000000000..16b690ee2 --- /dev/null +++ b/cmd/nardump/nardump/nardump_test.go @@ -0,0 +1,55 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package nardump + +import ( + "crypto/sha256" + "fmt" + "os" + "path/filepath" + "runtime" + "testing" +) + +// setupTmpdir sets up a known golden layout, covering all allowed file/folder types in a nar. +func setupTmpdir(t *testing.T) string { + t.Helper() + tmpdir := t.TempDir() + must := func(err error) { + t.Helper() + if err != nil { + t.Fatal(err) + } + } + must(os.MkdirAll(filepath.Join(tmpdir, "sub/dir"), 0755)) + must(os.Symlink("brokenfile", filepath.Join(tmpdir, "brokenlink"))) + must(os.Symlink("sub/dir", filepath.Join(tmpdir, "dirl"))) + must(os.Symlink("/abs/nonexistentdir", filepath.Join(tmpdir, "dirb"))) + f, err := os.Create(filepath.Join(tmpdir, "sub/dir/file1")) + must(err) + f.Close() + f, err = os.Create(filepath.Join(tmpdir, "file2m")) + must(err) + must(f.Truncate(2 * 1024 * 1024)) + f.Close() + must(os.Symlink("../file2m", filepath.Join(tmpdir, "sub/goodlink"))) + return tmpdir +} + +func TestWriteNAR(t *testing.T) { + if runtime.GOOS == "windows" { + // Skip test on Windows as the Nix package manager is not supported on this platform + t.Skip("nix package manager is not available on Windows") + } + dir := setupTmpdir(t) + // obtained via `nix-store --dump /tmp/... | sha256sum` of the above test dir + const expected = "727613a36f41030e93a4abf2649c3ec64a2757ccff364e3f6f7d544eb976e442" + h := sha256.New() + if err := WriteNAR(h, os.DirFS(dir)); err != nil { + t.Fatal(err) + } + if got := fmt.Sprintf("%x", h.Sum(nil)); got != expected { + t.Fatalf("sha256sum of nar: got %s, want %s", got, expected) + } +} diff --git a/cmd/nardump/nardump_test.go b/cmd/nardump/nardump_test.go deleted file mode 100644 index c1ca825e1..000000000 --- a/cmd/nardump/nardump_test.go +++ /dev/null @@ -1,52 +0,0 @@ -// Copyright (c) Tailscale Inc & contributors -// SPDX-License-Identifier: BSD-3-Clause - -package main - -import ( - "crypto/sha256" - "fmt" - "os" - "runtime" - "testing" -) - -// setupTmpdir sets up a known golden layout, covering all allowed file/folder types in a nar -func setupTmpdir(t *testing.T) string { - tmpdir := t.TempDir() - pwd, _ := os.Getwd() - os.Chdir(tmpdir) - defer os.Chdir(pwd) - os.MkdirAll("sub/dir", 0755) - os.Symlink("brokenfile", "brokenlink") - os.Symlink("sub/dir", "dirl") - os.Symlink("/abs/nonexistentdir", "dirb") - os.Create("sub/dir/file1") - f, _ := os.Create("file2m") - _ = f.Truncate(2 * 1024 * 1024) - f.Close() - os.Symlink("../file2m", "sub/goodlink") - return tmpdir -} - -func TestWriteNar(t *testing.T) { - if runtime.GOOS == "windows" { - // Skip test on Windows as the Nix package manager is not supported on this platform - t.Skip("nix package manager is not available on Windows") - } - dir := setupTmpdir(t) - t.Run("nar", func(t *testing.T) { - // obtained via `nix-store --dump /tmp/... | sha256sum` of the above test dir - expected := "727613a36f41030e93a4abf2649c3ec64a2757ccff364e3f6f7d544eb976e442" - h := sha256.New() - os.Chdir(dir) - err := writeNAR(h, os.DirFS(".")) - if err != nil { - t.Fatal(err) - } - hash := fmt.Sprintf("%x", h.Sum(nil)) - if expected != hash { - t.Fatal("sha256sum of nar not matched", hash, expected) - } - }) -} diff --git a/cmd/pgproxy/pgproxy.go b/cmd/pgproxy/pgproxy.go index ded6fa695..a138eacdc 100644 --- a/cmd/pgproxy/pgproxy.go +++ b/cmd/pgproxy/pgproxy.go @@ -291,7 +291,7 @@ func (p *proxy) serve(sessionID int64, c net.Conn) error { Certificates: p.downstreamCert, MinVersion: tls.VersionTLS12, }) - if err = uptc.HandshakeContext(ctx); err != nil { + if err = s.HandshakeContext(ctx); err != nil { p.errors.Add("client-tls", 1) return fmt.Errorf("client TLS handshake: %v", err) } diff --git a/cmd/sniproxy/sniproxy.go b/cmd/sniproxy/sniproxy.go index bd95cc113..f7ebc6aba 100644 --- a/cmd/sniproxy/sniproxy.go +++ b/cmd/sniproxy/sniproxy.go @@ -138,9 +138,9 @@ func run(ctx context.Context, ts *tsnet.Server, wgPort int, hostname string, pro } // Finally, start mainloop to configure app connector based on information - // in the netmap. - // We set the NotifyInitialNetMap flag so we will always get woken with the - // current netmap, before only being woken on changes. + // in the self node's CapMap. We set NotifyInitialNetMap so the first + // Notify carries the current self node (now via Notify.SelfChange); + // subsequent self changes wake us up too. bus, err := lc.WatchIPNBus(ctx, ipn.NotifyWatchEngineUpdates|ipn.NotifyInitialNetMap) if err != nil { log.Fatalf("watching IPN bus: %v", err) @@ -155,28 +155,30 @@ func run(ctx context.Context, ts *tsnet.Server, wgPort int, hostname string, pro log.Fatalf("reading IPN bus: %v", err) } - // NetMap contains app-connector configuration - if nm := msg.NetMap; nm != nil && nm.SelfNode.Valid() { - var c appctype.AppConnectorConfig - nmConf, err := tailcfg.UnmarshalNodeCapViewJSON[appctype.AppConnectorConfig](nm.SelfNode.CapMap(), configCapKey) - if err != nil { - log.Printf("failed to read app connector configuration from coordination server: %v", err) - } else if len(nmConf) > 0 { - c = nmConf[0] - } - - if c.AdvertiseRoutes { - if err := s.advertiseRoutesFromConfig(ctx, &c); err != nil { - log.Printf("failed to advertise routes: %v", err) - } - } - - // Backwards compatibility: combine any configuration from control with flags specified - // on the command line. This is intentionally done after we advertise any routes - // because its never correct to advertise the nodes native IP addresses. - s.mergeConfigFromFlags(&c, ports, forwards) - s.srv.Configure(&c) + self := msg.SelfChange + if self == nil { + continue } + var c appctype.AppConnectorConfig + // View() lets us reuse the existing CapView decoder. + nmConf, err := tailcfg.UnmarshalNodeCapViewJSON[appctype.AppConnectorConfig](self.View().CapMap(), configCapKey) + if err != nil { + log.Printf("failed to read app connector configuration from coordination server: %v", err) + } else if len(nmConf) > 0 { + c = nmConf[0] + } + + if c.AdvertiseRoutes { + if err := s.advertiseRoutesFromConfig(ctx, &c); err != nil { + log.Printf("failed to advertise routes: %v", err) + } + } + + // Backwards compatibility: combine any configuration from control with flags specified + // on the command line. This is intentionally done after we advertise any routes + // because its never correct to advertise the nodes native IP addresses. + s.mergeConfigFromFlags(&c, ports, forwards) + s.srv.Configure(&c) } } diff --git a/cmd/systray/systray.go b/cmd/systray/systray.go index 9dc35f142..68a339782 100644 --- a/cmd/systray/systray.go +++ b/cmd/systray/systray.go @@ -15,9 +15,11 @@ import ( ) var socket = flag.String("socket", paths.DefaultTailscaledSocket(), "path to tailscaled socket") +var theme = flag.String("theme", "dark", "color theme for Tailscale icon: dark, dark:nobg, light, light:nobg") func main() { flag.Parse() lc := &local.Client{Socket: *socket} + systray.SetTheme(*theme) new(systray.Menu).Run(lc) } diff --git a/cmd/tailscale/cli/cli.go b/cmd/tailscale/cli/cli.go index 8a2c2b9ef..16a14461c 100644 --- a/cmd/tailscale/cli/cli.go +++ b/cmd/tailscale/cli/cli.go @@ -28,6 +28,7 @@ import ( "tailscale.com/feature" "tailscale.com/paths" "tailscale.com/util/slicesx" + "tailscale.com/util/testenv" "tailscale.com/version/distro" ) @@ -92,8 +93,8 @@ var localClient = local.Client{ Socket: paths.DefaultTailscaledSocket(), } -// Run runs the CLI. The args do not include the binary name. -func Run(args []string) (err error) { +// RunWithContext runs the CLI. The args do not include the binary name. +func RunWithContext(ctx context.Context, args []string) (err error) { if runtime.GOOS == "linux" && os.Getenv("GOKRAZY_FIRST_START") == "1" && distro.Get() == distro.Gokrazy && os.Getppid() == 1 && len(args) == 0 { // We're running on gokrazy and the user did not specify 'up'. // Don't run the tailscale CLI and spam logs with usage; just exit. @@ -163,7 +164,7 @@ func Run(args []string) (err error) { return } - err = rootCmd.Run(context.Background()) + err = rootCmd.Run(ctx) if local.IsAccessDeniedError(err) && os.Getuid() != 0 && runtime.GOOS != "windows" { return fmt.Errorf("%v\n\nUse 'sudo tailscale %s'.\nTo not require root, use 'sudo tailscale set --operator=$USER' once.", err, strings.Join(args, " ")) } @@ -173,6 +174,11 @@ func Run(args []string) (err error) { return err } +// Run is equivalent to calling [RunWithContext] with the background context. +func Run(args []string) (err error) { + return RunWithContext(context.Background(), args) +} + type onceFlagValue struct { flag.Value set bool @@ -194,17 +200,39 @@ func (v *onceFlagValue) IsBoolFlag() bool { return ok && bf.IsBoolFlag() } -// noDupFlagify modifies c recursively to make all the -// flag values be wrappers that permit setting the value -// at most once. -func noDupFlagify(c *ffcli.Command) { - if c.FlagSet != nil { - c.FlagSet.VisitAll(func(f *flag.Flag) { - f.Value = &onceFlagValue{Value: f.Value} - }) +// noDupFlagify modifies c recursively to make all the flag values be +// wrappers that permit setting the value at most once. If tb is +// non-nil, the original values are restored when the test completes. +func noDupFlagify(c *ffcli.Command, tb testenv.TB) { + if tb == nil && testenv.InTest() { + return } - for _, sub := range c.Subcommands { - noDupFlagify(sub) + type restore struct { + f *flag.Flag + v flag.Value + } + var restores []restore + var walk func(*ffcli.Command) + walk = func(c *ffcli.Command) { + if c.FlagSet != nil { + c.FlagSet.VisitAll(func(f *flag.Flag) { + if tb != nil { + restores = append(restores, restore{f, f.Value}) + } + f.Value = &onceFlagValue{Value: f.Value} + }) + } + for _, sub := range c.Subcommands { + walk(sub) + } + } + walk(c) + if tb != nil { + tb.Cleanup(func() { + for _, r := range restores { + r.f.Value = r.v + } + }) } } @@ -221,7 +249,7 @@ var ( _ func() *ffcli.Command ) -func newRootCmd() *ffcli.Command { +func newRootCmd(tb ...testenv.TB) *ffcli.Command { rootfs := newFlagSet("tailscale") rootfs.Func("socket", "path to tailscaled socket", func(s string) error { localClient.Socket = s @@ -303,7 +331,11 @@ change in the future. }) ffcomplete.Inject(rootCmd, func(c *ffcli.Command) { c.LongHelp = hidden + c.LongHelp }, usageFunc) - noDupFlagify(rootCmd) + var t testenv.TB + if len(tb) > 0 { + t = tb[0] + } + noDupFlagify(rootCmd, t) return rootCmd } diff --git a/cmd/tailscale/cli/cli_test.go b/cmd/tailscale/cli/cli_test.go index f95d84695..d2df825d3 100644 --- a/cmd/tailscale/cli/cli_test.go +++ b/cmd/tailscale/cli/cli_test.go @@ -779,11 +779,43 @@ func TestPrefsFromUpArgs(t *testing.T) { wantErr: `--exit-node-allow-lan-access can only be used with --exit-node`, }, { - name: "error_tag_prefix", + name: "error_tag_bad_prefix", args: upArgsT{ - advertiseTags: "foo", + advertiseTags: "notatag:foo", + }, + wantErr: `tag: "notatag:foo": tags must start with 'tag:'`, + }, + { + name: "tag_auto_prefix", + args: upArgsFromOSArgs("linux", "--advertise-tags=foo,bar"), + want: &ipn.Prefs{ + ControlURL: ipn.DefaultControlURL, + WantRunning: true, + CorpDNS: true, + AdvertiseTags: []string{"tag:foo", "tag:bar"}, + NoSNAT: false, + NoStatefulFiltering: "true", + NetfilterMode: preftype.NetfilterOn, + AutoUpdate: ipn.AutoUpdatePrefs{ + Check: true, + }, + }, + }, + { + name: "tag_mixed_prefix", + args: upArgsFromOSArgs("linux", "--advertise-tags=tag:foo,bar"), + want: &ipn.Prefs{ + ControlURL: ipn.DefaultControlURL, + WantRunning: true, + CorpDNS: true, + AdvertiseTags: []string{"tag:foo", "tag:bar"}, + NoSNAT: false, + NoStatefulFiltering: "true", + NetfilterMode: preftype.NetfilterOn, + AutoUpdate: ipn.AutoUpdatePrefs{ + Check: true, + }, }, - wantErr: `tag: "foo": tags must start with 'tag:'`, }, { name: "error_long_hostname", @@ -1618,7 +1650,7 @@ func TestNoDups(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - cmd := newRootCmd() + cmd := newRootCmd(t) makeQuietContinueOnError(cmd) err := cmd.Parse(tt.args) if got := fmt.Sprint(err); got != tt.want { diff --git a/cmd/tailscale/cli/configure-kube.go b/cmd/tailscale/cli/configure-kube.go index 3dcec250f..8160025c6 100644 --- a/cmd/tailscale/cli/configure-kube.go +++ b/cmd/tailscale/cli/configure-kube.go @@ -20,10 +20,8 @@ import ( "github.com/peterbourgon/ff/v3/ffcli" "k8s.io/client-go/util/homedir" "sigs.k8s.io/yaml" - "tailscale.com/ipn" "tailscale.com/ipn/ipnstate" "tailscale.com/tailcfg" - "tailscale.com/types/netmap" "tailscale.com/util/dnsname" "tailscale.com/version" ) @@ -98,12 +96,12 @@ func runConfigureKubeconfig(ctx context.Context, args []string) error { if st.BackendState != "Running" { return errors.New("Tailscale is not running") } - nm, err := getNetMap(ctx) + dnsCfg, err := getDNSConfig(ctx) if err != nil { return err } - targetFQDN, err := nodeOrServiceDNSNameFromArg(st, nm, hostOrFQDNOrIP) + targetFQDN, err := nodeOrServiceDNSNameFromArg(st, dnsCfg, hostOrFQDNOrIP) if err != nil { return err } @@ -240,14 +238,14 @@ func setKubeconfigForPeer(scheme, fqdn, filePath string) error { // nodeOrServiceDNSNameFromArg returns the PeerStatus.DNSName value from a peer // in st that matches the input arg which can be a base name, full DNS name, or // an IP. If none is found, it looks for a Tailscale Service -func nodeOrServiceDNSNameFromArg(st *ipnstate.Status, nm *netmap.NetworkMap, arg string) (string, error) { +func nodeOrServiceDNSNameFromArg(st *ipnstate.Status, dns *tailcfg.DNSConfig, arg string) (string, error) { // First check for a node DNS name. if dnsName, ok := nodeDNSNameFromArg(st, arg); ok { return dnsName, nil } // If not found, check for a Tailscale Service DNS name. - rec, ok := serviceDNSRecordFromNetMap(nm, arg) + rec, ok := serviceDNSRecordFromDNSConfig(dns, arg) if !ok { return "", fmt.Errorf("no peer found for %q", arg) } @@ -269,25 +267,13 @@ func nodeOrServiceDNSNameFromArg(st *ipnstate.Status, nm *netmap.NetworkMap, arg return "", fmt.Errorf("%q is in MagicDNS, but is not currently reachable on any known peer", arg) } -func getNetMap(ctx context.Context) (*netmap.NetworkMap, error) { +func getDNSConfig(ctx context.Context) (*tailcfg.DNSConfig, error) { ctx, cancel := context.WithTimeout(ctx, 5*time.Second) defer cancel() - - watcher, err := localClient.WatchIPNBus(ctx, ipn.NotifyInitialNetMap) - if err != nil { - return nil, err - } - defer watcher.Close() - - n, err := watcher.Next() - if err != nil { - return nil, err - } - - return n.NetMap, nil + return localClient.DNSConfig(ctx) } -func serviceDNSRecordFromNetMap(nm *netmap.NetworkMap, arg string) (rec tailcfg.DNSRecord, ok bool) { +func serviceDNSRecordFromDNSConfig(dns *tailcfg.DNSConfig, arg string) (rec tailcfg.DNSRecord, ok bool) { argIP, _ := netip.ParseAddr(arg) argFQDN, err := dnsname.ToFQDN(arg) argFQDNValid := err == nil @@ -295,7 +281,7 @@ func serviceDNSRecordFromNetMap(nm *netmap.NetworkMap, arg string) (rec tailcfg. return rec, false } - for _, rec := range nm.DNS.ExtraRecords { + for _, rec := range dns.ExtraRecords { if argIP.IsValid() { recIP, _ := netip.ParseAddr(rec.Value) if recIP == argIP { diff --git a/cmd/tailscale/cli/configure_linux.go b/cmd/tailscale/cli/configure_linux.go index 9ba3b8e87..da0444908 100644 --- a/cmd/tailscale/cli/configure_linux.go +++ b/cmd/tailscale/cli/configure_linux.go @@ -18,7 +18,7 @@ func init() { maybeSystrayCmd = systrayConfigCmd } -var systrayArgs struct { +var configSystrayArgs struct { initSystem string installStartup bool } @@ -32,7 +32,7 @@ func systrayConfigCmd() *ffcli.Command { Exec: configureSystray, FlagSet: (func() *flag.FlagSet { fs := newFlagSet("systray") - fs.StringVar(&systrayArgs.initSystem, "enable-startup", "", + fs.StringVar(&configSystrayArgs.initSystem, "enable-startup", "", "Install startup script for init system. Currently supported systems are [systemd, freedesktop].") return fs })(), @@ -40,8 +40,8 @@ func systrayConfigCmd() *ffcli.Command { } func configureSystray(_ context.Context, _ []string) error { - if systrayArgs.initSystem != "" { - if err := systray.InstallStartupScript(systrayArgs.initSystem); err != nil { + if configSystrayArgs.initSystem != "" { + if err := systray.InstallStartupScript(configSystrayArgs.initSystem); err != nil { fmt.Printf("%s\n\n", err.Error()) return flag.ErrHelp } diff --git a/cmd/tailscale/cli/debug.go b/cmd/tailscale/cli/debug.go index 944f99f91..3531172bb 100644 --- a/cmd/tailscale/cli/debug.go +++ b/cmd/tailscale/cli/debug.go @@ -670,18 +670,11 @@ func runNetmap(ctx context.Context, args []string) error { ctx, cancel := context.WithTimeout(ctx, 5*time.Second) defer cancel() - var mask ipn.NotifyWatchOpt = ipn.NotifyInitialNetMap - watcher, err := localClient.WatchIPNBus(ctx, mask) + raw, err := localClient.DebugResultJSON(ctx, "current-netmap") if err != nil { return err } - defer watcher.Close() - - n, err := watcher.Next() - if err != nil { - return err - } - j, _ := json.MarshalIndent(n.NetMap, "", "\t") + j, _ := json.MarshalIndent(raw, "", "\t") fmt.Printf("%s\n", j) return nil } diff --git a/cmd/tailscale/cli/dns-status.go b/cmd/tailscale/cli/dns-status.go index 66a5e21d8..91a62f996 100644 --- a/cmd/tailscale/cli/dns-status.go +++ b/cmd/tailscale/cli/dns-status.go @@ -14,9 +14,7 @@ import ( "github.com/peterbourgon/ff/v3/ffcli" "tailscale.com/cmd/tailscale/cli/jsonoutput" - "tailscale.com/ipn" "tailscale.com/types/dnstype" - "tailscale.com/types/netmap" ) var dnsStatusCmd = &ffcli.Command{ @@ -120,11 +118,10 @@ func runDNSStatus(ctx context.Context, args []string) error { SelfDNSName: s.Self.DNSName, } - netMap, err := fetchNetMap() + dnsConfig, err := localClient.DNSConfig(ctx) if err != nil { - return fmt.Errorf("failed to fetch network map: %w", err) + return fmt.Errorf("failed to fetch DNS config: %w", err) } - dnsConfig := netMap.DNS for _, r := range dnsConfig.Resolvers { data.Resolvers = append(data.Resolvers, makeDNSResolverInfo(r)) @@ -357,19 +354,3 @@ func formatDNSStatusText(data *jsonoutput.DNSStatusResult, all bool) string { fmt.Fprintf(&sb, "[this is a preliminary version of this command; the output format may change in the future]\n") return sb.String() } - -func fetchNetMap() (netMap *netmap.NetworkMap, err error) { - w, err := localClient.WatchIPNBus(context.Background(), ipn.NotifyInitialNetMap) - if err != nil { - return nil, err - } - defer w.Close() - notify, err := w.Next() - if err != nil { - return nil, err - } - if notify.NetMap == nil { - return nil, fmt.Errorf("no network map yet available, please try again later") - } - return notify.NetMap, nil -} diff --git a/cmd/tailscale/cli/file.go b/cmd/tailscale/cli/file.go index e7406bee3..489c83deb 100644 --- a/cmd/tailscale/cli/file.go +++ b/cmd/tailscale/cli/file.go @@ -32,6 +32,7 @@ import ( "tailscale.com/client/tailscale/apitype" "tailscale.com/cmd/tailscale/cli/ffcomplete" "tailscale.com/envknob" + "tailscale.com/ipn" "tailscale.com/ipn/ipnstate" "tailscale.com/net/tsaddr" "tailscale.com/tailcfg" @@ -78,14 +79,16 @@ var fileCpCmd = &ffcli.Command{ fs.StringVar(&cpArgs.name, "name", "", "alternate filename to use, especially useful when is \"-\" (stdin)") fs.BoolVar(&cpArgs.verbose, "verbose", false, "verbose output") fs.BoolVar(&cpArgs.targets, "targets", false, "list possible file cp targets") + fs.DurationVar(&cpArgs.updateInterval, "update-interval", 250*time.Millisecond, "how often to repaint the progress line; zero or negative disables progress display entirely") return fs })(), } var cpArgs struct { - name string - verbose bool - targets bool + name string + verbose bool + targets bool + updateInterval time.Duration } func runCp(ctx context.Context, args []string) error { @@ -119,9 +122,6 @@ func runCp(ctx context.Context, args []string) error { if err != nil { return fmt.Errorf("can't send to %s: %v", target, err) } - if isOffline { - fmt.Fprintf(Stderr, "# warning: %s is offline\n", target) - } if len(files) > 1 { if cpArgs.name != "" { @@ -132,7 +132,51 @@ func runCp(ctx context.Context, args []string) error { } } - for _, fileArg := range files { + // outFiles tracks per-name push state, populated by a goroutine subscribed + // to the IPN bus. tailscaled's OutgoingFile.Sent is the bytes-pulled-toward- + // peerAPI signal; it stays at 0 until the peerAPI request body is actually + // being read, which is what we want both for the progress display and for + // disarming the offline warning. The CLI's local-side bytes counter would + // say "100% sent" the moment net/http buffers a small body into the local + // unix-socket conn to tailscaled, well before the peer has heard a thing. + type pushState struct { + sent atomic.Int64 + warnTimer *time.Timer // disarmed on first byte sent to peerAPI; nil after + } + var ( + outMu sync.Mutex + outFiles = map[string]*pushState{} // keyed by file name + ) + + busCtx, cancelBus := context.WithCancel(ctx) + defer cancelBus() + go watchOutgoingFiles(busCtx, stableID, func(name string, sent int64) { + outMu.Lock() + ps := outFiles[name] + outMu.Unlock() + if ps == nil { + return + } + // Only ever advance ps.sent forward. Bus updates can arrive late + // (after the success path below has already written contentLength + // to ps.sent for an instant final-100% paint), so we'd otherwise + // regress the count and the progress printer would compute a + // negative delta on its next tick. + for { + old := ps.sent.Load() + if sent <= old { + return + } + if ps.sent.CompareAndSwap(old, sent) { + if old == 0 && ps.warnTimer != nil { + ps.warnTimer.Stop() + } + return + } + } + }) + + for i, fileArg := range files { var fileContents *countingReader var name = cpArgs.name var contentLength int64 = -1 @@ -175,16 +219,57 @@ func runCp(ctx context.Context, args []string) error { log.Printf("sending %q to %v/%v/%v ...", name, target, ip, stableID) } + // Register this file with the watcher and, for the first file only, + // arm a timer that warns the user if no bytes have flowed to peerAPI + // after a few seconds. The watcher disarms it on first byte; PushFile + // returning also disarms it (cleanup, below). We don't gate on the + // netmap's Online bit (which can lag reality), but we do use it to + // pick between two warning messages. + ps := &pushState{} + if i == 0 { + ps.warnTimer = time.AfterFunc(3*time.Second, func() { + // vtRestartLine clears whatever (possibly progress) was on + // the current line, then we print the warning + \n so the + // next progress redraw lands on a fresh line below. + const vtRestartLine = "\r\x1b[K" + if isOffline { + fmt.Fprintf(Stderr, "%s# warning: %s is reportedly offline; trying anyway\n", vtRestartLine, target) + } else { + fmt.Fprintf(Stderr, "%s# warning: %s is not replying; trying anyway\n", vtRestartLine, target) + } + }) + } + outMu.Lock() + outFiles[name] = ps + outMu.Unlock() + var group sync.WaitGroup ctxProgress, cancelProgress := context.WithCancel(ctx) defer cancelProgress() - if isatty.IsTerminal(os.Stderr.Fd()) { - group.Go(func() { progressPrinter(ctxProgress, name, fileContents.n.Load, contentLength) }) + if cpArgs.updateInterval > 0 && isatty.IsTerminal(os.Stderr.Fd()) { + group.Go(func() { + progressPrinter(ctxProgress, name, ps.sent.Load, contentLength, cpArgs.updateInterval) + }) } err := localClient.PushFile(ctx, stableID, contentLength, name, fileContents) + if err == nil { + // PushFile can finish faster than the IPN bus delivers a final + // OutgoingFile update, leaving the progress display stuck at 0%. + // Synthesize a "fully done" count before stopping the printer so + // its final paint shows 100%. For stdin (contentLength == -1) we + // don't know the size, so fall back to the local read count. + if contentLength >= 0 { + ps.sent.Store(contentLength) + } else { + ps.sent.Store(fileContents.n.Load()) + } + } cancelProgress() group.Wait() // wait for progress printer to stop before reporting the error + if ps.warnTimer != nil { + ps.warnTimer.Stop() + } if err != nil { return err } @@ -195,15 +280,71 @@ func runCp(ctx context.Context, args []string) error { return nil } -func progressPrinter(ctx context.Context, name string, contentCount func() int64, contentLength int64) { +// watchOutgoingFiles subscribes to the IPN bus and invokes onUpdate once +// per OutgoingFile event for files going to peer. It runs until ctx is +// done (which runCp does on return) and is best-effort: if the bus +// subscription fails for any reason, onUpdate simply isn't called and the +// caller's progress display stays at 0 — exactly the right degradation, +// since the warning timer will then fire on its normal 3-second deadline. +func watchOutgoingFiles(ctx context.Context, peer tailcfg.StableNodeID, onUpdate func(name string, sent int64)) { + // NotifyPeerChanges opts in to per-peer add/remove notifications so the + // bus stays responsive without us also subscribing to the full NetMap, + // which we don't read here. + w, err := localClient.WatchIPNBus(ctx, ipn.NotifyInitialOutgoingFiles|ipn.NotifyPeerChanges) + if err != nil { + return + } + defer w.Close() + for { + n, err := w.Next() + if err != nil { + return + } + for _, of := range n.OutgoingFiles { + if of.PeerID != peer { + continue + } + // tailscaled keeps Finished entries in its OutgoingFiles map + // across PushFile calls (see feature/taildrop/ext.go), so a + // re-send of the same filename will see both the old completed + // (Sent == DeclaredSize) entry and the new in-progress one. + // Without this filter the watcher's monotonic CAS would latch + // onto the old entry's max value and the new transfer would + // appear stuck at 100% from the first bus tick. + if of.Finished { + continue + } + onUpdate(of.Name, of.Sent) + } + } +} + +// progressPrinter repaints a single-line transfer progress display every +// interval. interval must be > 0; runCp's caller gates on the +// --update-interval flag and skips invoking us when it's <= 0. +// +// It returns when ctx is done OR when it detects the transfer is stuck — +// "stuck" being: contentCount has equalled contentLength with a near-zero +// rate for >2 seconds. The stuck case prints a final newline so subsequent +// output (e.g. an error from PushFile) lands on a fresh line below the +// frozen progress line, instead of being painted over by it. +func progressPrinter(ctx context.Context, name string, contentCount func() int64, contentLength int64, interval time.Duration) { var rateValueFast, rateValueSlow tsrate.Value - rateValueFast.HalfLife = 1 * time.Second // fast response for rate measurement - rateValueSlow.HalfLife = 10 * time.Second // slow response for ETA measurement + // tailscaled emits OutgoingFile.Sent updates at ~1 Hz, so most printer + // ticks see no delta. With too short a half-life the displayed rate + // roughly halves between updates and doubles back when one arrives, + // looking jumpy. 5s keeps the swing under ~15% while still settling + // within a few seconds of a real change. + rateValueFast.HalfLife = 5 * time.Second // smoothed rate for display + rateValueSlow.HalfLife = 10 * time.Second // even slower, for ETA measurement var prevContentCount int64 print := func() { currContentCount := contentCount() - rateValueFast.Add(float64(currContentCount - prevContentCount)) - rateValueSlow.Add(float64(currContentCount - prevContentCount)) + // Clamp so a regression (which shouldn't happen, but tsrate.Value.Add + // panics on a negative count) can't take down the CLI. + delta := max(currContentCount-prevContentCount, 0) + rateValueFast.Add(float64(delta)) + rateValueSlow.Add(float64(delta)) prevContentCount = currContentCount const vtRestartLine = "\r\x1b[K" @@ -215,16 +356,23 @@ func progressPrinter(ctx context.Context, name string, contentCount func() int64 if contentLength >= 0 { currContentCount = min(currContentCount, contentLength) // cap at 100% ratioRemain := float64(currContentCount) / float64(contentLength) - bytesRemain := float64(contentLength - currContentCount) - secsRemain := bytesRemain / rateValueSlow.Rate() - secs := int(min(max(0, secsRemain), 99*60*60+59+60+59)) + etaStr := "ETA -" + if rate := rateValueSlow.Rate(); rate > 0 { + bytesRemain := float64(contentLength - currContentCount) + secsRemain := bytesRemain / rate + secs := int(min(max(0, secsRemain), 99*60*60+59+60+59)) + etaStr = fmt.Sprintf("ETA %02d:%02d:%02d", secs/60/60, (secs/60)%60, secs%60) + } fmt.Fprintf(os.Stderr, " %s %s", leftPad(fmt.Sprintf("%0.2f%%", 100.0*ratioRemain), len("100.00%")), - fmt.Sprintf("ETA %02d:%02d:%02d", secs/60/60, (secs/60)%60, secs%60)) + etaStr) } } - tc := time.NewTicker(250 * time.Millisecond) + const stuckAfter = 2 * time.Second + var fullStartedAt time.Time // when we first observed currCount==contentLength with ~zero rate + + tc := time.NewTicker(interval) defer tc.Stop() print() for { @@ -235,6 +383,24 @@ func progressPrinter(ctx context.Context, name string, contentCount func() int64 return case <-tc.C: print() + if contentLength < 0 { + continue + } + currCount := contentCount() + rate := rateValueFast.Rate() + if currCount >= contentLength && rate < 1 { + if fullStartedAt.IsZero() { + fullStartedAt = time.Now() + } else if time.Since(fullStartedAt) >= stuckAfter { + // Transfer is stuck at 100% with no movement. Stop + // repainting so we don't keep clobbering anything the + // rest of runCp prints (warnings, errors). + fmt.Fprintln(os.Stderr) + return + } + } else { + fullStartedAt = time.Time{} + } } } } @@ -328,7 +494,10 @@ peerLoop: return "", isOffline, errors.New("cannot send files: missing required Taildrop capability") case ipnstate.TaildropTargetOffline: - return "", isOffline, errors.New("cannot send files: peer is offline") + // Don't gate on the server-reported Online bit (which lags reality + // and isn't always accurate). runCp probes reachability itself with + // TSMP pings. + return foundPeer.ID, isOffline, nil case ipnstate.TaildropTargetNoPeerInfo: return "", isOffline, errors.New("cannot send files: invalid or unrecognized peer") diff --git a/cmd/tailscale/cli/jsonoutput/network-lock-log.go b/cmd/tailscale/cli/jsonoutput/network-lock-log.go index c7c16e223..779a99883 100644 --- a/cmd/tailscale/cli/jsonoutput/network-lock-log.go +++ b/cmd/tailscale/cli/jsonoutput/network-lock-log.go @@ -159,7 +159,7 @@ type expandedAUMV1 struct { } // tkaKeyV1 is the expanded version of a [tka.Key], which describes -// the public components of a key known to network-lock. +// the public components of a key known to tailnet-lock. type tkaKeyV1 struct { Kind string `json:"Kind,omitzero"` diff --git a/cmd/tailscale/cli/jsonoutput/network-lock-status.go b/cmd/tailscale/cli/jsonoutput/network-lock-status.go index a1d95b871..fce2276ef 100644 --- a/cmd/tailscale/cli/jsonoutput/network-lock-status.go +++ b/cmd/tailscale/cli/jsonoutput/network-lock-status.go @@ -116,7 +116,7 @@ type tailnetLockStatusV1Base struct { // Enabled is true if Tailnet Lock is enabled. Enabled bool - // PublicKey describes the node's network-lock public key. + // PublicKey describes the node's tailnet-lock public key. PublicKey string `json:"PublicKey,omitzero"` // NodeKey describes the node's current node-key. This field is not @@ -144,7 +144,7 @@ type tailnetLockEnabledStatusV1 struct { NodeKeySignature *tkaNodeKeySignatureV1 // TrustedKeys describes the keys currently trusted to make changes - // to network-lock. + // to tailnet-lock. TrustedKeys []tkaKeyV1 // VisiblePeers describes peers which are visible in the netmap that diff --git a/cmd/tailscale/cli/serve_legacy.go b/cmd/tailscale/cli/serve_legacy.go index 837d88513..635bcfa3d 100644 --- a/cmd/tailscale/cli/serve_legacy.go +++ b/cmd/tailscale/cli/serve_legacy.go @@ -848,10 +848,10 @@ func (e *serveEnv) enableFeatureInteractive(ctx context.Context, feature string, e.lc.IncrementCounter(ctx, fmt.Sprintf("%s_enablement_lost_connection", feature), 1) return err } - if nm := n.NetMap; nm != nil && nm.SelfNode.Valid() { + if self := n.SelfChange; self != nil { gotAll := true for _, c := range caps { - if !nm.SelfNode.HasCap(c) { + if _, has := self.CapMap[c]; !has { // The feature is not yet enabled. // Continue blocking until it is. gotAll = false diff --git a/cmd/tailscale/cli/systray.go b/cmd/tailscale/cli/systray.go index ca0840fe9..07de5c786 100644 --- a/cmd/tailscale/cli/systray.go +++ b/cmd/tailscale/cli/systray.go @@ -7,6 +7,7 @@ package cli import ( "context" + "flag" "github.com/peterbourgon/ff/v3/ffcli" "tailscale.com/client/systray" @@ -17,10 +18,20 @@ var systrayCmd = &ffcli.Command{ ShortUsage: "tailscale systray", ShortHelp: "Run a systray application to manage Tailscale", LongHelp: "Run a systray application to manage Tailscale.", - Exec: runSystray, + FlagSet: (func() *flag.FlagSet { + fs := newFlagSet("systray") + fs.StringVar(&systrayArgs.theme, "theme", "dark", "color theme for Tailscale icon: dark, dark:nobg, light, light:nobg") + return fs + })(), + Exec: runSystray, +} + +var systrayArgs struct { + theme string } func runSystray(ctx context.Context, _ []string) error { + systray.SetTheme(systrayArgs.theme) new(systray.Menu).Run(&localClient) return nil } diff --git a/cmd/tailscale/cli/up.go b/cmd/tailscale/cli/up.go index 586df07bb..fed7de9ae 100644 --- a/cmd/tailscale/cli/up.go +++ b/cmd/tailscale/cli/up.go @@ -113,12 +113,12 @@ func newUpFlagSet(goos string, upArgs *upArgsT, cmd string) *flag.FlagSet { upf.BoolVar(&upArgs.exitNodeAllowLANAccess, "exit-node-allow-lan-access", false, "Allow direct access to the local network when routing traffic via an exit node") upf.BoolVar(&upArgs.shieldsUp, "shields-up", false, "don't allow incoming connections") upf.BoolVar(&upArgs.runSSH, "ssh", false, "run an SSH server, permitting access per tailnet admin's declared policy") - upf.StringVar(&upArgs.advertiseTags, "advertise-tags", "", "comma-separated ACL tags to request; each must start with \"tag:\" (e.g. \"tag:eng,tag:montreal,tag:ssh\")") + upf.StringVar(&upArgs.advertiseTags, "advertise-tags", "", "comma-separated ACL tags to request (e.g. \"tag:eng,tag:montreal,tag:ssh\"); the \"tag:\" prefix is optional and added automatically when omitted (e.g. \"eng,montreal,ssh\")") upf.StringVar(&upArgs.hostname, "hostname", "", "hostname to use instead of the one provided by the OS") upf.StringVar(&upArgs.advertiseRoutes, "advertise-routes", "", "routes to advertise to other nodes (comma-separated, e.g. \"10.0.0.0/8,192.168.0.0/24\") or empty string to not advertise routes") upf.BoolVar(&upArgs.advertiseConnector, "advertise-connector", false, "advertise this node as an app connector") upf.BoolVar(&upArgs.advertiseDefaultRoute, "advertise-exit-node", false, "offer to be an exit node for internet traffic for the tailnet") - upf.BoolVar(&upArgs.postureChecking, "report-posture", false, hidden+"allow management plane to gather device posture information") + upf.BoolVar(&upArgs.postureChecking, "report-posture", false, "allow management plane to gather device posture information") if safesocket.GOOSUsesPeerCreds(goos) { upf.StringVar(&upArgs.opUser, "operator", "", "Unix username to allow to operate on tailscaled without sudo") @@ -309,9 +309,15 @@ func prefsFromUpArgs(upArgs upArgsT, warnf logger.Logf, st *ipnstate.Status, goo var tags []string if upArgs.advertiseTags != "" { tags = strings.Split(upArgs.advertiseTags, ",") - for _, tag := range tags { - err := tailcfg.CheckTag(tag) - if err != nil { + for i, tag := range tags { + // Allow users to omit the "tag:" prefix; if the tag has no + // colon at all, add it for them. Tags with a colon must be + // fully qualified ("tag:foo") and are validated as-is. + if !strings.Contains(tag, ":") { + tag = "tag:" + tag + tags[i] = tag + } + if err := tailcfg.CheckTag(tag); err != nil { return nil, fmt.Errorf("tag: %q: %s", tag, err) } } @@ -726,7 +732,7 @@ func runUp(ctx context.Context, cmd string, args []string, upArgs upArgsT) (retE if s := n.State; s != nil { ipnIsRunning = *s == ipn.Running } - if n.NetMap != nil && n.NetMap.NodeKey != origNodeKey { + if n.SelfChange != nil && n.SelfChange.Key != origNodeKey { waitingForKeyChange = false } if ipnIsRunning && !waitingForKeyChange { diff --git a/cmd/tailscale/depaware.txt b/cmd/tailscale/depaware.txt index 01d3f418f..d23ab1f4f 100644 --- a/cmd/tailscale/depaware.txt +++ b/cmd/tailscale/depaware.txt @@ -239,7 +239,7 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep tailscale.com/tstime from tailscale.com/control/controlhttp+ tailscale.com/tstime/mono from tailscale.com/tstime/rate tailscale.com/tstime/rate from tailscale.com/cmd/tailscale/cli - tailscale.com/tsweb from tailscale.com/util/eventbus + tailscale.com/tsweb from tailscale.com/util/eventbus+ tailscale.com/tsweb/varz from tailscale.com/util/usermetric+ tailscale.com/types/appctype from tailscale.com/client/local+ tailscale.com/types/dnstype from tailscale.com/tailcfg+ @@ -331,7 +331,7 @@ tailscale.com/cmd/tailscale dependencies: (generated by github.com/tailscale/dep golang.org/x/net/icmp from tailscale.com/net/ping golang.org/x/net/idna from golang.org/x/net/http/httpproxy+ golang.org/x/net/internal/iana from golang.org/x/net/icmp+ - golang.org/x/net/internal/socket from golang.org/x/net/icmp+ + golang.org/x/net/internal/socket from golang.org/x/net/ipv4+ golang.org/x/net/internal/socks from golang.org/x/net/proxy golang.org/x/net/ipv4 from golang.org/x/net/icmp+ golang.org/x/net/ipv6 from golang.org/x/net/icmp+ diff --git a/cmd/tailscaled/depaware-min.txt b/cmd/tailscaled/depaware-min.txt index 94ebf144b..8f0c34cf1 100644 --- a/cmd/tailscaled/depaware-min.txt +++ b/cmd/tailscaled/depaware-min.txt @@ -219,7 +219,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de golang.org/x/net/icmp from tailscale.com/net/ping golang.org/x/net/idna from golang.org/x/net/http/httpguts golang.org/x/net/internal/iana from golang.org/x/net/icmp+ - golang.org/x/net/internal/socket from golang.org/x/net/icmp+ + golang.org/x/net/internal/socket from golang.org/x/net/ipv4+ golang.org/x/net/ipv4 from github.com/tailscale/wireguard-go/conn+ golang.org/x/net/ipv6 from github.com/tailscale/wireguard-go/conn+ golang.org/x/sync/errgroup from github.com/mdlayher/socket diff --git a/cmd/tailscaled/depaware-minbox.txt b/cmd/tailscaled/depaware-minbox.txt index e518613f8..994310d60 100644 --- a/cmd/tailscaled/depaware-minbox.txt +++ b/cmd/tailscaled/depaware-minbox.txt @@ -240,7 +240,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de golang.org/x/net/icmp from tailscale.com/net/ping golang.org/x/net/idna from golang.org/x/net/http/httpguts+ golang.org/x/net/internal/iana from golang.org/x/net/icmp+ - golang.org/x/net/internal/socket from golang.org/x/net/icmp+ + golang.org/x/net/internal/socket from golang.org/x/net/ipv4+ golang.org/x/net/ipv4 from github.com/tailscale/wireguard-go/conn+ golang.org/x/net/ipv6 from github.com/tailscale/wireguard-go/conn+ golang.org/x/sync/errgroup from github.com/mdlayher/socket diff --git a/cmd/tailscaled/depaware.txt b/cmd/tailscaled/depaware.txt index 678d72560..7e0e95be8 100644 --- a/cmd/tailscaled/depaware.txt +++ b/cmd/tailscaled/depaware.txt @@ -130,7 +130,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de L github.com/google/nftables/expr from github.com/google/nftables+ L github.com/google/nftables/internal/parseexprfunc from github.com/google/nftables+ L github.com/google/nftables/xt from github.com/google/nftables/expr+ - DW github.com/google/uuid from tailscale.com/clientupdate+ + W github.com/google/uuid from tailscale.com/clientupdate github.com/hdevalence/ed25519consensus from tailscale.com/clientupdate/distsign+ github.com/huin/goupnp from github.com/huin/goupnp/dcps/internetgateway2+ github.com/huin/goupnp/dcps/internetgateway2 from tailscale.com/net/portmapper @@ -173,9 +173,8 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de github.com/pires/go-proxyproto from tailscale.com/ipn/ipnlocal LD github.com/pkg/sftp from tailscale.com/ssh/tailssh LD github.com/pkg/sftp/internal/encoding/ssh/filexfer from github.com/pkg/sftp - D github.com/prometheus-community/pro-bing from tailscale.com/wgengine/netstack L 💣 github.com/safchain/ethtool from tailscale.com/net/netkernelconf+ - W 💣 github.com/tailscale/certstore from tailscale.com/control/controlclient + DW 💣 github.com/tailscale/certstore from tailscale.com/control/controlclient LD github.com/tailscale/gliderssh from tailscale.com/ssh/tailssh W 💣 github.com/tailscale/go-winio from tailscale.com/safesocket W 💣 github.com/tailscale/go-winio/internal/fs from github.com/tailscale/go-winio @@ -259,6 +258,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de tailscale.com/client/web from tailscale.com/ipn/ipnlocal tailscale.com/clientupdate from tailscale.com/feature/clientupdate LW tailscale.com/clientupdate/distsign from tailscale.com/clientupdate + tailscale.com/cmd/tailscale/cli/jsonoutput from tailscale.com/feature/tailnetlock tailscale.com/cmd/tailscaled/childproc from tailscale.com/cmd/tailscaled+ tailscale.com/cmd/tailscaled/tailscaledhooks from tailscale.com/cmd/tailscaled+ tailscale.com/control/controlbase from tailscale.com/control/controlhttp+ @@ -303,10 +303,12 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de tailscale.com/feature/portmapper from tailscale.com/feature/condregister/portmapper tailscale.com/feature/posture from tailscale.com/feature/condregister tailscale.com/feature/relayserver from tailscale.com/feature/condregister + tailscale.com/feature/routecheck from tailscale.com/feature/condregister L tailscale.com/feature/sdnotify from tailscale.com/feature/condregister LD tailscale.com/feature/ssh from tailscale.com/cmd/tailscaled tailscale.com/feature/syspolicy from tailscale.com/feature/condregister+ tailscale.com/feature/taildrop from tailscale.com/feature/condregister + tailscale.com/feature/tailnetlock from tailscale.com/feature/condregister L tailscale.com/feature/tap from tailscale.com/feature/condregister tailscale.com/feature/tpm from tailscale.com/feature/condregister L 💣 tailscale.com/feature/tundevstats from tailscale.com/feature/condregister @@ -402,7 +404,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de tailscale.com/tstime from tailscale.com/control/controlclient+ tailscale.com/tstime/mono from tailscale.com/net/tstun+ tailscale.com/tstime/rate from tailscale.com/wgengine/filter - tailscale.com/tsweb from tailscale.com/util/eventbus + tailscale.com/tsweb from tailscale.com/util/eventbus+ tailscale.com/tsweb/varz from tailscale.com/cmd/tailscaled+ tailscale.com/types/appctype from tailscale.com/ipn/ipnlocal+ tailscale.com/types/bools from tailscale.com/wgengine/netlog @@ -525,13 +527,13 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de golang.org/x/net/dns/dnsmessage from tailscale.com/appc+ golang.org/x/net/http/httpguts from tailscale.com/ipn/ipnlocal golang.org/x/net/http/httpproxy from tailscale.com/net/tshttpproxy - golang.org/x/net/icmp from tailscale.com/net/ping+ + golang.org/x/net/icmp from tailscale.com/net/ping golang.org/x/net/idna from golang.org/x/net/http/httpguts+ golang.org/x/net/internal/iana from golang.org/x/net/icmp+ - golang.org/x/net/internal/socket from golang.org/x/net/icmp+ + golang.org/x/net/internal/socket from golang.org/x/net/ipv4+ golang.org/x/net/internal/socks from golang.org/x/net/proxy - golang.org/x/net/ipv4 from github.com/prometheus-community/pro-bing+ - golang.org/x/net/ipv6 from github.com/prometheus-community/pro-bing+ + golang.org/x/net/ipv4 from github.com/tailscale/wireguard-go/conn+ + golang.org/x/net/ipv6 from github.com/tailscale/wireguard-go/conn+ golang.org/x/net/proxy from tailscale.com/net/netns D golang.org/x/net/route from tailscale.com/net/netmon+ golang.org/x/sync/errgroup from github.com/mdlayher/socket+ @@ -642,7 +644,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de crypto/x509 from crypto/tls+ D crypto/x509/internal/macos from crypto/x509 crypto/x509/pkix from crypto/x509+ - DW database/sql/driver from github.com/google/uuid + W database/sql/driver from github.com/google/uuid W debug/dwarf from debug/pe W debug/pe from github.com/dblohm7/wingoes/pe embed from github.com/tailscale/web-client-prebuilt+ @@ -732,7 +734,7 @@ tailscale.com/cmd/tailscaled dependencies: (generated by github.com/tailscale/de mime/quotedprintable from mime/multipart net from crypto/tls+ net/http from expvar+ - net/http/httptrace from github.com/prometheus-community/pro-bing+ + net/http/httptrace from github.com/aws/smithy-go/transport/http+ net/http/httputil from github.com/aws/smithy-go/transport/http+ net/http/internal from net/http+ net/http/internal/ascii from net/http+ diff --git a/cmd/tailscaled/deps_test.go b/cmd/tailscaled/deps_test.go index be4f65a7d..e91509765 100644 --- a/cmd/tailscaled/deps_test.go +++ b/cmd/tailscaled/deps_test.go @@ -202,6 +202,19 @@ func TestOmitPortlist(t *testing.T) { }.Check(t) } +func TestOmitRouteCheck(t *testing.T) { + deptest.DepChecker{ + GOOS: "linux", + GOARCH: "amd64", + Tags: "ts_omit_routecheck,ts_include_cli", + OnDep: func(dep string) { + if strings.Contains(dep, "routecheck") { + t.Errorf("unexpected dep: %q", dep) + } + }, + }.Check(t) +} + func TestOmitGRO(t *testing.T) { deptest.DepChecker{ GOOS: "linux", diff --git a/cmd/tailscaled/tailscaled.go b/cmd/tailscaled/tailscaled.go index fe18731ae..9ecf84055 100644 --- a/cmd/tailscaled/tailscaled.go +++ b/cmd/tailscaled/tailscaled.go @@ -828,7 +828,6 @@ func tryEngine(logf logger.Logf, sys *tsd.System, name string) (onlyNetstack boo if err != nil { return onlyNetstack, err } - e = wgengine.NewWatchdog(e) sys.Set(e) sys.NetstackRouter.Set(netstackSubnetRouter) diff --git a/cmd/testwrapper/testwrapper.go b/cmd/testwrapper/testwrapper.go index 204409a63..34338fff2 100644 --- a/cmd/testwrapper/testwrapper.go +++ b/cmd/testwrapper/testwrapper.go @@ -267,7 +267,7 @@ func main() { if cached { lastCol = "(cached)" } else { - lastCol = fmt.Sprintf("%.3f", testDur.Seconds()) + lastCol = fmt.Sprintf("%.3fs", testDur.Seconds()) } fmt.Printf("%s\t%s\t%v\n", outcome, pkg, lastCol) } diff --git a/cmd/tsconnect/wasm/wasm_js.go b/cmd/tsconnect/wasm/wasm_js.go index 71e8476a0..f58e4201a 100644 --- a/cmd/tsconnect/wasm/wasm_js.go +++ b/cmd/tsconnect/wasm/wasm_js.go @@ -258,44 +258,50 @@ func (i *jsIPN) run(jsCallbacks js.Value) { if n.State != nil { notifyState(*n.State) } - if nm := n.NetMap; nm != nil { - jsNetMap := jsNetMap{ - Self: jsNetMapSelfNode{ - jsNetMapNode: jsNetMapNode{ - Name: nm.SelfName(), - Addresses: mapSliceView(nm.GetAddresses(), func(a netip.Prefix) string { return a.Addr().String() }), - NodeKey: nm.NodeKey.String(), - MachineKey: nm.MachineKey.String(), - }, - MachineStatus: jsMachineStatus[nm.GetMachineStatus()], - }, - Peers: mapSlice(nm.Peers, func(p tailcfg.NodeView) jsNetMapPeerNode { - name := p.Name() - if name == "" { - // In practice this should only happen for Hello. - name = p.Hostinfo().Hostname() - } - addrs := make([]string, p.Addresses().Len()) - for i, ap := range p.Addresses().All() { - addrs[i] = ap.Addr().String() - } - return jsNetMapPeerNode{ + if n.SelfChange != nil { + // Self changed: rebuild the JS-side NetMap snapshot. Peers + // don't ride on the bus anymore, so fetch them on demand + // from LocalBackend. + nm := i.lb.NetMapWithPeers() + if nm != nil { + jsNetMap := jsNetMap{ + Self: jsNetMapSelfNode{ jsNetMapNode: jsNetMapNode{ - Name: name, - Addresses: addrs, - MachineKey: p.Machine().String(), - NodeKey: p.Key().String(), + Name: nm.SelfName(), + Addresses: mapSliceView(nm.GetAddresses(), func(a netip.Prefix) string { return a.Addr().String() }), + NodeKey: nm.NodeKey.String(), + MachineKey: nm.MachineKey.String(), }, - Online: p.Online().Clone(), - TailscaleSSHEnabled: p.Hostinfo().TailscaleSSHEnabled(), - } - }), - LockedOut: nm.TKAEnabled && nm.SelfNode.KeySignature().Len() == 0, - } - if jsonNetMap, err := json.Marshal(jsNetMap); err == nil { - jsCallbacks.Call("notifyNetMap", string(jsonNetMap)) - } else { - log.Printf("Could not generate JSON netmap: %v", err) + MachineStatus: jsMachineStatus[nm.GetMachineStatus()], + }, + Peers: mapSlice(nm.Peers, func(p tailcfg.NodeView) jsNetMapPeerNode { + name := p.Name() + if name == "" { + // In practice this should only happen for Hello. + name = p.Hostinfo().Hostname() + } + addrs := make([]string, p.Addresses().Len()) + for i, ap := range p.Addresses().All() { + addrs[i] = ap.Addr().String() + } + return jsNetMapPeerNode{ + jsNetMapNode: jsNetMapNode{ + Name: name, + Addresses: addrs, + MachineKey: p.Machine().String(), + NodeKey: p.Key().String(), + }, + Online: p.Online().Clone(), + TailscaleSSHEnabled: p.Hostinfo().TailscaleSSHEnabled(), + } + }), + LockedOut: nm.TKAEnabled && nm.SelfNode.KeySignature().Len() == 0, + } + if jsonNetMap, err := json.Marshal(jsNetMap); err == nil { + jsCallbacks.Call("notifyNetMap", string(jsonNetMap)) + } else { + log.Printf("Could not generate JSON netmap: %v", err) + } } } if n.BrowseToURL != nil { diff --git a/cmd/tsidp/depaware.txt b/cmd/tsidp/depaware.txt index 360437860..cf1a4c279 100644 --- a/cmd/tsidp/depaware.txt +++ b/cmd/tsidp/depaware.txt @@ -6,77 +6,6 @@ tailscale.com/cmd/tsidp dependencies: (generated by github.com/tailscale/depawar W 💣 github.com/alexbrainman/sspi from github.com/alexbrainman/sspi/internal/common+ W github.com/alexbrainman/sspi/internal/common from github.com/alexbrainman/sspi/negotiate W 💣 github.com/alexbrainman/sspi/negotiate from tailscale.com/net/tshttpproxy - github.com/aws/aws-sdk-go-v2/aws from github.com/aws/aws-sdk-go-v2/aws/defaults+ - github.com/aws/aws-sdk-go-v2/aws/defaults from github.com/aws/aws-sdk-go-v2/service/sso+ - github.com/aws/aws-sdk-go-v2/aws/middleware from github.com/aws/aws-sdk-go-v2/aws/retry+ - github.com/aws/aws-sdk-go-v2/aws/protocol/query from github.com/aws/aws-sdk-go-v2/service/sts - github.com/aws/aws-sdk-go-v2/aws/protocol/restjson from github.com/aws/aws-sdk-go-v2/service/sso+ - github.com/aws/aws-sdk-go-v2/aws/protocol/xml from github.com/aws/aws-sdk-go-v2/service/sts - github.com/aws/aws-sdk-go-v2/aws/ratelimit from github.com/aws/aws-sdk-go-v2/aws/retry - github.com/aws/aws-sdk-go-v2/aws/retry from github.com/aws/aws-sdk-go-v2/credentials/endpointcreds/internal/client+ - github.com/aws/aws-sdk-go-v2/aws/signer/internal/v4 from github.com/aws/aws-sdk-go-v2/aws/signer/v4 - github.com/aws/aws-sdk-go-v2/aws/signer/v4 from github.com/aws/aws-sdk-go-v2/internal/auth/smithy+ - github.com/aws/aws-sdk-go-v2/aws/transport/http from github.com/aws/aws-sdk-go-v2/config+ - github.com/aws/aws-sdk-go-v2/config from tailscale.com/wif - github.com/aws/aws-sdk-go-v2/credentials from github.com/aws/aws-sdk-go-v2/config - github.com/aws/aws-sdk-go-v2/credentials/ec2rolecreds from github.com/aws/aws-sdk-go-v2/config - github.com/aws/aws-sdk-go-v2/credentials/endpointcreds from github.com/aws/aws-sdk-go-v2/config - github.com/aws/aws-sdk-go-v2/credentials/endpointcreds/internal/client from github.com/aws/aws-sdk-go-v2/credentials/endpointcreds - github.com/aws/aws-sdk-go-v2/credentials/processcreds from github.com/aws/aws-sdk-go-v2/config - github.com/aws/aws-sdk-go-v2/credentials/ssocreds from github.com/aws/aws-sdk-go-v2/config - github.com/aws/aws-sdk-go-v2/credentials/stscreds from github.com/aws/aws-sdk-go-v2/config - github.com/aws/aws-sdk-go-v2/feature/ec2/imds from github.com/aws/aws-sdk-go-v2/config+ - github.com/aws/aws-sdk-go-v2/feature/ec2/imds/internal/config from github.com/aws/aws-sdk-go-v2/feature/ec2/imds - github.com/aws/aws-sdk-go-v2/internal/auth from github.com/aws/aws-sdk-go-v2/aws/signer/v4+ - github.com/aws/aws-sdk-go-v2/internal/auth/smithy from github.com/aws/aws-sdk-go-v2/service/sso+ - github.com/aws/aws-sdk-go-v2/internal/configsources from github.com/aws/aws-sdk-go-v2/service/sso+ - github.com/aws/aws-sdk-go-v2/internal/context from github.com/aws/aws-sdk-go-v2/aws/retry+ - github.com/aws/aws-sdk-go-v2/internal/endpoints from github.com/aws/aws-sdk-go-v2/service/sso+ - github.com/aws/aws-sdk-go-v2/internal/endpoints/awsrulesfn from github.com/aws/aws-sdk-go-v2/service/sso+ - github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 from github.com/aws/aws-sdk-go-v2/service/sso/internal/endpoints+ - github.com/aws/aws-sdk-go-v2/internal/ini from github.com/aws/aws-sdk-go-v2/config - github.com/aws/aws-sdk-go-v2/internal/middleware from github.com/aws/aws-sdk-go-v2/service/sso+ - github.com/aws/aws-sdk-go-v2/internal/rand from github.com/aws/aws-sdk-go-v2/aws+ - github.com/aws/aws-sdk-go-v2/internal/sdk from github.com/aws/aws-sdk-go-v2/aws+ - github.com/aws/aws-sdk-go-v2/internal/sdkio from github.com/aws/aws-sdk-go-v2/credentials/processcreds - github.com/aws/aws-sdk-go-v2/internal/shareddefaults from github.com/aws/aws-sdk-go-v2/config+ - github.com/aws/aws-sdk-go-v2/internal/strings from github.com/aws/aws-sdk-go-v2/aws/signer/internal/v4 - github.com/aws/aws-sdk-go-v2/internal/sync/singleflight from github.com/aws/aws-sdk-go-v2/aws - github.com/aws/aws-sdk-go-v2/internal/timeconv from github.com/aws/aws-sdk-go-v2/aws/retry - github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding from github.com/aws/aws-sdk-go-v2/service/sts - github.com/aws/aws-sdk-go-v2/service/internal/presigned-url from github.com/aws/aws-sdk-go-v2/service/sts - github.com/aws/aws-sdk-go-v2/service/sso from github.com/aws/aws-sdk-go-v2/config+ - github.com/aws/aws-sdk-go-v2/service/sso/internal/endpoints from github.com/aws/aws-sdk-go-v2/service/sso - github.com/aws/aws-sdk-go-v2/service/sso/types from github.com/aws/aws-sdk-go-v2/service/sso - github.com/aws/aws-sdk-go-v2/service/ssooidc from github.com/aws/aws-sdk-go-v2/config+ - github.com/aws/aws-sdk-go-v2/service/ssooidc/internal/endpoints from github.com/aws/aws-sdk-go-v2/service/ssooidc - github.com/aws/aws-sdk-go-v2/service/ssooidc/types from github.com/aws/aws-sdk-go-v2/service/ssooidc - github.com/aws/aws-sdk-go-v2/service/sts from github.com/aws/aws-sdk-go-v2/config+ - github.com/aws/aws-sdk-go-v2/service/sts/internal/endpoints from github.com/aws/aws-sdk-go-v2/service/sts - github.com/aws/aws-sdk-go-v2/service/sts/types from github.com/aws/aws-sdk-go-v2/credentials/stscreds+ - github.com/aws/smithy-go from github.com/aws/aws-sdk-go-v2/aws/protocol/restjson+ - github.com/aws/smithy-go/auth from github.com/aws/aws-sdk-go-v2/internal/auth+ - github.com/aws/smithy-go/auth/bearer from github.com/aws/aws-sdk-go-v2/aws+ - github.com/aws/smithy-go/context from github.com/aws/smithy-go/auth/bearer - github.com/aws/smithy-go/document from github.com/aws/aws-sdk-go-v2/service/sso+ - github.com/aws/smithy-go/encoding from github.com/aws/smithy-go/encoding/json+ - github.com/aws/smithy-go/encoding/httpbinding from github.com/aws/aws-sdk-go-v2/aws/protocol/query+ - github.com/aws/smithy-go/encoding/json from github.com/aws/aws-sdk-go-v2/service/ssooidc - github.com/aws/smithy-go/encoding/xml from github.com/aws/aws-sdk-go-v2/service/sts - github.com/aws/smithy-go/endpoints from github.com/aws/aws-sdk-go-v2/service/sso+ - github.com/aws/smithy-go/endpoints/private/rulesfn from github.com/aws/aws-sdk-go-v2/service/sts - github.com/aws/smithy-go/internal/sync/singleflight from github.com/aws/smithy-go/auth/bearer - github.com/aws/smithy-go/io from github.com/aws/aws-sdk-go-v2/feature/ec2/imds+ - github.com/aws/smithy-go/logging from github.com/aws/aws-sdk-go-v2/aws+ - github.com/aws/smithy-go/metrics from github.com/aws/aws-sdk-go-v2/aws/retry+ - github.com/aws/smithy-go/middleware from github.com/aws/aws-sdk-go-v2/aws+ - github.com/aws/smithy-go/private/requestcompression from github.com/aws/aws-sdk-go-v2/config - github.com/aws/smithy-go/ptr from github.com/aws/aws-sdk-go-v2/aws+ - github.com/aws/smithy-go/rand from github.com/aws/aws-sdk-go-v2/aws/middleware - github.com/aws/smithy-go/time from github.com/aws/aws-sdk-go-v2/service/sso+ - github.com/aws/smithy-go/tracing from github.com/aws/aws-sdk-go-v2/aws/middleware+ - github.com/aws/smithy-go/transport/http from github.com/aws/aws-sdk-go-v2/aws+ - github.com/aws/smithy-go/transport/http/internal/io from github.com/aws/smithy-go/transport/http github.com/coder/websocket from tailscale.com/util/eventbus github.com/coder/websocket/internal/errd from github.com/coder/websocket github.com/coder/websocket/internal/util from github.com/coder/websocket @@ -105,7 +34,6 @@ tailscale.com/cmd/tsidp dependencies: (generated by github.com/tailscale/depawar L 💣 github.com/godbus/dbus/v5 from tailscale.com/net/dns github.com/golang/groupcache/lru from tailscale.com/net/dnscache github.com/google/btree from gvisor.dev/gvisor/pkg/tcpip/transport/tcp - D github.com/google/uuid from github.com/prometheus-community/pro-bing github.com/hdevalence/ed25519consensus from tailscale.com/tka github.com/huin/goupnp from github.com/huin/goupnp/dcps/internetgateway2+ github.com/huin/goupnp/dcps/internetgateway2 from tailscale.com/net/portmapper @@ -128,9 +56,8 @@ tailscale.com/cmd/tsidp dependencies: (generated by github.com/tailscale/depawar L 💣 github.com/mdlayher/socket from github.com/mdlayher/netlink+ 💣 github.com/mitchellh/go-ps from tailscale.com/safesocket github.com/pires/go-proxyproto from tailscale.com/ipn/ipnlocal - D github.com/prometheus-community/pro-bing from tailscale.com/wgengine/netstack L 💣 github.com/safchain/ethtool from tailscale.com/net/netkernelconf - W 💣 github.com/tailscale/certstore from tailscale.com/control/controlclient + DW 💣 github.com/tailscale/certstore from tailscale.com/control/controlclient W 💣 github.com/tailscale/go-winio from tailscale.com/safesocket W 💣 github.com/tailscale/go-winio/internal/fs from github.com/tailscale/go-winio W 💣 github.com/tailscale/go-winio/internal/socket from github.com/tailscale/go-winio @@ -223,11 +150,9 @@ tailscale.com/cmd/tsidp dependencies: (generated by github.com/tailscale/depawar tailscale.com/feature/buildfeatures from tailscale.com/wgengine/magicsock+ tailscale.com/feature/c2n from tailscale.com/tsnet tailscale.com/feature/condlite/expvar from tailscale.com/wgengine/magicsock - tailscale.com/feature/condregister/identityfederation from tailscale.com/tsnet tailscale.com/feature/condregister/oauthkey from tailscale.com/tsnet tailscale.com/feature/condregister/portmapper from tailscale.com/tsnet tailscale.com/feature/condregister/useproxy from tailscale.com/tsnet - tailscale.com/feature/identityfederation from tailscale.com/feature/condregister/identityfederation tailscale.com/feature/oauthkey from tailscale.com/feature/condregister/oauthkey tailscale.com/feature/portmapper from tailscale.com/feature/condregister/portmapper tailscale.com/feature/syspolicy from tailscale.com/logpolicy @@ -309,7 +234,7 @@ tailscale.com/cmd/tsidp dependencies: (generated by github.com/tailscale/depawar tailscale.com/tstime from tailscale.com/control/controlclient+ tailscale.com/tstime/mono from tailscale.com/net/tstun+ tailscale.com/tstime/rate from tailscale.com/wgengine/filter - tailscale.com/tsweb from tailscale.com/util/eventbus + tailscale.com/tsweb from tailscale.com/util/eventbus+ tailscale.com/tsweb/varz from tailscale.com/tsweb+ tailscale.com/types/appctype from tailscale.com/ipn/ipnlocal+ tailscale.com/types/bools from tailscale.com/tsnet+ @@ -399,7 +324,6 @@ tailscale.com/cmd/tsidp dependencies: (generated by github.com/tailscale/depawar tailscale.com/wgengine/wgcfg/nmcfg from tailscale.com/ipn/ipnlocal 💣 tailscale.com/wgengine/wgint from tailscale.com/wgengine+ tailscale.com/wgengine/wglog from tailscale.com/wgengine - tailscale.com/wif from tailscale.com/feature/identityfederation golang.org/x/crypto/argon2 from tailscale.com/tka golang.org/x/crypto/blake2b from golang.org/x/crypto/argon2+ golang.org/x/crypto/blake2s from github.com/tailscale/wireguard-go/device+ @@ -421,16 +345,16 @@ tailscale.com/cmd/tsidp dependencies: (generated by github.com/tailscale/depawar golang.org/x/net/dns/dnsmessage from tailscale.com/appc+ golang.org/x/net/http/httpguts from tailscale.com/ipn/ipnlocal golang.org/x/net/http/httpproxy from tailscale.com/net/tshttpproxy - golang.org/x/net/icmp from github.com/prometheus-community/pro-bing+ + golang.org/x/net/icmp from tailscale.com/net/ping golang.org/x/net/idna from golang.org/x/net/http/httpguts+ golang.org/x/net/internal/iana from golang.org/x/net/icmp+ - golang.org/x/net/internal/socket from golang.org/x/net/icmp+ + golang.org/x/net/internal/socket from golang.org/x/net/ipv4+ golang.org/x/net/internal/socks from golang.org/x/net/proxy - golang.org/x/net/ipv4 from github.com/prometheus-community/pro-bing+ - golang.org/x/net/ipv6 from github.com/prometheus-community/pro-bing+ + golang.org/x/net/ipv4 from github.com/tailscale/wireguard-go/conn+ + golang.org/x/net/ipv6 from github.com/tailscale/wireguard-go/conn+ golang.org/x/net/proxy from tailscale.com/net/netns D golang.org/x/net/route from tailscale.com/net/netmon+ - golang.org/x/oauth2 from golang.org/x/oauth2/clientcredentials+ + golang.org/x/oauth2 from golang.org/x/oauth2/clientcredentials golang.org/x/oauth2/clientcredentials from tailscale.com/feature/oauthkey golang.org/x/oauth2/internal from golang.org/x/oauth2+ golang.org/x/sync/errgroup from github.com/mdlayher/socket+ @@ -533,12 +457,11 @@ tailscale.com/cmd/tsidp dependencies: (generated by github.com/tailscale/depawar crypto/sha3 from crypto/internal/fips140hash+ crypto/sha512 from crypto/ecdsa+ crypto/subtle from crypto/cipher+ - crypto/tls from github.com/prometheus-community/pro-bing+ + crypto/tls from net/http+ crypto/tls/internal/fips140tls from crypto/tls crypto/x509 from crypto/tls+ D crypto/x509/internal/macos from crypto/x509 crypto/x509/pkix from crypto/x509+ - D database/sql/driver from github.com/google/uuid W debug/dwarf from debug/pe W debug/pe from github.com/dblohm7/wingoes/pe embed from github.com/tailscale/web-client-prebuilt+ @@ -627,7 +550,7 @@ tailscale.com/cmd/tsidp dependencies: (generated by github.com/tailscale/depawar mime/quotedprintable from mime/multipart net from crypto/tls+ net/http from expvar+ - net/http/httptrace from github.com/prometheus-community/pro-bing+ + net/http/httptrace from net/http+ net/http/httputil from tailscale.com/client/web+ net/http/internal from net/http+ net/http/internal/ascii from net/http+ @@ -642,7 +565,7 @@ tailscale.com/cmd/tsidp dependencies: (generated by github.com/tailscale/depawar os/user from github.com/godbus/dbus/v5+ path from debug/dwarf+ path/filepath from crypto/x509+ - reflect from database/sql/driver+ + reflect from encoding/asn1+ regexp from github.com/huin/goupnp/httpu+ regexp/syntax from regexp runtime from crypto/internal/fips140+ diff --git a/cmd/tsnet-proxy/tsnet-proxy.go b/cmd/tsnet-proxy/tsnet-proxy.go new file mode 100644 index 000000000..0a83fd1a8 --- /dev/null +++ b/cmd/tsnet-proxy/tsnet-proxy.go @@ -0,0 +1,173 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +// The tsnet-proxy command exposes a local port on the tailnet under a +// chosen hostname. By default it proxies raw TCP; pass --http to reverse +// proxy as HTTP, or --https to reverse proxy as HTTPS with an auto-issued +// Tailscale cert. Both HTTP modes inject Tailscale-User-* identity headers +// from WhoIs. +// +// Arguments are [tailnet]: local is the port on localhost +// to proxy to and tailnet is the port to expose on the tailnet. If tailnet +// is omitted, it defaults to 443 for --https, 80 for --http, and the local +// port otherwise. +// +// go run ./cmd/tsnet-proxy myapp 8080 # raw TCP, tailnet :8080 +// go run ./cmd/tsnet-proxy myapp 22 2222 # raw TCP, tailnet :2222 +// go run ./cmd/tsnet-proxy --http myapp 8080 # tailnet :80 +// go run ./cmd/tsnet-proxy --https myapp 8080 # tailnet :443 +// +// Or run directly from the module, no checkout required: +// +// go run tailscale.com/cmd/tsnet-proxy@latest myapp 8080 +package main + +import ( + "flag" + "fmt" + "io" + "log" + "mime" + "net" + "net/http" + "net/http/httputil" + "net/url" + "os" + "strconv" + "unicode/utf8" + + "tailscale.com/client/local" + "tailscale.com/tsnet" +) + +func main() { + asHTTP := flag.Bool("http", false, "reverse proxy as HTTP and inject Tailscale-User-* headers") + asHTTPS := flag.Bool("https", false, "reverse proxy as HTTPS with an auto-issued Tailscale cert; implies --http") + dir := flag.String("dir", "", "directory to persist tsnet state (default: per-user config dir)") + verbose := flag.Bool("v", false, "verbose tsnet backend logs") + flag.Usage = func() { + fmt.Fprintf(flag.CommandLine.Output(), "usage: %s [flags] [tailnet]\n", flag.CommandLine.Name()) + flag.PrintDefaults() + } + flag.Parse() + + if n := flag.NArg(); n != 2 && n != 3 { + flag.Usage() + os.Exit(2) + } + name := flag.Arg(0) + localPort, err := parsePort(flag.Arg(1)) + if err != nil { + log.Fatalf("invalid local port %q: %v", flag.Arg(1), err) + } + tailnetPort := defaultTailnetPort(localPort, *asHTTP, *asHTTPS) + if flag.NArg() == 3 { + tailnetPort, err = parsePort(flag.Arg(2)) + if err != nil { + log.Fatalf("invalid tailnet port %q: %v", flag.Arg(2), err) + } + } + + target := "localhost:" + strconv.Itoa(localPort) + addr := ":" + strconv.Itoa(tailnetPort) + + s := &tsnet.Server{Hostname: name, Dir: *dir} + if *verbose { + s.Logf = log.Printf + } + defer s.Close() + + var ln net.Listener + if *asHTTPS { + ln, err = s.ListenTLS("tcp", addr) + } else { + ln, err = s.Listen("tcp", addr) + } + if err != nil { + log.Fatal(err) + } + defer ln.Close() + + log.Printf("proxying %s -> %s on tailnet", target, name+addr) + + if *asHTTP || *asHTTPS { + lc, err := s.LocalClient() + if err != nil { + log.Fatal(err) + } + targetURL := &url.URL{Scheme: "http", Host: target} + rp := &httputil.ReverseProxy{ + Rewrite: func(r *httputil.ProxyRequest) { + r.SetURL(targetURL) + r.SetXForwarded() + addTailscaleIdentityHeaders(lc, r) + }, + } + log.Fatal(http.Serve(ln, rp)) + } + + for { + c, err := ln.Accept() + if err != nil { + log.Fatal(err) + } + go proxyTCP(c, target) + } +} + +func parsePort(s string) (int, error) { + p, err := strconv.Atoi(s) + if err != nil || p <= 0 || p > 65535 { + return 0, fmt.Errorf("bad port") + } + return p, nil +} + +// defaultTailnetPort returns the tailnet port when the user didn't +// specify one: 443 for HTTPS, 80 for HTTP, else the local port. +func defaultTailnetPort(local int, asHTTP, asHTTPS bool) int { + switch { + case asHTTPS: + return 443 + case asHTTP: + return 80 + } + return local +} + +func proxyTCP(c net.Conn, target string) { + defer c.Close() + d, err := net.Dial("tcp", target) + if err != nil { + log.Printf("dial %s: %v", target, err) + return + } + defer d.Close() + go io.Copy(d, c) + io.Copy(c, d) +} + +func addTailscaleIdentityHeaders(lc *local.Client, r *httputil.ProxyRequest) { + r.Out.Header.Del("Tailscale-User-Login") + r.Out.Header.Del("Tailscale-User-Name") + r.Out.Header.Del("Tailscale-User-Profile-Pic") + r.Out.Header.Del("Tailscale-Funnel-Request") + r.Out.Header.Del("Tailscale-Headers-Info") + + who, err := lc.WhoIs(r.In.Context(), r.In.RemoteAddr) + if err != nil || who == nil || who.Node.IsTagged() { + return + } + r.Out.Header.Set("Tailscale-User-Login", encHeader(who.UserProfile.LoginName)) + r.Out.Header.Set("Tailscale-User-Name", encHeader(who.UserProfile.DisplayName)) + r.Out.Header.Set("Tailscale-User-Profile-Pic", who.UserProfile.ProfilePicURL) +} + +// encHeader mirrors the encoding tailscaled's serve path applies to +// user-provided strings destined for HTTP headers. +func encHeader(v string) string { + if !utf8.ValidString(v) { + return "" + } + return mime.QEncoding.Encode("utf-8", v) +} diff --git a/cmd/tsp/tsp.go b/cmd/tsp/tsp.go new file mode 100644 index 000000000..a59b352d5 --- /dev/null +++ b/cmd/tsp/tsp.go @@ -0,0 +1,513 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +// Program tsp is a low-level Tailscale protocol tool for performing +// composable building block operations like generating keys and +// registering nodes. +package main + +import ( + "bytes" + "cmp" + "context" + "encoding/json" + "errors" + "flag" + "fmt" + "io" + "os" + "reflect" + "strings" + + "github.com/peterbourgon/ff/v3/ffcli" + "tailscale.com/control/tsp" + "tailscale.com/hostinfo" + "tailscale.com/tailcfg" + "tailscale.com/types/key" +) + +var globalArgs struct { + // serverURL is the base URL of the coordination server (-s flag). + // If empty, tsp.DefaultServerURL is used. + serverURL string + + // controlKeyFile is a path to a file containing the server's + // MachinePublic key in MarshalText form (--control-key flag). + // When set, server key discovery is skipped. + controlKeyFile string +} + +func main() { + args := os.Args[1:] + if err := rootCmd.Parse(args); err != nil { + fmt.Fprintln(os.Stderr, err.Error()) + os.Exit(1) + } + err := rootCmd.Run(context.Background()) + if errors.Is(err, flag.ErrHelp) { + os.Exit(0) + } + if err != nil { + fmt.Fprintln(os.Stderr, err.Error()) + os.Exit(1) + } +} + +var rootCmd = &ffcli.Command{ + Name: "tsp", + ShortUsage: "tsp [-s url] [flags]", + ShortHelp: "Low-level Tailscale protocol tool.", + FlagSet: (func() *flag.FlagSet { + fs := flag.NewFlagSet("tsp", flag.ExitOnError) + fs.StringVar(&globalArgs.serverURL, "s", "", "base URL of coordination server (default: "+tsp.DefaultServerURL+")") + fs.StringVar(&globalArgs.controlKeyFile, "control-key", "", "file containing the server's public key (skips discovery)") + return fs + })(), + Subcommands: []*ffcli.Command{ + newMachineKeyCmd, + newNodeKeyCmd, + newNodeCmd, + registerCmd, + mapCmd, + discoverServerKeyCmd, + }, + Exec: func(ctx context.Context, args []string) error { + return flag.ErrHelp + }, +} + +var newMachineKeyArgs struct { + output string +} + +var newMachineKeyCmd = &ffcli.Command{ + Name: "new-machine-key", + ShortUsage: "tsp new-machine-key [-o file]", + ShortHelp: "Generate a new machine key.", + FlagSet: (func() *flag.FlagSet { + fs := flag.NewFlagSet("new-machine-key", flag.ExitOnError) + fs.StringVar(&newMachineKeyArgs.output, "o", "", "output file (default: stdout)") + return fs + })(), + Exec: runNewMachineKey, +} + +func runNewMachineKey(ctx context.Context, args []string) error { + k := key.NewMachine() + text, err := k.MarshalText() + if err != nil { + return err + } + text = append(text, '\n') + return writeOutput(newMachineKeyArgs.output, text) +} + +var newNodeKeyArgs struct { + output string +} + +var newNodeKeyCmd = &ffcli.Command{ + Name: "new-node-key", + ShortUsage: "tsp new-node-key [-o file]", + ShortHelp: "Generate a new node key.", + FlagSet: (func() *flag.FlagSet { + fs := flag.NewFlagSet("new-node-key", flag.ExitOnError) + fs.StringVar(&newNodeKeyArgs.output, "o", "", "output file (default: stdout)") + return fs + })(), + Exec: runNewNodeKey, +} + +func runNewNodeKey(ctx context.Context, args []string) error { + k := key.NewNode() + text, err := k.MarshalText() + if err != nil { + return err + } + text = append(text, '\n') + return writeOutput(newNodeKeyArgs.output, text) +} + +var discoverServerKeyArgs struct { + output string +} + +var discoverServerKeyCmd = &ffcli.Command{ + Name: "discover-server-key", + ShortUsage: "tsp [-s url] discover-server-key [-o file]", + ShortHelp: "Discover and print the coordination server's public key.", + FlagSet: (func() *flag.FlagSet { + fs := flag.NewFlagSet("discover-server-key", flag.ExitOnError) + fs.StringVar(&discoverServerKeyArgs.output, "o", "", "output file (default: stdout)") + return fs + })(), + Exec: runDiscoverServerKey, +} + +func runDiscoverServerKey(ctx context.Context, args []string) error { + k, err := tsp.DiscoverServerKey(ctx, globalArgs.serverURL) + if err != nil { + return err + } + text, err := k.MarshalText() + if err != nil { + return fmt.Errorf("marshaling server key: %w", err) + } + text = append(text, '\n') + return writeOutput(discoverServerKeyArgs.output, text) +} + +var newNodeArgs struct { + nodeKeyFile string + machineKeyFile string + output string +} + +var newNodeCmd = &ffcli.Command{ + Name: "new-node", + ShortUsage: "tsp [-s url] [--control-key file] new-node [-n node-key-file] [-m machine-key-file] [-o output]", + ShortHelp: "Generate a new node JSON file with keys and server info.", + FlagSet: (func() *flag.FlagSet { + fs := flag.NewFlagSet("new-node", flag.ExitOnError) + fs.StringVar(&newNodeArgs.nodeKeyFile, "n", "", "existing node key file (default: generate new)") + fs.StringVar(&newNodeArgs.machineKeyFile, "m", "", "existing machine key file (default: generate new)") + fs.StringVar(&newNodeArgs.output, "o", "", "output file (default: stdout)") + return fs + })(), + Exec: runNewNode, +} + +func runNewNode(ctx context.Context, args []string) error { + var nodeKey key.NodePrivate + if newNodeArgs.nodeKeyFile != "" { + var err error + nodeKey, err = readNodeKeyFile(newNodeArgs.nodeKeyFile) + if err != nil { + return fmt.Errorf("reading node key: %w", err) + } + } else { + nodeKey = key.NewNode() + } + + var machineKey key.MachinePrivate + if newNodeArgs.machineKeyFile != "" { + var err error + machineKey, err = readMachineKeyFile(newNodeArgs.machineKeyFile) + if err != nil { + return fmt.Errorf("reading machine key: %w", err) + } + } else { + machineKey = key.NewMachine() + } + + serverURL := cmp.Or(globalArgs.serverURL, tsp.DefaultServerURL) + + var serverKey key.MachinePublic + if globalArgs.controlKeyFile != "" { + var err error + serverKey, err = readControlKeyFile(globalArgs.controlKeyFile) + if err != nil { + return fmt.Errorf("reading control key: %w", err) + } + } else { + var err error + serverKey, err = tsp.DiscoverServerKey(ctx, serverURL) + if err != nil { + return fmt.Errorf("discovering server key: %w", err) + } + } + + nf := tsp.NodeFile{ + NodeKey: nodeKey, + MachineKey: machineKey, + ServerInfo: tsp.ServerInfo{URL: serverURL, Key: serverKey}, + } + + out, err := json.MarshalIndent(nf, "", " ") + if err != nil { + return fmt.Errorf("encoding node file: %w", err) + } + out = append(out, '\n') + return writeOutput(newNodeArgs.output, out) +} + +var registerArgs struct { + nodeFile string + output string + hostname string + ephemeral bool + authKey string + tags string +} + +var registerCmd = &ffcli.Command{ + Name: "register", + ShortUsage: "tsp [-s url] register -n [flags]", + ShortHelp: "Register a node key with a coordination server.", + FlagSet: (func() *flag.FlagSet { + fs := flag.NewFlagSet("register", flag.ExitOnError) + fs.StringVar(®isterArgs.nodeFile, "n", "", "node JSON file (required)") + fs.StringVar(®isterArgs.output, "o", "", "output file (default: stdout)") + fs.StringVar(®isterArgs.hostname, "hostname", "", "hostname to register") + fs.BoolVar(®isterArgs.ephemeral, "ephemeral", false, "register as ephemeral node") + fs.StringVar(®isterArgs.authKey, "auth-key", "", "pre-authorized auth key or file containing one") + fs.StringVar(®isterArgs.tags, "tags", "", "comma-separated ACL tags") + return fs + })(), + Exec: runRegister, +} + +func runRegister(ctx context.Context, args []string) error { + if registerArgs.nodeFile == "" { + return fmt.Errorf("flag -n (node file) is required") + } + + nf, err := tsp.ReadNodeFile(registerArgs.nodeFile) + if err != nil { + return fmt.Errorf("reading node file: %w", err) + } + + hi := hostinfo.New() + if registerArgs.hostname != "" { + hi.Hostname = registerArgs.hostname + } + + var tags []string + if registerArgs.tags != "" { + tags = strings.Split(registerArgs.tags, ",") + } + + authKey, err := resolveAuthKey(registerArgs.authKey) + if err != nil { + return err + } + + client, err := tsp.NewClient(tsp.ClientOpts{ + ServerURL: cmp.Or(globalArgs.serverURL, nf.URL), + MachineKey: nf.MachineKey, + }) + if err != nil { + return fmt.Errorf("creating client: %w", err) + } + defer client.Close() + + if globalArgs.controlKeyFile != "" { + controlKey, err := readControlKeyFile(globalArgs.controlKeyFile) + if err != nil { + return fmt.Errorf("reading control key: %w", err) + } + client.SetControlPublicKey(controlKey) + } else { + client.SetControlPublicKey(nf.ServerInfo.Key) + } + + resp, err := client.Register(ctx, tsp.RegisterOpts{ + NodeKey: nf.NodeKey, + Hostinfo: hi, + Ephemeral: registerArgs.ephemeral, + AuthKey: authKey, + Tags: tags, + }) + if err != nil { + return err + } + + out, err := json.MarshalIndent(resp, "", " ") + if err != nil { + return fmt.Errorf("encoding response: %w", err) + } + out = append(out, '\n') + + if err := writeOutput(registerArgs.output, out); err != nil { + return err + } + + if resp.AuthURL != "" { + fmt.Fprintf(os.Stderr, "AuthURL: %s\n", resp.AuthURL) + } + return nil +} + +var mapArgs struct { + nodeFile string + stream bool + peers bool + quiet bool + output string +} + +var mapCmd = &ffcli.Command{ + Name: "map", + ShortUsage: "tsp [-s url] map -n [-stream]", + ShortHelp: "Send a map request to the coordination server.", + FlagSet: (func() *flag.FlagSet { + fs := flag.NewFlagSet("map", flag.ExitOnError) + fs.StringVar(&mapArgs.nodeFile, "n", "", "node JSON file (required)") + fs.BoolVar(&mapArgs.stream, "stream", false, "stream map responses") + fs.BoolVar(&mapArgs.peers, "peers", true, "include peers in map response") + fs.BoolVar(&mapArgs.quiet, "quiet", true, "suppress keepalives and handled c2n ping requests from output") + fs.StringVar(&mapArgs.output, "o", "", "output file (default: stdout)") + return fs + })(), + Exec: runMap, +} + +func runMap(ctx context.Context, args []string) error { + if mapArgs.nodeFile == "" { + return fmt.Errorf("flag -n (node file) is required") + } + + nf, err := tsp.ReadNodeFile(mapArgs.nodeFile) + if err != nil { + return fmt.Errorf("reading node file: %w", err) + } + + if globalArgs.serverURL != "" && globalArgs.serverURL != nf.URL { + return fmt.Errorf("server URL mismatch: -s flag is %q but node file is for %q", globalArgs.serverURL, nf.URL) + } + + hi := hostinfo.New() + + client, err := tsp.NewClient(tsp.ClientOpts{ + ServerURL: cmp.Or(globalArgs.serverURL, nf.URL), + MachineKey: nf.MachineKey, + }) + if err != nil { + return fmt.Errorf("creating client: %w", err) + } + defer client.Close() + + if globalArgs.controlKeyFile != "" { + controlKey, err := readControlKeyFile(globalArgs.controlKeyFile) + if err != nil { + return fmt.Errorf("reading control key: %w", err) + } + client.SetControlPublicKey(controlKey) + } else { + client.SetControlPublicKey(nf.ServerInfo.Key) + } + + session, err := client.Map(ctx, tsp.MapOpts{ + NodeKey: nf.NodeKey, + Hostinfo: hi, + Stream: mapArgs.stream, + OmitPeers: !mapArgs.peers, + }) + if err != nil { + return err + } + defer session.Close() + + gotResponse := false + for { + resp, err := session.Next() + if err == io.EOF { + if !gotResponse { + return fmt.Errorf("server returned no map response") + } + return nil + } + if err != nil { + return fmt.Errorf("reading map response: %w", err) + } + gotResponse = true + + if pr := resp.PingRequest; pr != nil && pr.Types == "c2n" { + if client.AnswerC2NPing(ctx, pr, session.NoiseRoundTrip) && mapArgs.quiet { + resp.PingRequest = nil + } + } + if mapArgs.quiet { + resp.KeepAlive = false + } + + if isZeroMapResponse(resp) { + continue + } + + out, err := json.MarshalIndent(resp, "", " ") + if err != nil { + return fmt.Errorf("encoding response: %w", err) + } + out = append(out, '\n') + if err := writeOutput(mapArgs.output, out); err != nil { + return err + } + } +} + +// readMachineKeyFile reads a machine private key from a file. +func readMachineKeyFile(path string) (key.MachinePrivate, error) { + data, err := os.ReadFile(path) + if err != nil { + return key.MachinePrivate{}, err + } + var k key.MachinePrivate + if err := k.UnmarshalText(bytes.TrimSpace(data)); err != nil { + return key.MachinePrivate{}, fmt.Errorf("parsing machine key from %q: %w", path, err) + } + return k, nil +} + +// readNodeKeyFile reads a node private key from a file. +func readNodeKeyFile(path string) (key.NodePrivate, error) { + data, err := os.ReadFile(path) + if err != nil { + return key.NodePrivate{}, err + } + var k key.NodePrivate + if err := k.UnmarshalText(bytes.TrimSpace(data)); err != nil { + return key.NodePrivate{}, fmt.Errorf("parsing node key from %q: %w", path, err) + } + return k, nil +} + +// readControlKeyFile reads a file containing a server's MachinePublic key +// in its MarshalText form (e.g. "mkey:..."). +func readControlKeyFile(path string) (key.MachinePublic, error) { + data, err := os.ReadFile(path) + if err != nil { + return key.MachinePublic{}, err + } + var k key.MachinePublic + if err := k.UnmarshalText(bytes.TrimSpace(data)); err != nil { + return key.MachinePublic{}, fmt.Errorf("parsing control key from %q: %w", path, err) + } + return k, nil +} + +// resolveAuthKey returns the auth key from v. If v is empty, it returns "". +// If v starts with "tskey-", it's used directly. Otherwise v is treated as a +// filename and its contents are read and trimmed. +func resolveAuthKey(v string) (string, error) { + if v == "" { + return "", nil + } + if strings.HasPrefix(strings.TrimSpace(v), "tskey-") { + return strings.TrimSpace(v), nil + } + data, err := os.ReadFile(v) + if err != nil { + return "", fmt.Errorf("reading auth key file: %w", err) + } + return strings.TrimSpace(string(data)), nil +} + +func writeOutput(path string, data []byte) error { + if path == "" { + _, err := os.Stdout.Write(data) + return err + } + return os.WriteFile(path, data, 0600) +} + +// isZeroMapResponse reports whether all fields of resp are zero values. +func isZeroMapResponse(resp *tailcfg.MapResponse) bool { + v := reflect.ValueOf(*resp) + for i := range v.NumField() { + if !v.Field(i).IsZero() { + return false + } + } + return true +} diff --git a/cmd/tta/bypass_linux.go b/cmd/tta/bypass_linux.go new file mode 100644 index 000000000..868cd716f --- /dev/null +++ b/cmd/tta/bypass_linux.go @@ -0,0 +1,39 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package main + +import ( + "fmt" + "syscall" + + "golang.org/x/sys/unix" + "tailscale.com/net/netmon" +) + +// bypassControlFunc is set as net.Dialer.Control so that sockets dialed by +// TTA bypass tailscaled's policy routing. Without it, sockets opened before +// tailscaled installs an exit-node route would have their packets rerouted +// via the exit node when the route is later installed, breaking the +// existing connection. +// +// We bind the socket to the default route's interface (typically the VM's +// LAN-facing NIC) rather than relying on the bypass fwmark. The fwmark +// approach is conditional on tailscaled having configured SO_MARK-based +// policy routing; binding to the underlying interface is unconditional. +func bypassControlFunc(network, address string, c syscall.RawConn) error { + ifc, err := netmon.DefaultRouteInterface() + if err != nil { + return fmt.Errorf("netmon.DefaultRouteInterface: %w", err) + } + var sockErr error + if err := c.Control(func(fd uintptr) { + sockErr = unix.SetsockoptString(int(fd), unix.SOL_SOCKET, unix.SO_BINDTODEVICE, ifc) + }); err != nil { + return err + } + if sockErr != nil { + return fmt.Errorf("setting SO_BINDTODEVICE on %q: %w", ifc, sockErr) + } + return nil +} diff --git a/cmd/tta/bypass_other.go b/cmd/tta/bypass_other.go new file mode 100644 index 000000000..e6b453f49 --- /dev/null +++ b/cmd/tta/bypass_other.go @@ -0,0 +1,14 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !linux + +package main + +import "syscall" + +// bypassControlFunc is a no-op on non-Linux platforms; SO_MARK is a Linux +// concept and exit-node routing only matters here for Linux VMs in vmtest. +func bypassControlFunc(network, address string, c syscall.RawConn) error { + return nil +} diff --git a/cmd/tta/ipassign_darwin.go b/cmd/tta/ipassign_darwin.go new file mode 100644 index 000000000..69a178956 --- /dev/null +++ b/cmd/tta/ipassign_darwin.go @@ -0,0 +1,135 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +//go:build darwin + +package main + +import ( + "encoding/json" + "fmt" + "log" + "net" + "os/exec" + "strconv" + "time" + "unsafe" + + "golang.org/x/sys/unix" + "tailscale.com/tstest/natlab/vnet" +) + +const ( + afVSOCK = 40 // AF_VSOCK on macOS + vmaddrCIDHost = 2 // VMADDR_CID_HOST + vsockPort = 51011 // port for IP assignment protocol +) + +// sockaddrVM is the Go equivalent of struct sockaddr_vm from . +type sockaddrVM struct { + Len uint8 + Family uint8 + Reserved1 uint16 + Port uint32 + CID uint32 +} + +type netConfig struct { + IP string `json:"ip"` + Mask string `json:"mask"` + GW string `json:"gw"` +} + +// startIPAssignLoop starts a background goroutine that polls the host +// via the virtio socket for an IP assignment. When the host responds +// with a JSON config (rather than "wait"), TTA sets the IP statically +// using ifconfig and stops polling. +func startIPAssignLoop() { + go ipAssignLoop() +} + +func ipAssignLoop() { + log.Printf("ipassign: starting vsock poll loop") + var lastErr string + for attempt := 0; ; attempt++ { + resp, err := askHostForIP() + if err != nil { + if e := err.Error(); e != lastErr { + log.Printf("ipassign: attempt %d: %v", attempt, err) + lastErr = e + } + time.Sleep(500 * time.Millisecond) + continue + } + if resp == "wait" { + time.Sleep(500 * time.Millisecond) + continue + } + var nc netConfig + if err := json.Unmarshal([]byte(resp), &nc); err != nil { + log.Printf("ipassign: bad config: %v", err) + time.Sleep(500 * time.Millisecond) + continue + } + if err := setStaticIP(nc); err != nil { + log.Printf("ipassign: %v", err) + time.Sleep(500 * time.Millisecond) + continue + } + log.Printf("ipassign: configured en0 with %s/%s gw %s", nc.IP, nc.Mask, nc.GW) + + // Switch the driver address from the DNS name to the IP directly + // (avoids DNS resolution delay) and kick the dial-out loop so it + // retries immediately with the new address. + ipAddr := net.JoinHostPort(vnet.TestDriverIPv4().String(), strconv.Itoa(vnet.TestDriverPort)) + *driverAddr = ipAddr + log.Printf("ipassign: switched driver addr to %s", ipAddr) + resetDialCancels() + return + } +} + +// askHostForIP connects to the host via AF_VSOCK and reads the response. +func askHostForIP() (string, error) { + fd, err := unix.Socket(afVSOCK, unix.SOCK_STREAM, 0) + if err != nil { + return "", fmt.Errorf("socket: %w", err) + } + defer unix.Close(fd) + + // Set a short connect+read timeout via SO_RCVTIMEO. + tv := unix.Timeval{Sec: 1} + unix.SetsockoptTimeval(fd, unix.SOL_SOCKET, unix.SO_RCVTIMEO, &tv) + + addr := sockaddrVM{ + Len: uint8(unsafe.Sizeof(sockaddrVM{})), + Family: afVSOCK, + Port: vsockPort, + CID: vmaddrCIDHost, + } + _, _, errno := unix.RawSyscall(unix.SYS_CONNECT, uintptr(fd), + uintptr(unsafe.Pointer(&addr)), unsafe.Sizeof(addr)) + if errno != 0 { + return "", fmt.Errorf("connect: %w", errno) + } + + var buf [1024]byte + n, err := unix.Read(fd, buf[:]) + if err != nil { + return "", fmt.Errorf("read: %w", err) + } + return string(buf[:n]), nil +} + +// setStaticIP configures en0 with a static IP address and default route. +func setStaticIP(nc netConfig) error { + out, err := exec.Command("ifconfig", "en0", nc.IP, "netmask", nc.Mask, "up").CombinedOutput() + if err != nil { + return fmt.Errorf("ifconfig: %v: %s", err, out) + } + out, err = exec.Command("route", "add", "default", nc.GW).CombinedOutput() + if err != nil { + return fmt.Errorf("route add: %v: %s", err, out) + } + return nil +} diff --git a/cmd/tta/ipassign_other.go b/cmd/tta/ipassign_other.go new file mode 100644 index 000000000..dc331b5e0 --- /dev/null +++ b/cmd/tta/ipassign_other.go @@ -0,0 +1,14 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !darwin + +package main + +// startIPAssignLoop is a no-op on non-macOS platforms. +// macOS VMs use vsock-based IP assignment to bypass slow DHCP. +func startIPAssignLoop() {} + +// Reference resetDialCancels to prevent unused-function lint errors. +// It's called from ipassign_darwin.go on macOS builds. +var _ = resetDialCancels diff --git a/cmd/tta/restart_tailscaled_linux.go b/cmd/tta/restart_tailscaled_linux.go new file mode 100644 index 000000000..accf2b404 --- /dev/null +++ b/cmd/tta/restart_tailscaled_linux.go @@ -0,0 +1,47 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package main + +import ( + "fmt" + "os" + "strconv" + "strings" +) + +func init() { + restartTailscaled = restartTailscaledLinux +} + +// restartTailscaledLinux finds the tailscaled process by walking /proc and +// sends it SIGKILL. On gokrazy, the supervisor will restart tailscaled within +// a few seconds. The PID of the process that was killed is returned. +func restartTailscaledLinux() (int, error) { + ents, err := os.ReadDir("/proc") + if err != nil { + return 0, err + } + for _, e := range ents { + pid, err := strconv.Atoi(e.Name()) + if err != nil { + continue + } + comm, err := os.ReadFile("/proc/" + e.Name() + "/comm") + if err != nil { + continue + } + if strings.TrimSpace(string(comm)) != "tailscaled" { + continue + } + proc, err := os.FindProcess(pid) + if err != nil { + return 0, err + } + if err := proc.Kill(); err != nil { + return 0, fmt.Errorf("killing tailscaled pid %d: %w", pid, err) + } + return pid, nil + } + return 0, fmt.Errorf("tailscaled process not found in /proc") +} diff --git a/cmd/tta/tta.go b/cmd/tta/tta.go index cf5dc4162..5dd1eddb9 100644 --- a/cmd/tta/tta.go +++ b/cmd/tta/tta.go @@ -24,7 +24,9 @@ import ( "net/url" "os" "os/exec" + "path/filepath" "regexp" + "runtime" "strconv" "strings" "sync" @@ -103,6 +105,10 @@ func main() { } flag.Parse() + // On macOS VMs, start polling the host via vsock for an IP assignment. + // This bypasses DHCP for near-instant network configuration. + startIPAssignLoop() + debug := false if distro.Get() == distro.Gokrazy { cmdLine, _ := os.ReadFile("/proc/cmdline") @@ -202,6 +208,9 @@ func main() { if routes := r.URL.Query().Get("advertise-routes"); routes != "" { args = append(args, "--advertise-routes="+routes) } + if snat := r.URL.Query().Get("snat-subnet-routes"); snat != "" { + args = append(args, "--snat-subnet-routes="+snat) + } serveCmd(w, "tailscale", args...) }) ttaMux.HandleFunc("/ip", func(w http.ResponseWriter, r *http.Request) { @@ -222,6 +231,20 @@ func main() { serveCmd(w, "ping", "-c", "4", "-W", "1", host) } }) + ttaMux.HandleFunc("/add-route", func(w http.ResponseWriter, r *http.Request) { + prefix := r.URL.Query().Get("prefix") + via := r.URL.Query().Get("via") + if prefix == "" || via == "" { + http.Error(w, "missing prefix or via", http.StatusBadRequest) + return + } + switch runtime.GOOS { + case "linux": + serveCmd(w, "ip", "route", "add", prefix, "via", via) + default: + http.Error(w, "add-route not supported on "+runtime.GOOS, http.StatusNotImplemented) + } + }) ttaMux.HandleFunc("/start-webserver", func(w http.ResponseWriter, r *http.Request) { port := r.URL.Query().Get("port") name := r.URL.Query().Get("name") @@ -236,7 +259,8 @@ func main() { go func() { mux := http.NewServeMux() mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { - fmt.Fprintf(w, "Hello world I am %s", name) + host, _, _ := net.SplitHostPort(r.RemoteAddr) + fmt.Fprintf(w, "Hello world I am %s from %s", name, host) }) if err := http.ListenAndServe(":"+port, mux); err != nil { log.Printf("webserver on :%s failed: %v", port, err) @@ -244,6 +268,72 @@ func main() { }() io.WriteString(w, "OK\n") }) + ttaMux.HandleFunc("/taildrop-send", func(w http.ResponseWriter, r *http.Request) { + to := r.URL.Query().Get("to") // peer's Tailscale IP + name := r.URL.Query().Get("name") + if to == "" || name == "" { + http.Error(w, "missing to or name", http.StatusBadRequest) + return + } + if strings.ContainsAny(name, "/\\") { + http.Error(w, "bad name", http.StatusBadRequest) + return + } + dir, err := os.MkdirTemp("", "taildrop-send-") + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + defer os.RemoveAll(dir) + path := filepath.Join(dir, name) + f, err := os.Create(path) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + if _, err := io.Copy(f, r.Body); err != nil { + f.Close() + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + if err := f.Close(); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + serveCmd(w, "tailscale", "file", "cp", path, to+":") + }) + ttaMux.HandleFunc("/taildrop-recv", func(w http.ResponseWriter, r *http.Request) { + dir, err := os.MkdirTemp("", "taildrop-recv-") + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + defer os.RemoveAll(dir) + ctx, cancel := context.WithTimeout(r.Context(), 60*time.Second) + defer cancel() + cmd := exec.CommandContext(ctx, absify("tailscale"), "file", "get", "--wait", dir) + if out, err := cmd.CombinedOutput(); err != nil { + http.Error(w, fmt.Sprintf("tailscale file get: %v\n%s", err, out), http.StatusInternalServerError) + return + } + ents, err := os.ReadDir(dir) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + if len(ents) != 1 { + http.Error(w, fmt.Sprintf("got %d files, want 1", len(ents)), http.StatusInternalServerError) + return + } + data, err := os.ReadFile(filepath.Join(dir, ents[0].Name())) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + w.Header().Set("Taildrop-Filename", ents[0].Name()) + w.Header().Set("Content-Type", "application/octet-stream") + w.Write(data) + }) ttaMux.HandleFunc("/http-get", func(w http.ResponseWriter, r *http.Request) { targetURL := r.URL.Query().Get("url") if targetURL == "" { @@ -294,6 +384,25 @@ func main() { io.Copy(w, resp.Body) }) ttaMux.HandleFunc("/fw", addFirewallHandler) + ttaMux.HandleFunc("/wg-server-up", func(w http.ResponseWriter, r *http.Request) { + if wgServerUp == nil { + http.Error(w, "wg-server-up not supported on this platform", http.StatusNotImplemented) + return + } + wgServerUp(w, r) + }) + ttaMux.HandleFunc("/restart-tailscaled", func(w http.ResponseWriter, r *http.Request) { + if restartTailscaled == nil { + http.Error(w, "restart-tailscaled not supported on this platform", http.StatusNotImplemented) + return + } + pid, err := restartTailscaled() + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + fmt.Fprintf(w, "killed tailscaled pid %d (supervisor will respawn)\n", pid) + }) ttaMux.HandleFunc("/logs", func(w http.ResponseWriter, r *http.Request) { logBuf.mu.Lock() defer logBuf.mu.Unlock() @@ -315,10 +424,48 @@ func main() { revSt.runDialOutLoop(conns) } +// dialCancels tracks cancel funcs for in-flight connect() and sleep contexts. +// resetDialCancels cancels them all so the dial loop retries immediately. +var ( + dialCancelMu sync.Mutex + dialCancels set.HandleSet[context.CancelFunc] +) + +// registerDialCancel adds a cancel func and returns a handle for removal. +func registerDialCancel(cancel context.CancelFunc) set.Handle { + dialCancelMu.Lock() + defer dialCancelMu.Unlock() + return dialCancels.Add(cancel) +} + +// unregisterDialCancel removes a previously registered cancel func. +func unregisterDialCancel(h set.Handle) { + dialCancelMu.Lock() + defer dialCancelMu.Unlock() + delete(dialCancels, h) +} + +// resetDialCancels cancels all in-flight connect and sleep contexts, +// causing the dial loop to retry immediately with the updated driver address. +func resetDialCancels() { + dialCancelMu.Lock() + defer dialCancelMu.Unlock() + for h, cancel := range dialCancels { + cancel() + delete(dialCancels, h) + } +} + func connect() (net.Conn, error) { - var d net.Dialer + d := net.Dialer{ + Control: bypassControlFunc, + } ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) - defer cancel() + h := registerDialCancel(cancel) + defer func() { + cancel() + unregisterDialCancel(h) + }() c, err := d.DialContext(ctx, "tcp", *driverAddr) if err != nil { return nil, err @@ -415,7 +562,11 @@ func (s *revDialState) runDialOutLoop(conns chan<- net.Conn) { log.Printf("[dial-driver] connect failure: %v", s) } lastErr = s - time.Sleep(time.Second) + sleepCtx, sleepCancel := context.WithTimeout(context.Background(), time.Second) + h := registerDialCancel(sleepCancel) + <-sleepCtx.Done() + sleepCancel() + unregisterDialCancel(h) continue } if !connected { @@ -456,6 +607,16 @@ func addFirewallHandler(w http.ResponseWriter, r *http.Request) { var addFirewall func() error // set by fw_linux.go +// wgServerUp brings up a userspace WireGuard "Mullvad-style" exit-node +// server on this VM. It is set by wgserver_linux.go and is nil on +// non-Linux. +var wgServerUp func(w http.ResponseWriter, r *http.Request) + +// restartTailscaled sends SIGKILL to the local tailscaled process so the +// gokrazy supervisor restarts it. It is set by restart_tailscaled_linux.go +// and is nil on non-Linux. +var restartTailscaled func() (pid int, err error) + // logBuffer is a bytes.Buffer that is safe for concurrent use // intended to capture early logs from the process, even if // gokrazy's syslog streaming isn't working or yet working. diff --git a/cmd/tta/wgserver_linux.go b/cmd/tta/wgserver_linux.go new file mode 100644 index 000000000..10d6bbe28 --- /dev/null +++ b/cmd/tta/wgserver_linux.go @@ -0,0 +1,155 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package main + +import ( + "cmp" + "crypto/rand" + "encoding/base64" + "encoding/hex" + "fmt" + "log" + "net/http" + "os" + "os/exec" + "sync" + + "github.com/tailscale/wireguard-go/conn" + "github.com/tailscale/wireguard-go/device" + "github.com/tailscale/wireguard-go/tun" + "golang.org/x/crypto/curve25519" + "tailscale.com/wgengine/wgcfg" +) + +func init() { + wgServerUp = wgServerUpLinux +} + +var ( + wgServerMu sync.Mutex + wgServerDev *device.Device // retained so the goroutines stay alive +) + +// wgServerUpLinux brings up a userspace WireGuard interface on the local VM +// configured as a single-peer "Mullvad-style" exit node, then sets up the +// kernel-side IP/forwarding/MASQUERADE so that decrypted traffic from the +// peer egresses to the test internet. +// +// Required URL query parameters: +// - addr: CIDR for the WG interface (e.g. "10.64.0.1/24") +// - listen-port: WG listen port +// - peer-pub-b64: base64-encoded 32-byte WG public key of the only peer +// - peer-allowed-ip: prefix the peer is allowed to source from +// (e.g. "10.64.0.2/32") +// - masq-src: prefix to MASQUERADE on egress (e.g. "10.64.0.0/24") +// +// Optional: +// - name: TUN device name (default "wg0") +// +// On success, it writes "PUBKEY=\n" — the freshly generated public +// key the caller must pin as the peer's WG public key. +func wgServerUpLinux(w http.ResponseWriter, r *http.Request) { + wgServerMu.Lock() + defer wgServerMu.Unlock() + if wgServerDev != nil { + http.Error(w, "wg server already up", http.StatusConflict) + return + } + + q := r.URL.Query() + name := cmp.Or(q.Get("name"), "wg0") + addr := q.Get("addr") + listenPort := q.Get("listen-port") + peerPubB64 := q.Get("peer-pub-b64") + peerAllowedIP := q.Get("peer-allowed-ip") + masqSrc := q.Get("masq-src") + for _, kv := range []struct{ k, v string }{ + {"addr", addr}, + {"listen-port", listenPort}, + {"peer-pub-b64", peerPubB64}, + {"peer-allowed-ip", peerAllowedIP}, + {"masq-src", masqSrc}, + } { + if kv.v == "" { + http.Error(w, "missing "+kv.k, http.StatusBadRequest) + return + } + } + + peerPub, err := base64.StdEncoding.DecodeString(peerPubB64) + if err != nil || len(peerPub) != 32 { + http.Error(w, fmt.Sprintf("bad peer-pub-b64: %v (len=%d)", err, len(peerPub)), http.StatusBadRequest) + return + } + + var priv [32]byte + if _, err := rand.Read(priv[:]); err != nil { + http.Error(w, "rand: "+err.Error(), http.StatusInternalServerError) + return + } + // X25519 key clamping. + priv[0] &= 248 + priv[31] = (priv[31] & 127) | 64 + + pub, err := curve25519.X25519(priv[:], curve25519.Basepoint) + if err != nil { + http.Error(w, "deriving pubkey: "+err.Error(), http.StatusInternalServerError) + return + } + + tdev, err := tun.CreateTUN(name, device.DefaultMTU) + if err != nil { + http.Error(w, "tun.CreateTUN: "+err.Error(), http.StatusInternalServerError) + return + } + wglog := &device.Logger{ + Verbosef: func(string, ...any) {}, + Errorf: func(f string, a ...any) { log.Printf("wg-server: "+f, a...) }, + } + dev := wgcfg.NewDevice(tdev, conn.NewDefaultBind(), wglog) + + uapi := fmt.Sprintf("private_key=%s\nlisten_port=%s\npublic_key=%s\nallowed_ip=%s\n", + hex.EncodeToString(priv[:]), listenPort, + hex.EncodeToString(peerPub), peerAllowedIP) + if err := dev.IpcSet(uapi); err != nil { + dev.Close() + http.Error(w, "IpcSet: "+err.Error(), http.StatusInternalServerError) + return + } + if err := dev.Up(); err != nil { + dev.Close() + http.Error(w, "dev.Up: "+err.Error(), http.StatusInternalServerError) + return + } + + steps := []struct { + why string + exec []string + file struct{ path, data string } + }{ + {why: "ip addr add", exec: []string{"ip", "addr", "add", addr, "dev", name}}, + {why: "ip link up", exec: []string{"ip", "link", "set", name, "up"}}, + {why: "enable forwarding", file: struct{ path, data string }{"/proc/sys/net/ipv4/ip_forward", "1\n"}}, + {why: "FORWARD policy", exec: []string{"iptables", "-P", "FORWARD", "ACCEPT"}}, + {why: "MASQUERADE", exec: []string{"iptables", "-t", "nat", "-A", "POSTROUTING", "-s", masqSrc, "-j", "MASQUERADE"}}, + } + for _, s := range steps { + if s.file.path != "" { + if err := os.WriteFile(s.file.path, []byte(s.file.data), 0644); err != nil { + dev.Close() + http.Error(w, fmt.Sprintf("%s: %v", s.why, err), http.StatusInternalServerError) + return + } + continue + } + if out, err := exec.Command(s.exec[0], s.exec[1:]...).CombinedOutput(); err != nil { + dev.Close() + http.Error(w, fmt.Sprintf("%s: %v: %s", s.why, err, out), http.StatusInternalServerError) + return + } + } + + wgServerDev = dev + fmt.Fprintf(w, "PUBKEY=%s\n", base64.StdEncoding.EncodeToString(pub)) +} diff --git a/cmd/vet/lowerell/analyzer.go b/cmd/vet/lowerell/analyzer.go new file mode 100644 index 000000000..a62f79bdc --- /dev/null +++ b/cmd/vet/lowerell/analyzer.go @@ -0,0 +1,132 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +// Package lowerell forbids variables named "l" (lowercase ell) or "I" +// (uppercase i), because they are hard to distinguish from the digit +// "1" and from each other in too many fonts. +package lowerell + +import ( + "go/ast" + "go/token" + + "golang.org/x/tools/go/analysis" +) + +// Analyzer reports variables named "l" (lowercase ell) or "I" (uppercase i). +var Analyzer = &analysis.Analyzer{ + Name: "lowerell", + Doc: `forbid variables named "l" (lowercase ell) or "I" (uppercase i), which are hard to distinguish from "1"`, + Run: run, +} + +// messages maps a banned identifier name to the diagnostic shown to users. +// Each message names the specific symbol that triggered it, so the +// reader does not have to guess which of "l" or "I" they typed. +var messages = map[string]string{ + "l": `do not use "l" (lowercase ell) as a variable name; it is hard to distinguish from "1" and "I" in too many fonts; see https://github.com/tailscale/tailscale/issues/19631`, + "I": `do not use "I" (uppercase i) as a variable name; it is hard to distinguish from "1" and "l" in too many fonts; see https://github.com/tailscale/tailscale/issues/19631`, +} + +// reported tracks identifier positions already reported, to avoid duplicate +// diagnostics when the same declaration is reachable from multiple AST nodes. +type reportedSet map[token.Pos]bool + +func (rs reportedSet) check(pass *analysis.Pass, ident *ast.Ident) { + if ident == nil { + return + } + msg, ok := messages[ident.Name] + if !ok { + return + } + if rs[ident.Pos()] { + return + } + rs[ident.Pos()] = true + pass.Reportf(ident.Pos(), "%s", msg) +} + +func (rs reportedSet) checkFieldList(pass *analysis.Pass, fl *ast.FieldList) { + if fl == nil { + return + } + for _, f := range fl.List { + for _, n := range f.Names { + rs.check(pass, n) + } + } +} + +func run(pass *analysis.Pass) (any, error) { + rs := reportedSet{} + + for _, file := range pass.Files { + ast.Inspect(file, func(n ast.Node) bool { + switch n := n.(type) { + case *ast.FuncDecl: + // Receiver name. + rs.checkFieldList(pass, n.Recv) + // Parameters, results, and type parameters + // are checked via the FuncType case below. + + case *ast.FuncType: + rs.checkFieldList(pass, n.TypeParams) + rs.checkFieldList(pass, n.Params) + rs.checkFieldList(pass, n.Results) + + case *ast.StructType: + rs.checkFieldList(pass, n.Fields) + + case *ast.GenDecl: + if n.Tok != token.VAR && n.Tok != token.CONST { + return true + } + for _, spec := range n.Specs { + vs, ok := spec.(*ast.ValueSpec) + if !ok { + continue + } + for _, name := range vs.Names { + rs.check(pass, name) + } + } + + case *ast.AssignStmt: + if n.Tok != token.DEFINE { + return true + } + for _, lhs := range n.Lhs { + if id, ok := lhs.(*ast.Ident); ok { + rs.check(pass, id) + } + } + + case *ast.RangeStmt: + if n.Tok != token.DEFINE { + return true + } + if id, ok := n.Key.(*ast.Ident); ok { + rs.check(pass, id) + } + if id, ok := n.Value.(*ast.Ident); ok { + rs.check(pass, id) + } + + case *ast.TypeSwitchStmt: + // switch l := x.(type) { ... } + as, ok := n.Assign.(*ast.AssignStmt) + if !ok || as.Tok != token.DEFINE { + return true + } + for _, lhs := range as.Lhs { + if id, ok := lhs.(*ast.Ident); ok { + rs.check(pass, id) + } + } + } + return true + }) + } + return nil, nil +} diff --git a/cmd/vet/lowerell/analyzer_test.go b/cmd/vet/lowerell/analyzer_test.go new file mode 100644 index 000000000..c566c2ec4 --- /dev/null +++ b/cmd/vet/lowerell/analyzer_test.go @@ -0,0 +1,15 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package lowerell + +import ( + "testing" + + "golang.org/x/tools/go/analysis/analysistest" +) + +func TestAnalyzer(t *testing.T) { + testdata := analysistest.TestData() + analysistest.Run(t, testdata, Analyzer, "example") +} diff --git a/cmd/vet/lowerell/testdata/src/example/example.go b/cmd/vet/lowerell/testdata/src/example/example.go new file mode 100644 index 000000000..c67c19781 --- /dev/null +++ b/cmd/vet/lowerell/testdata/src/example/example.go @@ -0,0 +1,100 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package example + +import "sync" + +// Bad: var declarations. +var l int // want `do not use "l"` +var I int // want `do not use "I"` + +// OK: variables named "ll", "II", "i" are fine. +var ( + ll int + II int + i int +) + +// Bad: const declaration in a function scope. +func F0() { + const l = 3 // want `do not use "l"` + const I = 4 // want `do not use "I"` + _ = l + _ = I +} + +// Bad: function parameters. +func F1a(l int) {} // want `do not use "l"` +func F1b(I int) {} // want `do not use "I"` + +// Bad: named return values. +func F2a() (l int) { return } // want `do not use "l"` +func F2b() (I int) { return } // want `do not use "I"` + +// Bad: receiver names. +type T struct{} + +func (l *T) Ml() {} // want `do not use "l"` +func (I *T) MI() {} // want `do not use "I"` + +// Bad: struct fields. +type S struct { + l int // want `do not use "l"` + I int // want `do not use "I"` +} + +// Bad: short variable declarations. +func F3() { + l := 1 // want `do not use "l"` + I := 2 // want `do not use "I"` + _ = l + _ = I +} + +// Bad: var statement inside a function. +func F4() { + var l int // want `do not use "l"` + var I int // want `do not use "I"` + _ = l + _ = I +} + +// Bad: range key/value. +func F5(xs []int) { + for l, v := range xs { // want `do not use "l"` + _ = l + _ = v + } + for _, I := range xs { // want `do not use "I"` + _ = I + } +} + +// Bad: type parameters. +func F6a[l any](x l) l { return x } // want `do not use "l"` +func F6b[I any](x I) I { return x } // want `do not use "I"` + +// Bad: type switch guards. +func F7(x any) { + switch l := x.(type) { // want `do not use "l"` + case int: + _ = l + } + switch I := x.(type) { // want `do not use "I"` + case int: + _ = I + } +} + +// OK: clean code with no banned variables. +func F8() { + count := 0 + for i := 0; i < 10; i++ { + count++ + } + _ = count +} + +// OK: sync.Mutex named "mu". +var mu sync.Mutex diff --git a/cmd/vet/vet.go b/cmd/vet/vet.go index 4087f7073..38ffdb6d0 100644 --- a/cmd/vet/vet.go +++ b/cmd/vet/vet.go @@ -9,6 +9,7 @@ import ( "golang.org/x/tools/go/analysis/unitchecker" "tailscale.com/cmd/vet/jsontags" + "tailscale.com/cmd/vet/lowerell" "tailscale.com/cmd/vet/subtestnames" ) @@ -21,5 +22,5 @@ func init() { } func main() { - unitchecker.Main(jsontags.Analyzer, subtestnames.Analyzer) + unitchecker.Main(jsontags.Analyzer, lowerell.Analyzer, subtestnames.Analyzer) } diff --git a/cmd/viewer/tests/tests_clone.go b/cmd/viewer/tests/tests_clone.go index 08cae87e2..bc576ec97 100644 --- a/cmd/viewer/tests/tests_clone.go +++ b/cmd/viewer/tests/tests_clone.go @@ -96,14 +96,36 @@ func (src *Map) Clone() *Map { dst.StructWithoutPtr = maps.Clone(src.StructWithoutPtr) if dst.SlicesWithPtrs != nil { dst.SlicesWithPtrs = map[string][]*StructWithPtrs{} - for k := range src.SlicesWithPtrs { - dst.SlicesWithPtrs[k] = append([]*StructWithPtrs{}, src.SlicesWithPtrs[k]...) + for k, sv := range src.SlicesWithPtrs { + if sv == nil { + dst.SlicesWithPtrs[k] = nil + continue + } + dst.SlicesWithPtrs[k] = make([]*StructWithPtrs, len(sv)) + for i := range sv { + if sv[i] == nil { + dst.SlicesWithPtrs[k][i] = nil + } else { + dst.SlicesWithPtrs[k][i] = sv[i].Clone() + } + } } } if dst.SlicesWithoutPtrs != nil { dst.SlicesWithoutPtrs = map[string][]*StructWithoutPtrs{} - for k := range src.SlicesWithoutPtrs { - dst.SlicesWithoutPtrs[k] = append([]*StructWithoutPtrs{}, src.SlicesWithoutPtrs[k]...) + for k, sv := range src.SlicesWithoutPtrs { + if sv == nil { + dst.SlicesWithoutPtrs[k] = nil + continue + } + dst.SlicesWithoutPtrs[k] = make([]*StructWithoutPtrs, len(sv)) + for i := range sv { + if sv[i] == nil { + dst.SlicesWithoutPtrs[k][i] = nil + } else { + dst.SlicesWithoutPtrs[k][i] = new(*sv[i]) + } + } } } dst.StructWithoutPtrKey = maps.Clone(src.StructWithoutPtrKey) @@ -115,8 +137,19 @@ func (src *Map) Clone() *Map { } if dst.SliceIntPtr != nil { dst.SliceIntPtr = map[string][]*int{} - for k := range src.SliceIntPtr { - dst.SliceIntPtr[k] = append([]*int{}, src.SliceIntPtr[k]...) + for k, sv := range src.SliceIntPtr { + if sv == nil { + dst.SliceIntPtr[k] = nil + continue + } + dst.SliceIntPtr[k] = make([]*int, len(sv)) + for i := range sv { + if sv[i] == nil { + dst.SliceIntPtr[k][i] = nil + } else { + dst.SliceIntPtr[k][i] = new(*sv[i]) + } + } } } dst.PointerKey = maps.Clone(src.PointerKey) @@ -399,8 +432,15 @@ func (src *GenericCloneableStruct[T, V]) Clone() *GenericCloneableStruct[T, V] { } if dst.SliceMap != nil { dst.SliceMap = map[string][]T{} - for k := range src.SliceMap { - dst.SliceMap[k] = append([]T{}, src.SliceMap[k]...) + for k, sv := range src.SliceMap { + if sv == nil { + dst.SliceMap[k] = nil + continue + } + dst.SliceMap[k] = make([]T, len(sv)) + for i := range sv { + dst.SliceMap[k][i] = sv[i].Clone() + } } } return dst @@ -500,14 +540,36 @@ func (src *StructWithTypeAliasFields) Clone() *StructWithTypeAliasFields { } if dst.MapOfSlicesWithPtrs != nil { dst.MapOfSlicesWithPtrs = map[string][]*StructWithPtrsAlias{} - for k := range src.MapOfSlicesWithPtrs { - dst.MapOfSlicesWithPtrs[k] = append([]*StructWithPtrsAlias{}, src.MapOfSlicesWithPtrs[k]...) + for k, sv := range src.MapOfSlicesWithPtrs { + if sv == nil { + dst.MapOfSlicesWithPtrs[k] = nil + continue + } + dst.MapOfSlicesWithPtrs[k] = make([]*StructWithPtrsAlias, len(sv)) + for i := range sv { + if sv[i] == nil { + dst.MapOfSlicesWithPtrs[k][i] = nil + } else { + dst.MapOfSlicesWithPtrs[k][i] = sv[i].Clone() + } + } } } if dst.MapOfSlicesWithoutPtrs != nil { dst.MapOfSlicesWithoutPtrs = map[string][]*StructWithoutPtrsAlias{} - for k := range src.MapOfSlicesWithoutPtrs { - dst.MapOfSlicesWithoutPtrs[k] = append([]*StructWithoutPtrsAlias{}, src.MapOfSlicesWithoutPtrs[k]...) + for k, sv := range src.MapOfSlicesWithoutPtrs { + if sv == nil { + dst.MapOfSlicesWithoutPtrs[k] = nil + continue + } + dst.MapOfSlicesWithoutPtrs[k] = make([]*StructWithoutPtrsAlias, len(sv)) + for i := range sv { + if sv[i] == nil { + dst.MapOfSlicesWithoutPtrs[k][i] = nil + } else { + dst.MapOfSlicesWithoutPtrs[k][i] = new(*sv[i]) + } + } } } return dst diff --git a/control/controlclient/auto.go b/control/controlclient/auto.go index b087e1444..05c7552c8 100644 --- a/control/controlclient/auto.go +++ b/control/controlclient/auto.go @@ -356,7 +356,15 @@ func (c *Auto) authRoutine() { if err != nil { c.direct.health.SetAuthRoutineInError(err) report(err, f) - bo.BackOff(ctx, err) + if rle, ok := errors.AsType[*rateLimitError](err); ok { + c.logf("authRoutine: %s", rle) + select { + case <-ctx.Done(): + case <-time.After(rle.retryAfter): + } + } else { + bo.BackOff(ctx, err) + } continue } if url != "" { diff --git a/control/controlclient/controlclient_test.go b/control/controlclient/controlclient_test.go index 2205a0eb3..5c25af0f4 100644 --- a/control/controlclient/controlclient_test.go +++ b/control/controlclient/controlclient_test.go @@ -406,6 +406,118 @@ func testHTTPS(t *testing.T, withProxy bool) { } } +// TestRegisterRateLimited verifies that the client correctly handles 429 +// responses to registration requests by parsing the Retry-After header +// and returning a rateLimitError. +func TestRegisterRateLimited(t *testing.T) { + bakedroots.ResetForTest(t, tlstest.TestRootCA()) + + bus := eventbustest.NewBus(t) + + controlLn, err := tls.Listen("tcp", "127.0.0.1:0", tlstest.ControlPlane.ServerTLSConfig()) + if err != nil { + t.Fatal(err) + } + defer controlLn.Close() + + var registerAttempts atomic.Int64 + tc := &testcontrol.Server{ + Logf: tstest.WhileTestRunningLogger(t), + MaybeRateLimitRegister: func() (bool, string, string) { + if registerAttempts.Add(1) == 1 { + return true, "30", "try again later" + } + return false, "", "" + }, + } + controlSrv := &http.Server{ + Handler: tc, + ErrorLog: logger.StdLogger(t.Logf), + } + go controlSrv.Serve(controlLn) + + const fakeControlIP = "1.2.3.4" + + dialer := &tsdial.Dialer{} + dialer.SetNetMon(netmon.NewStatic()) + dialer.SetBus(bus) + dialer.SetSystemDialerForTest(func(ctx context.Context, network, addr string) (net.Conn, error) { + host, _, err := net.SplitHostPort(addr) + if err != nil { + return nil, fmt.Errorf("SplitHostPort(%q): %v", addr, err) + } + var d net.Dialer + if host == fakeControlIP { + return d.DialContext(ctx, network, controlLn.Addr().String()) + } + return nil, fmt.Errorf("unexpected dial to %q", addr) + }) + + opts := Options{ + Persist: persist.Persist{}, + GetMachinePrivateKey: func() (key.MachinePrivate, error) { + return key.NewMachine(), nil + }, + ServerURL: "https://controlplane.tstest", + Clock: tstime.StdClock{}, + Hostinfo: &tailcfg.Hostinfo{ + BackendLogID: "test-backend-log-id", + }, + DiscoPublicKey: key.NewDisco().Public(), + Logf: t.Logf, + HealthTracker: health.NewTracker(bus), + PopBrowserURL: func(url string) { + t.Logf("PopBrowserURL: %q", url) + }, + Dialer: dialer, + Bus: bus, + } + d, err := NewDirect(opts) + if err != nil { + t.Fatalf("NewDirect: %v", err) + } + + d.dnsCache.LookupIPForTest = func(ctx context.Context, host string) ([]netip.Addr, error) { + if host == "controlplane.tstest" { + return []netip.Addr{netip.MustParseAddr(fakeControlIP)}, nil + } + t.Errorf("unexpected DNS query for %q", host) + return nil, fmt.Errorf("unexpected DNS lookup for %q", host) + } + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + // First attempt should get a 429 and return a rateLimitError. + _, err = d.TryLogin(ctx, LoginEphemeral) + if err == nil { + t.Fatal("expected rate limit error on first attempt, got nil") + } + var rle *rateLimitError + if !errors.As(err, &rle) { + t.Fatalf("expected *rateLimitError, got %T: %v", err, err) + } + if rle.retryAfter != 30*time.Second { + t.Errorf("retryAfter = %v, want 30s", rle.retryAfter) + } + if rle.msg != "try again later" { + t.Errorf("msg = %q, want %q", rle.msg, "try again later") + } + + // Second attempt should succeed (server no longer rate-limiting). + url, err := d.TryLogin(ctx, LoginEphemeral) + if err != nil { + t.Fatalf("TryLogin after rate limit: %v", err) + } + if url != "" { + t.Errorf("got URL %q, want empty", url) + } + + if got := registerAttempts.Load(); got != 2 { + t.Errorf("register attempts = %d, want 2", got) + } +} + func connectProxyTo(t testing.TB, target, backendAddrPort string, reqs *atomic.Int64) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if r.RequestURI != target { diff --git a/control/controlclient/direct.go b/control/controlclient/direct.go index d873cc745..032999cb9 100644 --- a/control/controlclient/direct.go +++ b/control/controlclient/direct.go @@ -17,6 +17,7 @@ import ( "fmt" "io" "log" + "math/rand/v2" "net" "net/http" "net/netip" @@ -24,6 +25,7 @@ import ( "reflect" "runtime" "slices" + "strconv" "strings" "sync/atomic" "time" @@ -575,6 +577,37 @@ var macOSScreenTime = health.Register(&health.Warnable{ ImpactsConnectivity: true, }) +type rateLimitError struct { + msg string + retryAfter time.Duration +} + +func (e *rateLimitError) Error() string { + return fmt.Sprintf("rate limited: %s (retry after %v)", e.msg, e.retryAfter) +} + +func parseRateLimitError(res *http.Response) *rateLimitError { + msg, _ := io.ReadAll(res.Body) + res.Body.Close() + + ret := &rateLimitError{ + msg: strings.TrimSpace(string(msg)), + } + + v := res.Header.Get("Retry-After") + if i, err := strconv.Atoi(v); err == nil { + ret.retryAfter = time.Duration(i) * time.Second + } else if t, err := http.ParseTime(v); err == nil { + ret.retryAfter = time.Until(t) + } + + // If the server didn't give us a valid Retry-After, default to 10s. + if ret.retryAfter <= 0 || ret.retryAfter > time.Hour { + ret.retryAfter = 5*time.Second + rand.N(5*time.Second) + } + return ret +} + func (c *Direct) doLogin(ctx context.Context, opt loginOpt) (mustRegen bool, newURL string, nks tkatype.MarshaledSignature, err error) { if c.panicOnUse { panic("tainted client") @@ -769,6 +802,12 @@ func (c *Direct) doLogin(ctx context.Context, opt loginOpt) (mustRegen bool, new if err != nil { return regen, opt.URL, nil, fmt.Errorf("register request: %w", err) } + // Handle 429 Too Many Requests with a specific error type that includes the retry-after duration. + if res.StatusCode == 429 { + rle := parseRateLimitError(res) + msg := fmt.Sprintf("node registration rate limited; will retry after %v", rle.retryAfter) + return false, "", nil, vizerror.WrapWithMessage(rle, msg) + } if res.StatusCode != 200 { msg, _ := io.ReadAll(res.Body) res.Body.Close() diff --git a/control/controlclient/direct_test.go b/control/controlclient/direct_test.go index d10b346ae..98741482f 100644 --- a/control/controlclient/direct_test.go +++ b/control/controlclient/direct_test.go @@ -5,9 +5,11 @@ package controlclient import ( "encoding/json" + "errors" "net/http" "net/http/httptest" "net/netip" + "strings" "testing" "time" @@ -126,6 +128,109 @@ func fakeEndpoints(ports ...uint16) (ret []tailcfg.Endpoint) { return } +func TestParseRateLimitError(t *testing.T) { + tests := []struct { + name string + statusCode int + body string + retryAfter string // Retry-After header value + wantMsg string + wantMin time.Duration // minimum expected retryAfter + wantMax time.Duration // maximum expected retryAfter + }{ + { + name: "retry-after-seconds", + statusCode: 429, + body: "too many requests", + retryAfter: "30", + wantMsg: "too many requests", + wantMin: 30 * time.Second, + wantMax: 30 * time.Second, + }, + { + name: "no-retry-after-header", + statusCode: 429, + body: "slow down", + retryAfter: "", + wantMsg: "slow down", + wantMin: 5 * time.Second, + wantMax: 10 * time.Second, + }, + { + name: "unparseable-retry-after", + statusCode: 429, + body: "rate limited", + retryAfter: "not-a-number", + wantMsg: "rate limited", + wantMin: 5 * time.Second, + wantMax: 10 * time.Second, + }, + { + name: "empty-body", + statusCode: 429, + body: "", + retryAfter: "5", + wantMsg: "", + wantMin: 5 * time.Second, + wantMax: 5 * time.Second, + }, + { + name: "body-with-whitespace", + statusCode: 429, + body: " too many requests \n", + retryAfter: "10", + wantMsg: "too many requests", + wantMin: 10 * time.Second, + wantMax: 10 * time.Second, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rec := httptest.NewRecorder() + if tt.retryAfter != "" { + rec.Header().Set("Retry-After", tt.retryAfter) + } + rec.WriteHeader(tt.statusCode) + rec.Body.WriteString(tt.body) + res := rec.Result() + + err := parseRateLimitError(res) + if err == nil { + t.Fatal("expected non-nil error") + } + + var rle *rateLimitError + if !errors.As(err, &rle) { + t.Fatalf("error is not a *rateLimitError: %T", err) + } + if rle.msg != tt.wantMsg { + t.Errorf("msg = %q, want %q", rle.msg, tt.wantMsg) + } + if rle.retryAfter < tt.wantMin || rle.retryAfter > tt.wantMax { + t.Errorf("retryAfter = %v, want between %v and %v", rle.retryAfter, tt.wantMin, tt.wantMax) + } + + // Verify the Error() string contains useful information. + errStr := err.Error() + if !strings.Contains(errStr, "rate limited") { + t.Errorf("Error() = %q, want it to contain 'rate limited'", errStr) + } + }) + } +} + +func TestRateLimitErrorIsError(t *testing.T) { + err := &rateLimitError{msg: "test", retryAfter: 5 * time.Second} + var target *rateLimitError + if !errors.As(err, &target) { + t.Fatal("errors.As should match *rateLimitError") + } + if target.retryAfter != 5*time.Second { + t.Errorf("retryAfter = %v, want 5s", target.retryAfter) + } +} + func TestTsmpPing(t *testing.T) { hi := hostinfo.New() ni := tailcfg.NetInfo{LinkType: "wired"} diff --git a/control/controlclient/map.go b/control/controlclient/map.go index c8cbdbce5..34b5ecc55 100644 --- a/control/controlclient/map.go +++ b/control/controlclient/map.go @@ -298,7 +298,12 @@ func (ms *mapSession) handleNonKeepAliveMapResponse(ctx context.Context, resp *t ms.patchifyPeersChanged(resp) - ms.removeUnwantedDiscoUpdates(resp) + ms.removeUnwantedDiscoUpdates(resp, viaTSMP) + + // TSMP learned key was rejected, no need to do any more work in the engine. + if viaTSMP && len(resp.PeersChangedPatch) == 0 { + return nil + } ms.removeUnwantedDiscoUpdatesFromFullNetmapUpdate(resp) ms.updateStateFromResponse(resp) @@ -407,7 +412,7 @@ type updateStats struct { // removeUnwantedDiscoUpdates goes over the patchified updates and reject items // where the node is offline and has last been seen before the recorded last seen. -func (ms *mapSession) removeUnwantedDiscoUpdates(resp *tailcfg.MapResponse) { +func (ms *mapSession) removeUnwantedDiscoUpdates(resp *tailcfg.MapResponse, viaTSMP bool) { ms.peersMu.RLock() defer ms.peersMu.RUnlock() @@ -422,6 +427,30 @@ func (ms *mapSession) removeUnwantedDiscoUpdates(resp *tailcfg.MapResponse) { continue } + existingNode, ok := ms.peers[change.NodeID] + // Accept if: + // - Cannot find the peer, don't have enough data. + if !ok { + acceptedDiscoUpdates = append(acceptedDiscoUpdates, change) + continue + } + + // Reject if: + // - key was learned via tsmp AND, + // - existing node is online AND, + // - key did not change. + // Here to avoid a deeper reconfig in the case where we get a TSMP key + // exchange while that node is already in a connected state (from the view + // of the control plane). This is meant to keep the node stable, avoiding a + // reconfiguration of the node deeper down in the engine. + // With this, we are avoiding updating the LastSeen and Online fields from + // TSMP updates when that is not relevant, overall making the connection + // state change less, and updating the engine less. + if viaTSMP && existingNode.Online().Get() && + *change.DiscoKey == existingNode.DiscoKey() { + continue + } + // Accept if: // - Node is online. if *change.Online { @@ -429,17 +458,10 @@ func (ms *mapSession) removeUnwantedDiscoUpdates(resp *tailcfg.MapResponse) { continue } - existingNode, ok := ms.peers[change.NodeID] // Accept if: - // - Cannot find the peer, don't have enough data - if !ok { - acceptedDiscoUpdates = append(acceptedDiscoUpdates, change) - continue - } - - // Accept if: - // - lastSeen moved forward in time. - if existingLastSeen, ok := existingNode.LastSeen().GetOk(); ok && + // - if we don't have a last seen to compare against on the existing node. + // - OR lastSeen moved forward in time. + if existingLastSeen, ok := existingNode.LastSeen().GetOk(); !ok || change.LastSeen.After(existingLastSeen) { acceptedDiscoUpdates = append(acceptedDiscoUpdates, change) } @@ -497,8 +519,13 @@ func (ms *mapSession) removeUnwantedDiscoUpdatesFromFullNetmapUpdate(resp *tailc continue } - // Overwrite the key in the full netmap update. + // Overwrite the key and last seen in the full netmap update. peer.DiscoKey = existingNode.DiscoKey() + if t, ok := existingNode.LastSeen().GetOk(); ok { + peer.LastSeen = new(t) + } else { + peer.LastSeen = nil + } } } @@ -812,13 +839,22 @@ func (ms *mapSession) addUserProfile(nm *netmap.NetworkMap, userID tailcfg.UserI } var debugPatchifyPeer = envknob.RegisterBool("TS_DEBUG_PATCHIFY_PEER") +var debugPatchifyPeerMiss = envknob.RegisterBool("TS_DEBUG_PATCHIFY_PEER_MISS") + +// patchifyMissOnFalse, if non-nil, is called with the field name when +// patchifyPeer fails. It is set by an init func in map_debug.go. +var patchifyMissOnFalse func(string) // patchifyPeersChanged mutates resp to promote PeersChanged entries to PeersChangedPatch // when possible. func (ms *mapSession) patchifyPeersChanged(resp *tailcfg.MapResponse) { + var onFalse func(string) + if debugPatchifyPeerMiss() { + onFalse = patchifyMissOnFalse + } filtered := resp.PeersChanged[:0] for _, n := range resp.PeersChanged { - if p, ok := ms.patchifyPeer(n); ok { + if p, ok := ms.patchifyPeer(n, onFalse); ok { patchifiedPeer.Add(1) if debugPatchifyPeer() { patchj, _ := json.Marshal(p) @@ -856,21 +892,27 @@ func getNodeFields() []string { // // It returns ok=false if a patch can't be made, (V, ok) on a delta, or (nil, // true) if all the fields were identical (a zero change). -func (ms *mapSession) patchifyPeer(n *tailcfg.Node) (_ *tailcfg.PeerChange, ok bool) { +func (ms *mapSession) patchifyPeer(n *tailcfg.Node, onFalse func(string)) (_ *tailcfg.PeerChange, ok bool) { ms.peersMu.RLock() defer ms.peersMu.RUnlock() was, ok := ms.peers[n.ID] if !ok { + if onFalse != nil { + onFalse("peer_not_found") + } return nil, false } - return peerChangeDiff(was, n) + return peerChangeDiff(was, n, onFalse) } // peerChangeDiff returns the difference from 'was' to 'n', if possible. // // It returns (nil, true) if the fields were identical. -func peerChangeDiff(was tailcfg.NodeView, n *tailcfg.Node) (_ *tailcfg.PeerChange, ok bool) { +func peerChangeDiff(was tailcfg.NodeView, n *tailcfg.Node, onFalse func(string)) (_ *tailcfg.PeerChange, ok bool) { + if onFalse == nil { + onFalse = func(string) {} + } var ret *tailcfg.PeerChange pc := func() *tailcfg.PeerChange { if ret == nil { @@ -894,22 +936,27 @@ func peerChangeDiff(was tailcfg.NodeView, n *tailcfg.Node) (_ *tailcfg.PeerChang // And it was never sent by any known control server. case "ID": if was.ID() != n.ID { + onFalse(field) return nil, false } case "StableID": if was.StableID() != n.StableID { + onFalse(field) return nil, false } case "Name": if was.Name() != n.Name { + onFalse(field) return nil, false } case "User": if was.User() != n.User { + onFalse(field) return nil, false } case "Sharer": if was.Sharer() != n.Sharer { + onFalse(field) return nil, false } case "Key": @@ -926,6 +973,7 @@ func peerChangeDiff(was tailcfg.NodeView, n *tailcfg.Node) (_ *tailcfg.PeerChang } case "Machine": if was.Machine() != n.Machine { + onFalse(field) return nil, false } case "DiscoKey": @@ -934,10 +982,12 @@ func peerChangeDiff(was tailcfg.NodeView, n *tailcfg.Node) (_ *tailcfg.PeerChang } case "Addresses": if !views.SliceEqual(was.Addresses(), views.SliceOf(n.Addresses)) { + onFalse(field) return nil, false } case "AllowedIPs": if !views.SliceEqual(was.AllowedIPs(), views.SliceOf(n.AllowedIPs)) { + onFalse(field) return nil, false } case "Endpoints": @@ -957,13 +1007,16 @@ func peerChangeDiff(was tailcfg.NodeView, n *tailcfg.Node) (_ *tailcfg.PeerChang continue } if !was.Hostinfo().Valid() || !n.Hostinfo.Valid() { + onFalse(field) return nil, false } if !was.Hostinfo().Equal(n.Hostinfo) { + onFalse(field) return nil, false } case "Created": if !was.Created().Equal(n.Created) { + onFalse(field) return nil, false } case "Cap": @@ -991,10 +1044,12 @@ func peerChangeDiff(was tailcfg.NodeView, n *tailcfg.Node) (_ *tailcfg.PeerChang } case "Tags": if !views.SliceEqual(was.Tags(), views.SliceOf(n.Tags)) { + onFalse(field) return nil, false } case "PrimaryRoutes": if !views.SliceEqual(was.PrimaryRoutes(), views.SliceOf(n.PrimaryRoutes)) { + onFalse(field) return nil, false } case "Online": @@ -1007,22 +1062,27 @@ func peerChangeDiff(was tailcfg.NodeView, n *tailcfg.Node) (_ *tailcfg.PeerChang } case "MachineAuthorized": if was.MachineAuthorized() != n.MachineAuthorized { + onFalse(field) return nil, false } case "UnsignedPeerAPIOnly": if was.UnsignedPeerAPIOnly() != n.UnsignedPeerAPIOnly { + onFalse(field) return nil, false } case "IsWireGuardOnly": if was.IsWireGuardOnly() != n.IsWireGuardOnly { + onFalse(field) return nil, false } case "IsJailed": if was.IsJailed() != n.IsJailed { + onFalse(field) return nil, false } case "Expired": if was.Expired() != n.Expired { + onFalse(field) return nil, false } case "SelfNodeV4MasqAddrForThisPeer": @@ -1031,6 +1091,7 @@ func peerChangeDiff(was tailcfg.NodeView, n *tailcfg.Node) (_ *tailcfg.PeerChang continue } if va, ok := va.GetOk(); !ok || vb == nil || va != *vb { + onFalse(field) return nil, false } case "SelfNodeV6MasqAddrForThisPeer": @@ -1039,17 +1100,20 @@ func peerChangeDiff(was tailcfg.NodeView, n *tailcfg.Node) (_ *tailcfg.PeerChang continue } if va, ok := va.GetOk(); !ok || vb == nil || va != *vb { + onFalse(field) return nil, false } case "ExitNodeDNSResolvers": va, vb := was.ExitNodeDNSResolvers(), views.SliceOfViews(n.ExitNodeDNSResolvers) if va.Len() != vb.Len() { + onFalse(field) return nil, false } for i := range va.Len() { if !va.At(i).Equal(vb.At(i)) { + onFalse(field) return nil, false } } diff --git a/control/controlclient/map_debug.go b/control/controlclient/map_debug.go new file mode 100644 index 000000000..2d6012211 --- /dev/null +++ b/control/controlclient/map_debug.go @@ -0,0 +1,16 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !ts_omit_debug + +package controlclient + +import "tailscale.com/metrics" + +var patchifyMissStats = metrics.NewLabelMap("counter_patchify_miss", "why") + +func init() { + patchifyMissOnFalse = func(field string) { + patchifyMissStats.Add(field, 1) + } +} diff --git a/control/controlclient/map_test.go b/control/controlclient/map_test.go index f057345d9..be28685b3 100644 --- a/control/controlclient/map_test.go +++ b/control/controlclient/map_test.go @@ -14,6 +14,7 @@ import ( "strings" "sync/atomic" "testing" + "testing/synctest" "time" "github.com/google/go-cmp/cmp" @@ -630,81 +631,160 @@ func TestUpdateDiscoForNode(t *testing.T) { name string initialOnline bool initialLastSeen time.Time + updateDiscoKey bool updateOnline bool updateLastSeen time.Time wantUpdate bool + wantKeyChanged bool }{ { name: "newer_key_not_online", initialOnline: true, initialLastSeen: time.Unix(1, 0), + updateDiscoKey: true, updateOnline: false, updateLastSeen: time.Now(), wantUpdate: true, + wantKeyChanged: true, }, { name: "newer_key_online", initialOnline: true, initialLastSeen: time.Unix(1, 0), + updateDiscoKey: true, updateOnline: true, updateLastSeen: time.Now(), wantUpdate: true, + wantKeyChanged: true, }, { name: "older_key_not_online", initialOnline: false, initialLastSeen: time.Now(), + updateDiscoKey: true, updateOnline: false, updateLastSeen: time.Unix(1, 0), wantUpdate: false, + wantKeyChanged: false, }, { name: "older_key_online", initialOnline: false, initialLastSeen: time.Now(), + updateDiscoKey: true, updateOnline: true, updateLastSeen: time.Unix(1, 0), wantUpdate: true, + wantKeyChanged: true, + }, + { + name: "same_newer_key_not_online", + initialOnline: true, + initialLastSeen: time.Unix(1, 0), + updateDiscoKey: false, + updateOnline: false, + updateLastSeen: time.Now(), + wantUpdate: false, + wantKeyChanged: false, + }, + { + name: "same_newer_key_online", + initialOnline: true, + initialLastSeen: time.Unix(1, 0), + updateDiscoKey: false, + updateOnline: true, + updateLastSeen: time.Now(), + wantUpdate: false, + wantKeyChanged: false, + }, + { + name: "same_older_key_not_online", + initialOnline: false, + initialLastSeen: time.Now(), + updateDiscoKey: false, + updateOnline: false, + updateLastSeen: time.Unix(1, 0), + wantUpdate: false, + wantKeyChanged: false, + }, + { + name: "same_older_key_online", + initialOnline: false, + initialLastSeen: time.Now(), + updateDiscoKey: false, + updateOnline: true, + updateLastSeen: time.Unix(1, 0), + wantUpdate: true, + wantKeyChanged: false, + }, + { + name: "no_initial_last_seen", + initialOnline: false, + updateDiscoKey: true, + updateOnline: false, + updateLastSeen: time.Now(), + wantUpdate: true, + wantKeyChanged: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - nu := &rememberLastNetmapUpdater{ - done: make(chan any, 1), - } - ms := newTestMapSession(t, nu) + synctest.Test(t, func(*testing.T) { + nu := &rememberLastNetmapUpdater{ + done: make(chan any, 1), + } + ms := newTestMapSession(t, nu) + defer ms.Close() - oldKey := key.NewDisco() + oldKey := key.NewDisco() - // Insert existing node - node := tailcfg.Node{ - ID: 1, - Key: key.NewNode().Public(), - DiscoKey: oldKey.Public(), - Online: &tt.initialOnline, - LastSeen: &tt.initialLastSeen, - } + // Insert existing node + node := tailcfg.Node{ + ID: 1, + Key: key.NewNode().Public(), + DiscoKey: oldKey.Public(), + Online: &tt.initialOnline, + } + if !tt.initialLastSeen.IsZero() { + node.LastSeen = &tt.initialLastSeen + } - if nm := ms.netmapForResponse(&tailcfg.MapResponse{ - Peers: []*tailcfg.Node{&node}, - }); len(nm.Peers) != 1 { - t.Fatalf("node not inserted") - } + if nm := ms.netmapForResponse(&tailcfg.MapResponse{ + Peers: []*tailcfg.Node{&node}, + }); len(nm.Peers) != 1 { + t.Fatalf("node not inserted") + } - newKey := key.NewDisco() - ms.updateDiscoForNode(node.ID, node.Key, newKey.Public(), tt.updateLastSeen, tt.updateOnline) - <-nu.done + newKey := oldKey.Public() + if tt.updateDiscoKey { + newKey = key.NewDisco().Public() + } + ms.updateDiscoForNode(node.ID, node.Key, newKey, tt.updateLastSeen, tt.updateOnline) - peer, ok := ms.peers[node.ID] - if !ok { - t.Fatal("node not found") - } + // We have an early escape that would not trigger the netmap updater. + synctest.Wait() + select { + case <-nu.done: + if !tt.wantUpdate { + t.Errorf("did not expect update, got: %v", nu.last) + } + default: + if tt.wantUpdate { + t.Errorf("expected update, did not get any") + } + } - updated := peer.DiscoKey().Compare(newKey.Public()) == 0 - if updated != tt.wantUpdate { - t.Fatalf("Disco key update: %t, wanted update: %t", updated, tt.wantUpdate) - } + peer, ok := ms.peers[node.ID] + if !ok { + t.Fatal("node not found") + } + + keyChanged := peer.DiscoKey().Compare(oldKey.Public()) != 0 + if keyChanged != tt.wantKeyChanged { + t.Errorf("Disco key update: %t, wanted update: %t", keyChanged, tt.wantKeyChanged) + } + }) }) } } @@ -831,6 +911,14 @@ func TestUpdateDiscoForNodeCallbackWithFullNetmap(t *testing.T) { updateLastSeen: now, expectNewDisco: true, }, + { + name: "local-lastseen-preserved-after-first-reconnect", + initialOnline: false, + initialLastSeen: now, + updateOnline: false, + updateLastSeen: now, + expectNewDisco: false, + }, } for _, tt := range tests { @@ -1115,17 +1203,20 @@ func TestPeerChangeDiff(t *testing.T) { a: &tailcfg.Node{ID: 1, CapMap: tailcfg.NodeCapMap{tailcfg.CapabilityAdmin: nil}}, b: &tailcfg.Node{ID: 1, CapMap: tailcfg.NodeCapMap{tailcfg.CapabilityAdmin: nil, tailcfg.CapabilityDebug: nil}}, want: &tailcfg.PeerChange{NodeID: 1, CapMap: tailcfg.NodeCapMap{tailcfg.CapabilityAdmin: nil, tailcfg.CapabilityDebug: nil}}, - }, { + }, + { name: "patch-capmap-remove-key", a: &tailcfg.Node{ID: 1, CapMap: tailcfg.NodeCapMap{tailcfg.CapabilityAdmin: nil}}, b: &tailcfg.Node{ID: 1, CapMap: tailcfg.NodeCapMap{}}, want: &tailcfg.PeerChange{NodeID: 1, CapMap: tailcfg.NodeCapMap{}}, - }, { + }, + { name: "patch-capmap-remove-as-nil", a: &tailcfg.Node{ID: 1, CapMap: tailcfg.NodeCapMap{tailcfg.CapabilityAdmin: nil}}, b: &tailcfg.Node{ID: 1}, want: &tailcfg.PeerChange{NodeID: 1, CapMap: tailcfg.NodeCapMap{}}, - }, { + }, + { name: "patch-capmap-add-key-to-empty-map", a: &tailcfg.Node{ID: 1}, b: &tailcfg.Node{ID: 1, CapMap: tailcfg.NodeCapMap{tailcfg.CapabilityAdmin: nil}}, @@ -1140,7 +1231,7 @@ func TestPeerChangeDiff(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - pc, ok := peerChangeDiff(tt.a.View(), tt.b) + pc, ok := peerChangeDiff(tt.a.View(), tt.b, nil) if tt.wantEqual { if !ok || pc != nil { t.Errorf("got (%p, %v); want (nil, true); pc=%v", pc, ok, logger.AsJSON(pc)) @@ -1161,7 +1252,7 @@ func TestPeerChangeDiffAllocs(t *testing.T) { a := &tailcfg.Node{ID: 1} b := &tailcfg.Node{ID: 1} n := testing.AllocsPerRun(10000, func() { - diff, ok := peerChangeDiff(a.View(), b) + diff, ok := peerChangeDiff(a.View(), b, nil) if !ok || diff != nil { t.Fatalf("unexpected result: (%s, %v)", logger.AsJSON(diff), ok) } @@ -1366,7 +1457,6 @@ func TestUpgradeNode(t *testing.T) { } }) } - } func BenchmarkMapSessionDelta(b *testing.B) { @@ -1774,11 +1864,6 @@ func TestPathDiscokeyerImplementations(t *testing.T) { if _, ok := e.(patchDiscoKeyer); !ok { t.Error("wgengine.userspaceEngine must implement patchDiscoKeyer") } - - wd := wgengine.NewWatchdog(e) - if _, ok := wd.(patchDiscoKeyer); !ok { - t.Error("wgengine.watchdogEngine must implement patchDiscoKeyer") - } } func TestPeerIDAndKeyByTailscaleIP(t *testing.T) { @@ -1838,3 +1923,101 @@ func TestPeerIDAndKeyByTailscaleIP(t *testing.T) { } }) } + +func TestRemoveUnwantedDiscoUpdates(t *testing.T) { + tests := []struct { + name string + viaTSMP bool + existingOnline bool + sameKey bool + newerLastSeen bool + wantAccepted bool + }{ + { + name: "tsmp_online_peer_same_key", + viaTSMP: true, + existingOnline: true, + sameKey: true, + newerLastSeen: true, + wantAccepted: false, + }, + { + name: "not_tsmp_online_peer_same_key", + viaTSMP: false, + existingOnline: true, + sameKey: true, + newerLastSeen: true, + wantAccepted: true, + }, + { + name: "tsmp_offline_peer_same_key", + viaTSMP: true, + existingOnline: false, + sameKey: true, + newerLastSeen: true, + wantAccepted: true, + }, + { + name: "tsmp_online_peer_diff_key", + viaTSMP: true, + existingOnline: true, + sameKey: false, + newerLastSeen: true, + wantAccepted: true, + }, + { + name: "tsmp_online_peer_same_key_old_lastseen", + viaTSMP: true, + existingOnline: true, + sameKey: true, + newerLastSeen: false, + wantAccepted: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ms := newTestMapSession(t, &rememberLastNetmapUpdater{done: make(chan any, 1)}) + + existingKey := key.NewDisco().Public() + existingOnline := tt.existingOnline + initialLastSeen := time.Unix(1, 0) + + ms.updateStateFromResponse(&tailcfg.MapResponse{ + Peers: []*tailcfg.Node{{ + ID: 1, + Key: key.NewNode().Public(), + DiscoKey: existingKey, + Online: &existingOnline, + LastSeen: &initialLastSeen, + }}, + }) + + changeKey := existingKey + if !tt.sameKey { + changeKey = key.NewDisco().Public() + } + changeOnline := false // must be false to reach the new guard + updateLastSeen := time.Unix(2, 0) + if !tt.newerLastSeen { + updateLastSeen = time.Unix(0, 0) + } + + resp := &tailcfg.MapResponse{ + PeersChangedPatch: []*tailcfg.PeerChange{{ + NodeID: 1, + DiscoKey: &changeKey, + Online: &changeOnline, + LastSeen: &updateLastSeen, + }}, + } + + ms.removeUnwantedDiscoUpdates(resp, tt.viaTSMP) + + got := len(resp.PeersChangedPatch) > 0 + if got != tt.wantAccepted { + t.Errorf("accepted=%v, want %v", got, tt.wantAccepted) + } + }) + } +} diff --git a/control/controlclient/sign_supported.go b/control/controlclient/sign_supported.go index ea6fa28e3..f3340d5a6 100644 --- a/control/controlclient/sign_supported.go +++ b/control/controlclient/sign_supported.go @@ -1,9 +1,7 @@ // Copyright (c) Tailscale Inc & contributors // SPDX-License-Identifier: BSD-3-Clause -//go:build windows - -// darwin,cgo is also supported by certstore but untested, so it is not enabled. +//go:build windows || (darwin && !ios && cgo) package controlclient diff --git a/control/controlclient/sign_unsupported.go b/control/controlclient/sign_unsupported.go index ff830282e..a371cbaf1 100644 --- a/control/controlclient/sign_unsupported.go +++ b/control/controlclient/sign_unsupported.go @@ -1,7 +1,7 @@ // Copyright (c) Tailscale Inc & contributors // SPDX-License-Identifier: BSD-3-Clause -//go:build !windows +//go:build (!windows && !(darwin && cgo)) || ios package controlclient diff --git a/control/controlhttp/client_js.go b/control/controlhttp/client_js.go index a3ce7ffe5..e9a4e7dbb 100644 --- a/control/controlhttp/client_js.go +++ b/control/controlhttp/client_js.go @@ -32,7 +32,9 @@ func (d *Dialer) Dial(ctx context.Context) (*ClientConn, error) { host := d.Hostname // If using a custom control server (on a non-standard port), prefer that. // This mirrors the port selection in newNoiseClient from noise.go. - if d.HTTPPort != "" && d.HTTPPort != "80" && d.HTTPSPort == "443" { + // Also use ws:// when HTTPS is explicitly disabled (NoPort), which happens + // for http:// URLs with private hostnames (e.g. http://localhost:31544). + if d.HTTPPort != "" && d.HTTPPort != "80" && (d.HTTPSPort == "443" || d.HTTPSPort == NoPort) { wsScheme = "ws" host = net.JoinHostPort(host, d.HTTPPort) } diff --git a/control/controlknobs/controlknobs.go b/control/controlknobs/controlknobs.go index 14f30d9ce..77a496349 100644 --- a/control/controlknobs/controlknobs.go +++ b/control/controlknobs/controlknobs.go @@ -21,11 +21,6 @@ type Knobs struct { // DisableUPnP indicates whether to attempt UPnP mapping. DisableUPnP atomic.Bool - // KeepFullWGConfig is whether we should disable the lazy wireguard - // programming and instead give WireGuard the full netmap always, even for - // idle peers. - KeepFullWGConfig atomic.Bool - // RandomizeClientPort is whether control says we should randomize // the client port. RandomizeClientPort atomic.Bool @@ -62,12 +57,6 @@ type Knobs struct { // netfiltering, unless overridden by the user. LinuxForceNfTables atomic.Bool - // SeamlessKeyRenewal is whether to renew node keys without breaking connections. - // This is enabled by default in 1.90 and later, but we but we can remotely disable - // it from the control plane if there's a problem. - // http://go/seamless-key-renewal - SeamlessKeyRenewal atomic.Bool - // ProbeUDPLifetime is whether the node should probe UDP path lifetime on // the tail end of an active direct connection in magicsock. ProbeUDPLifetime atomic.Bool @@ -131,7 +120,6 @@ func (k *Knobs) UpdateFromNodeAttributes(capMap tailcfg.NodeCapMap) { } has := capMap.Contains var ( - keepFullWG = has(tailcfg.NodeAttrDebugDisableWGTrim) disableUPnP = has(tailcfg.NodeAttrDisableUPnP) randomizeClientPort = has(tailcfg.NodeAttrRandomizeClientPort) disableDeltaUpdates = has(tailcfg.NodeAttrDisableDeltaUpdates) @@ -142,8 +130,6 @@ func (k *Knobs) UpdateFromNodeAttributes(capMap tailcfg.NodeCapMap) { silentDisco = has(tailcfg.NodeAttrSilentDisco) forceIPTables = has(tailcfg.NodeAttrLinuxMustUseIPTables) forceNfTables = has(tailcfg.NodeAttrLinuxMustUseNfTables) - seamlessKeyRenewal = has(tailcfg.NodeAttrSeamlessKeyRenewal) - disableSeamlessKeyRenewal = has(tailcfg.NodeAttrDisableSeamlessKeyRenewal) probeUDPLifetime = has(tailcfg.NodeAttrProbeUDPLifetime) appCStoreRoutes = has(tailcfg.NodeAttrStoreAppCRoutes) userDialUseRoutes = has(tailcfg.NodeAttrUserDialUseRoutes) @@ -161,7 +147,6 @@ func (k *Knobs) UpdateFromNodeAttributes(capMap tailcfg.NodeCapMap) { oneCGNAT.Set(false) } - k.KeepFullWGConfig.Store(keepFullWG) k.DisableUPnP.Store(disableUPnP) k.RandomizeClientPort.Store(randomizeClientPort) k.OneCGNAT.Store(oneCGNAT) @@ -181,21 +166,6 @@ func (k *Knobs) UpdateFromNodeAttributes(capMap tailcfg.NodeCapMap) { k.DisableSkipStatusQueue.Store(disableSkipStatusQueue) k.DisableHostsFileUpdates.Store(disableHostsFileUpdates) k.ForceRegisterMagicDNSIPv4Only.Store(forceRegisterMagicDNSIPv4Only) - - // If both attributes are present, then "enable" should win. This reflects - // the history of seamless key renewal. - // - // Before 1.90, seamless was a private alpha, opt-in feature. Devices would - // only seamless do if customers opted in using the seamless renewal attr. - // - // In 1.90 and later, seamless is the default behaviour, and devices will use - // seamless unless explicitly told not to by control (e.g. if we discover - // a bug and want clients to use the prior behaviour). - // - // If a customer has opted in to the pre-1.90 seamless implementation, we - // don't want to switch it off for them -- we only want to switch it off for - // devices that haven't opted in. - k.SeamlessKeyRenewal.Store(seamlessKeyRenewal || !disableSeamlessKeyRenewal) } // AsDebugJSON returns k as something that can be marshalled with json.Marshal diff --git a/control/tsp/map.go b/control/tsp/map.go new file mode 100644 index 000000000..961c5dd57 --- /dev/null +++ b/control/tsp/map.go @@ -0,0 +1,415 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package tsp + +import ( + "bytes" + "cmp" + "context" + "encoding/binary" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "strings" + "sync" + "sync/atomic" + + "github.com/klauspost/compress/zstd" + "tailscale.com/control/ts2021" + "tailscale.com/tailcfg" + "tailscale.com/types/key" +) + +// errSessionClosed is returned by [MapSession.Next] and +// [MapSession.NextInto] when called after [MapSession.Close]. +var errSessionClosed = errors.New("tsp: map session closed") + +// DefaultMaxMessageSize is the default cap, in bytes, on the size of a +// single compressed map response frame. See [MapOpts.MaxMessageSize]. +const DefaultMaxMessageSize = 4 << 20 + +// zstdDecoderPool is a pool of *zstd.Decoder reused across MapSessions to +// amortize the cost of setting up zstd state. Decoders are returned via +// [MapSession.Close]; entries are reclaimed by the runtime under memory +// pressure via sync.Pool semantics. +var zstdDecoderPool sync.Pool // of *zstd.Decoder + +// MapOpts contains options for sending a map request. +type MapOpts struct { + // NodeKey is the node's private key. Required. + NodeKey key.NodePrivate + + // Hostinfo is the host information to send. Optional; + // if nil, a minimal default is used. + Hostinfo *tailcfg.Hostinfo + + // Stream is whether to receive multiple MapResponses over + // the same HTTP connection. + Stream bool + + // OmitPeers is whether the client is okay with the Peers list + // being omitted in the response. + OmitPeers bool + + // MaxMessageSize is the maximum size in bytes of any single + // compressed map response frame on the wire. If zero, + // [DefaultMaxMessageSize] is used. + MaxMessageSize int64 +} + +// framedReader is an io.Reader that consumes a stream of length-prefixed +// frames (each a little-endian uint32 length followed by that many bytes) +// from r and yields only the frame payloads back-to-back. +// +// This lets us feed the concatenated zstd frames from our wire protocol +// into a single streaming zstd decoder. Zstd's file format permits +// concatenation (RFC 8478 §2), and klauspost's decoder handles it +// transparently. +// +// If onNewFrame is non-nil, it is called after each new 4-byte length +// header is successfully read. Used to reset the per-message decoded-size +// budget downstream. +type framedReader struct { + r io.Reader + maxSize int64 // per-frame compressed-size cap + remain int // bytes remaining in the current frame + onNewFrame func() +} + +func (f *framedReader) Read(p []byte) (int, error) { + if f.remain == 0 { + var hdr [4]byte + if _, err := io.ReadFull(f.r, hdr[:]); err != nil { + return 0, err + } + sz := int64(binary.LittleEndian.Uint32(hdr[:])) + if sz == 0 { + return 0, fmt.Errorf("map response: zero-length frame") + } + if sz > f.maxSize { + return 0, fmt.Errorf("map response frame size %d exceeds max %d", sz, f.maxSize) + } + f.remain = int(sz) + if f.onNewFrame != nil { + f.onNewFrame() + } + } + if len(p) > f.remain { + p = p[:f.remain] + } + n, err := f.r.Read(p) + f.remain -= n + return n, err +} + +// boundedReader is an io.Reader that yields at most remain bytes from r +// before returning an error. Call reset to raise the budget back to max, +// typically at a new message boundary. +// +// Used to cap the decoded size of a single map response so a malicious +// server can't send a small zstd frame that explodes into gigabytes of +// junk for the json.Decoder to consume. +type boundedReader struct { + r io.Reader + max int64 + remain int64 +} + +func (b *boundedReader) Read(p []byte) (int, error) { + if b.remain <= 0 { + return 0, fmt.Errorf("map response decoded size exceeds max %d", b.max) + } + if int64(len(p)) > b.remain { + p = p[:b.remain] + } + n, err := b.r.Read(p) + b.remain -= int64(n) + return n, err +} + +func (b *boundedReader) reset() { b.remain = b.max } + +// MapSession wraps an in-progress map response stream. Call Next to read +// each MapResponse. Call Close when done. +type MapSession struct { + res *http.Response + stream bool + noiseDoer func(*http.Request) (*http.Response, error) + + // inNext detects concurrent NextInto callers. It CAS-flips + // false→true on entry and back to false on exit; a failed CAS + // panics, akin to how the Go runtime detects concurrent map + // access. It does not serialize Close vs. NextInto; that's + // nextMu's job. + inNext atomic.Bool + + // nextMu is held while [MapSession.NextInto] is running jdec.Decode, + // so that Close can wait for an in-flight Decode to unwind before it + // touches zdec (Reset, pool-Put) and avoids racing with the running + // Read chain that Decode drives. + nextMu sync.Mutex + read int // guarded by nextMu + closed bool // guarded by nextMu + zdec *zstd.Decoder // reads from a framedReader wrapping res.Body + jdec *json.Decoder // reads decompressed JSON from zdec + + closeOnce sync.Once + closeErr error +} + +// NoiseRoundTrip sends an HTTP request over the Noise channel used by this map session. +func (s *MapSession) NoiseRoundTrip(req *http.Request) (*http.Response, error) { + return s.noiseDoer(req) +} + +// Next reads and returns the next MapResponse from the stream. +// For non-streaming sessions, the first call returns the single response +// and subsequent calls return io.EOF. +// For streaming sessions, Next blocks until the next response arrives +// or the server closes the connection. +// +// Each call allocates a fresh MapResponse. Callers that want to amortize +// the allocation across calls can use [MapSession.NextInto]. +// +// Next and NextInto are not safe to call concurrently from multiple +// goroutines on the same [MapSession]; a concurrent call panics, akin +// to the Go runtime's concurrent map access detection. [MapSession.Close] +// may be called concurrently to abort an in-flight Next. +func (s *MapSession) Next() (*tailcfg.MapResponse, error) { + var resp tailcfg.MapResponse + if err := s.NextInto(&resp); err != nil { + return nil, err + } + return &resp, nil +} + +// NextInto is like [MapSession.Next] but decodes the next MapResponse into +// the caller-supplied *resp rather than allocating a new one. The pointer's +// pointee is zeroed before decoding so fields from a prior response do not +// persist. +// +// For non-streaming sessions, the first call decodes the single response +// and subsequent calls return io.EOF. +// For streaming sessions, NextInto blocks until the next response arrives +// or the server closes the connection. +// +// See [MapSession.Next] for concurrency rules; those apply to NextInto too. +func (s *MapSession) NextInto(resp *tailcfg.MapResponse) error { + if !s.inNext.CompareAndSwap(false, true) { + panic("tsp: invalid concurrent call to MapSession.Next/NextInto") + } + defer s.inNext.Store(false) + + s.nextMu.Lock() + defer s.nextMu.Unlock() + if s.closed { + return errSessionClosed + } + if !s.stream && s.read > 0 { + return io.EOF + } + *resp = tailcfg.MapResponse{} + if err := s.jdec.Decode(resp); err != nil { + return err + } + s.read++ + return nil +} + +// Close returns the session's zstd decoder to the pool and closes the +// underlying HTTP response body. It is safe to call Close multiple times +// and from multiple goroutines, including while a [MapSession.Next] or +// [MapSession.NextInto] call is in flight on another goroutine (which +// will return an error once the body close propagates). +func (s *MapSession) Close() error { + // Callers are likely to race a deferred Close with a time.AfterFunc + // timeout (or similar) Close that aborts a hung Next. Without the + // Once, both Closes would Put the same *zstd.Decoder into the pool, + // corrupting it, and the Reset/Put in one would race with the + // zdec.Read that the hung Next is driving. + // + // Ordering inside the Once: close the body first to unblock any + // in-flight NextInto (its Read chain ends at res.Body and will + // return an error once it's closed). That lets NextInto unwind and + // release nextMu. Only then do we take nextMu ourselves and touch + // zdec, which is safe because no goroutine is still reading from + // it. Acquiring nextMu before closing the body would deadlock + // against a hung NextInto. + s.closeOnce.Do(func() { + s.closeErr = s.res.Body.Close() + s.nextMu.Lock() + defer s.nextMu.Unlock() + s.closed = true + s.zdec.Reset(nil) + zstdDecoderPool.Put(s.zdec) + }) + return s.closeErr +} + +// SendMapUpdateOpts contains options for [Client.SendMapUpdate]. +type SendMapUpdateOpts struct { + // NodeKey is the node's private key. Required. + NodeKey key.NodePrivate + + // DiscoKey, if non-zero, is the node's disco public key. + // Peers use it to verify disco pings from this node, which is + // what enables direct (non-DERP) paths. + DiscoKey key.DiscoPublic + + // Hostinfo is the host information to send. Optional; + // if nil, a minimal default is used. + Hostinfo *tailcfg.Hostinfo +} + +// SendMapUpdate sends a one-shot, non-streaming MapRequest to push small +// updates (such as the node's endpoints, hostinfo, or disco public key) to the +// coordination server without starting or disturbing a streaming map session. +func (c *Client) SendMapUpdate(ctx context.Context, opts SendMapUpdateOpts) error { + if opts.NodeKey.IsZero() { + return fmt.Errorf("NodeKey is required") + } + + hi := opts.Hostinfo + if hi == nil { + hi = defaultHostinfo() + } + + mapReq := tailcfg.MapRequest{ + Version: tailcfg.CurrentCapabilityVersion, + NodeKey: opts.NodeKey.Public(), + DiscoKey: opts.DiscoKey, + Hostinfo: hi, + Compress: "zstd", + + // A lite update that lets the server persist our state without breaking + // any existing streaming map session. See the [tailcfg.MapResponse] + // OmitPeers docs. + OmitPeers: true, + Stream: false, + ReadOnly: false, + } + + body, err := json.Marshal(mapReq) + if err != nil { + return fmt.Errorf("encoding map request: %w", err) + } + + nc, err := c.noiseClient(ctx) + if err != nil { + return fmt.Errorf("establishing noise connection: %w", err) + } + + url := c.serverURL + "/machine/map" + url = strings.Replace(url, "http:", "https:", 1) + req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body)) + if err != nil { + return fmt.Errorf("creating map request: %w", err) + } + ts2021.AddLBHeader(req, opts.NodeKey.Public()) + + res, err := nc.Do(req) + if err != nil { + return fmt.Errorf("map request: %w", err) + } + defer res.Body.Close() + + if res.StatusCode != 200 { + msg, _ := io.ReadAll(res.Body) + return fmt.Errorf("map request: http %d: %.200s", + res.StatusCode, strings.TrimSpace(string(msg))) + } + io.Copy(io.Discard, res.Body) + return nil +} + +// Map sends a map request to the coordination server and returns a MapSession +// for reading the framed, zstd-compressed response(s). +func (c *Client) Map(ctx context.Context, opts MapOpts) (*MapSession, error) { + if opts.NodeKey.IsZero() { + return nil, fmt.Errorf("NodeKey is required") + } + + hi := opts.Hostinfo + if hi == nil { + hi = defaultHostinfo() + } + + mapReq := tailcfg.MapRequest{ + Version: tailcfg.CurrentCapabilityVersion, + NodeKey: opts.NodeKey.Public(), + Hostinfo: hi, + Stream: opts.Stream, + Compress: "zstd", + OmitPeers: opts.OmitPeers, + // Streaming requires the server to track us as "connected", + // which in turn requires ReadOnly=false. Non-streaming polls + // stay ReadOnly to minimize side effects. + ReadOnly: !opts.Stream, + } + + body, err := json.Marshal(mapReq) + if err != nil { + return nil, fmt.Errorf("encoding map request: %w", err) + } + + nc, err := c.noiseClient(ctx) + if err != nil { + return nil, fmt.Errorf("establishing noise connection: %w", err) + } + + url := c.serverURL + "/machine/map" + url = strings.Replace(url, "http:", "https:", 1) + req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("creating map request: %w", err) + } + ts2021.AddLBHeader(req, opts.NodeKey.Public()) + + res, err := nc.Do(req) + if err != nil { + return nil, fmt.Errorf("map request: %w", err) + } + + if res.StatusCode != 200 { + msg, _ := io.ReadAll(res.Body) + res.Body.Close() + return nil, fmt.Errorf("map request: http %d: %.200s", + res.StatusCode, strings.TrimSpace(string(msg))) + } + + maxMessageSize := cmp.Or(opts.MaxMessageSize, DefaultMaxMessageSize) + bounded := &boundedReader{max: maxMessageSize, remain: maxMessageSize} + fr := &framedReader{ + r: res.Body, + maxSize: maxMessageSize, + onNewFrame: bounded.reset, + } + + zdec, _ := zstdDecoderPool.Get().(*zstd.Decoder) + if zdec != nil { + if err := zdec.Reset(fr); err != nil { + // Reset can fail if the previous stream is in a bad state; drop + // the decoder and create a fresh one. + zdec = nil + } + } + if zdec == nil { + zdec, err = zstd.NewReader(fr, zstd.WithDecoderConcurrency(1)) + if err != nil { + res.Body.Close() + return nil, fmt.Errorf("creating zstd decoder: %w", err) + } + } + bounded.r = zdec + + return &MapSession{ + res: res, + stream: opts.Stream, + noiseDoer: nc.Do, + zdec: zdec, + jdec: json.NewDecoder(bounded), + }, nil +} diff --git a/control/tsp/map_test.go b/control/tsp/map_test.go new file mode 100644 index 000000000..ddfde3971 --- /dev/null +++ b/control/tsp/map_test.go @@ -0,0 +1,409 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package tsp + +import ( + "bytes" + "context" + "encoding/binary" + "encoding/json" + "fmt" + "io" + "net/http/httptest" + "strings" + "testing" + "time" + + "github.com/klauspost/compress/zstd" + "tailscale.com/health" + "tailscale.com/tailcfg" + "tailscale.com/tstest/integration/testcontrol" + "tailscale.com/types/key" +) + +func TestMapAgainstTestControl(t *testing.T) { + ctrl := &testcontrol.Server{} + ctrl.HTTPTestServer = httptest.NewUnstartedServer(ctrl) + ctrl.HTTPTestServer.Start() + t.Cleanup(ctrl.HTTPTestServer.Close) + baseURL := ctrl.HTTPTestServer.URL + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + ht := new(health.Tracker) + + serverKey, err := DiscoverServerKey(ctx, baseURL) + if err != nil { + t.Fatalf("DiscoverServerKey: %v", err) + } + + register := func(hostname string) (nodeKey key.NodePrivate, machineKey key.MachinePrivate) { + t.Helper() + nodeKey = key.NewNode() + machineKey = key.NewMachine() + c, err := NewClient(ClientOpts{ + ServerURL: baseURL, + MachineKey: machineKey, + HealthTracker: ht, + }) + if err != nil { + t.Fatalf("NewClient %s: %v", hostname, err) + } + defer c.Close() + c.SetControlPublicKey(serverKey) + if _, err := c.Register(ctx, RegisterOpts{ + NodeKey: nodeKey, + Hostinfo: &tailcfg.Hostinfo{Hostname: hostname}, + }); err != nil { + t.Fatalf("Register %s: %v", hostname, err) + } + return nodeKey, machineKey + } + + nodeKeyA, machineKeyA := register("a") + nodeKeyB, _ := register("b") + + clientA, err := NewClient(ClientOpts{ + ServerURL: baseURL, + MachineKey: machineKeyA, + HealthTracker: ht, + }) + if err != nil { + t.Fatalf("NewClient A: %v", err) + } + defer clientA.Close() + clientA.SetControlPublicKey(serverKey) + + session, err := clientA.Map(ctx, MapOpts{ + NodeKey: nodeKeyA, + Hostinfo: &tailcfg.Hostinfo{Hostname: "a"}, + Stream: true, + }) + if err != nil { + t.Fatalf("Map: %v", err) + } + defer session.Close() + + // nextNonKeepalive returns the next non-keepalive MapResponse, to keep + // the test robust if a server-side keepalive arrives mid-test. + nextNonKeepalive := func() *tailcfg.MapResponse { + t.Helper() + for { + resp, err := session.Next() + if err != nil { + t.Fatalf("session.Next: %v", err) + } + if resp.KeepAlive { + continue + } + return resp + } + } + + // First MapResponse: expect node A as self and node B in Peers. + first := nextNonKeepalive() + if first.Node == nil { + t.Fatal("first response has nil Node") + } + if got, want := first.Node.Key, nodeKeyA.Public(); got != want { + t.Errorf("first Node.Key = %v, want %v", got, want) + } + var foundB bool + for _, p := range first.Peers { + if p.Key == nodeKeyB.Public() { + foundB = true + break + } + } + if !foundB { + t.Errorf("peer B (%v) not in first response's Peers (%d peers)", nodeKeyB.Public(), len(first.Peers)) + } + + // Inject raw MapResponses and verify they come out the reader, in order. + // msgToSend is single-slot, so we must consume each before injecting the next. + for i := range 3 { + want := fmt.Sprintf("injected-%d.example.com", i) + inject := &tailcfg.MapResponse{Domain: want} + if !ctrl.AddRawMapResponse(nodeKeyA.Public(), inject) { + t.Fatalf("AddRawMapResponse %d: node not connected", i) + } + got := nextNonKeepalive() + if got.Domain != want { + t.Errorf("injected %d: got Domain=%q, want %q", i, got.Domain, want) + } + } +} + +// TestSendMapUpdateAgainstTestControl verifies that a [Client.SendMapUpdate] +// call from one node lands on the coordination server and that peer nodes +// subsequently observe the updated DiscoKey via their own streaming map poll. +func TestSendMapUpdateAgainstTestControl(t *testing.T) { + ctrl := &testcontrol.Server{} + ctrl.HTTPTestServer = httptest.NewUnstartedServer(ctrl) + ctrl.HTTPTestServer.Start() + t.Cleanup(ctrl.HTTPTestServer.Close) + baseURL := ctrl.HTTPTestServer.URL + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + ht := new(health.Tracker) + + serverKey, err := DiscoverServerKey(ctx, baseURL) + if err != nil { + t.Fatalf("DiscoverServerKey: %v", err) + } + + register := func(hostname string) (nodeKey key.NodePrivate, machineKey key.MachinePrivate) { + t.Helper() + nodeKey = key.NewNode() + machineKey = key.NewMachine() + c, err := NewClient(ClientOpts{ + ServerURL: baseURL, + MachineKey: machineKey, + HealthTracker: ht, + }) + if err != nil { + t.Fatalf("NewClient %s: %v", hostname, err) + } + defer c.Close() + c.SetControlPublicKey(serverKey) + if _, err := c.Register(ctx, RegisterOpts{ + NodeKey: nodeKey, + Hostinfo: &tailcfg.Hostinfo{Hostname: hostname}, + }); err != nil { + t.Fatalf("Register %s: %v", hostname, err) + } + return nodeKey, machineKey + } + + nodeKeyA, machineKeyA := register("a") + nodeKeyB, machineKeyB := register("b") + + // B starts a streaming map poll so we can observe updates about peer A. + clientB, err := NewClient(ClientOpts{ + ServerURL: baseURL, + MachineKey: machineKeyB, + HealthTracker: ht, + }) + if err != nil { + t.Fatalf("NewClient B: %v", err) + } + defer clientB.Close() + clientB.SetControlPublicKey(serverKey) + + session, err := clientB.Map(ctx, MapOpts{ + NodeKey: nodeKeyB, + Hostinfo: &tailcfg.Hostinfo{Hostname: "b"}, + Stream: true, + }) + if err != nil { + t.Fatalf("Map B: %v", err) + } + defer session.Close() + + nextNonKeepalive := func() *tailcfg.MapResponse { + t.Helper() + for { + resp, err := session.Next() + if err != nil { + t.Fatalf("session.Next: %v", err) + } + if resp.KeepAlive { + continue + } + return resp + } + } + + // Drain B's initial MapResponse. A should be present as a peer with a + // zero DiscoKey (it never pushed one). + first := nextNonKeepalive() + var initialA *tailcfg.Node + for _, p := range first.Peers { + if p.Key == nodeKeyA.Public() { + initialA = p + break + } + } + if initialA == nil { + t.Fatalf("peer A (%v) not in B's first MapResponse", nodeKeyA.Public()) + } + if !initialA.DiscoKey.IsZero() { + t.Fatalf("peer A initial DiscoKey = %v, want zero", initialA.DiscoKey) + } + + // A pushes its disco key via SendMapUpdate. + clientA, err := NewClient(ClientOpts{ + ServerURL: baseURL, + MachineKey: machineKeyA, + HealthTracker: ht, + }) + if err != nil { + t.Fatalf("NewClient A: %v", err) + } + defer clientA.Close() + clientA.SetControlPublicKey(serverKey) + + wantDisco := key.NewDisco().Public() + if err := clientA.SendMapUpdate(ctx, SendMapUpdateOpts{ + NodeKey: nodeKeyA, + DiscoKey: wantDisco, + Hostinfo: &tailcfg.Hostinfo{Hostname: "a"}, + }); err != nil { + t.Fatalf("SendMapUpdate: %v", err) + } + + // B should now observe A's new DiscoKey in a subsequent MapResponse. + for { + resp := nextNonKeepalive() + for _, p := range resp.Peers { + if p.Key != nodeKeyA.Public() { + continue + } + if p.DiscoKey == wantDisco { + return // success + } + } + } +} + +// newTestPipeline builds the same framedReader → zstd → boundedReader → +// json.Decoder pipeline that [Client.Map] builds for a live session, but +// feeds it from a raw byte slice. Returned jdec can be used with Decode to +// pull out MapResponses. +func newTestPipeline(t testing.TB, wire []byte, maxMessageSize int64) *json.Decoder { + t.Helper() + bounded := &boundedReader{max: maxMessageSize, remain: maxMessageSize} + fr := &framedReader{ + r: bytes.NewReader(wire), + maxSize: maxMessageSize, + onNewFrame: bounded.reset, + } + zdec, err := zstd.NewReader(fr, zstd.WithDecoderConcurrency(1)) + if err != nil { + t.Fatalf("zstd.NewReader: %v", err) + } + t.Cleanup(zdec.Close) + bounded.r = zdec + return json.NewDecoder(bounded) +} + +// zstdFrame returns a zstd-compressed frame of b. +func zstdFrame(t testing.TB, b []byte) []byte { + t.Helper() + enc, err := zstd.NewWriter(io.Discard, zstd.WithEncoderConcurrency(1)) + if err != nil { + t.Fatalf("zstd.NewWriter: %v", err) + } + defer enc.Close() + return enc.EncodeAll(b, nil) +} + +// wireFrame writes a 4-byte little-endian length prefix plus payload to buf. +func wireFrame(buf *bytes.Buffer, payload []byte) { + var hdr [4]byte + binary.LittleEndian.PutUint32(hdr[:], uint32(len(payload))) + buf.Write(hdr[:]) + buf.Write(payload) +} + +// TestMapFrameSizeTooLarge verifies that a 4-byte length prefix claiming +// a frame larger than the configured cap is rejected before any payload +// bytes are read from the stream. +func TestMapFrameSizeTooLarge(t *testing.T) { + const max = 4 << 20 + var wire bytes.Buffer + var hdr [4]byte + binary.LittleEndian.PutUint32(hdr[:], (max + 1)) + wire.Write(hdr[:]) + + jdec := newTestPipeline(t, wire.Bytes(), max) + var resp tailcfg.MapResponse + err := jdec.Decode(&resp) + if err == nil { + t.Fatal("Decode: got nil error, want frame-too-large") + } + if !strings.Contains(err.Error(), "exceeds max") { + t.Errorf("Decode error = %q, want one containing %q", err, "exceeds max") + } +} + +// TestMapDecodedSizeTooLarge verifies that a small on-wire frame (well +// under the cap) which decompresses into a huge JSON payload is rejected. +// This is the "zstd bomb" case: a tiny compressed frame that would +// explode into a huge decoded payload for json.Decoder to consume. +func TestMapDecodedSizeTooLarge(t *testing.T) { + const max = 4 << 20 + big := strings.Repeat("a", 5<<20) // 5 MiB of 'a' + raw, err := json.Marshal(&tailcfg.MapResponse{Domain: big}) + if err != nil { + t.Fatal(err) + } + if int64(len(raw)) <= max { + t.Fatalf("raw JSON unexpectedly small: %d", len(raw)) + } + compressed := zstdFrame(t, raw) + if int64(len(compressed)) >= max { + t.Fatalf("compressed too large (%d); test needs a more compressible payload", len(compressed)) + } + + var wire bytes.Buffer + wireFrame(&wire, compressed) + + jdec := newTestPipeline(t, wire.Bytes(), max) + var resp tailcfg.MapResponse + err = jdec.Decode(&resp) + if err == nil { + t.Fatal("Decode: got nil error, want decoded-size-exceeded") + } + if !strings.Contains(err.Error(), "exceeds max") { + t.Errorf("Decode error = %q, want one containing %q", err, "exceeds max") + } +} + +// TestMapBudgetResetsBetweenFrames verifies that the per-message decoded +// budget is reset at each new frame boundary. Two consecutive 3-MiB frames +// should both decode successfully under a 4-MiB per-frame cap. Without the +// reset, the second frame would fail (remaining budget after frame 1 = +// 4MiB - 3MiB = 1MiB, and we'd try to read 3MiB more). +func TestMapBudgetResetsBetweenFrames(t *testing.T) { + const max = 4 << 20 + payload := strings.Repeat("a", 3<<20) + r1 := &tailcfg.MapResponse{Domain: payload + "-one"} + r2 := &tailcfg.MapResponse{Domain: payload + "-two"} + + var wire bytes.Buffer + for _, r := range []*tailcfg.MapResponse{r1, r2} { + raw, err := json.Marshal(r) + if err != nil { + t.Fatal(err) + } + if int64(len(raw)) >= max { + t.Fatalf("raw JSON size %d >= max %d; would fail budget check by itself", len(raw), max) + } + compressed := zstdFrame(t, raw) + if int64(len(compressed)) >= max { + t.Fatalf("compressed size %d >= max %d", len(compressed), max) + } + wireFrame(&wire, compressed) + } + + jdec := newTestPipeline(t, wire.Bytes(), max) + + var got1, got2 tailcfg.MapResponse + if err := jdec.Decode(&got1); err != nil { + t.Fatalf("first Decode: %v", err) + } + if got1.Domain != r1.Domain { + t.Errorf("first Domain mismatch (len %d vs %d)", len(got1.Domain), len(r1.Domain)) + } + if err := jdec.Decode(&got2); err != nil { + t.Fatalf("second Decode: %v", err) + } + if got2.Domain != r2.Domain { + t.Errorf("second Domain mismatch (len %d vs %d)", len(got2.Domain), len(r2.Domain)) + } +} diff --git a/control/tsp/nodefile.go b/control/tsp/nodefile.go new file mode 100644 index 000000000..8cae11ba9 --- /dev/null +++ b/control/tsp/nodefile.go @@ -0,0 +1,105 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package tsp + +import ( + "encoding/json" + "fmt" + "os" + + "tailscale.com/types/key" +) + +// ServerInfo identifies a coordination server by its URL and Noise public key. +type ServerInfo struct { + // URL is the base URL of the coordination server, without any path + // (e.g. "https://controlplane.tailscale.com"). + // + // There is no default value; a URL must always be supplied. + URL string `json:"server_url"` + + // Key is the server's Noise public key, used to establish an encrypted + // channel between the client and the coordination server. + Key key.MachinePublic `json:"server_key"` +} + +// NodeFile is the JSON structure for a node credentials file. It contains +// the private keys that authenticate a node to a coordination server. +// +// Example: +// +// { +// "node_key": "privkey:...", +// "machine_key": "privkey:...", +// "server_url": "https://controlplane.tailscale.com", +// "server_key": "mkey:..." +// } +// +// Note that node and machine private keys share the same "privkey:" +// textual form; they are disambiguated by the surrounding JSON field +// names rather than by any prefix in the key itself. +type NodeFile struct { + // NodeKey is the node's WireGuard private key. The corresponding + // public key identifies this node to other peers. + NodeKey key.NodePrivate `json:"node_key"` + + // MachineKey is the machine's private key. It authenticates this + // machine to the coordination server over Noise. + MachineKey key.MachinePrivate `json:"machine_key"` + + ServerInfo // server_url and server_key +} + +// ReadNodeFile reads and parses a node JSON file. +func ReadNodeFile(path string) (NodeFile, error) { + data, err := os.ReadFile(path) + if err != nil { + return NodeFile{}, err + } + var nf NodeFile + if err := json.Unmarshal(data, &nf); err != nil { + return NodeFile{}, fmt.Errorf("parsing node file %q: %w", path, err) + } + return nf, nil +} + +// WriteNodeFile writes a node JSON file. The file is created with mode 0600. +func WriteNodeFile(path string, nf NodeFile) error { + if err := nf.Check(); err != nil { + return fmt.Errorf("invalid NodeFile: %w", err) + } + return os.WriteFile(path, nf.AsJSON(), 0600) +} + +// AsJSON returns nf as a pretty-printed JSON object, terminated by a newline. +// +// It always succeeds and always returns a valid JSON object. It does not +// validate that the fields of nf are non-zero; it is the caller's +// responsibility to call [NodeFile.Check] first if they want to reject +// incomplete NodeFiles. +func (nf NodeFile) AsJSON() []byte { + out, err := json.MarshalIndent(nf, "", " ") + if err != nil { + panic(fmt.Sprintf("NodeFile.AsJSON: %v", err)) // unreachable: all fields marshal successfully + } + return append(out, '\n') +} + +// Check reports whether nf has all required fields set. +// It returns an error describing the first zero-valued field, if any. +func (nf NodeFile) Check() error { + if nf.NodeKey.IsZero() { + return fmt.Errorf("node_key is missing") + } + if nf.MachineKey.IsZero() { + return fmt.Errorf("machine_key is missing") + } + if nf.URL == "" { + return fmt.Errorf("server_url is missing") + } + if nf.ServerInfo.Key.IsZero() { + return fmt.Errorf("server_key is missing") + } + return nil +} diff --git a/control/tsp/nodefile_test.go b/control/tsp/nodefile_test.go new file mode 100644 index 000000000..4a019f25f --- /dev/null +++ b/control/tsp/nodefile_test.go @@ -0,0 +1,116 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package tsp + +import ( + "encoding/json" + "os" + "path/filepath" + "testing" + + "tailscale.com/types/key" +) + +func TestNodeFileRoundTrip(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "node.json") + + nf := NodeFile{ + NodeKey: key.NewNode(), + MachineKey: key.NewMachine(), + ServerInfo: ServerInfo{ + URL: "https://controlplane.tailscale.com", + Key: key.NewMachine().Public(), + }, + } + + if err := WriteNodeFile(path, nf); err != nil { + t.Fatalf("WriteNodeFile: %v", err) + } + + got, err := ReadNodeFile(path) + if err != nil { + t.Fatalf("ReadNodeFile: %v", err) + } + if !got.NodeKey.Equal(nf.NodeKey) { + t.Errorf("node key mismatch") + } + if !got.MachineKey.Equal(nf.MachineKey) { + t.Errorf("machine key mismatch") + } + if got.URL != nf.URL { + t.Errorf("server URL = %q, want %q", got.URL, nf.URL) + } + if got.ServerInfo.Key != nf.ServerInfo.Key { + t.Errorf("server key mismatch") + } +} + +// TestNodeFileFormat verifies that ReadNodeFile can parse a fixed JSON literal, +// ensuring we don't accidentally change the on-disk format. +func TestNodeFileFormat(t *testing.T) { + const fileContents = `{ + "node_key": "privkey:0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef", + "machine_key": "privkey:fedcba9876543210fedcba9876543210fedcba9876543210fedcba9876543210", + "server_url": "https://controlplane.tailscale.com", + "server_key": "mkey:1111111111111111111111111111111111111111111111111111111111111111" +}` + dir := t.TempDir() + path := filepath.Join(dir, "node.json") + if err := os.WriteFile(path, []byte(fileContents), 0600); err != nil { + t.Fatal(err) + } + + nf, err := ReadNodeFile(path) + if err != nil { + t.Fatalf("ReadNodeFile: %v", err) + } + if nf.NodeKey.IsZero() { + t.Error("node key is zero") + } + if nf.MachineKey.IsZero() { + t.Error("machine key is zero") + } + if nf.URL != "https://controlplane.tailscale.com" { + t.Errorf("server URL = %q", nf.URL) + } + if nf.ServerInfo.Key.IsZero() { + t.Error("server key is zero") + } +} + +// TestNodeFileWriteFormat verifies that WriteNodeFile produces the expected +// JSON field names. +func TestNodeFileWriteFormat(t *testing.T) { + dir := t.TempDir() + path := filepath.Join(dir, "node.json") + + nf := NodeFile{ + NodeKey: key.NewNode(), + MachineKey: key.NewMachine(), + ServerInfo: ServerInfo{ + URL: "https://example.com", + Key: key.NewMachine().Public(), + }, + } + + if err := WriteNodeFile(path, nf); err != nil { + t.Fatalf("WriteNodeFile: %v", err) + } + + data, err := os.ReadFile(path) + if err != nil { + t.Fatal(err) + } + + var raw map[string]json.RawMessage + if err := json.Unmarshal(data, &raw); err != nil { + t.Fatalf("parsing written JSON: %v", err) + } + for _, field := range []string{"node_key", "machine_key", "server_url", "server_key"} { + if _, ok := raw[field]; !ok { + t.Errorf("missing JSON field %q in written file", field) + } + } +} diff --git a/control/tsp/register.go b/control/tsp/register.go new file mode 100644 index 000000000..0d2baf75f --- /dev/null +++ b/control/tsp/register.go @@ -0,0 +1,116 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package tsp + +import ( + "bytes" + "cmp" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + + "tailscale.com/control/ts2021" + "tailscale.com/tailcfg" + "tailscale.com/types/key" +) + +// RegisterOpts contains options for registering a node. +type RegisterOpts struct { + // NodeKey is the node's private key. Required. + NodeKey key.NodePrivate + + // Hostinfo is the host information to send. Optional; + // if nil, a minimal default is used. + Hostinfo *tailcfg.Hostinfo + + // Ephemeral marks the node as ephemeral. + Ephemeral bool + + // AuthKey is a pre-authorized auth key. + AuthKey string + + // Tags is a list of ACL tags to request. + Tags []string + + // MaxResponseSize is the maximum size in bytes of the register + // response body. If zero, [DefaultMaxMessageSize] is used. + MaxResponseSize int64 +} + +// Register sends a registration request to the coordination server +// and returns the response. +func (c *Client) Register(ctx context.Context, opts RegisterOpts) (*tailcfg.RegisterResponse, error) { + hi := opts.Hostinfo + if hi == nil { + hi = defaultHostinfo() + } + if len(opts.Tags) > 0 { + hi.RequestTags = opts.Tags + } + + regReq := tailcfg.RegisterRequest{ + Version: tailcfg.CurrentCapabilityVersion, + NodeKey: opts.NodeKey.Public(), + Hostinfo: hi, + Ephemeral: opts.Ephemeral, + } + if opts.AuthKey != "" { + regReq.Auth = &tailcfg.RegisterResponseAuth{ + AuthKey: opts.AuthKey, + } + } + + body, err := json.Marshal(regReq) + if err != nil { + return nil, fmt.Errorf("encoding register request: %w", err) + } + + nc, err := c.noiseClient(ctx) + if err != nil { + return nil, fmt.Errorf("establishing noise connection: %w", err) + } + + url := c.serverURL + "/machine/register" + url = strings.Replace(url, "http:", "https:", 1) + req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("creating register request: %w", err) + } + ts2021.AddLBHeader(req, opts.NodeKey.Public()) + + res, err := nc.Do(req) + if err != nil { + return nil, fmt.Errorf("register request: %w", err) + } + defer res.Body.Close() + + maxResponseSize := cmp.Or(opts.MaxResponseSize, DefaultMaxMessageSize) + + if res.StatusCode != 200 { + msg, _ := io.ReadAll(io.LimitReader(res.Body, maxResponseSize)) + return nil, fmt.Errorf("register request: http %d: %.200s", + res.StatusCode, strings.TrimSpace(string(msg))) + } + + // Read up to maxResponseSize+1 so we can distinguish "exactly at cap" from + // "over the cap" rather than relying on a truncated json parse error. + data, err := io.ReadAll(io.LimitReader(res.Body, maxResponseSize+1)) + if err != nil { + return nil, fmt.Errorf("reading register response: %w", err) + } + if int64(len(data)) > maxResponseSize { + return nil, fmt.Errorf("register response exceeds max %d", maxResponseSize) + } + var resp tailcfg.RegisterResponse + if err := json.Unmarshal(data, &resp); err != nil { + return nil, fmt.Errorf("decoding register response: %w", err) + } + if resp.Error != "" { + return nil, fmt.Errorf("register: %s", resp.Error) + } + return &resp, nil +} diff --git a/control/tsp/tsp.go b/control/tsp/tsp.go new file mode 100644 index 000000000..23f2fc261 --- /dev/null +++ b/control/tsp/tsp.go @@ -0,0 +1,257 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +// Package tsp provides a client for speaking the Tailscale protocol +// to a coordination server over Noise. +package tsp + +import ( + "bufio" + "bytes" + "cmp" + "context" + "encoding/json" + "fmt" + "io" + "net" + "net/http" + "strconv" + "sync" + + "tailscale.com/control/ts2021" + "tailscale.com/health" + "tailscale.com/ipn" + "tailscale.com/net/tsdial" + "tailscale.com/tailcfg" + "tailscale.com/types/key" + "tailscale.com/types/logger" + "tailscale.com/version" +) + +// DefaultServerURL is the default coordination server base URL, +// used when ClientOpts.ServerURL is empty. +const DefaultServerURL = ipn.DefaultControlURL + +// ClientOpts contains options for creating a new Client. +type ClientOpts struct { + // ServerURL is the base URL of the coordination server + // (e.g. "https://controlplane.tailscale.com"). + // If empty, DefaultServerURL is used. + ServerURL string + + // MachineKey is this node's machine private key. Required. + MachineKey key.MachinePrivate + + // Logf is the log function. If nil, logger.Discard is used. + Logf logger.Logf + + // HealthTracker, if non-nil, is the health tracker passed through + // to the underlying noise client. May be nil. + HealthTracker *health.Tracker +} + +// Client is a Tailscale protocol client that speaks to a coordination +// server over Noise. +type Client struct { + opts ClientOpts + serverURL string + logf logger.Logf + + mu sync.Mutex + nc *ts2021.Client // nil until noiseClient called + serverPub key.MachinePublic // zero until set or discovered +} + +// NewClient creates a new Client configured to talk to the coordination server +// specified in opts. It performs no I/O; the server's public key is discovered +// lazily on first use or can be set explicitly via SetControlPublicKey. +func NewClient(opts ClientOpts) (*Client, error) { + if opts.MachineKey.IsZero() { + return nil, fmt.Errorf("MachineKey is required") + } + logf := opts.Logf + if logf == nil { + logf = logger.Discard + } + return &Client{ + opts: opts, + serverURL: cmp.Or(opts.ServerURL, DefaultServerURL), + logf: logf, + }, nil +} + +// SetControlPublicKey sets the server's public key, bypassing lazy discovery. +// Any existing noise client is invalidated and will be re-created on next use. +func (c *Client) SetControlPublicKey(k key.MachinePublic) { + c.mu.Lock() + defer c.mu.Unlock() + c.serverPub = k + c.nc = nil +} + +// DiscoverServerKey fetches the server's public key from the coordination +// server and stores it for subsequent use. Any existing noise client is +// invalidated. +func (c *Client) DiscoverServerKey(ctx context.Context) (key.MachinePublic, error) { + k, err := DiscoverServerKey(ctx, c.serverURL) + if err != nil { + return key.MachinePublic{}, err + } + c.mu.Lock() + defer c.mu.Unlock() + c.serverPub = k + c.nc = nil + return k, nil +} + +// DiscoverServerKey fetches the coordination server's public key from the +// given server URL. It is a standalone function that requires no client state. +func DiscoverServerKey(ctx context.Context, serverURL string) (key.MachinePublic, error) { + serverURL = cmp.Or(serverURL, DefaultServerURL) + keysURL := serverURL + "/key?v=" + strconv.Itoa(int(tailcfg.CurrentCapabilityVersion)) + req, err := http.NewRequestWithContext(ctx, "GET", keysURL, nil) + if err != nil { + return key.MachinePublic{}, fmt.Errorf("creating key request: %w", err) + } + res, err := http.DefaultClient.Do(req) + if err != nil { + return key.MachinePublic{}, fmt.Errorf("fetching server key: %w", err) + } + defer res.Body.Close() + if res.StatusCode != 200 { + return key.MachinePublic{}, fmt.Errorf("fetching server key: %s", res.Status) + } + var keys struct { + PublicKey key.MachinePublic + } + if err := json.NewDecoder(res.Body).Decode(&keys); err != nil { + return key.MachinePublic{}, fmt.Errorf("decoding server key: %w", err) + } + return keys.PublicKey, nil +} + +// noiseClient returns the ts2021 noise client, creating it lazily if needed. +// If the server's public key is not yet known, it is discovered via HTTP. +func (c *Client) noiseClient(ctx context.Context) (*ts2021.Client, error) { + c.mu.Lock() + defer c.mu.Unlock() + + if c.nc != nil { + return c.nc, nil + } + + if c.serverPub.IsZero() { + // Discover server key without holding the lock, to avoid blocking + // other callers during the HTTP request. + c.mu.Unlock() + k, err := DiscoverServerKey(ctx, c.serverURL) + c.mu.Lock() + if err != nil { + return nil, err + } + // Re-check: another goroutine may have set it while we were unlocked. + if c.serverPub.IsZero() { + c.serverPub = k + } + // If nc was created by another goroutine while unlocked, use it. + if c.nc != nil { + return c.nc, nil + } + } + + nc, err := ts2021.NewClient(ts2021.ClientOpts{ + ServerURL: c.serverURL, + PrivKey: c.opts.MachineKey, + ServerPubKey: c.serverPub, + Dialer: tsdial.NewFromFuncForDebug(c.logf, (&net.Dialer{}).DialContext), + Logf: c.logf, + HealthTracker: c.opts.HealthTracker, + }) + if err != nil { + return nil, fmt.Errorf("creating noise client: %w", err) + } + c.nc = nc + return nc, nil +} + +// AnswerC2NPing handles a c2n PingRequest from the control plane by parsing the +// embedded HTTP request in the payload, routing it locally, and POSTing the HTTP +// response back to pr.URL using doNoiseRequest. The POST is done in a new +// goroutine so this method does not block. +// +// It reports whether the ping was handled. Unhandled pings (nil pr, non-c2n +// types, or unrecognized c2n paths) return false. +func (c *Client) AnswerC2NPing(ctx context.Context, pr *tailcfg.PingRequest, doNoiseRequest func(*http.Request) (*http.Response, error)) (handled bool) { + if pr == nil || pr.Types != "c2n" { + return false + } + + // Parse the HTTP request from the payload. + httpReq, err := http.ReadRequest(bufio.NewReader(bytes.NewReader(pr.Payload))) + if err != nil { + c.logf("parsing c2n ping payload: %v", err) + return false + } + + // Route the request locally. + var httpResp *http.Response + switch httpReq.URL.Path { + case "/echo": + body, _ := io.ReadAll(httpReq.Body) + httpResp = &http.Response{ + StatusCode: 200, + Status: "200 OK", + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: http.Header{}, + Body: io.NopCloser(bytes.NewReader(body)), + ContentLength: int64(len(body)), + } + default: + c.logf("ignoring c2n ping request for unhandled path %q", httpReq.URL.Path) + return false + } + + // Serialize the HTTP response. + var buf bytes.Buffer + if err := httpResp.Write(&buf); err != nil { + c.logf("serializing c2n ping response: %v", err) + return false + } + + // Send the response back to the control plane over the Noise channel. + go func() { + req, err := http.NewRequestWithContext(ctx, "POST", pr.URL, &buf) + if err != nil { + c.logf("creating c2n ping reply request: %v", err) + return + } + resp, err := doNoiseRequest(req) + if err != nil { + c.logf("sending c2n ping reply: %v", err) + return + } + resp.Body.Close() + }() + return true +} + +// Close closes the client and releases resources. +func (c *Client) Close() error { + c.mu.Lock() + nc := c.nc + c.nc = nil + c.mu.Unlock() + if nc != nil { + nc.Close() + } + return nil +} + +func defaultHostinfo() *tailcfg.Hostinfo { + return &tailcfg.Hostinfo{ + OS: version.OS(), + IPNVersion: version.Long(), + } +} diff --git a/derp/derp_test.go b/derp/derp_test.go index 24d509944..0edbaff17 100644 --- a/derp/derp_test.go +++ b/derp/derp_test.go @@ -353,7 +353,7 @@ func TestSendRecv(t *testing.T) { } } - serverMetrics := s.ExpVar().(*metrics.Set) + serverMetrics := s.ExpVar(false).(*metrics.Set) wantActive := func(total, home int64) { t.Helper() diff --git a/derp/derpserver/derpserver.go b/derp/derpserver/derpserver.go index ae8e9d433..947f4b005 100644 --- a/derp/derpserver/derpserver.go +++ b/derp/derpserver/derpserver.go @@ -38,6 +38,7 @@ import ( "time" "github.com/axiomhq/hyperloglog" + "github.com/go4org/hashtriemap" "go4.org/mem" "golang.org/x/sync/errgroup" xrate "golang.org/x/time/rate" @@ -172,6 +173,8 @@ type Server struct { meshUpdateBatchSize *metrics.Histogram meshUpdateLoopCount *metrics.Histogram bufferedWriteFrames *metrics.Histogram // how many sendLoop frames (or groups of related frames) get written per flush + rateLimitPerClientWaited expvar.Int // number of times per-client rate limit caused a wait + // TODO(illotum): add metrics for rate limited wait time, consider total seconds vs a histogram. // verifyClientsLocalTailscaled only accepts client connections to the DERP // server if the clientKey is a known peer in the network, as specified by a @@ -181,11 +184,23 @@ type Server struct { verifyClientsURL string verifyClientsURLFailOpen bool - mu syncs.Mutex + perClientSendQueueDepth int // Sets the client send queue depth for the server. + tcpWriteTimeout time.Duration + clock tstime.Clock + + mu syncs.Mutex // guards the following fields closed bool netConns map[derp.Conn]chan struct{} // chan is closed when conn closes - clients map[key.NodePublic]*clientSet - watchers set.Set[*sclient] // mesh peers + // clients holds the set of clients connected locally to this server, + // keyed by their public key. Writes happen under Server.mu so they + // stay consistent with clientsMesh, watchers, dup tracking, and the + // numLocalClientKeys counter. Reads on the packet send hot path + // are performed lock-free; see lookupDest. + clients hashtriemap.HashTrieMap[key.NodePublic, *clientSet] + // numLocalClientKeys is the number of distinct keys in clients. + // HashTrieMap has no Len, so the count is tracked here. + numLocalClientKeys int + watchers set.Set[*sclient] // mesh peers // clientsMesh tracks all clients in the cluster, both locally // and to mesh peers. If the value is nil, that means the // peer is only local (and thus in the clients Map, but not @@ -197,25 +212,9 @@ type Server struct { // is gone from the region, we notify all of these watchers, // calling their funcs in a new goroutine. peerGoneWatchers map[key.NodePublic]set.HandleSet[func(key.NodePublic)] - // maps from netip.AddrPort to a client's public key - keyOfAddr map[netip.AddrPort]key.NodePublic - - // Sets the client send queue depth for the server. - perClientSendQueueDepth int - - tcpWriteTimeout time.Duration - - // perClientRecvBytesPerSec is the rate limit for receiving data from - // a single client connection, in bytes per second. 0 means unlimited. - // Mesh peers are exempt from this limit. - perClientRecvBytesPerSec uint - // perClientRecvBurst is the burst size in bytes for the per-client - // receive rate limiter. perClientRecvBurst is only relevant when - // perClientRecvBytesPerSec is nonzero. - perClientRecvBurst uint - - clock tstime.Clock + keyOfAddr map[netip.AddrPort]key.NodePublic + rateConfig RateConfig // per-client DERP frame rate limiting config } // clientSet represents 1 or more *sclients. @@ -239,10 +238,6 @@ type clientSet struct { // activeClient holds the currently active connection for the set. It's nil // if there are no connections or the connection is disabled. // - // A pointer to a clientSet can be held by peers for long periods of time - // without holding Server.mu to avoid mutex contention on Server.mu, only - // re-acquiring the mutex and checking the clients map if activeClient is - // nil. activeClient atomic.Pointer[sclient] // dup is non-nil if there are multiple connections for the @@ -378,7 +373,6 @@ func New(privateKey key.NodePrivate, logf logger.Logf) *Server { logf: logf, limitedLogf: logger.RateLimitedFn(logf, 30*time.Second, 5, 100), packetsRecvByKind: metrics.LabelMap{Label: "kind"}, - clients: map[key.NodePublic]*clientSet{}, clientsMesh: map[key.NodePublic]PacketForwarder{}, netConns: map[derp.Conn]chan struct{}{}, memSys0: ms.Sys, @@ -518,14 +512,77 @@ func (s *Server) SetTCPWriteTimeout(d time.Duration) { s.tcpWriteTimeout = d } -// SetPerClientRateLimit sets the per-client receive rate limit in bytes per -// second and the burst size in bytes. Mesh peers are exempt from this limit. -// The burst is at least [derp.MaxPacketSize], or burst, if burst is greater -// than [derp.MaxPacketSize]. This ensures at least a full packet can -// be received in a burst, even if the rate limit is low. -func (s *Server) SetPerClientRateLimit(bytesPerSec, burst uint) { - s.perClientRecvBytesPerSec = bytesPerSec - s.perClientRecvBurst = max(burst, derp.MaxPacketSize) +// minRateLimitTokenBucketSize represents the minimum size of a token bucket +// applied for the purposes of rate limiting a DERP connection per received DERP +// frame. +// +// Note: The DERP protocol supports frames larger than this ([math.MaxUint32] length), +// but a [derp.FrameSendPacket] cannot exceed this value, which is what we optimize +// our token bucket calls for. +const minRateLimitTokenBucketSize = derp.MaxPacketSize + derp.KeyLen + +// RateConfig is a JSON-serializable configuration for rate limits. Values are +// in bytes. +type RateConfig struct { + // PerClientRateLimitBytesPerSec represents the per-client + // rate limit in bytes per second. A zero value disables all rate limiting. + PerClientRateLimitBytesPerSec uint64 `json:",omitzero"` + // PerClientRateBurstBytes represents the per-client token bucket depth, + // or burst, in bytes. Any value lower than [minRateLimitTokenBucketSize] + // will be increased to [minRateLimitTokenBucketSize] before application. Only + // relevant if PerClientRateLimitBytesPerSec is nonzero. + PerClientRateBurstBytes uint64 `json:",omitzero"` +} + +// LoadRateConfig reads and JSON-unmarshals a [RateConfig] from the file at path. +func LoadRateConfig(path string) (RateConfig, error) { + if path == "" { + return RateConfig{}, errors.New("rate config path is empty") + } + b, err := os.ReadFile(path) + if err != nil { + return RateConfig{}, fmt.Errorf("error reading rate config: %w", err) + } + var rc RateConfig + if err := json.Unmarshal(b, &rc); err != nil { + return RateConfig{}, fmt.Errorf("error parsing rate config: %w", err) + } + return rc, nil +} + +// LoadAndApplyRateConfig reads a [RateConfig] from the file at path and +// applies it to the server via [Server.UpdateRateLimits]. +func (s *Server) LoadAndApplyRateConfig(path string) error { + rc, err := LoadRateConfig(path) + if err != nil { + return err + } + applied := s.UpdateRateLimits(rc) + s.logf("rate config applied: client-rate=%d bytes/sec, client-burst=%d bytes", + applied.PerClientRateLimitBytesPerSec, applied.PerClientRateBurstBytes) + return nil +} + +// UpdateRateLimits sets the receive rate limits, updating all existing client +// connections. It returns the applied config, which may differ from rc. If the +// per-client rate limits is 0, rate limiting is disabled. Mesh peers are always +// exempt from rate limiting. +func (s *Server) UpdateRateLimits(rc RateConfig) (applied RateConfig) { + s.mu.Lock() + defer s.mu.Unlock() + if rc.PerClientRateLimitBytesPerSec == 0 { + // all rate limiting is disabled + rc = RateConfig{} + } else { + rc.PerClientRateBurstBytes = max(rc.PerClientRateBurstBytes, minRateLimitTokenBucketSize) + } + s.rateConfig = rc + for _, cs := range s.clients.All() { + cs.ForeachClient(func(c *sclient) { + c.setRateLimit(rc.PerClientRateLimitBytesPerSec, rc.PerClientRateBurstBytes) + }) + } + return rc } // HasMeshKey reports whether the server is configured with a mesh key. @@ -577,7 +634,7 @@ func (s *Server) isClosed() bool { func (s *Server) IsClientConnectedForTest(k key.NodePublic) bool { s.mu.Lock() defer s.mu.Unlock() - x, ok := s.clients[k] + x, ok := s.clients.Load(k) if !ok { return false } @@ -690,11 +747,14 @@ func (s *Server) registerClient(c *sclient) { s.mu.Lock() defer s.mu.Unlock() - cs, ok := s.clients[c.key] + c.setRateLimit(s.rateConfig.PerClientRateLimitBytesPerSec, s.rateConfig.PerClientRateBurstBytes) + + cs, ok := s.clients.Load(c.key) if !ok { c.debugLogf("register single client") cs = &clientSet{} - s.clients[c.key] = cs + s.clients.Store(c.key, cs) + s.numLocalClientKeys++ } was := cs.activeClient.Load() if was == nil { @@ -760,7 +820,7 @@ func (s *Server) unregisterClient(c *sclient) { s.mu.Lock() defer s.mu.Unlock() - set, ok := s.clients[c.key] + set, ok := s.clients.Load(c.key) if !ok { c.logf("[unexpected]; clients map is empty") return @@ -780,7 +840,9 @@ func (s *Server) unregisterClient(c *sclient) { } c.debugLogf("removed connection") set.activeClient.Store(nil) - delete(s.clients, c.key) + if s.clients.CompareAndDelete(c.key, set) { + s.numLocalClientKeys-- + } if v, ok := s.clientsMesh[c.key]; ok && v == nil { delete(s.clientsMesh, c.key) s.notePeerGoneFromRegionLocked(c.key) @@ -911,7 +973,7 @@ func (s *Server) addWatcher(c *sclient) { defer s.mu.Unlock() // Queue messages for each already-connected client. - for peer, clientSet := range s.clients { + for peer, clientSet := range s.clients.All() { ac := clientSet.activeClient.Load() if ac == nil { continue @@ -975,9 +1037,6 @@ func (s *Server) accept(ctx context.Context, nc derp.Conn, brw *bufio.ReadWriter peerGoneLim: rate.NewLimiter(rate.Every(time.Second), 3), } - if s.perClientRecvBytesPerSec > 0 && !c.canMesh { - c.recvLim = xrate.NewLimiter(xrate.Limit(s.perClientRecvBytesPerSec), int(s.perClientRecvBurst)) - } if c.canMesh { c.meshUpdate = make(chan struct{}, 1) // must be buffered; >1 is fine but wasteful } @@ -1050,10 +1109,9 @@ func (c *sclient) run(ctx context.Context) error { } return fmt.Errorf("client %s: readFrameHeader: %w", c.key.ShortString(), err) } - // Rate limit by DERP frame length (fl), which excludes DERP - // and TLS protocol overheads. + // Rate-limit by DERP frame length (fl), which excludes TLS protocol and + // DERP frame length field overheads. // Note: meshed clients are exempt from rate limits. - // meshed clients are exempt from rate limits if err := c.rateLimit(int(fl)); err != nil { return err // context canceled, connection closing } @@ -1153,7 +1211,7 @@ func (c *sclient) handleFrameClosePeer(ft derp.FrameType, fl uint32) error { s.mu.Lock() defer s.mu.Unlock() - if set, ok := s.clients[targetKey]; ok { + if set, ok := s.clients.Load(targetKey); ok { if set.Len() == 1 { c.logf("frameClosePeer closing peer %x", targetKey) } else { @@ -1183,15 +1241,10 @@ func (c *sclient) handleFrameForwardPacket(_ derp.FrameType, fl uint32) error { } s.packetsForwardedIn.Add(1) - var dstLen int - var dst *sclient - - s.mu.Lock() - if set, ok := s.clients[dstKey]; ok { - dstLen = set.Len() - dst = set.activeClient.Load() - } - s.mu.Unlock() + // Use the same lock-free fast path as the local send path. The mesh + // forwarder return is intentionally discarded: we never re-forward an + // already-forwarded packet. + dst, _, dstLen := c.lookupDest(dstKey) if dst == nil { reason := dropReasonUnknownDestOnFwd @@ -1213,6 +1266,40 @@ func (c *sclient) handleFrameForwardPacket(_ derp.FrameType, fl uint32) error { }) } +// lookupDest returns the local client, mesh forwarder, or duplicate-client +// count for dst. dstLen is only meaningful when the returned local client is +// nil; when a local client is returned, dstLen is just non-zero. +// +// The fast path reads Server.clients lock-free: if a *clientSet is present +// for dst and has an active client, we return that without taking Server.mu. +// Misses, inactive clientSets, duplicate-client accounting, and mesh +// forwarder lookups fall through to a slow path under Server.mu. At most +// one local client and PacketForwarder can be non-nil: local clients win +// over mesh forwarding, and mesh forwarding is considered only when there +// is no local clientSet. +func (c *sclient) lookupDest(dst key.NodePublic) (_ *sclient, fwd PacketForwarder, dstLen int) { + s := c.s + if set, ok := s.clients.Load(dst); ok { + if dst := set.activeClient.Load(); dst != nil { + return dst, nil, 1 + } + } + // Slow path: no active local client. Take Server.mu to read the + // duplicate-client count and clientsMesh consistently. + s.mu.Lock() + defer s.mu.Unlock() + if set, ok := s.clients.Load(dst); ok { + if dst := set.activeClient.Load(); dst != nil { + return dst, nil, 1 + } + dstLen = set.Len() + } + if dstLen < 1 { + fwd = s.clientsMesh[dst] + } + return nil, fwd, dstLen +} + // handleFrameSendPacket reads a "send packet" frame from the client. func (c *sclient) handleFrameSendPacket(_ derp.FrameType, fl uint32) error { s := c.s @@ -1222,19 +1309,7 @@ func (c *sclient) handleFrameSendPacket(_ derp.FrameType, fl uint32) error { return fmt.Errorf("client %v: recvPacket: %v", c.key, err) } - var fwd PacketForwarder - var dstLen int - var dst *sclient - - s.mu.Lock() - if set, ok := s.clients[dstKey]; ok { - dstLen = set.Len() - dst = set.activeClient.Load() - } - if dst == nil && dstLen < 1 { - fwd = s.clientsMesh[dstKey] - } - s.mu.Unlock() + dst, fwd, dstLen := c.lookupDest(dstKey) if dst == nil { if fwd != nil { @@ -1267,19 +1342,78 @@ func (c *sclient) handleFrameSendPacket(_ derp.FrameType, fl uint32) error { return c.sendPkt(dst, p) } -// rateLimit applies the per-client receive rate limit, if configured. +// setRateLimit updates the receive rate limiter. When bytesPerSec is 0, or the +// client is a mesh peer, the limiter is set to nil so that [sclient.rateLimit] is a no-op. +func (c *sclient) setRateLimit(bytesPerSec, burst uint64) { + if c.canMesh || bytesPerSec == 0 { + c.recvLim.Store(nil) + return + } + limiter := xrate.NewLimiter(xrate.Limit(bytesPerSec), int(burst)) + c.recvLim.Store(limiter) +} + +// rateLimitWait is a reimplementation of [xrate.Limiter.WaitN] via [xrate.Limiter.ReserveN]. +// It returns the duration waited for tokens to become available. +func rateLimitWait(ctx context.Context, lim *xrate.Limiter, n int, now time.Time, newTimer func(time.Duration) (<-chan time.Time, func() bool)) (time.Duration, error) { + r := lim.ReserveN(now, n) + if !r.OK() { + return 0, fmt.Errorf("rate: Wait(n=%d) exceeds limiter's burst %d", n, lim.Burst()) + } + delay := r.DelayFrom(now) + if delay == 0 { + return 0, nil + } + ch, stop := newTimer(delay) + defer stop() + select { + case <-ch: + // Note: We return the predicted delay as wall-clock duration. May be not the same. + return delay, nil + case <-ctx.Done(): + r.Cancel() + return 0, ctx.Err() + } +} + +// rateLimit applies the receive rate limit. // By limiting here we prevent reading from the buffered reader // [sclient.br] if the limit has been exceeded. Any reads done here provide space // within the buffered reader to fill back in with data from // the TCP socket. Pacing reads acts as a form of natural // backpressure via TCP flow control. -// meshed clients are exempt from rate limits. +// When rate limiting is disabled or the client is a mesh peer, recvLim is nil +// and this is a no-op. func (c *sclient) rateLimit(n int) error { - if c.recvLim == nil || c.canMesh { - return nil + if lim := c.recvLim.Load(); lim != nil { + newTimer := func(d time.Duration) (<-chan time.Time, func() bool) { + tc, ch := c.s.clock.NewTimer(d) + return ch, tc.Stop + } + // If n exceeds the capacity of the bucket, then WaitN will return + // an error and consume zero tokens. To prevent this, clamp n to + // [minRateLimitTokenBucketSize]. + // + // While we could call WaitN multiple times and/or more precisely for + // lim.Burst(), it's better to return early as a larger DERP frame: + // 1. is unexpected + // 2. is only partially read off the socket (bufio) + // 3. would cause the connection to close shortly after rate limiting, anyway. + clampedN := min(n, minRateLimitTokenBucketSize) + now := c.s.clock.Now() + var ( + durationWaited time.Duration + err error + ) + durationWaited, err = rateLimitWait(c.ctx, lim, clampedN, now, newTimer) + if err != nil { + return err + } + if durationWaited > 0 { + c.s.rateLimitPerClientWaited.Add(1) + } } - - return c.recvLim.WaitN(c.ctx, n) + return nil } func (c *sclient) debugLogf(format string, v ...any) { @@ -1507,7 +1641,7 @@ func (s *Server) noteClientActivity(c *sclient) { s.mu.Lock() defer s.mu.Unlock() - cs, ok := s.clients[c.key] + cs, ok := s.clients.Load(c.key) if !ok { return } @@ -1714,10 +1848,15 @@ type sclient struct { // through us with a peer we have no record of. peerGoneLim *rate.Limiter - // recvLim is the per-connection receive rate limiter. If non-nil, - // the server calls WaitN per received DERP frame in order to - // apply TCP backpressure and throttle the sender. - recvLim *xrate.Limiter + // recvLim is the receive rate limiter. When rate limiting is enabled for a + // non-mesh client, it points to a [xrate.Limiter]. When rate limiting + // is disabled or the client is a mesh peer, it is nil and [sclient.rateLimit] + // is a no-op. Updated atomically by [sclient.setRateLimit] so that + // [sclient.rateLimit] can load it without holding [Server.mu]. + // + // TODO: consider porting the required APIs from [xrate.Limiter] to [rate.Limiter], + // which is already optimized to use [mono.Time]. + recvLim atomic.Pointer[xrate.Limiter] } func (c *sclient) presentFlags() derp.PeerPresentFlags { @@ -2168,7 +2307,7 @@ func (s *Server) RemovePacketForwarder(dst key.NodePublic, fwd PacketForwarder) return } - if _, isLocal := s.clients[dst]; isLocal { + if _, isLocal := s.clients.Load(dst); isLocal { s.clientsMesh[dst] = nil } else { delete(s.clientsMesh, dst) @@ -2258,7 +2397,7 @@ func (s *Server) expVarFunc(f func() any) expvar.Func { } // ExpVar returns an expvar variable suitable for registering with expvar.Publish. -func (s *Server) ExpVar() expvar.Var { +func (s *Server) ExpVar(rateLimitEnabled bool) expvar.Var { m := new(metrics.Set) m.Set("gauge_memstats_sys0", expvar.Func(func() any { return int64(s.memSys0) })) m.Set("gauge_watchers", s.expVarFunc(func() any { return len(s.watchers) })) @@ -2267,8 +2406,8 @@ func (s *Server) ExpVar() expvar.Var { m.Set("gauge_current_home_connections", &s.curHomeClients) m.Set("gauge_current_notideal_connections", &s.curClientsNotIdeal) m.Set("gauge_clients_total", s.expVarFunc(func() any { return len(s.clientsMesh) })) - m.Set("gauge_clients_local", s.expVarFunc(func() any { return len(s.clients) })) - m.Set("gauge_clients_remote", s.expVarFunc(func() any { return len(s.clientsMesh) - len(s.clients) })) + m.Set("gauge_clients_local", s.expVarFunc(func() any { return s.numLocalClientKeys })) + m.Set("gauge_clients_remote", s.expVarFunc(func() any { return len(s.clientsMesh) - s.numLocalClientKeys })) m.Set("gauge_current_dup_client_keys", &s.dupClientKeys) m.Set("gauge_current_dup_client_conns", &s.dupClientConns) m.Set("counter_total_dup_client_conns", &s.dupClientConnTotal) @@ -2301,6 +2440,18 @@ func (s *Server) ExpVar() expvar.Var { var expvarVersion expvar.String expvarVersion.Set(version.Long()) m.Set("version", &expvarVersion) + if rateLimitEnabled { + // Rate limiting is currently experimental, its APIs are unstable, and it must + // be opted-in via --rate-config. Therefore, we only publish related metrics + // on demand, to avoid polluting uninterested metrics consumers. + m.Set("rate_limit_per_client_bytes_per_second", s.expVarFunc(func() any { + return s.rateConfig.PerClientRateLimitBytesPerSec + })) + m.Set("rate_limit_per_client_burst_bytes", s.expVarFunc(func() any { + return s.rateConfig.PerClientRateBurstBytes + })) + m.Set("rate_limit_per_client_waited", &s.rateLimitPerClientWaited) + } return m } @@ -2313,7 +2464,7 @@ func (s *Server) ConsistencyCheck() error { var nilMeshNotInClient int for k, f := range s.clientsMesh { if f == nil { - if _, ok := s.clients[k]; !ok { + if _, ok := s.clients.Load(k); !ok { nilMeshNotInClient++ } } @@ -2323,7 +2474,7 @@ func (s *Server) ConsistencyCheck() error { } var clientNotInMesh int - for k := range s.clients { + for k := range s.clients.All() { if _, ok := s.clientsMesh[k]; !ok { clientNotInMesh++ } @@ -2332,10 +2483,10 @@ func (s *Server) ConsistencyCheck() error { errs = append(errs, fmt.Sprintf("%d s.clients keys not in s.clientsMesh", clientNotInMesh)) } - if s.curClients.Value() != int64(len(s.clients)) { + if s.curClients.Value() != int64(s.numLocalClientKeys) { errs = append(errs, fmt.Sprintf("expvar connections = %d != clients map says of %d", s.curClients.Value(), - len(s.clients))) + s.numLocalClientKeys)) } if s.verifyClientsLocalTailscaled { @@ -2431,7 +2582,7 @@ func (s *Server) ServeDebugTraffic(w http.ResponseWriter, r *http.Request) { if prev.Sent < next.Sent || prev.Recv < next.Recv { if pkey, ok := s.keyOfAddr[k]; ok { next.Key = pkey - if cs, ok := s.clients[pkey]; ok { + if cs, ok := s.clients.Load(pkey); ok { if c := cs.activeClient.Load(); c != nil { next.UniqueSenders = c.EstimatedUniqueSenders() } diff --git a/derp/derpserver/derpserver_test.go b/derp/derpserver/derpserver_test.go index 3fb4b838e..13826819b 100644 --- a/derp/derpserver/derpserver_test.go +++ b/derp/derpserver/derpserver_test.go @@ -15,6 +15,7 @@ import ( "log" "net" "os" + "path/filepath" "reflect" "strconv" "sync" @@ -28,8 +29,10 @@ import ( "golang.org/x/time/rate" "tailscale.com/derp" "tailscale.com/derp/derpconst" + "tailscale.com/tstime" "tailscale.com/types/key" "tailscale.com/types/logger" + "tailscale.com/util/set" ) const testMeshKey = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef" @@ -146,7 +149,6 @@ func pubAll(b byte) (ret key.NodePublic) { func TestForwarderRegistration(t *testing.T) { s := &Server{ - clients: make(map[key.NodePublic]*clientSet), clientsMesh: map[key.NodePublic]PacketForwarder{}, } want := func(want map[key.NodePublic]PacketForwarder) { @@ -228,7 +230,7 @@ func TestForwarderRegistration(t *testing.T) { key: u1, logf: logger.Discard, } - s.clients[u1] = singleClient(u1c) + s.clients.Store(u1, singleClient(u1c)) s.RemovePacketForwarder(u1, testFwd(100)) want(map[key.NodePublic]PacketForwarder{ u1: nil, @@ -248,7 +250,7 @@ func TestForwarderRegistration(t *testing.T) { // Now pretend u1 was already connected locally (so clientsMesh[u1] is nil), and then we heard // that they're also connected to a peer of ours. That shouldn't transition the forwarder // from nil to the new one, not a multiForwarder. - s.clients[u1] = singleClient(u1c) + s.clients.Store(u1, singleClient(u1c)) s.clientsMesh[u1] = nil want(map[key.NodePublic]PacketForwarder{ u1: nil, @@ -280,7 +282,6 @@ func TestMultiForwarder(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) s := &Server{ - clients: make(map[key.NodePublic]*clientSet), clientsMesh: map[key.NodePublic]PacketForwarder{}, } u := pubAll(1) @@ -389,7 +390,7 @@ func TestServerDupClients(t *testing.T) { } wantSingleClient := func(t *testing.T, want *sclient) { t.Helper() - got, ok := s.clients[want.key] + got, ok := s.clients.Load(want.key) if !ok { t.Error("no clients for key") return @@ -412,7 +413,7 @@ func TestServerDupClients(t *testing.T) { } wantNoClient := func(t *testing.T) { t.Helper() - _, ok := s.clients[clientPub] + _, ok := s.clients.Load(clientPub) if !ok { // Good return @@ -421,7 +422,7 @@ func TestServerDupClients(t *testing.T) { } wantDupSet := func(t *testing.T) *dupClientSet { t.Helper() - cs, ok := s.clients[clientPub] + cs, ok := s.clients.Load(clientPub) if !ok { t.Fatal("no set for key; want dup set") return nil @@ -434,7 +435,7 @@ func TestServerDupClients(t *testing.T) { } wantActive := func(t *testing.T, want *sclient) { t.Helper() - set, ok := s.clients[clientPub] + set, ok := s.clients.Load(clientPub) if !ok { t.Error("no set for key") return @@ -775,7 +776,7 @@ func TestServeDebugTrafficUniqueSenders(t *testing.T) { s.mu.Lock() cs := &clientSet{} cs.activeClient.Store(c) - s.clients[clientKey] = cs + s.clients.Store(clientKey, cs) s.mu.Unlock() estimate := c.EstimatedUniqueSenders() @@ -957,25 +958,34 @@ func BenchmarkHyperLogLogEstimate(b *testing.B) { func TestPerClientRateLimit(t *testing.T) { t.Run("throttled", func(t *testing.T) { synctest.Test(t, func(t *testing.T) { - // 100 bytes/sec with a burst of 100 bytes. - const bytesPerSec = 100 - const burst = 100 - ctx, cancel := context.WithCancel(context.Background()) t.Cleanup(cancel) + s := New(key.NewNode(), logger.Discard) + defer s.Close() + c := &sclient{ - ctx: ctx, - recvLim: rate.NewLimiter(rate.Limit(bytesPerSec), burst), + ctx: ctx, + s: s, + } + lim := rate.NewLimiter(rate.Limit(minRateLimitTokenBucketSize), minRateLimitTokenBucketSize) + c.recvLim.Store(lim) + wantTokens := func(t *testing.T, wantTokens float64) { + t.Helper() + if lim.Tokens() != wantTokens { + t.Fatalf("want tokens: %v got: %v", wantTokens, lim.Tokens()) + } } // First call within burst should not block. - c.rateLimit(burst) + c.rateLimit(minRateLimitTokenBucketSize) + + wantTokens(t, 0) // Next call exceeds burst, should block until tokens replenish. done := make(chan error, 1) go func() { - done <- c.rateLimit(burst) + done <- c.rateLimit(minRateLimitTokenBucketSize) }() // After settling, the goroutine should be blocked (no result yet). @@ -986,7 +996,7 @@ func TestPerClientRateLimit(t *testing.T) { default: } - // Advance time by 1 second; 100 bytes/sec * 1s = 100 bytes = burst. + // Advance time by 1 second, the goroutine should be unblocked time.Sleep(1 * time.Second) synctest.Wait() @@ -998,6 +1008,13 @@ func TestPerClientRateLimit(t *testing.T) { default: t.Fatal("rateLimit should have unblocked after 1s") } + + wantTokens(t, 0) + + // The second rateLimit call had to wait + if got := s.rateLimitPerClientWaited.Value(); got != 1 { + t.Fatalf("rateLimitPerClientWaited = %d, want 1", got) + } }) }) @@ -1005,19 +1022,24 @@ func TestPerClientRateLimit(t *testing.T) { synctest.Test(t, func(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) + s := New(key.NewNode(), logger.Discard) + defer s.Close() + c := &sclient{ - ctx: ctx, - recvLim: rate.NewLimiter(rate.Limit(100), 100), + ctx: ctx, + s: s, } + lim := rate.NewLimiter(rate.Limit(minRateLimitTokenBucketSize), minRateLimitTokenBucketSize) + c.recvLim.Store(lim) // Exhaust burst. - if err := c.rateLimit(100); err != nil { + if err := c.rateLimit(minRateLimitTokenBucketSize); err != nil { t.Fatalf("rateLimit: %v", err) } done := make(chan error, 1) go func() { - done <- c.rateLimit(100) + done <- c.rateLimit(minRateLimitTokenBucketSize) }() synctest.Wait() @@ -1040,37 +1062,520 @@ func TestPerClientRateLimit(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) t.Cleanup(cancel) + // Mesh peers have nil recvLim, so rate limiting is a no-op. c := &sclient{ ctx: ctx, canMesh: true, - recvLim: rate.NewLimiter(rate.Limit(1), 1), // would block immediately if not exempt } - // rateLimit should be a no-op for mesh peers. if err := c.rateLimit(1000); err != nil { t.Fatalf("mesh peer rateLimit should be no-op: %v", err) } }) - t.Run("nil_limiter_no_op", func(t *testing.T) { - ctx, cancel := context.WithCancel(context.Background()) - t.Cleanup(cancel) - - c := &sclient{ - ctx: ctx, - } - - // rateLimit with nil recvLim should be a no-op. - if err := c.rateLimit(1000); err != nil { - t.Fatalf("nil limiter rateLimit should be no-op: %v", err) - } - }) - t.Run("zero_config_no_limiter", func(t *testing.T) { s := New(key.NewNode(), logger.Discard) defer s.Close() - if s.perClientRecvBytesPerSec != 0 { - t.Errorf("expected zero rate limit, got %d", s.perClientRecvBytesPerSec) + if !reflect.DeepEqual(s.rateConfig, RateConfig{}) { + t.Errorf("expected zero rate limit, got %+v", s.rateConfig) + } + }) +} + +// zeroTimer returns a timer that fires immediately. +func zeroTimer(_ time.Duration) (<-chan time.Time, func() bool) { + t := time.NewTimer(0) + return t.C, t.Stop +} + +// neverTimer returns a timer that never fires. +func neverTimer(_ time.Duration) (<-chan time.Time, func() bool) { + return make(chan time.Time), func() bool { return false } +} + +func TestRateLimitWait(t *testing.T) { + ctx := context.Background() + + t.Run("no_wait", func(t *testing.T) { + lim := rate.NewLimiter(10, 10) + waited, err := rateLimitWait(ctx, lim, 5, time.Now(), zeroTimer) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if waited != 0 { + t.Fatalf("waited = %v, want 0", waited) + } + }) + + t.Run("wait_for_tokens", func(t *testing.T) { + lim := rate.NewLimiter(10, 10) + now := time.Now() + waited, err := rateLimitWait(ctx, lim, 10, now, zeroTimer) // exhaust all tokens + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if waited != 0 { + t.Fatalf("waited = %v, want 0", waited) + } + waited, err = rateLimitWait(ctx, lim, 10, now, zeroTimer) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if waited == 0 { + t.Fatal("waited = 0, want > 0") + } + }) + + t.Run("context_canceled", func(t *testing.T) { + lim := rate.NewLimiter(10, 10) + now := time.Now() + _, err := rateLimitWait(ctx, lim, 10, now, zeroTimer) // exhaust all tokens + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + canceled, cancel := context.WithCancel(ctx) // cancel context so the select picks ctx.Done() + cancel() + waited, err := rateLimitWait(canceled, lim, 10, now, neverTimer) // neverTimer to only unblock via context + if err == nil { + t.Fatal("expected error from canceled context") + } + if waited != 0 { + t.Fatalf("waited = %v, want 0", waited) + } + }) + + t.Run("n_exceeds_burst", func(t *testing.T) { + lim := rate.NewLimiter(10, 5) + waited, err := rateLimitWait(ctx, lim, 10, time.Now(), zeroTimer) + if err == nil { + t.Fatal("expected error when n > burst") + } + if waited != 0 { + t.Fatalf("waited = %v, want 0", waited) + } + }) +} + +func verifyLimiter(t *testing.T, lim *rate.Limiter, wantRateConfig RateConfig) { + t.Helper() + if got := lim.Limit(); got != rate.Limit(wantRateConfig.PerClientRateLimitBytesPerSec) { + t.Errorf("client rate limit = %v; want %d", got, wantRateConfig.PerClientRateLimitBytesPerSec) + } + if got := lim.Burst(); got != int(wantRateConfig.PerClientRateBurstBytes) { + t.Errorf("client burst = %v; want %d", got, wantRateConfig.PerClientRateBurstBytes) + } +} + +func TestUpdateRateLimits(t *testing.T) { + const ( + testClientBurst1 = minRateLimitTokenBucketSize + 1 + testClientRate1 = minRateLimitTokenBucketSize + 2 + testClientBurst2 = minRateLimitTokenBucketSize + 3 + testClientRate2 = minRateLimitTokenBucketSize + 4 + ) + + s := New(key.NewNode(), t.Logf) + defer s.Close() + + // Create a non-mesh client with no initial limiter. + clientKey := key.NewNode().Public() + c := &sclient{ + key: clientKey, + s: s, + logf: logger.Discard, + canMesh: false, + } + cs := &clientSet{} + cs.activeClient.Store(c) + + s.mu.Lock() + s.clients.Store(clientKey, cs) + s.mu.Unlock() + + rc := RateConfig{ + PerClientRateLimitBytesPerSec: testClientRate1, + PerClientRateBurstBytes: testClientBurst1, + } + s.UpdateRateLimits(rc) + + lim := c.recvLim.Load() + if lim == nil { + t.Fatal("expected non-nil limiter after update") + } + verifyLimiter(t, lim, rc) + + // Verify server fields updated. + s.mu.Lock() + if !reflect.DeepEqual(s.rateConfig, rc) { + t.Errorf("s.rateConfig = %+v; want %+v", s.rateConfig, rc) + } + s.mu.Unlock() + + // Update again with different nonzero values. + rc = RateConfig{ + PerClientRateLimitBytesPerSec: testClientRate2, + PerClientRateBurstBytes: testClientBurst2, + } + s.UpdateRateLimits(rc) + lim = c.recvLim.Load() + if lim == nil { + t.Fatal("expected non-nil limiter") + } + verifyLimiter(t, lim, rc) + + // Disable rate limiting (set to 0). + s.UpdateRateLimits(RateConfig{}) + + if got := c.recvLim.Load(); got != nil { + t.Errorf("expected nil limiter after disable, got limit=%v", got.Limit()) + } + + // Mesh peer should always have nil limiter regardless of update. + meshKey := key.NewNode().Public() + meshClient := &sclient{ + key: meshKey, + s: s, + logf: logger.Discard, + canMesh: true, + } + meshCS := &clientSet{} + meshCS.activeClient.Store(meshClient) + + s.mu.Lock() + s.clients.Store(meshKey, meshCS) + s.mu.Unlock() + + rc = RateConfig{ + PerClientRateLimitBytesPerSec: testClientRate2, + PerClientRateBurstBytes: testClientBurst2, + } + s.UpdateRateLimits(rc) + + if got := meshClient.recvLim.Load(); got != nil { + t.Errorf("mesh peer should have nil limiter, got limit=%v", got.Limit()) + } + // Non-mesh client should be updated. + lim = c.recvLim.Load() + if lim == nil { + t.Fatal("expected non-nil limiter for non-mesh client") + } + verifyLimiter(t, lim, rc) + + // Verify dup clients are also updated. + dupKey := key.NewNode().Public() + d1 := &sclient{key: dupKey, s: s, logf: logger.Discard} + d2 := &sclient{key: dupKey, s: s, logf: logger.Discard} + dupCS := &clientSet{} + dupCS.activeClient.Store(d1) + dupCS.dup = &dupClientSet{set: set.Of(d1, d2)} + s.mu.Lock() + s.clients.Store(dupKey, dupCS) + s.mu.Unlock() + + rc = RateConfig{ + PerClientRateLimitBytesPerSec: testClientRate1, + PerClientRateBurstBytes: testClientBurst1, + } + s.UpdateRateLimits(rc) + for i, d := range []*sclient{d1, d2} { + dl := d.recvLim.Load() + if dl == nil { + t.Fatalf("dup client %d: expected non-nil limiter", i) + } + verifyLimiter(t, dl, rc) + } +} + +func TestLoadRateConfig(t *testing.T) { + for _, tt := range []struct { + name string + json string + wantRateConfig RateConfig + }{ + {"all_set", `{"PerClientRateLimitBytesPerSec": 1, "PerClientRateBurstBytes": 2}`, RateConfig{ + PerClientRateLimitBytesPerSec: 1, + PerClientRateBurstBytes: 2, + }}, + {"rate_only", `{"PerClientRateLimitBytesPerSec": 1}`, RateConfig{ + PerClientRateLimitBytesPerSec: 1, + }}, + {"zeros", `{"PerClientRateLimitBytesPerSec": 0, "PerClientRateBurstBytes": 0}`, RateConfig{}}, + {"empty_json", `{}`, RateConfig{}}, + } { + t.Run(tt.name, func(t *testing.T) { + f := filepath.Join(t.TempDir(), "rate.json") + if err := os.WriteFile(f, []byte(tt.json), 0644); err != nil { + t.Fatal(err) + } + rc, err := LoadRateConfig(f) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(rc, tt.wantRateConfig) { + t.Errorf("rate config = %v want %v", rc, tt.wantRateConfig) + } + }) + } + + for _, tt := range []struct { + name string + path string + content string // written to loaded path if non-empty; path used as-is if empty + }{ + {"empty_path", "", ""}, + {"missing_file", filepath.Join(t.TempDir(), "nonexistent.json"), ""}, + {"invalid_json", "", "not json"}, + } { + t.Run(tt.name, func(t *testing.T) { + path := tt.path + if tt.content != "" { + path = filepath.Join(t.TempDir(), "rate.json") + if err := os.WriteFile(path, []byte(tt.content), 0644); err != nil { + t.Fatal(err) + } + } + _, err := LoadRateConfig(path) + if err == nil { + t.Fatal("expected error") + } + }) + } +} + +func TestLoadAndApplyRateConfig(t *testing.T) { + writeConfig := func(t *testing.T, json string) string { + t.Helper() + f := filepath.Join(t.TempDir(), "rate.json") + if err := os.WriteFile(f, []byte(json), 0644); err != nil { + t.Fatal(err) + } + return f + } + + t.Run("applies_and_updates_clients", func(t *testing.T) { + s := New(key.NewNode(), t.Logf) + defer s.Close() + + clientKey := key.NewNode().Public() + c := &sclient{key: clientKey, s: s, logf: logger.Discard} + cs := &clientSet{} + cs.activeClient.Store(c) + s.mu.Lock() + s.clients.Store(clientKey, cs) + s.mu.Unlock() + + f := writeConfig(t, fmt.Sprintf(`{"PerClientRateLimitBytesPerSec": %d, "PerClientRateBurstBytes": %d}`, + minRateLimitTokenBucketSize, minRateLimitTokenBucketSize+1)) + if err := s.LoadAndApplyRateConfig(f); err != nil { + t.Fatalf("LoadAndApplyRateConfig: %v", err) + } + + // Verify server fields. + wantRateConfig := RateConfig{ + PerClientRateLimitBytesPerSec: minRateLimitTokenBucketSize, + PerClientRateBurstBytes: minRateLimitTokenBucketSize + 1, + } + s.mu.Lock() + if !reflect.DeepEqual(s.rateConfig, wantRateConfig) { + t.Errorf("s.rateConfig = %+v; want %+v", s.rateConfig, wantRateConfig) + } + s.mu.Unlock() + + // Verify client limiter. + lim := c.recvLim.Load() + if lim == nil { + t.Fatal("expected non-nil limiter") + } + verifyLimiter(t, lim, wantRateConfig) + }) + + t.Run("burst_is_at_least_minRateLimitTokenBucketSize", func(t *testing.T) { + s := New(key.NewNode(), t.Logf) + defer s.Close() + + f := writeConfig(t, `{"PerClientRateLimitBytesPerSec": 1250000, "PerClientRateBurstBytes": 10}`) + if err := s.LoadAndApplyRateConfig(f); err != nil { + t.Fatalf("LoadAndApplyRateConfig: %v", err) + } + + s.mu.Lock() + gotClientBurst := s.rateConfig.PerClientRateBurstBytes + s.mu.Unlock() + if gotClientBurst != minRateLimitTokenBucketSize { + t.Errorf("client burst = %d; want %d", gotClientBurst, minRateLimitTokenBucketSize) + } + }) + + t.Run("reload_disables_limiting", func(t *testing.T) { + s := New(key.NewNode(), t.Logf) + defer s.Close() + + f := writeConfig(t, `{"PerClientRateLimitBytesPerSec": 1250000, "PerClientRateBurstBytes": 2500000}`) + if err := s.LoadAndApplyRateConfig(f); err != nil { + t.Fatal(err) + } + s.mu.Lock() + if reflect.DeepEqual(s.rateConfig, RateConfig{}) { + t.Error("s.rateConfig is zero val; want nonzero rates") + } + s.mu.Unlock() + + if err := os.WriteFile(f, []byte(`{}`), 0644); err != nil { + t.Fatal(err) + } + if err := s.LoadAndApplyRateConfig(f); err != nil { + t.Fatal(err) + } + + s.mu.Lock() + if !reflect.DeepEqual(s.rateConfig, RateConfig{}) { + t.Errorf("s.rateConfig = %+v; want %+v", s.rateConfig, RateConfig{}) + } + s.mu.Unlock() + }) + + t.Run("propagates_errors", func(t *testing.T) { + s := New(key.NewNode(), t.Logf) + defer s.Close() + + if err := s.LoadAndApplyRateConfig(filepath.Join(t.TempDir(), "nonexistent.json")); err == nil { + t.Fatal("expected error") + } + }) +} + +func TestLookupDestHashTrieFastPath(t *testing.T) { + s := &Server{ + clientsMesh: map[key.NodePublic]PacketForwarder{}, + clock: tstime.StdClock{}, + } + src := pubAll(1) + dst := pubAll(2) + dstClient := &sclient{key: dst} + cs := &clientSet{} + cs.activeClient.Store(dstClient) + s.clients.Store(dst, cs) + + c := &sclient{s: s, key: src} + got, fwd, dstLen := c.lookupDest(dst) + if got != dstClient || fwd != nil || dstLen != 1 { + t.Fatalf("lookupDest = (%v, %v, %d), want (%v, nil, 1)", got, fwd, dstLen, dstClient) + } + + // This must not deadlock while s.mu is held; the hashtrie fast path + // should not acquire Server.mu. + s.mu.Lock() + got, _, _ = c.lookupDest(dst) + s.mu.Unlock() + if got != dstClient { + t.Fatalf("lookupDest got %v, want %v", got, dstClient) + } +} + +func TestLookupDestHashTrieFallsBackForForwarder(t *testing.T) { + s := &Server{ + clientsMesh: map[key.NodePublic]PacketForwarder{}, + clock: tstime.StdClock{}, + } + src := pubAll(1) + dst := pubAll(2) + c := &sclient{s: s, key: src} + + s.clientsMesh[dst] = testFwd(1) + got, fwd, dstLen := c.lookupDest(dst) + if got != nil || fwd != testFwd(1) || dstLen != 0 { + t.Fatalf("lookupDest = (%v, %v, %d), want (nil, testFwd(1), 0)", got, fwd, dstLen) + } +} + +func TestLookupDestHashTrieIgnoresInactiveSet(t *testing.T) { + s := &Server{ + clientsMesh: map[key.NodePublic]PacketForwarder{}, + clock: tstime.StdClock{}, + } + src := pubAll(1) + dst := pubAll(2) + c := &sclient{s: s, key: src} + + // A clientSet with no activeClient (a transient state during + // register/unregister) must not be returned by the fast path. + cs := &clientSet{} + s.clients.Store(dst, cs) + + got, fwd, dstLen := c.lookupDest(dst) + if got != nil || fwd != nil || dstLen != 0 { + t.Fatalf("lookupDest with inactive set = (%v, %v, %d), want (nil, nil, 0)", got, fwd, dstLen) + } + + // Setting activeClient on the same in-map entry must make the next + // fast-path lookup observe it. + newClient := &sclient{key: dst} + cs.activeClient.Store(newClient) + got, fwd, dstLen = c.lookupDest(dst) + if got != newClient || fwd != nil || dstLen != 1 { + t.Fatalf("lookupDest after activation = (%v, %v, %d), want (%v, nil, 1)", got, fwd, dstLen, newClient) + } +} + +func TestLookupDestHashTrieNoAlloc(t *testing.T) { + s := &Server{ + clientsMesh: map[key.NodePublic]PacketForwarder{}, + clock: tstime.StdClock{}, + } + var dstKeys [4]key.NodePublic + var dstClients [4]*sclient + for i := range dstKeys { + dstKeys[i] = pubAll(byte(i + 2)) + dstClients[i] = &sclient{key: dstKeys[i]} + cs := &clientSet{} + cs.activeClient.Store(dstClients[i]) + s.clients.Store(dstKeys[i], cs) + } + c := &sclient{s: s, key: pubAll(1)} + + var i int + var got *sclient + allocs := testing.AllocsPerRun(1000, func() { + idx := i & (len(dstKeys) - 1) + got, _, _ = c.lookupDest(dstKeys[idx]) + i++ + }) + if got == nil { + t.Fatal("lookupDest returned nil") + } + if allocs != 0 { + t.Fatalf("lookupDest allocated %v times per run, want 0", allocs) + } +} + +func BenchmarkLookupDestHashTrie(b *testing.B) { + s := &Server{ + clientsMesh: map[key.NodePublic]PacketForwarder{}, + clock: tstime.StdClock{}, + } + var dstKeys [4]key.NodePublic + var dstClients [4]*sclient + for i := range dstKeys { + dstKeys[i] = pubAll(byte(i + 2)) + dstClients[i] = &sclient{key: dstKeys[i]} + cs := &clientSet{} + cs.activeClient.Store(dstClients[i]) + s.clients.Store(dstKeys[i], cs) + } + + b.ReportAllocs() + b.SetParallelism(32) + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + c := &sclient{s: s, key: pubAll(1)} + var i int + for pb.Next() { + idx := i & (len(dstKeys) - 1) + got, fwd, dstLen := c.lookupDest(dstKeys[idx]) + if got != dstClients[idx] || fwd != nil { + b.Fatalf("lookupDest = (%v, %v, %d), want (%v, nil, _)", got, fwd, dstLen, dstClients[idx]) + } + i++ } }) } diff --git a/derp/xdp/xdp_linux_test.go b/derp/xdp/xdp_linux_test.go index cb59721f7..d8de2bf62 100644 --- a/derp/xdp/xdp_linux_test.go +++ b/derp/xdp/xdp_linux_test.go @@ -18,6 +18,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/checksum" "gvisor.dev/gvisor/pkg/tcpip/header" "tailscale.com/net/stun" + "tailscale.com/tstest" ) type xdpAction uint32 @@ -271,6 +272,7 @@ func getIPv6STUNBindingResp() []byte { } func TestXDP(t *testing.T) { + tstest.RequireRoot(t) ipv4STUNBindingReqTX := getIPv4STUNBindingReq(nil) ipv6STUNBindingReqTX := getIPv6STUNBindingReq(nil) @@ -957,10 +959,6 @@ func TestXDP(t *testing.T) { server, err := NewSTUNServer(&STUNServerConfig{DeviceName: "fake", DstPort: defaultSTUNPort}, &noAttachOption{}) if err != nil { - if errors.Is(err, unix.EPERM) { - // TODO(jwhited): get this running - t.Skip("skipping due to EPERM error; test requires elevated privileges") - } t.Fatalf("error constructing STUN server: %v", err) } defer server.Close() diff --git a/drive/driveimpl/drive_test.go b/drive/driveimpl/drive_test.go index 8f9b43d6b..185ae2a9c 100644 --- a/drive/driveimpl/drive_test.go +++ b/drive/driveimpl/drive_test.go @@ -239,7 +239,7 @@ func TestLOCK(t *testing.T) { } u := fmt.Sprintf("http://%s/%s/%s/%s/%s", - s.local.l.Addr(), + s.local.ln.Addr(), url.PathEscape(domain), url.PathEscape(remote1), url.PathEscape(share11), @@ -365,7 +365,7 @@ func TestUNLOCK(t *testing.T) { } u := fmt.Sprintf("http://%s/%s/%s/%s/%s", - s.local.l.Addr(), + s.local.ln.Addr(), url.PathEscape(domain), url.PathEscape(remote1), url.PathEscape(share11), @@ -428,12 +428,12 @@ func TestUNLOCK(t *testing.T) { } type local struct { - l net.Listener + ln net.Listener fs *FileSystemForLocal } type remote struct { - l net.Listener + ln net.Listener fs *FileSystemForRemote fileServer *FileServer shares map[string]string @@ -487,7 +487,7 @@ func newSystem(t *testing.T) *system { client.SetTransport(&http.Transport{DisableKeepAlives: true}) s := &system{ t: t, - local: &local{l: ln, fs: fs}, + local: &local{ln: ln, fs: fs}, client: client, remotes: make(map[string]*remote), } @@ -510,7 +510,7 @@ func (s *system) addRemote(name string) string { s.t.Logf("FileServer for %v listening at %s", name, fileServer.Addr()) r := &remote{ - l: ln, + ln: ln, fileServer: fileServer, fs: NewFileSystemForRemote(log.Printf), shares: make(map[string]string), @@ -524,7 +524,7 @@ func (s *system) addRemote(name string) string { for name, r := range s.remotes { remotes = append(remotes, &drive.Remote{ Name: name, - URL: func() string { return fmt.Sprintf("http://%s", r.l.Addr()) }, + URL: func() string { return fmt.Sprintf("http://%s", r.ln.Addr()) }, }) } s.local.fs.SetRemotes( @@ -683,7 +683,7 @@ func (s *system) stop() { s.t.Fatalf("failed to Close fs: %s", err) } - err = s.local.l.Close() + err = s.local.ln.Close() if err != nil { s.t.Fatalf("failed to Close listener: %s", err) } @@ -694,7 +694,7 @@ func (s *system) stop() { s.t.Fatalf("failed to Close remote fs: %s", err) } - err = r.l.Close() + err = r.ln.Close() if err != nil { s.t.Fatalf("failed to Close remote listener: %s", err) } diff --git a/envknob/logknob/logknob.go b/envknob/logknob/logknob.go deleted file mode 100644 index bc6e8c362..000000000 --- a/envknob/logknob/logknob.go +++ /dev/null @@ -1,83 +0,0 @@ -// Copyright (c) Tailscale Inc & contributors -// SPDX-License-Identifier: BSD-3-Clause - -// Package logknob provides a helpful wrapper that allows enabling logging -// based on either an envknob or other methods of enablement. -package logknob - -import ( - "sync/atomic" - - "tailscale.com/envknob" - "tailscale.com/tailcfg" - "tailscale.com/types/logger" -) - -// TODO(andrew-d): should we have a package-global registry of logknobs? It -// would allow us to update from a netmap in a central location, which might be -// reason enough to do it... - -// LogKnob allows configuring verbose logging, with multiple ways to enable. It -// supports enabling logging via envknob, via atomic boolean (for use in e.g. -// c2n log level changes), and via capabilities from a NetMap (so users can -// enable logging via the ACL JSON). -type LogKnob struct { - capName tailcfg.NodeCapability - cap atomic.Bool - env func() bool - manual atomic.Bool -} - -// NewLogKnob creates a new LogKnob, with the provided environment variable -// name and/or NetMap capability. -func NewLogKnob(env string, cap tailcfg.NodeCapability) *LogKnob { - if env == "" && cap == "" { - panic("must provide either an environment variable or capability") - } - - lk := &LogKnob{ - capName: cap, - } - if env != "" { - lk.env = envknob.RegisterBool(env) - } else { - lk.env = func() bool { return false } - } - return lk -} - -// Set will cause logs to be printed when called with Set(true). When called -// with Set(false), logs will not be printed due to an earlier call of -// Set(true), but may be printed due to either the envknob and/or capability of -// this LogKnob. -func (lk *LogKnob) Set(v bool) { - lk.manual.Store(v) -} - -// NetMap is an interface for the parts of netmap.NetworkMap that we care -// about; we use this rather than a concrete type to avoid a circular -// dependency. -type NetMap interface { - HasSelfCapability(tailcfg.NodeCapability) bool -} - -// UpdateFromNetMap will enable logging if the SelfNode in the provided NetMap -// contains the capability provided for this LogKnob. -func (lk *LogKnob) UpdateFromNetMap(nm NetMap) { - if lk.capName == "" { - return - } - lk.cap.Store(nm.HasSelfCapability(lk.capName)) -} - -// Do will call log with the provided format and arguments if any of the -// configured methods for enabling logging are true. -func (lk *LogKnob) Do(log logger.Logf, format string, args ...any) { - if lk.shouldLog() { - log(format, args...) - } -} - -func (lk *LogKnob) shouldLog() bool { - return lk.manual.Load() || lk.env() || lk.cap.Load() -} diff --git a/envknob/logknob/logknob_test.go b/envknob/logknob/logknob_test.go deleted file mode 100644 index 9e7ab8aef..000000000 --- a/envknob/logknob/logknob_test.go +++ /dev/null @@ -1,99 +0,0 @@ -// Copyright (c) Tailscale Inc & contributors -// SPDX-License-Identifier: BSD-3-Clause - -package logknob - -import ( - "bytes" - "fmt" - "testing" - - "tailscale.com/envknob" - "tailscale.com/tailcfg" - "tailscale.com/types/netmap" - "tailscale.com/util/set" -) - -var testKnob = NewLogKnob( - "TS_TEST_LOGKNOB", - "https://tailscale.com/cap/testing", -) - -// Static type assertion for our interface type. -var _ NetMap = &netmap.NetworkMap{} - -func TestLogKnob(t *testing.T) { - t.Run("Default", func(t *testing.T) { - if testKnob.shouldLog() { - t.Errorf("expected default shouldLog()=false") - } - assertNoLogs(t) - }) - t.Run("Manual", func(t *testing.T) { - t.Cleanup(func() { testKnob.Set(false) }) - - assertNoLogs(t) - testKnob.Set(true) - if !testKnob.shouldLog() { - t.Errorf("expected shouldLog()=true") - } - assertLogs(t) - }) - t.Run("Env", func(t *testing.T) { - t.Cleanup(func() { - envknob.Setenv("TS_TEST_LOGKNOB", "") - }) - - assertNoLogs(t) - if testKnob.shouldLog() { - t.Errorf("expected default shouldLog()=false") - } - - envknob.Setenv("TS_TEST_LOGKNOB", "true") - if !testKnob.shouldLog() { - t.Errorf("expected shouldLog()=true") - } - assertLogs(t) - }) - t.Run("NetMap", func(t *testing.T) { - t.Cleanup(func() { testKnob.cap.Store(false) }) - - assertNoLogs(t) - if testKnob.shouldLog() { - t.Errorf("expected default shouldLog()=false") - } - - testKnob.UpdateFromNetMap(&netmap.NetworkMap{ - AllCaps: set.Of(tailcfg.NodeCapability("https://tailscale.com/cap/testing")), - }) - if !testKnob.shouldLog() { - t.Errorf("expected shouldLog()=true") - } - assertLogs(t) - }) -} - -func assertLogs(t *testing.T) { - var buf bytes.Buffer - logf := func(format string, args ...any) { - fmt.Fprintf(&buf, format, args...) - } - - testKnob.Do(logf, "hello %s", "world") - const want = "hello world" - if got := buf.String(); got != want { - t.Errorf("got %q, want %q", got, want) - } -} - -func assertNoLogs(t *testing.T) { - var buf bytes.Buffer - logf := func(format string, args ...any) { - fmt.Fprintf(&buf, format, args...) - } - - testKnob.Do(logf, "hello %s", "world") - if got := buf.String(); got != "" { - t.Errorf("expected no logs, but got: %q", got) - } -} diff --git a/feature/buildfeatures/feature_lazywg_disabled.go b/feature/buildfeatures/feature_lazywg_disabled.go deleted file mode 100644 index af1ad388c..000000000 --- a/feature/buildfeatures/feature_lazywg_disabled.go +++ /dev/null @@ -1,13 +0,0 @@ -// Copyright (c) Tailscale Inc & contributors -// SPDX-License-Identifier: BSD-3-Clause - -// Code generated by gen.go; DO NOT EDIT. - -//go:build ts_omit_lazywg - -package buildfeatures - -// HasLazyWG is whether the binary was built with support for modular feature "Lazy WireGuard configuration for memory-constrained devices with large netmaps". -// Specifically, it's whether the binary was NOT built with the "ts_omit_lazywg" build tag. -// It's a const so it can be used for dead code elimination. -const HasLazyWG = false diff --git a/feature/buildfeatures/feature_lazywg_enabled.go b/feature/buildfeatures/feature_lazywg_enabled.go deleted file mode 100644 index f2d6a10f8..000000000 --- a/feature/buildfeatures/feature_lazywg_enabled.go +++ /dev/null @@ -1,13 +0,0 @@ -// Copyright (c) Tailscale Inc & contributors -// SPDX-License-Identifier: BSD-3-Clause - -// Code generated by gen.go; DO NOT EDIT. - -//go:build !ts_omit_lazywg - -package buildfeatures - -// HasLazyWG is whether the binary was built with support for modular feature "Lazy WireGuard configuration for memory-constrained devices with large netmaps". -// Specifically, it's whether the binary was NOT built with the "ts_omit_lazywg" build tag. -// It's a const so it can be used for dead code elimination. -const HasLazyWG = true diff --git a/feature/buildfeatures/feature_routecheck_disabled.go b/feature/buildfeatures/feature_routecheck_disabled.go new file mode 100644 index 000000000..e728fc91d --- /dev/null +++ b/feature/buildfeatures/feature_routecheck_disabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build ts_omit_routecheck + +package buildfeatures + +// HasRouteCheck is whether the binary was built with support for modular feature "Support checking the reachability of overlapping routers, for choosing between multiple network paths to the same IP address". +// Specifically, it's whether the binary was NOT built with the "ts_omit_routecheck" build tag. +// It's a const so it can be used for dead code elimination. +const HasRouteCheck = false diff --git a/feature/buildfeatures/feature_routecheck_enabled.go b/feature/buildfeatures/feature_routecheck_enabled.go new file mode 100644 index 000000000..34fffb835 --- /dev/null +++ b/feature/buildfeatures/feature_routecheck_enabled.go @@ -0,0 +1,13 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by gen.go; DO NOT EDIT. + +//go:build !ts_omit_routecheck + +package buildfeatures + +// HasRouteCheck is whether the binary was built with support for modular feature "Support checking the reachability of overlapping routers, for choosing between multiple network paths to the same IP address". +// Specifically, it's whether the binary was NOT built with the "ts_omit_routecheck" build tag. +// It's a const so it can be used for dead code elimination. +const HasRouteCheck = true diff --git a/feature/clientupdate/clientupdate.go b/feature/clientupdate/clientupdate.go index d47d04815..999dd7920 100644 --- a/feature/clientupdate/clientupdate.go +++ b/feature/clientupdate/clientupdate.go @@ -163,6 +163,7 @@ func (e *extension) DoSelfUpdate() { }) if err != nil { e.pushSelfUpdateProgress(ipnstate.NewUpdateProgress(ipnstate.UpdateFailed, err.Error())) + return } err = up.Update() if err != nil { diff --git a/feature/condlite/expvar/omit.go b/feature/condlite/expvar/omit.go index b5481695c..188de2af2 100644 --- a/feature/condlite/expvar/omit.go +++ b/feature/condlite/expvar/omit.go @@ -3,7 +3,6 @@ //go:build ts_omit_debug && ts_omit_clientmetrics && ts_omit_usermetrics -// excluding the package from builds. package expvar type Int int64 diff --git a/feature/condregister/maybe_routecheck.go b/feature/condregister/maybe_routecheck.go new file mode 100644 index 000000000..2d98c43db --- /dev/null +++ b/feature/condregister/maybe_routecheck.go @@ -0,0 +1,8 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !ts_omit_routecheck + +package condregister + +import _ "tailscale.com/feature/routecheck" diff --git a/feature/condregister/maybe_tailnetlock.go b/feature/condregister/maybe_tailnetlock.go new file mode 100644 index 000000000..80a3dffe3 --- /dev/null +++ b/feature/condregister/maybe_tailnetlock.go @@ -0,0 +1,8 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !ts_omit_tailnetlock + +package condregister + +import _ "tailscale.com/feature/tailnetlock" diff --git a/feature/conn25/addrAssignments.go b/feature/conn25/addrAssignments.go new file mode 100644 index 000000000..6d8a87dbb --- /dev/null +++ b/feature/conn25/addrAssignments.go @@ -0,0 +1,133 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package conn25 + +import ( + "container/heap" + "errors" + "net/netip" + "time" + + "tailscale.com/tstime" + "tailscale.com/util/dnsname" + "tailscale.com/util/mak" +) + +// domainDst is a key for looking up an existing address assignment by the +// DNS response domain and destination IP pair. +type domainDst struct { + domain dnsname.FQDN + dst netip.Addr +} + +// addrAssignments is the collection of addrs assigned by this client +// supporting lookup by magic IP, transit IP or domain+dst, or to lookup all +// transit IPs associated with a given connector (identified by its node key). +type addrAssignments struct { + byMagicIP map[netip.Addr]*addrs + byTransitIP map[netip.Addr]*addrs + byDomainDst map[domainDst]*addrs + byExpiresAt addrsHeap + clock tstime.Clock +} + +const defaultExpiry = 48 * time.Hour + +func (a *addrAssignments) insert(as *addrs) error { + return a.insertWithExpiry(as, defaultExpiry) +} + +func (a *addrAssignments) insertWithExpiry(as *addrs, d time.Duration) error { + now := a.clock.Now() + if !as.expiresAt.IsZero() && !as.expiresAt.Before(now) { + return errors.New("expiresAt already set") + } + // we don't expect for addresses to be reused before expiry + if existing, ok := a.byMagicIP[as.magic]; ok { + if !existing.expiresAt.Before(now) { + return errors.New("byMagicIP key exists") + } + } + ddst := domainDst{domain: as.domain, dst: as.dst} + if existing, ok := a.byDomainDst[ddst]; ok { + if !existing.expiresAt.Before(now) { + return errors.New("byDomainDst key exists") + } + } + if existing, ok := a.byTransitIP[as.transit]; ok { + if !existing.expiresAt.Before(now) { + return errors.New("byTransitIP key exists") + } + } + as.expiresAt = now.Add(d) + mak.Set(&a.byMagicIP, as.magic, as) + mak.Set(&a.byTransitIP, as.transit, as) + mak.Set(&a.byDomainDst, ddst, as) + heap.Push(&a.byExpiresAt, as) + return nil +} + +func (a *addrAssignments) lookupByDomainDst(domain dnsname.FQDN, dst netip.Addr) (*addrs, bool) { + v, ok := a.byDomainDst[domainDst{domain: domain, dst: dst}] + if !ok || v.expiresAt.Before(a.clock.Now()) { + return &addrs{}, false + } + return v, true +} + +func (a *addrAssignments) lookupByMagicIP(mip netip.Addr) (*addrs, bool) { + v, ok := a.byMagicIP[mip] + if !ok || v.expiresAt.Before(a.clock.Now()) { + return &addrs{}, false + } + return v, true +} + +func (a *addrAssignments) lookupByTransitIP(tip netip.Addr) (*addrs, bool) { + v, ok := a.byTransitIP[tip] + if !ok || v.expiresAt.Before(a.clock.Now()) { + return &addrs{}, false + } + return v, true +} + +// popExpired returns the member of addrAssignments that expired earliest, +// or an invalid addrs if there are no expired members of addrAssignments. +func (a *addrAssignments) popExpired(now time.Time) *addrs { + if a.byExpiresAt.Len() == 0 { + return &addrs{} + } + if !a.byExpiresAt.peek().expiresAt.Before(now) { + return &addrs{} + } + v := heap.Pop(&a.byExpiresAt).(*addrs) + delete(a.byMagicIP, v.magic) + delete(a.byTransitIP, v.transit) + dd := domainDst{domain: v.domain, dst: v.dst} + delete(a.byDomainDst, dd) + return v +} + +type addrsHeap []*addrs + +func (h addrsHeap) Len() int { return len(h) } +func (h addrsHeap) Less(i, j int) bool { return h[i].expiresAt.Before(h[j].expiresAt) } +func (h addrsHeap) Swap(i, j int) { h[i], h[j] = h[j], h[i] } +func (h *addrsHeap) Push(x any) { + as, ok := x.(*addrs) + if !ok { + panic("unexpected not an addrs") + } + *h = append(*h, as) +} +func (h *addrsHeap) Pop() any { + old := *h + n := len(old) + x := old[n-1] + *h = old[0 : n-1] + return x +} +func (h addrsHeap) peek() *addrs { + return (h)[0] +} diff --git a/feature/conn25/addrAssignments_test.go b/feature/conn25/addrAssignments_test.go new file mode 100644 index 000000000..7dd984453 --- /dev/null +++ b/feature/conn25/addrAssignments_test.go @@ -0,0 +1,149 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package conn25 + +import ( + "fmt" + "net/netip" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "tailscale.com/tstest" +) + +func TestAssignmentsExpire(t *testing.T) { + clock := tstest.NewClock(tstest.ClockOpts{Start: time.Now()}) + assignments := addrAssignments{clock: clock} + as := &addrs{ + dst: netip.MustParseAddr("0.0.0.1"), + magic: netip.MustParseAddr("0.0.0.2"), + transit: netip.MustParseAddr("0.0.0.3"), + app: "a", + domain: "example.com.", + } + err := assignments.insert(as) + if err != nil { + t.Fatal(err) + } + // Time has not passed since the insert, the assignment should be returned. + foundAs, ok := assignments.lookupByMagicIP(as.magic) + if !ok { + t.Fatal("expected to find") + } + if foundAs.dst != as.dst { + t.Fatalf("want %v; got %v", as.dst, foundAs.dst) + } + // and we cannot insert over the addresses + err = assignments.insert(as) + if err == nil { + t.Fatal("expected an error but got nil") + } + // After a time greater than the default expiry passes, the assignment should + // not be returned. + clock.Advance(defaultExpiry * 2) + foundAsAfter, okAfter := assignments.lookupByMagicIP(as.magic) + if okAfter { + t.Fatal("expected not to find (expired)") + } + if foundAsAfter.isValid() { + t.Fatal("expected zero val") + } + // Now we can reuse the addresses + err = assignments.insert(as) + if err != nil { + t.Fatal(err) + } + foundAs, ok = assignments.lookupByMagicIP(as.magic) + if !ok { + t.Fatal("expected to find") + } + if foundAs.dst != as.dst { + t.Fatalf("want %v; got %v", as.dst, foundAs.dst) + } + if !foundAs.expiresAt.After(clock.Now()) { + t.Fatalf("expected foundAs to expire after now") + } +} + +func TestPopExpired(t *testing.T) { + clock := tstest.NewClock(tstest.ClockOpts{Start: time.Now()}) + assignments := addrAssignments{clock: clock} + makeAndAddAddrs := func(n int) *addrs { + t.Helper() + as := &addrs{ + dst: netip.MustParseAddr(fmt.Sprintf("0.0.1.%d", n)), + magic: netip.MustParseAddr(fmt.Sprintf("0.0.2.%d", n)), + transit: netip.MustParseAddr(fmt.Sprintf("0.0.3.%d", n)), + app: "a", + domain: "example.com.", + } + err := assignments.insert(as) + if err != nil { + t.Fatal(err) + } + return as + } + // cmp.Diff addrs ignoring expiresAt + doDiff := func(want, got *addrs) string { + t.Helper() + return cmp.Diff( + want, + got, + cmp.AllowUnexported(addrs{}), + cmpopts.EquateComparable(netip.Addr{}), + cmpopts.IgnoreFields(addrs{}, "expiresAt"), + ) + } + testAddrs := []*addrs{} + for i := range 2 { + testAddrs = append(testAddrs, makeAndAddAddrs(i+1)) + clock.Advance(1 * time.Second) + } + if len(assignments.byMagicIP) != 2 { + t.Fatalf("test setup wrong") + } + + nn := assignments.popExpired(clock.Now()) + want := &addrs{} // invalid addr + if diff := doDiff(want, nn); diff != "" { + t.Fatalf("only expired addresses are removed: %s", diff) + } + if len(assignments.byMagicIP) != 2 { + t.Fatalf("nothing should have been removed") + } + if nn.isValid() { + t.Fatal("empty addrs should be invalid") + } + + clock.Advance(2 * defaultExpiry) // all addrs are now expired + + want = testAddrs[0] + nn = assignments.popExpired(clock.Now()) + if diff := doDiff(want, nn); diff != "" { + t.Fatal(diff) + } + if len(assignments.byMagicIP) != 1 { + t.Fatalf("an assignment should have been removed") + } + + want = testAddrs[1] + nn = assignments.popExpired(clock.Now()) + if diff := doDiff(want, nn); diff != "" { + t.Fatal(diff) + } + if len(assignments.byMagicIP) != 0 { + t.Fatalf("an assignment should have been removed") + } + + want = &addrs{} + nn = assignments.popExpired(clock.Now()) + if diff := doDiff(want, nn); diff != "" { + t.Fatal(diff) + } + if len(assignments.byMagicIP) != 0 { + t.Fatalf("there should have been no change") + } +} diff --git a/feature/conn25/conn25.go b/feature/conn25/conn25.go index e716c09d0..cc1b38a71 100644 --- a/feature/conn25/conn25.go +++ b/feature/conn25/conn25.go @@ -19,18 +19,22 @@ import ( "slices" "strings" "sync" + "sync/atomic" + "time" "go4.org/netipx" "golang.org/x/net/dns/dnsmessage" "tailscale.com/appc" "tailscale.com/envknob" "tailscale.com/feature" + "tailscale.com/ipn" "tailscale.com/ipn/ipnext" "tailscale.com/ipn/ipnlocal" "tailscale.com/net/packet" "tailscale.com/net/tsaddr" "tailscale.com/net/tstun" "tailscale.com/tailcfg" + "tailscale.com/tstime" "tailscale.com/types/appctype" "tailscale.com/types/key" "tailscale.com/types/logger" @@ -120,10 +124,12 @@ func (e *extension) Init(host ipnext.Host) error { } e.host = host - dph := newDatapathHandler(e.conn25, e.conn25.client.logf) + dph := newDatapathHandler(e.conn25, e.conn25.logf) if err := e.installHooks(dph); err != nil { return err } + profile, prefs := e.host.Profiles().CurrentProfileState() + e.profileStateChange(profile, prefs, false) ctx, cancel := context.WithCancelCause(context.Background()) e.ctxCancel = cancel @@ -167,10 +173,13 @@ func (e *extension) installHooks(dph *datapathHandler) error { } // Manage how we react to changes to the current node, - // including property changes (e.g. HostInfo, Capabilities, CapMap) - // and profile switches. + // including property changes (e.g. HostInfo, Capabilities, CapMap). e.host.Hooks().OnSelfChange.Add(e.onSelfChange) + // Manage how we react profile state changes, which include + // prefs changes. + e.host.Hooks().ProfileStateChange.Add(e.profileStateChange) + // Allow the client to send packets with Transit IP destinations // in the link-local space. e.host.Hooks().Filter.LinkLocalAllowHooks.Add(func(p packet.Parsed) (bool, string) { @@ -219,17 +228,40 @@ func (e *extension) installHooks(dph *datapathHandler) error { // ClientTransitIPForMagicIP implements [IPMapper]. func (c *Conn25) ClientTransitIPForMagicIP(m netip.Addr) (netip.Addr, error) { - return c.client.transitIPForMagicIP(m) + if addr, ok := c.client.transitIPForMagicIP(m); ok { + return addr, nil + } + cfg, ok := c.getConfig() + if !ok { + return netip.Addr{}, nil + } + if !cfg.ipSets.v4Magic.Contains(m) && !cfg.ipSets.v6Magic.Contains(m) { + return netip.Addr{}, nil + } + return netip.Addr{}, ErrUnmappedMagicIP } // ConnectorRealIPForTransitIPConnection implements [IPMapper]. func (c *Conn25) ConnectorRealIPForTransitIPConnection(src, transit netip.Addr) (netip.Addr, error) { - return c.connector.realIPForTransitIPConnection(src, transit) + if addr, ok := c.connector.realIPForTransitIPConnection(src, transit); ok { + return addr, nil + } + cfg, ok := c.getConfig() + if !ok { + return netip.Addr{}, nil + } + if !cfg.ipSets.v4Transit.Contains(transit) && !cfg.ipSets.v6Transit.Contains(transit) { + return netip.Addr{}, nil + } + return netip.Addr{}, ErrUnmappedSrcAndTransitIP } func (e *extension) getMagicRange() views.Slice[netip.Prefix] { - cfg := e.conn25.client.getConfig() - return views.SliceOf(slices.Concat(cfg.v4MagicIPSet.Prefixes(), cfg.v6MagicIPSet.Prefixes())) + cfg, ok := e.conn25.getConfig() + if !ok { + return views.Slice[netip.Prefix]{} + } + return views.SliceOf(slices.Concat(cfg.ipSets.v4Magic.Prefixes(), cfg.ipSets.v6Magic.Prefixes())) } // Shutdown implements [ipnlocal.Extension]. @@ -264,12 +296,21 @@ func (e *extension) handleConnectorTransitIP(h ipnlocal.PeerAPIHandler, w http.R w.Write(bs) } +// onSelfChange implements the [ipnext.Hooks.OnSelfChange] hook. func (e *extension) onSelfChange(selfNode tailcfg.NodeView) { - err := e.conn25.reconfig(selfNode) + cfg, err := configFromNodeView(selfNode) if err != nil { - e.conn25.client.logf("error during Reconfig onSelfChange: %v", err) + e.conn25.logf("error generating config from self node view: %v", err) return } + e.conn25.reconfig(cfg) +} + +// profileStateChange implements the [ipnext.Hooks.ProfileStateChange] hook. +func (e *extension) profileStateChange(loginProfile ipn.LoginProfileView, prefs ipn.PrefsView, sameNode bool) { + // TODO(mzb): Handle node changes. Wipe out all config? + // We'll need to look at the ordering of this hook and onSelfChange. + e.conn25.prefsAdvertiseConnector.Store(prefs.AppConnector().Advertise) } func (e *extension) extraWireGuardAllowedIPs(k key.NodePublic) views.Slice[netip.Prefix] { @@ -283,22 +324,41 @@ type appAddr struct { // Conn25 holds state for routing traffic for a domain via a connector. type Conn25 struct { - client *client - connector *connector + config atomic.Pointer[config] + prefsAdvertiseConnector atomic.Bool + logf logger.Logf + client *client + connector *connector +} + +func (c *Conn25) getConfig() (*config, bool) { + cfg := c.config.Load() + return cfg, cfg.isConfigured } func (c *Conn25) isConfigured() bool { - return c.client.isConfigured() + _, ok := c.getConfig() + return ok } func newConn25(logf logger.Logf) *Conn25 { c := &Conn25{ - client: &client{ - logf: logf, - addrsCh: make(chan addrs, 64), - }, + logf: logf, connector: &connector{logf: logf}, } + c.config.Store(&config{}) // initialize with empty to avoid nil checks + c.client = &client{ + logf: logf, + addrsCh: make(chan addrs, 64), + assignments: addrAssignments{clock: tstime.StdClock{}}, + getIPSets: func() ipSets { + cfg, ok := c.getConfig() + if !ok { + return emptyIPSets() + } + return cfg.ipSets + }, + } return c } @@ -310,24 +370,9 @@ func ipSetFromIPRanges(rs []netipx.IPRange) (*netipx.IPSet, error) { return b.IPSet() } -func (c *Conn25) reconfig(selfNode tailcfg.NodeView) error { - cfg, err := configFromNodeView(selfNode) - if err != nil { - return err - } - if err := c.client.reconfig(cfg); err != nil { - return err - } - if err := c.connector.reconfig(cfg); err != nil { - return err - } - return nil -} - -// mapDNSResponse parses and inspects the DNS response, and uses the -// contents to assign addresses for connecting. -func (c *Conn25) mapDNSResponse(buf []byte) []byte { - return c.client.mapDNSResponse(buf) +func (c *Conn25) reconfig(cfg *config) { + c.config.Store(cfg) + c.client.reconfig() } const dupeTransitIPMessage = "Duplicate transit address in ConnectorTransitIPRequest" @@ -341,6 +386,21 @@ const unknownAppNameMessage = "The App name in the request does not match a conf // family of the transitIP). If a peer has stored this mapping in the connector, // Conn25 will route traffic to TransitIPs to DestinationIPs for that peer. func (c *Conn25) handleConnectorTransitIPRequest(n tailcfg.NodeView, ctipr ConnectorTransitIPRequest) ConnectorTransitIPResponse { + resp := ConnectorTransitIPResponse{} + cfg, ok := c.getConfig() + if !ok { + // TODO(mzb): If this node is no longer configured at the + // the time of this call, perhaps there should be a top-level + // error, instead of error-per-TransitIP? + for range ctipr.TransitIPs { + resp.TransitIPs = append(resp.TransitIPs, TransitIPResponse{ + Code: UnknownAppName, + Message: unknownAppNameMessage, + }) + } + return resp + } + var peerIPv4, peerIPv6 netip.Addr for _, ip := range n.Addresses().All() { if !ip.IsSingleIP() || !tsaddr.IsTailscaleIP(ip.Addr()) { @@ -353,7 +413,6 @@ func (c *Conn25) handleConnectorTransitIPRequest(n tailcfg.NodeView, ctipr Conne } } - resp := ConnectorTransitIPResponse{} seen := map[netip.Addr]bool{} for _, each := range ctipr.TransitIPs { if seen[each.TransitIP] { @@ -361,10 +420,20 @@ func (c *Conn25) handleConnectorTransitIPRequest(n tailcfg.NodeView, ctipr Conne Code: DuplicateTransitIP, Message: dupeTransitIPMessage, }) - c.connector.logf("[Unexpected] peer attempt to map a transit IP reused a transitIP: node: %s, IP: %v", + c.logf("[Unexpected] peer attempt to map a transit IP reused a transitIP: node: %s, IP: %v", n.StableID(), each.TransitIP) continue } + + if _, ok := cfg.appsByName[each.App]; !ok { + resp.TransitIPs = append(resp.TransitIPs, TransitIPResponse{ + Code: UnknownAppName, + Message: unknownAppNameMessage, + }) + c.logf("[Unexpected] peer attempt to map a transit IP referenced unknown app: node: %s, app: %q", + n.StableID(), each.App) + continue + } tipresp := c.connector.handleTransitIPRequest(n, peerIPv4, peerIPv6, each) seen[each.TransitIP] = true resp.TransitIPs = append(resp.TransitIPs, tipresp) @@ -397,12 +466,6 @@ func (c *connector) handleTransitIPRequest(n tailcfg.NodeView, peerV4 netip.Addr c.mu.Lock() defer c.mu.Unlock() - if _, ok := c.config.appsByName[tipr.App]; !ok { - c.logf("[Unexpected] peer attempt to map a transit IP referenced unknown app: node: %s, app: %q", - n.StableID(), tipr.App) - return TransitIPResponse{Code: UnknownAppName, Message: unknownAppNameMessage} - } - if c.transitIPs == nil { c.transitIPs = make(map[netip.Addr]map[netip.Addr]appAddr) } @@ -485,74 +548,105 @@ type ConnectorTransitIPResponse struct { const AppConnectorsExperimentalAttrName = "tailscale.com/app-connectors-experimental" -// config holds the config from the policy and lookups derived from that. -// config is not safe for concurrent use. -type config struct { - isConfigured bool - apps []appctype.Conn25Attr - appsByName map[string]appctype.Conn25Attr - appNamesByDomain map[dnsname.FQDN][]string - selfRoutedDomains set.Set[dnsname.FQDN] - v4TransitIPSet netipx.IPSet - v4MagicIPSet netipx.IPSet - v6TransitIPSet netipx.IPSet - v6MagicIPSet netipx.IPSet +// ipSets wraps all the IPSets the config needs. +type ipSets struct { + v4Transit *netipx.IPSet + v4Magic *netipx.IPSet + v6Transit *netipx.IPSet + v6Magic *netipx.IPSet } -func configFromNodeView(n tailcfg.NodeView) (config, error) { +func emptyIPSets() ipSets { + return ipSets{ + v4Transit: &netipx.IPSet{}, + v4Magic: &netipx.IPSet{}, + v6Transit: &netipx.IPSet{}, + v6Magic: &netipx.IPSet{}, + } +} + +// config holds the config derived from the self node view, +// which includes the policy. +// config is not safe for concurrent use. +type config struct { + isConfigured bool + apps []appctype.Conn25Attr + appsByName map[string]appctype.Conn25Attr + appNamesByDomain map[dnsname.FQDN][]string + appNamesByWCDomain map[dnsname.FQDN][]string + selfAppNames set.Set[string] + ipSets ipSets +} + +func configFromNodeView(n tailcfg.NodeView) (*config, error) { apps, err := tailcfg.UnmarshalNodeCapViewJSON[appctype.Conn25Attr](n.CapMap(), AppConnectorsExperimentalAttrName) if err != nil { - return config{}, err + return &config{}, err } if len(apps) == 0 { - return config{}, nil + return &config{}, nil } selfTags := set.SetOf(n.Tags().AsSlice()) - cfg := config{ - isConfigured: true, - apps: apps, - appsByName: map[string]appctype.Conn25Attr{}, - appNamesByDomain: map[dnsname.FQDN][]string{}, - selfRoutedDomains: set.Set[dnsname.FQDN]{}, + cfg := &config{ + isConfigured: true, + apps: apps, + appsByName: map[string]appctype.Conn25Attr{}, + appNamesByDomain: map[dnsname.FQDN][]string{}, + appNamesByWCDomain: map[dnsname.FQDN][]string{}, + selfAppNames: set.Set[string]{}, + ipSets: emptyIPSets(), } for _, app := range apps { - selfMatchesTags := slices.ContainsFunc(app.Connectors, selfTags.Contains) + normalizedDomains := set.Set[dnsname.FQDN]{} + normalizedWCDomains := set.Set[dnsname.FQDN]{} for _, d := range app.Domains { - fqdn, err := normalizeDNSName(d) + domain, isWild := strings.CutPrefix(d, "*.") + fqdn, err := normalizeDNSName(domain) if err != nil { - return config{}, err + return &config{}, err } - mak.Set(&cfg.appNamesByDomain, fqdn, append(cfg.appNamesByDomain[fqdn], app.Name)) - if selfMatchesTags { - cfg.selfRoutedDomains.Add(fqdn) + if isWild && !normalizedWCDomains.Contains(fqdn) { + normalizedWCDomains.Add(fqdn) + mak.Set(&cfg.appNamesByWCDomain, fqdn, append(cfg.appNamesByWCDomain[fqdn], app.Name)) + } else if !isWild && !normalizedDomains.Contains(fqdn) { + normalizedDomains.Add(fqdn) + mak.Set(&cfg.appNamesByDomain, fqdn, append(cfg.appNamesByDomain[fqdn], app.Name)) } } mak.Set(&cfg.appsByName, app.Name, app) + if slices.ContainsFunc(app.Connectors, selfTags.Contains) { + cfg.selfAppNames.Add(app.Name) + } + } + // TODO(fran) 2026-03-18 we don't yet have a proper way to communicate the // global IP pool config. For now just take it from the first app. if len(apps) != 0 { app := apps[0] v4Mipp, err := ipSetFromIPRanges(app.V4MagicIPPool) if err != nil { - return config{}, err + return &config{}, err } v4Tipp, err := ipSetFromIPRanges(app.V4TransitIPPool) if err != nil { - return config{}, err + return &config{}, err } v6Mipp, err := ipSetFromIPRanges(app.V6MagicIPPool) if err != nil { - return config{}, err + return &config{}, err } v6Tipp, err := ipSetFromIPRanges(app.V6TransitIPPool) if err != nil { - return config{}, err + return &config{}, err } - cfg.v4MagicIPSet = *v4Mipp - cfg.v4TransitIPSet = *v4Tipp - cfg.v6MagicIPSet = *v6Mipp - cfg.v6TransitIPSet = *v6Tipp + ipSets := ipSets{ + v4Magic: v4Mipp, + v4Transit: v4Tipp, + v6Magic: v6Mipp, + v6Transit: v6Tipp, + } + cfg.ipSets = ipSets } return cfg, nil } @@ -562,8 +656,9 @@ func configFromNodeView(n tailcfg.NodeView) (config, error) { // connectors. // It's safe for concurrent use. type client struct { - logf logger.Logf - addrsCh chan addrs + logf logger.Logf + addrsCh chan addrs + getIPSets func() ipSets mu sync.Mutex // protects the fields below v4MagicIPPool *ippool @@ -571,28 +666,19 @@ type client struct { v6MagicIPPool *ippool v6TransitIPPool *ippool assignments addrAssignments - config config -} - -func (c *client) getConfig() config { - c.mu.Lock() - defer c.mu.Unlock() - return c.config + byConnKey map[key.NodePublic]set.Set[netip.Prefix] } // transitIPForMagicIP is part of the implementation of the IPMapper interface for dataflows lookups. // See also [IPMapper.ClientTransitIPForMagicIP]. -func (c *client) transitIPForMagicIP(magicIP netip.Addr) (netip.Addr, error) { +func (c *client) transitIPForMagicIP(magicIP netip.Addr) (netip.Addr, bool) { c.mu.Lock() defer c.mu.Unlock() v, ok := c.assignments.lookupByMagicIP(magicIP) if ok { - return v.transit, nil + return v.transit, true } - if !c.config.v4MagicIPSet.Contains(magicIP) && !c.config.v6MagicIPSet.Contains(magicIP) { - return netip.Addr{}, nil - } - return netip.Addr{}, ErrUnmappedMagicIP + return netip.Addr{}, false } // linkLocalAllow returns true if the provided packet with a link-local Dst address has a @@ -615,88 +701,121 @@ func (c *client) isKnownTransitIP(tip netip.Addr) bool { return ok } -func (c *client) isConfigured() bool { +func (c *client) reconfig() { c.mu.Lock() defer c.mu.Unlock() - return c.config.isConfigured + + ipSets := c.getIPSets() + + c.v4MagicIPPool = c.v4MagicIPPool.reconfig(ipSets.v4Magic) + c.v4TransitIPPool = c.v4TransitIPPool.reconfig(ipSets.v4Transit) + c.v6MagicIPPool = c.v6MagicIPPool.reconfig(ipSets.v6Magic) + c.v6TransitIPPool = c.v6TransitIPPool.reconfig(ipSets.v6Transit) } -func (c *client) reconfig(newCfg config) error { - c.mu.Lock() - defer c.mu.Unlock() +// getAppsForConnectorDomain returns the slice of app names which match the +// provided domain. Apps which match the domain exactly are preferred, +// otherwise the list of apps comes from the wildcard domain which matches +// the longest suffix of the specified domain. A nil or empty slice is returned +// if no match is found or if the list of matching apps would contain an app +// which is being handled by the self-node's connector. +func (cfg *config) getAppsForConnectorDomain(domain dnsname.FQDN, prefsAdvertiseConnector bool) []string { + // Lookup exact matches first + appNames := cfg.appNamesByDomain[domain] + if len(appNames) == 0 { + // No exact match, check wildcard domains + // We have made the decision that wildcards will match the base domain. + // So example.com will be a match for *.example.com, because we think that + // this is most likely what users will expect. + for d := domain; d != ""; d = d.Parent() { + if appNames = cfg.appNamesByWCDomain[d]; len(appNames) > 0 { + break + } + } + } - c.config = newCfg - - c.v4MagicIPPool = newIPPool(&(newCfg.v4MagicIPSet)) - c.v4TransitIPPool = newIPPool(&(newCfg.v4TransitIPSet)) - c.v6MagicIPPool = newIPPool(&(newCfg.v6MagicIPSet)) - c.v6TransitIPPool = newIPPool(&(newCfg.v6TransitIPSet)) - return nil -} - -func (c *client) isConnectorDomain(domain dnsname.FQDN) bool { - c.mu.Lock() - defer c.mu.Unlock() - appNames, ok := c.config.appNamesByDomain[domain] - return ok && len(appNames) > 0 + // If we have a candidate match, make sure that no candidate app is pointing + // at a connector on the self-node. + if len(appNames) == 0 || (prefsAdvertiseConnector && slices.ContainsFunc(appNames, cfg.selfAppNames.Contains)) { + return nil + } + return appNames } // reserveAddresses tries to make an assignment of addrs from the address pools // for this domain+dst address, so that this client can use conn25 connectors. +// The name of the matching app is also provided, no validation is done to check whether or not +// the app name refers to a configured app. // It checks that this domain should be routed and that this client is not itself a connector for the domain // and generally if it is valid to make the assignment. -func (c *client) reserveAddresses(domain dnsname.FQDN, dst netip.Addr) (addrs, error) { +func (c *client) reserveAddresses(appName string, domain dnsname.FQDN, dst netip.Addr) (*addrs, error) { if !dst.IsValid() { - return addrs{}, errors.New("dst is not valid") + return nil, errors.New("dst is not valid") } c.mu.Lock() defer c.mu.Unlock() if existing, ok := c.assignments.lookupByDomainDst(domain, dst); ok { return existing, nil } - appNames, _ := c.config.appNamesByDomain[domain] - if len(appNames) == 0 { - return addrs{}, fmt.Errorf("no app names found for domain %q", domain) + + // Before we check out more addresses from the pools try to return some. + // Trying to return any number greater than 1 will cause the number of + // addresses used to trend down in general. But as we have 2 different + // pools for the different IP versions, use a number a bit higher than + // 2 to try and process bursty behavior faster. + now := c.assignments.clock.Now() + for range 10 { + a := c.assignments.popExpired(now) + if !a.isValid() { + break + } + if a.is4() { + c.v4MagicIPPool.returnAddr(a.magic) + c.v4TransitIPPool.returnAddr(a.transit) + } else if a.is6() { + c.v6MagicIPPool.returnAddr(a.magic) + c.v6TransitIPPool.returnAddr(a.transit) + } else { + return nil, errors.New("unexpected neither 4 nor 6") + } } - // only reserve for first app - app := appNames[0] var mip, tip netip.Addr var err error if dst.Is4() { mip, err = c.v4MagicIPPool.next() if err != nil { - return addrs{}, err + return nil, err } tip, err = c.v4TransitIPPool.next() if err != nil { - return addrs{}, err + return nil, err } } else if dst.Is6() { mip, err = c.v6MagicIPPool.next() if err != nil { - return addrs{}, err + return nil, err } tip, err = c.v6TransitIPPool.next() if err != nil { - return addrs{}, err + return nil, err } } else { - return addrs{}, errors.New("unexpected neither 4 nor 6") + return nil, errors.New("unexpected neither 4 nor 6") } - as := addrs{ + as := &addrs{ dst: dst, magic: mip, transit: tip, - app: app, + app: appName, domain: domain, } if err := c.assignments.insert(as); err != nil { - return addrs{}, err + return nil, err } err = c.enqueueAddressAssignment(as) if err != nil { - return addrs{}, err + return nil, err } return as, nil } @@ -708,7 +827,7 @@ func (c *client) addTransitIPForConnector(tip netip.Addr, conn tailcfg.NodeView) c.mu.Lock() defer c.mu.Unlock() - return c.assignments.insertTransitConnMapping(tip, conn.Key()) + return c.insertTransitConnMapping(tip, conn.Key()) } func (e *extension) sendLoop(ctx context.Context) { @@ -718,7 +837,7 @@ func (e *extension) sendLoop(ctx context.Context) { return case as := <-e.conn25.client.addrsCh: if err := e.handleAddressAssignment(ctx, as); err != nil { - e.conn25.client.logf("error handling transit IP assignment (app: %s, mip: %v, src: %v): %v", as.app, as.magic, as.dst, err) + e.conn25.logf("error handling transit IP assignment (app: %s, mip: %v, src: %v): %v", as.app, as.magic, as.dst, err) } } } @@ -738,11 +857,11 @@ func (e *extension) handleAddressAssignment(ctx context.Context, as addrs) error return nil } -func (c *client) enqueueAddressAssignment(addrs addrs) error { +func (c *client) enqueueAddressAssignment(addrs *addrs) error { select { // TODO(fran) investigate the value of waiting for multiple addresses and sending them // in one ConnectorTransitIPRequest - case c.addrsCh <- addrs: + case c.addrsCh <- *addrs: return nil default: c.logf("address assignment queue full, dropping transit assignment for %v", addrs.domain) @@ -753,7 +872,7 @@ func (c *client) enqueueAddressAssignment(addrs addrs) error { func (c *client) extraWireGuardAllowedIPs(k key.NodePublic) views.Slice[netip.Prefix] { c.mu.Lock() defer c.mu.Unlock() - tips, ok := c.assignments.lookupTransitIPsByConnKey(k) + tips, ok := c.lookupTransitIPsByConnKey(k) if !ok { return views.Slice[netip.Prefix]{} } @@ -802,9 +921,13 @@ func makePeerAPIReq(ctx context.Context, httpClient *http.Client, urlBase string } func (e *extension) sendAddressAssignment(ctx context.Context, as addrs) (tailcfg.NodeView, error) { - app, ok := e.conn25.client.getConfig().appsByName[as.app] + cfg, ok := e.conn25.getConfig() if !ok { - e.conn25.client.logf("App not found for app: %s (domain: %s)", as.app, as.domain) + return tailcfg.NodeView{}, errors.New("not configured") + } + app, ok := cfg.appsByName[as.app] + if !ok { + e.conn25.logf("App not found for app: %s (domain: %s)", as.app, as.domain) return tailcfg.NodeView{}, errors.New("app not found") } @@ -856,7 +979,10 @@ func makeServFail(logf logger.Logf, h dnsmessage.Header, q dnsmessage.Question) return bs } -func (c *client) mapDNSResponse(buf []byte) []byte { +// mapDNSResponse parses and inspects the DNS response. If the domain +// is determined to belong to app this node is client for, it assigns addresses +// for connecting and rewrites the response to contain Magic IPs. +func (c *Conn25) mapDNSResponse(buf []byte) []byte { var p dnsmessage.Parser hdr, err := p.Start(buf) if err != nil { @@ -881,19 +1007,30 @@ func (c *client) mapDNSResponse(buf []byte) []byte { if err != nil { return buf } - if !c.isConnectorDomain(queriedDomain) { + + cfg, ok := c.getConfig() + if !ok { return buf } + appNames := cfg.getAppsForConnectorDomain(queriedDomain, c.prefsAdvertiseConnector.Load()) + if len(appNames) == 0 { + return buf + } + + // There is guaranteed to be at least one matching app, so just take the first one for now + appName := appNames[0] + // Now we know this is a dns response we think we should rewrite, we're going to provide our response which // currently means we will: // * write the questions through as they are // * not send through the additional section // * provide our answers, or no answers if we don't handle those answers (possibly in the future we should write through answers for eg TypeTXT) var answers []dnsResponseRewrite + var cnameChain map[dnsname.FQDN]dnsname.FQDN if question.Type != dnsmessage.TypeA && question.Type != dnsmessage.TypeAAAA { c.logf("mapping dns response for connector domain, unsupported type: %v", question.Type) - newBuf, err := c.rewriteDNSResponse(hdr, questions, answers) + newBuf, err := c.client.rewriteDNSResponse(appName, hdr, questions, answers) if err != nil { c.logf("error writing empty response for unsupported type: %v", err) return makeServFail(c.logf, hdr, question) @@ -920,15 +1057,32 @@ func (c *client) mapDNSResponse(buf []byte) []byte { } switch h.Type { case dnsmessage.TypeCNAME: - // An A record was asked for, and the answer is a CNAME, this answer will tell us which domain it's a CNAME for - // and a subsequent answer should tell us what the target domains address is (or possibly another CNAME). Drop - // this for now (2026-03-11) but in the near future we should collapse the CNAME chain and map to the ultimate - // destination address (see eg appc/{appconnector,observe}.go). - c.logf("not yet implemented CNAME answer: %v", queriedDomain) - if err := p.SkipAnswer(); err != nil { + // A DNS response with CNAME records might look a bit like + // + // a.example.com. CNAME b.example.com. + // b.example.com. CNAME example.com. + // example.com. A 1.1.1.1 + // + // We don't return CNAME records for our domains. We use them to build a + // cname chain so we can rewrite the final A/AAAA record to eg: + // + // a.example.com A (some magic IP that is associated with 1.1.1.1) + r, err := p.CNAMEResource() + if err != nil { c.logf("error parsing dns response: %v", err) return makeServFail(c.logf, hdr, question) } + src, err := normalizeDNSName(h.Name.String()) + if err != nil { + c.logf("bad dnsname: %v", err) + return makeServFail(c.logf, hdr, question) + } + target, err := normalizeDNSName(r.CNAME.String()) + if err != nil { + c.logf("bad dnsname: %v", err) + return makeServFail(c.logf, hdr, question) + } + mak.Set(&cnameChain, src, target) case dnsmessage.TypeA, dnsmessage.TypeAAAA: if h.Type != question.Type { // would not expect a v4 response to a v6 question or vice versa, don't add a rewrite for this. @@ -938,19 +1092,40 @@ func (c *client) mapDNSResponse(buf []byte) []byte { } continue } - domain, err := normalizeDNSName(h.Name.String()) + answerDomain, err := normalizeDNSName(h.Name.String()) if err != nil { c.logf("bad dnsname: %v", err) return makeServFail(c.logf, hdr, question) } - // answers should be for the domain that was queried - if domain != queriedDomain { - c.logf("unexpected domain for connector domain dns response: %v %v", queriedDomain, domain) - if err := p.SkipAnswer(); err != nil { - c.logf("error parsing dns response: %v", err) - return makeServFail(c.logf, hdr, question) + // If answerDomain is not the same domain as the domain that was queried for, + // try to walk down the cname chain from the queried domain until we find the answerDomain. + // If we can't, skip the answer. + // If we can, then we will rewrite the dns response to an A/AAAA record pointing + // the queriedDomain to the magic IP. + if answerDomain != queriedDomain { + d := queriedDomain + found := false + seen := set.Set[dnsname.FQDN]{} // avoid following cname record loops + for { + target, ok := cnameChain[d] + if !ok || seen.Contains(target) { + break + } + if target == answerDomain { + found = true + break + } + seen.Add(target) + d = target + } + if !found { + c.logf("unexpected domain for connector domain dns response: %v %v", queriedDomain, answerDomain) + if err := p.SkipAnswer(); err != nil { + c.logf("error parsing dns response: %v", err) + return makeServFail(c.logf, hdr, question) + } + continue } - continue } var dstAddr netip.Addr if h.Type == dnsmessage.TypeA { @@ -968,7 +1143,7 @@ func (c *client) mapDNSResponse(buf []byte) []byte { } dstAddr = netip.AddrFrom16(r.AAAA) } - answers = append(answers, dnsResponseRewrite{domain: domain, dst: dstAddr}) + answers = append(answers, dnsResponseRewrite{domain: queriedDomain, dst: dstAddr}) default: // we already checked the question was for a supported type, this answer is unexpected c.logf("unexpected type for connector domain dns response: %v %v", queriedDomain, h.Type) @@ -978,7 +1153,7 @@ func (c *client) mapDNSResponse(buf []byte) []byte { } } } - newBuf, err := c.rewriteDNSResponse(hdr, questions, answers) + newBuf, err := c.client.rewriteDNSResponse(appName, hdr, questions, answers) if err != nil { c.logf("error rewriting dns response: %v", err) return makeServFail(c.logf, hdr, question) @@ -986,7 +1161,7 @@ func (c *client) mapDNSResponse(buf []byte) []byte { return newBuf } -func (c *client) rewriteDNSResponse(hdr dnsmessage.Header, questions []dnsmessage.Question, answers []dnsResponseRewrite) ([]byte, error) { +func (c *client) rewriteDNSResponse(appName string, hdr dnsmessage.Header, questions []dnsmessage.Question, answers []dnsResponseRewrite) ([]byte, error) { b := dnsmessage.NewBuilder(nil, hdr) b.EnableCompression() if err := b.StartQuestions(); err != nil { @@ -1003,7 +1178,7 @@ func (c *client) rewriteDNSResponse(hdr dnsmessage.Header, questions []dnsmessag // make an answer for each rewrite for _, rw := range answers { - as, err := c.reserveAddresses(rw.domain, rw.dst) + as, err := c.reserveAddresses(appName, rw.domain, rw.dst) if err != nil { return nil, err } @@ -1044,22 +1219,18 @@ type connector struct { // transitIPs is a map of connector client peer IP -> client transitIPs that we update as connector client peers instruct us to, and then use to route traffic to its destination on behalf of connector clients. // Note that each peer could potentially have two maps: one for its IPv4 address, and one for its IPv6 address. The transit IPs map for a given peer IP will contain transit IPs of the same family as the peer's IP. transitIPs map[netip.Addr]map[netip.Addr]appAddr - config config } // realIPForTransitIPConnection is part of the implementation of the IPMapper interface for dataflows lookups. // See also [IPMapper.ConnectorRealIPForTransitIPConnection]. -func (c *connector) realIPForTransitIPConnection(srcIP netip.Addr, transitIP netip.Addr) (netip.Addr, error) { +func (c *connector) realIPForTransitIPConnection(srcIP netip.Addr, transitIP netip.Addr) (netip.Addr, bool) { c.mu.Lock() defer c.mu.Unlock() v, ok := c.lookupBySrcIPAndTransitIP(srcIP, transitIP) if ok { - return v.addr, nil + return v.addr, true } - if !c.config.v4TransitIPSet.Contains(transitIP) && !c.config.v6TransitIPSet.Contains(transitIP) { - return netip.Addr{}, nil - } - return netip.Addr{}, ErrUnmappedSrcAndTransitIP + return netip.Addr{}, false } const packetFilterAllowReason = "app connector transit IP" @@ -1085,73 +1256,36 @@ func (c *connector) lookupBySrcIPAndTransitIP(srcIP, transitIP netip.Addr) (appA return v, ok } -func (c *connector) reconfig(newCfg config) error { - c.mu.Lock() - defer c.mu.Unlock() - c.config = newCfg - return nil -} - type addrs struct { - dst netip.Addr - magic netip.Addr - transit netip.Addr - domain dnsname.FQDN - app string + dst netip.Addr + magic netip.Addr + transit netip.Addr + domain dnsname.FQDN + app string + expiresAt time.Time } -func (c addrs) isValid() bool { - return c.dst.IsValid() +func (as addrs) isValid() bool { + return as.dst.IsValid() } -// domainDst is a key for looking up an existing address assignment by the -// DNS response domain and destination IP pair. -type domainDst struct { - domain dnsname.FQDN - dst netip.Addr +func (as addrs) is4() bool { + return as.dst.Is4() } -// addrAssignments is the collection of addrs assigned by this client -// supporting lookup by magic IP, transit IP or domain+dst, or to lookup all -// transit IPs associated with a given connector (identified by its node key). -// byConnKey stores netip.Prefix versions of the transit IPs for use in the -// WireGuard hooks. -type addrAssignments struct { - byMagicIP map[netip.Addr]addrs - byTransitIP map[netip.Addr]addrs - byDomainDst map[domainDst]addrs - byConnKey map[key.NodePublic]set.Set[netip.Prefix] -} - -func (a *addrAssignments) insert(as addrs) error { - // we likely will want to allow overwriting in the future when we - // have address expiry, but for now this should not happen - if _, ok := a.byMagicIP[as.magic]; ok { - return errors.New("byMagicIP key exists") - } - ddst := domainDst{domain: as.domain, dst: as.dst} - if _, ok := a.byDomainDst[ddst]; ok { - return errors.New("byDomainDst key exists") - } - if _, ok := a.byTransitIP[as.transit]; ok { - return errors.New("byTransitIP key exists") - } - - mak.Set(&a.byMagicIP, as.magic, as) - mak.Set(&a.byTransitIP, as.transit, as) - mak.Set(&a.byDomainDst, ddst, as) - return nil +func (as addrs) is6() bool { + return as.dst.Is6() } // insertTransitConnMapping adds an entry to the byConnKey map // for the provided transitIP (as a prefix). // The provided transitIP must already be present in the byTransitIP map. -func (a *addrAssignments) insertTransitConnMapping(tip netip.Addr, connKey key.NodePublic) error { - if _, ok := a.lookupByTransitIP(tip); !ok { +func (c *client) insertTransitConnMapping(tip netip.Addr, connKey key.NodePublic) error { + if _, ok := c.assignments.lookupByTransitIP(tip); !ok { return errors.New("transit IP is not already known") } - ctips, ok := a.byConnKey[connKey] + ctips, ok := c.byConnKey[connKey] tipp := netip.PrefixFrom(tip, tip.BitLen()) if ok { if ctips.Contains(tipp) { @@ -1159,32 +1293,17 @@ func (a *addrAssignments) insertTransitConnMapping(tip netip.Addr, connKey key.N } } else { ctips.Make() - mak.Set(&a.byConnKey, connKey, ctips) + mak.Set(&c.byConnKey, connKey, ctips) } ctips.Add(tipp) return nil } -func (a *addrAssignments) lookupByDomainDst(domain dnsname.FQDN, dst netip.Addr) (addrs, bool) { - v, ok := a.byDomainDst[domainDst{domain: domain, dst: dst}] - return v, ok -} - -func (a *addrAssignments) lookupByMagicIP(mip netip.Addr) (addrs, bool) { - v, ok := a.byMagicIP[mip] - return v, ok -} - -func (a *addrAssignments) lookupByTransitIP(tip netip.Addr) (addrs, bool) { - v, ok := a.byTransitIP[tip] - return v, ok -} - // lookupTransitIPsByConnKey returns a slice containing the transit IPs (as netipPrefix) // associated with the given connector (identified by node key), or (nil, false) if there is no entry // for the given key. -func (a *addrAssignments) lookupTransitIPsByConnKey(k key.NodePublic) ([]netip.Prefix, bool) { - s, ok := a.byConnKey[k] +func (c *client) lookupTransitIPsByConnKey(k key.NodePublic) ([]netip.Prefix, bool) { + s, ok := c.byConnKey[k] if !ok { return nil, false } diff --git a/feature/conn25/conn25_test.go b/feature/conn25/conn25_test.go index 5f136c556..b3414d3da 100644 --- a/feature/conn25/conn25_test.go +++ b/feature/conn25/conn25_test.go @@ -5,6 +5,7 @@ package conn25 import ( "encoding/json" + "errors" "net/http" "net/http/httptest" "net/netip" @@ -17,6 +18,7 @@ import ( "go4.org/mem" "go4.org/netipx" "golang.org/x/net/dns/dnsmessage" + "tailscale.com/ipn" "tailscale.com/ipn/ipnext" "tailscale.com/net/dns" "tailscale.com/net/packet" @@ -338,9 +340,10 @@ func TestHandleConnectorTransitIPRequest(t *testing.T) { // Use the same Conn25 for each request in the test and seed it with a test app name. c := newConn25(logger.Discard) - c.connector.config = config{ - appsByName: map[string]appctype.Conn25Attr{appName: {}}, - } + c.reconfig(&config{ + isConfigured: true, + appsByName: map[string]appctype.Conn25Attr{appName: {}}, + }) for i, peer := range tt.ctipReqPeers { req := tt.ctipReqs[i] @@ -390,15 +393,19 @@ func TestHandleConnectorTransitIPRequest(t *testing.T) { func TestReserveIPs(t *testing.T) { c := newConn25(logger.Discard) - c.client.v4MagicIPPool = newIPPool(mustIPSetFromPrefix("100.64.0.0/24")) - c.client.v6MagicIPPool = newIPPool(mustIPSetFromPrefix("fd7a:115c:a1e0:a99c:0100::/80")) - c.client.v4TransitIPPool = newIPPool(mustIPSetFromPrefix("169.254.0.0/24")) - c.client.v6TransitIPPool = newIPPool(mustIPSetFromPrefix("fd7a:115c:a1e0:a99c:0200::/80")) - app := "a" + const appName = "a" domainStr := "example.com." - mbd := map[dnsname.FQDN][]string{} - mbd["example.com."] = []string{app} - c.client.config.appNamesByDomain = mbd + cfg := &config{ + isConfigured: true, + appsByName: map[string]appctype.Conn25Attr{appName: {}}, + ipSets: ipSets{ + v4Magic: mustIPSetFromPrefix("100.64.0.0/24"), + v6Magic: mustIPSetFromPrefix("fd7a:115c:a1e0:a99c:0100::/80"), + v4Transit: mustIPSetFromPrefix("169.254.0.0/24"), + v6Transit: mustIPSetFromPrefix("fd7a:115c:a1e0:a99c:0200::/80"), + }, + } + c.reconfig(cfg) domain := must.Get(dnsname.ToFQDN(domainStr)) for _, tt := range []struct { @@ -421,7 +428,7 @@ func TestReserveIPs(t *testing.T) { }, } { t.Run(tt.name, func(t *testing.T) { - addrs, err := c.client.reserveAddresses(domain, tt.dst) + addrs, err := c.client.reserveAddresses(appName, domain, tt.dst) if err != nil { t.Fatal(err) } @@ -434,8 +441,8 @@ func TestReserveIPs(t *testing.T) { if tt.wantTransit != addrs.transit { t.Errorf("want %v, got %v", tt.wantTransit, addrs.transit) } - if app != addrs.app { - t.Errorf("want %s, got %s", app, addrs.app) + if appName != addrs.app { + t.Errorf("want %s, got %s", appName, addrs.app) } if domain != addrs.domain { t.Errorf("want %s, got %s", domain, addrs.domain) @@ -453,29 +460,40 @@ func TestReconfig(t *testing.T) { } c := newConn25(logger.Discard) + if c.isConfigured() { + t.Fatal("expected Conn25 isConfigured() to report unconfigured before reconfig") + } + sn := (&tailcfg.Node{ CapMap: capMap, }).View() + cfg := mustConfig(t, sn) + c.reconfig(cfg) - err := c.reconfig(sn) - if err != nil { - t.Fatal(err) + if !c.isConfigured() { + t.Fatal("expected Conn25 isConfigured() to report configured after reconfig") } - if len(c.client.config.apps) != 1 || c.client.config.apps[0].Name != "app1" { - t.Fatalf("want apps to have one entry 'app1', got %v", c.client.config.apps) + cfg, ok := c.getConfig() + if !ok { + t.Fatal("expected Conn25 getConfig() to report configured after reconfig") + } + + if len(cfg.apps) != 1 || cfg.apps[0].Name != "app1" { + t.Fatalf("want apps to have one entry 'app1', got %v", cfg.apps) } } -func TestConfigReconfig(t *testing.T) { +func TestConfigFromNodeView(t *testing.T) { for _, tt := range []struct { - name string - rawCfg string - cfg []appctype.Conn25Attr - tags []string - wantErr bool - wantAppsByDomain map[dnsname.FQDN][]string - wantSelfRoutedDomains set.Set[dnsname.FQDN] + name string + rawCfg string + cfg []appctype.Conn25Attr + tags []string + wantErr bool + wantAppsByDomain map[dnsname.FQDN][]string + wantAppsByWCDomain map[dnsname.FQDN][]string + wantSelfAppNames set.Set[string] }{ { name: "bad-config", @@ -493,10 +511,11 @@ func TestConfigReconfig(t *testing.T) { "a.example.com.": {"one"}, "b.example.com.": {"two"}, }, - wantSelfRoutedDomains: set.SetOf([]dnsname.FQDN{"a.example.com."}), + wantAppsByWCDomain: map[dnsname.FQDN][]string{}, + wantSelfAppNames: set.SetOf([]string{"one"}), }, { - name: "more-complex", + name: "more-complex-with-connector-self-domains", cfg: []appctype.Conn25Attr{ {Name: "one", Domains: []string{"1.a.example.com", "1.b.example.com"}, Connectors: []string{"tag:one", "tag:onea"}}, {Name: "two", Domains: []string{"2.b.example.com", "2.c.example.com"}, Connectors: []string{"tag:two", "tag:twoa"}}, @@ -513,7 +532,63 @@ func TestConfigReconfig(t *testing.T) { "4.b.example.com.": {"four"}, "4.d.example.com.": {"four"}, }, - wantSelfRoutedDomains: set.SetOf([]dnsname.FQDN{"1.a.example.com.", "1.b.example.com.", "4.b.example.com.", "4.d.example.com."}), + wantAppsByWCDomain: map[dnsname.FQDN][]string{}, + wantSelfAppNames: set.SetOf([]string{"one", "four"}), + }, + { + name: "eligible-connector-no-matching-tag-no-self-domains", + cfg: []appctype.Conn25Attr{ + {Name: "one", Domains: []string{"a.example.com"}, Connectors: []string{"tag:one"}}, + {Name: "two", Domains: []string{"b.example.com"}, Connectors: []string{"tag:two"}}, + }, + tags: []string{"tag:unrelated"}, + wantAppsByDomain: map[dnsname.FQDN][]string{ + "a.example.com.": {"one"}, + "b.example.com.": {"two"}, + }, + wantAppsByWCDomain: map[dnsname.FQDN][]string{}}, + { + name: "wildcard-collapse-and-deduplication", + cfg: []appctype.Conn25Attr{ + {Name: "one", Domains: []string{"*.example.com", "example.com"}, Connectors: []string{"tag:one"}}, + {Name: "two", Domains: []string{"example.com", "sub.example.com"}, Connectors: []string{"tag:two"}}, + }, + tags: []string{"tag:one", "tag:two"}, + wantAppsByDomain: map[dnsname.FQDN][]string{ + "example.com.": {"one", "two"}, + "sub.example.com.": {"two"}, + }, + wantAppsByWCDomain: map[dnsname.FQDN][]string{ + "example.com.": {"one"}, + }, + wantSelfAppNames: set.SetOf([]string{"one", "two"}), + }, + { + // Domain names that differ only in case must be treated as the same + // domain and the app name must appear exactly once in appNamesByDomain, + // not once per case variant. + name: "case-variant-exact-domains-deduplicated-within-app", + cfg: []appctype.Conn25Attr{ + {Name: "one", Domains: []string{"EXAMPLE.com", "example.COM", "Example.COM"}, Connectors: []string{"tag:one"}}, + }, + tags: []string{"tag:one"}, + wantAppsByDomain: map[dnsname.FQDN][]string{ + "example.com.": {"one"}, + }, + wantAppsByWCDomain: map[dnsname.FQDN][]string{}, + wantSelfAppNames: set.SetOf([]string{"one"}), + }, + { + // Same as above but for wildcard domains: *.EXAMPLE.com and *.example.COM + // must collapse to a single entry in appNamesByWCDomain. + name: "case-variant-wildcard-domains-deduplicated-within-app", + cfg: []appctype.Conn25Attr{ + {Name: "one", Domains: []string{"*.EXAMPLE.com", "*.example.COM"}, Connectors: []string{"tag:one"}}, + }, + tags: []string{"tag:one"}, + wantAppsByDomain: map[dnsname.FQDN][]string{}, + wantAppsByWCDomain: map[dnsname.FQDN][]string{"example.com.": {"one"}}, + wantSelfAppNames: set.SetOf([]string{"one"}), }, } { t.Run(tt.name, func(t *testing.T) { @@ -535,6 +610,7 @@ func TestConfigReconfig(t *testing.T) { CapMap: capMap, Tags: tt.tags, }).View() + c, err := configFromNodeView(sn) if (err != nil) != tt.wantErr { t.Fatalf("wantErr: %t, err: %v", tt.wantErr, err) @@ -542,8 +618,114 @@ func TestConfigReconfig(t *testing.T) { if diff := cmp.Diff(tt.wantAppsByDomain, c.appNamesByDomain); diff != "" { t.Errorf("appsByDomain diff (-want, +got):\n%s", diff) } - if diff := cmp.Diff(tt.wantSelfRoutedDomains, c.selfRoutedDomains); diff != "" { - t.Errorf("selfRoutedDomains diff (-want, +got):\n%s", diff) + if diff := cmp.Diff(tt.wantAppsByWCDomain, c.appNamesByWCDomain); diff != "" { + t.Errorf("appsByWCDomain diff (-want, +got):\n%s", diff) + } + if diff := cmp.Diff(tt.wantSelfAppNames, c.selfAppNames); diff != "" { + t.Errorf("selfAppNames diff (-want, +got):\n%s", diff) + } + }) + } +} + +func TestGetAppsForDomainName(t *testing.T) { + defaultSN := makeSelfNode( + t, + []appctype.Conn25Attr{ + {Name: "one", Domains: []string{"*.example.com", "example.com"}, Connectors: []string{"tag:one"}}, + {Name: "two", Domains: []string{"sub.example.com", "example.com"}, Connectors: []string{"tag:two"}}, + {Name: "three", Domains: []string{"*.sub.example.com"}, Connectors: []string{"tag:three"}}, + {Name: "four", Domains: []string{"a.sub.example.com"}, Connectors: []string{"tag:four"}}, + {Name: "self-routed", Domains: []string{"*.wildcard.com", "exact-match.com"}, Connectors: []string{"tag:self-routed"}}, + }, + []string{"tag:self-routed"}, + ) + + for _, tt := range []struct { + name string + isConnector bool + domain dnsname.FQDN + wantApps []string + }{ + { + name: "no-match", + domain: "nomatch.com.", + wantApps: nil, + }, + { + name: "exact-match", + domain: "example.com.", + wantApps: []string{"one", "two"}, + }, + { + name: "wildcard-subdomain-match", + domain: "a.example.com.", + wantApps: []string{"one"}, + }, + { + name: "exact-subdomain-match", + domain: "sub.example.com.", + wantApps: []string{"two"}, + }, + { + name: "wildcard-sub-of-subdomain-match", + domain: "b.sub.example.com.", + wantApps: []string{"three"}, + }, + { + name: "exact-sub-of-subdomain-match", + domain: "a.sub.example.com.", + wantApps: []string{"four"}, + }, + { + name: "exact-domain-matches-wildcard", + domain: "wildcard.com.", + wantApps: []string{"self-routed"}, + }, + { + name: "self-routed-exact-domain-suppressed", + isConnector: true, + domain: "exact-match.com.", + wantApps: nil, + }, + { + // Self node is an eligible connector for "wildcard-self-app" via + // *.wildcard.com, so the wildcard match must also be suppressed. + name: "self-routed-wildcard-domain-suppressed", + isConnector: true, + domain: "sub.wildcard.com.", + wantApps: nil, + }, + { + // "other-app" is not on a self-connector tag, so it must not be suppressed. + name: "non-self-routed-domain-not-suppressed", + isConnector: true, + domain: "example.com.", + wantApps: []string{"one", "two"}, + }, + { + // Even though the app's connector tag matches the self node's tags, + // if the node is not an eligible connector (Advertise=false) then + // isSelfRoutedApp returns false and the domain is forwarded normally. + name: "not-eligible-connector-not-suppressed", + domain: "exact-match.com.", + wantApps: []string{"self-routed"}, + }, + } { + t.Run(tt.name, func(t *testing.T) { + c := newConn25(logger.Discard) + if tt.isConnector { + c.prefsAdvertiseConnector.Store(true) + } + cfg := mustConfig(t, defaultSN) + c.reconfig(cfg) + cfg, ok := c.getConfig() + if !ok { + t.Fatal("could not get config") + } + gotApps := cfg.getAppsForConnectorDomain(tt.domain, tt.isConnector) + if diff := cmp.Diff(tt.wantApps, gotApps); diff != "" { + t.Errorf("unexpected appNames result: diff (-want, +got):\n%s", diff) } }) } @@ -562,12 +744,26 @@ func makeSelfNode(t *testing.T, attrs []appctype.Conn25Attr, tags []string) tail capMap := tailcfg.NodeCapMap{ tailcfg.NodeCapability(AppConnectorsExperimentalAttrName): cfg, } + return (&tailcfg.Node{ CapMap: capMap, Tags: tags, }).View() } +var ( + testPrefsNotConnector = (&ipn.Prefs{AppConnector: ipn.AppConnectorPrefs{Advertise: false}}).View() +) + +func mustConfig(t *testing.T, selfNode tailcfg.NodeView) *config { + t.Helper() + cfg, err := configFromNodeView(selfNode) + if err != nil { + t.Fatal(err) + } + return cfg +} + func v4RangeFrom(from, to string) netipx.IPRange { return netipx.IPRangeFrom( netip.MustParseAddr("100.64.0."+from), @@ -681,6 +877,12 @@ func makeDNSResponseForSections(t *testing.T, questions []dnsmessage.Question, a t.Fatalf("unexpected answer type, update test") } b.AAAAResource(ans.Header, *body) + case dnsmessage.TypeCNAME: + body, ok := (ans.Body).(*dnsmessage.CNAMEResource) + if !ok { + t.Fatalf("unexpected answer type, update test") + } + b.CNAMEResource(ans.Header, *body) default: t.Fatalf("unhandled answer type, update test: %v", ans.Header.Type) } @@ -706,18 +908,22 @@ func makeDNSResponseForSections(t *testing.T, questions []dnsmessage.Question, a func TestMapDNSResponseAssignsAddrs(t *testing.T) { for _, tt := range []struct { - name string - domain string - v4Addrs []*dnsmessage.AResource - v6Addrs []*dnsmessage.AAAAResource - wantByMagicIP map[netip.Addr]addrs + name string + appDomains []string + domain string + v4Addrs []*dnsmessage.AResource + v6Addrs []*dnsmessage.AAAAResource + selfTags []string + isEligibleConnector bool + wantByMagicIP map[netip.Addr]*addrs }{ { - name: "one-ip-matches", - domain: "example.com.", - v4Addrs: []*dnsmessage.AResource{{A: [4]byte{1, 0, 0, 0}}}, + name: "one-ip-matches", + appDomains: []string{"example.com"}, + domain: "example.com.", + v4Addrs: []*dnsmessage.AResource{{A: [4]byte{1, 0, 0, 0}}}, // these are 'expected' because they are the beginning of the provided pools - wantByMagicIP: map[netip.Addr]addrs{ + wantByMagicIP: map[netip.Addr]*addrs{ netip.MustParseAddr("100.64.0.0"): { domain: "example.com.", dst: netip.MustParseAddr("1.0.0.0"), @@ -728,13 +934,14 @@ func TestMapDNSResponseAssignsAddrs(t *testing.T) { }, }, { - name: "v6-ip-matches", - domain: "example.com.", + name: "v6-ip-matches", + appDomains: []string{"example.com"}, + domain: "example.com.", v6Addrs: []*dnsmessage.AAAAResource{ {AAAA: [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}}, {AAAA: [16]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2}}, }, - wantByMagicIP: map[netip.Addr]addrs{ + wantByMagicIP: map[netip.Addr]*addrs{ netip.MustParseAddr("fd7a:115c:a1e0:a99c::"): { domain: "example.com.", dst: netip.MustParseAddr("::1"), @@ -752,13 +959,14 @@ func TestMapDNSResponseAssignsAddrs(t *testing.T) { }, }, { - name: "multiple-ip-matches", - domain: "example.com.", + name: "multiple-ip-matches", + appDomains: []string{"example.com"}, + domain: "example.com.", v4Addrs: []*dnsmessage.AResource{ {A: [4]byte{1, 0, 0, 0}}, {A: [4]byte{2, 0, 0, 0}}, }, - wantByMagicIP: map[netip.Addr]addrs{ + wantByMagicIP: map[netip.Addr]*addrs{ netip.MustParseAddr("100.64.0.0"): { domain: "example.com.", dst: netip.MustParseAddr("1.0.0.0"), @@ -776,13 +984,107 @@ func TestMapDNSResponseAssignsAddrs(t *testing.T) { }, }, { - name: "no-domain-match", - domain: "x.example.com.", + name: "no-domain-match", + appDomains: []string{"foo.example.com"}, + domain: "bad.example.com.", v4Addrs: []*dnsmessage.AResource{ {A: [4]byte{1, 0, 0, 0}}, {A: [4]byte{2, 0, 0, 0}}, }, }, + { + name: "no-rewrite-self-routed-domain", + appDomains: []string{"example.com"}, + domain: "example.com.", + v4Addrs: []*dnsmessage.AResource{{A: [4]byte{1, 0, 0, 0}}}, + selfTags: []string{"tag:woo"}, + isEligibleConnector: true, + }, + { + name: "rewrite-tagged-but-not-eligible-connector", + appDomains: []string{"example.com"}, + domain: "example.com.", + v4Addrs: []*dnsmessage.AResource{{A: [4]byte{1, 0, 0, 0}}}, + selfTags: []string{"tag:woo"}, + // isEligibleConnector is false: tag matches but prefs not set, + // so DNS response should be rewritten normally. + wantByMagicIP: map[netip.Addr]*addrs{ + netip.MustParseAddr("100.64.0.0"): { + domain: "example.com.", + dst: netip.MustParseAddr("1.0.0.0"), + magic: netip.MustParseAddr("100.64.0.0"), + transit: netip.MustParseAddr("100.64.0.40"), + app: "app1", + }, + }, + }, + { + name: "rewrite-eligible-connector-no-matching-tag", + appDomains: []string{"example.com"}, + domain: "example.com.", + v4Addrs: []*dnsmessage.AResource{{A: [4]byte{1, 0, 0, 0}}}, + selfTags: []string{"tag:unrelated"}, + isEligibleConnector: true, + // isEligibleConnector is true but tag doesn't match the app, + // so DNS response should be rewritten normally. + wantByMagicIP: map[netip.Addr]*addrs{ + netip.MustParseAddr("100.64.0.0"): { + domain: "example.com.", + dst: netip.MustParseAddr("1.0.0.0"), + magic: netip.MustParseAddr("100.64.0.0"), + transit: netip.MustParseAddr("100.64.0.40"), + app: "app1", + }, + }, + }, + { + name: "subdomain-matches-wildcard", + appDomains: []string{"*.example.com"}, + domain: "sub.example.com.", + v4Addrs: []*dnsmessage.AResource{{A: [4]byte{1, 0, 0, 0}}}, + // these are 'expected' because they are the beginning of the provided pools + wantByMagicIP: map[netip.Addr]*addrs{ + netip.MustParseAddr("100.64.0.0"): { + domain: "sub.example.com.", + dst: netip.MustParseAddr("1.0.0.0"), + magic: netip.MustParseAddr("100.64.0.0"), + transit: netip.MustParseAddr("100.64.0.40"), + app: "app1", + }, + }, + }, + { + name: "exact-subdomain-matches", + appDomains: []string{"example.com", "sub.example.com"}, + domain: "sub.example.com.", + v4Addrs: []*dnsmessage.AResource{{A: [4]byte{1, 0, 0, 0}}}, + // these are 'expected' because they are the beginning of the provided pools + wantByMagicIP: map[netip.Addr]*addrs{ + netip.MustParseAddr("100.64.0.0"): { + domain: "sub.example.com.", + dst: netip.MustParseAddr("1.0.0.0"), + magic: netip.MustParseAddr("100.64.0.0"), + transit: netip.MustParseAddr("100.64.0.40"), + app: "app1", + }, + }, + }, + { + name: "wildcard-subdomain-matches-subdomain", + appDomains: []string{"example.com", "*.sub.example.com"}, + domain: "a.sub.example.com.", + v4Addrs: []*dnsmessage.AResource{{A: [4]byte{1, 0, 0, 0}}}, + // these are 'expected' because they are the beginning of the provided pools + wantByMagicIP: map[netip.Addr]*addrs{ + netip.MustParseAddr("100.64.0.0"): { + domain: "a.sub.example.com.", + dst: netip.MustParseAddr("1.0.0.0"), + magic: netip.MustParseAddr("100.64.0.0"), + transit: netip.MustParseAddr("100.64.0.40"), + app: "app1", + }, + }, + }, } { t.Run(tt.name, func(t *testing.T) { var dnsResp []byte @@ -794,23 +1096,55 @@ func TestMapDNSResponseAssignsAddrs(t *testing.T) { sn := makeSelfNode(t, []appctype.Conn25Attr{{ Name: "app1", Connectors: []string{"tag:woo"}, - Domains: []string{"example.com"}, + Domains: tt.appDomains, V4MagicIPPool: []netipx.IPRange{v4RangeFrom("0", "10"), v4RangeFrom("20", "30")}, V6MagicIPPool: []netipx.IPRange{v6RangeFrom("0", "10"), v6RangeFrom("20", "30")}, V4TransitIPPool: []netipx.IPRange{v4RangeFrom("40", "50")}, V6TransitIPPool: []netipx.IPRange{v6RangeFrom("40", "50")}, - }}, []string{}) + }}, tt.selfTags) + c := newConn25(logger.Discard) - c.reconfig(sn) + cfg := mustConfig(t, sn) + c.reconfig(cfg) + c.prefsAdvertiseConnector.Store(tt.isEligibleConnector) c.mapDNSResponse(dnsResp) - if diff := cmp.Diff(tt.wantByMagicIP, c.client.assignments.byMagicIP, cmpopts.EquateComparable(addrs{}, netip.Addr{})); diff != "" { + if diff := cmp.Diff( + tt.wantByMagicIP, + c.client.assignments.byMagicIP, + cmp.AllowUnexported(addrs{}), + cmpopts.IgnoreFields(addrs{}, "expiresAt"), + cmpopts.EquateComparable(netip.Addr{}), + ); diff != "" { t.Errorf("byMagicIP diff (-want, +got):\n%s", diff) } }) } } +func TestNormalizedDNSNames(t *testing.T) { + tests := []struct { + name string + domain string + want dnsname.FQDN + }{ + {name: "no-change", domain: "example.com.", want: "example.com."}, + {name: "mixed-case", domain: "eXAmPle.COM", want: "example.com."}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := normalizeDNSName(tt.domain) + if err != nil { + t.Errorf("unexpected error %v", err) + } + if got != tt.want { + t.Errorf("Unexpected result, want %q, got %q", tt.want, got) + } + }) + } +} + func TestReserveAddressesDeduplicated(t *testing.T) { for _, tt := range []struct { name string @@ -826,30 +1160,33 @@ func TestReserveAddressesDeduplicated(t *testing.T) { }, } { t.Run(tt.name, func(t *testing.T) { - c := newConn25(logger.Discard) - c.client.v4MagicIPPool = newIPPool(mustIPSetFromPrefix("100.64.0.0/24")) - c.client.v6MagicIPPool = newIPPool(mustIPSetFromPrefix("fd7a:115c:a1e0:a99c:0100::/80")) - c.client.v4TransitIPPool = newIPPool(mustIPSetFromPrefix("169.254.0.0/24")) - c.client.v6TransitIPPool = newIPPool(mustIPSetFromPrefix("fd7a:115c:a1e0:a99c:0200::/80")) - c.client.config.appNamesByDomain = map[dnsname.FQDN][]string{"example.com.": {"a"}} + const appName = "a" + conn25 := newConn25(t.Logf) + c := conn25.client + c.v4MagicIPPool = newIPPool(mustIPSetFromPrefix("100.64.0.0/24")) + c.v6MagicIPPool = newIPPool(mustIPSetFromPrefix("fd7a:115c:a1e0:a99c:0100::/80")) + c.v4TransitIPPool = newIPPool(mustIPSetFromPrefix("169.254.0.0/24")) + c.v6TransitIPPool = newIPPool(mustIPSetFromPrefix("fd7a:115c:a1e0:a99c:0200::/80")) - first, err := c.client.reserveAddresses("example.com.", tt.dst) + first, err := c.reserveAddresses(appName, "example.com.", tt.dst) if err != nil { t.Fatal(err) } - second, err := c.client.reserveAddresses("example.com.", tt.dst) + second, err := c.reserveAddresses(appName, "example.com.", tt.dst) if err != nil { t.Fatal(err) } if first != second { - t.Errorf("expected same addrs on repeated call, got first=%v second=%v", first, second) + // reserveAddresses should return the existing entry when called for a domain that already has assigned addrs + t.Fatalf("want first==second, got first: %v, second: %v", first, second) } - if got := len(c.client.assignments.byMagicIP); got != 1 { + + if got := len(c.assignments.byMagicIP); got != 1 { t.Errorf("want 1 entry in byMagicIP, got %d", got) } - if got := len(c.client.assignments.byDomainDst); got != 1 { + if got := len(c.assignments.byDomainDst); got != 1 { t.Errorf("want 1 entry in byDomainDst, got %d", got) } @@ -880,16 +1217,28 @@ func (nb *testNodeBackend) PeerAPIBase(p tailcfg.NodeView) string { return nb.peerAPIURL } +type testProfileServices struct { + ipnext.ProfileServices + prefs ipn.PrefsView +} + +func (p *testProfileServices) CurrentPrefs() ipn.PrefsView { return p.prefs } +func (p *testProfileServices) CurrentProfileState() (ipn.LoginProfileView, ipn.PrefsView) { + return ipn.LoginProfileView{}, p.prefs +} + type testHost struct { ipnext.Host nb ipnext.NodeBackend hooks ipnext.Hooks + prefs ipn.PrefsView authReconfigAsync func() } -func (h *testHost) NodeBackend() ipnext.NodeBackend { return h.nb } -func (h *testHost) Hooks() *ipnext.Hooks { return &h.hooks } -func (h *testHost) AuthReconfigAsync() { h.authReconfigAsync() } +func (h *testHost) NodeBackend() ipnext.NodeBackend { return h.nb } +func (h *testHost) Hooks() *ipnext.Hooks { return &h.hooks } +func (h *testHost) Profiles() ipnext.ProfileServices { return &testProfileServices{prefs: h.prefs} } +func (h *testHost) AuthReconfigAsync() { h.authReconfigAsync() } type testSafeBackend struct { ipnext.SafeBackend @@ -950,6 +1299,7 @@ func TestAddressAssignmentIsHandled(t *testing.T) { peers: []tailcfg.NodeView{connectorPeer}, peerAPIURL: peersAPI.URL, }, + prefs: testPrefsNotConnector, authReconfigAsync: func() { authReconfigAsyncCalled <- struct{}{} }, @@ -963,12 +1313,11 @@ func TestAddressAssignmentIsHandled(t *testing.T) { Connectors: []string{"tag:woo"}, Domains: []string{"example.com"}, }}, []string{}) - err := ext.conn25.reconfig(sn) - if err != nil { - t.Fatal(err) - } - as := addrs{ + cfg := mustConfig(t, sn) + ext.conn25.reconfig(cfg) + + as := &addrs{ dst: netip.MustParseAddr("1.2.3.4"), magic: netip.MustParseAddr("100.64.0.0"), transit: netip.MustParseAddr("169.254.0.1"), @@ -1046,6 +1395,8 @@ func TestMapDNSResponseRewritesResponses(t *testing.T) { V6TransitIPPool: []netipx.IPRange{netipx.IPRangeFrom(netip.MustParseAddr("2606:4700::6813:100"), netip.MustParseAddr("2606:4700::6813:1ff"))}, }}, []string{}) + cfg := mustConfig(t, sn) + compareToRecords := func(t *testing.T, resources []dnsmessage.Resource, want []netip.Addr) { t.Helper() var got []netip.Addr @@ -1320,12 +1671,178 @@ func TestMapDNSResponseRewritesResponses(t *testing.T) { ), assertFx: assertParsesToAnswers([]netip.Addr{netip.MustParseAddr("2606:4700::6812:100")}), }, + { + name: "cname-resolves-to-magic-ip", + toMap: makeDNSResponseForSections(t, + []dnsmessage.Question{{Name: dnsMessageName, Type: dnsmessage.TypeA, Class: dnsmessage.ClassINET}}, + []dnsmessage.Resource{ + { + Header: dnsmessage.ResourceHeader{ + Name: dnsMessageName, + Type: dnsmessage.TypeCNAME, + Class: dnsmessage.ClassINET, + }, + Body: &dnsmessage.CNAMEResource{CNAME: dnsmessage.MustNewName("a.example.com.")}, + }, + { + Header: dnsmessage.ResourceHeader{ + Name: dnsmessage.MustNewName("a.example.com."), + Type: dnsmessage.TypeCNAME, + Class: dnsmessage.ClassINET, + }, + Body: &dnsmessage.CNAMEResource{CNAME: dnsmessage.MustNewName("b.example.com.")}, + }, + { + Header: dnsmessage.ResourceHeader{ + Name: dnsmessage.MustNewName("b.example.com."), + Type: dnsmessage.TypeCNAME, + Class: dnsmessage.ClassINET, + }, + Body: &dnsmessage.CNAMEResource{CNAME: dnsmessage.MustNewName("c.example.com.")}, + }, + { + Header: dnsmessage.ResourceHeader{ + Name: dnsmessage.MustNewName("c.example.com."), + Type: dnsmessage.TypeA, + Class: dnsmessage.ClassINET, + }, + Body: &dnsmessage.AResource{A: netip.MustParseAddr("1.2.3.4").As4()}, + }, + }, + nil, + ), + assertFx: assertParsesToAnswers([]netip.Addr{netip.MustParseAddr("100.64.0.0")}), + }, + { + name: "cname-aaaa-resolves-to-magic-ip", + toMap: makeDNSResponseForSections(t, + []dnsmessage.Question{ + { + Name: dnsMessageName, + Type: dnsmessage.TypeAAAA, + Class: dnsmessage.ClassINET, + }, + }, + []dnsmessage.Resource{ + { + Header: dnsmessage.ResourceHeader{ + Name: dnsMessageName, + Type: dnsmessage.TypeCNAME, + Class: dnsmessage.ClassINET, + }, + Body: &dnsmessage.CNAMEResource{CNAME: dnsmessage.MustNewName("cdn.example.net.")}, + }, + { + Header: dnsmessage.ResourceHeader{ + Name: dnsmessage.MustNewName("cdn.example.net."), + Type: dnsmessage.TypeAAAA, + Class: dnsmessage.ClassINET, + }, + Body: &dnsmessage.AAAAResource{AAAA: netip.MustParseAddr("2606:4700::6812:1a78").As16()}, + }, + }, + nil, + ), + assertFx: assertParsesToAnswers([]netip.Addr{netip.MustParseAddr("2606:4700::6812:100")}), + }, + { + name: "cname-broken-chain-skips-answer", + toMap: makeDNSResponseForSections(t, + []dnsmessage.Question{{Name: dnsMessageName, Type: dnsmessage.TypeA, Class: dnsmessage.ClassINET}}, + []dnsmessage.Resource{ + { + Header: dnsmessage.ResourceHeader{ + Name: dnsMessageName, + Type: dnsmessage.TypeCNAME, + Class: dnsmessage.ClassINET, + }, + Body: &dnsmessage.CNAMEResource{CNAME: dnsmessage.MustNewName("cdn.example.net.")}, + }, + { + Header: dnsmessage.ResourceHeader{ + Name: dnsmessage.MustNewName("unrelated.com."), + Type: dnsmessage.TypeA, + Class: dnsmessage.ClassINET, + }, + Body: &dnsmessage.AResource{A: netip.MustParseAddr("1.2.3.4").As4()}, + }, + }, + nil, + ), + assertFx: assertParsesToAnswers(nil), + }, + { + name: "cname-multi-source-same-target", + toMap: makeDNSResponseForSections(t, + []dnsmessage.Question{{Name: dnsMessageName, Type: dnsmessage.TypeA, Class: dnsmessage.ClassINET}}, + []dnsmessage.Resource{ + { + Header: dnsmessage.ResourceHeader{ + Name: dnsMessageName, + Type: dnsmessage.TypeCNAME, + Class: dnsmessage.ClassINET, + }, + Body: &dnsmessage.CNAMEResource{CNAME: dnsmessage.MustNewName("z.example.com.")}, + }, + { + Header: dnsmessage.ResourceHeader{ + Name: dnsmessage.MustNewName("a.example.com."), + Type: dnsmessage.TypeCNAME, + Class: dnsmessage.ClassINET, + }, + Body: &dnsmessage.CNAMEResource{CNAME: dnsmessage.MustNewName("z.example.com.")}, + }, + { + Header: dnsmessage.ResourceHeader{ + Name: dnsmessage.MustNewName("z.example.com."), + Type: dnsmessage.TypeA, + Class: dnsmessage.ClassINET, + }, + Body: &dnsmessage.AResource{A: netip.MustParseAddr("1.2.3.4").As4()}, + }, + }, + nil, + ), + assertFx: assertParsesToAnswers([]netip.Addr{netip.MustParseAddr("100.64.0.0")}), + }, + { + name: "cname-has-loop", + toMap: makeDNSResponseForSections(t, + []dnsmessage.Question{{Name: dnsMessageName, Type: dnsmessage.TypeA, Class: dnsmessage.ClassINET}}, + []dnsmessage.Resource{ + { + Header: dnsmessage.ResourceHeader{ + Name: dnsMessageName, + Type: dnsmessage.TypeCNAME, + Class: dnsmessage.ClassINET, + }, + Body: &dnsmessage.CNAMEResource{CNAME: dnsmessage.MustNewName("a.example.com.")}, + }, + { + Header: dnsmessage.ResourceHeader{ + Name: dnsmessage.MustNewName("a.example.com."), + Type: dnsmessage.TypeCNAME, + Class: dnsmessage.ClassINET, + }, + Body: &dnsmessage.CNAMEResource{CNAME: dnsMessageName}, + }, + { + Header: dnsmessage.ResourceHeader{ + Name: dnsmessage.MustNewName("z.example.com."), + Type: dnsmessage.TypeA, + Class: dnsmessage.ClassINET, + }, + Body: &dnsmessage.AResource{A: netip.MustParseAddr("1.2.3.4").As4()}, + }, + }, + nil, + ), + assertFx: assertParsesToAnswers(nil), + }, } { t.Run(tt.name, func(t *testing.T) { c := newConn25(logger.Discard) - if err := c.reconfig(sn); err != nil { - t.Fatal(err) - } + c.reconfig(cfg) bs := c.mapDNSResponse(tt.toMap) tt.assertFx(t, bs) }) @@ -1376,6 +1893,7 @@ func TestHandleAddressAssignmentStoresTransitIPs(t *testing.T) { peers: connectorPeers, peerAPIURL: peersAPI.URL, }, + prefs: testPrefsNotConnector, authReconfigAsync: func() { authReconfigAsyncCalled <- struct{}{} }, @@ -1396,10 +1914,9 @@ func TestHandleAddressAssignmentStoresTransitIPs(t *testing.T) { Domains: []string{"hoo.example.com"}, }, }, []string{}) - err := ext.conn25.reconfig(sn) - if err != nil { - t.Fatal(err) - } + + cfg := mustConfig(t, sn) + ext.conn25.reconfig(cfg) type lookup struct { connKey key.NodePublic @@ -1416,12 +1933,12 @@ func TestHandleAddressAssignmentStoresTransitIPs(t *testing.T) { // and then does the lookups. steps := []struct { name string - as addrs + as *addrs lookups []lookup }{ { name: "step-1-conn1-tip1", - as: addrs{ + as: &addrs{ dst: netip.MustParseAddr("1.2.3.1"), magic: netip.MustParseAddr("100.64.0.1"), transit: transitIPs[0].Addr(), @@ -1445,7 +1962,7 @@ func TestHandleAddressAssignmentStoresTransitIPs(t *testing.T) { }, { name: "step-2-conn1-tip2", - as: addrs{ + as: &addrs{ dst: netip.MustParseAddr("1.2.3.2"), magic: netip.MustParseAddr("100.64.0.2"), transit: transitIPs[1].Addr(), @@ -1465,7 +1982,7 @@ func TestHandleAddressAssignmentStoresTransitIPs(t *testing.T) { }, { name: "step-3-conn2-tip1", - as: addrs{ + as: &addrs{ dst: netip.MustParseAddr("1.2.3.3"), magic: netip.MustParseAddr("100.64.0.3"), transit: transitIPs[2].Addr(), @@ -1510,7 +2027,7 @@ func TestHandleAddressAssignmentStoresTransitIPs(t *testing.T) { // Check that each of the lookups behaves as expected for i, lu := range tt.lookups { - got, ok := ext.conn25.client.assignments.lookupTransitIPsByConnKey(lu.connKey) + got, ok := ext.conn25.client.lookupTransitIPsByConnKey(lu.connKey) if ok != lu.expectedOk { t.Fatalf("unexpected ok result at index %d wanted %v, got %v", i, lu.expectedOk, ok) } @@ -1526,7 +2043,7 @@ func TestHandleAddressAssignmentStoresTransitIPs(t *testing.T) { func TestTransitIPConnMapping(t *testing.T) { conn25 := newConn25(t.Logf) - as := addrs{ + as := &addrs{ dst: netip.MustParseAddr("1.2.3.1"), magic: netip.MustParseAddr("100.64.0.1"), transit: netip.MustParseAddr("169.254.0.1"), @@ -1576,6 +2093,7 @@ func TestClientTransitIPForMagicIP(t *testing.T) { V4MagicIPPool: []netipx.IPRange{v4RangeFrom("0", "10")}, // 100.64.0.0 - 100.64.0.10 V6MagicIPPool: []netipx.IPRange{v6RangeFrom("0", "10")}, }}, []string{}) + cfg := mustConfig(t, sn) mappedMip := netip.MustParseAddr("100.64.0.0") mappedTip := netip.MustParseAddr("169.0.0.0") @@ -1634,24 +2152,23 @@ func TestClientTransitIPForMagicIP(t *testing.T) { } { t.Run(tt.name, func(t *testing.T) { c := newConn25(t.Logf) - if err := c.reconfig(sn); err != nil { - t.Fatal(err) - } - if err := c.client.assignments.insert(addrs{ + c.reconfig(cfg) + + if err := c.client.assignments.insert(&addrs{ magic: mappedMip, transit: mappedTip, dst: dst, }); err != nil { t.Fatal(err) } - if err := c.client.assignments.insert(addrs{ + if err := c.client.assignments.insert(&addrs{ magic: v6MappedMip, transit: v6MappedTip, dst: v6Dst, }); err != nil { t.Fatal(err) } - tip, err := c.client.transitIPForMagicIP(tt.mip) + tip, err := c.ClientTransitIPForMagicIP(tt.mip) if tip != tt.wantTip { t.Fatalf("checking transit ip: want %v, got %v", tt.wantTip, tip) } @@ -1666,6 +2183,8 @@ func TestConnectorRealIPForTransitIPConnection(t *testing.T) { sn := makeSelfNode(t, []appctype.Conn25Attr{{ V4TransitIPPool: []netipx.IPRange{v4RangeFrom("40", "50")}, // 100.64.0.40 - 100.64.0.50 }}, []string{}) + cfg := mustConfig(t, sn) + mappedSrc := netip.MustParseAddr("100.0.0.1") unmappedSrc := netip.MustParseAddr("100.0.0.2") mappedTip := netip.MustParseAddr("100.64.0.41") @@ -1717,13 +2236,11 @@ func TestConnectorRealIPForTransitIPConnection(t *testing.T) { } { t.Run(tt.name, func(t *testing.T) { c := newConn25(t.Logf) - if err := c.reconfig(sn); err != nil { - t.Fatal(err) - } + c.reconfig(cfg) c.connector.transitIPs = map[netip.Addr]map[netip.Addr]appAddr{} c.connector.transitIPs[mappedSrc] = map[netip.Addr]appAddr{} c.connector.transitIPs[mappedSrc][mappedTip] = appAddr{addr: mappedMip} - mip, err := c.connector.realIPForTransitIPConnection(tt.src, tt.tip) + mip, err := c.ConnectorRealIPForTransitIPConnection(tt.src, tt.tip) if mip != tt.wantMip { t.Fatalf("checking magic ip: want %v, got %v", tt.wantMip, mip) } @@ -1739,7 +2256,7 @@ func TestIsKnownTransitIP(t *testing.T) { unknownTip := netip.MustParseAddr("100.64.0.42") c := newConn25(t.Logf) - c.client.assignments.insert(addrs{ + c.client.assignments.insert(&addrs{ transit: knownTip, }) @@ -1755,7 +2272,7 @@ func TestLinkLocalAllow(t *testing.T) { knownTip := netip.MustParseAddr("100.64.0.41") c := newConn25(t.Logf) - c.client.assignments.insert(addrs{ + c.client.assignments.insert(&addrs{ transit: knownTip, }) @@ -1812,10 +2329,9 @@ func TestGetMagicRange(t *testing.T) { V4MagicIPPool: []netipx.IPRange{netipx.IPRangeFrom(netip.MustParseAddr("0.0.0.1"), netip.MustParseAddr("0.0.0.3"))}, V6MagicIPPool: []netipx.IPRange{netipx.IPRangeFrom(netip.MustParseAddr("::1"), netip.MustParseAddr("::3"))}, }}, []string{}) + cfg := mustConfig(t, sn) c := newConn25(t.Logf) - if err := c.reconfig(sn); err != nil { - t.Fatal(err) - } + c.reconfig(cfg) ext := &extension{ conn25: c, } @@ -1853,3 +2369,69 @@ func TestGetMagicRange(t *testing.T) { } } } + +func TestReconfigDoesNotReissueInUseAddresses(t *testing.T) { + appName := "app1" + mustRange := func(from, to string) netipx.IPRange { + return netipx.IPRangeFrom(netip.MustParseAddr(from), netip.MustParseAddr(to)) + } + beforeRangeV4 := mustRange("0.0.0.1", "0.0.0.3") + beforeRangeV6 := mustRange("::1", "::3") + afterRangeV4 := mustRange("0.0.0.4", "0.0.0.7") + afterRangeV6 := mustRange("::4", "::7") + makeNodeFromMagicRange := func(v4, v6 netipx.IPRange) tailcfg.NodeView { + return makeSelfNode(t, []appctype.Conn25Attr{{ + Name: appName, + Connectors: []string{"tag:woo"}, + Domains: []string{"example.com"}, + V4MagicIPPool: []netipx.IPRange{v4}, + V6MagicIPPool: []netipx.IPRange{v6}, + V4TransitIPPool: []netipx.IPRange{mustRange("169.254.0.0", "169.254.0.10")}, + V6TransitIPPool: []netipx.IPRange{mustRange("fd7a:115c:a1e0:a99c:0200::", "fd7a:115c:a1e0:a99c:0200::10")}, + }}, []string{}) + } + domain := must.Get(dnsname.ToFQDN("example.com.")) + + for _, tt := range []struct { + name string + dstOne netip.Addr + dstTwo netip.Addr + }{ + { + name: "v4", + dstOne: netip.MustParseAddr("0.0.0.100"), + dstTwo: netip.MustParseAddr("0.0.0.101"), + }, + { + name: "v6", + dstOne: netip.MustParseAddr("::100"), + dstTwo: netip.MustParseAddr("::101"), + }, + } { + t.Run(tt.name, func(t *testing.T) { + c := newConn25(t.Logf) + ext := &extension{ + conn25: c, + } + + _, err := c.client.reserveAddresses(appName, domain, tt.dstOne) + if !errors.Is(err, errUninitializedIPPool) { + t.Fatalf("want %v, got %v", errUninitializedIPPool, err) + } + + ext.onSelfChange(makeNodeFromMagicRange(beforeRangeV4, beforeRangeV6)) + beforeAddrs, err := c.client.reserveAddresses(appName, domain, tt.dstOne) + if err != nil { + t.Fatal(err) + } + ext.onSelfChange(makeNodeFromMagicRange(afterRangeV4, afterRangeV6)) + afterAddrs, err := c.client.reserveAddresses(appName, domain, tt.dstTwo) + if err != nil { + t.Fatal(err) + } + if afterAddrs.magic == beforeAddrs.magic { + t.Errorf("pool reissued magic: %v that was already assigned", beforeAddrs.magic) + } + }) + } +} diff --git a/feature/conn25/ippool.go b/feature/conn25/ippool.go index e50186d88..72149cbee 100644 --- a/feature/conn25/ippool.go +++ b/feature/conn25/ippool.go @@ -8,17 +8,27 @@ import ( "net/netip" "go4.org/netipx" + "tailscale.com/util/set" ) // errPoolExhausted is returned when there are no more addresses to iterate over. var errPoolExhausted = errors.New("ip pool exhausted") -// ippool allows for iteration over all the addresses within a netipx.IPSet. +// errNotOurAddress is returned if a provided address is not from our pool +var errNotOurAddress = errors.New("not our address") + +// errAddrExists is returned if a returned address is already in the returned pool. +var errAddrExists = errors.New("address already returned") + +// errUninitializedIPPool is returned if the pool is used when it's not initialized +var errUninitializedIPPool = errors.New("uninitialized ippool") + +// ipSetIterator allows for round robin iteration over all the addresses within a netipx.IPSet. // netipx.IPSet has a Ranges call that returns the "minimum and sorted set of IP ranges that covers [the set]". // netipx.IPRange is "an inclusive range of IP addresses from the same address family.". So we can iterate over // all the addresses in the set by keeping a track of the last address we returned, calling Next on the last address -// to get the new one, and if we run off the edge of the current range, starting on the next one. -type ippool struct { +// to get the new one, and if we run off the edge of the current range, starting on the next one, or back at the beginning. +type ipSetIterator struct { // ranges defines the addresses in the pool ranges []netipx.IPRange // last is internal tracking of which the last address provided was. @@ -27,35 +37,106 @@ type ippool struct { rangeIdx int } +// next returns the next address from the set. +func (ipsi *ipSetIterator) next() (netip.Addr, error) { + if len(ipsi.ranges) == 0 { + // ipset is empty + return netip.Addr{}, errPoolExhausted + } + if !ipsi.last.IsValid() { + // not initialized yet + ipsi.last = ipsi.ranges[0].From() + return ipsi.last, nil + } + currRange := ipsi.ranges[ipsi.rangeIdx] + if ipsi.last == currRange.To() { + // then we need to move to the next range + ipsi.rangeIdx++ + if ipsi.rangeIdx >= len(ipsi.ranges) { + // back to the beginning + ipsi.rangeIdx = 0 + } + ipsi.last = ipsi.ranges[ipsi.rangeIdx].From() + return ipsi.last, nil + } + ipsi.last = ipsi.last.Next() + return ipsi.last, nil +} + func newIPPool(ipset *netipx.IPSet) *ippool { if ipset == nil { return &ippool{} } - return &ippool{ranges: ipset.Ranges()} + return &ippool{ + ipSet: ipset, + ipSetIterator: &ipSetIterator{ranges: ipset.Ranges()}, + inUse: &set.Set[netip.Addr]{}, + } } -// next returns the next address from the set, or errPoolExhausted if we have -// iterated over the whole set. +type ippool struct { + // ipSet defines the addresses within the ippool, it is configured by the user. + ipSet *netipx.IPSet + // ipSetIterator keeps track of iteration through the ippool. + ipSetIterator *ipSetIterator + // inUse is a set of addresses that have been handed out and not yet returned. + // Addresses in inUse won't be returned from next. + // Addresses in inUse may no longer be in the ipSet definition of the pool bounds + // if the ippool has been reconfigured. + inUse *set.Set[netip.Addr] +} + +// next returns the next available address from within the ippool. +// next will return errPoolExhausted if there are no more unused addresses. func (ipp *ippool) next() (netip.Addr, error) { - if ipp.rangeIdx >= len(ipp.ranges) { - // ipset is empty or we have iterated off the end - return netip.Addr{}, errPoolExhausted + if ipp == nil || ipp.ipSetIterator == nil { + return netip.Addr{}, errUninitializedIPPool } - if !ipp.last.IsValid() { - // not initialized yet - ipp.last = ipp.ranges[0].From() - return ipp.last, nil + a, err := ipp.ipSetIterator.next() + if err != nil { + return netip.Addr{}, err } - currRange := ipp.ranges[ipp.rangeIdx] - if ipp.last == currRange.To() { - // then we need to move to the next range - ipp.rangeIdx++ - if ipp.rangeIdx >= len(ipp.ranges) { + startedAt := a + for ipp.inUse.Contains(a) { + a, err = ipp.ipSetIterator.next() + if err != nil { + return a, err + } + if a == startedAt { return netip.Addr{}, errPoolExhausted } - ipp.last = ipp.ranges[ipp.rangeIdx].From() - return ipp.last, nil } - ipp.last = ipp.last.Next() - return ipp.last, nil + ipp.inUse.Add(a) + return a, nil +} + +// returnAddr puts an address back into the ippool, that address will +// now be available to be handed out when we iterate back around to it. +// returnAddr will return an error if the provided address is not one +// that's currently in inUse. +func (ipp *ippool) returnAddr(a netip.Addr) error { + if ipp.inUse.Contains(a) { + ipp.inUse.Delete(a) + return nil + } + if !ipp.ipSet.Contains(a) { + return errNotOurAddress + } + return errAddrExists +} + +// reconfig changes the definition of the addresses that are in the ippool +// while keeping track of the addresses that are currently in inUse. +func (ipp *ippool) reconfig(ipSet *netipx.IPSet) *ippool { + if ipp != nil && ipSet != nil && ipSet.Equal(ipp.ipSet) { + // in the common case that the definition has not changed, do nothing. + return ipp + } + newPool := newIPPool(ipSet) + if ipp != nil { + // even if the definition of which addresses are in the pool has changed + // we don't want to lose track of which addresses are currently in use + newPool.inUse = ipp.inUse + } + return newPool } diff --git a/feature/conn25/ippool_test.go b/feature/conn25/ippool_test.go index ccfaad3eb..1cc9845f5 100644 --- a/feature/conn25/ippool_test.go +++ b/feature/conn25/ippool_test.go @@ -13,7 +13,7 @@ import ( ) func TestNext(t *testing.T) { - a := ippool{} + a := ipSetIterator{} _, err := a.next() if !errors.Is(err, errPoolExhausted) { t.Fatalf("expected errPoolExhausted, got %v", err) @@ -58,3 +58,142 @@ func TestNext(t *testing.T) { t.Fatalf("expected errPoolExhausted, got %v", err) } } + +// TestReturnAddr tests that if a pool is exhausted, an address can be returned to the +// pool, and then that address will be handed out again. +func TestReturnAddr(t *testing.T) { + addrString := "192.168.0.0" + // There's an IPPool with one address in it. + var isb netipx.IPSetBuilder + isb.AddRange(netipx.IPRangeFrom(netip.MustParseAddr(addrString), netip.MustParseAddr(addrString))) + ipset := must.Get(isb.IPSet()) + ipp := newIPPool(ipset) + // The first time we call next we get the address. + addr, err := ipp.next() + if err != nil { + t.Fatalf("expected nil error, got: %v", err) + } + if addr != netip.MustParseAddr(addrString) { + t.Fatalf("want %v, got %v", addrString, addr) + } + // The second time we call next we get errPoolExhausted + _, err = ipp.next() + if !errors.Is(err, errPoolExhausted) { + t.Fatalf("expected errPoolExhausted, got %v", err) + } + // Return the addr to the pool + err = ipp.returnAddr(netip.MustParseAddr(addrString)) + if err != nil { + t.Fatal(err) + } + // It's not possible to return addresses that are already in the pool. + err = ipp.returnAddr(netip.MustParseAddr(addrString)) + if !errors.Is(err, errAddrExists) { + t.Fatalf("want errAddrExists, got: %v", err) + } + // When we call next we get the returned addr + addrAfterReturn, err := ipp.next() + if err != nil { + t.Fatalf("expected nil error, got: %v", err) + } + if addrAfterReturn != netip.MustParseAddr(addrString) { + t.Fatalf("want %v, got %v", addrString, addrAfterReturn) + } + // You can't return addresses that aren't from the pool. + err = ipp.returnAddr(netip.MustParseAddr("100.100.100.0")) + if !errors.Is(err, errNotOurAddress) { + t.Fatalf("want errNotOurAddress, got: %v", err) + } +} + +func expectAddrNext(t *testing.T, ipp *ippool, addrString string) { + t.Helper() + got, err := ipp.next() + if err != nil { + t.Fatalf("expected nil error, got: %v", err) + } + want := netip.MustParseAddr(addrString) + if want != got { + t.Fatalf("want %v; got %v", want, got) + } +} + +func expectErrPoolExhaustedNext(t *testing.T, ipp *ippool) { + t.Helper() + _, err := ipp.next() + if !errors.Is(err, errPoolExhausted) { + t.Fatalf("expected errPoolExhausted; got %v", err) + } +} + +// TestGettingReturnedAddresses tests that when addresses are returned to the IP Pool +// they are then handed out in the order they were returned. +func TestGettingReturnedAddresses(t *testing.T) { + var isb netipx.IPSetBuilder + isb.AddRange(netipx.IPRangeFrom(netip.MustParseAddr("192.168.0.0"), netip.MustParseAddr("192.168.0.4"))) + ipset := must.Get(isb.IPSet()) + ipp := newIPPool(ipset) + expectAddrNext(t, ipp, "192.168.0.0") + expectAddrNext(t, ipp, "192.168.0.1") + expectAddrNext(t, ipp, "192.168.0.2") + expectAddrNext(t, ipp, "192.168.0.3") + expectAddrNext(t, ipp, "192.168.0.4") + expectErrPoolExhaustedNext(t, ipp) + ipp.returnAddr(netip.MustParseAddr("192.168.0.2")) + ipp.returnAddr(netip.MustParseAddr("192.168.0.4")) + expectAddrNext(t, ipp, "192.168.0.2") + expectAddrNext(t, ipp, "192.168.0.4") + expectErrPoolExhaustedNext(t, ipp) +} + +func TestIPPoolReconfig(t *testing.T) { + var isb netipx.IPSetBuilder + isb.AddRange(netipx.IPRangeFrom(netip.MustParseAddr("192.168.0.0"), netip.MustParseAddr("192.168.0.4"))) + ipsetOne := must.Get(isb.IPSet()) + ipsetOneClone := must.Get(isb.IPSet()) + isb = netipx.IPSetBuilder{} + isb.AddRange(netipx.IPRangeFrom(netip.MustParseAddr("192.168.0.7"), netip.MustParseAddr("192.168.0.10"))) + ipsetTwo := must.Get(isb.IPSet()) + + var ipp *ippool + ipp = ipp.reconfig(ipsetOne) + if ipp.ipSet != ipsetOne { + t.Fatalf("want %v, got %v", ipsetOne, ipp.ipSet) + } + expectAddrNext(t, ipp, "192.168.0.0") + + // check that we don't lose iterator state when we reconfig with the same ranges + expectAddrNext(t, ipp, "192.168.0.1") + ipp.returnAddr(netip.MustParseAddr("192.168.0.1")) + ipp = ipp.reconfig(ipsetOneClone) + expectAddrNext(t, ipp, "192.168.0.2") + + // when we reconfig with different ranges, we only hand out addresses from the new ranges + ipp = ipp.reconfig(ipsetTwo) + if ipp.ipSet != ipsetTwo { + t.Fatalf("want %v, got %v", ipsetTwo, ipp.ipSet) + } + expectAddrNext(t, ipp, "192.168.0.7") + expectAddrNext(t, ipp, "192.168.0.8") + expectAddrNext(t, ipp, "192.168.0.9") + expectAddrNext(t, ipp, "192.168.0.10") + expectErrPoolExhaustedNext(t, ipp) + + // but we have not lost track of the fact that the old addresses are in use + if !ipp.inUse.Contains(netip.MustParseAddr("192.168.0.0")) { + t.Fatalf("expected inUse to still have the address") + } + + // old addresses can be returned + ipp.returnAddr(netip.MustParseAddr("192.168.0.0")) + + // but they are not handed out again + expectErrPoolExhaustedNext(t, ipp) + if ipp.inUse.Contains(netip.MustParseAddr("192.168.0.0")) { + t.Fatalf("expected inUse to no longer have the address") + } + + // returning addresses from the new ranges works as normal + ipp.returnAddr(netip.MustParseAddr("192.168.0.9")) + expectAddrNext(t, ipp, "192.168.0.9") +} diff --git a/feature/doctor/doctor.go b/feature/doctor/doctor.go index db061311b..01897f0a6 100644 --- a/feature/doctor/doctor.go +++ b/feature/doctor/doctor.go @@ -63,7 +63,7 @@ func visitDoctor(ctx context.Context, b *ipnlocal.LocalBackend, logf logger.Logf // IPs; this can interfere with our ability to connect to the Tailscale // controlplane. checks = append(checks, doctor.CheckFunc("dns-resolvers", func(_ context.Context, logf logger.Logf) error { - nm := b.NetMap() + nm := b.NetMapNoPeers() if nm == nil { return nil } diff --git a/feature/featuretags/featuretags.go b/feature/featuretags/featuretags.go index 4f34acbe8..2c12b6960 100644 --- a/feature/featuretags/featuretags.go +++ b/feature/featuretags/featuretags.go @@ -171,7 +171,6 @@ var Features = map[FeatureTag]FeatureMeta{ "ipnbus": {Sym: "IPNBus", Desc: "IPN notification bus (watch-ipn-bus) support, used by GUIs, debugging, and nicer 'tailscale up' support"}, "iptables": {Sym: "IPTables", Desc: "Linux iptables support"}, "kube": {Sym: "Kube", Desc: "Kubernetes integration"}, - "lazywg": {Sym: "LazyWG", Desc: "Lazy WireGuard configuration for memory-constrained devices with large netmaps"}, "linuxdnsfight": {Sym: "LinuxDNSFight", Desc: "Linux support for detecting DNS fights (inotify watching of /etc/resolv.conf)"}, "linkspeed": { Sym: "LinkSpeed", @@ -236,6 +235,10 @@ var Features = map[FeatureTag]FeatureMeta{ Desc: "Linux systemd-resolved integration", Deps: []FeatureTag{"dbus"}, }, + "routecheck": { + Sym: "RouteCheck", + Desc: "Support checking the reachability of overlapping routers, for choosing between multiple network paths to the same IP address", + }, "sdnotify": { Sym: "SDNotify", Desc: "systemd notification support", diff --git a/feature/posture/posture.go b/feature/posture/posture.go index d8db1ac19..0c60d38b0 100644 --- a/feature/posture/posture.go +++ b/feature/posture/posture.go @@ -8,8 +8,10 @@ package posture import ( "encoding/json" + "fmt" "net/http" + "tailscale.com/health" "tailscale.com/ipn/ipnext" "tailscale.com/ipn/ipnlocal" "tailscale.com/posture" @@ -25,6 +27,15 @@ func init() { ipnlocal.RegisterC2N("GET /posture/identity", handleC2NPostureIdentityGet) } +var postureSerialWarnable = health.Register(&health.Warnable{ + Code: "posture-checking-serial-collection-failed", + Title: "Device Posture: serial number collection failed", + Severity: health.SeverityMedium, + Text: func(args health.Args) string { + return fmt.Sprintf("Could not collect device serial numbers for posture checking. (%v)", args[health.ArgError]) + }, +}) + func newExtension(logf logger.Logf, b ipnext.SafeBackend) (ipnext.Extension, error) { e := &extension{ logf: logger.WithPrefix(logf, "posture: "), @@ -73,6 +84,9 @@ func handleC2NPostureIdentityGet(b *ipnlocal.LocalBackend, w http.ResponseWriter res.SerialNumbers, err = posture.GetSerialNumbers(b.PolicyClient(), e.logf) if err != nil { e.logf("c2n: GetSerialNumbers returned error: %v", err) + b.HealthTracker().SetUnhealthy(postureSerialWarnable, health.Args{health.ArgError: err.Error()}) + } else { + b.HealthTracker().SetHealthy(postureSerialWarnable) } // TODO(tailscale/corp#21371, 2024-07-10): once this has landed in a stable release diff --git a/feature/routecheck/routecheck.go b/feature/routecheck/routecheck.go new file mode 100644 index 000000000..055ceb379 --- /dev/null +++ b/feature/routecheck/routecheck.go @@ -0,0 +1,17 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +// Package routecheck registers support for RouteCheck, +// which checks the reachability of overlapping routers. +// +// When there are multiple network paths to an IP address, it is being routed by +// overlapping routers. The client uses reachability to pick between those +// paths: either sticking with an active WireGuard session or choosing from the +// peers that it has determined it can reach. It doesn’t need reachability for +// IP addresses that have only one network path, since it can naively attempt to +// establish a WireGuard session. +package routecheck + +func init() { + // TODO(sfllaw): Initialize the new routecheck package. +} diff --git a/feature/taildrop/doc.go b/feature/taildrop/doc.go index c394ebe82..a3243b3c2 100644 --- a/feature/taildrop/doc.go +++ b/feature/taildrop/doc.go @@ -1,5 +1,10 @@ // Copyright (c) Tailscale Inc & contributors // SPDX-License-Identifier: BSD-3-Clause -// Package taildrop registers the taildrop (file sending) feature. +// Package taildrop contains the implementation of the Taildrop +// functionality including sending and retrieving files. +// This package does not validate permissions, the caller should +// be responsible for ensuring correct authorization. +// +// For related documentation see: http://go/taildrop-how-does-it-work package taildrop diff --git a/feature/taildrop/taildrop.go b/feature/taildrop/taildrop.go index 7042ca97a..9839b8330 100644 --- a/feature/taildrop/taildrop.go +++ b/feature/taildrop/taildrop.go @@ -1,12 +1,6 @@ // Copyright (c) Tailscale Inc & contributors // SPDX-License-Identifier: BSD-3-Clause -// Package taildrop contains the implementation of the Taildrop -// functionality including sending and retrieving files. -// This package does not validate permissions, the caller should -// be responsible for ensuring correct authorization. -// -// For related documentation see: http://go/taildrop-how-does-it-work package taildrop import ( diff --git a/feature/tailnetlock/tailnetlock.go b/feature/tailnetlock/tailnetlock.go new file mode 100644 index 000000000..325a13b08 --- /dev/null +++ b/feature/tailnetlock/tailnetlock.go @@ -0,0 +1,54 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +// package tailnetlock registers the tailnet lock debug C2N handler. In the +// future, all tailnet lock code should move here. +package tailnetlock + +import ( + "fmt" + "net/http" + "strconv" + + "tailscale.com/cmd/tailscale/cli/jsonoutput" + "tailscale.com/feature" + "tailscale.com/feature/buildfeatures" + "tailscale.com/ipn/ipnlocal" +) + +func init() { + feature.Register("tailnetlock") + ipnlocal.RegisterC2N("/debug/tka/log", handleC2NDebugTKALog) +} + +const defaultC2NLogLimit = 50 +const maxC2NLogLimit = 1000 + +func handleC2NDebugTKALog(b *ipnlocal.LocalBackend, w http.ResponseWriter, r *http.Request) { + if !buildfeatures.HasDebug { + http.Error(w, feature.ErrUnavailable.Error(), http.StatusNotImplemented) + return + } + + logf := b.Logger() + logf("c2n: %s %s received", r.Method, r.URL) + + limit := defaultC2NLogLimit + limitStr := r.URL.Query().Get("limit") + if limitStr != "" { + if parsed, err := strconv.Atoi(limitStr); err == nil { + limit = min(parsed, maxC2NLogLimit) + } + } + + updates, err := b.NetworkLockLog(limit) + if ipnlocal.IsNetworkLockNotActive(err) { + http.Error(w, "tailnet lock not active", http.StatusBadRequest) + return + } else if err != nil { + http.Error(w, fmt.Sprintf("failed to get tailnet lock log: %v", err), http.StatusInternalServerError) + return + } + w.Header().Set("Content-Type", "application/json") + jsonoutput.PrintNetworkLockLogJSONV1(w, updates) +} diff --git a/feature/tailnetlock/tailnetlock_test.go b/feature/tailnetlock/tailnetlock_test.go new file mode 100644 index 000000000..bad294109 --- /dev/null +++ b/feature/tailnetlock/tailnetlock_test.go @@ -0,0 +1,143 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package tailnetlock + +import ( + "bytes" + "encoding/json" + "net/http/httptest" + "strings" + "testing" + + "tailscale.com/ipn/ipnlocal" + "tailscale.com/tka" + "tailscale.com/types/key" + "tailscale.com/util/must" +) + +func TestHandleC2NDebugTKA(t *testing.T) { + makeTKA := func(length int) (tka.CompactableChonk, *tka.Authority) { + if length == 0 { + return nil, nil + } + + signerKey := key.NewNLPrivate() + key1 := tka.Key{Kind: tka.Key25519, Public: signerKey.Public().Verifier(), Votes: 2} + state := tka.CreateStateForTest(key1) + + chonk := tka.ChonkMem() + authority, _, err := tka.Create(chonk, state, signerKey) + if err != nil { + t.Fatalf("tka.Create() failed: %v", err) + } + + for range length - 1 { + updater := authority.NewUpdater(signerKey) + key2 := tka.Key{Kind: tka.Key25519, Public: key.NewNLPrivate().Public().Verifier(), Votes: 2} + updater.AddKey(key2) + aums := must.Get(updater.Finalize(chonk)) + must.Do(authority.Inform(chonk, aums)) + } + + return chonk, authority + } + + bodyHead := func(body *bytes.Buffer) string { + count := 0 + var sb strings.Builder + for line := range strings.Lines(body.String()) { + if count == 10 { + sb.WriteString("...") + break + } + sb.WriteString(line) + count++ + } + return sb.String() + } + + // matches [jsonoutput.PrintNetworkLockLogJSONV1] + type response struct { + SchemaVersion string + Messages []any + } + + t.Run("tailnet-lock-disabled", func(t *testing.T) { + b := ipnlocal.LocalBackendWithTKAForTest(nil, nil) + + req := httptest.NewRequest("GET", "/debug/tka/log", nil) + rec := httptest.NewRecorder() + b.HandleC2NForTest(rec, req) + + if rec.Code != 400 { + t.Fatalf("got status code: %v, want: 400\nBody: %s", rec.Code, rec.Body) + } + }) + + t.Run("tailnet-lock-enabled", func(t *testing.T) { + chonk, authority := makeTKA(2) + b := ipnlocal.LocalBackendWithTKAForTest(chonk, authority) + + req := httptest.NewRequest("GET", "/debug/tka/log", nil) + rec := httptest.NewRecorder() + b.HandleC2NForTest(rec, req) + + if rec.Code != 200 { + t.Fatalf("got status code: %v, want: 200\nBody: %s", rec.Code, bodyHead(rec.Body)) + } + + var got response + if err := json.Unmarshal(rec.Body.Bytes(), &got); err != nil { + t.Fatalf("couldn't parse JSON: %v\nbody: %s", err, bodyHead(rec.Body)) + } + + if len(got.Messages) != 2 { + t.Fatalf("got %d items, want 2", len(got.Messages)) + } + }) + + t.Run("default-limit", func(t *testing.T) { + chonk, authority := makeTKA(60) + b := ipnlocal.LocalBackendWithTKAForTest(chonk, authority) + + req := httptest.NewRequest("GET", "/debug/tka/log", nil) + rec := httptest.NewRecorder() + b.HandleC2NForTest(rec, req) + + if rec.Code != 200 { + t.Fatalf("got status code: %v, want: 200\nBody: %s", rec.Code, bodyHead(rec.Body)) + } + + var got response + if err := json.Unmarshal(rec.Body.Bytes(), &got); err != nil { + t.Fatalf("couldn't parse JSON: %v\nbody: %s", err, bodyHead(rec.Body)) + } + + if len(got.Messages) != 50 { + t.Fatalf("got %d items, want 50", len(got.Messages)) + } + }) + + t.Run("override-limit", func(t *testing.T) { + chonk, authority := makeTKA(65) + b := ipnlocal.LocalBackendWithTKAForTest(chonk, authority) + + req := httptest.NewRequest("GET", "/debug/tka/log?limit=60", nil) + rec := httptest.NewRecorder() + b.HandleC2NForTest(rec, req) + + if rec.Code != 200 { + t.Fatalf("got status code: %v, want: 200\nBody: %s", rec.Code, bodyHead(rec.Body)) + } + + var got response + if err := json.Unmarshal(rec.Body.Bytes(), &got); err != nil { + t.Fatalf("couldn't parse JSON: %v\nbody: %s", err, bodyHead(rec.Body)) + } + + if len(got.Messages) != 60 { + t.Fatalf("got %d items, want 60", len(got.Messages)) + } + }) +} diff --git a/flake.nix b/flake.nix index 5c0fdbb2e..e2c237f42 100644 --- a/flake.nix +++ b/flake.nix @@ -48,7 +48,8 @@ }: let goVersion = nixpkgs.lib.fileContents ./go.toolchain.version; toolChainRev = nixpkgs.lib.fileContents ./go.toolchain.rev; - gitHash = nixpkgs.lib.fileContents ./go.toolchain.rev.sri; + flakeHashes = builtins.fromJSON (builtins.readFile ./flakehashes.json); + gitHash = flakeHashes.toolchain.sri; eachSystem = f: nixpkgs.lib.genAttrs (import systems) (system: f (import nixpkgs { @@ -103,7 +104,7 @@ name = "tailscale"; pname = "tailscale"; src = ./.; - vendorHash = pkgs.lib.fileContents ./go.mod.sri; + vendorHash = flakeHashes.vendor.sri; nativeBuildInputs = [pkgs.makeWrapper pkgs.installShellFiles]; ldflags = ["-X tailscale.com/version.gitCommitStamp=${tailscaleRev}"]; env.CGO_ENABLED = 0; @@ -163,4 +164,4 @@ }); }; } -# nix-direnv cache busting line: sha256-aZkUnWyQokNw+lxut9Fak3CazmwYE4tXILhzfK4jeK4= +# nix-direnv cache busting line: sha256-Xwm+ZLNqd2k7c2GFQJ2Pf/xuFLMcXhYl5I/YVgS9V4U= diff --git a/flakehashes.json b/flakehashes.json new file mode 100644 index 000000000..b5f6234e2 --- /dev/null +++ b/flakehashes.json @@ -0,0 +1,10 @@ +{ + "toolchain": { + "rev": "e877d973840c91ec9d4bc1921b0845789de359ae", + "sri": "sha256-HeD70CytKL0Ks/VDqMU73bN8fxpWkNc6mNgNr9PEO7k=" + }, + "vendor": { + "goModSum": "sha256-qAO4LAc1PwV43rr/kDsfYwkxeXAelP5DoNSZiCkwcpU=", + "sri": "sha256-Xwm+ZLNqd2k7c2GFQJ2Pf/xuFLMcXhYl5I/YVgS9V4U=" + } +} diff --git a/go.mod b/go.mod index 5c421e167..89038db40 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module tailscale.com -go 1.26.2 +go 1.26.3 require ( filippo.io/mkcert v1.4.4 @@ -17,7 +17,7 @@ require ( github.com/aws/aws-sdk-go-v2/service/ssm v1.44.7 github.com/axiomhq/hyperloglog v0.0.0-20240319100328-84253e514e02 github.com/bradfitz/go-tool-cache v0.0.0-20260216153636-9e5201344fe5 - github.com/bradfitz/monogok v0.0.0-20260310223834-65a3d9465088 + github.com/bradfitz/monogok v0.0.0-20260429173803-229ef7981a6b github.com/bramvdbogaerde/go-scp v1.4.0 github.com/cilium/ebpf v0.16.0 github.com/coder/websocket v1.8.12 @@ -41,16 +41,18 @@ require ( github.com/go-json-experiment/json v0.0.0-20250813024750-ebf49471dced github.com/go-logr/zapr v1.3.0 github.com/go-ole/go-ole v1.3.0 + github.com/go4org/hashtriemap v0.0.0-20251130024219-545ba229f689 github.com/go4org/plan9netshell v0.0.0-20250324183649-788daa080737 github.com/godbus/dbus/v5 v5.1.1-0.20230522191255-76236955d466 github.com/gokrazy/breakglass v0.0.0-20251229072214-9dbc0478d486 - github.com/gokrazy/gokrazy v0.0.0-20260123094004-294c93fa173c + github.com/gokrazy/gokrazy v0.0.0-20260418085648-c38c3134b8a7 + github.com/gokrazy/kernel.arm64 v0.0.0-20260403054012-807489e0272a github.com/gokrazy/serial-busybox v0.0.0-20250119153030-ac58ba7574e7 github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8 github.com/golang/snappy v0.0.4 github.com/golangci/golangci-lint v1.57.1 github.com/google/go-cmp v0.7.0 - github.com/google/go-containerregistry v0.20.7 + github.com/google/go-containerregistry v0.21.5 github.com/google/go-tpm v0.9.4 github.com/google/gopacket v1.1.19 github.com/google/nftables v0.2.1-0.20240414091927-5e242ec57806 @@ -68,7 +70,7 @@ require ( github.com/jsimonetti/rtnetlink v1.4.0 github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 github.com/kdomanski/iso9660 v0.4.0 - github.com/klauspost/compress v1.18.2 + github.com/klauspost/compress v1.18.5 github.com/kortschak/wol v0.0.0-20200729010619-da482cc4850a github.com/mattn/go-colorable v0.1.13 github.com/mattn/go-isatty v0.0.20 @@ -81,14 +83,14 @@ require ( github.com/pires/go-proxyproto v0.8.1 github.com/pkg/errors v0.9.1 github.com/pkg/sftp v1.13.6 - github.com/prometheus-community/pro-bing v0.4.0 github.com/prometheus/client_golang v1.23.0 github.com/prometheus/common v0.65.0 github.com/prometheus/prometheus v0.49.2-0.20240125131847-c3b8ef1694ff + github.com/robert-nix/ansihtml v1.0.1 github.com/safchain/ethtool v0.3.0 github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e github.com/studio-b12/gowebdav v0.9.0 - github.com/tailscale/certstore v0.1.1-0.20231202035212-d3fa0460f47e + github.com/tailscale/certstore v0.1.1-0.20260409135935-3638fb84b77d github.com/tailscale/depaware v0.0.0-20251001183927-9c2ad255ef3f github.com/tailscale/gliderssh v0.3.4-0.20260330083525-c1389c70ff89 github.com/tailscale/goexpect v0.0.0-20210902213824-6e8c725cea41 @@ -99,9 +101,10 @@ require ( github.com/tailscale/netlink v1.1.1-0.20240822203006-4d49adab4de7 github.com/tailscale/peercred v0.0.0-20250107143737-35a0c7bd7edc github.com/tailscale/setec v0.0.0-20251203133219-2ab774e4129a + github.com/tailscale/ts-gokrazy v0.0.0-20260429180033-fe741c6deb44 github.com/tailscale/web-client-prebuilt v0.0.0-20250124233751-d4cd19a26976 github.com/tailscale/wf v0.0.0-20240214030419-6fbb0a674ee6 - github.com/tailscale/wireguard-go v0.0.0-20260304043104-4184faf59e56 + github.com/tailscale/wireguard-go v0.0.0-20260427181203-e3ac4a0afb4e github.com/tailscale/xnet v0.0.0-20240729143630-8497ac4dab2e github.com/tc-hib/winres v0.2.1 github.com/tcnksm/go-httpstat v0.2.0 @@ -111,16 +114,16 @@ require ( go.uber.org/zap v1.27.0 go4.org/mem v0.0.0-20240501181205-ae6ca9944745 go4.org/netipx v0.0.0-20231129151722-fdeea329fbba - golang.org/x/crypto v0.46.0 + golang.org/x/crypto v0.50.0 golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b - golang.org/x/mod v0.31.0 - golang.org/x/net v0.48.0 + golang.org/x/mod v0.35.0 + golang.org/x/net v0.53.0 golang.org/x/oauth2 v0.36.0 - golang.org/x/sync v0.19.0 - golang.org/x/sys v0.40.0 - golang.org/x/term v0.38.0 + golang.org/x/sync v0.20.0 + golang.org/x/sys v0.43.0 + golang.org/x/term v0.42.0 golang.org/x/time v0.12.0 - golang.org/x/tools v0.40.1-0.20260108161641-ca281cf95054 + golang.org/x/tools v0.44.0 golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 golang.zx2c4.com/wireguard/windows v0.5.3 gopkg.in/square/go-jose.v2 v2.6.0 @@ -174,7 +177,7 @@ require ( github.com/cyphar/filepath-securejoin v0.6.1 // indirect github.com/deckarep/golang-set/v2 v2.8.0 // indirect github.com/dgryski/go-metro v0.0.0-20180109044635-280f6062b5bc // indirect - github.com/docker/go-connections v0.5.0 // indirect + github.com/docker/go-connections v0.6.0 // indirect github.com/docker/go-events v0.0.0-20250808211157-605354379745 // indirect github.com/docker/go-units v0.5.0 // indirect github.com/evanphx/json-patch v5.9.11+incompatible // indirect @@ -194,7 +197,7 @@ require ( github.com/google/gnostic-models v0.7.0 // indirect github.com/google/go-github/v66 v66.0.0 // indirect github.com/google/go-querystring v1.1.0 // indirect - github.com/google/renameio/v2 v2.0.0 // indirect + github.com/google/renameio/v2 v2.0.2 // indirect github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect github.com/gorilla/securecookie v1.1.2 // indirect github.com/gorilla/websocket v1.5.4-0.20250319132907-e064f32e3674 // indirect @@ -223,6 +226,8 @@ require ( github.com/mitchellh/go-wordwrap v1.0.1 // indirect github.com/moby/buildkit v0.20.2 // indirect github.com/moby/docker-image-spec v1.3.1 // indirect + github.com/moby/moby/api v1.54.1 // indirect + github.com/moby/moby/client v0.4.0 // indirect github.com/moby/spdystream v0.5.0 // indirect github.com/moby/term v0.5.2 // indirect github.com/monochromegane/go-gitignore v0.0.0-20200626010858-205db1a8cc00 // indirect @@ -237,7 +242,7 @@ require ( github.com/santhosh-tekuri/jsonschema/v5 v5.3.1 // indirect github.com/santhosh-tekuri/jsonschema/v6 v6.0.2 // indirect github.com/stacklok/frizbee v0.1.7 // indirect - github.com/vishvananda/netlink v1.3.1-0.20240922070040-084abd93d350 // indirect + github.com/vishvananda/netlink v1.3.1 // indirect github.com/xen0n/gosmopolitan v1.2.2 // indirect github.com/xlab/treeprint v1.2.0 // indirect github.com/ykadowak/zerologlint v0.1.5 // indirect @@ -253,7 +258,7 @@ require ( go.yaml.in/yaml/v2 v2.4.2 // indirect go.yaml.in/yaml/v3 v3.0.4 // indirect golang.org/x/crypto/x509roots/fallback v0.0.0-20260113154411-7d0074ccc6f1 // indirect - golang.org/x/telemetry v0.0.0-20251203150158-8fff8a5912fc // indirect + golang.org/x/telemetry v0.0.0-20260409153401-be6f6cb8b1fa // indirect golang.org/x/tools/go/packages/packagestest v0.1.1-deprecated // indirect golang.org/x/xerrors v0.0.0-20240716161551-93cc26a95ae9 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20251213004720-97cd9d5aeac2 // indirect @@ -317,14 +322,12 @@ require ( github.com/charithe/durationcheck v0.0.10 // indirect github.com/chavacava/garif v0.1.0 // indirect github.com/cloudflare/circl v1.6.3 // indirect - github.com/containerd/stargz-snapshotter/estargz v0.18.1 // indirect + github.com/containerd/stargz-snapshotter/estargz v0.18.2 // indirect github.com/curioswitch/go-reassign v0.2.0 // indirect github.com/daixiang0/gci v0.12.3 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/denis-tingaikin/go-header v0.5.0 // indirect - github.com/docker/cli v29.0.3+incompatible // indirect - github.com/docker/distribution v2.8.3+incompatible // indirect - github.com/docker/docker v28.5.2+incompatible // indirect + github.com/docker/cli v29.4.0+incompatible // indirect github.com/docker/docker-credential-helpers v0.9.3 // indirect github.com/emicklei/go-restful/v3 v3.12.2 // indirect github.com/emirpasic/gods v1.18.1 // indirect @@ -337,8 +340,8 @@ require ( github.com/fzipp/gocyclo v0.6.0 // indirect github.com/go-critic/go-critic v0.11.2 // indirect github.com/go-git/gcfg v1.5.1-0.20230307220236-3a3c6141e376 // indirect - github.com/go-git/go-billy/v5 v5.6.2 // indirect - github.com/go-git/go-git/v5 v5.16.5 // indirect + github.com/go-git/go-billy/v5 v5.8.0 // indirect + github.com/go-git/go-git/v5 v5.17.1 // indirect github.com/go-logr/logr v1.4.3 // indirect github.com/go-openapi/jsonpointer v0.21.0 // indirect github.com/go-openapi/jsonreference v0.20.4 // indirect @@ -445,7 +448,7 @@ require ( github.com/sergi/go-diff v1.3.2-0.20230802210424-5b0b94c5c0d3 // indirect github.com/shazow/go-diff v0.0.0-20160112020656-b6b7b6733b8c // indirect github.com/shopspring/decimal v1.4.0 // indirect - github.com/sirupsen/logrus v1.9.3 // indirect + github.com/sirupsen/logrus v1.9.4 // indirect github.com/sivchari/containedctx v1.0.3 // indirect github.com/sivchari/tenv v1.7.1 // indirect github.com/skeema/knownhosts v1.3.1 // indirect @@ -455,7 +458,7 @@ require ( github.com/spf13/cast v1.7.0 // indirect github.com/spf13/cobra v1.10.2 // indirect github.com/spf13/jwalterweatherman v1.1.0 // indirect - github.com/spf13/pflag v1.0.9 // indirect + github.com/spf13/pflag v1.0.10 // indirect github.com/spf13/viper v1.16.0 // indirect github.com/ssgreg/nlreturn/v2 v2.2.1 // indirect github.com/stbenjam/no-sprintf-host-port v0.1.1 // indirect @@ -485,7 +488,7 @@ require ( go.uber.org/multierr v1.11.0 // indirect golang.org/x/exp/typeparams v0.0.0-20240314144324-c7f7c6466f7f // indirect golang.org/x/image v0.27.0 // indirect - golang.org/x/text v0.32.0 // indirect + golang.org/x/text v0.36.0 // indirect gomodules.xyz/jsonpatch/v2 v2.4.0 // indirect google.golang.org/protobuf v1.36.11 // indirect gopkg.in/inf.v0 v0.9.1 // indirect diff --git a/go.mod.sri b/go.mod.sri deleted file mode 100644 index 6ffc2c3f1..000000000 --- a/go.mod.sri +++ /dev/null @@ -1 +0,0 @@ -sha256-aZkUnWyQokNw+lxut9Fak3CazmwYE4tXILhzfK4jeK4= diff --git a/go.sum b/go.sum index cc0046798..b5e8950c3 100644 --- a/go.sum +++ b/go.sum @@ -205,8 +205,8 @@ github.com/bombsimon/wsl/v4 v4.2.1 h1:Cxg6u+XDWff75SIFFmNsqnIOgob+Q9hG6y/ioKbRFi github.com/bombsimon/wsl/v4 v4.2.1/go.mod h1:Xu/kDxGZTofQcDGCtQe9KCzhHphIe0fDuyWTxER9Feo= github.com/bradfitz/go-tool-cache v0.0.0-20260216153636-9e5201344fe5 h1:0sG3c7afYdBNlc3QyhckvZ4bV9iqlfqCQM1i+mWm0eE= github.com/bradfitz/go-tool-cache v0.0.0-20260216153636-9e5201344fe5/go.mod h1:78ZLITnBUCDJeU01+wYYJKaPYYgsDzJPRfxeI8qFh5g= -github.com/bradfitz/monogok v0.0.0-20260310223834-65a3d9465088 h1:dDVY5cJ+7bQQll29aeWGx1Ima4RIGy/f1fXVs+HlIxo= -github.com/bradfitz/monogok v0.0.0-20260310223834-65a3d9465088/go.mod h1:TG1HbU9fRVDnNgXncVkKz9GdvjIvqquXjH6QZSEVmY4= +github.com/bradfitz/monogok v0.0.0-20260429173803-229ef7981a6b h1:lhWZfi1U/yi8zuFA6pkJKYv45pVAC3xs6SUE2QsjsEE= +github.com/bradfitz/monogok v0.0.0-20260429173803-229ef7981a6b/go.mod h1:TG1HbU9fRVDnNgXncVkKz9GdvjIvqquXjH6QZSEVmY4= github.com/bramvdbogaerde/go-scp v1.4.0 h1:jKMwpwCbcX1KyvDbm/PDJuXcMuNVlLGi0Q0reuzjyKY= github.com/bramvdbogaerde/go-scp v1.4.0/go.mod h1:on2aH5AxaFb2G0N5Vsdy6B0Ml7k9HuHSwfo1y0QzAbQ= github.com/breml/bidichk v0.2.7 h1:dAkKQPLl/Qrk7hnP6P+E0xOodrq8Us7+U0o4UBOAlQY= @@ -268,8 +268,8 @@ github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I= github.com/containerd/log v0.1.0/go.mod h1:VRRf09a7mHDIRezVKTRCrOq78v577GXq3bSa3EhrzVo= github.com/containerd/platforms v1.0.0-rc.2 h1:0SPgaNZPVWGEi4grZdV8VRYQn78y+nm6acgLGv/QzE4= github.com/containerd/platforms v1.0.0-rc.2/go.mod h1:J71L7B+aiM5SdIEqmd9wp6THLVRzJGXfNuWCZCllLA4= -github.com/containerd/stargz-snapshotter/estargz v0.18.1 h1:cy2/lpgBXDA3cDKSyEfNOFMA/c10O1axL69EU7iirO8= -github.com/containerd/stargz-snapshotter/estargz v0.18.1/go.mod h1:ALIEqa7B6oVDsrF37GkGN20SuvG/pIMm7FwP7ZmRb0Q= +github.com/containerd/stargz-snapshotter/estargz v0.18.2 h1:yXkZFYIzz3eoLwlTUZKz2iQ4MrckBxJjkmD16ynUTrw= +github.com/containerd/stargz-snapshotter/estargz v0.18.2/go.mod h1:XyVU5tcJ3PRpkA9XS2T5us6Eg35yM0214Y+wvrZTBrY= github.com/containerd/typeurl/v2 v2.2.3 h1:yNA/94zxWdvYACdYO8zofhrTVuQY73fFU1y++dYSw40= github.com/containerd/typeurl/v2 v2.2.3/go.mod h1:95ljDnPfD3bAbDJRugOiShd/DlAAsxGtUBhJxIn7SCk= github.com/coreos/go-iptables v0.7.1-0.20240112124308-65c67c9f46e6 h1:8h5+bWd7R6AYUslN6c6iuZWTKsKxUFDlpnmilO6R2n0= @@ -319,16 +319,12 @@ github.com/djherbis/times v1.6.0 h1:w2ctJ92J8fBvWPxugmXIv7Nz7Q3iDMKNx9v5ocVH20c= github.com/djherbis/times v1.6.0/go.mod h1:gOHeRAz2h+VJNZ5Gmc/o7iD9k4wW7NMVqieYCY99oc0= github.com/dlclark/regexp2 v1.11.0 h1:G/nrcoOa7ZXlpoa/91N3X7mM3r8eIlMBBJZvsz/mxKI= github.com/dlclark/regexp2 v1.11.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= -github.com/docker/cli v29.0.3+incompatible h1:8J+PZIcF2xLd6h5sHPsp5pvvJA+Sr2wGQxHkRl53a1E= -github.com/docker/cli v29.0.3+incompatible/go.mod h1:JLrzqnKDaYBop7H2jaqPtU4hHvMKP+vjCwu2uszcLI8= -github.com/docker/distribution v2.8.3+incompatible h1:AtKxIZ36LoNK51+Z6RpzLpddBirtxJnzDrHLEKxTAYk= -github.com/docker/distribution v2.8.3+incompatible/go.mod h1:J2gT2udsDAN96Uj4KfcMRqY0/ypR+oyYUYmja8H+y+w= -github.com/docker/docker v28.5.2+incompatible h1:DBX0Y0zAjZbSrm1uzOkdr1onVghKaftjlSWt4AFexzM= -github.com/docker/docker v28.5.2+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk= +github.com/docker/cli v29.4.0+incompatible h1:+IjXULMetlvWJiuSI0Nbor36lcJ5BTcVpUmB21KBoVM= +github.com/docker/cli v29.4.0+incompatible/go.mod h1:JLrzqnKDaYBop7H2jaqPtU4hHvMKP+vjCwu2uszcLI8= github.com/docker/docker-credential-helpers v0.9.3 h1:gAm/VtF9wgqJMoxzT3Gj5p4AqIjCBS4wrsOh9yRqcz8= github.com/docker/docker-credential-helpers v0.9.3/go.mod h1:x+4Gbw9aGmChi3qTLZj8Dfn0TD20M/fuWy0E5+WDeCo= -github.com/docker/go-connections v0.5.0 h1:USnMq7hx7gwdVZq1L49hLXaFtUdTADjXGp+uj1Br63c= -github.com/docker/go-connections v0.5.0/go.mod h1:ov60Kzw0kKElRwhNs9UlUHAE/F9Fe6GLaXnqyDdmEXc= +github.com/docker/go-connections v0.6.0 h1:LlMG9azAe1TqfR7sO+NJttz1gy6KO7VJBh+pMmjSD94= +github.com/docker/go-connections v0.6.0/go.mod h1:AahvXYshr6JgfUJGdDCs2b5EZG/vmaMAntpSFH5BFKE= github.com/docker/go-events v0.0.0-20250808211157-605354379745 h1:yOn6Ze6IbYI/KAw2lw/83ELYvZh6hvsygTVkD0dzMC4= github.com/docker/go-events v0.0.0-20250808211157-605354379745/go.mod h1:Uw6UezgYA44ePAFQYUehOuCzmy5zmg/+nl2ZfMWGkpA= github.com/docker/go-metrics v0.0.1 h1:AgB/0SvBxihN0X8OR4SjsblXkbMvalQ8cjmtKQ2rQV8= @@ -396,12 +392,12 @@ github.com/go-errors/errors v1.4.2 h1:J6MZopCL4uSllY1OfXM374weqZFFItUbrImctkmUxI github.com/go-errors/errors v1.4.2/go.mod h1:sIVyrIiJhuEF+Pj9Ebtd6P/rEYROXFi3BopGUQ5a5Og= github.com/go-git/gcfg v1.5.1-0.20230307220236-3a3c6141e376 h1:+zs/tPmkDkHx3U66DAb0lQFJrpS6731Oaa12ikc+DiI= github.com/go-git/gcfg v1.5.1-0.20230307220236-3a3c6141e376/go.mod h1:an3vInlBmSxCcxctByoQdvwPiA7DTK7jaaFDBTtu0ic= -github.com/go-git/go-billy/v5 v5.6.2 h1:6Q86EsPXMa7c3YZ3aLAQsMA0VlWmy43r6FHqa/UNbRM= -github.com/go-git/go-billy/v5 v5.6.2/go.mod h1:rcFC2rAsp/erv7CMz9GczHcuD0D32fWzH+MJAU+jaUU= +github.com/go-git/go-billy/v5 v5.8.0 h1:I8hjc3LbBlXTtVuFNJuwYuMiHvQJDq1AT6u4DwDzZG0= +github.com/go-git/go-billy/v5 v5.8.0/go.mod h1:RpvI/rw4Vr5QA+Z60c6d6LXH0rYJo0uD5SqfmrrheCY= github.com/go-git/go-git-fixtures/v4 v4.3.2-0.20231010084843-55a94097c399 h1:eMje31YglSBqCdIqdhKBW8lokaMrL3uTkpGYlE2OOT4= github.com/go-git/go-git-fixtures/v4 v4.3.2-0.20231010084843-55a94097c399/go.mod h1:1OCfN199q1Jm3HZlxleg+Dw/mwps2Wbk9frAWm+4FII= -github.com/go-git/go-git/v5 v5.16.5 h1:mdkuqblwr57kVfXri5TTH+nMFLNUxIj9Z7F5ykFbw5s= -github.com/go-git/go-git/v5 v5.16.5/go.mod h1:QOMLpNf1qxuSY4StA/ArOdfFR2TrKEjJiye2kel2m+M= +github.com/go-git/go-git/v5 v5.17.1 h1:WnljyxIzSj9BRRUlnmAU35ohDsjRK0EKmL0evDqi5Jk= +github.com/go-git/go-git/v5 v5.17.1/go.mod h1:pW/VmeqkanRFqR6AljLcs7EA7FbZaN5MQqO7oZADXpo= github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= @@ -467,6 +463,8 @@ github.com/go-viper/mapstructure/v2 v2.4.0 h1:EBsztssimR/CONLSZZ04E8qAkxNYq4Qp9L github.com/go-viper/mapstructure/v2 v2.4.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= github.com/go-xmlfmt/xmlfmt v1.1.2 h1:Nea7b4icn8s57fTx1M5AI4qQT5HEM3rVUO8MuE6g80U= github.com/go-xmlfmt/xmlfmt v1.1.2/go.mod h1:aUCEOzzezBEjDBbFBoSiya/gduyIiWYRP6CnSFIV8AM= +github.com/go4org/hashtriemap v0.0.0-20251130024219-545ba229f689 h1:0psnKZ+N2IP43/SZC8SKx6OpFJwLmQb9m9QyV9BC2f8= +github.com/go4org/hashtriemap v0.0.0-20251130024219-545ba229f689/go.mod h1:OGmRfY/9QEK2P5zCRtmqfbCF283xPkU2dvVA4MvbvpI= github.com/go4org/plan9netshell v0.0.0-20250324183649-788daa080737 h1:cf60tHxREO3g1nroKr2osU3JWZsJzkfi7rEg+oAB0Lo= github.com/go4org/plan9netshell v0.0.0-20250324183649-788daa080737/go.mod h1:MIS0jDzbU/vuM9MC4YnBITCv+RYuTRq8dJzmCrFsK9g= github.com/gobuffalo/flect v1.0.3 h1:xeWBM2nui+qnVvNM4S3foBhCAL2XgPU+a7FdpelbTq4= @@ -487,11 +485,13 @@ github.com/gokrazy/breakglass v0.0.0-20251229072214-9dbc0478d486/go.mod h1:PFPkR github.com/gokrazy/gokapi v0.0.0-20250222071133-506fdb322775 h1:f5+2UMRRbr3+e/gdWCBNn48chS/KMMljfbmlSSHfRBA= github.com/gokrazy/gokapi v0.0.0-20250222071133-506fdb322775/go.mod h1:q9mIV8al0wqmqFXJhKiO3SOHkL9/7Q4kIMynqUQWhgU= github.com/gokrazy/gokrazy v0.0.0-20200501080617-f3445e01a904/go.mod h1:pq6rGHqxMRPSaTXaCMzIZy0wLDusAJyoVNyNo05RLs0= -github.com/gokrazy/gokrazy v0.0.0-20260123094004-294c93fa173c h1:grjqEMf6dPJzZxf+gdo8rjx6bcyseO5p9hierlVkhXQ= -github.com/gokrazy/gokrazy v0.0.0-20260123094004-294c93fa173c/go.mod h1:NtMkrFeDGnwldKLi0dLdd2ipNwoVa7TI4HTxsy7lFRg= +github.com/gokrazy/gokrazy v0.0.0-20260418085648-c38c3134b8a7 h1:Isk3pOiVO5uj4BSrfRlQ16v6YpelnrTgMC618hEkKJ8= +github.com/gokrazy/gokrazy v0.0.0-20260418085648-c38c3134b8a7/go.mod h1:NtMkrFeDGnwldKLi0dLdd2ipNwoVa7TI4HTxsy7lFRg= github.com/gokrazy/internal v0.0.0-20200407075822-660ad467b7c9/go.mod h1:LA5TQy7LcvYGQOy75tkrYkFUhbV2nl5qEBP47PSi2JA= github.com/gokrazy/internal v0.0.0-20251208203110-3c1aa9087c82 h1:4ghNfD9NaZLpFrqQiBF6mPVFeMYXJSky38ubVA4ic2E= github.com/gokrazy/internal v0.0.0-20251208203110-3c1aa9087c82/go.mod h1:dQY4EMkD4L5ZjYJ0SPtpgYbV7MIUMCxNIXiOfnZ6jP4= +github.com/gokrazy/kernel.arm64 v0.0.0-20260403054012-807489e0272a h1:fa11POmSLo6fkkcqc+RUIyiqGJzBAOHEe/CCHAA/NGc= +github.com/gokrazy/kernel.arm64 v0.0.0-20260403054012-807489e0272a/go.mod h1:WWx72LXHEesuJxbopusRfSoKJQ6ffdwkT0DZditdrLo= github.com/gokrazy/serial-busybox v0.0.0-20250119153030-ac58ba7574e7 h1:gurTGc4sL7Ik+IKZ29rhGgHNZQTXPtEXLw+aM9E+/HE= github.com/gokrazy/serial-busybox v0.0.0-20250119153030-ac58ba7574e7/go.mod h1:OYcG5tSb+QrelmUOO4EZVUFcIHyyZb0QDbEbZFUp1TA= github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 h1:DACJavvAHhabrF08vX0COfcOBJRhZ8lUbR+ZWIs0Y5g= @@ -564,8 +564,8 @@ github.com/google/go-cmp v0.5.8/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeN github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= -github.com/google/go-containerregistry v0.20.7 h1:24VGNpS0IwrOZ2ms2P1QE3Xa5X9p4phx0aUgzYzHW6I= -github.com/google/go-containerregistry v0.20.7/go.mod h1:Lx5LCZQjLH1QBaMPeGwsME9biPeo1lPx6lbGj/UmzgM= +github.com/google/go-containerregistry v0.21.5 h1:KTJG9Pn/jC0VdZR6ctV3/jcN+q6/Iqlx0sTVz3ywZlM= +github.com/google/go-containerregistry v0.21.5/go.mod h1:ySvMuiWg+dOsRW0Hw8GYwfMwBlNRTmpYBFJPlkco5zU= github.com/google/go-github/v66 v66.0.0 h1:ADJsaXj9UotwdgK8/iFZtv7MLc8E8WBl62WLd/D/9+M= github.com/google/go-github/v66 v66.0.0/go.mod h1:+4SO9Zkuyf8ytMj0csN1NR/5OTR+MfqPp8P8dVlcvY4= github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD/fhyJ8= @@ -596,8 +596,8 @@ github.com/google/pprof v0.0.0-20200708004538-1a94d8640e99/go.mod h1:ZgVRPoUq/hf github.com/google/pprof v0.0.0-20241029153458-d1b30febd7db h1:097atOisP2aRj7vFgYQBbFN4U4JNXUNYpxael3UzMyo= github.com/google/pprof v0.0.0-20241029153458-d1b30febd7db/go.mod h1:vavhavw2zAxS5dIdcRluK6cSGGPlZynqzFM8NdvU144= github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= -github.com/google/renameio/v2 v2.0.0 h1:UifI23ZTGY8Tt29JbYFiuyIU3eX+RNFtUwefq9qAhxg= -github.com/google/renameio/v2 v2.0.0/go.mod h1:BtmJXm5YlszgC+TD4HOEEUFgkJP3nLxehU6hfe7jRt4= +github.com/google/renameio/v2 v2.0.2 h1:qKZs+tfn+arruZZhQ7TKC/ergJunuJicWS6gLDt/dGw= +github.com/google/renameio/v2 v2.0.2/go.mod h1:OX+G6WHHpHq3NVj7cAOleLOwJfcQ1s3uUJQCrr78SWo= github.com/google/rpmpack v0.5.0 h1:L16KZ3QvkFGpYhmp23iQip+mx1X39foEsqszjMNBm8A= github.com/google/rpmpack v0.5.0/go.mod h1:uqVAUVQLq8UY2hCDfmJ/+rtO3aw7qyhc90rCVEabEfI= github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaUGG7oYTSPP8MxqL4YI3kZKwcP4= @@ -765,8 +765,8 @@ github.com/kisielk/errcheck v1.7.0/go.mod h1:1kLL+jV4e+CFfueBmI1dSK2ADDyQnlrnrY/ github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/kkHAIKE/contextcheck v1.1.4 h1:B6zAaLhOEEcjvUgIYEqystmnFk1Oemn8bvJhbt0GMb8= github.com/kkHAIKE/contextcheck v1.1.4/go.mod h1:1+i/gWqokIa+dm31mqGLZhZJ7Uh44DJGZVmr6QRBNJg= -github.com/klauspost/compress v1.18.2 h1:iiPHWW0YrcFgpBYhsA6D1+fqHssJscY/Tm/y2Uqnapk= -github.com/klauspost/compress v1.18.2/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4= +github.com/klauspost/compress v1.18.5 h1:/h1gH5Ce+VWNLSWqPzOVn6XBO+vJbCNGvjoaGBFW2IE= +github.com/klauspost/compress v1.18.5/go.mod h1:cwPg85FWrGar70rWktvGQj8/hthj3wpl0PGDogxkrSQ= github.com/klauspost/pgzip v1.2.6 h1:8RXeL5crjEUFnR2/Sn6GJNWtSQ3Dk8pq4CL3jvdDyjU= github.com/klauspost/pgzip v1.2.6/go.mod h1:Ch1tH69qFZu15pkjo5kYi6mth2Zzwzt50oCQKQE9RUs= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= @@ -873,12 +873,12 @@ github.com/moby/buildkit v0.20.2 h1:qIeR47eQ1tzI1rwz0on3Xx2enRw/1CKjFhoONVcTlMA= github.com/moby/buildkit v0.20.2/go.mod h1:DhaF82FjwOElTftl0JUAJpH/SUIUx4UvcFncLeOtlDI= github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0= github.com/moby/docker-image-spec v1.3.1/go.mod h1:eKmb5VW8vQEh/BAr2yvVNvuiJuY6UIocYsFu/DxxRpo= +github.com/moby/moby/api v1.54.1 h1:TqVzuJkOLsgLDDwNLmYqACUuTehOHRGKiPhvH8V3Nn4= +github.com/moby/moby/api v1.54.1/go.mod h1:+RQ6wluLwtYaTd1WnPLykIDPekkuyD/ROWQClE83pzs= +github.com/moby/moby/client v0.4.0 h1:S+2XegzHQrrvTCvF6s5HFzcrywWQmuVnhOXe2kiWjIw= +github.com/moby/moby/client v0.4.0/go.mod h1:QWPbvWchQbxBNdaLSpoKpCdf5E+WxFAgNHogCWDoa7g= github.com/moby/spdystream v0.5.0 h1:7r0J1Si3QO/kjRitvSLVVFUjxMEb/YLj6S9FF62JBCU= github.com/moby/spdystream v0.5.0/go.mod h1:xBAYlnt/ay+11ShkdFKNAG7LsyK/tmNBVvVOwrfMgdI= -github.com/moby/sys/atomicwriter v0.1.0 h1:kw5D/EqkBwsBFi0ss9v1VG3wIkVhzGvLklJ+w3A14Sw= -github.com/moby/sys/atomicwriter v0.1.0/go.mod h1:Ul8oqv2ZMNHOceF643P6FKPXeCmYtlQMvpizfsSoaWs= -github.com/moby/sys/sequential v0.6.0 h1:qrx7XFUd/5DxtqcoH1h438hF5TmOvzC/lspjy7zgvCU= -github.com/moby/sys/sequential v0.6.0/go.mod h1:uyv8EUTrca5PnDsdMGXhZe6CCe8U/UiTWd+lL+7b/Ko= github.com/moby/term v0.5.2 h1:6qk3FJAFDs6i/q3W/pQ97SX192qKfZgGjCQqfCJkgzQ= github.com/moby/term v0.5.2/go.mod h1:d3djjFCrjnB+fl8NJux+EJzu0msscUP+f8it8hPkFLc= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= @@ -893,8 +893,6 @@ github.com/monochromegane/go-gitignore v0.0.0-20200626010858-205db1a8cc00 h1:n6/ github.com/monochromegane/go-gitignore v0.0.0-20200626010858-205db1a8cc00/go.mod h1:Pm3mSP3c5uWn86xMLZ5Sa7JB9GsEZySvHYXCTK4E9q4= github.com/moricho/tparallel v0.3.1 h1:fQKD4U1wRMAYNngDonW5XupoB/ZGJHdpzrWqgyg9krA= github.com/moricho/tparallel v0.3.1/go.mod h1:leENX2cUv7Sv2qDgdi0D0fCftN8fRC67Bcn8pqzeYNI= -github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A= -github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= @@ -970,8 +968,6 @@ github.com/poy/onpar v1.1.2 h1:QaNrNiZx0+Nar5dLgTVp5mXkyoVFIbepjyEoGSnhbAY= github.com/poy/onpar v1.1.2/go.mod h1:6X8FLNoxyr9kkmnlqpK6LSoiOtrO6MICtWwEuWkLjzg= github.com/prashantv/gostub v1.1.0 h1:BTyx3RfQjRHnUWaGF9oQos79AlQ5k8WNktv7VGvVH4g= github.com/prashantv/gostub v1.1.0/go.mod h1:A5zLQHz7ieHGG7is6LLXLz7I8+3LZzsrV0P1IAHhP5U= -github.com/prometheus-community/pro-bing v0.4.0 h1:YMbv+i08gQz97OZZBwLyvmmQEEzyfyrrjEaAchdy3R4= -github.com/prometheus-community/pro-bing v0.4.0/go.mod h1:b7wRYZtCcPmt4Sz319BykUU241rWLe1VFXyiyWK/dH4= github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw= github.com/prometheus/client_golang v1.0.0/go.mod h1:db9x61etRT2tGnBNRi70OPL5FsnadC4Ky3P0J6CfImo= github.com/prometheus/client_golang v1.4.0/go.mod h1:e9GMxYsXl05ICDXkRhurwBS4Q3OK1iX/F2sw+iXX5zU= @@ -1023,6 +1019,8 @@ github.com/redis/go-redis/v9 v9.7.3/go.mod h1:bGUrSggJ9X9GUmZpZNEOQKaANxSGgOEBRl github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/rivo/uniseg v0.4.4 h1:8TfxU8dW6PdqD27gjM8MVNuicgxIjxpm4K7x4jp8sis= github.com/rivo/uniseg v0.4.4/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= +github.com/robert-nix/ansihtml v1.0.1 h1:VTiyQ6/+AxSJoSSLsMecnkh8i0ZqOEdiRl/odOc64fc= +github.com/robert-nix/ansihtml v1.0.1/go.mod h1:CJwclxYaTPc2RfcxtanEACsYuTksh4yDXcNeHHKZINE= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= @@ -1065,8 +1063,8 @@ github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPx github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88= github.com/sirupsen/logrus v1.7.0/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0= -github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= -github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= +github.com/sirupsen/logrus v1.9.4 h1:TsZE7l11zFCLZnZ+teH4Umoq5BhEIfIzfRDZ1Uzql2w= +github.com/sirupsen/logrus v1.9.4/go.mod h1:ftWc9WdOfJ0a92nsE2jF5u5ZwH8Bv2zdeOC42RjbV2g= github.com/sivchari/containedctx v1.0.3 h1:x+etemjbsh2fB5ewm5FeLNi5bUjK0V8n0RB+Wwfd0XE= github.com/sivchari/containedctx v1.0.3/go.mod h1:c1RDvCbnJLtH4lLcYD/GqwiBSSf4F5Qk0xld2rBqzJ4= github.com/sivchari/tenv v1.7.1 h1:PSpuD4bu6fSmtWMxSGWcvqUUgIn7k3yOJhOIzVWn8Ak= @@ -1092,8 +1090,9 @@ github.com/spf13/cobra v1.10.2/go.mod h1:7C1pvHqHw5A4vrJfjNwvOdzYu0Gml16OCs2GRiT github.com/spf13/jwalterweatherman v1.1.0 h1:ue6voC5bR5F8YxI5S67j9i582FU4Qvo2bmqnqMYADFk= github.com/spf13/jwalterweatherman v1.1.0/go.mod h1:aNWZUN0dPAAO/Ljvb5BEdw96iTZ0EXowPYD95IqWIGo= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= -github.com/spf13/pflag v1.0.9 h1:9exaQaMOCwffKiiiYk6/BndUBv+iRViNW+4lEMi0PvY= github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk= +github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/spf13/viper v1.16.0 h1:rGGH0XDZhdUOryiDWjmIvUSWpbNqisK8Wk0Vyefw8hc= github.com/spf13/viper v1.16.0/go.mod h1:yg78JgCJcbrQOvV9YLXgkLaZqUidkY9K+Dd1FofRzQg= github.com/ssgreg/nlreturn/v2 v2.2.1 h1:X4XDI7jstt3ySqGU86YGAURbxw3oTDPK9sPEi6YEwQ0= @@ -1128,8 +1127,8 @@ github.com/subosito/gotenv v1.4.2 h1:X1TuBLAMDFbaTAChgCBLu3DU3UPyELpnF2jjJ2cz/S8 github.com/subosito/gotenv v1.4.2/go.mod h1:ayKnFf/c6rvx/2iiLrJUk1e6plDbT3edrFNGqEflhK0= github.com/t-yuki/gocover-cobertura v0.0.0-20180217150009-aaee18c8195c h1:+aPplBwWcHBo6q9xrfWdMrT9o4kltkmmvpemgIjep/8= github.com/t-yuki/gocover-cobertura v0.0.0-20180217150009-aaee18c8195c/go.mod h1:SbErYREK7xXdsRiigaQiQkI9McGRzYMvlKYaP3Nimdk= -github.com/tailscale/certstore v0.1.1-0.20231202035212-d3fa0460f47e h1:PtWT87weP5LWHEY//SWsYkSO3RWRZo4OSWagh3YD2vQ= -github.com/tailscale/certstore v0.1.1-0.20231202035212-d3fa0460f47e/go.mod h1:XrBNfAFN+pwoWuksbFS9Ccxnopa15zJGgXRFN90l3K4= +github.com/tailscale/certstore v0.1.1-0.20260409135935-3638fb84b77d h1:JcGKBZAL7ePLwOhUdN8qGQZlP5GueEiIZwY7R62pejE= +github.com/tailscale/certstore v0.1.1-0.20260409135935-3638fb84b77d/go.mod h1:XrBNfAFN+pwoWuksbFS9Ccxnopa15zJGgXRFN90l3K4= github.com/tailscale/depaware v0.0.0-20251001183927-9c2ad255ef3f h1:PDPGJtm9PFBLNudHGwkfUGp/FWvP+kXXJ0D1pB35F40= github.com/tailscale/depaware v0.0.0-20251001183927-9c2ad255ef3f/go.mod h1:p9lPsd+cx33L3H9nNoecRRxPssFKUwwI50I3pZ0yT+8= github.com/tailscale/gliderssh v0.3.4-0.20260330083525-c1389c70ff89 h1:glgVc1ZYMjwN1Q/ITWeuSQyl029uayagaR2sjsifehc= @@ -1152,12 +1151,14 @@ github.com/tailscale/peercred v0.0.0-20250107143737-35a0c7bd7edc h1:24heQPtnFR+y github.com/tailscale/peercred v0.0.0-20250107143737-35a0c7bd7edc/go.mod h1:f93CXfllFsO9ZQVq+Zocb1Gp4G5Fz0b0rXHLOzt/Djc= github.com/tailscale/setec v0.0.0-20251203133219-2ab774e4129a h1:TApskGPim53XY5WRt5hX4DnO8V6CmVoimSklryIoGMM= github.com/tailscale/setec v0.0.0-20251203133219-2ab774e4129a/go.mod h1:+6WyG6kub5/5uPsMdYQuSti8i6F5WuKpFWLQnZt/Mms= +github.com/tailscale/ts-gokrazy v0.0.0-20260429180033-fe741c6deb44 h1:a6GdEBrBcDy/4XQ2CxKQvuCaKN8EFL5JTE7ZFOkXDzQ= +github.com/tailscale/ts-gokrazy v0.0.0-20260429180033-fe741c6deb44/go.mod h1:mu0sethAvP7xItcfBAxMJWiXZ3ZQ5qbKmjPYizOkSHE= github.com/tailscale/web-client-prebuilt v0.0.0-20250124233751-d4cd19a26976 h1:UBPHPtv8+nEAy2PD8RyAhOYvau1ek0HDJqLS/Pysi14= github.com/tailscale/web-client-prebuilt v0.0.0-20250124233751-d4cd19a26976/go.mod h1:agQPE6y6ldqCOui2gkIh7ZMztTkIQKH049tv8siLuNQ= github.com/tailscale/wf v0.0.0-20240214030419-6fbb0a674ee6 h1:l10Gi6w9jxvinoiq15g8OToDdASBni4CyJOdHY1Hr8M= github.com/tailscale/wf v0.0.0-20240214030419-6fbb0a674ee6/go.mod h1:ZXRML051h7o4OcI0d3AaILDIad/Xw0IkXaHM17dic1Y= -github.com/tailscale/wireguard-go v0.0.0-20260304043104-4184faf59e56 h1:/R1vu+eNhg1eKstmVPEKvsJgkh4TUyb+J+Eadwv+d/I= -github.com/tailscale/wireguard-go v0.0.0-20260304043104-4184faf59e56/go.mod h1:zvaAPQrjUBWufXgqpSQ1/BYu9ZFOKnsNWLFQe+E78cM= +github.com/tailscale/wireguard-go v0.0.0-20260427181203-e3ac4a0afb4e h1:GexFR7ak1iz26fxg8HWCpOEqAOL8UEZJ7J3JxeCalDs= +github.com/tailscale/wireguard-go v0.0.0-20260427181203-e3ac4a0afb4e/go.mod h1:6SerzcvHWQchKO2BfNdmquA77CHSECZuFl+D9fp4RnI= github.com/tailscale/xnet v0.0.0-20240729143630-8497ac4dab2e h1:zOGKqN5D5hHhiYUp091JqK7DPCqSARyUfduhGUY8Bek= github.com/tailscale/xnet v0.0.0-20240729143630-8497ac4dab2e/go.mod h1:orPd6JZXXRyuDusYilywte7k094d7dycXXU5YnWsrwg= github.com/tc-hib/winres v0.2.1 h1:YDE0FiP0VmtRaDn7+aaChp1KiF4owBiJa5l964l5ujA= @@ -1201,10 +1202,9 @@ github.com/uudashr/gocognit v1.1.2 h1:l6BAEKJqQH2UpKAPKdMfZf5kE4W/2xk8pfU1OVLvni github.com/uudashr/gocognit v1.1.2/go.mod h1:aAVdLURqcanke8h3vg35BC++eseDm66Z7KmchI5et4k= github.com/vbatts/tar-split v0.12.2 h1:w/Y6tjxpeiFMR47yzZPlPj/FcPLpXbTUi/9H7d3CPa4= github.com/vbatts/tar-split v0.12.2/go.mod h1:eF6B6i6ftWQcDqEn3/iGFRFRo8cBIMSJVOpnNdfTMFA= -github.com/vishvananda/netlink v1.3.1-0.20240922070040-084abd93d350 h1:w5OI+kArIBVksl8UGn6ARQshtPCQvDsbuA9NQie3GIg= -github.com/vishvananda/netlink v1.3.1-0.20240922070040-084abd93d350/go.mod h1:i6NetklAujEcC6fK0JPjT8qSwWyO0HLn4UKG+hGqeJs= +github.com/vishvananda/netlink v1.3.1 h1:3AEMt62VKqz90r0tmNhog0r/PpWKmrEShJU0wJW6bV0= +github.com/vishvananda/netlink v1.3.1/go.mod h1:ARtKouGSTGchR8aMwmkzC0qiNPrrWO5JS/XMVl45+b4= github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0= -github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM= github.com/vishvananda/netns v0.0.5 h1:DfiHV+j8bA32MFM7bfEunvT8IAqQ/NzSJHtcmW5zdEY= github.com/vishvananda/netns v0.0.5/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM= github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= @@ -1269,8 +1269,8 @@ go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.34.0 h1:OeNbIYk/2C15ckl7glB go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.34.0/go.mod h1:7Bept48yIeqxP2OZ9/AqIpYS94h2or0aB4FypJTc8ZM= go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.34.0 h1:tgJ0uaNS4c98WRNUEx5U3aDlrDOI5Rs+1Vifcw4DJ8U= go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.34.0/go.mod h1:U7HYyW0zt/a9x5J1Kjs+r1f/d4ZHnYFclhYY2+YbeoE= -go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.33.0 h1:wpMfgF8E1rkrT1Z6meFh1NDtownE9Ii3n3X2GJYjsaU= -go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.33.0/go.mod h1:wAy0T/dUbs468uOlkT31xjvqQgEVXv58BRFWEgn5v/0= +go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.32.0 h1:cMyu9O88joYEaI47CnQkxO1XZdpoTF9fEnW2duIddhw= +go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.32.0/go.mod h1:6Am3rn7P9TVVeXYG+wtcGE7IE1tsQ+bP3AuWcKt/gOI= go.opentelemetry.io/otel/exporters/prometheus v0.54.0 h1:rFwzp68QMgtzu9PgP3jm9XaMICI6TsofWWPcBDKwlsU= go.opentelemetry.io/otel/exporters/prometheus v0.54.0/go.mod h1:QyjcV9qDP6VeK5qPyKETvNjmaaEc7+gqjh4SS0ZYzDU= go.opentelemetry.io/otel/exporters/stdout/stdoutlog v0.8.0 h1:CHXNXwfKWfzS65yrlB2PVds1IBZcdsX8Vepy9of0iRU= @@ -1319,8 +1319,8 @@ golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPh golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.1.0/go.mod h1:RecgLatLF4+eUMCP1PoPZQb+cVrJcOPbHkTkbkB9sbw= -golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU= -golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0= +golang.org/x/crypto v0.50.0 h1:zO47/JPrL6vsNkINmLoo/PH1gcxpls50DNogFvB5ZGI= +golang.org/x/crypto v0.50.0/go.mod h1:3muZ7vA7PBCE6xgPX7nkzzjiUq87kRItoJQM1Yo8S+Q= golang.org/x/crypto/x509roots/fallback v0.0.0-20260113154411-7d0074ccc6f1 h1:EBHQuS9qI8xJ96+YRgVV2ahFLUYbWpt1rf3wPfXN2wQ= golang.org/x/crypto/x509roots/fallback v0.0.0-20260113154411-7d0074ccc6f1/go.mod h1:MEIPiCnxvQEjA4astfaKItNwEVZA5Ki+3+nyGbJ5N18= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= @@ -1370,8 +1370,8 @@ golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91 golang.org/x/mod v0.6.0/go.mod h1:4mET923SAdbXp2ki8ey+zGs1SLqsuM2Y0uvdZR/fUNI= golang.org/x/mod v0.7.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= -golang.org/x/mod v0.31.0 h1:HaW9xtz0+kOcWKwli0ZXy79Ix+UW/vOfmWI5QVd2tgI= -golang.org/x/mod v0.31.0/go.mod h1:43JraMp9cGx1Rx3AqioxrbrhNsLl2l/iNAvuBkrezpg= +golang.org/x/mod v0.35.0 h1:Ww1D637e6Pg+Zb2KrWfHQUnH2dQRLBQyAtpr/haaJeM= +golang.org/x/mod v0.35.0/go.mod h1:+GwiRhIInF8wPm+4AoT6L0FA1QWAad3OMdTRx4tFYlU= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -1411,8 +1411,8 @@ golang.org/x/net v0.1.0/go.mod h1:Cx3nUiGt4eDBEyega/BKRp+/AlGL8hYe7U9odMt2Cco= golang.org/x/net v0.2.0/go.mod h1:KqCZLdyyvdV855qA2rE3GC2aiw5xGR5TEjj8smXukLY= golang.org/x/net v0.5.0/go.mod h1:DivGGAXEgPSlEBzxGzZI+ZLohi+xUj054jfeKui00ws= golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= -golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU= -golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY= +golang.org/x/net v0.53.0 h1:d+qAbo5L0orcWAr0a9JweQpjXF19LMXJE8Ey7hwOdUA= +golang.org/x/net v0.53.0/go.mod h1:JvMuJH7rrdiCfbeHoo3fCQU24Lf5JJwT9W3sJFulfgs= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= @@ -1434,8 +1434,8 @@ golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= -golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= +golang.org/x/sync v0.20.0 h1:e0PTpb7pjO8GAtTs2dQ6jYa5BWYlMuX047Dco/pItO4= +golang.org/x/sync v0.20.0/go.mod h1:9xrNwdLfx4jkKbNva9FpL6vEN7evnE43NNNJQ2LF3+0= golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= @@ -1501,18 +1501,18 @@ golang.org/x/sys v0.4.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.10.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ= -golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= -golang.org/x/telemetry v0.0.0-20251203150158-8fff8a5912fc h1:bH6xUXay0AIFMElXG2rQ4uiE+7ncwtiOdPfYK1NK2XA= -golang.org/x/telemetry v0.0.0-20251203150158-8fff8a5912fc/go.mod h1:hKdjCMrbv9skySur+Nek8Hd0uJ0GuxJIoIX2payrIdQ= +golang.org/x/sys v0.43.0 h1:Rlag2XtaFTxp19wS8MXlJwTvoh8ArU6ezoyFsMyCTNI= +golang.org/x/sys v0.43.0/go.mod h1:4GL1E5IUh+htKOUEOaiffhrAeqysfVGipDYzABqnCmw= +golang.org/x/telemetry v0.0.0-20260409153401-be6f6cb8b1fa h1:efT73AJZfAAUV7SOip6pWGkwJDzIGiKBZGVzHYa+ve4= +golang.org/x/telemetry v0.0.0-20260409153401-be6f6cb8b1fa/go.mod h1:kHjTxDEnAu6/Nl9lDkzjWpR+bmKfxeiRuSDlsMb70gE= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.1.0/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.2.0/go.mod h1:TVmDHMZPmdnySmBfhjOoOdhjzdE1h4u1VwSiw2l1Nuc= golang.org/x/term v0.4.0/go.mod h1:9P2UbLfCdcvo3p/nzKvsmas4TnlujnuoV9hGgYzW1lQ= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= -golang.org/x/term v0.38.0 h1:PQ5pkm/rLO6HnxFR7N2lJHOZX6Kez5Y1gDSJla6jo7Q= -golang.org/x/term v0.38.0/go.mod h1:bSEAKrOT1W+VSu9TSCMtoGEOUcKxOKgl3LE5QEF/xVg= +golang.org/x/term v0.42.0 h1:UiKe+zDFmJobeJ5ggPwOshJIVt6/Ft0rcfrXZDLWAWY= +golang.org/x/term v0.42.0/go.mod h1:Dq/D+snpsbazcBG5+F9Q1n2rXV8Ma+71xEjTRufARgY= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= @@ -1523,8 +1523,8 @@ golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.6.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= -golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU= -golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY= +golang.org/x/text v0.36.0 h1:JfKh3XmcRPqZPKevfXVpI1wXPTqbkE5f7JA92a55Yxg= +golang.org/x/text v0.36.0/go.mod h1:NIdBknypM8iqVmPiuco0Dh6P5Jcdk8lJL0CUebqK164= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= @@ -1594,8 +1594,8 @@ golang.org/x/tools v0.2.0/go.mod h1:y4OqIKeOV/fWJetJ8bXPU1sEVniLMIyDAZWeHdV+NTA= golang.org/x/tools v0.3.0/go.mod h1:/rWhSS2+zyEVwoJf8YAX6L2f0ntZ7Kn/mGgAWcipA5k= golang.org/x/tools v0.5.0/go.mod h1:N+Kgy78s5I24c24dU8OfWNEotWjutIs8SnJvn5IDq+k= golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= -golang.org/x/tools v0.40.1-0.20260108161641-ca281cf95054 h1:CHVDrNHx9ZoOrNN9kKWYIbT5Rj+WF2rlwPkhbQQ5V4U= -golang.org/x/tools v0.40.1-0.20260108161641-ca281cf95054/go.mod h1:Ik/tzLRlbscWpqqMRjyWYDisX8bG13FrdXp3o4Sr9lc= +golang.org/x/tools v0.44.0 h1:UP4ajHPIcuMjT1GqzDWRlalUEoY+uzoZKnhOjbIPD2c= +golang.org/x/tools v0.44.0/go.mod h1:KA0AfVErSdxRZIsOVipbv3rQhVXTnlU6UhKxHd1seDI= golang.org/x/tools/go/expect v0.1.1-deprecated h1:jpBZDwmgPhXsKZC6WhL20P4b/wmnpsEAGHaNy0n/rJM= golang.org/x/tools/go/expect v0.1.1-deprecated/go.mod h1:eihoPOH+FgIqa3FpoTwguz/bVUSGBlGQU67vpBeOrBY= golang.org/x/tools/go/packages/packagestest v0.1.1-deprecated h1:1h2MnaIAIXISqTFKdENegdpAgUXz6NrPEsbIeWaBRvM= @@ -1728,8 +1728,8 @@ gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -gotest.tools/v3 v3.4.0 h1:ZazjZUfuVeZGLAmlKKuyv3IKP5orXcwtOwDQH6YVr6o= -gotest.tools/v3 v3.4.0/go.mod h1:CtbdzLSsqVhDgMtKsx03ird5YTGB3ar27v0u/yKBW5g= +gotest.tools/v3 v3.5.2 h1:7koQfIKdy+I8UTetycgUqXWSDwpgv193Ka+qRsmBY8Q= +gotest.tools/v3 v3.5.2/go.mod h1:LtdLGcnqToBH83WByAAi/wiwSFCArdFIUV/xxN4pcjA= gvisor.dev/gvisor v0.0.0-20260224225140-573d5e7127a8 h1:Zy8IV/+FMLxy6j6p87vk/vQGKcdnbprwjTxc8UiUtsA= gvisor.dev/gvisor v0.0.0-20260224225140-573d5e7127a8/go.mod h1:QkHjoMIBaYtpVufgwv3keYAbln78mBoCuShZrPrer1Q= helm.sh/helm/v3 v3.19.0 h1:krVyCGa8fa/wzTZgqw0DUiXuRT5BPdeqE/sQXujQ22k= @@ -1773,6 +1773,8 @@ mvdan.cc/unparam v0.0.0-20240104100049-c549a3470d14 h1:zCr3iRRgdk5eIikZNDphGcM6K mvdan.cc/unparam v0.0.0-20240104100049-c549a3470d14/go.mod h1:ZzZjEpJDOmx8TdVU6umamY3Xy0UAQUI2DHbf05USVbI= oras.land/oras-go/v2 v2.6.0 h1:X4ELRsiGkrbeox69+9tzTu492FMUu7zJQW6eJU+I2oc= oras.land/oras-go/v2 v2.6.0/go.mod h1:magiQDfG6H1O9APp+rOsvCPcW1GD2MM7vgnKY0Y+u1o= +pgregory.net/rapid v1.2.0 h1:keKAYRcjm+e1F0oAuU5F5+YPAWcyxNNRK2wud503Gnk= +pgregory.net/rapid v1.2.0/go.mod h1:PY5XlDGj0+V1FCq0o192FdRhpKHGTRIWBgqjDBTrq04= rsc.io/binaryregexp v0.2.0/go.mod h1:qTv7/COck+e2FymRvadv62gMdZztPaShugOCi3I+8D8= rsc.io/quote/v3 v3.1.0/go.mod h1:yEA65RcK8LyAZtP9Kv3t0HmxON59tX3rD+tICJqUlj0= rsc.io/sampler v1.3.0/go.mod h1:T1hPZKmBbMNahiBKFy5HrXp6adAjACjK9JXDnKaTXpA= diff --git a/go.toolchain.next.rev b/go.toolchain.next.rev index dbcbdd5e1..fd5a21697 100644 --- a/go.toolchain.next.rev +++ b/go.toolchain.next.rev @@ -1 +1 @@ -dfe2a5fd8ee2e68b08ce5ff259269f50ecadf2f4 +e877d973840c91ec9d4bc1921b0845789de359ae diff --git a/go.toolchain.rev b/go.toolchain.rev index dbcbdd5e1..fd5a21697 100644 --- a/go.toolchain.rev +++ b/go.toolchain.rev @@ -1 +1 @@ -dfe2a5fd8ee2e68b08ce5ff259269f50ecadf2f4 +e877d973840c91ec9d4bc1921b0845789de359ae diff --git a/go.toolchain.rev.sri b/go.toolchain.rev.sri deleted file mode 100644 index 00025af58..000000000 --- a/go.toolchain.rev.sri +++ /dev/null @@ -1 +0,0 @@ -sha256-pCvFNTFuvhSBb5O+PPuilaowP4tXcCOP1NgYUDJTcJU= diff --git a/go.toolchain.version b/go.toolchain.version index c7c3f3333..f8f738140 100644 --- a/go.toolchain.version +++ b/go.toolchain.version @@ -1 +1 @@ -1.26.2 +1.26.3 diff --git a/gokrazy/Makefile b/gokrazy/Makefile index bc55f2a52..014866851 100644 --- a/gokrazy/Makefile +++ b/gokrazy/Makefile @@ -11,3 +11,8 @@ qemu: image natlab: go run build.go --build --app=natlabapp qemu-img convert -O qcow2 natlabapp.img natlabapp.qcow2 + +# For natlab integration tests on macOS arm64: +natlab-arm64: + go run build.go --build --app=natlabapp.arm64 + qemu-img convert -O qcow2 natlabapp.arm64.img natlabapp.arm64.qcow2 diff --git a/gokrazy/natlabapp.arm64/config.json b/gokrazy/natlabapp.arm64/config.json index 2ba9a20f9..8283dc053 100644 --- a/gokrazy/natlabapp.arm64/config.json +++ b/gokrazy/natlabapp.arm64/config.json @@ -27,5 +27,7 @@ "KernelPackage": "github.com/gokrazy/kernel.arm64", "FirmwarePackage": "github.com/gokrazy/kernel.arm64", "EEPROMPackage": "", - "InternalCompatibilityFlags": {} + "InternalCompatibilityFlags": { + "InitImportPath": "github.com/tailscale/ts-gokrazy/gokrazyinit" + } } diff --git a/gokrazy/natlabapp.arm64/gokrazydeps.go b/gokrazy/natlabapp.arm64/gokrazydeps.go new file mode 100644 index 000000000..001ab89b8 --- /dev/null +++ b/gokrazy/natlabapp.arm64/gokrazydeps.go @@ -0,0 +1,16 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +//go:build for_go_mod_tidy_only + +package gokrazydeps + +import ( + _ "github.com/gokrazy/gokrazy/cmd/dhcp" + _ "github.com/gokrazy/kernel.arm64" + _ "github.com/gokrazy/serial-busybox" + _ "github.com/tailscale/ts-gokrazy/gokrazyinit" + _ "tailscale.com/cmd/tailscale" + _ "tailscale.com/cmd/tailscaled" + _ "tailscale.com/cmd/tta" +) diff --git a/gokrazy/natlabapp/config.json b/gokrazy/natlabapp/config.json index 1968b2aac..c46f01879 100644 --- a/gokrazy/natlabapp/config.json +++ b/gokrazy/natlabapp/config.json @@ -27,5 +27,7 @@ "KernelPackage": "github.com/tailscale/gokrazy-kernel", "FirmwarePackage": "", "EEPROMPackage": "", - "InternalCompatibilityFlags": {} + "InternalCompatibilityFlags": { + "InitImportPath": "github.com/tailscale/ts-gokrazy/gokrazyinit" + } } diff --git a/gokrazy/natlabapp/gokrazydeps.go b/gokrazy/natlabapp/gokrazydeps.go index c5d2b32a3..2e4c1361c 100644 --- a/gokrazy/natlabapp/gokrazydeps.go +++ b/gokrazy/natlabapp/gokrazydeps.go @@ -6,10 +6,10 @@ package gokrazydeps import ( - _ "github.com/gokrazy/gokrazy" _ "github.com/gokrazy/gokrazy/cmd/dhcp" _ "github.com/gokrazy/serial-busybox" _ "github.com/tailscale/gokrazy-kernel" + _ "github.com/tailscale/ts-gokrazy/gokrazyinit" _ "tailscale.com/cmd/tailscale" _ "tailscale.com/cmd/tailscaled" _ "tailscale.com/cmd/tta" diff --git a/gokrazy/tsapp/config.json b/gokrazy/tsapp/config.json index b88be53a4..15533afd1 100644 --- a/gokrazy/tsapp/config.json +++ b/gokrazy/tsapp/config.json @@ -33,5 +33,7 @@ ], "KernelPackage": "github.com/tailscale/gokrazy-kernel", "FirmwarePackage": "github.com/tailscale/gokrazy-kernel", - "InternalCompatibilityFlags": {} + "InternalCompatibilityFlags": { + "InitImportPath": "github.com/tailscale/ts-gokrazy/gokrazyinit" + } } diff --git a/gokrazy/tsapp/gokrazydeps.go b/gokrazy/tsapp/gokrazydeps.go index 931080647..22bdc3a49 100644 --- a/gokrazy/tsapp/gokrazydeps.go +++ b/gokrazy/tsapp/gokrazydeps.go @@ -7,12 +7,12 @@ package gokrazydeps import ( _ "github.com/gokrazy/breakglass" - _ "github.com/gokrazy/gokrazy" _ "github.com/gokrazy/gokrazy/cmd/dhcp" _ "github.com/gokrazy/gokrazy/cmd/ntp" _ "github.com/gokrazy/gokrazy/cmd/randomd" _ "github.com/gokrazy/serial-busybox" _ "github.com/tailscale/gokrazy-kernel" + _ "github.com/tailscale/ts-gokrazy/gokrazyinit" _ "tailscale.com/cmd/tailscale" _ "tailscale.com/cmd/tailscaled" ) diff --git a/health/health.go b/health/health.go index 1829bd482..7e2878159 100644 --- a/health/health.go +++ b/health/health.go @@ -492,7 +492,12 @@ func (t *Tracker) SetHealthy(w *Warnable) { } func (t *Tracker) setHealthyLocked(w *Warnable) { - if !buildfeatures.HasHealth || t.warnableVal[w] == nil { + if !buildfeatures.HasHealth { + return + } + + ws := t.warnableVal[w] + if ws == nil { // Nothing to remove return } @@ -501,15 +506,28 @@ func (t *Tracker) setHealthyLocked(w *Warnable) { // Stop any pending visiblity timers for this Warnable if canc, ok := t.pendingVisibleTimers[w]; ok { + // We removed the warningState for this Warnable, + // and we hold the lock, so even if the timer callback + // has already started, it won't find a warningState + // for this Warnable and won't publish any changes. canc.Stop() delete(t.pendingVisibleTimers, w) } - change := Change{ - WarnableChanged: true, - Warnable: w, + // Only publish a change if the Warnable was unhealthy long + // enough to become visible to the user. Otherwise, it would + // not have been published as unhealthy, so there is no need + // to publish it as healthy. This prevents eventbus (and by + // extension the IPN bus) churn for Warnables that are marked + // unhealthy and then healthy again. Notably, this includes + // warnables touched by [Tracker.updateBuiltinWarnablesLocked]. + if w.IsVisible(ws, t.now) { + change := Change{ + WarnableChanged: true, + Warnable: w, + } + t.changePub.Publish(change) } - t.changePub.Publish(change) } // notifyWatchersControlChangedLocked calls each watcher to signal that control diff --git a/health/health_test.go b/health/health_test.go index ccd49b19a..4b5af6d76 100644 --- a/health/health_test.go +++ b/health/health_test.go @@ -271,47 +271,91 @@ func TestSetUnhealthyWithTimeToVisible(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(*testing.T) { - bus := eventbustest.NewBus(t) - ht := NewTracker(bus) - mw := Register(&Warnable{ - Code: "test-warnable-3-secs-to-visible", - Title: "Test Warnable with 3 seconds to visible", - Text: StaticMessage("Hello world"), - TimeToVisible: 2 * time.Second, - ImpactsConnectivity: true, + synctest.Test(t, func(t *testing.T) { + clock := tstest.NewClock(tstest.ClockOpts{ + Start: time.Unix(123, 0), + FollowRealTime: false, + }) + bus := eventbustest.NewBus(t) + ht := NewTracker(bus) + ht.testClock = clock + mw := Register(&Warnable{ + Code: "test-warnable-3-secs-to-visible", + Title: "Test Warnable with 3 seconds to visible", + Text: StaticMessage("Hello world"), + TimeToVisible: 2 * time.Second, + ImpactsConnectivity: true, + }) + defer unregister(mw) + + becameUnhealthy := make(chan struct{}) + becameHealthy := make(chan struct{}) + + watchFunc := func(c Change) { + w := c.Warnable + us := c.UnhealthyState + if w != mw { + t.Fatalf("watcherFunc was called, but with an unexpected Warnable: %v, want: %v", w, w) + } + + if us != nil { + becameUnhealthy <- struct{}{} + } else { + becameHealthy <- struct{}{} + } + } + + tt.preFunc(t, ht, bus, watchFunc) + ht.SetUnhealthy(mw, Args{ArgError: "Hello world"}) + + // Advance time by half of the TimeToVisible duration. + clock.Advance(mw.TimeToVisible / 2) + + select { + case <-becameUnhealthy: + // Test failed because the watcher got notified of an unhealthy state + t.Fatalf("watcherFunc was called with an unhealthy state") + case <-becameHealthy: + // Test failed because the watcher got of a healthy state + t.Fatalf("watcherFunc was called with a healthy state") + default: + // As expected, watcherFunc still had not been called + // after mw.TimeToVisible / 2. + } + + // Advance time to get past the the TimeToVisible duration. + // The watcher should be notified of the unhealthy state. + clock.Advance(mw.TimeToVisible/2 + 1) + <-becameUnhealthy + + // Reset the warnable to neutral / healthy state before + // the next part of the test. + ht.SetHealthy(mw) + <-becameHealthy + + // Mark the warnable unhealthy and then immediately healthy + // before the TimeToVisible duration elapses. + // The watcher should not be notified of either change + // because the warnable never became visible. + ht.SetUnhealthy(mw, Args{ArgError: "Hello world"}) + ht.SetHealthy(mw) + + // Advance to get past the the TimeToVisible delay. + clock.Advance(mw.TimeToVisible * 2) + synctest.Wait() + + select { + case <-becameUnhealthy: + // Test failed because the watcher got notified of an unhealthy state + t.Fatalf("watcherFunc was called with an unhealthy state") + case <-becameHealthy: + // Test failed because the watcher got of a healthy state + t.Fatalf("watcherFunc was called with a healthy state") + default: + // As expected, watcherFunc was not called after marking + // the warnable healthy again as it never became visible. + } }) - - becameUnhealthy := make(chan struct{}) - becameHealthy := make(chan struct{}) - - watchFunc := func(c Change) { - w := c.Warnable - us := c.UnhealthyState - if w != mw { - t.Fatalf("watcherFunc was called, but with an unexpected Warnable: %v, want: %v", w, w) - } - - if us != nil { - becameUnhealthy <- struct{}{} - } else { - becameHealthy <- struct{}{} - } - } - - tt.preFunc(t, ht, bus, watchFunc) - ht.SetUnhealthy(mw, Args{ArgError: "Hello world"}) - - select { - case <-becameUnhealthy: - // Test failed because the watcher got notified of an unhealthy state - t.Fatalf("watcherFunc was called with an unhealthy state") - case <-becameHealthy: - // Test failed because the watcher got of a healthy state - t.Fatalf("watcherFunc was called with a healthy state") - case <-time.After(1 * time.Second): - // As expected, watcherFunc still had not been called after 1 second - } - unregister(mw) }) } } @@ -739,12 +783,11 @@ func TestControlHealthNotifies(t *testing.T) { ht.SetIPNState("NeedsLogin", true) ht.GotStreamedMapResponse() - // Expect events at starup, before doing anything else, skip unstable - // event and no warning event as they show up at different times. + // Expect events at starup, before doing anything else, skip + // the warming up events. synctest.Wait() if err := eventbustest.Expect(tw, CompareWarnableCode(t, tsconst.HealthWarnableWarmingUp), - CompareWarnableCode(t, tsconst.HealthWarnableNotInMapPoll), CompareWarnableCode(t, tsconst.HealthWarnableWarmingUp), ); err != nil { t.Errorf("startup error: %v", err) diff --git a/ipn/backend.go b/ipn/backend.go index 7ea7c92b4..51617e08e 100644 --- a/ipn/backend.go +++ b/ipn/backend.go @@ -118,6 +118,17 @@ type Notify struct { Prefs *PrefsView // if non-nil && Valid, the new or current preferences NetMap *netmap.NetworkMap // if non-nil, the new or current netmap + // SelfChange, if non-nil, indicates that this node's own [tailcfg.Node] + // has changed: addresses, name, key expiry, capabilities, etc. It carries + // the new self node so reactive consumers (containerboot, kube agents, + // sniproxy, etc.) can read the current self state without watching the + // full netmap. + // + // Consumers that need additional state (peers, DNS config, packet + // filter) should react to SelfChange by fetching the relevant bits on + // demand via [LocalClient]. + SelfChange *tailcfg.Node `json:",omitzero"` + // PeerChanges, if non-nil, is a list of [tailcfg.PeerChange] that have occurred since the last // full netmap update. This is sent in lieu of a full NetMap when [NotifyPeerChanges] is set in // the session's mask and a netmap update is derived from an incremental MapResponse. @@ -196,6 +207,9 @@ func (n Notify) String() string { if n.NetMap != nil { sb.WriteString("NetMap{...} ") } + if n.SelfChange != nil { + fmt.Fprintf(&sb, "SelfChange(%v) ", n.SelfChange.StableID) + } if n.PeerChanges != nil { fmt.Fprintf(&sb, "PeerChanges(%d) ", len(n.PeerChanges)) } diff --git a/ipn/ipnlocal/bus.go b/ipn/ipnlocal/bus.go index de04fd09a..8be508010 100644 --- a/ipn/ipnlocal/bus.go +++ b/ipn/ipnlocal/bus.go @@ -205,6 +205,7 @@ func isNotableNotify(n *ipn.Notify) bool { n.Prefs != nil || n.ErrMessage != nil || n.LoginFinished != nil || + n.SelfChange != nil || !n.DriveShares.IsNil() || n.Health != nil || len(n.IncomingFiles) > 0 || diff --git a/ipn/ipnlocal/bus_test.go b/ipn/ipnlocal/bus_test.go index 8e4d3ede8..048e5bff4 100644 --- a/ipn/ipnlocal/bus_test.go +++ b/ipn/ipnlocal/bus_test.go @@ -32,6 +32,7 @@ func TestIsNotableNotify(t *testing.T) { {"netmap", &ipn.Notify{NetMap: new(netmap.NetworkMap)}, false}, {"peerchanges", &ipn.Notify{PeerChanges: []*tailcfg.PeerChange{{}}}, false}, {"engine", &ipn.Notify{Engine: new(ipn.EngineStatus)}, false}, + {"selfchange", &ipn.Notify{SelfChange: &tailcfg.Node{}}, true}, } // Then for all other fields, assume they're notable. @@ -41,7 +42,7 @@ func TestIsNotableNotify(t *testing.T) { for sf := range rt.Fields() { n := &ipn.Notify{} switch sf.Name { - case "_", "NetMap", "PeerChanges", "Engine", "Version": + case "_", "NetMap", "PeerChanges", "SelfChange", "Engine", "Version": // Already covered above or not applicable. continue case "DriveShares": diff --git a/ipn/ipnlocal/c2n.go b/ipn/ipnlocal/c2n.go index 8284872b9..bf8cf2e03 100644 --- a/ipn/ipnlocal/c2n.go +++ b/ipn/ipnlocal/c2n.go @@ -27,6 +27,7 @@ import ( "tailscale.com/util/goroutines" "tailscale.com/util/httpm" "tailscale.com/util/set" + "tailscale.com/util/testenv" "tailscale.com/version" ) @@ -323,3 +324,10 @@ func handleC2NSetNetfilterKind(b *LocalBackend, w http.ResponseWriter, r *http.R w.WriteHeader(http.StatusNoContent) } + +// HandleC2NForTest calls [handleC2N], for use by feature/ packages that +// register C2N handlers and want to test them. +func (b *LocalBackend) HandleC2NForTest(w http.ResponseWriter, r *http.Request) { + testenv.AssertInTest() + b.handleC2N(w, r) +} diff --git a/ipn/ipnlocal/cert.go b/ipn/ipnlocal/cert.go index efab9db7a..eab70b295 100644 --- a/ipn/ipnlocal/cert.go +++ b/ipn/ipnlocal/cert.go @@ -909,7 +909,7 @@ func (b *LocalBackend) resolveCertDomain(domain string) (string, error) { } // Read the netmap once to get both CertDomains and capabilities atomically. - nm := b.NetMap() + nm := b.NetMapNoPeers() if nm == nil { return "", errors.New("no netmap available") } diff --git a/ipn/ipnlocal/cert_test.go b/ipn/ipnlocal/cert_test.go index e2d69da52..56d6df77f 100644 --- a/ipn/ipnlocal/cert_test.go +++ b/ipn/ipnlocal/cert_test.go @@ -519,8 +519,11 @@ func TestGetCertPEMWithValidity(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { + tstest.AssertNotParallel(t) if tt.readOnlyMode { envknob.Setenv("TS_CERT_SHARE_MODE", "ro") + } else { + envknob.Setenv("TS_CERT_SHARE_MODE", "") } os.RemoveAll(certDir) diff --git a/ipn/ipnlocal/diskcache.go b/ipn/ipnlocal/diskcache.go index 03ced7967..3235869e6 100644 --- a/ipn/ipnlocal/diskcache.go +++ b/ipn/ipnlocal/diskcache.go @@ -35,7 +35,19 @@ func (b *LocalBackend) writeNetmapToDiskLocked(nm *netmap.NetworkMap) error { b.diskCache.cache = netmapcache.NewCache(netmapcache.FileStore(dir)) b.diskCache.dir = dir } - return b.diskCache.cache.Store(b.currentNode().Context(), nm) + + // Set the homeDERP on the self node before saving. The self node homeDERP is + // generally not used since the homeDERP for self is stored in magicsock, but + // to be able to load it during loading the cache, we use the existing field + // to save it. + + // Make a shallow copy and mutate a copy of the selfNode. + nmCopy := *nm + selfNode := nm.SelfNode.AsStruct() + selfNode.HomeDERP = int(b.currentNode().homeDERP.Load()) + nmCopy.SelfNode = selfNode.View() + + return b.diskCache.cache.Store(b.currentNode().Context(), &nmCopy) } func (b *LocalBackend) loadDiskCacheLocked() (om *netmap.NetworkMap, ok bool) { diff --git a/ipn/ipnlocal/diskcache_test.go b/ipn/ipnlocal/diskcache_test.go new file mode 100644 index 000000000..748ff6a40 --- /dev/null +++ b/ipn/ipnlocal/diskcache_test.go @@ -0,0 +1,229 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package ipnlocal + +import ( + "net/netip" + "testing" + + "tailscale.com/tailcfg" + "tailscale.com/tstest" + "tailscale.com/types/netmap" + "tailscale.com/util/eventbus" + "tailscale.com/wgengine/magicsock" +) + +// newCacheTestNetmap returns a minimal valid netmap suitable for testing disk +// cache operations. +func newCacheTestNetmap() *netmap.NetworkMap { + return &netmap.NetworkMap{ + SelfNode: (&tailcfg.Node{ + Name: "test-node.ts.net", + User: tailcfg.UserID(1), + Addresses: []netip.Prefix{ + netip.MustParsePrefix("100.64.0.1/32"), + }, + }).View(), + UserProfiles: map[tailcfg.UserID]tailcfg.UserProfileView{ + tailcfg.UserID(1): (&tailcfg.UserProfile{ + LoginName: "user@example.com", + DisplayName: "Test User", + }).View(), + }, + DERPMap: &tailcfg.DERPMap{ + Regions: map[int]*tailcfg.DERPRegion{ + 1: {}, + 2: {}, + 3: {}, + 4: {}, + 5: {}, + 6: {}, + 7: {}, + 8: {}, + 9: {}, + 10: {}, + 11: {}, + }, + }, + } +} + +func TestWriteAndLoadHomeDERP(t *testing.T) { + b := newTestBackend(t) + + nm := newCacheTestNetmap() + b.currentNode().SetNetMap(nm) + + const wantDERP = 7 + b.currentNode().homeDERP.Store(wantDERP) + + b.mu.Lock() + defer b.mu.Unlock() + + if err := b.writeNetmapToDiskLocked(nm); err != nil { + t.Fatalf("writeNetmapToDiskLocked: %v", err) + } + + loaded, ok := b.loadDiskCacheLocked() + if !ok { + t.Fatal("loadDiskCacheLocked returned ok=false") + } + if !loaded.SelfNode.Valid() { + t.Fatal("loaded netmap SelfNode is invalid") + } + if got := loaded.SelfNode.HomeDERP(); got != wantDERP { + t.Errorf("loaded SelfNode.HomeDERP() = %d, want %d", got, wantDERP) + } +} + +func TestOnHomeDERPUpdate(t *testing.T) { + t.Run("normal_derp_change", func(t *testing.T) { + b := newTestBackend(t) + done := make(chan struct{}) + tstest.Replace(t, &testOnlyHomeDERPUpdate, func() { close(done) }) + + nm := newCacheTestNetmap() + b.currentNode().SetNetMap(nm) + + // Publish a HomeDERPChanged event via the backend's event bus. + bus := b.Sys().Bus.Get() + ec := bus.Client("test.TestOnHomeDERPUpdate") + pub := eventbus.Publish[magicsock.HomeDERPChanged](ec) + + const wantDERP = 11 + pub.Publish(magicsock.HomeDERPChanged{Old: 0, New: wantDERP}) + <-done + + if got := b.currentNode().homeDERP.Load(); got != wantDERP { + t.Errorf("b.homeDERP = %d, want %d", got, wantDERP) + } + + // Verify the value was persisted to the disk cache. + b.mu.Lock() + defer b.mu.Unlock() + loaded, ok := b.loadDiskCacheLocked() + if !ok { + t.Fatal("loadDiskCacheLocked returned ok=false after homeDERP update") + } + if got := loaded.SelfNode.HomeDERP(); got != wantDERP { + t.Errorf("cached SelfNode.HomeDERP() = %d, want %d", got, wantDERP) + } + }) + t.Run("old_does_not_match", func(t *testing.T) { + b := newTestBackend(t) + done := make(chan struct{}) + tstest.Replace(t, &testOnlyHomeDERPUpdate, func() { close(done) }) + + const setDERP = 11 + const wantDERP = 4 + + nm := newCacheTestNetmap() + selfNode := nm.SelfNode.AsStruct() + selfNode.HomeDERP = wantDERP + nm.SelfNode = selfNode.View() + b.currentNode().SetNetMap(nm) + b.currentNode().homeDERP.Store(wantDERP) + + // Write an initial cache entry so we can verify it is not overwritten. + b.mu.Lock() + if err := b.writeNetmapToDiskLocked(nm); err != nil { + b.mu.Unlock() + t.Fatalf("setup writeNetmapToDiskLocked: %v", err) + } + b.mu.Unlock() + + // Publish a HomeDERPChanged event via the backend's event bus. + bus := b.Sys().Bus.Get() + ec := bus.Client("test.TestOnHomeDERPUpdate") + pub := eventbus.Publish[magicsock.HomeDERPChanged](ec) + pub.Publish(magicsock.HomeDERPChanged{Old: wantDERP + 1, New: setDERP}) + <-done + + if got := b.currentNode().homeDERP.Load(); got != wantDERP { + t.Errorf("b.homeDERP = %d, wanted no change %d", got, wantDERP) + } + + // Verify the cache still exists and still holds the original value. + b.mu.Lock() + defer b.mu.Unlock() + loaded, ok := b.loadDiskCacheLocked() + if !ok { + t.Fatal("loadDiskCacheLocked returned ok=false; expected cache to still exist") + } + if got := loaded.SelfNode.HomeDERP(); got != wantDERP { + t.Errorf("cached SelfNode.HomeDERP() = %d after rejected event, want original %d", got, wantDERP) + } + }) + t.Run("new_does_not_exist_in_map", func(t *testing.T) { + b := newTestBackend(t) + done := make(chan struct{}) + tstest.Replace(t, &testOnlyHomeDERPUpdate, func() { close(done) }) + + const setDERP = 111 + const wantDERP = 4 + + nm := newCacheTestNetmap() + selfNode := nm.SelfNode.AsStruct() + selfNode.HomeDERP = wantDERP + nm.SelfNode = selfNode.View() + b.currentNode().SetNetMap(nm) + b.currentNode().homeDERP.Store(wantDERP) + + // Write an initial cache entry so we can verify it is not overwritten. + b.mu.Lock() + if err := b.writeNetmapToDiskLocked(nm); err != nil { + b.mu.Unlock() + t.Fatalf("setup writeNetmapToDiskLocked: %v", err) + } + b.mu.Unlock() + + // Publish a HomeDERPChanged event via the backend's event bus. + // Old matches the stored homeDERP so only the "new region not in map" + // guard is exercised. + bus := b.Sys().Bus.Get() + ec := bus.Client("test.TestOnHomeDERPUpdate") + pub := eventbus.Publish[magicsock.HomeDERPChanged](ec) + pub.Publish(magicsock.HomeDERPChanged{Old: wantDERP, New: setDERP}) + <-done + + if got := b.currentNode().homeDERP.Load(); got != wantDERP { + t.Errorf("b.homeDERP = %d, wanted no change %d", got, wantDERP) + } + + // Verify the cache still exists and still holds the original value. + b.mu.Lock() + defer b.mu.Unlock() + loaded, ok := b.loadDiskCacheLocked() + if !ok { + t.Fatal("loadDiskCacheLocked returned ok=false; expected cache to still exist") + } + if got := loaded.SelfNode.HomeDERP(); got != wantDERP { + t.Errorf("cached SelfNode.HomeDERP() = %d after rejected event, want original %d", got, wantDERP) + } + }) +} + +func TestWriteNetmapDoesNotMutateOriginal(t *testing.T) { + b := newTestBackend(t) + + nm := newCacheTestNetmap() + b.currentNode().SetNetMap(nm) + + originalDERP := nm.SelfNode.HomeDERP() // expected to be 0 initially + + const storeDERP = 5 + b.currentNode().homeDERP.Store(storeDERP) + + b.mu.Lock() + defer b.mu.Unlock() + + if err := b.writeNetmapToDiskLocked(nm); err != nil { + t.Fatalf("writeNetmapToDiskLocked: %v", err) + } + + // The original netmap must not have been mutated. + if got := nm.SelfNode.HomeDERP(); got != originalDERP { + t.Errorf("original nm.SelfNode.HomeDERP() = %d after write, want %d (original was mutated)", got, originalDERP) + } +} diff --git a/ipn/ipnlocal/drive.go b/ipn/ipnlocal/drive.go index 485114eae..110ffff2a 100644 --- a/ipn/ipnlocal/drive.go +++ b/ipn/ipnlocal/drive.go @@ -303,18 +303,19 @@ func (b *LocalBackend) updateDrivePeersLocked(nm *netmap.NetworkMap) { } func (b *LocalBackend) driveRemotesFromPeers(nm *netmap.NetworkMap) []*drive.Remote { - b.logf("[v1] taildrive: setting up drive remotes from peers") + b.logf("[v1] taildrive: setting up drive remotes from %d peers", len(nm.Peers)) driveRemotes := make([]*drive.Remote, 0, len(nm.Peers)) for _, p := range nm.Peers { peer := p peerID := peer.ID() peerKey := peer.Key().ShortString() - b.logf("[v1] taildrive: appending remote for peer %s", peerKey) + peerName := peer.DisplayName(false) + driveRemotes = append(driveRemotes, &drive.Remote{ - Name: p.DisplayName(false), + Name: peerName, URL: func() string { url := fmt.Sprintf("%s/%s", b.currentNode().PeerAPIBase(peer), taildrivePrefix[1:]) - b.logf("[v2] taildrive: url for peer %s: %s", peerKey, url) + b.logf("[v2] taildrive: url for peer %s (%s): %s", peerKey, peerName, url) return url }, Available: func() bool { @@ -325,7 +326,7 @@ func (b *LocalBackend) driveRemotesFromPeers(nm *netmap.NetworkMap) []*drive.Rem cn := b.currentNode() peer, ok := cn.NodeByID(peerID) if !ok { - b.logf("[v2] taildrive: Available(): peer %s not found", peerKey) + b.logf("[v2] taildrive: peer %s (%s, id=%v) not found", peerKey, peerName, peerID) return false } @@ -338,26 +339,25 @@ func (b *LocalBackend) driveRemotesFromPeers(nm *netmap.NetworkMap) []*drive.Rem // The netmap.Peers slice is not updated in all cases. // It should be fixed now that we use PeerByIDOk. if !peer.Online().Get() { - b.logf("[v2] taildrive: Available(): peer %s offline", peerKey) + b.logf("[v2] taildrive: peer %s (%s, id=%v) offline", peerKey, peerName, peerID) return false } - - if b.currentNode().PeerAPIBase(peer) == "" { - b.logf("[v2] taildrive: Available(): peer %s PeerAPI unreachable", peerKey) + if cn.PeerAPIBase(peer) == "" { + b.logf("[v2] taildrive: peer %s (%s, id=%v) PeerAPI unreachable", peerKey, peerName, peerID) return false } - // Check that the peer is allowed to share with us. if cn.PeerHasCap(peer, tailcfg.PeerCapabilityTaildriveSharer) { - b.logf("[v2] taildrive: Available(): peer %s available", peerKey) + b.logf("[v2] taildrive: peer %s (%s, id=%v) available", peerKey, peerName, peerID) return true } - b.logf("[v2] taildrive: Available(): peer %s not allowed to share", peerKey) + b.logf("[v2] taildrive: peer %s (%s, id=%v) not allowed to share", peerKey, peerName, peerID) return false }, }) } + b.logf("[v1] taildrive: built %d candidate remotes", len(driveRemotes)) return driveRemotes } diff --git a/ipn/ipnlocal/local.go b/ipn/ipnlocal/local.go index 610d1d7b5..70bf87b70 100644 --- a/ipn/ipnlocal/local.go +++ b/ipn/ipnlocal/local.go @@ -536,10 +536,14 @@ func NewLocalBackend(logf logger.Logf, logID logid.PublicID, sys *tsd.System, lo needsCaptiveDetection: make(chan bool), } + sys.NoiseRoundTripper.Set(noiseRoundTripper{b}) + nb := newNodeBackend(ctx, b.logf, b.sys.Bus.Get()) b.currentNodeAtomic.Store(nb) nb.ready() + e.SetPeerByIPPacketFunc(b.lookupPeerByIP) + if sys.InitialConfig != nil { if err := b.initPrefsFromConfig(sys.InitialConfig); err != nil { return nil, err @@ -627,6 +631,7 @@ func NewLocalBackend(logf logger.Logf, logID logid.PublicID, sys *tsd.System, lo } eventbus.SubscribeFunc(ec, b.onAppConnectorRouteUpdate) eventbus.SubscribeFunc(ec, b.onAppConnectorStoreRoutes) + eventbus.SubscribeFunc(ec, b.onHomeDERPUpdate) mConn.SetNetInfoCallback(b.setNetInfo) // TODO(tailscale/tailscale#17887): move to eventbus return b, nil @@ -658,6 +663,53 @@ func (b *LocalBackend) onAppConnectorStoreRoutes(ri appctype.RouteInfo) { } } +// testOnlyHomeDERPUpdate if non-nil is called after setting home DERP and +// writing netmap to disk. +var testOnlyHomeDERPUpdate func() + +func (b *LocalBackend) onHomeDERPUpdate(du magicsock.HomeDERPChanged) { + b.mu.Lock() + defer b.mu.Unlock() + + b.onHomeDERPUpdateLocked(du) + + if testOnlyHomeDERPUpdate != nil { + testOnlyHomeDERPUpdate() + } +} + +// onHomeDERPUpdateLocked considitonally updates the homeDERP for use in the +// netmap cache. +// If we switched our currentNode by switching profiles, we might be trying +// to update the homeDERP from another profile. If the old homeDERP does not +// match what we expect, don't swap the homeDERP. +// In practice, it is possible that one profile with a homeDERP of 0 (no-derp) +// got switched before setting any home DERP or that DERP IDs match across +// DERP maps. Since the risk of this happening is small and the consequences +// of this is is just a possible less optimal DERP until the next reSTUN, +// accept this possibility. +func (b *LocalBackend) onHomeDERPUpdateLocked(du magicsock.HomeDERPChanged) { + cn := b.currentNode() + + if cn == nil || cn.DERPMap() == nil || cn.DERPMap().Regions == nil { + return + } + + if _, ok := cn.DERPMap().Regions[du.New]; !ok { + return + } + + if !cn.homeDERP.CompareAndSwap(int64(du.Old), int64(du.New)) { + return + } + + // Persist the full netmap (including up-to-date Peers) to disk for + // fast restart. + if err := b.writeNetmapToDiskLocked(b.NetMapWithPeers()); err != nil { + b.logf("write netmap to cache: %v", err) + } +} + func (b *LocalBackend) Clock() tstime.Clock { return b.clock } func (b *LocalBackend) Sys() *tsd.System { return b.sys } @@ -975,7 +1027,7 @@ func (b *LocalBackend) pauseOrResumeControlClientLocked() { return } networkUp := b.interfaceState.AnyInterfaceUp() - pauseForNetwork := (b.state == ipn.Stopped && b.NetMap() != nil) || (!networkUp && !testenv.InTest() && !envknob.AssumeNetworkUp()) + pauseForNetwork := (b.state == ipn.Stopped && b.NetMapNoPeers() != nil) || (!networkUp && !testenv.InTest() && !envknob.AssumeNetworkUp()) prefs := b.pm.CurrentPrefs() pauseForSyncPref := prefs.Valid() && prefs.Sync().EqualBool(false) @@ -1571,6 +1623,28 @@ func (b *LocalBackend) PeerCaps(src netip.Addr) tailcfg.PeerCapMap { return b.currentNode().PeerCaps(src) } +// PeerCapsForIP returns the capabilities that remote src IP has when +// talking to the given destination IP on this node. +func (b *LocalBackend) PeerCapsForIP(src, dst netip.Addr) tailcfg.PeerCapMap { + return b.currentNode().PeerCapsForIP(src, dst) +} + +// PeerCapsForService returns the capabilities that remote src IP has when +// talking to the named VIP service on this node. +func (b *LocalBackend) PeerCapsForService(src netip.Addr, svcName tailcfg.ServiceName) tailcfg.PeerCapMap { + return b.currentNode().PeerCapsForService(src, svcName) +} + +// PeerByID returns the current full [tailcfg.Node] for the peer with the +// given NodeID, in O(1) time. It returns ok=false if no such peer is in +// the current netmap. +// +// It is intended for callers that need the latest state of a single peer +// without fetching the entire netmap. +func (b *LocalBackend) PeerByID(id tailcfg.NodeID) (n tailcfg.NodeView, ok bool) { + return b.currentNode().NodeByID(id) +} + func (b *LocalBackend) GetFilterForTest() *filter.Filter { testenv.AssertInTest() nb := b.currentNode() @@ -1821,7 +1895,24 @@ func (b *LocalBackend) setControlClientStatusLocked(c controlclient.Client, st c } b.e.SetNetworkMap(st.NetMap) - b.MagicConn().SetDERPMap(st.NetMap.DERPMap) + + var cachedHome int + if c == nil && st.NetMap.Cached && st.NetMap.SelfNode.Valid() { + cachedHome = st.NetMap.SelfNode.HomeDERP() + } + if cachedHome != 0 { + // Loading from a cached netmap (c == nil means no live control + // client). Pre-seed the home DERP from the cached self node so + // that the guard in maybeSetNearestDERP prevents changing the + // DERP home before we reconnect to the control plane. If the cache has + // nothing in it, skip this, and let the node pick a DERP itself. + b.MagicConn().SetDERPMapWithoutReSTUN(st.NetMap.DERPMap) + b.health.SetOutOfPollNetMap() + b.MagicConn().ForceSetNearestDERP(cachedHome) + } else { + b.MagicConn().SetDERPMap(st.NetMap.DERPMap) + } + b.MagicConn().SetOnlyTCP443(st.NetMap.HasCap(tailcfg.NodeAttrOnlyTCP443)) // Update our cached DERP map @@ -1830,7 +1921,15 @@ func (b *LocalBackend) setControlClientStatusLocked(c controlclient.Client, st c // Update the DERP map in the health package, which uses it for health notifications b.health.SetDERPMap(st.NetMap.DERPMap) - b.sendLocked(ipn.Notify{NetMap: st.NetMap}) + // Notify watchers that the self node may have changed. Reactive + // consumers (containerboot, kube agents, sniproxy, etc.) listen on + // this signal and re-fetch peers/DNS via the LocalAPI if they need + // more than self info. + var selfChange *tailcfg.Node + if st.NetMap.SelfNode.Valid() { + selfChange = st.NetMap.SelfNode.AsStruct() + } + b.sendLocked(ipn.Notify{NetMap: st.NetMap, SelfChange: selfChange}) // The error here is unimportant as is the result. This will recalculate the suggested exit node // cache the value and push any changes to the IPN bus. @@ -2471,6 +2570,14 @@ func (b *LocalBackend) PeersForTest() []tailcfg.NodeView { return b.currentNode().PeersForTest() } +// AwaitNodeKeyForTest returns a channel that is closed once a peer with the +// given node key first appears in the current netmap. If the peer is already +// present, the returned channel is already closed. See +// [nodeBackend.AwaitNodeKeyForTest]. +func (b *LocalBackend) AwaitNodeKeyForTest(k key.NodePublic) <-chan struct{} { + return b.currentNode().AwaitNodeKeyForTest(k) +} + func (b *LocalBackend) getNewControlClientFuncLocked() clientGen { if b.ccGen == nil { // Initialize it rather than just returning the @@ -2767,7 +2874,7 @@ func (b *LocalBackend) startLocked(opts ipn.Options) error { // Without this, the state machine transitions to "NeedsLogin" implying // that user interaction is required, which is not the case and can // regress tsnet.Server restarts. - cc.Login(controlclient.LoginDefault) + cc.Login(b.loginFlags) } b.stateMachineLocked() @@ -3563,12 +3670,11 @@ func (b *LocalBackend) setAuthURLLocked(url string) { // // b.mu must be held. func (b *LocalBackend) popBrowserAuthNowLocked(url string, keyExpired bool, recipient ipnauth.Actor) { - b.logf("popBrowserAuthNow(%q): url=%v, key-expired=%v, seamless-key-renewal=%v", maybeUsernameOf(recipient), url != "", keyExpired, b.seamlessRenewalEnabled()) + b.logf("popBrowserAuthNow(%q): url=%v, key-expired=%v", maybeUsernameOf(recipient), url != "", keyExpired) - // Deconfigure the local network data plane if: - // - seamless key renewal is not enabled; - // - key is expired (in which case tailnet connectivity is down anyway). - if !b.seamlessRenewalEnabled() || keyExpired { + // Deconfigure the local network data plane if the key is expired + // (in which case tailnet connectivity is down anyway). + if keyExpired { b.blockEngineUpdatesLocked(true) b.stopEngineAndWaitLocked() @@ -3985,7 +4091,8 @@ func (b *LocalBackend) pingPeerAPI(ctx context.Context, ip netip.Addr) (peer tai var zero tailcfg.NodeView ctx, cancel := context.WithTimeout(ctx, 10*time.Second) defer cancel() - nm := b.NetMap() + // PeerByTailscaleIP needs an up-to-date Peers slice. + nm := b.NetMapWithPeers() if nm == nil { return zero, "", errors.New("no netmap") } @@ -4212,6 +4319,8 @@ func (b *LocalBackend) CurrentUserForTest() (ipn.WindowsUserID, ipnauth.Actor) { return b.pm.CurrentUserID(), b.currentUser } +// CheckPrefs validates the provided user modifiable settings for correctness +// and returns an error if they are invalid for the current backend. func (b *LocalBackend) CheckPrefs(p *ipn.Prefs) error { b.mu.Lock() defer b.mu.Unlock() @@ -4842,7 +4951,7 @@ func (b *LocalBackend) setPrefsLocked(newp *ipn.Prefs) ipn.PrefsView { if !oldp.WantRunning() && newp.WantRunning && cc != nil { b.logf("transitioning to running; doing Login...") - cc.Login(controlclient.LoginDefault) + cc.Login(b.loginFlags) } if oldp.WantRunning() != newp.WantRunning { @@ -4893,7 +5002,7 @@ func (b *LocalBackend) handlePeerAPIConn(remote, local netip.AddrPort, c net.Con } func (b *LocalBackend) isLocalIP(ip netip.Addr) bool { - nm := b.NetMap() + nm := b.NetMapNoPeers() return nm != nil && views.SliceContains(nm.GetAddresses(), netip.PrefixFrom(ip, ip.BitLen())) } @@ -5045,10 +5154,67 @@ func extractPeerAPIPorts(services []tailcfg.Service) portPair { // NetMap returns the latest cached network map received from // controlclient, or nil if no network map was received yet. +// +// Deprecated: callers should declare their needs explicitly by calling +// either [LocalBackend.NetMapNoPeers] (cheap; for code that reads +// non-Peers fields like SelfNode, DNS, PacketFilter, capabilities) or +// [LocalBackend.NetMapWithPeers] (currently the same; will be made to +// return an up-to-date Peers slice in a follow-up change, at the cost of +// O(N) work per call). NetMap will eventually be removed. func (b *LocalBackend) NetMap() *netmap.NetworkMap { return b.currentNode().NetMap() } +// NetMapNoPeers returns the latest cached network map received from +// controlclient WITHOUT a freshly-built Peers slice. +// +// On a tailnet with frequent peer churn the cached netmap's Peers slice +// can be stale relative to the live per-node-backend peers map; non-Peers +// fields (SelfNode, DNS, PacketFilter, capabilities, ...) are always +// current. Use this for any caller that does not need to iterate Peers, +// since it's O(1) regardless of tailnet size. +// +// Returns nil if no network map has been received yet. +func (b *LocalBackend) NetMapNoPeers() *netmap.NetworkMap { + return b.currentNode().NetMap() +} + +// NetMapWithPeers returns the latest network map with the Peers slice +// populated. +// +// Currently this is the same as [LocalBackend.NetMapNoPeers]: the cached +// netmap's Peers slice may be stale relative to the live per-node-backend +// peers map. A follow-up change will switch this method to return a +// freshly-built netmap with up-to-date Peers, at O(N) cost per call. +// Callers that genuinely need the up-to-date peer set should use this +// method (and document why) so the upcoming change reaches them. +// +// Returns nil if no network map has been received yet. +func (b *LocalBackend) NetMapWithPeers() *netmap.NetworkMap { + return b.currentNode().NetMap() +} + +// lookupPeerByIP returns the node public key for the peer that owns the +// given IP address. It is the fast path for [Engine.SetPeerByIPPacketFunc], +// handling exact-IP matches against node addresses; subnet routes and exit +// nodes are handled by a BART-based fallback in userspaceEngine that uses +// the wireguard-filtered peer list (see lastCfgFull). +// +// It is called by wireguard-go on every outbound packet (not cached), so +// it must be fast. +func (b *LocalBackend) lookupPeerByIP(ip netip.Addr) (key.NodePublic, bool) { + nb := b.currentNode() + nid, ok := nb.NodeByAddr(ip) + if !ok { + return key.NodePublic{}, false + } + peer, ok := nb.NodeByID(nid) + if !ok { + return key.NodePublic{}, false + } + return peer.Key(), true +} + func (b *LocalBackend) isEngineBlocked() bool { b.mu.Lock() defer b.mu.Unlock() @@ -5199,7 +5365,6 @@ func (b *LocalBackend) authReconfig() { // // b.mu must be held. func (b *LocalBackend) authReconfigLocked() { - if b.shutdownCalled { b.logf("[v1] authReconfig: skipping because in shutdown") return @@ -5307,18 +5472,22 @@ func shouldUseOneCGNATRoute(logf logger.Logf, mon *netmon.Monitor, controlKnobs return true } - // Also prefer to do this on the Mac, so that we don't need to constantly - // update the network extension configuration (which is disruptive to - // Chrome, see https://github.com/tailscale/tailscale/issues/3102). Only - // use fine-grained routes if another interfaces is also using the CGNAT + // Prefer a single CGNAT route on platforms where updateing the VPN + // configuration is espensive. On macOS, changing the network extension + // configuration can disrupt existing connections notably Chrome; see + // https://github.com/tailscale/tailscale/issues/3102). On Android, updating + // VpnService.Builder configuration requires establishing a new VPN interface, + // which tears down long lived TCP connections. + // + // Only use fine-grained routes if another interfaces is also using the CGNAT // IP range. - if versionOS == "macOS" { + if versionOS == "macOS" || versionOS == "android" { hasCGNATInterface, err := mon.HasCGNATInterface() if err != nil { logf("shouldUseOneCGNATRoute: Could not determine if any interfaces use CGNAT: %v", err) return false } - logf("[v1] shouldUseOneCGNATRoute: macOS automatic=%v", !hasCGNATInterface) + logf("[v1] shouldUseOneCGNATRoute: %s automatic=%v", versionOS, !hasCGNATInterface) if !hasCGNATInterface { return true } @@ -5673,13 +5842,14 @@ func (b *LocalBackend) routerConfigLocked(cfg *wgcfg.Config, prefs ipn.PrefsView } rs := &router.Config{ - LocalAddrs: unmapIPPrefixes(cfg.Addresses), - SubnetRoutes: unmapIPPrefixes(prefs.AdvertiseRoutes().AsSlice()), - SNATSubnetRoutes: !prefs.NoSNAT(), - StatefulFiltering: doStatefulFiltering, - NetfilterMode: prefs.NetfilterMode(), - Routes: peerRoutes(b.logf, cfg.Peers, singleRouteThreshold, prefs.RouteAll()), - NetfilterKind: netfilterKind, + LocalAddrs: unmapIPPrefixes(cfg.Addresses), + SubnetRoutes: unmapIPPrefixes(prefs.AdvertiseRoutes().AsSlice()), + SNATSubnetRoutes: !prefs.NoSNAT(), + StatefulFiltering: doStatefulFiltering, + NetfilterMode: prefs.NetfilterMode(), + Routes: peerRoutes(b.logf, cfg.Peers, singleRouteThreshold, prefs.RouteAll()), + NetfilterKind: netfilterKind, + RemoveCGNATDropRule: nm.HasCap(tailcfg.NodeAttrDisableLinuxCGNATDropRule), } if buildfeatures.HasSynology && distro.Get() == distro.Synology { @@ -5927,9 +6097,9 @@ func (b *LocalBackend) enterStateLocked(newState ipn.State) { switch newState { case ipn.NeedsLogin: feature.SystemdStatus("Needs login: %s", authURL) - // always block updates on NeedsLogin even if seamless renewal is enabled, - // to prevent calls to authReconfigLocked from reconfiguring the engine when our - // key has expired and we're waiting to authenticate to use the new key. + // always block updates on NeedsLogin, to prevent calls to authReconfigLocked + // from reconfiguring the engine when our key has expired and we're waiting + // to authenticate to use the new key. b.blockEngineUpdatesLocked(true) fallthrough case ipn.Stopped, ipn.NoState: @@ -6401,6 +6571,23 @@ func (b *LocalBackend) resolveExitNodeInPrefsLocked(prefs *ipn.Prefs) (changed b // received nm. If nm is nil, it resets all configuration as though // Tailscale is turned off. func (b *LocalBackend) setNetMapLocked(nm *netmap.NetworkMap) { + if buildfeatures.HasCacheNetMap { + // As a defensive measure, if something triggers a panic when we are + // installing a network map, make an effort to discard any cached netmaps. + // This helps avert the possibility that a restart after panic will stick in + // a cycle. Importantly, we do not attempt to swallow or handle the panic, + // since that indicates a real bug. + // + // See https://github.com/tailscale/tailscale/issues/12639 + defer func() { + if p := recover(); p != nil { + b.logf("WARNING: Panic while installing netmap; discardng caches") + b.discardDiskCacheLocked() + panic(p) // propagate + } + }() + } + oldSelf := b.currentNode().NetMap().SelfNodeOrZero() b.dialer.SetNetMap(nm) @@ -6414,7 +6601,11 @@ func (b *LocalBackend) setNetMapLocked(nm *netmap.NetworkMap) { b.currentNode().SetNetMap(nm) if ms, ok := b.sys.MagicSock.GetOK(); ok { if nm != nil { - ms.SetNetworkMap(nm.SelfNode, nm.Peers) + if nm.Cached { + ms.SetNetworkMapCached(nm.SelfNode, nm.Peers) + } else { + ms.SetNetworkMap(nm.SelfNode, nm.Peers) + } } else { ms.SetNetworkMap(tailcfg.NodeView{}, nil) } @@ -6866,7 +7057,7 @@ func (b *LocalBackend) AppConnector() *appc.AppConnector { func (b *LocalBackend) allowExitNodeDNSProxyToServeName(name string) bool { b.mu.Lock() defer b.mu.Unlock() - nm := b.NetMap() + nm := b.NetMapNoPeers() if nm == nil { return false } @@ -7010,11 +7201,28 @@ func (b *LocalBackend) DebugRotateDiscoKey() error { b.mu.Lock() cc := b.cc + wantRunning := b.pm.CurrentPrefs().WantRunning() b.mu.Unlock() if cc != nil { cc.SetDiscoPublicKey(newDiscoKey) } + // Bounce WantRunning to fully reset wireguard-go state for all peers. + if wantRunning { + if _, err := b.EditPrefs(&ipn.MaskedPrefs{ + Prefs: ipn.Prefs{WantRunning: false}, + WantRunningSet: true, + }); err != nil { + return err + } + if _, err := b.EditPrefs(&ipn.MaskedPrefs{ + Prefs: ipn.Prefs{WantRunning: true}, + WantRunningSet: true, + }); err != nil { + return err + } + } + return nil } @@ -7022,6 +7230,25 @@ func (b *LocalBackend) DebugPeerRelayServers() set.Set[netip.Addr] { return b.MagicConn().PeerRelays() } +// DebugPeerDiscoKeys returns the disco public keys this node has learned for +// each of its peers from the most recent network map. Intended for tests +// (the production [ipnstate.PeerStatus] purposefully does not surface disco +// keys; surfacing them via the [ipnstate.Status] API would also pollute +// every PeerStatus consumer with a non-comparable struct field). +func (b *LocalBackend) DebugPeerDiscoKeys() map[key.NodePublic]key.DiscoPublic { + nm := b.currentNode().NetMap() + if nm == nil { + return nil + } + m := make(map[key.NodePublic]key.DiscoPublic, len(nm.Peers)) + for _, p := range nm.Peers { + if dk := p.DiscoKey(); !dk.IsZero() { + m[p.Key()] = dk + } + } + return m +} + // ControlKnobs returns the node's control knobs. func (b *LocalBackend) ControlKnobs() *controlknobs.Knobs { return b.sys.ControlKnobs() @@ -7049,6 +7276,15 @@ func (b *LocalBackend) DoNoiseRequest(req *http.Request) (*http.Response, error) return cc.DoNoiseRequest(req) } +// noiseRoundTripper adapts LocalBackend.DoNoiseRequest to http.RoundTripper. +type noiseRoundTripper struct { + lb *LocalBackend +} + +func (n noiseRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + return n.lb.DoNoiseRequest(req) +} + // ActiveSSHConns returns the number of active SSH connections, // or 0 if SSH is not linked into the binary or available on the platform. func (b *LocalBackend) ActiveSSHConns() int { @@ -7570,14 +7806,6 @@ func (b *LocalBackend) ReadRouteInfo() (*appctype.RouteInfo, error) { return b.readRouteInfoLocked() } -// seamlessRenewalEnabled reports whether seamless key renewals are enabled. -// -// As of 2025-09-11, this is the default behaviour unless nodes receive -// [tailcfg.NodeAttrDisableSeamlessKeyRenewal] in their netmap. -func (b *LocalBackend) seamlessRenewalEnabled() bool { - return b.ControlKnobs().SeamlessKeyRenewal.Load() -} - var ( disallowedAddrs = []netip.Addr{ netip.MustParseAddr("::1"), diff --git a/ipn/ipnlocal/local_test.go b/ipn/ipnlocal/local_test.go index d930735cd..21188d784 100644 --- a/ipn/ipnlocal/local_test.go +++ b/ipn/ipnlocal/local_test.go @@ -32,6 +32,7 @@ import ( "tailscale.com/appc" "tailscale.com/appc/appctest" "tailscale.com/control/controlclient" + "tailscale.com/control/controlknobs" "tailscale.com/drive" "tailscale.com/drive/driveimpl" "tailscale.com/feature" @@ -8101,3 +8102,133 @@ func TestNoSNATWithAdvertisedExitNodeWarning(t *testing.T) { } }) } + +// TestStartPreservesLoginFlags is a regression test for a bug where the +// LoginEphemeral flag stored on LocalBackend was silently dropped by the +// auto-login paths in Start() and setPrefsLocked(). The user-visible symptom +// was tsnet.Server.Ephemeral=true being ignored when combined with an auth +// key, because the resulting RegisterRequest.Ephemeral was false. +// +// The test manually constructs the LocalBackend to be able set +// loginFlags=LoginEphemeral, and then checks that at least one cc.Login call +// carried the LoginEphemeral bit. +func TestStartPreservesLoginFlags(t *testing.T) { + logf := tstest.WhileTestRunningLogger(t) + sys := tsd.NewSystem() + sys.Set(new(mem.Store)) + e, err := wgengine.NewFakeUserspaceEngine(logf, sys.Set, sys.HealthTracker.Get(), sys.UserMetricsRegistry(), sys.Bus.Get()) + if err != nil { + t.Fatalf("NewFakeUserspaceEngine: %v", err) + } + t.Cleanup(e.Close) + sys.Set(e) + + b, err := NewLocalBackend(logf, logid.PublicID{}, sys, controlclient.LoginEphemeral) + if err != nil { + t.Fatalf("NewLocalBackend: %v", err) + } + t.Cleanup(b.Shutdown) + + var cc *mockControl + b.SetControlClientGetterForTesting(func(opts controlclient.Options) (controlclient.Client, error) { + cc = newClient(t, opts) + return cc, nil + }) + + if err := b.Start(ipn.Options{ + UpdatePrefs: &ipn.Prefs{ + ControlURL: "https://controlplane.example.com", + WantRunning: false, + }, + AuthKey: "tskey-auth-test", + }); err != nil { + t.Fatalf("Start: %v", err) + } + + if _, err := b.EditPrefs(&ipn.MaskedPrefs{ + Prefs: ipn.Prefs{WantRunning: true}, + WantRunningSet: true, + }); err != nil { + t.Fatalf("EditPrefs: %v", err) + } + + cc.mu.Lock() + flags := cc.loginFlags + cc.mu.Unlock() + if flags&controlclient.LoginEphemeral == 0 { + t.Errorf("cc.Login was never called with LoginEphemeral; got flags=%v", flags) + } +} + +func TestShouldUseOneCGNATRoute(t *testing.T) { + tests := []struct { + name string + versionOS string + want bool + }{ + {"android", "android", true}, + {"macOS", "macOS", true}, + {"plan9", "plan9", true}, + {"linux", "linux", false}, + {"windows", "windows", false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := shouldUseOneCGNATRoute(t.Logf, nil, nil, tt.versionOS) + if got != tt.want { + t.Errorf("shouldUseOneCGNATRoute(%q) = %v; want %v", tt.versionOS, got, tt.want) + } + }) + } + + // Control knob takes precedence over everything. + t.Run("control-knob-override", func(t *testing.T) { + knobs := &controlknobs.Knobs{} + knobs.OneCGNAT.Store(opt.NewBool(false)) + if got := shouldUseOneCGNATRoute(t.Logf, nil, knobs, "android"); got { + t.Error("control knob should override android default; got true, want false") + } + knobs.OneCGNAT.Store(opt.NewBool(true)) + if got := shouldUseOneCGNATRoute(t.Logf, nil, knobs, "linux"); !got { + t.Error("control knob should override linux default; got false, want true") + } + }) +} + +func TestPeerRoutesCGNATCollapse(t *testing.T) { + pp := netip.MustParsePrefix + + // With cgnatThreshold=1 (oneCGNATRoute), adding a peer should not + // change the route list. Both collapse to a single 100.64.0.0/10. + twoPeers := []wgcfg.Peer{ + {AllowedIPs: []netip.Prefix{pp("100.64.0.1/32")}}, + {AllowedIPs: []netip.Prefix{pp("100.64.0.2/32")}}, + } + threePeers := []wgcfg.Peer{ + {AllowedIPs: []netip.Prefix{pp("100.64.0.1/32")}}, + {AllowedIPs: []netip.Prefix{pp("100.64.0.2/32")}}, + {AllowedIPs: []netip.Prefix{pp("100.64.0.3/32")}}, + } + + routesTwo := peerRoutes(t.Logf, twoPeers, 1, true) + routesThree := peerRoutes(t.Logf, threePeers, 1, true) + + wantCGNAT := []netip.Prefix{pp("100.64.0.0/10")} + if !reflect.DeepEqual(routesTwo, wantCGNAT) { + t.Errorf("two peers: got %v; want %v", routesTwo, wantCGNAT) + } + if !reflect.DeepEqual(routesThree, wantCGNAT) { + t.Errorf("three peers: got %v; want %v", routesThree, wantCGNAT) + } + + // Subnet routes must still appear alongside the collapsed CGNAT route. + peersWithSubnet := []wgcfg.Peer{ + {AllowedIPs: []netip.Prefix{pp("100.64.0.1/32")}}, + {AllowedIPs: []netip.Prefix{pp("100.64.0.2/32"), pp("10.0.0.0/24")}}, + } + got := peerRoutes(t.Logf, peersWithSubnet, 1, true) + want := []netip.Prefix{pp("100.64.0.0/10"), pp("10.0.0.0/24")} + if !reflect.DeepEqual(got, want) { + t.Errorf("with subnet: got %v; want %v", got, want) + } +} diff --git a/ipn/ipnlocal/network-lock.go b/ipn/ipnlocal/network-lock.go index 535d9803d..3238a0a07 100644 --- a/ipn/ipnlocal/network-lock.go +++ b/ipn/ipnlocal/network-lock.go @@ -27,6 +27,7 @@ import ( "tailscale.com/health/healthmsg" "tailscale.com/ipn" "tailscale.com/ipn/ipnstate" + "tailscale.com/ipn/store/mem" "tailscale.com/net/tsaddr" "tailscale.com/tailcfg" "tailscale.com/tka" @@ -38,6 +39,7 @@ import ( "tailscale.com/types/tkatype" "tailscale.com/util/mak" "tailscale.com/util/set" + "tailscale.com/util/testenv" ) // TODO(tom): RPC retry/backoff was broken and has been removed. Fix? @@ -45,13 +47,15 @@ import ( var ( errMissingNetmap = errors.New("missing netmap: verify that you are logged in") errNetworkLockNotActive = errors.New("tailnet-lock is not active") - - tkaCompactionDefaults = tka.CompactionOptions{ - MinChain: 24, // Keep at minimum 24 AUMs since head. - MinAge: 14 * 24 * time.Hour, // Keep 2 weeks of AUMs. - } ) +// IsNetworkLockNotActive reports whether the given error indicates that +// tailnet-lock is not active. Stop-gap for feature/tailnetlock to check this +// until all of this is code is moved to the feature. +func IsNetworkLockNotActive(err error) bool { + return errors.Is(err, errNetworkLockNotActive) +} + type tkaState struct { profile ipn.ProfileID authority *tka.Authority @@ -76,13 +80,13 @@ func (b *LocalBackend) initTKALocked() error { root := b.TailscaleVarRoot() if root == "" { b.tka = nil - b.logf("cannot fetch existing TKA state; no state directory for network-lock") + b.logf("cannot fetch existing TKA state; no state directory for tailnet-lock") return nil } chonkDir := b.chonkPathLocked() if _, err := os.Stat(chonkDir); err == nil { - // The directory exists, which means network-lock has been initialized. + // The directory exists, which means tailnet-lock has been initialized. storage, err := tka.ChonkDir(chonkDir) if err != nil { return fmt.Errorf("opening tailchonk: %v", err) @@ -92,16 +96,12 @@ func (b *LocalBackend) initTKALocked() error { return fmt.Errorf("initializing tka: %v", err) } - if err := authority.Compact(storage, tkaCompactionDefaults); err != nil { - b.logf("tka compaction failed: %v", err) - } - b.tka = &tkaState{ profile: cp.ID(), authority: authority, storage: storage, } - b.logf("tka initialized at head %x", authority.Head()) + b.logf("tka initialized at head %s", authority.Head()) } return nil @@ -139,12 +139,12 @@ func (b *LocalBackend) tkaFilterNetmapLocked(nm *netmap.NetworkMap) { continue } if p.KeySignature().Len() == 0 { - b.logf("Network lock is dropping peer %v(%v) due to missing signature", p.ID(), p.StableID()) + b.logf("Tailnet lock is dropping peer %v(%v) due to missing signature", p.ID(), p.StableID()) mak.Set(&toDelete, i, true) } else { details, err := b.tka.authority.NodeKeyAuthorizedWithDetails(p.Key(), p.KeySignature().AsSlice()) if err != nil { - b.logf("Network lock is dropping peer %v(%v) due to failed signature check: %v", p.ID(), p.StableID(), err) + b.logf("Tailnet lock is dropping peer %v(%v) due to failed signature check: %v", p.ID(), p.StableID(), err) mak.Set(&toDelete, i, true) continue } @@ -166,7 +166,7 @@ func (b *LocalBackend) tkaFilterNetmapLocked(nm *netmap.NetworkMap) { peers = append(peers, p) } else { if obsoleteByRotation.Contains(p.Key()) { - b.logf("Network lock is dropping peer %v(%v) due to key rotation", p.ID(), p.StableID()) + b.logf("Tailnet lock is dropping peer %v(%v) due to key rotation", p.ID(), p.StableID()) } // Record information about the node we filtered out. filtered = append(filtered, tkaStateFromPeer(p)) @@ -304,7 +304,11 @@ func (b *LocalBackend) tkaSyncIfNeeded(nm *netmap.NetworkMap, prefs ipn.PrefsVie wantEnabled := nm.TKAEnabled if isEnabled || wantEnabled { - b.logf("tkaSyncIfNeeded: isEnabled=%t, wantEnabled=%t, head=%v", isEnabled, wantEnabled, nm.TKAHead) + nodeHead := "" + if b.tka != nil { + nodeHead = b.tka.authority.Head().String() + } + b.logf("tkaSyncIfNeeded: isEnabled=%t, wantEnabled=%t, nodeHead=%v, netmapHead=%v", isEnabled, wantEnabled, nodeHead, nm.TKAHead) } ourNodeKey, ok := prefs.Persist().PublicNodeKeyOK() @@ -360,7 +364,7 @@ func (b *LocalBackend) tkaSyncIfNeeded(nm *netmap.NetworkMap, prefs ipn.PrefsVie // // We run this on every sync so that clients compact consistently. In many // cases this will be a no-op. - if err := b.tka.authority.Compact(b.tka.storage, tkaCompactionDefaults); err != nil { + if err := b.tka.authority.Compact(b.tka.storage, tka.CompactionDefaults); err != nil { return fmt.Errorf("tka compact: %w", err) } } @@ -492,7 +496,7 @@ func (b *LocalBackend) tkaBootstrapFromGenesisLocked(g tkatype.MarshaledAUM, per var storage tka.CompactableChonk if root == "" { b.health.SetUnhealthy(noNetworkLockStateDirWarnable, nil) - b.logf("network-lock using in-memory storage; no state directory") + b.logf("tailnet-lock using in-memory storage; no state directory") storage = tka.ChonkMem() } else { chonkDir := b.chonkPathLocked() @@ -620,7 +624,7 @@ func tkaStateFromPeer(p tailcfg.NodeView) ipnstate.TKAPeer { return fp } -// NetworkLockInit enables network-lock for the tailnet, with the tailnets' +// NetworkLockInit enables tailnet-lock for the tailnet, with the tailnets' // key authority initialized to trust the provided keys. // // Initialization involves two RPCs with control, termed 'begin' and 'finish'. @@ -628,7 +632,7 @@ func tkaStateFromPeer(p tailcfg.NodeView) ipnstate.TKAPeer { // encodes the initial state of the authority, and the list of all nodes // needing signatures is returned as a response. // The Finish RPC submits signatures for all these nodes, at which point -// Control has everything it needs to atomically enable network lock. +// Control has everything it needs to atomically enable tailnet lock. func (b *LocalBackend) NetworkLockInit(keys []tka.Key, disablementValues [][]byte, supportDisablement []byte) error { var ourNodeKey key.NodePublic var nlPriv key.NLPrivate @@ -663,7 +667,7 @@ func (b *LocalBackend) NetworkLockInit(keys []tka.Key, disablementValues [][]byt return fmt.Errorf("tka.Create: %v", err) } - b.logf("Generated genesis AUM to initialize network lock, trusting the following keys:") + b.logf("Generated genesis AUM to initialize tailnet lock, trusting the following keys:") for i, k := range genesisAUM.State.Keys { b.logf(" - key[%d] = tlpub:%x with %d votes", i, k.Public, k.Votes) } @@ -678,7 +682,7 @@ func (b *LocalBackend) NetworkLockInit(keys []tka.Key, disablementValues [][]byt // node-key signatures, we need to sign keys for all the existing nodes. // If we don't get these signatures ahead of time, everyone will lose // connectivity because control won't have any signatures to send which - // satisfy network-lock checks. + // satisfy tailnet-lock checks. sigs := make(map[tailcfg.NodeID]tkatype.MarshaledSignature, len(initResp.NeedSignatures)) for _, nodeInfo := range initResp.NeedSignatures { nks, err := signNodeKey(nodeInfo, nlPriv) @@ -703,6 +707,7 @@ func (b *LocalBackend) NetworkLockAllowed() bool { // Only use is in tests. func (b *LocalBackend) NetworkLockVerifySignatureForTest(nks tkatype.MarshaledSignature, nodeKey key.NodePublic) error { + testenv.AssertInTest() b.mu.Lock() defer b.mu.Unlock() if b.tka == nil { @@ -713,10 +718,11 @@ func (b *LocalBackend) NetworkLockVerifySignatureForTest(nks tkatype.MarshaledSi // Only use is in tests. func (b *LocalBackend) NetworkLockKeyTrustedForTest(keyID tkatype.KeyID) bool { + testenv.AssertInTest() b.mu.Lock() defer b.mu.Unlock() if b.tka == nil { - panic("network lock not initialized") + panic("tailnet lock not initialized") } return b.tka.authority.KeyTrusted(keyID) } @@ -790,7 +796,7 @@ func (b *LocalBackend) NetworkLockSign(nodeKey key.NodePublic, rotationPublic [] return err } - b.logf("Generated network-lock signature for %v, submitting to control plane", nodeKey) + b.logf("Generated tailnet-lock signature for %v, submitting to control plane", nodeKey) if _, err := b.tkaSubmitSignature(ourNodeKey, sig.Serialize()); err != nil { return err } @@ -877,7 +883,7 @@ func (b *LocalBackend) NetworkLockModify(addKeys, removeKeys []tka.Key) (err err return nil } -// NetworkLockDisable disables network-lock using the provided disablement secret. +// NetworkLockDisable disables tailnet-lock using the provided disablement secret. func (b *LocalBackend) NetworkLockDisable(secret []byte) error { var ( ourNodeKey key.NodePublic @@ -1482,3 +1488,24 @@ func (b *LocalBackend) tkaReadAffectedSigs(ourNodeKey key.NodePublic, key tkatyp return a, nil } + +// LocalBackendWithTKAForTest creates a LocalBackend with an initialized TKA +// state for testing tailnet lock from the feature/tailnetlock package. Will be +// removed when tailnet lock is fully moved to its own package. Do not use this +// from any other package. +func LocalBackendWithTKAForTest(chonk tka.CompactableChonk, tka *tka.Authority) *LocalBackend { + testenv.AssertInTest() + + var state *tkaState + if tka != nil { + state = &tkaState{ + authority: tka, + storage: chonk, + } + } + return &LocalBackend{ + store: &mem.Store{}, + logf: logger.Discard, + tka: state, + } +} diff --git a/ipn/ipnlocal/network-lock_test.go b/ipn/ipnlocal/network-lock_test.go index eead2d892..1fceb748a 100644 --- a/ipn/ipnlocal/network-lock_test.go +++ b/ipn/ipnlocal/network-lock_test.go @@ -84,7 +84,29 @@ func fakeNoiseServer(t *testing.T, handler http.HandlerFunc) (*httptest.Server, return ts, client } +// newLocalBackendForTKA creates a new instance of [LocalBackend] for testing +// Tailnet Lock, in particular setting the tka field. +func newLocalBackendForTKA(t *testing.T, varRoot string, client *http.Client, pm *profileManager, authority *tka.Authority, chonk tka.CompactableChonk) LocalBackend { + t.Helper() + cc := fakeControlClient(t, client) + return LocalBackend{ + varRoot: varRoot, + cc: cc, + ccAuto: cc, + logf: t.Logf, + health: health.NewTracker(eventbustest.NewBus(t)), + tka: &tkaState{ + profile: pm.CurrentProfile().ID(), + authority: authority, + storage: chonk, + }, + pm: pm, + store: pm.Store(), + } +} + func setupProfileManager(t *testing.T, nodePriv key.NodePrivate, nlPriv key.NLPrivate) *profileManager { + t.Helper() pm := must.Get(newProfileManager(new(mem.Store), t.Logf, health.NewTracker(eventbustest.NewBus(t)))) must.Do(pm.SetPrefs((&ipn.Prefs{ Persist: &persist.Persist{ @@ -95,6 +117,18 @@ func setupProfileManager(t *testing.T, nodePriv key.NodePrivate, nlPriv key.NLPr return pm } +// setupChonkStorage creates a new [tka.FS] in a temporary folder. +func setupChonkStorage(t *testing.T, pm *profileManager) (varRoot string, chonk *tka.FS) { + varRoot = t.TempDir() + tkaPath := filepath.Join(varRoot, "tka-profile", string(pm.CurrentProfile().ID())) + os.Mkdir(tkaPath, 0755) + chonk, err := tka.ChonkDir(tkaPath) + if err != nil { + t.Fatal(err) + } + return varRoot, chonk +} + func TestTKAEnablementFlow(t *testing.T) { nodePriv := key.NewNode() @@ -102,11 +136,9 @@ func TestTKAEnablementFlow(t *testing.T) { // our mock server can communicate. nlPriv := key.NewNLPrivate() key := tka.Key{Kind: tka.Key25519, Public: nlPriv.Public().Verifier(), Votes: 2} + state := tka.CreateStateForTest(key) chonk := tka.ChonkMem() - a1, genesisAUM, err := tka.Create(chonk, tka.State{ - Keys: []tka.Key{key}, - DisablementValues: [][]byte{bytes.Repeat([]byte{0xa5}, 32)}, - }, nlPriv) + a1, genesisAUM, err := tka.Create(chonk, state, nlPriv) if err != nil { t.Fatalf("tka.Create() failed: %v", err) } @@ -188,13 +220,7 @@ func TestTKADisablementFlow(t *testing.T) { pm := setupProfileManager(t, nodePriv, nlPriv) - temp := t.TempDir() - tkaPath := filepath.Join(temp, "tka-profile", string(pm.CurrentProfile().ID())) - os.Mkdir(tkaPath, 0755) - chonk, err := tka.ChonkDir(tkaPath) - if err != nil { - t.Fatal(err) - } + varRoot, chonk := setupChonkStorage(t, pm) authority, _, err := tka.Create(chonk, tka.State{ Keys: []tka.Key{key}, DisablementValues: [][]byte{tka.DisablementKDF(disablementSecret)}, @@ -239,20 +265,7 @@ func TestTKADisablementFlow(t *testing.T) { })) defer ts.Close() - cc := fakeControlClient(t, client) - b := LocalBackend{ - varRoot: temp, - cc: cc, - ccAuto: cc, - logf: t.Logf, - health: health.NewTracker(eventbustest.NewBus(t)), - tka: &tkaState{ - authority: authority, - storage: chonk, - }, - pm: pm, - store: pm.Store(), - } + b := newLocalBackendForTKA(t, varRoot, client, pm, authority, chonk) // Test that the wrong disablement secret does not shut down the authority. returnWrongSecret = true @@ -289,8 +302,6 @@ func TestTKASync(t *testing.T) { someKeyPriv := key.NewNLPrivate() someKey := tka.Key{Kind: tka.Key25519, Public: someKeyPriv.Public().Verifier(), Votes: 1} - disablementSecret := bytes.Repeat([]byte{0xa5}, 32) - type tkaSyncScenario struct { name string // controlAUMs is called (if non-nil) to get any AUMs which the tka state @@ -369,10 +380,8 @@ func TestTKASync(t *testing.T) { // Setup the tka authority on the control plane. key := tka.Key{Kind: tka.Key25519, Public: nlPriv.Public().Verifier(), Votes: 2} controlStorage := tka.ChonkMem() - controlAuthority, bootstrap, err := tka.Create(controlStorage, tka.State{ - Keys: []tka.Key{key, someKey}, - DisablementValues: [][]byte{tka.DisablementKDF(disablementSecret)}, - }, nlPriv) + controlState := tka.CreateStateForTest(key, someKey) + controlAuthority, bootstrap, err := tka.Create(controlStorage, controlState, nlPriv) if err != nil { t.Fatalf("tka.Create() failed: %v", err) } @@ -382,14 +391,8 @@ func TestTKASync(t *testing.T) { } } - temp := t.TempDir() - tkaPath := filepath.Join(temp, "tka-profile", string(pm.CurrentProfile().ID())) - os.Mkdir(tkaPath, 0755) // Setup the TKA authority on the node. - nodeStorage, err := tka.ChonkDir(tkaPath) - if err != nil { - t.Fatal(err) - } + varRoot, nodeStorage := setupChonkStorage(t, pm) nodeAuthority, err := tka.Bootstrap(nodeStorage, bootstrap) if err != nil { t.Fatalf("tka.Bootstrap() failed: %v", err) @@ -424,20 +427,7 @@ func TestTKASync(t *testing.T) { defer ts.Close() // Setup the client. - cc := fakeControlClient(t, client) - b := LocalBackend{ - varRoot: temp, - cc: cc, - ccAuto: cc, - logf: t.Logf, - health: health.NewTracker(eventbustest.NewBus(t)), - pm: pm, - store: pm.Store(), - tka: &tkaState{ - authority: nodeAuthority, - storage: nodeStorage, - }, - } + b := newLocalBackendForTKA(t, varRoot, client, pm, nodeAuthority, nodeStorage) // Finally, let's trigger a sync. err = b.tkaSyncIfNeeded(&netmap.NetworkMap{ @@ -463,8 +453,6 @@ func TestTKASyncTriggersCompact(t *testing.T) { someKeyPriv := key.NewNLPrivate() someKey := tka.Key{Kind: tka.Key25519, Public: someKeyPriv.Public().Verifier(), Votes: 1} - disablementSecret := bytes.Repeat([]byte{0xa5}, 32) - nodePriv := key.NewNode() nlPriv := key.NewNLPrivate() pm := setupProfileManager(t, nodePriv, nlPriv) @@ -480,10 +468,8 @@ func TestTKASyncTriggersCompact(t *testing.T) { key := tka.Key{Kind: tka.Key25519, Public: nlPriv.Public().Verifier(), Votes: 2} controlStorage := tka.ChonkMem() controlStorage.SetClock(clock) - controlAuthority, bootstrap, err := tka.Create(controlStorage, tka.State{ - Keys: []tka.Key{key, someKey}, - DisablementValues: [][]byte{tka.DisablementKDF(disablementSecret)}, - }, nlPriv) + controlState := tka.CreateStateForTest(key, someKey) + controlAuthority, bootstrap, err := tka.Create(controlStorage, controlState, nlPriv) if err != nil { t.Fatalf("tka.Create() failed: %v", err) } @@ -542,19 +528,8 @@ func TestTKASyncTriggersCompact(t *testing.T) { defer ts.Close() // Setup the client. - cc := fakeControlClient(t, client) - b := LocalBackend{ - cc: cc, - ccAuto: cc, - logf: t.Logf, - health: health.NewTracker(eventbustest.NewBus(t)), - pm: pm, - store: pm.Store(), - tka: &tkaState{ - authority: nodeAuthority, - storage: nodeStorage, - }, - } + varRoot := "" + b := newLocalBackendForTKA(t, varRoot, client, pm, nodeAuthority, nodeStorage) // Trigger a sync. err = b.tkaSyncIfNeeded(&netmap.NetworkMap{ @@ -610,11 +585,9 @@ func TestTKASyncTriggersCompact(t *testing.T) { func TestTKAFilterNetmap(t *testing.T) { nlPriv := key.NewNLPrivate() nlKey := tka.Key{Kind: tka.Key25519, Public: nlPriv.Public().Verifier(), Votes: 2} + state := tka.CreateStateForTest(nlKey) storage := tka.ChonkMem() - authority, _, err := tka.Create(storage, tka.State{ - Keys: []tka.Key{nlKey}, - DisablementValues: [][]byte{bytes.Repeat([]byte{0xa5}, 32)}, - }, nlPriv) + authority, _, err := tka.Create(storage, state, nlPriv) if err != nil { t.Fatalf("tka.Create() failed: %v", err) } @@ -764,17 +737,11 @@ func TestTKADisable(t *testing.T) { // Make a fake TKA authority, to seed local state. disablementSecret := bytes.Repeat([]byte{0xa5}, 32) nlPriv := key.NewNLPrivate() + key := tka.Key{Kind: tka.Key25519, Public: nlPriv.Public().Verifier(), Votes: 2} pm := setupProfileManager(t, nodePriv, nlPriv) - temp := t.TempDir() - tkaPath := filepath.Join(temp, "tka-profile", string(pm.CurrentProfile().ID())) - os.Mkdir(tkaPath, 0755) - key := tka.Key{Kind: tka.Key25519, Public: nlPriv.Public().Verifier(), Votes: 2} - chonk, err := tka.ChonkDir(tkaPath) - if err != nil { - t.Fatal(err) - } + temp, chonk := setupChonkStorage(t, pm) authority, _, err := tka.Create(chonk, tka.State{ Keys: []tka.Key{key}, DisablementValues: [][]byte{tka.DisablementKDF(disablementSecret)}, @@ -821,21 +788,7 @@ func TestTKADisable(t *testing.T) { })) defer ts.Close() - cc := fakeControlClient(t, client) - b := LocalBackend{ - varRoot: temp, - cc: cc, - ccAuto: cc, - logf: t.Logf, - health: health.NewTracker(eventbustest.NewBus(t)), - tka: &tkaState{ - profile: pm.CurrentProfile().ID(), - authority: authority, - storage: chonk, - }, - pm: pm, - store: pm.Store(), - } + b := newLocalBackendForTKA(t, temp, client, pm, authority, chonk) // Test that we get an error for an incorrect disablement secret. if err := b.NetworkLockDisable([]byte{1, 2, 3, 4}); err == nil || err.Error() != "incorrect disablement secret" { @@ -854,20 +807,11 @@ func TestTKASign(t *testing.T) { pm := setupProfileManager(t, nodePriv, nlPriv) // Make a fake TKA authority, to seed local state. - disablementSecret := bytes.Repeat([]byte{0xa5}, 32) key := tka.Key{Kind: tka.Key25519, Public: nlPriv.Public().Verifier(), Votes: 2} + state := tka.CreateStateForTest(key) - temp := t.TempDir() - tkaPath := filepath.Join(temp, "tka-profile", string(pm.CurrentProfile().ID())) - os.Mkdir(tkaPath, 0755) - chonk, err := tka.ChonkDir(tkaPath) - if err != nil { - t.Fatal(err) - } - authority, _, err := tka.Create(chonk, tka.State{ - Keys: []tka.Key{key}, - DisablementValues: [][]byte{tka.DisablementKDF(disablementSecret)}, - }, nlPriv) + varRoot, chonk := setupChonkStorage(t, pm) + authority, _, err := tka.Create(chonk, state, nlPriv) if err != nil { t.Fatalf("tka.Create() failed: %v", err) } @@ -887,20 +831,8 @@ func TestTKASign(t *testing.T) { } })) defer ts.Close() - cc := fakeControlClient(t, client) - b := LocalBackend{ - varRoot: temp, - cc: cc, - ccAuto: cc, - logf: t.Logf, - health: health.NewTracker(eventbustest.NewBus(t)), - tka: &tkaState{ - authority: authority, - storage: chonk, - }, - pm: pm, - store: pm.Store(), - } + + b := newLocalBackendForTKA(t, varRoot, client, pm, authority, chonk) if err := b.NetworkLockSign(toSign.Public(), nil); err != nil { t.Errorf("NetworkLockSign() failed: %v", err) @@ -911,23 +843,14 @@ func TestTKAForceDisable(t *testing.T) { nodePriv := key.NewNode() // Make a fake TKA authority, to seed local state. - disablementSecret := bytes.Repeat([]byte{0xa5}, 32) nlPriv := key.NewNLPrivate() key := tka.Key{Kind: tka.Key25519, Public: nlPriv.Public().Verifier(), Votes: 2} + state := tka.CreateStateForTest(key) pm := setupProfileManager(t, nodePriv, nlPriv) - temp := t.TempDir() - tkaPath := filepath.Join(temp, "tka-profile", string(pm.CurrentProfile().ID())) - os.Mkdir(tkaPath, 0755) - chonk, err := tka.ChonkDir(tkaPath) - if err != nil { - t.Fatal(err) - } - authority, genesis, err := tka.Create(chonk, tka.State{ - Keys: []tka.Key{key}, - DisablementValues: [][]byte{tka.DisablementKDF(disablementSecret)}, - }, nlPriv) + temp, chonk := setupChonkStorage(t, pm) + authority, genesis, err := tka.Create(chonk, state, nlPriv) if err != nil { t.Fatalf("tka.Create() failed: %v", err) } @@ -1002,20 +925,11 @@ func TestTKAAffectedSigs(t *testing.T) { pm := setupProfileManager(t, nodePriv, nlPriv) // Make a fake TKA authority, to seed local state. - disablementSecret := bytes.Repeat([]byte{0xa5}, 32) tkaKey := tka.Key{Kind: tka.Key25519, Public: nlPriv.Public().Verifier(), Votes: 2} + state := tka.CreateStateForTest(tkaKey) - temp := t.TempDir() - tkaPath := filepath.Join(temp, "tka-profile", string(pm.CurrentProfile().ID())) - os.Mkdir(tkaPath, 0755) - chonk, err := tka.ChonkDir(tkaPath) - if err != nil { - t.Fatal(err) - } - authority, _, err := tka.Create(chonk, tka.State{ - Keys: []tka.Key{tkaKey}, - DisablementValues: [][]byte{tka.DisablementKDF(disablementSecret)}, - }, nlPriv) + varRoot, chonk := setupChonkStorage(t, pm) + authority, _, err := tka.Create(chonk, state, nlPriv) if err != nil { t.Fatalf("tka.Create() failed: %v", err) } @@ -1084,20 +998,7 @@ func TestTKAAffectedSigs(t *testing.T) { } })) defer ts.Close() - cc := fakeControlClient(t, client) - b := LocalBackend{ - varRoot: temp, - cc: cc, - ccAuto: cc, - logf: t.Logf, - health: health.NewTracker(eventbustest.NewBus(t)), - tka: &tkaState{ - authority: authority, - storage: chonk, - }, - pm: pm, - store: pm.Store(), - } + b := newLocalBackendForTKA(t, varRoot, client, pm, authority, chonk) sigs, err := b.NetworkLockAffectedSigs(nlPriv.KeyID()) switch { @@ -1130,22 +1031,13 @@ func TestTKARecoverCompromisedKeyFlow(t *testing.T) { pm := setupProfileManager(t, nodePriv, nlPriv) // Make a fake TKA authority, to seed local state. - disablementSecret := bytes.Repeat([]byte{0xa5}, 32) key := tka.Key{Kind: tka.Key25519, Public: nlPriv.Public().Verifier(), Votes: 2} cosignKey := tka.Key{Kind: tka.Key25519, Public: cosignPriv.Public().Verifier(), Votes: 2} compromisedKey := tka.Key{Kind: tka.Key25519, Public: compromisedPriv.Public().Verifier(), Votes: 1} + state := tka.CreateStateForTest(key, compromisedKey, cosignKey) - temp := t.TempDir() - tkaPath := filepath.Join(temp, "tka-profile", string(pm.CurrentProfile().ID())) - os.Mkdir(tkaPath, 0755) - chonk, err := tka.ChonkDir(tkaPath) - if err != nil { - t.Fatal(err) - } - authority, _, err := tka.Create(chonk, tka.State{ - Keys: []tka.Key{key, compromisedKey, cosignKey}, - DisablementValues: [][]byte{tka.DisablementKDF(disablementSecret)}, - }, nlPriv) + varRoot, chonk := setupChonkStorage(t, pm) + authority, _, err := tka.Create(chonk, state, nlPriv) if err != nil { t.Fatalf("tka.Create() failed: %v", err) } @@ -1170,20 +1062,7 @@ func TestTKARecoverCompromisedKeyFlow(t *testing.T) { } })) defer ts.Close() - cc := fakeControlClient(t, client) - b := LocalBackend{ - varRoot: temp, - cc: cc, - ccAuto: cc, - logf: t.Logf, - health: health.NewTracker(eventbustest.NewBus(t)), - tka: &tkaState{ - authority: authority, - storage: chonk, - }, - pm: pm, - store: pm.Store(), - } + b := newLocalBackendForTKA(t, varRoot, client, pm, authority, chonk) aum, err := b.NetworkLockGenerateRecoveryAUM([]tkatype.KeyID{compromisedPriv.KeyID()}, tka.AUMHash{}) if err != nil { @@ -1193,17 +1072,7 @@ func TestTKARecoverCompromisedKeyFlow(t *testing.T) { // Cosign using the cosigning key. { pm := setupProfileManager(t, nodePriv, cosignPriv) - b := LocalBackend{ - varRoot: temp, - logf: t.Logf, - health: health.NewTracker(eventbustest.NewBus(t)), - tka: &tkaState{ - authority: authority, - storage: chonk, - }, - pm: pm, - store: pm.Store(), - } + b := newLocalBackendForTKA(t, varRoot, client, pm, authority, chonk) if aum, err = b.NetworkLockCosignRecoveryAUM(aum); err != nil { t.Fatalf("NetworkLockCosignRecoveryAUM() failed: %v", err) } diff --git a/ipn/ipnlocal/node_backend.go b/ipn/ipnlocal/node_backend.go index 75550b3d5..59c26ebe5 100644 --- a/ipn/ipnlocal/node_backend.go +++ b/ipn/ipnlocal/node_backend.go @@ -29,6 +29,7 @@ import ( "tailscale.com/util/eventbus" "tailscale.com/util/mak" "tailscale.com/util/slicesx" + "tailscale.com/util/testenv" "tailscale.com/wgengine/filter" ) @@ -79,6 +80,13 @@ type nodeBackend struct { eventClient *eventbus.Client derpMapViewPub *eventbus.Publisher[tailcfg.DERPMapView] + // homeDERP lives here temporarily. as long as mapSession is short lived, we + // don't have a location delivering netmaps to local backend that knows our + // homeDERP hence why it is cached here for now. + // TODO(cmol): move this field into a refactored mapSession that is not + // short lived. + homeDERP atomic.Int64 + // TODO(nickkhyl): maybe use sync.RWMutex? mu syncs.Mutex // protects the following fields @@ -107,6 +115,12 @@ type nodeBackend struct { // nodeByKey is an index of node public key to node ID for fast lookups. // It is mutated in place (with mu held) and must not escape the [nodeBackend]. nodeByKey map[key.NodePublic]tailcfg.NodeID + + // keyWaitersForTest is the test-only registry of channels waiting for + // a given peer key to first appear in the netmap. See + // [nodeBackend.AwaitNodeKeyForTest]. It is populated lazily and remains + // nil in production, where no test installs a waiter. + keyWaitersForTest map[key.NodePublic]chan struct{} } func newNodeBackend(ctx context.Context, logf logger.Logf, bus *eventbus.Bus) *nodeBackend { @@ -317,6 +331,46 @@ func (nb *nodeBackend) peerCapsLocked(src netip.Addr) tailcfg.PeerCapMap { return nil } +// PeerCapsForIP returns the capabilities that remote src IP has when +// talking to the given destination IP on this node. The destination may +// be any IP the node handles: its own tailnet address, a VIP service +// address, or any future routable IP. +func (nb *nodeBackend) PeerCapsForIP(src, dst netip.Addr) tailcfg.PeerCapMap { + nb.mu.Lock() + defer nb.mu.Unlock() + if nb.netMap == nil { + return nil + } + filt := nb.filterAtomic.Load() + if filt == nil { + return nil + } + return filt.CapsWithValues(src, dst) +} + +// PeerCapsForService returns the capabilities that remote src IP has when +// talking to the named VIP service on this node. The service name is +// resolved to its VIP addresses via the node's service IP mappings, and +// the first address matching the src IP family is used for cap lookup. +func (nb *nodeBackend) PeerCapsForService(src netip.Addr, svcName tailcfg.ServiceName) tailcfg.PeerCapMap { + nb.mu.Lock() + defer nb.mu.Unlock() + if nb.netMap == nil { + return nil + } + filt := nb.filterAtomic.Load() + if filt == nil { + return nil + } + addrs := nb.netMap.GetVIPServiceIPMap()[svcName] + for _, ip := range addrs { + if ip.BitLen() == src.BitLen() { + return filt.CapsWithValues(src, ip) + } + } + return nil +} + // PeerHasCap reports whether the peer contains the given capability string, // with any value(s). func (nb *nodeBackend) PeerHasCap(peer tailcfg.NodeView, wantCap tailcfg.PeerCapability) bool { @@ -421,6 +475,7 @@ func (nb *nodeBackend) SetNetMap(nm *netmap.NetworkMap) { nb.updateNodeByAddrLocked() nb.updateNodeByKeyLocked() nb.updatePeersLocked() + nb.signalKeyWaitersForTestLocked() if nm != nil { nb.derpMapViewPub.Publish(nm.DERPMap.View()) } else { @@ -428,6 +483,43 @@ func (nb *nodeBackend) SetNetMap(nm *netmap.NetworkMap) { } } +// AwaitNodeKeyForTest returns a channel that is closed once a peer with the +// given node key first appears in this nodeBackend's peer index, or +// immediately (a closed channel) if it's already present. It is intended for +// in-process benchmarks that drive synthetic netmap deltas and need a +// zero-overhead signal that the client has applied a delta, replacing +// poll-based [local.Client.WhoIsNodeKey] loops in tests. It panics outside +// of tests. +func (nb *nodeBackend) AwaitNodeKeyForTest(k key.NodePublic) <-chan struct{} { + testenv.AssertInTest() + nb.mu.Lock() + defer nb.mu.Unlock() + if _, ok := nb.nodeByKey[k]; ok { + return syncs.ClosedChan() + } + if ch, ok := nb.keyWaitersForTest[k]; ok { + return ch + } + ch := make(chan struct{}) + mak.Set(&nb.keyWaitersForTest, k, ch) + return ch +} + +// signalKeyWaitersForTestLocked closes any waiter channels whose keys now +// appear in nb.nodeByKey. It is cheap when there are no waiters, which is +// the common case in production. It is called from [nodeBackend.SetNetMap] +// after the per-key index has been rebuilt. +// +// Caller must hold nb.mu. +func (nb *nodeBackend) signalKeyWaitersForTestLocked() { + for k, ch := range nb.keyWaitersForTest { + if _, ok := nb.nodeByKey[k]; ok { + close(ch) + delete(nb.keyWaitersForTest, k) + } + } +} + func (nb *nodeBackend) updateNodeByAddrLocked() { nm := nb.netMap if nm == nil { @@ -864,7 +956,7 @@ func dnsConfigForNetmap(nm *netmap.NetworkMap, peers map[tailcfg.NodeID]tailcfg. addSplitDNSRoutes(nm.DNS.Routes) // Add split DNS routes for conn25 - conn25DNSTargets := appc.PickSplitDNSPeers(nm.HasCap, nm.SelfNode, peers) + conn25DNSTargets := appc.PickSplitDNSPeers(nm.HasCap, nm.SelfNode, peers, prefs.AppConnector().Advertise) if conn25DNSTargets != nil { var m map[string][]*dnstype.Resolver for domain, candidateSplitDNSPeers := range conn25DNSTargets { diff --git a/ipn/ipnlocal/peerapi.go b/ipn/ipnlocal/peerapi.go index 322884fc7..d72a519ab 100644 --- a/ipn/ipnlocal/peerapi.go +++ b/ipn/ipnlocal/peerapi.go @@ -192,7 +192,7 @@ func (pln *peerAPIListener) ServeConn(src netip.AddrPort, c net.Conn) { c.Close() return } - nm := pln.lb.NetMap() + nm := pln.lb.NetMapNoPeers() if nm == nil || !nm.SelfNode.Valid() { logf("peerapi: no netmap") c.Close() diff --git a/ipn/ipnlocal/serve.go b/ipn/ipnlocal/serve.go index 9460896ad..83b8027d7 100644 --- a/ipn/ipnlocal/serve.go +++ b/ipn/ipnlocal/serve.go @@ -276,7 +276,7 @@ func (b *LocalBackend) updateServeTCPPortNetMapAddrListenersLocked(ports []uint1 } } - nm := b.NetMap() + nm := b.NetMapNoPeers() if nm == nil { b.logf("netMap is nil") return @@ -333,7 +333,7 @@ func (b *LocalBackend) setServeConfigLocked(config *ipn.ServeConfig, etag string return errors.New("can't reconfigure tailscaled when using a config file; config file is locked") } - nm := b.NetMap() + nm := b.NetMapNoPeers() if nm == nil { return errors.New("netMap is nil") } diff --git a/ipn/ipnlocal/state_test.go b/ipn/ipnlocal/state_test.go index 17e93a430..104c29a3f 100644 --- a/ipn/ipnlocal/state_test.go +++ b/ipn/ipnlocal/state_test.go @@ -136,6 +136,7 @@ type mockControl struct { calls []string authBlocked bool shutdown chan struct{} + loginFlags controlclient.LoginFlags hi *tailcfg.Hostinfo } @@ -273,6 +274,7 @@ func (cc *mockControl) Login(flags controlclient.LoginFlags) { cc.mu.Lock() defer cc.mu.Unlock() cc.authBlocked = interact || newKeys + cc.loginFlags |= flags } func (cc *mockControl) Logout(ctx context.Context) error { @@ -371,14 +373,6 @@ func (b *LocalBackend) nonInteractiveLoginForStateTest() { // predictable, but maybe a bit less thorough. This is more of an overall // state machine test than a test of the wgengine+magicsock integration. func TestStateMachine(t *testing.T) { - runTestStateMachine(t, false) -} - -func TestStateMachineSeamless(t *testing.T) { - runTestStateMachine(t, true) -} - -func runTestStateMachine(t *testing.T, seamless bool) { envknob.Setenv("TAILSCALE_USE_WIP_CODE", "1") defer envknob.Setenv("TAILSCALE_USE_WIP_CODE", "") c := qt.New(t) @@ -588,12 +582,6 @@ func runTestStateMachine(t *testing.T, seamless bool) { cc.persist.UserProfile.LoginName = "user1" cc.persist.NodeID = "node1" - // even if seamless is being enabled by default rather than by policy, this is - // the point where it will first get enabled. - if seamless { - sys.ControlKnobs().SeamlessKeyRenewal.Store(true) - } - cc.send(sendOpt{loginFinished: true, nm: &netmap.NetworkMap{}}) { nn := notifies.drain(3) @@ -696,6 +684,7 @@ func runTestStateMachine(t *testing.T, seamless bool) { notifies.expect(5) b.Logout(context.Background(), ipnauth.Self) { + b.awaitNoGoroutinesInTest() nn := notifies.drain(5) previousCC.assertCalls("pause", "Logout", "unpause", "Shutdown") // nn[0] is state notification (Stopped) @@ -873,7 +862,9 @@ func runTestStateMachine(t *testing.T, seamless bool) { // additional netmap updates. Since our LocalBackend instance already // has a netmap, we will reset it to nil to simulate the first netmap // retrieval. + b.mu.Lock() b.setNetMapLocked(nil) + b.mu.Unlock() cc.assertCalls("unpause") // // TODO: really the various GUIs and prefs should be refactored to @@ -1052,6 +1043,7 @@ func runTestStateMachine(t *testing.T, seamless bool) { } notifies.expect(1) // Fake a DERP connection. + b.awaitNoGoroutinesInTest() b.setWgengineStatus(&wgengine.Status{DERPs: 1, AsOf: time.Now()}, nil) { nn := notifies.drain(1) @@ -1474,11 +1466,23 @@ func TestEngineReconfigOnStateChange(t *testing.T) { lb.StartLoginInteractive(context.Background()) cc().sendAuthURL(node1) }, - // Without seamless renewal, even starting a reauth tears down everything: - wantState: ipn.Starting, - wantCfg: &wgcfg.Config{}, - wantRouterCfg: &router.Config{}, - wantDNSCfg: &dns.Config{}, + // Starting a reauth should leave everything up: + wantState: ipn.Starting, + wantCfg: &wgcfg.Config{ + Peers: []wgcfg.Peer{}, + Addresses: node1.SelfNode.Addresses().AsSlice(), + }, + wantRouterCfg: &router.Config{ + SNATSubnetRoutes: true, + NetfilterMode: preftype.NetfilterOn, + LocalAddrs: node1.SelfNode.Addresses().AsSlice(), + Routes: routesWithQuad100(), + }, + wantDNSCfg: &dns.Config{ + AcceptDNS: true, + Routes: map[dnsname.FQDN][]*dnstype.Resolver{}, + Hosts: hostsFor(node1), + }, }, { name: "Start/Connect/Login/InitReauth/Login", @@ -1512,71 +1516,8 @@ func TestEngineReconfigOnStateChange(t *testing.T) { }, }, { - name: "Seamless/Start/Connect/Login/InitReauth", + name: "Start/Connect/Login/Expire", steps: func(t *testing.T, lb *LocalBackend, cc func() *mockControl) { - lb.ControlKnobs().SeamlessKeyRenewal.Store(true) - mustDo(t)(lb.Start(ipn.Options{})) - mustDo2(t)(lb.EditPrefs(connect)) - cc().authenticated(node1) - - // Start the re-auth process: - lb.StartLoginInteractive(context.Background()) - cc().sendAuthURL(node1) - }, - // With seamless renewal, starting a reauth should leave everything up: - wantState: ipn.Starting, - wantCfg: &wgcfg.Config{ - Peers: []wgcfg.Peer{}, - Addresses: node1.SelfNode.Addresses().AsSlice(), - }, - wantRouterCfg: &router.Config{ - SNATSubnetRoutes: true, - NetfilterMode: preftype.NetfilterOn, - LocalAddrs: node1.SelfNode.Addresses().AsSlice(), - Routes: routesWithQuad100(), - }, - wantDNSCfg: &dns.Config{ - AcceptDNS: true, - Routes: map[dnsname.FQDN][]*dnstype.Resolver{}, - Hosts: hostsFor(node1), - }, - }, - { - name: "Seamless/Start/Connect/Login/InitReauth/Login", - steps: func(t *testing.T, lb *LocalBackend, cc func() *mockControl) { - lb.ControlKnobs().SeamlessKeyRenewal.Store(true) - mustDo(t)(lb.Start(ipn.Options{})) - mustDo2(t)(lb.EditPrefs(connect)) - cc().authenticated(node1) - - // Start the re-auth process: - lb.StartLoginInteractive(context.Background()) - cc().sendAuthURL(node1) - - // Complete the re-auth process: - cc().authenticated(node1) - }, - wantState: ipn.Starting, - wantCfg: &wgcfg.Config{ - Peers: []wgcfg.Peer{}, - Addresses: node1.SelfNode.Addresses().AsSlice(), - }, - wantRouterCfg: &router.Config{ - SNATSubnetRoutes: true, - NetfilterMode: preftype.NetfilterOn, - LocalAddrs: node1.SelfNode.Addresses().AsSlice(), - Routes: routesWithQuad100(), - }, - wantDNSCfg: &dns.Config{ - AcceptDNS: true, - Routes: map[dnsname.FQDN][]*dnstype.Resolver{}, - Hosts: hostsFor(node1), - }, - }, - { - name: "Seamless/Start/Connect/Login/Expire", - steps: func(t *testing.T, lb *LocalBackend, cc func() *mockControl) { - lb.ControlKnobs().SeamlessKeyRenewal.Store(true) mustDo(t)(lb.Start(ipn.Options{})) mustDo2(t)(lb.EditPrefs(connect)) cc().authenticated(node1) @@ -1586,7 +1527,7 @@ func TestEngineReconfigOnStateChange(t *testing.T) { }).View(), }}) }, - // Even with seamless, if the key we are using expires, we want to disconnect: + // If the key we are using expires, we want to disconnect: wantState: ipn.NeedsLogin, wantCfg: &wgcfg.Config{}, wantRouterCfg: &router.Config{}, @@ -1635,14 +1576,6 @@ func TestEngineReconfigOnStateChange(t *testing.T) { // TestSendPreservesAuthURL tests that wgengine updates arriving in the middle of // processing an auth URL doesn't result in the auth URL being cleared. func TestSendPreservesAuthURL(t *testing.T) { - runTestSendPreservesAuthURL(t, false) -} - -func TestSendPreservesAuthURLSeamless(t *testing.T) { - runTestSendPreservesAuthURL(t, true) -} - -func runTestSendPreservesAuthURL(t *testing.T, seamless bool) { var cc *mockControl b := newLocalBackendWithTestControl(t, true, func(tb testing.TB, opts controlclient.Options) controlclient.Client { cc = newClient(t, opts) @@ -1661,10 +1594,6 @@ func runTestSendPreservesAuthURL(t *testing.T, seamless bool) { cc.persist.UserProfile.LoginName = "user1" cc.persist.NodeID = "node1" - if seamless { - b.sys.ControlKnobs().SeamlessKeyRenewal.Store(true) - } - cc.send(sendOpt{loginFinished: true, nm: &netmap.NetworkMap{ SelfNode: (&tailcfg.Node{MachineAuthorized: true}).View(), }}) @@ -2009,6 +1938,8 @@ func (e *mockEngine) Ping(ip netip.Addr, pingType tailcfg.PingType, size int, cb func (e *mockEngine) InstallCaptureHook(packet.CaptureCallback) {} +func (e *mockEngine) SetPeerByIPPacketFunc(func(netip.Addr) (_ key.NodePublic, ok bool)) {} + func (e *mockEngine) Close() { e.mu.Lock() defer e.mu.Unlock() diff --git a/ipn/ipnlocal/web_client.go b/ipn/ipnlocal/web_client.go index 37dba93d0..6ab68858e 100644 --- a/ipn/ipnlocal/web_client.go +++ b/ipn/ipnlocal/web_client.go @@ -173,7 +173,7 @@ func (b *LocalBackend) waitWebClientAuthURL(ctx context.Context, id string, src // one to be completed, based on the presence or absence of the // provided id value. func (b *LocalBackend) doWebClientNoiseRequest(ctx context.Context, id string, src tailcfg.NodeID) (*tailcfg.WebClientAuthResponse, error) { - nm := b.NetMap() + nm := b.NetMapNoPeers() if nm == nil || !nm.SelfNode.Valid() { return nil, errors.New("[unexpected] no self node") } diff --git a/ipn/ipnstate/ipnstate.go b/ipn/ipnstate/ipnstate.go index 17e6ac870..f7df7e5a2 100644 --- a/ipn/ipnstate/ipnstate.go +++ b/ipn/ipnstate/ipnstate.go @@ -86,7 +86,7 @@ type Status struct { ClientVersion *tailcfg.ClientVersion } -// TKAKey describes a key trusted by network lock. +// TKAKey describes a key trusted by tailnet lock. type TKAKey struct { Kind string Key key.NLPublic @@ -94,7 +94,7 @@ type TKAKey struct { Votes uint } -// TKAPeer describes a peer and its network lock details. +// TKAPeer describes a peer and its tailnet lock details. type TKAPeer struct { Name string // DNS ID tailcfg.NodeID @@ -104,7 +104,7 @@ type TKAPeer struct { NodeKeySignature tka.NodeKeySignature } -// NetworkLockStatus represents whether network-lock is enabled, +// NetworkLockStatus represents whether tailnet-lock is enabled, // along with details about the locally-known state of the tailnet // key authority. type NetworkLockStatus struct { @@ -115,7 +115,7 @@ type NetworkLockStatus struct { // if network lock is not enabled. Head *[32]byte - // PublicKey describes the node's network-lock public key. + // PublicKey describes the node's tailnet-lock public key. // It may be zero if the node has not logged in. PublicKey key.NLPublic @@ -123,14 +123,14 @@ type NetworkLockStatus struct { // populated if the node is not operating (i.e. waiting for a login). NodeKey *key.NodePublic - // NodeKeySigned is true if our node is authorized by network-lock. + // NodeKeySigned is true if our node is authorized by tailnet-lock. NodeKeySigned bool // NodeKeySignature is the current signature of this node's key. NodeKeySignature *tka.NodeKeySignature // TrustedKeys describes the keys currently trusted to make changes - // to network-lock. + // to tailnet-lock. TrustedKeys []TKAKey // VisiblePeers describes peers which are visible in the netmap that @@ -148,7 +148,7 @@ type NetworkLockStatus struct { StateID uint64 } -// NetworkLockUpdate describes a change to network-lock state. +// NetworkLockUpdate describes a change to tailnet-lock state. type NetworkLockUpdate struct { Hash [32]byte Change string // values of tka.AUMKind.String() diff --git a/ipn/localapi/debug.go b/ipn/localapi/debug.go index d8e46040d..6f222bef0 100644 --- a/ipn/localapi/debug.go +++ b/ipn/localapi/debug.go @@ -9,6 +9,7 @@ import ( "cmp" "context" "encoding/json" + "errors" "fmt" "io" "net" @@ -232,6 +233,12 @@ func (h *Handler) serveDebug(w http.ResponseWriter, r *http.Request) { if err == nil { return } + case "peer-disco-keys": + w.Header().Set("Content-Type", "application/json") + err = json.NewEncoder(w).Encode(h.b.DebugPeerDiscoKeys()) + if err == nil { + return + } case "rotate-disco-key": err = h.b.DebugRotateDiscoKey() case "statedir": @@ -243,6 +250,22 @@ func (h *Handler) serveDebug(w http.ResponseWriter, r *http.Request) { } case "clear-netmap-cache": h.b.ClearNetmapCache(r.Context()) + case "current-netmap": + // Return the current netmap (with peers populated) as JSON. This + // is a debug-only path: the netmap.NetworkMap shape is an + // internal type and may change without notice. Production + // callers should fetch the narrower bits they need via their + // own LocalAPI methods instead. + nm := h.b.NetMapWithPeers() + if nm == nil { + err = errors.New("no netmap") + break + } + w.Header().Set("Content-Type", "application/json") + err = json.NewEncoder(w).Encode(nm) + if err == nil { + return + } case "": err = fmt.Errorf("missing parameter 'action'") default: @@ -278,7 +301,7 @@ func (h *Handler) serveDebugPacketFilterRules(w http.ResponseWriter, r *http.Req http.Error(w, "debug access denied", http.StatusForbidden) return } - nm := h.b.NetMap() + nm := h.b.NetMapNoPeers() if nm == nil { http.Error(w, "no netmap", http.StatusNotFound) return @@ -295,7 +318,7 @@ func (h *Handler) serveDebugPacketFilterMatches(w http.ResponseWriter, r *http.R http.Error(w, "debug access denied", http.StatusForbidden) return } - nm := h.b.NetMap() + nm := h.b.NetMapNoPeers() if nm == nil { http.Error(w, "no netmap", http.StatusNotFound) return diff --git a/ipn/localapi/localapi.go b/ipn/localapi/localapi.go index 43942c52f..9d4977e48 100644 --- a/ipn/localapi/localapi.go +++ b/ipn/localapi/localapi.go @@ -72,16 +72,20 @@ var handler = map[string]LocalAPIHandler{ // The other /localapi/v0/NAME handlers are exact matches and contain only NAME // without a trailing slash: + "cert-domains": (*Handler).serveCertDomains, "check-prefs": (*Handler).serveCheckPrefs, "check-so-mark-in-use": (*Handler).serveCheckSOMarkInUse, "derpmap": (*Handler).serveDERPMap, + "dns-config": (*Handler).serveDNSConfig, "goroutines": (*Handler).serveGoroutines, "login-interactive": (*Handler).serveLoginInteractive, "logout": (*Handler).serveLogout, + "peer-by-id": (*Handler).servePeerByID, "ping": (*Handler).servePing, "prefs": (*Handler).servePrefs, "reload-config": (*Handler).reloadConfig, "reset-auth": (*Handler).serveResetAuth, + "services": (*Handler).serveServices, "set-expiry-sooner": (*Handler).serveSetExpirySooner, "shutdown": (*Handler).serveShutdown, "start": (*Handler).serveStart, @@ -346,7 +350,7 @@ func (h *Handler) serveIDToken(w http.ResponseWriter, r *http.Request) { http.Error(w, "id-token access denied", http.StatusForbidden) return } - nm := h.b.NetMap() + nm := h.b.NetMapNoPeers() if nm == nil { http.Error(w, "no netmap", http.StatusServiceUnavailable) return @@ -416,7 +420,7 @@ func (h *Handler) serveBugReport(w http.ResponseWriter, r *http.Request) { } // Information about the current node from the netmap - if nm := h.b.NetMap(); nm != nil { + if nm := h.b.NetMapNoPeers(); nm != nil { if self := nm.SelfNode; self.Valid() { h.logf("user bugreport node info: nodeid=%q stableid=%q expiry=%q", self.ID(), self.StableID(), self.KeyExpiry().Format(time.RFC3339)) } @@ -541,6 +545,8 @@ type localBackendWhoIsMethods interface { WhoIs(string, netip.AddrPort) (n tailcfg.NodeView, u tailcfg.UserProfile, ok bool) WhoIsNodeKey(key.NodePublic) (n tailcfg.NodeView, u tailcfg.UserProfile, ok bool) PeerCaps(netip.Addr) tailcfg.PeerCapMap + PeerCapsForIP(src, dst netip.Addr) tailcfg.PeerCapMap + PeerCapsForService(src netip.Addr, svcName tailcfg.ServiceName) tailcfg.PeerCapMap } func (h *Handler) serveWhoIsWithBackend(w http.ResponseWriter, r *http.Request, b localBackendWhoIsMethods) { @@ -588,7 +594,25 @@ func (h *Handler) serveWhoIsWithBackend(w http.ResponseWriter, r *http.Request, UserProfile: &u, // always non-nil per WhoIsResponse contract } if n.Addresses().Len() > 0 { - res.CapMap = b.PeerCaps(n.Addresses().At(0).Addr()) + src := n.Addresses().At(0).Addr() + switch { + case r.FormValue("svc_name") != "": + svcName := tailcfg.AsServiceName(r.FormValue("svc_name")) + if svcName == "" { + http.Error(w, "invalid svc_name", http.StatusBadRequest) + return + } + res.CapMap = b.PeerCapsForService(src, svcName) + case r.FormValue("dst_ip") != "": + svcAddr, err := netip.ParseAddr(r.FormValue("dst_ip")) + if err != nil { + http.Error(w, "invalid dst_ip", http.StatusBadRequest) + return + } + res.CapMap = b.PeerCapsForIP(src, svcAddr) + default: + res.CapMap = b.PeerCaps(src) + } } j, err := json.MarshalIndent(res, "", "\t") if err != nil { @@ -1072,6 +1096,80 @@ func (h *Handler) serveDERPMap(w http.ResponseWriter, r *http.Request) { e.Encode(h.b.DERPMap()) } +// serveCertDomains returns the list of DNS.CertDomains from the current +// netmap, or an empty list if no netmap has been received yet. +// The returned list is sorted in ascending order. +func (h *Handler) serveCertDomains(w http.ResponseWriter, r *http.Request) { + if !h.PermitRead { + http.Error(w, "cert-domains access denied", http.StatusForbidden) + return + } + var domains []string + if nm := h.b.NetMapNoPeers(); nm != nil { + domains = slices.Clone(nm.DNS.CertDomains) + slices.Sort(domains) + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(domains) +} + +// serveDNSConfig returns the [tailcfg.DNSConfig] from the current netmap. +// It returns 503 if no netmap has been received yet. +func (h *Handler) serveDNSConfig(w http.ResponseWriter, r *http.Request) { + if !h.PermitRead { + http.Error(w, "dns-config access denied", http.StatusForbidden) + return + } + nm := h.b.NetMapNoPeers() + if nm == nil { + http.Error(w, "no netmap", http.StatusServiceUnavailable) + return + } + w.Header().Set("Content-Type", "application/json") + e := json.NewEncoder(w) + e.SetIndent("", "\t") + e.Encode(nm.DNS) +} + +// peerByIDBackend is the subset of [ipnlocal.LocalBackend] used by +// [Handler.servePeerByID]. It exists so the handler can be tested with a +// trivial mock without spinning up a full LocalBackend. +type peerByIDBackend interface { + PeerByID(tailcfg.NodeID) (tailcfg.NodeView, bool) +} + +// servePeerByID returns the current full [tailcfg.Node] for the peer with +// the NodeID given in the "id" query parameter, in O(1) time. It returns +// 404 if no such peer is in the current netmap. +// +// It is intended for clients that need the latest state of a single peer +// without fetching the entire netmap. +func (h *Handler) servePeerByID(w http.ResponseWriter, r *http.Request) { + h.servePeerByIDWithBackend(w, r, h.b) +} + +func (h *Handler) servePeerByIDWithBackend(w http.ResponseWriter, r *http.Request, b peerByIDBackend) { + if !h.PermitRead { + http.Error(w, "peer-by-id access denied", http.StatusForbidden) + return + } + idStr := r.FormValue("id") + id, err := strconv.ParseInt(idStr, 10, 64) + if err != nil || id <= 0 { + http.Error(w, "invalid 'id' parameter", http.StatusBadRequest) + return + } + nv, ok := b.PeerByID(tailcfg.NodeID(id)) + if !ok { + http.Error(w, "no peer with that NodeID", http.StatusNotFound) + return + } + w.Header().Set("Content-Type", "application/json") + e := json.NewEncoder(w) + e.SetIndent("", "\t") + e.Encode(nv.AsStruct()) +} + // serveSetExpirySooner sets the expiry date on the current machine, specified // by an `expiry` unix timestamp as POST or query param. func (h *Handler) serveSetExpirySooner(w http.ResponseWriter, r *http.Request) { @@ -1168,16 +1266,34 @@ func (h *Handler) serveDial(w http.ResponseWriter, r *http.Request) { http.Error(w, "missing Dial-Host or Dial-Port header", http.StatusBadRequest) return } + network := cmp.Or(r.Header.Get("Dial-Network"), "tcp") + + addr := net.JoinHostPort(hostStr, portStr) + + // Check whether the resolved address is a Tailscale route. + // If not, tell the client to dial it directly so the connection + // comes from the calling user's UID rather than our root-owned daemon. + ipp, viaTailscale, err := h.b.Dialer().UserDialPlan(r.Context(), network, addr) + if err != nil { + http.Error(w, "resolve failure: "+err.Error(), http.StatusBadGateway) + return + } + if !viaTailscale { + w.Header().Set("Dial-Self", "true") + w.Header().Set("Dial-Addr", ipp.String()) + w.WriteHeader(http.StatusOK) + return + } + hijacker, ok := w.(http.Hijacker) if !ok { http.Error(w, "make request over HTTP/1", http.StatusBadRequest) return } - network := cmp.Or(r.Header.Get("Dial-Network"), "tcp") - - addr := net.JoinHostPort(hostStr, portStr) - outConn, err := h.b.Dialer().UserDial(r.Context(), network, addr) + // Dial via Tailscale using the resolved IP:port to avoid a TOCTOU + // race with DNS re-resolution. + outConn, err := h.b.Dialer().UserDial(r.Context(), network, ipp.String()) if err != nil { http.Error(w, "dial failure: "+err.Error(), http.StatusBadGateway) return @@ -1457,7 +1573,7 @@ func (h *Handler) serveQueryFeature(w http.ResponseWriter, r *http.Request) { http.Error(w, "missing feature", http.StatusInternalServerError) return } - nm := h.b.NetMap() + nm := h.b.NetMapNoPeers() if nm == nil { http.Error(w, "no netmap", http.StatusServiceUnavailable) return @@ -1707,6 +1823,20 @@ func (h *Handler) serveShutdown(w http.ResponseWriter, r *http.Request) { eventbus.Publish[Shutdown](ec).Publish(Shutdown{}) } +func (h *Handler) serveServices(w http.ResponseWriter, r *http.Request) { + if r.Method != httpm.GET { + http.Error(w, "only GET allowed", http.StatusMethodNotAllowed) + return + } + nm := h.b.NetMapNoPeers() + if nm == nil { + http.Error(w, "no netmap", http.StatusServiceUnavailable) + return + } + w.Header().Set("Content-Type", "application/json") + json.NewEncoder(w).Encode(nm.Services()) +} + func (h *Handler) serveGetAppcRouteInfo(w http.ResponseWriter, r *http.Request) { if !buildfeatures.HasAppConnectors { http.Error(w, feature.ErrUnavailable.Error(), http.StatusNotImplemented) diff --git a/ipn/localapi/localapi_test.go b/ipn/localapi/localapi_test.go index 47e334571..84d8e1e0f 100644 --- a/ipn/localapi/localapi_test.go +++ b/ipn/localapi/localapi_test.go @@ -116,9 +116,11 @@ func TestSetPushDeviceToken(t *testing.T) { } type whoIsBackend struct { - whoIs func(proto string, ipp netip.AddrPort) (n tailcfg.NodeView, u tailcfg.UserProfile, ok bool) - whoIsNodeKey func(key.NodePublic) (n tailcfg.NodeView, u tailcfg.UserProfile, ok bool) - peerCaps map[netip.Addr]tailcfg.PeerCapMap + whoIs func(proto string, ipp netip.AddrPort) (n tailcfg.NodeView, u tailcfg.UserProfile, ok bool) + whoIsNodeKey func(key.NodePublic) (n tailcfg.NodeView, u tailcfg.UserProfile, ok bool) + peerCaps map[netip.Addr]tailcfg.PeerCapMap + peerCapsForIP func(src, dst netip.Addr) tailcfg.PeerCapMap + peerCapsForSvcName func(src netip.Addr, svcName tailcfg.ServiceName) tailcfg.PeerCapMap } func (b whoIsBackend) WhoIs(proto string, ipp netip.AddrPort) (n tailcfg.NodeView, u tailcfg.UserProfile, ok bool) { @@ -133,6 +135,20 @@ func (b whoIsBackend) PeerCaps(ip netip.Addr) tailcfg.PeerCapMap { return b.peerCaps[ip] } +func (b whoIsBackend) PeerCapsForIP(src, dst netip.Addr) tailcfg.PeerCapMap { + if b.peerCapsForIP != nil { + return b.peerCapsForIP(src, dst) + } + return nil +} + +func (b whoIsBackend) PeerCapsForService(src netip.Addr, svcName tailcfg.ServiceName) tailcfg.PeerCapMap { + if b.peerCapsForSvcName != nil { + return b.peerCapsForSvcName(src, svcName) + } + return nil +} + // Tests that the WhoIs handler accepts IPs, IP:ports, or nodekeys. // // From https://github.com/tailscale/tailscale/pull/9714 (a PR that is effectively a bug report) @@ -202,6 +218,249 @@ func TestWhoIsArgTypes(t *testing.T) { } } +func TestWhoIsServiceParams(t *testing.T) { + h := handlerForTest(t, &Handler{ + PermitRead: true, + }) + + peerAddr := netip.MustParseAddr("100.101.102.103") + vipA := netip.MustParseAddr("100.100.0.1") + vipB := netip.MustParseAddr("100.100.0.2") + + nodeCapsForAddr := tailcfg.PeerCapMap{"host-cap": {`"host-val"`}} + vipACaps := tailcfg.PeerCapMap{"svc-a-cap": {`"a-val"`}} + vipBCaps := tailcfg.PeerCapMap{"svc-b-cap": {`"b-val"`}} + + match := func() (tailcfg.NodeView, tailcfg.UserProfile, bool) { + return (&tailcfg.Node{ + ID: 123, + Addresses: []netip.Prefix{netip.PrefixFrom(peerAddr, 32)}, + }).View(), tailcfg.UserProfile{ID: 456}, true + } + + backend := whoIsBackend{ + whoIs: func(proto string, ipp netip.AddrPort) (tailcfg.NodeView, tailcfg.UserProfile, bool) { + return match() + }, + peerCaps: map[netip.Addr]tailcfg.PeerCapMap{ + peerAddr: nodeCapsForAddr, + }, + peerCapsForIP: func(src, dst netip.Addr) tailcfg.PeerCapMap { + switch dst { + case vipA: + return vipACaps + case vipB: + return vipBCaps + } + return nil + }, + peerCapsForSvcName: func(src netip.Addr, svcName tailcfg.ServiceName) tailcfg.PeerCapMap { + switch svcName { + case "svc:db": + return vipACaps + case "svc:cache": + return vipBCaps + } + return nil + }, + } + + doWhoIs := func(t *testing.T, query string) apitype.WhoIsResponse { + t.Helper() + rec := httptest.NewRecorder() + h.serveWhoIsWithBackend(rec, httptest.NewRequest("GET", "/v0/whois?"+query, nil), backend) + if rec.Code != 200 { + t.Fatalf("response code %d; body: %s", rec.Code, rec.Body.String()) + } + var res apitype.WhoIsResponse + if err := json.Unmarshal(rec.Body.Bytes(), &res); err != nil { + t.Fatalf("parsing response: %v", err) + } + return res + } + + doWhoIsStatus := func(t *testing.T, query string) int { + t.Helper() + rec := httptest.NewRecorder() + h.serveWhoIsWithBackend(rec, httptest.NewRequest("GET", "/v0/whois?"+query, nil), backend) + return rec.Code + } + + // No service params — uses PeerCaps (host-level). + t.Run("no_service_params_uses_PeerCaps", func(t *testing.T) { + res := doWhoIs(t, "addr="+peerAddr.String()) + if _, ok := res.CapMap["host-cap"]; !ok { + t.Errorf("expected host-cap from PeerCaps; got %v", res.CapMap) + } + if _, ok := res.CapMap["svc-a-cap"]; ok { + t.Error("VIP cap should not appear without service param") + } + }) + + // dst_ip tests — PeerCapsForIP path. + t.Run("dst_ip_uses_PeerCapsForIP", func(t *testing.T) { + res := doWhoIs(t, "addr="+peerAddr.String()+"&dst_ip="+vipA.String()) + if _, ok := res.CapMap["svc-a-cap"]; !ok { + t.Errorf("expected svc-a-cap; got %v", res.CapMap) + } + if _, ok := res.CapMap["host-cap"]; ok { + t.Error("host-cap should not appear when dst_ip is specified") + } + }) + + t.Run("dst_ip_scopes_to_specific_service", func(t *testing.T) { + resA := doWhoIs(t, "addr="+peerAddr.String()+"&dst_ip="+vipA.String()) + resB := doWhoIs(t, "addr="+peerAddr.String()+"&dst_ip="+vipB.String()) + + if _, ok := resA.CapMap["svc-a-cap"]; !ok { + t.Errorf("dst_ip=vipA: expected svc-a-cap; got %v", resA.CapMap) + } + if _, ok := resA.CapMap["svc-b-cap"]; ok { + t.Error("dst_ip=vipA: svc-b-cap should not appear") + } + + if _, ok := resB.CapMap["svc-b-cap"]; !ok { + t.Errorf("dst_ip=vipB: expected svc-b-cap; got %v", resB.CapMap) + } + if _, ok := resB.CapMap["svc-a-cap"]; ok { + t.Error("dst_ip=vipB: svc-a-cap should not appear") + } + }) + + t.Run("dst_ip_unrelated_ip_returns_empty", func(t *testing.T) { + res := doWhoIs(t, "addr="+peerAddr.String()+"&dst_ip=10.0.0.99") + if len(res.CapMap) != 0 { + t.Errorf("expected empty CapMap for unrelated dst_ip; got %v", res.CapMap) + } + }) + + t.Run("dst_ip_invalid_returns_400", func(t *testing.T) { + if code := doWhoIsStatus(t, "addr="+peerAddr.String()+"&dst_ip=not-an-ip"); code != 400 { + t.Errorf("expected 400 for invalid dst_ip; got %d", code) + } + }) + + // svc_name tests — PeerCapsForService path. + t.Run("svc_name_uses_PeerCapsForService", func(t *testing.T) { + res := doWhoIs(t, "addr="+peerAddr.String()+"&svc_name=svc:db") + if _, ok := res.CapMap["svc-a-cap"]; !ok { + t.Errorf("expected svc-a-cap; got %v", res.CapMap) + } + if _, ok := res.CapMap["host-cap"]; ok { + t.Error("host-cap should not appear when svc_name is specified") + } + }) + + t.Run("svc_name_scopes_to_specific_service", func(t *testing.T) { + resA := doWhoIs(t, "addr="+peerAddr.String()+"&svc_name=svc:db") + resB := doWhoIs(t, "addr="+peerAddr.String()+"&svc_name=svc:cache") + + if _, ok := resA.CapMap["svc-a-cap"]; !ok { + t.Errorf("svc_name=svc:db: expected svc-a-cap; got %v", resA.CapMap) + } + if _, ok := resA.CapMap["svc-b-cap"]; ok { + t.Error("svc_name=svc:db: svc-b-cap should not appear") + } + + if _, ok := resB.CapMap["svc-b-cap"]; !ok { + t.Errorf("svc_name=svc:cache: expected svc-b-cap; got %v", resB.CapMap) + } + if _, ok := resB.CapMap["svc-a-cap"]; ok { + t.Error("svc_name=svc:cache: svc-a-cap should not appear") + } + }) + + t.Run("svc_name_unknown_service_returns_empty", func(t *testing.T) { + res := doWhoIs(t, "addr="+peerAddr.String()+"&svc_name=svc:unknown") + if len(res.CapMap) != 0 { + t.Errorf("expected empty CapMap for unknown service; got %v", res.CapMap) + } + }) + + t.Run("svc_name_invalid_returns_400", func(t *testing.T) { + if code := doWhoIsStatus(t, "addr="+peerAddr.String()+"&svc_name=not-a-service-name"); code != 400 { + t.Errorf("expected 400 for invalid svc_name; got %d", code) + } + }) + + // svc_name takes priority over dst_ip when both are specified. + t.Run("svc_name_takes_priority_over_dst_ip", func(t *testing.T) { + res := doWhoIs(t, "addr="+peerAddr.String()+"&svc_name=svc:cache&dst_ip="+vipA.String()) + if _, ok := res.CapMap["svc-b-cap"]; !ok { + t.Errorf("svc_name should take priority; expected svc-b-cap (cache); got %v", res.CapMap) + } + if _, ok := res.CapMap["svc-a-cap"]; ok { + t.Error("dst_ip result should not appear when svc_name is also specified") + } + }) +} + +type fakePeerByIDBackend map[tailcfg.NodeID]*tailcfg.Node + +func (f fakePeerByIDBackend) PeerByID(id tailcfg.NodeID) (tailcfg.NodeView, bool) { + n, ok := f[id] + if !ok { + return tailcfg.NodeView{}, false + } + return n.View(), true +} + +func TestServePeerByID(t *testing.T) { + h := handlerForTest(t, &Handler{PermitRead: true}) + b := fakePeerByIDBackend{ + 42: { + ID: 42, + Name: "alpha", + Addresses: []netip.Prefix{ + netip.MustParsePrefix("100.64.0.42/32"), + }, + }, + } + + tests := []struct { + name string + query string + wantCode int + wantNodeID tailcfg.NodeID + }{ + {"hit", "id=42", 200, 42}, + {"miss", "id=99", 404, 0}, + {"bad_id", "id=garbage", 400, 0}, + {"missing_id", "", 400, 0}, + {"zero_id", "id=0", 400, 0}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rec := httptest.NewRecorder() + req := httptest.NewRequest("GET", "/v0/peer-by-id?"+tt.query, nil) + h.servePeerByIDWithBackend(rec, req, b) + if rec.Code != tt.wantCode { + t.Fatalf("status = %d, want %d; body=%q", rec.Code, tt.wantCode, rec.Body.String()) + } + if tt.wantCode != 200 { + return + } + var got tailcfg.Node + if err := json.Unmarshal(rec.Body.Bytes(), &got); err != nil { + t.Fatalf("unmarshal body %q: %v", rec.Body.Bytes(), err) + } + if got.ID != tt.wantNodeID { + t.Errorf("Node.ID = %d, want %d", got.ID, tt.wantNodeID) + } + }) + } + + t.Run("forbidden", func(t *testing.T) { + hh := handlerForTest(t, &Handler{PermitRead: false}) + rec := httptest.NewRecorder() + req := httptest.NewRequest("GET", "/v0/peer-by-id?id=42", nil) + hh.servePeerByIDWithBackend(rec, req, b) + if rec.Code != http.StatusForbidden { + t.Fatalf("status = %d, want %d", rec.Code, http.StatusForbidden) + } + }) +} + func TestShouldDenyServeConfigForGOOSAndUserContext(t *testing.T) { newHandler := func(connIsLocalAdmin bool) *Handler { return handlerForTest(t, &Handler{ @@ -500,3 +759,69 @@ func TestServeWithUnhealthyState(t *testing.T) { }) } } + +func TestServeDialSelf(t *testing.T) { + h := handlerForTest(t, &Handler{ + PermitRead: true, + PermitWrite: true, + b: newTestLocalBackend(t), + }) + + tests := []struct { + name string + host string + port string + wantSelf bool + wantAddr string + wantStatus int + }{ + { + name: "loopback_v4", + host: "127.0.0.1", + port: "8080", + wantSelf: true, + wantAddr: "127.0.0.1:8080", + wantStatus: http.StatusOK, + }, + { + name: "loopback_v6", + host: "::1", + port: "8080", + wantSelf: true, + wantAddr: "[::1]:8080", + wantStatus: http.StatusOK, + }, + { + name: "localhost", + host: "localhost", + port: "3000", + wantSelf: true, + wantStatus: http.StatusOK, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + req := httptest.NewRequest("POST", "http://local-tailscaled.sock/localapi/v0/dial", nil) + req.Header.Set("Connection", "upgrade") + req.Header.Set("Upgrade", "ts-dial") + req.Header.Set("Dial-Host", tt.host) + req.Header.Set("Dial-Port", tt.port) + resp := httptest.NewRecorder() + h.serveDial(resp, req) + + if resp.Code != tt.wantStatus { + t.Fatalf("status = %d, want %d; body: %s", resp.Code, tt.wantStatus, resp.Body.String()) + } + gotSelf := resp.Header().Get("Dial-Self") == "true" + if gotSelf != tt.wantSelf { + t.Errorf("Dial-Self = %v, want %v", gotSelf, tt.wantSelf) + } + if tt.wantAddr != "" { + if got := resp.Header().Get("Dial-Addr"); got != tt.wantAddr { + t.Errorf("Dial-Addr = %q, want %q", got, tt.wantAddr) + } + } + }) + } +} diff --git a/k8s-operator/api.md b/k8s-operator/api.md index 5a60f66e0..9101c95ca 100644 --- a/k8s-operator/api.md +++ b/k8s-operator/api.md @@ -483,6 +483,8 @@ _Appears in:_ | Field | Description | Default | Validation | | --- | --- | --- | --- | | `tolerations` _[Toleration](https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.3/#toleration-v1-core) array_ | If specified, applies tolerations to the pods deployed by the DNSConfig resource. | | | +| `affinity` _[Affinity](https://kubernetes.io/docs/reference/generated/kubernetes-api/v1.3/#affinity-v1-core)_ | If specified, applies affinity rules to the pods deployed by the DNSConfig resource. | | | +| `nodeSelector` _object (keys:string, values:string)_ | If specified, applies node selector rules to the pods deployed by the DNSConfig resource. | | | #### NameserverService diff --git a/k8s-operator/apis/v1alpha1/types_tsdnsconfig.go b/k8s-operator/apis/v1alpha1/types_tsdnsconfig.go index c1a2e7906..529114c2e 100644 --- a/k8s-operator/apis/v1alpha1/types_tsdnsconfig.go +++ b/k8s-operator/apis/v1alpha1/types_tsdnsconfig.go @@ -113,6 +113,12 @@ type NameserverPod struct { // If specified, applies tolerations to the pods deployed by the DNSConfig resource. // +optional Tolerations []corev1.Toleration `json:"tolerations,omitempty"` + // If specified, applies affinity rules to the pods deployed by the DNSConfig resource. + // +optional + Affinity *corev1.Affinity `json:"affinity,omitzero"` + // If specified, applies node selector rules to the pods deployed by the DNSConfig resource. + // +optional + NodeSelector map[string]string `json:"nodeSelector,omitzero"` } type DNSConfigStatus struct { diff --git a/k8s-operator/apis/v1alpha1/zz_generated.deepcopy.go b/k8s-operator/apis/v1alpha1/zz_generated.deepcopy.go index 2528c89f3..b401c6d87 100644 --- a/k8s-operator/apis/v1alpha1/zz_generated.deepcopy.go +++ b/k8s-operator/apis/v1alpha1/zz_generated.deepcopy.go @@ -469,6 +469,18 @@ func (in *NameserverPod) DeepCopyInto(out *NameserverPod) { (*in)[i].DeepCopyInto(&(*out)[i]) } } + if in.Affinity != nil { + in, out := &in.Affinity, &out.Affinity + *out = new(corev1.Affinity) + (*in).DeepCopyInto(*out) + } + if in.NodeSelector != nil { + in, out := &in.NodeSelector, &out.NodeSelector + *out = make(map[string]string, len(*in)) + for key, val := range *in { + (*out)[key] = val + } + } } // DeepCopy is an autogenerated deepcopy function, copying the receiver, creating a new NameserverPod. diff --git a/k8s-operator/utils.go b/k8s-operator/utils.go index 043a9d7b5..d83d98e0c 100644 --- a/k8s-operator/utils.go +++ b/k8s-operator/utils.go @@ -7,6 +7,8 @@ package kube import ( + "crypto/sha256" + "encoding/hex" "fmt" "tailscale.com/tailcfg" @@ -50,3 +52,17 @@ func CapVerFromFileName(name string) (tailcfg.CapabilityVersion, error) { _, err := fmt.Sscanf(name, "cap-%d.hujson", &cap) return cap, err } + +// TruncateLabelValue truncates a Kubernetes label value to fit within the +// 63-character limit. If the value exceeds the limit, it is truncated and a +// short hash suffix is appended to preserve uniqueness. +func TruncateLabelValue(val string) string { + const maxLen = 63 + if len(val) <= maxLen { + return val + } + hash := sha256.Sum256([]byte(val)) + suffix := hex.EncodeToString(hash[:4]) // 8 hex chars + truncated := val[:maxLen-len(suffix)-1] + return truncated + "-" + suffix +} diff --git a/k8s-operator/utils_test.go b/k8s-operator/utils_test.go new file mode 100644 index 000000000..7a30df6b4 --- /dev/null +++ b/k8s-operator/utils_test.go @@ -0,0 +1,78 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !plan9 + +package kube + +import ( + "strings" + "testing" +) + +func TestTruncateLabelValue(t *testing.T) { + tests := []struct { + name string + input string + want string // empty means expect input unchanged + }{ + { + name: "short-value-unchanged", + input: "my-service", + }, + { + name: "exactly-63-chars-unchanged", + input: strings.Repeat("a", 63), + }, + { + name: "64-chars-gets-truncated", + input: strings.Repeat("a", 64), + }, + { + name: "very-long-value-gets-truncated", + input: "tailscale-nginx-clickhouse-o11y-server-https-with-extra-long-suffix-that-exceeds-limit", + }, + { + name: "253-chars-max-k8s-resource-name", + input: strings.Repeat("x", 253), + }, + { + name: "empty-string-unchanged", + input: "", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := TruncateLabelValue(tt.input) + if len(got) > 63 { + t.Errorf("TruncateLabelValue(%q) = %q (len %d), exceeds 63 chars", tt.input, got, len(got)) + } + if len(tt.input) <= 63 && got != tt.input { + t.Errorf("TruncateLabelValue(%q) = %q, want unchanged input", tt.input, got) + } + if len(tt.input) > 63 && got == tt.input { + t.Errorf("TruncateLabelValue(%q) was not truncated", tt.input) + } + }) + } +} + +func TestTruncateLabelValueDeterministic(t *testing.T) { + input := strings.Repeat("a", 100) + first := TruncateLabelValue(input) + for i := 0; i < 10; i++ { + got := TruncateLabelValue(input) + if got != first { + t.Fatalf("non-deterministic: got %q, want %q", got, first) + } + } +} + +func TestTruncateLabelValueUniqueness(t *testing.T) { + // Two inputs sharing a long prefix but differing at the end should produce different outputs. + a := strings.Repeat("a", 100) + "-one" + b := strings.Repeat("a", 100) + "-two" + if TruncateLabelValue(a) == TruncateLabelValue(b) { + t.Errorf("collision: %q and %q produce the same truncated label", a, b) + } +} diff --git a/kube/authkey/authkey.go b/kube/authkey/authkey.go new file mode 100644 index 000000000..f544a0c81 --- /dev/null +++ b/kube/authkey/authkey.go @@ -0,0 +1,122 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !plan9 + +// Package authkey provides shared logic for handling auth key reissue +// requests between tailnet clients (containerboot, k8s-proxy) and the +// operator. +// +// When a client fails to authenticate (expired key, single-use key already +// used), it signals the operator by setting a marker in its state Secret. +// The operator responds by deleting the old device and issuing a new auth +// key. The client watches for the new key and restarts to apply it. +package authkey + +import ( + "context" + "fmt" + "log" + "time" + + "tailscale.com/ipn" + "tailscale.com/ipn/conffile" + "tailscale.com/kube/kubeapi" + "tailscale.com/kube/kubeclient" + "tailscale.com/kube/kubetypes" +) + +const ( + TailscaleContainerFieldManager = "tailscale-container" +) + +// SetReissueAuthKey sets the reissue_authkey marker in the state Secret to +// signal to the operator that a new auth key is needed. The marker value is +// the auth key that failed to authenticate. +func SetReissueAuthKey(ctx context.Context, kc kubeclient.Client, stateSecretName string, authKey string, fieldManager string) error { + s := &kubeapi.Secret{ + Data: map[string][]byte{ + kubetypes.KeyReissueAuthkey: []byte(authKey), + }, + } + + log.Printf("Requesting a new auth key from operator") + return kc.StrategicMergePatchSecret(ctx, stateSecretName, s, fieldManager) +} + +// ClearReissueAuthKey removes the reissue_authkey marker from the state Secret +// to signal to the operator that we've successfully received the new key. +func ClearReissueAuthKey(ctx context.Context, kc kubeclient.Client, stateSecretName string, fieldManager string) error { + existing, err := kc.GetSecret(ctx, stateSecretName) + if err != nil { + return fmt.Errorf("error getting state secret: %w", err) + } + + s := &kubeapi.Secret{ + Data: map[string][]byte{ + kubetypes.KeyReissueAuthkey: nil, + kubetypes.KeyDeviceID: nil, + kubetypes.KeyDeviceFQDN: nil, + kubetypes.KeyDeviceIPs: nil, + string(ipn.MachineKeyStateKey): nil, + string(ipn.CurrentProfileStateKey): nil, + string(ipn.KnownProfilesStateKey): nil, + }, + } + + if profileKey := string(existing.Data[string(ipn.CurrentProfileStateKey)]); profileKey != "" { + s.Data[profileKey] = nil + } + + return kc.StrategicMergePatchSecret(ctx, stateSecretName, s, fieldManager) +} + +// WaitForAuthKeyReissue polls getAuthKey for a new auth key different from +// oldAuthKey, returning when one is found or maxWait expires. If notify is +// non-nil, it is used to wake the loop on config changes; otherwise it falls +// back to periodic polling. The clearFn callback is called when a new key is +// detected, to clear the reissue marker from the state Secret. +func WaitForAuthKeyReissue(ctx context.Context, oldAuthKey string, maxWait time.Duration, getAuthKey func() string, clearFn func(context.Context) error, notify <-chan struct{}) error { + log.Printf("Waiting for operator to provide new auth key (max wait: %v)", maxWait) + + ctx, cancel := context.WithTimeout(ctx, maxWait) + defer cancel() + + pollInterval := 5 * time.Second + pt := time.NewTicker(pollInterval) + defer pt.Stop() + + start := time.Now() + + for { + select { + case <-ctx.Done(): + return fmt.Errorf("timeout waiting for auth key reissue after %v", maxWait) + case <-pt.C: + case <-notify: + } + + newAuthKey := getAuthKey() + if newAuthKey != "" && newAuthKey != oldAuthKey { + log.Printf("New auth key received from operator after %v", time.Since(start).Round(time.Second)) + if err := clearFn(ctx); err != nil { + log.Printf("Warning: failed to clear reissue request: %v", err) + } + return nil + } + + if notify == nil { + log.Printf("Waiting for new auth key from operator (%v elapsed)", time.Since(start).Round(time.Second)) + } + } +} + +// AuthKeyFromConfig extracts the auth key from a tailscaled config file. +// Returns empty string if the file cannot be read or contains no auth key. +func AuthKeyFromConfig(path string) string { + if cfg, err := conffile.Load(path); err == nil && cfg.Parsed.AuthKey != nil { + return *cfg.Parsed.AuthKey + } + + return "" +} diff --git a/kube/authkey/authkey_test.go b/kube/authkey/authkey_test.go new file mode 100644 index 000000000..268bc46d6 --- /dev/null +++ b/kube/authkey/authkey_test.go @@ -0,0 +1,124 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !plan9 + +package authkey + +import ( + "context" + "os" + "path/filepath" + "testing" + + "github.com/google/go-cmp/cmp" + "tailscale.com/ipn" + "tailscale.com/kube/kubeapi" + "tailscale.com/kube/kubeclient" + "tailscale.com/kube/kubetypes" +) + +func TestSetReissueAuthKey(t *testing.T) { + var patched map[string][]byte + kc := &kubeclient.FakeClient{ + StrategicMergePatchSecretImpl: func(ctx context.Context, name string, secret *kubeapi.Secret, _ string) error { + patched = secret.Data + return nil + }, + } + + err := SetReissueAuthKey(context.Background(), kc, "test-secret", "old-auth-key", TailscaleContainerFieldManager) + if err != nil { + t.Fatalf("SetReissueAuthKey() error = %v", err) + } + + want := map[string][]byte{ + kubetypes.KeyReissueAuthkey: []byte("old-auth-key"), + } + if diff := cmp.Diff(want, patched); diff != "" { + t.Errorf("SetReissueAuthKey() mismatch (-want +got):\n%s", diff) + } +} + +func TestClearReissueAuthKey(t *testing.T) { + var patched map[string][]byte + kc := &kubeclient.FakeClient{ + GetSecretImpl: func(ctx context.Context, name string) (*kubeapi.Secret, error) { + return &kubeapi.Secret{ + Data: map[string][]byte{ + "_current-profile": []byte("profile-abc1"), + "profile-abc1": []byte("some-profile-data"), + "_machinekey": []byte("machine-key-data"), + }, + }, nil + }, + StrategicMergePatchSecretImpl: func(ctx context.Context, name string, secret *kubeapi.Secret, _ string) error { + patched = secret.Data + return nil + }, + } + + err := ClearReissueAuthKey(context.Background(), kc, "test-secret", TailscaleContainerFieldManager) + if err != nil { + t.Fatalf("ClearReissueAuthKey() error = %v", err) + } + + want := map[string][]byte{ + kubetypes.KeyReissueAuthkey: nil, + kubetypes.KeyDeviceID: nil, + kubetypes.KeyDeviceFQDN: nil, + kubetypes.KeyDeviceIPs: nil, + string(ipn.MachineKeyStateKey): nil, + string(ipn.CurrentProfileStateKey): nil, + string(ipn.KnownProfilesStateKey): nil, + "profile-abc1": nil, + } + if diff := cmp.Diff(want, patched); diff != "" { + t.Errorf("ClearReissueAuthKey() mismatch (-want +got):\n%s", diff) + } +} + +func TestAuthKeyFromConfig(t *testing.T) { + for name, tc := range map[string]struct { + configContent string + want string + }{ + "valid_config_with_authkey": { + configContent: `{"Version":"alpha0","AuthKey":"test-auth-key"}`, + want: "test-auth-key", + }, + "valid_config_without_authkey": { + configContent: `{"Version":"alpha0"}`, + want: "", + }, + "invalid_config": { + configContent: `not valid json`, + want: "", + }, + "empty_config": { + configContent: ``, + want: "", + }, + } { + t.Run(name, func(t *testing.T) { + tmpDir := t.TempDir() + configPath := filepath.Join(tmpDir, "config.json") + + if err := os.WriteFile(configPath, []byte(tc.configContent), 0600); err != nil { + t.Fatalf("failed to write config file: %v", err) + } + + got := AuthKeyFromConfig(configPath) + if got != tc.want { + t.Errorf("AuthKeyFromConfig() = %q, want %q", got, tc.want) + } + }) + } + + t.Run("nonexistent_file", func(t *testing.T) { + got := AuthKeyFromConfig("/nonexistent/path/config.json") + if got != "" { + t.Errorf("AuthKeyFromConfig() = %q, want empty string for nonexistent file", got) + } + }) +} diff --git a/kube/certs/certs.go b/kube/certs/certs.go index 4c8ac88b6..fd7c82a10 100644 --- a/kube/certs/certs.go +++ b/kube/certs/certs.go @@ -171,8 +171,9 @@ func (cm *CertManager) runCertLoop(ctx context.Context, domain string) { } } -// waitForCertDomain ensures the requested domain is in the list of allowed -// domains before issuing the cert for the first time. +// domains before issuing the cert for the first time. It uses the IPN bus +// only as a wake-up trigger (Notify.SelfChange) and queries the current +// cert domains explicitly via [LocalClient.CertDomains]. func (cm *CertManager) waitForCertDomain(ctx context.Context, domain string) error { w, err := cm.lc.WatchIPNBus(ctx, ipn.NotifyInitialNetMap) if err != nil { @@ -185,11 +186,14 @@ func (cm *CertManager) waitForCertDomain(ctx context.Context, domain string) err if err != nil { return err } - if n.NetMap == nil { + if n.SelfChange == nil { continue } - - if slices.Contains(n.NetMap.DNS.CertDomains, domain) { + domains, err := cm.lc.CertDomains(ctx) + if err != nil { + continue + } + if slices.Contains(domains, domain) { return nil } } diff --git a/kube/certs/certs_test.go b/kube/certs/certs_test.go index f3662f6c3..27fe12752 100644 --- a/kube/certs/certs_test.go +++ b/kube/certs/certs_test.go @@ -12,7 +12,6 @@ import ( "tailscale.com/ipn" "tailscale.com/kube/localclient" "tailscale.com/tailcfg" - "tailscale.com/types/netmap" ) // TestEnsureCertLoops tests that the certManager correctly starts and stops @@ -201,17 +200,11 @@ func TestEnsureCertLoops(t *testing.T) { notifyChan := make(chan ipn.Notify) go func() { + // SelfChange wakes the cert manager; cert domains are + // then fetched via FakeLocalClient.CertDomainsResult. for { notifyChan <- ipn.Notify{ - NetMap: &netmap.NetworkMap{ - DNS: tailcfg.DNSConfig{ - CertDomains: []string{ - "my-app.tailnetxyz.ts.net", - "my-other-app.tailnetxyz.ts.net", - "my-apiserver.tailnetxyz.ts.net", - }, - }, - }, + SelfChange: &tailcfg.Node{StableID: "test"}, } } }() @@ -220,6 +213,11 @@ func TestEnsureCertLoops(t *testing.T) { FakeIPNBusWatcher: localclient.FakeIPNBusWatcher{ NotifyChan: notifyChan, }, + CertDomainsResult: []string{ + "my-app.tailnetxyz.ts.net", + "my-other-app.tailnetxyz.ts.net", + "my-apiserver.tailnetxyz.ts.net", + }, }, logf: log.Printf, certLoops: make(map[string]context.CancelFunc), diff --git a/kube/health/healthz.go b/kube/health/healthz.go index 53888922b..e9b459fc1 100644 --- a/kube/health/healthz.go +++ b/kube/health/healthz.go @@ -65,8 +65,8 @@ func (h *Healthz) MonitorHealth(ctx context.Context, lc *local.Client) error { return err } - if n.NetMap != nil { - h.Update(n.NetMap.SelfNode.Addresses().Len() != 0) + if self := n.SelfChange; self != nil { + h.Update(len(self.Addresses) != 0) } } } diff --git a/kube/localclient/fake-client.go b/kube/localclient/fake-client.go index a244ce31a..7ecada113 100644 --- a/kube/localclient/fake-client.go +++ b/kube/localclient/fake-client.go @@ -12,9 +12,10 @@ import ( type FakeLocalClient struct { FakeIPNBusWatcher - SetServeCalled bool - EditPrefsCalls []*ipn.MaskedPrefs - GetPrefsResult *ipn.Prefs + SetServeCalled bool + EditPrefsCalls []*ipn.MaskedPrefs + GetPrefsResult *ipn.Prefs + CertDomainsResult []string } func (m *FakeLocalClient) SetServeConfig(ctx context.Context, cfg *ipn.ServeConfig) error { @@ -45,6 +46,10 @@ func (f *FakeLocalClient) CertPair(ctx context.Context, domain string) ([]byte, return nil, nil, fmt.Errorf("CertPair not implemented") } +func (f *FakeLocalClient) CertDomains(ctx context.Context) ([]string, error) { + return f.CertDomainsResult, nil +} + type FakeIPNBusWatcher struct { NotifyChan chan ipn.Notify } diff --git a/kube/localclient/local-client.go b/kube/localclient/local-client.go index b8d40f406..f759568ba 100644 --- a/kube/localclient/local-client.go +++ b/kube/localclient/local-client.go @@ -19,6 +19,7 @@ type LocalClient interface { WatchIPNBus(ctx context.Context, mask ipn.NotifyWatchOpt) (IPNBusWatcher, error) SetServeConfig(context.Context, *ipn.ServeConfig) error EditPrefs(ctx context.Context, mp *ipn.MaskedPrefs) (*ipn.Prefs, error) + CertDomains(ctx context.Context) ([]string, error) CertIssuer } @@ -57,3 +58,7 @@ func (lc *localClient) WatchIPNBus(ctx context.Context, mask ipn.NotifyWatchOpt) func (lc *localClient) CertPair(ctx context.Context, domain string) ([]byte, []byte, error) { return lc.lc.CertPair(ctx, domain) } + +func (lc *localClient) CertDomains(ctx context.Context) ([]string, error) { + return lc.lc.CertDomains(ctx) +} diff --git a/kube/state/state.go b/kube/state/state.go index ebedb2f72..220eb439f 100644 --- a/kube/state/state.go +++ b/kube/state/state.go @@ -30,19 +30,8 @@ const ( keyDeviceFQDN = ipn.StateKey(kubetypes.KeyDeviceFQDN) ) -// SetInitialKeys sets Pod UID and cap ver and clears tailnet device state -// keys to help stop the operator using stale tailnet device state. +// SetInitialKeys sets Pod UID and cap ver. func SetInitialKeys(store ipn.StateStore, podUID string) error { - // Clear device state keys first so the operator knows if the pod UID - // matches, the other values are definitely not stale. - for _, key := range []ipn.StateKey{keyDeviceID, keyDeviceFQDN, keyDeviceIPs} { - if _, err := store.ReadState(key); err == nil { - if err := store.WriteState(key, nil); err != nil { - return fmt.Errorf("error writing %q to state store: %w", key, err) - } - } - } - if err := store.WriteState(keyPodUID, []byte(podUID)); err != nil { return fmt.Errorf("error writing pod UID to state store: %w", err) } @@ -55,9 +44,9 @@ func SetInitialKeys(store ipn.StateStore, podUID string) error { // KeepKeysUpdated sets state store keys consistent with containerboot to // signal proxy readiness to the operator. It runs until its context is -// cancelled or it hits an error. The passed in next function is expected to be -// from a local.IPNBusWatcher that is at least subscribed to -// ipn.NotifyInitialNetMap. +// cancelled or it hits an error. It watches the IPN bus for SelfChange +// notifications (which fire whenever the self node changes) and reads +// the new self node directly from the notify. func KeepKeysUpdated(ctx context.Context, store ipn.StateStore, lc klc.LocalClient) error { w, err := lc.WatchIPNBus(ctx, ipn.NotifyInitialNetMap) if err != nil { @@ -74,25 +63,26 @@ func KeepKeysUpdated(ctx context.Context, store ipn.StateStore, lc klc.LocalClie } return err } - if n.NetMap == nil { + self := n.SelfChange + if self == nil { continue } - if deviceID := n.NetMap.SelfNode.StableID(); deephash.Update(¤tDeviceID, &deviceID) { + if deviceID := self.StableID; deephash.Update(¤tDeviceID, &deviceID) { if err := store.WriteState(keyDeviceID, []byte(deviceID)); err != nil { return fmt.Errorf("failed to store device ID in state: %w", err) } } - if fqdn := n.NetMap.SelfNode.Name(); deephash.Update(¤tDeviceFQDN, &fqdn) { + if fqdn := self.Name; deephash.Update(¤tDeviceFQDN, &fqdn) { if err := store.WriteState(keyDeviceFQDN, []byte(fqdn)); err != nil { return fmt.Errorf("failed to store device FQDN in state: %w", err) } } - if addrs := n.NetMap.SelfNode.Addresses(); deephash.Update(¤tDeviceIPs, &addrs) { + if addrs := self.Addresses; deephash.Update(¤tDeviceIPs, &addrs) { var deviceIPs []string - for _, addr := range addrs.AsSlice() { + for _, addr := range addrs { deviceIPs = append(deviceIPs, addr.Addr().String()) } deviceIPsValue, err := json.Marshal(deviceIPs) diff --git a/kube/state/state_test.go b/kube/state/state_test.go index 9b2ce69be..5c438377e 100644 --- a/kube/state/state_test.go +++ b/kube/state/state_test.go @@ -18,7 +18,6 @@ import ( klc "tailscale.com/kube/localclient" "tailscale.com/tailcfg" "tailscale.com/types/logger" - "tailscale.com/types/netmap" ) func TestSetInitialStateKeys(t *testing.T) { @@ -58,9 +57,9 @@ func TestSetInitialStateKeys(t *testing.T) { expected: map[ipn.StateKey][]byte{ keyPodUID: podUID, keyCapVer: expectedCapVer, - keyDeviceID: nil, - keyDeviceFQDN: nil, - keyDeviceIPs: nil, + keyDeviceID: []byte("existing-device-id"), + keyDeviceFQDN: []byte("existing-device-fqdn"), + keyDeviceIPs: []byte(`["1.2.3.4"]`), }, }, } { @@ -133,12 +132,10 @@ func TestKeepStateKeysUpdated(t *testing.T) { { name: "authed", notify: ipn.Notify{ - NetMap: &netmap.NetworkMap{ - SelfNode: (&tailcfg.Node{ - StableID: "TESTCTRL00000001", - Name: "test-node.test.ts.net", - Addresses: []netip.Prefix{netip.MustParsePrefix("100.64.0.1/32"), netip.MustParsePrefix("fd7a:115c:a1e0:ab12:4843:cd96:0:1/128")}, - }).View(), + SelfChange: &tailcfg.Node{ + StableID: "TESTCTRL00000001", + Name: "test-node.test.ts.net", + Addresses: []netip.Prefix{netip.MustParsePrefix("100.64.0.1/32"), netip.MustParsePrefix("fd7a:115c:a1e0:ab12:4843:cd96:0:1/128")}, }, }, expected: []string{ @@ -150,12 +147,10 @@ func TestKeepStateKeysUpdated(t *testing.T) { { name: "updated_fields", notify: ipn.Notify{ - NetMap: &netmap.NetworkMap{ - SelfNode: (&tailcfg.Node{ - StableID: "TESTCTRL00000001", - Name: "updated.test.ts.net", - Addresses: []netip.Prefix{netip.MustParsePrefix("100.64.0.250/32")}, - }).View(), + SelfChange: &tailcfg.Node{ + StableID: "TESTCTRL00000001", + Name: "updated.test.ts.net", + Addresses: []netip.Prefix{netip.MustParsePrefix("100.64.0.250/32")}, }, }, expected: []string{ diff --git a/licenses/android.md b/licenses/android.md index 0b8fbe963..07c97948e 100644 --- a/licenses/android.md +++ b/licenses/android.md @@ -22,9 +22,9 @@ Client][]. See also the dependencies in the [Tailscale CLI][]. - [github.com/huin/goupnp](https://pkg.go.dev/github.com/huin/goupnp) ([BSD-2-Clause](https://github.com/huin/goupnp/blob/v1.3.0/LICENSE)) - [github.com/insomniacslk/dhcp](https://pkg.go.dev/github.com/insomniacslk/dhcp) ([BSD-3-Clause](https://github.com/insomniacslk/dhcp/blob/8c70d406f6d2/LICENSE)) - [github.com/jellydator/ttlcache/v3](https://pkg.go.dev/github.com/jellydator/ttlcache/v3) ([MIT](https://github.com/jellydator/ttlcache/blob/v3.1.0/LICENSE)) - - [github.com/klauspost/compress](https://pkg.go.dev/github.com/klauspost/compress) ([Apache-2.0](https://github.com/klauspost/compress/blob/v1.18.2/LICENSE)) - - [github.com/klauspost/compress/internal/snapref](https://pkg.go.dev/github.com/klauspost/compress/internal/snapref) ([BSD-3-Clause](https://github.com/klauspost/compress/blob/v1.18.2/internal/snapref/LICENSE)) - - [github.com/klauspost/compress/zstd/internal/xxhash](https://pkg.go.dev/github.com/klauspost/compress/zstd/internal/xxhash) ([MIT](https://github.com/klauspost/compress/blob/v1.18.2/zstd/internal/xxhash/LICENSE.txt)) + - [github.com/klauspost/compress](https://pkg.go.dev/github.com/klauspost/compress) ([Apache-2.0](https://github.com/klauspost/compress/blob/v1.18.5/LICENSE)) + - [github.com/klauspost/compress/internal/snapref](https://pkg.go.dev/github.com/klauspost/compress/internal/snapref) ([BSD-3-Clause](https://github.com/klauspost/compress/blob/v1.18.5/internal/snapref/LICENSE)) + - [github.com/klauspost/compress/zstd/internal/xxhash](https://pkg.go.dev/github.com/klauspost/compress/zstd/internal/xxhash) ([MIT](https://github.com/klauspost/compress/blob/v1.18.5/zstd/internal/xxhash/LICENSE.txt)) - [github.com/kortschak/wol](https://pkg.go.dev/github.com/kortschak/wol) ([BSD-3-Clause](https://github.com/kortschak/wol/blob/da482cc4850a/LICENSE)) - [github.com/mdlayher/netlink](https://pkg.go.dev/github.com/mdlayher/netlink) ([MIT](https://github.com/mdlayher/netlink/blob/fbb4dce95f42/LICENSE.md)) - [github.com/mdlayher/socket](https://pkg.go.dev/github.com/mdlayher/socket) ([MIT](https://github.com/mdlayher/socket/blob/v0.5.0/LICENSE.md)) @@ -32,22 +32,22 @@ Client][]. See also the dependencies in the [Tailscale CLI][]. - [github.com/pires/go-proxyproto](https://pkg.go.dev/github.com/pires/go-proxyproto) ([Apache-2.0](https://github.com/pires/go-proxyproto/blob/v0.8.1/LICENSE)) - [github.com/tailscale/peercred](https://pkg.go.dev/github.com/tailscale/peercred) ([BSD-3-Clause](https://github.com/tailscale/peercred/blob/35a0c7bd7edc/LICENSE)) - [github.com/tailscale/tailscale-android/libtailscale](https://pkg.go.dev/github.com/tailscale/tailscale-android/libtailscale) ([BSD-3-Clause](https://github.com/tailscale/tailscale-android/blob/HEAD/LICENSE)) - - [github.com/tailscale/wireguard-go](https://pkg.go.dev/github.com/tailscale/wireguard-go) ([MIT](https://github.com/tailscale/wireguard-go/blob/4184faf59e56/LICENSE)) + - [github.com/tailscale/wireguard-go](https://pkg.go.dev/github.com/tailscale/wireguard-go) ([MIT](https://github.com/tailscale/wireguard-go/blob/e3ac4a0afb4e/LICENSE)) - [github.com/tailscale/xnet/webdav](https://pkg.go.dev/github.com/tailscale/xnet/webdav) ([BSD-3-Clause](https://github.com/tailscale/xnet/blob/8497ac4dab2e/LICENSE)) - [github.com/u-root/uio](https://pkg.go.dev/github.com/u-root/uio) ([BSD-3-Clause](https://github.com/u-root/uio/blob/d2acac8f3701/LICENSE)) - [github.com/x448/float16](https://pkg.go.dev/github.com/x448/float16) ([MIT](https://github.com/x448/float16/blob/v0.8.4/LICENSE)) - [go4.org/mem](https://pkg.go.dev/go4.org/mem) ([Apache-2.0](https://github.com/go4org/mem/blob/ae6ca9944745/LICENSE)) - [go4.org/netipx](https://pkg.go.dev/go4.org/netipx) ([BSD-3-Clause](https://github.com/go4org/netipx/blob/fdeea329fbba/LICENSE)) - - [golang.org/x/crypto](https://pkg.go.dev/golang.org/x/crypto) ([BSD-3-Clause](https://cs.opensource.google/go/x/crypto/+/v0.46.0:LICENSE)) + - [golang.org/x/crypto](https://pkg.go.dev/golang.org/x/crypto) ([BSD-3-Clause](https://cs.opensource.google/go/x/crypto/+/v0.50.0:LICENSE)) - [golang.org/x/exp](https://pkg.go.dev/golang.org/x/exp) ([BSD-3-Clause](https://cs.opensource.google/go/x/exp/+/b7579e27:LICENSE)) - [golang.org/x/mobile](https://pkg.go.dev/golang.org/x/mobile) ([BSD-3-Clause](https://cs.opensource.google/go/x/mobile/+/81131f64:LICENSE)) - - [golang.org/x/mod/semver](https://pkg.go.dev/golang.org/x/mod/semver) ([BSD-3-Clause](https://cs.opensource.google/go/x/mod/+/v0.31.0:LICENSE)) - - [golang.org/x/net](https://pkg.go.dev/golang.org/x/net) ([BSD-3-Clause](https://cs.opensource.google/go/x/net/+/v0.48.0:LICENSE)) - - [golang.org/x/sync](https://pkg.go.dev/golang.org/x/sync) ([BSD-3-Clause](https://cs.opensource.google/go/x/sync/+/v0.19.0:LICENSE)) - - [golang.org/x/sys](https://pkg.go.dev/golang.org/x/sys) ([BSD-3-Clause](https://cs.opensource.google/go/x/sys/+/v0.40.0:LICENSE)) - - [golang.org/x/term](https://pkg.go.dev/golang.org/x/term) ([BSD-3-Clause](https://cs.opensource.google/go/x/term/+/v0.38.0:LICENSE)) - - [golang.org/x/text](https://pkg.go.dev/golang.org/x/text) ([BSD-3-Clause](https://cs.opensource.google/go/x/text/+/v0.32.0:LICENSE)) + - [golang.org/x/mod/semver](https://pkg.go.dev/golang.org/x/mod/semver) ([BSD-3-Clause](https://cs.opensource.google/go/x/mod/+/v0.35.0:LICENSE)) + - [golang.org/x/net](https://pkg.go.dev/golang.org/x/net) ([BSD-3-Clause](https://cs.opensource.google/go/x/net/+/v0.53.0:LICENSE)) + - [golang.org/x/sync](https://pkg.go.dev/golang.org/x/sync) ([BSD-3-Clause](https://cs.opensource.google/go/x/sync/+/v0.20.0:LICENSE)) + - [golang.org/x/sys](https://pkg.go.dev/golang.org/x/sys) ([BSD-3-Clause](https://cs.opensource.google/go/x/sys/+/v0.43.0:LICENSE)) + - [golang.org/x/term](https://pkg.go.dev/golang.org/x/term) ([BSD-3-Clause](https://cs.opensource.google/go/x/term/+/v0.42.0:LICENSE)) + - [golang.org/x/text](https://pkg.go.dev/golang.org/x/text) ([BSD-3-Clause](https://cs.opensource.google/go/x/text/+/v0.36.0:LICENSE)) - [golang.org/x/time/rate](https://pkg.go.dev/golang.org/x/time/rate) ([BSD-3-Clause](https://cs.opensource.google/go/x/time/+/v0.12.0:LICENSE)) - - [golang.org/x/tools](https://pkg.go.dev/golang.org/x/tools) ([BSD-3-Clause](https://cs.opensource.google/go/x/tools/+/ca281cf9:LICENSE)) + - [golang.org/x/tools](https://pkg.go.dev/golang.org/x/tools) ([BSD-3-Clause](https://cs.opensource.google/go/x/tools/+/v0.44.0:LICENSE)) - [gvisor.dev/gvisor/pkg](https://pkg.go.dev/gvisor.dev/gvisor/pkg) ([Apache-2.0](https://github.com/google/gvisor/blob/573d5e7127a8/LICENSE)) - [tailscale.com](https://pkg.go.dev/tailscale.com) ([BSD-3-Clause](https://github.com/tailscale/tailscale/blob/HEAD/LICENSE)) diff --git a/licenses/apple.md b/licenses/apple.md index fa80ca300..16d7ad036 100644 --- a/licenses/apple.md +++ b/licenses/apple.md @@ -12,23 +12,23 @@ See also the dependencies in the [Tailscale CLI][]. - [filippo.io/edwards25519](https://pkg.go.dev/filippo.io/edwards25519) ([BSD-3-Clause](https://github.com/FiloSottile/edwards25519/blob/v1.2.0/LICENSE)) - - [github.com/aws/aws-sdk-go-v2](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/v1.41.5/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/config](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/config) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/config/v1.32.5/config/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/credentials](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/credentials) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/credentials/v1.19.5/credentials/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/feature/ec2/imds](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/feature/ec2/imds) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/feature/ec2/imds/v1.18.16/feature/ec2/imds/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/internal/configsources](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/internal/configsources) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/internal/configsources/v1.4.16/internal/configsources/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/internal/endpoints/v2](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/internal/endpoints/v2) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/internal/endpoints/v2.7.16/internal/endpoints/v2/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/internal/ini](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/internal/ini) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/internal/ini/v1.8.4/internal/ini/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/internal/sync/singleflight](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/internal/sync/singleflight) ([BSD-3-Clause](https://github.com/aws/aws-sdk-go-v2/blob/v1.41.5/internal/sync/singleflight/LICENSE)) - - [github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/internal/accept-encoding/v1.13.4/service/internal/accept-encoding/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/service/internal/presigned-url](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/internal/presigned-url) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/internal/presigned-url/v1.13.16/service/internal/presigned-url/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/service/signin](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/signin) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/signin/v1.0.4/service/signin/LICENSE.txt)) + - [github.com/aws/aws-sdk-go-v2](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/v1.41.7/LICENSE.txt)) + - [github.com/aws/aws-sdk-go-v2/config](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/config) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/config/v1.32.17/config/LICENSE.txt)) + - [github.com/aws/aws-sdk-go-v2/credentials](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/credentials) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/credentials/v1.19.16/credentials/LICENSE.txt)) + - [github.com/aws/aws-sdk-go-v2/feature/ec2/imds](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/feature/ec2/imds) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/feature/ec2/imds/v1.18.23/feature/ec2/imds/LICENSE.txt)) + - [github.com/aws/aws-sdk-go-v2/internal/configsources](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/internal/configsources) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/internal/configsources/v1.4.23/internal/configsources/LICENSE.txt)) + - [github.com/aws/aws-sdk-go-v2/internal/endpoints/v2](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/internal/endpoints/v2) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/internal/endpoints/v2.7.23/internal/endpoints/v2/LICENSE.txt)) + - [github.com/aws/aws-sdk-go-v2/internal/sync/singleflight](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/internal/sync/singleflight) ([BSD-3-Clause](https://github.com/aws/aws-sdk-go-v2/blob/v1.41.7/internal/sync/singleflight/LICENSE)) + - [github.com/aws/aws-sdk-go-v2/internal/v4a](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/internal/v4a) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/internal/v4a/v1.4.24/internal/v4a/LICENSE.txt)) + - [github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/internal/accept-encoding/v1.13.9/service/internal/accept-encoding/LICENSE.txt)) + - [github.com/aws/aws-sdk-go-v2/service/internal/presigned-url](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/internal/presigned-url) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/internal/presigned-url/v1.13.23/service/internal/presigned-url/LICENSE.txt)) + - [github.com/aws/aws-sdk-go-v2/service/signin](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/signin) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/signin/v1.0.11/service/signin/LICENSE.txt)) - [github.com/aws/aws-sdk-go-v2/service/ssm](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/ssm) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/ssm/v1.45.0/service/ssm/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/service/sso](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/sso) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/sso/v1.30.7/service/sso/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/service/ssooidc](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/ssooidc) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/ssooidc/v1.35.12/service/ssooidc/LICENSE.txt)) - - [github.com/aws/aws-sdk-go-v2/service/sts](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/sts) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/sts/v1.41.5/service/sts/LICENSE.txt)) - - [github.com/aws/smithy-go](https://pkg.go.dev/github.com/aws/smithy-go) ([Apache-2.0](https://github.com/aws/smithy-go/blob/v1.24.2/LICENSE)) - - [github.com/aws/smithy-go/internal/sync/singleflight](https://pkg.go.dev/github.com/aws/smithy-go/internal/sync/singleflight) ([BSD-3-Clause](https://github.com/aws/smithy-go/blob/v1.24.2/internal/sync/singleflight/LICENSE)) + - [github.com/aws/aws-sdk-go-v2/service/sso](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/sso) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/sso/v1.30.17/service/sso/LICENSE.txt)) + - [github.com/aws/aws-sdk-go-v2/service/ssooidc](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/ssooidc) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/ssooidc/v1.35.21/service/ssooidc/LICENSE.txt)) + - [github.com/aws/aws-sdk-go-v2/service/sts](https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/service/sts) ([Apache-2.0](https://github.com/aws/aws-sdk-go-v2/blob/service/sts/v1.42.1/service/sts/LICENSE.txt)) + - [github.com/aws/smithy-go](https://pkg.go.dev/github.com/aws/smithy-go) ([Apache-2.0](https://github.com/aws/smithy-go/blob/v1.25.1/LICENSE)) + - [github.com/aws/smithy-go/internal/sync/singleflight](https://pkg.go.dev/github.com/aws/smithy-go/internal/sync/singleflight) ([BSD-3-Clause](https://github.com/aws/smithy-go/blob/v1.25.1/internal/sync/singleflight/LICENSE)) - [github.com/coreos/go-iptables/iptables](https://pkg.go.dev/github.com/coreos/go-iptables/iptables) ([Apache-2.0](https://github.com/coreos/go-iptables/blob/65c67c9f46e6/LICENSE)) - [github.com/creachadair/msync/trigger](https://pkg.go.dev/github.com/creachadair/msync/trigger) ([BSD-3-Clause](https://github.com/creachadair/msync/blob/v0.8.1/LICENSE)) - [github.com/digitalocean/go-smbios/smbios](https://pkg.go.dev/github.com/digitalocean/go-smbios/smbios) ([Apache-2.0](https://github.com/digitalocean/go-smbios/blob/390a4f403a8e/LICENSE.md)) @@ -36,11 +36,10 @@ See also the dependencies in the [Tailscale CLI][]. - [github.com/fxamacker/cbor/v2](https://pkg.go.dev/github.com/fxamacker/cbor/v2) ([MIT](https://github.com/fxamacker/cbor/blob/v2.9.0/LICENSE)) - [github.com/gaissmai/bart](https://pkg.go.dev/github.com/gaissmai/bart) ([MIT](https://github.com/gaissmai/bart/blob/v0.26.1/LICENSE)) - [github.com/go-json-experiment/json](https://pkg.go.dev/github.com/go-json-experiment/json) ([BSD-3-Clause](https://github.com/go-json-experiment/json/blob/4849db3c2f7e/LICENSE)) - - [github.com/godbus/dbus/v5](https://pkg.go.dev/github.com/godbus/dbus/v5) ([BSD-2-Clause](https://github.com/godbus/dbus/blob/76236955d466/LICENSE)) + - [github.com/godbus/dbus/v5](https://pkg.go.dev/github.com/godbus/dbus/v5) ([BSD-2-Clause](https://github.com/godbus/dbus/blob/v5.2.2/LICENSE)) - [github.com/golang/groupcache/lru](https://pkg.go.dev/github.com/golang/groupcache/lru) ([Apache-2.0](https://github.com/golang/groupcache/blob/2c02b8208cf8/LICENSE)) - [github.com/google/btree](https://pkg.go.dev/github.com/google/btree) ([Apache-2.0](https://github.com/google/btree/blob/v1.1.3/LICENSE)) - [github.com/google/nftables](https://pkg.go.dev/github.com/google/nftables) ([Apache-2.0](https://github.com/google/nftables/blob/5e242ec57806/LICENSE)) - - [github.com/google/uuid](https://pkg.go.dev/github.com/google/uuid) ([BSD-3-Clause](https://github.com/google/uuid/blob/v1.6.0/LICENSE)) - [github.com/hdevalence/ed25519consensus](https://pkg.go.dev/github.com/hdevalence/ed25519consensus) ([BSD-3-Clause](https://github.com/hdevalence/ed25519consensus/blob/v0.2.0/LICENSE)) - [github.com/huin/goupnp](https://pkg.go.dev/github.com/huin/goupnp) ([BSD-2-Clause](https://github.com/huin/goupnp/blob/v1.3.0/LICENSE)) - [github.com/illarion/gonotify/v3](https://pkg.go.dev/github.com/illarion/gonotify/v3) ([MIT](https://github.com/illarion/gonotify/blob/v3.0.2/LICENSE)) @@ -48,9 +47,9 @@ See also the dependencies in the [Tailscale CLI][]. - [github.com/jellydator/ttlcache/v3](https://pkg.go.dev/github.com/jellydator/ttlcache/v3) ([MIT](https://github.com/jellydator/ttlcache/blob/v3.1.0/LICENSE)) - [github.com/jmespath/go-jmespath](https://pkg.go.dev/github.com/jmespath/go-jmespath) ([Apache-2.0](https://github.com/jmespath/go-jmespath/blob/v0.4.0/LICENSE)) - [github.com/jsimonetti/rtnetlink](https://pkg.go.dev/github.com/jsimonetti/rtnetlink) ([MIT](https://github.com/jsimonetti/rtnetlink/blob/v1.4.1/LICENSE.md)) - - [github.com/klauspost/compress](https://pkg.go.dev/github.com/klauspost/compress) ([Apache-2.0](https://github.com/klauspost/compress/blob/v1.18.4/LICENSE)) - - [github.com/klauspost/compress/internal/snapref](https://pkg.go.dev/github.com/klauspost/compress/internal/snapref) ([BSD-3-Clause](https://github.com/klauspost/compress/blob/v1.18.4/internal/snapref/LICENSE)) - - [github.com/klauspost/compress/zstd/internal/xxhash](https://pkg.go.dev/github.com/klauspost/compress/zstd/internal/xxhash) ([MIT](https://github.com/klauspost/compress/blob/v1.18.4/zstd/internal/xxhash/LICENSE.txt)) + - [github.com/klauspost/compress](https://pkg.go.dev/github.com/klauspost/compress) ([Apache-2.0](https://github.com/klauspost/compress/blob/v1.18.5/LICENSE)) + - [github.com/klauspost/compress/internal/snapref](https://pkg.go.dev/github.com/klauspost/compress/internal/snapref) ([BSD-3-Clause](https://github.com/klauspost/compress/blob/v1.18.5/internal/snapref/LICENSE)) + - [github.com/klauspost/compress/zstd/internal/xxhash](https://pkg.go.dev/github.com/klauspost/compress/zstd/internal/xxhash) ([MIT](https://github.com/klauspost/compress/blob/v1.18.5/zstd/internal/xxhash/LICENSE.txt)) - [github.com/kortschak/wol](https://pkg.go.dev/github.com/kortschak/wol) ([BSD-3-Clause](https://github.com/kortschak/wol/blob/da482cc4850a/LICENSE)) - [github.com/mdlayher/genetlink](https://pkg.go.dev/github.com/mdlayher/genetlink) ([MIT](https://github.com/mdlayher/genetlink/blob/v1.3.2/LICENSE.md)) - [github.com/mdlayher/netlink](https://pkg.go.dev/github.com/mdlayher/netlink) ([MIT](https://github.com/mdlayher/netlink/blob/fbb4dce95f42/LICENSE.md)) @@ -59,25 +58,24 @@ See also the dependencies in the [Tailscale CLI][]. - [github.com/mitchellh/go-ps](https://pkg.go.dev/github.com/mitchellh/go-ps) ([MIT](https://github.com/mitchellh/go-ps/blob/v1.0.0/LICENSE.md)) - [github.com/pierrec/lz4/v4](https://pkg.go.dev/github.com/pierrec/lz4/v4) ([BSD-3-Clause](https://github.com/pierrec/lz4/blob/v4.1.25/LICENSE)) - [github.com/pires/go-proxyproto](https://pkg.go.dev/github.com/pires/go-proxyproto) ([Apache-2.0](https://github.com/pires/go-proxyproto/blob/v0.8.1/LICENSE)) - - [github.com/prometheus-community/pro-bing](https://pkg.go.dev/github.com/prometheus-community/pro-bing) ([MIT](https://github.com/prometheus-community/pro-bing/blob/v0.4.0/LICENSE)) - [github.com/safchain/ethtool](https://pkg.go.dev/github.com/safchain/ethtool) ([Apache-2.0](https://github.com/safchain/ethtool/blob/v0.3.0/LICENSE)) - [github.com/tailscale/netlink](https://pkg.go.dev/github.com/tailscale/netlink) ([Apache-2.0](https://github.com/tailscale/netlink/blob/4d49adab4de7/LICENSE)) - [github.com/tailscale/peercred](https://pkg.go.dev/github.com/tailscale/peercred) ([BSD-3-Clause](https://github.com/tailscale/peercred/blob/35a0c7bd7edc/LICENSE)) - - [github.com/tailscale/wireguard-go](https://pkg.go.dev/github.com/tailscale/wireguard-go) ([MIT](https://github.com/tailscale/wireguard-go/blob/4184faf59e56/LICENSE)) + - [github.com/tailscale/wireguard-go](https://pkg.go.dev/github.com/tailscale/wireguard-go) ([MIT](https://github.com/tailscale/wireguard-go/blob/e3ac4a0afb4e/LICENSE)) - [github.com/tailscale/xnet/webdav](https://pkg.go.dev/github.com/tailscale/xnet/webdav) ([BSD-3-Clause](https://github.com/tailscale/xnet/blob/8497ac4dab2e/LICENSE)) - [github.com/u-root/uio](https://pkg.go.dev/github.com/u-root/uio) ([BSD-3-Clause](https://github.com/u-root/uio/blob/d2acac8f3701/LICENSE)) - [github.com/vishvananda/netns](https://pkg.go.dev/github.com/vishvananda/netns) ([Apache-2.0](https://github.com/vishvananda/netns/blob/v0.0.5/LICENSE)) - [github.com/x448/float16](https://pkg.go.dev/github.com/x448/float16) ([MIT](https://github.com/x448/float16/blob/v0.8.4/LICENSE)) - [go4.org/mem](https://pkg.go.dev/go4.org/mem) ([Apache-2.0](https://github.com/go4org/mem/blob/ae6ca9944745/LICENSE)) - [go4.org/netipx](https://pkg.go.dev/go4.org/netipx) ([BSD-3-Clause](https://github.com/go4org/netipx/blob/fdeea329fbba/LICENSE)) - - [golang.org/x/crypto](https://pkg.go.dev/golang.org/x/crypto) ([BSD-3-Clause](https://cs.opensource.google/go/x/crypto/+/v0.49.0:LICENSE)) - - [golang.org/x/exp](https://pkg.go.dev/golang.org/x/exp) ([BSD-3-Clause](https://cs.opensource.google/go/x/exp/+/a4bb9ffd:LICENSE)) - - [golang.org/x/net](https://pkg.go.dev/golang.org/x/net) ([BSD-3-Clause](https://cs.opensource.google/go/x/net/+/v0.52.0:LICENSE)) + - [golang.org/x/crypto](https://pkg.go.dev/golang.org/x/crypto) ([BSD-3-Clause](https://cs.opensource.google/go/x/crypto/+/v0.50.0:LICENSE)) + - [golang.org/x/exp](https://pkg.go.dev/golang.org/x/exp) ([BSD-3-Clause](https://cs.opensource.google/go/x/exp/+/3dfff04d:LICENSE)) + - [golang.org/x/net](https://pkg.go.dev/golang.org/x/net) ([BSD-3-Clause](https://cs.opensource.google/go/x/net/+/v0.53.0:LICENSE)) - [golang.org/x/sync](https://pkg.go.dev/golang.org/x/sync) ([BSD-3-Clause](https://cs.opensource.google/go/x/sync/+/v0.20.0:LICENSE)) - - [golang.org/x/sys](https://pkg.go.dev/golang.org/x/sys) ([BSD-3-Clause](https://cs.opensource.google/go/x/sys/+/v0.42.0:LICENSE)) - - [golang.org/x/term](https://pkg.go.dev/golang.org/x/term) ([BSD-3-Clause](https://cs.opensource.google/go/x/term/+/v0.41.0:LICENSE)) - - [golang.org/x/text](https://pkg.go.dev/golang.org/x/text) ([BSD-3-Clause](https://cs.opensource.google/go/x/text/+/v0.35.0:LICENSE)) - - [golang.org/x/time/rate](https://pkg.go.dev/golang.org/x/time/rate) ([BSD-3-Clause](https://cs.opensource.google/go/x/time/+/v0.12.0:LICENSE)) + - [golang.org/x/sys](https://pkg.go.dev/golang.org/x/sys) ([BSD-3-Clause](https://cs.opensource.google/go/x/sys/+/v0.43.0:LICENSE)) + - [golang.org/x/term](https://pkg.go.dev/golang.org/x/term) ([BSD-3-Clause](https://cs.opensource.google/go/x/term/+/v0.42.0:LICENSE)) + - [golang.org/x/text](https://pkg.go.dev/golang.org/x/text) ([BSD-3-Clause](https://cs.opensource.google/go/x/text/+/v0.36.0:LICENSE)) + - [golang.org/x/time/rate](https://pkg.go.dev/golang.org/x/time/rate) ([BSD-3-Clause](https://cs.opensource.google/go/x/time/+/v0.15.0:LICENSE)) - [gvisor.dev/gvisor/pkg](https://pkg.go.dev/gvisor.dev/gvisor/pkg) ([Apache-2.0](https://github.com/google/gvisor/blob/573d5e7127a8/LICENSE)) - [tailscale.com](https://pkg.go.dev/tailscale.com) ([BSD-3-Clause](https://github.com/tailscale/tailscale/blob/HEAD/LICENSE)) diff --git a/licenses/tailscale.md b/licenses/tailscale.md index d5e46e90f..489da564e 100644 --- a/licenses/tailscale.md +++ b/licenses/tailscale.md @@ -58,9 +58,9 @@ Some packages may only be included on certain architectures or operating systems - [github.com/jellydator/ttlcache/v3](https://pkg.go.dev/github.com/jellydator/ttlcache/v3) ([MIT](https://github.com/jellydator/ttlcache/blob/v3.1.0/LICENSE)) - [github.com/jmespath/go-jmespath](https://pkg.go.dev/github.com/jmespath/go-jmespath) ([Apache-2.0](https://github.com/jmespath/go-jmespath/blob/v0.4.0/LICENSE)) - [github.com/kballard/go-shellquote](https://pkg.go.dev/github.com/kballard/go-shellquote) ([MIT](https://github.com/kballard/go-shellquote/blob/95032a82bc51/LICENSE)) - - [github.com/klauspost/compress](https://pkg.go.dev/github.com/klauspost/compress) ([Apache-2.0](https://github.com/klauspost/compress/blob/v1.18.2/LICENSE)) - - [github.com/klauspost/compress/internal/snapref](https://pkg.go.dev/github.com/klauspost/compress/internal/snapref) ([BSD-3-Clause](https://github.com/klauspost/compress/blob/v1.18.2/internal/snapref/LICENSE)) - - [github.com/klauspost/compress/zstd/internal/xxhash](https://pkg.go.dev/github.com/klauspost/compress/zstd/internal/xxhash) ([MIT](https://github.com/klauspost/compress/blob/v1.18.2/zstd/internal/xxhash/LICENSE.txt)) + - [github.com/klauspost/compress](https://pkg.go.dev/github.com/klauspost/compress) ([Apache-2.0](https://github.com/klauspost/compress/blob/v1.18.5/LICENSE)) + - [github.com/klauspost/compress/internal/snapref](https://pkg.go.dev/github.com/klauspost/compress/internal/snapref) ([BSD-3-Clause](https://github.com/klauspost/compress/blob/v1.18.5/internal/snapref/LICENSE)) + - [github.com/klauspost/compress/zstd/internal/xxhash](https://pkg.go.dev/github.com/klauspost/compress/zstd/internal/xxhash) ([MIT](https://github.com/klauspost/compress/blob/v1.18.5/zstd/internal/xxhash/LICENSE.txt)) - [github.com/kortschak/wol](https://pkg.go.dev/github.com/kortschak/wol) ([BSD-3-Clause](https://github.com/kortschak/wol/blob/da482cc4850a/LICENSE)) - [github.com/kr/fs](https://pkg.go.dev/github.com/kr/fs) ([BSD-3-Clause](https://github.com/kr/fs/blob/v0.1.0/LICENSE)) - [github.com/mattn/go-colorable](https://pkg.go.dev/github.com/mattn/go-colorable) ([MIT](https://github.com/mattn/go-colorable/blob/v0.1.13/LICENSE)) @@ -72,13 +72,13 @@ Some packages may only be included on certain architectures or operating systems - [github.com/pierrec/lz4/v4](https://pkg.go.dev/github.com/pierrec/lz4/v4) ([BSD-3-Clause](https://github.com/pierrec/lz4/blob/v4.1.25/LICENSE)) - [github.com/pires/go-proxyproto](https://pkg.go.dev/github.com/pires/go-proxyproto) ([Apache-2.0](https://github.com/pires/go-proxyproto/blob/v0.8.1/LICENSE)) - [github.com/pkg/sftp](https://pkg.go.dev/github.com/pkg/sftp) ([BSD-2-Clause](https://github.com/pkg/sftp/blob/v1.13.6/LICENSE)) - - [github.com/prometheus-community/pro-bing](https://pkg.go.dev/github.com/prometheus-community/pro-bing) ([MIT](https://github.com/prometheus-community/pro-bing/blob/v0.4.0/LICENSE)) - [github.com/skip2/go-qrcode](https://pkg.go.dev/github.com/skip2/go-qrcode) ([MIT](https://github.com/skip2/go-qrcode/blob/da1b6568686e/LICENSE)) - - [github.com/tailscale/certstore](https://pkg.go.dev/github.com/tailscale/certstore) ([MIT](https://github.com/tailscale/certstore/blob/d3fa0460f47e/LICENSE.md)) + - [github.com/tailscale/certstore](https://pkg.go.dev/github.com/tailscale/certstore) ([MIT](https://github.com/tailscale/certstore/blob/3638fb84b77d/LICENSE.md)) + - [github.com/tailscale/gliderssh](https://pkg.go.dev/github.com/tailscale/gliderssh) ([BSD-3-Clause](https://github.com/tailscale/gliderssh/blob/c1389c70ff89/LICENSE)) - [github.com/tailscale/go-winio](https://pkg.go.dev/github.com/tailscale/go-winio) ([MIT](https://github.com/tailscale/go-winio/blob/c4f33415bf55/LICENSE)) - [github.com/tailscale/web-client-prebuilt](https://pkg.go.dev/github.com/tailscale/web-client-prebuilt) ([BSD-3-Clause](https://github.com/tailscale/web-client-prebuilt/blob/d4cd19a26976/LICENSE)) - [github.com/tailscale/wf](https://pkg.go.dev/github.com/tailscale/wf) ([BSD-3-Clause](https://github.com/tailscale/wf/blob/6fbb0a674ee6/LICENSE)) - - [github.com/tailscale/wireguard-go](https://pkg.go.dev/github.com/tailscale/wireguard-go) ([MIT](https://github.com/tailscale/wireguard-go/blob/4184faf59e56/LICENSE)) + - [github.com/tailscale/wireguard-go](https://pkg.go.dev/github.com/tailscale/wireguard-go) ([MIT](https://github.com/tailscale/wireguard-go/blob/e3ac4a0afb4e/LICENSE)) - [github.com/tailscale/xnet/webdav](https://pkg.go.dev/github.com/tailscale/xnet/webdav) ([BSD-3-Clause](https://github.com/tailscale/xnet/blob/8497ac4dab2e/LICENSE)) - [github.com/toqueteos/webbrowser](https://pkg.go.dev/github.com/toqueteos/webbrowser) ([MIT](https://github.com/toqueteos/webbrowser/blob/v1.2.0/LICENSE.md)) - [github.com/u-root/u-root/pkg/termios](https://pkg.go.dev/github.com/u-root/u-root/pkg/termios) ([BSD-3-Clause](https://github.com/u-root/u-root/blob/v0.14.0/LICENSE)) @@ -87,15 +87,15 @@ Some packages may only be included on certain architectures or operating systems - [go.yaml.in/yaml/v2](https://pkg.go.dev/go.yaml.in/yaml/v2) ([Apache-2.0](https://github.com/yaml/go-yaml/blob/v2.4.2/LICENSE)) - [go4.org/mem](https://pkg.go.dev/go4.org/mem) ([Apache-2.0](https://github.com/go4org/mem/blob/ae6ca9944745/LICENSE)) - [go4.org/netipx](https://pkg.go.dev/go4.org/netipx) ([BSD-3-Clause](https://github.com/go4org/netipx/blob/fdeea329fbba/LICENSE)) - - [golang.org/x/crypto](https://pkg.go.dev/golang.org/x/crypto) ([BSD-3-Clause](https://cs.opensource.google/go/x/crypto/+/v0.46.0:LICENSE)) + - [golang.org/x/crypto](https://pkg.go.dev/golang.org/x/crypto) ([BSD-3-Clause](https://cs.opensource.google/go/x/crypto/+/v0.50.0:LICENSE)) - [golang.org/x/exp](https://pkg.go.dev/golang.org/x/exp) ([BSD-3-Clause](https://cs.opensource.google/go/x/exp/+/b7579e27:LICENSE)) - [golang.org/x/image](https://pkg.go.dev/golang.org/x/image) ([BSD-3-Clause](https://cs.opensource.google/go/x/image/+/v0.27.0:LICENSE)) - - [golang.org/x/net](https://pkg.go.dev/golang.org/x/net) ([BSD-3-Clause](https://cs.opensource.google/go/x/net/+/v0.48.0:LICENSE)) - - [golang.org/x/oauth2](https://pkg.go.dev/golang.org/x/oauth2) ([BSD-3-Clause](https://cs.opensource.google/go/x/oauth2/+/v0.33.0:LICENSE)) - - [golang.org/x/sync](https://pkg.go.dev/golang.org/x/sync) ([BSD-3-Clause](https://cs.opensource.google/go/x/sync/+/v0.19.0:LICENSE)) - - [golang.org/x/sys](https://pkg.go.dev/golang.org/x/sys) ([BSD-3-Clause](https://cs.opensource.google/go/x/sys/+/v0.40.0:LICENSE)) - - [golang.org/x/term](https://pkg.go.dev/golang.org/x/term) ([BSD-3-Clause](https://cs.opensource.google/go/x/term/+/v0.38.0:LICENSE)) - - [golang.org/x/text](https://pkg.go.dev/golang.org/x/text) ([BSD-3-Clause](https://cs.opensource.google/go/x/text/+/v0.32.0:LICENSE)) + - [golang.org/x/net](https://pkg.go.dev/golang.org/x/net) ([BSD-3-Clause](https://cs.opensource.google/go/x/net/+/v0.53.0:LICENSE)) + - [golang.org/x/oauth2](https://pkg.go.dev/golang.org/x/oauth2) ([BSD-3-Clause](https://cs.opensource.google/go/x/oauth2/+/v0.36.0:LICENSE)) + - [golang.org/x/sync](https://pkg.go.dev/golang.org/x/sync) ([BSD-3-Clause](https://cs.opensource.google/go/x/sync/+/v0.20.0:LICENSE)) + - [golang.org/x/sys](https://pkg.go.dev/golang.org/x/sys) ([BSD-3-Clause](https://cs.opensource.google/go/x/sys/+/v0.43.0:LICENSE)) + - [golang.org/x/term](https://pkg.go.dev/golang.org/x/term) ([BSD-3-Clause](https://cs.opensource.google/go/x/term/+/v0.42.0:LICENSE)) + - [golang.org/x/text](https://pkg.go.dev/golang.org/x/text) ([BSD-3-Clause](https://cs.opensource.google/go/x/text/+/v0.36.0:LICENSE)) - [golang.org/x/time/rate](https://pkg.go.dev/golang.org/x/time/rate) ([BSD-3-Clause](https://cs.opensource.google/go/x/time/+/v0.12.0:LICENSE)) - [golang.zx2c4.com/wintun](https://pkg.go.dev/golang.zx2c4.com/wintun) ([MIT](https://git.zx2c4.com/wintun-go/tree/LICENSE?id=0fa3db229ce2)) - [golang.zx2c4.com/wireguard/windows/tunnel/winipcfg](https://pkg.go.dev/golang.zx2c4.com/wireguard/windows/tunnel/winipcfg) ([MIT](https://git.zx2c4.com/wireguard-windows/tree/COPYING?h=v0.5.3)) @@ -103,5 +103,4 @@ Some packages may only be included on certain architectures or operating systems - [k8s.io/client-go/util/homedir](https://pkg.go.dev/k8s.io/client-go/util/homedir) ([Apache-2.0](https://github.com/kubernetes/client-go/blob/v0.34.0/LICENSE)) - [sigs.k8s.io/yaml](https://pkg.go.dev/sigs.k8s.io/yaml) ([Apache-2.0](https://github.com/kubernetes-sigs/yaml/blob/v1.6.0/LICENSE)) - [tailscale.com](https://pkg.go.dev/tailscale.com) ([BSD-3-Clause](https://github.com/tailscale/tailscale/blob/HEAD/LICENSE)) - - [github.com/tailscale/gliderssh](https://pkg.go.dev/github.com/tailscale/gliderssh) ([BSD-3-Clause](https://github.com/tailscale/gliderssh/blob/HEAD/LICENSE)) - [tailscale.com/tempfork/spf13/cobra](https://pkg.go.dev/tailscale.com/tempfork/spf13/cobra) ([Apache-2.0](https://github.com/tailscale/tailscale/blob/HEAD/tempfork/spf13/cobra/LICENSE.txt)) diff --git a/licenses/windows.md b/licenses/windows.md index 6329655cc..33c142550 100644 --- a/licenses/windows.md +++ b/licenses/windows.md @@ -28,9 +28,9 @@ Windows][]. See also the dependencies in the [Tailscale CLI][]. - [github.com/hdevalence/ed25519consensus](https://pkg.go.dev/github.com/hdevalence/ed25519consensus) ([BSD-3-Clause](https://github.com/hdevalence/ed25519consensus/blob/v0.2.0/LICENSE)) - [github.com/jellydator/ttlcache/v3](https://pkg.go.dev/github.com/jellydator/ttlcache/v3) ([MIT](https://github.com/jellydator/ttlcache/blob/v3.1.0/LICENSE)) - [github.com/jsimonetti/rtnetlink](https://pkg.go.dev/github.com/jsimonetti/rtnetlink) ([MIT](https://github.com/jsimonetti/rtnetlink/blob/v1.4.1/LICENSE.md)) - - [github.com/klauspost/compress](https://pkg.go.dev/github.com/klauspost/compress) ([Apache-2.0](https://github.com/klauspost/compress/blob/v1.18.4/LICENSE)) - - [github.com/klauspost/compress/internal/snapref](https://pkg.go.dev/github.com/klauspost/compress/internal/snapref) ([BSD-3-Clause](https://github.com/klauspost/compress/blob/v1.18.4/internal/snapref/LICENSE)) - - [github.com/klauspost/compress/zstd/internal/xxhash](https://pkg.go.dev/github.com/klauspost/compress/zstd/internal/xxhash) ([MIT](https://github.com/klauspost/compress/blob/v1.18.4/zstd/internal/xxhash/LICENSE.txt)) + - [github.com/klauspost/compress](https://pkg.go.dev/github.com/klauspost/compress) ([Apache-2.0](https://github.com/klauspost/compress/blob/v1.18.5/LICENSE)) + - [github.com/klauspost/compress/internal/snapref](https://pkg.go.dev/github.com/klauspost/compress/internal/snapref) ([BSD-3-Clause](https://github.com/klauspost/compress/blob/v1.18.5/internal/snapref/LICENSE)) + - [github.com/klauspost/compress/zstd/internal/xxhash](https://pkg.go.dev/github.com/klauspost/compress/zstd/internal/xxhash) ([MIT](https://github.com/klauspost/compress/blob/v1.18.5/zstd/internal/xxhash/LICENSE.txt)) - [github.com/mdlayher/netlink](https://pkg.go.dev/github.com/mdlayher/netlink) ([MIT](https://github.com/mdlayher/netlink/blob/fbb4dce95f42/LICENSE.md)) - [github.com/mdlayher/socket](https://pkg.go.dev/github.com/mdlayher/socket) ([MIT](https://github.com/mdlayher/socket/blob/v0.5.0/LICENSE.md)) - [github.com/mitchellh/go-ps](https://pkg.go.dev/github.com/mitchellh/go-ps) ([MIT](https://github.com/mitchellh/go-ps/blob/v1.0.0/LICENSE.md)) @@ -42,7 +42,7 @@ Windows][]. See also the dependencies in the [Tailscale CLI][]. - [github.com/prometheus/common](https://pkg.go.dev/github.com/prometheus/common) ([Apache-2.0](https://github.com/prometheus/common/blob/v0.67.5/LICENSE)) - [github.com/skip2/go-qrcode](https://pkg.go.dev/github.com/skip2/go-qrcode) ([MIT](https://github.com/skip2/go-qrcode/blob/da1b6568686e/LICENSE)) - [github.com/tailscale/go-winio](https://pkg.go.dev/github.com/tailscale/go-winio) ([MIT](https://github.com/tailscale/go-winio/blob/c4f33415bf55/LICENSE)) - - [github.com/tailscale/hujson](https://pkg.go.dev/github.com/tailscale/hujson) ([BSD-3-Clause](https://github.com/tailscale/hujson/blob/992244df8c5a/LICENSE)) + - [github.com/tailscale/hujson](https://pkg.go.dev/github.com/tailscale/hujson) ([BSD-3-Clause](https://github.com/tailscale/hujson/blob/ecc657c15afd/LICENSE)) - [github.com/tailscale/walk](https://pkg.go.dev/github.com/tailscale/walk) ([BSD-3-Clause](https://github.com/tailscale/walk/blob/963e260a8227/LICENSE)) - [github.com/tailscale/win](https://pkg.go.dev/github.com/tailscale/win) ([BSD-3-Clause](https://github.com/tailscale/win/blob/f4da2b8ee071/LICENSE)) - [github.com/tailscale/xnet/webdav](https://pkg.go.dev/github.com/tailscale/xnet/webdav) ([BSD-3-Clause](https://github.com/tailscale/xnet/blob/8497ac4dab2e/LICENSE)) @@ -51,14 +51,14 @@ Windows][]. See also the dependencies in the [Tailscale CLI][]. - [go.yaml.in/yaml/v2](https://pkg.go.dev/go.yaml.in/yaml/v2) ([Apache-2.0](https://github.com/yaml/go-yaml/blob/v2.4.3/LICENSE)) - [go4.org/mem](https://pkg.go.dev/go4.org/mem) ([Apache-2.0](https://github.com/go4org/mem/blob/ae6ca9944745/LICENSE)) - [go4.org/netipx](https://pkg.go.dev/go4.org/netipx) ([BSD-3-Clause](https://github.com/go4org/netipx/blob/fdeea329fbba/LICENSE)) - - [golang.org/x/crypto](https://pkg.go.dev/golang.org/x/crypto) ([BSD-3-Clause](https://cs.opensource.google/go/x/crypto/+/v0.49.0:LICENSE)) - - [golang.org/x/exp](https://pkg.go.dev/golang.org/x/exp) ([BSD-3-Clause](https://cs.opensource.google/go/x/exp/+/a4bb9ffd:LICENSE)) + - [golang.org/x/crypto](https://pkg.go.dev/golang.org/x/crypto) ([BSD-3-Clause](https://cs.opensource.google/go/x/crypto/+/v0.50.0:LICENSE)) + - [golang.org/x/exp](https://pkg.go.dev/golang.org/x/exp) ([BSD-3-Clause](https://cs.opensource.google/go/x/exp/+/3dfff04d:LICENSE)) - [golang.org/x/image/bmp](https://pkg.go.dev/golang.org/x/image/bmp) ([BSD-3-Clause](https://cs.opensource.google/go/x/image/+/v0.27.0:LICENSE)) - - [golang.org/x/mod](https://pkg.go.dev/golang.org/x/mod) ([BSD-3-Clause](https://cs.opensource.google/go/x/mod/+/v0.33.0:LICENSE)) - - [golang.org/x/net](https://pkg.go.dev/golang.org/x/net) ([BSD-3-Clause](https://cs.opensource.google/go/x/net/+/v0.52.0:LICENSE)) + - [golang.org/x/mod](https://pkg.go.dev/golang.org/x/mod) ([BSD-3-Clause](https://cs.opensource.google/go/x/mod/+/v0.35.0:LICENSE)) + - [golang.org/x/net](https://pkg.go.dev/golang.org/x/net) ([BSD-3-Clause](https://cs.opensource.google/go/x/net/+/v0.53.0:LICENSE)) - [golang.org/x/sync](https://pkg.go.dev/golang.org/x/sync) ([BSD-3-Clause](https://cs.opensource.google/go/x/sync/+/v0.20.0:LICENSE)) - - [golang.org/x/sys](https://pkg.go.dev/golang.org/x/sys) ([BSD-3-Clause](https://cs.opensource.google/go/x/sys/+/v0.42.0:LICENSE)) - - [golang.org/x/term](https://pkg.go.dev/golang.org/x/term) ([BSD-3-Clause](https://cs.opensource.google/go/x/term/+/v0.41.0:LICENSE)) + - [golang.org/x/sys](https://pkg.go.dev/golang.org/x/sys) ([BSD-3-Clause](https://cs.opensource.google/go/x/sys/+/v0.43.0:LICENSE)) + - [golang.org/x/term](https://pkg.go.dev/golang.org/x/term) ([BSD-3-Clause](https://cs.opensource.google/go/x/term/+/v0.42.0:LICENSE)) - [golang.zx2c4.com/wintun](https://pkg.go.dev/golang.zx2c4.com/wintun) ([MIT](https://git.zx2c4.com/wintun-go/tree/LICENSE?id=0fa3db229ce2)) - [golang.zx2c4.com/wireguard/windows/tunnel/winipcfg](https://pkg.go.dev/golang.zx2c4.com/wireguard/windows/tunnel/winipcfg) ([MIT](https://git.zx2c4.com/wireguard-windows/tree/COPYING?h=v0.5.3)) - [google.golang.org/protobuf](https://pkg.go.dev/google.golang.org/protobuf) ([BSD-3-Clause](https://github.com/protocolbuffers/protobuf-go/blob/v1.36.11/LICENSE)) diff --git a/logtail/config.go b/logtail/config.go index c504047a3..0ee599905 100644 --- a/logtail/config.go +++ b/logtail/config.go @@ -64,4 +64,12 @@ type Config struct { // being included in the logs. The sequence number is incremented for each // log message sent, but is not persisted across process restarts. IncludeProcSequence bool + + // Disabled, if true, causes the returned [Logger] to start in the + // disabled state, dropping entries without buffering or uploading + // (equivalent to calling [Logger.SetEnabled] with false immediately). + // It applies before the internal startup banner is written, so no + // log entries are emitted until [Logger.SetEnabled] is called with + // true. The process-wide [Disable] kill switch still takes precedence. + Disabled bool } diff --git a/logtail/logtail.go b/logtail/logtail.go index ed3872e79..a45f1bfe9 100644 --- a/logtail/logtail.go +++ b/logtail/logtail.go @@ -132,6 +132,7 @@ func NewLogger(cfg Config, logf tslogger.Logf) *Logger { } logger.SetSockstatsLabel(sockstats.LabelLogtailLogger) logger.compressLogs = cfg.CompressLogs + logger.disabled.Store(cfg.Disabled) ctx, cancel := context.WithCancel(context.Background()) logger.uploadCancel = cancel @@ -172,6 +173,11 @@ type Logger struct { procID uint32 includeProcSequence bool + // disabled, when true, causes this logger to drop incoming log entries + // without buffering or uploading. It is independent of the process-wide + // Disable kill switch, which takes precedence. Toggled by SetEnabled. + disabled atomic.Bool + writeLock sync.Mutex // guards procSequence, flushTimer, buffer.Write calls procSequence uint64 flushTimer tstime.TimerController // used when flushDelay is >0 @@ -594,6 +600,15 @@ func Disable() { logtailDisabled.Store(true) } +// SetEnabled enables or disables log uploading by lg. When disabled, log +// entries passed to lg are dropped rather than buffered or uploaded; already +// buffered entries may still drain. The process-wide [Disable] kill switch +// takes precedence: if Disable has been called, SetEnabled(true) does not +// re-enable uploads. +func (lg *Logger) SetEnabled(enabled bool) { + lg.disabled.Store(!enabled) +} + var debugWakesAndUploads = envknob.RegisterBool("TS_DEBUG_LOGTAIL_WAKES") // tryDrainWake tries to send to lg.drainWake, to cause an uploading wakeup. @@ -613,7 +628,7 @@ func (lg *Logger) tryDrainWake() { func (lg *Logger) sendLocked(jsonBlob []byte) (int, error) { tapSend(jsonBlob) - if logtailDisabled.Load() { + if logtailDisabled.Load() || lg.disabled.Load() { return len(jsonBlob), nil } diff --git a/logtail/logtail_omit.go b/logtail/logtail_omit.go index 21f18c980..98f1c6a0e 100644 --- a/logtail/logtail_omit.go +++ b/logtail/logtail_omit.go @@ -20,6 +20,8 @@ type Buffer any func Disable() {} +func (*Logger) SetEnabled(enabled bool) {} + func NewLogger(cfg Config, logf tslogger.Logf) *Logger { return &Logger{} } diff --git a/logtail/logtail_test.go b/logtail/logtail_test.go index 19e1eeb7a..8273097c3 100644 --- a/logtail/logtail_test.go +++ b/logtail/logtail_test.go @@ -7,34 +7,51 @@ import ( "bytes" "context" "encoding/json" + "fmt" "io" + "net" "net/http" - "net/http/httptest" + "os" "strings" + "sync" "testing" + "testing/synctest" "time" "github.com/go-json-experiment/json/jsontext" + "tailscale.com/net/memnet" "tailscale.com/tstest" "tailscale.com/tstime" "tailscale.com/util/eventbus/eventbustest" "tailscale.com/util/must" ) -func TestFastShutdown(t *testing.T) { +// TestMain installs a safety net that refuses non-localhost dials for any +// test in this package. Config.BaseURL defaults to https://log.tailscale.com +// and Config.HTTPC defaults to http.DefaultClient, so a test that forgets to +// override either can otherwise silently hit the real logtail server. +// Tests that need an HTTP server should use memnet (see newTestLogtailServer). +func TestMain(m *testing.M) { + tr := http.DefaultTransport.(*http.Transport) + orig := tr.DialContext + tr.DialContext = func(ctx context.Context, network, addr string) (net.Conn, error) { + host, _, err := net.SplitHostPort(addr) + if err == nil && (host == "127.0.0.1" || host == "::1" || host == "localhost") { + return orig(ctx, network, addr) + } + return nil, fmt.Errorf("logtail tests: refusing to dial non-localhost address %q; use memnet or a custom Config.HTTPC", addr) + } + os.Exit(m.Run()) +} + +func TestFastShutdown(t *testing.T) { synctest.Test(t, synctestFastShutdown) } + +func synctestFastShutdown(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) cancel() - testServ := httptest.NewServer(http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) {})) - defer testServ.Close() - - logger := NewLogger(Config{ - BaseURL: testServ.URL, - Bus: eventbustest.NewBus(t), - }, t.Logf) - err := logger.Shutdown(ctx) - if err != nil { + _, logger := newTestLogtailServer(t) + if err := logger.Shutdown(ctx); err != nil { t.Error(err) } } @@ -43,49 +60,60 @@ func TestFastShutdown(t *testing.T) { const logLines = 3 type LogtailTestServer struct { - srv *httptest.Server // Log server uploaded chan []byte } -func NewLogtailTestHarness(t *testing.T) (*LogtailTestServer, *Logger) { - ts := LogtailTestServer{} +// newTestLogtailServer wires up an in-memory HTTP server (via memnet) and a +// *Logger whose HTTPC dials it. Lives inside the caller's synctest bubble so +// the default FlushDelay and any other fake timers advance automatically. +func newTestLogtailServer(t *testing.T) (*LogtailTestServer, *Logger) { + ts := &LogtailTestServer{ + // max channel backlog = 1 "started" + #logLines x "log line" + 1 "closed" + uploaded: make(chan []byte, 2+logLines), + } - // max channel backlog = 1 "started" + #logLines x "log line" + 1 "closed" - ts.uploaded = make(chan []byte, 2+logLines) + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + body, err := io.ReadAll(r.Body) + if err != nil { + t.Error("failed to read HTTP request") + } + ts.uploaded <- body + }) - ts.srv = httptest.NewServer(http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - body, err := io.ReadAll(r.Body) - if err != nil { - t.Error("failed to read HTTP request") - } - ts.uploaded <- body - })) - - t.Cleanup(ts.srv.Close) + ln := memnet.Listen("logtail-test:0") + httpsrv := &http.Server{Handler: handler} + go httpsrv.Serve(ln) + t.Cleanup(func() { + httpsrv.Close() + ln.Close() + }) logger := NewLogger(Config{ - BaseURL: ts.srv.URL, + BaseURL: "http://" + ln.Addr().String(), Bus: eventbustest.NewBus(t), + HTTPC: &http.Client{ + Transport: &http.Transport{DialContext: ln.Dial}, + }, }, t.Logf) - // There is always an initial "logtail started" message + // There is always an initial "logtail started" message. body := <-ts.uploaded if !strings.Contains(string(body), "started") { t.Errorf("unknown start logging statement: %q", string(body)) } - - return &ts, logger + return ts, logger } -func TestDrainPendingMessages(t *testing.T) { - ts, logger := NewLogtailTestHarness(t) +func TestDrainPendingMessages(t *testing.T) { synctest.Test(t, synctestDrainPendingMessages) } + +func synctestDrainPendingMessages(t *testing.T) { + ts, logger := newTestLogtailServer(t) for range logLines { logger.Write([]byte("log line")) } - // all of the "log line" messages usually arrive at once, but poll if needed. + // All the "log line" messages usually arrive at once, but poll if needed. var body strings.Builder for i := 0; i <= logLines; i++ { body.WriteString(string(<-ts.uploaded)) @@ -93,17 +121,17 @@ func TestDrainPendingMessages(t *testing.T) { if count == logLines { break } - // if we never find count == logLines, the test will eventually time out. } - err := logger.Shutdown(context.Background()) - if err != nil { + if err := logger.Shutdown(context.Background()); err != nil { t.Error(err) } } -func TestEncodeAndUploadMessages(t *testing.T) { - ts, logger := NewLogtailTestHarness(t) +func TestEncodeAndUploadMessages(t *testing.T) { synctest.Test(t, synctestEncodeAndUploadMessages) } + +func synctestEncodeAndUploadMessages(t *testing.T) { + ts, logger := newTestLogtailServer(t) tests := []struct { name string @@ -144,8 +172,7 @@ func TestEncodeAndUploadMessages(t *testing.T) { } } - err := logger.Shutdown(context.Background()) - if err != nil { + if err := logger.Shutdown(context.Background()); err != nil { t.Error(err) } } @@ -321,6 +348,90 @@ func TestLoggerWriteResult(t *testing.T) { } } +type roundTripperFunc func(*http.Request) (*http.Response, error) + +func (f roundTripperFunc) RoundTrip(r *http.Request) (*http.Response, error) { return f(r) } + +func TestNewLoggerDisabled(t *testing.T) { synctest.Test(t, synctestNewLoggerDisabled) } + +func synctestNewLoggerDisabled(t *testing.T) { + // When Config.Disabled is true, NewLogger must not emit the usual + // "logtail started" banner: the logger should start in the disabled + // state before the internal startup write, so nothing ever lands + // in the buffer for the upload goroutine to drain. + buf := NewMemoryBuffer(100) + + // Any HTTP attempt indicates the banner leaked into the buffer and + // the upload goroutine tried to ship it. Report it once (so the + // retry spin doesn't drown the log), then block on the request + // context so synctest.Wait sees a durable block and Shutdown's + // uploadCancel can unblock us cleanly. + var once sync.Once + httpc := &http.Client{ + Transport: roundTripperFunc(func(r *http.Request) (*http.Response, error) { + once.Do(func() { + t.Errorf("unexpected HTTP request while Disabled=true: %s", r.URL) + }) + <-r.Context().Done() + return nil, r.Context().Err() + }), + } + + logger := NewLogger(Config{ + BaseURL: "http://logtail.test.invalid", + HTTPC: httpc, + Bus: eventbustest.NewBus(t), + Buffer: buf, + Disabled: true, + }, t.Logf) + defer func() { + // Pass an already-cancelled context so Shutdown invokes + // uploadCancel immediately; otherwise on the regression path + // (Disabled=false) the upload goroutine stays in its retry + // loop and synctest.Test never returns. + ctx, cancel := context.WithCancel(context.Background()) + cancel() + logger.Shutdown(ctx) + }() + + synctest.Wait() + + if back, _ := buf.TryReadLine(); len(back) != 0 { + t.Errorf("Disabled logger buffered a startup entry: %q", back) + } +} + +func TestLoggerSetEnabled(t *testing.T) { + buf := NewMemoryBuffer(100) + lg := &Logger{ + clock: tstest.NewClock(tstest.ClockOpts{Start: time.Unix(123, 0)}), + buffer: buf, + } + + if _, err := lg.Write([]byte("enabled1")); err != nil { + t.Fatal(err) + } + if back, _ := buf.TryReadLine(); !strings.Contains(string(back), "enabled1") { + t.Fatalf("initial write not buffered; got %q", back) + } + + lg.SetEnabled(false) + if _, err := lg.Write([]byte("disabled")); err != nil { + t.Fatal(err) + } + if back, _ := buf.TryReadLine(); len(back) != 0 { + t.Errorf("write while disabled leaked into buffer: %q", back) + } + + lg.SetEnabled(true) + if _, err := lg.Write([]byte("enabled2")); err != nil { + t.Fatal(err) + } + if back, _ := buf.TryReadLine(); !strings.Contains(string(back), "enabled2") { + t.Errorf("write after re-enable not buffered; got %q", back) + } +} + func TestAppendMetadata(t *testing.T) { var lg Logger lg.clock = tstest.NewClock(tstest.ClockOpts{Start: time.Date(2000, 01, 01, 0, 0, 0, 0, time.UTC)}) diff --git a/maths/ewma.go b/maths/ewma.go deleted file mode 100644 index 1946081cf..000000000 --- a/maths/ewma.go +++ /dev/null @@ -1,72 +0,0 @@ -// Copyright (c) Tailscale Inc & contributors -// SPDX-License-Identifier: BSD-3-Clause - -// Package maths contains additional mathematical functions or structures not -// found in the standard library. -package maths - -import ( - "math" - "time" -) - -// EWMA is an exponentially weighted moving average supporting updates at -// irregular intervals with at most nanosecond resolution. -// The zero value will compute a half-life of 1 second. -// It is not safe for concurrent use. -// TODO(raggi): de-duplicate with tstime/rate.Value, which has a more complex -// and synchronized interface and does not provide direct access to the stable -// value. -type EWMA struct { - value float64 // current value of the average - lastTime int64 // time of last update in unix nanos - halfLife float64 // half-life in seconds -} - -// NewEWMA creates a new EWMA with the specified half-life. If halfLifeSeconds -// is 0, it defaults to 1. -func NewEWMA(halfLifeSeconds float64) *EWMA { - return &EWMA{ - halfLife: halfLifeSeconds, - } -} - -// Update adds a new sample to the average. If t is zero or precedes the last -// update, the update is ignored. -func (e *EWMA) Update(value float64, t time.Time) { - if t.IsZero() { - return - } - hl := e.halfLife - if hl == 0 { - hl = 1 - } - tn := t.UnixNano() - if e.lastTime == 0 { - e.value = value - e.lastTime = tn - return - } - - dt := (time.Duration(tn-e.lastTime) * time.Nanosecond).Seconds() - if dt < 0 { - // drop out of order updates - return - } - - // decay = 2^(-dt/halfLife) - decay := math.Exp2(-dt / hl) - e.value = e.value*decay + value*(1-decay) - e.lastTime = tn -} - -// Get returns the current value of the average -func (e *EWMA) Get() float64 { - return e.value -} - -// Reset clears the EWMA to its initial state -func (e *EWMA) Reset() { - e.value = 0 - e.lastTime = 0 -} diff --git a/maths/ewma_test.go b/maths/ewma_test.go deleted file mode 100644 index 9fddf34e1..000000000 --- a/maths/ewma_test.go +++ /dev/null @@ -1,178 +0,0 @@ -// Copyright (c) Tailscale Inc & contributors -// SPDX-License-Identifier: BSD-3-Clause - -package maths - -import ( - "slices" - "testing" - "time" -) - -// some real world latency samples. -var ( - latencyHistory1 = []int{ - 14, 12, 15, 6, 19, 12, 13, 13, 13, 16, 17, 11, 17, 11, 14, 15, 14, 15, - 16, 16, 17, 14, 12, 16, 18, 14, 14, 11, 15, 15, 25, 11, 15, 14, 12, 15, - 13, 12, 13, 15, 11, 13, 15, 14, 14, 15, 12, 15, 18, 12, 15, 22, 12, 13, - 10, 14, 16, 15, 16, 11, 14, 17, 18, 20, 16, 11, 16, 14, 5, 15, 17, 12, - 15, 11, 15, 20, 12, 17, 12, 17, 15, 12, 12, 11, 14, 15, 11, 20, 14, 13, - 11, 12, 13, 13, 11, 13, 11, 15, 13, 13, 14, 12, 11, 12, 12, 14, 11, 13, - 12, 12, 12, 19, 14, 13, 13, 14, 11, 12, 10, 11, 15, 12, 14, 11, 11, 14, - 14, 12, 12, 11, 14, 12, 11, 12, 14, 11, 12, 15, 12, 14, 12, 12, 21, 16, - 21, 12, 16, 9, 11, 16, 14, 13, 14, 12, 13, 16, - } - latencyHistory2 = []int{ - 18, 20, 21, 21, 20, 23, 18, 18, 20, 21, 20, 19, 22, 18, 20, 20, 19, 21, - 21, 22, 22, 19, 18, 22, 22, 19, 20, 17, 16, 11, 25, 16, 18, 21, 17, 22, - 19, 18, 22, 21, 20, 18, 22, 17, 17, 20, 19, 10, 19, 16, 19, 25, 17, 18, - 15, 20, 21, 20, 23, 22, 22, 22, 19, 22, 22, 17, 22, 20, 20, 19, 21, 22, - 20, 19, 17, 22, 16, 16, 20, 22, 17, 19, 21, 16, 20, 22, 19, 21, 20, 19, - 13, 14, 23, 19, 16, 10, 19, 15, 15, 17, 16, 18, 14, 16, 18, 22, 20, 18, - 18, 21, 15, 19, 18, 19, 18, 20, 17, 19, 21, 19, 20, 19, 20, 20, 17, 14, - 17, 17, 18, 21, 20, 18, 18, 17, 16, 17, 17, 20, 22, 19, 20, 21, 21, 20, - 21, 24, 20, 18, 12, 17, 18, 17, 19, 19, 19, - } -) - -func TestEWMALatencyHistory(t *testing.T) { - type result struct { - t time.Time - v float64 - s int - } - - for _, latencyHistory := range [][]int{latencyHistory1, latencyHistory2} { - startTime := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC) - halfLife := 30.0 - - ewma := NewEWMA(halfLife) - - var results []result - sum := 0.0 - for i, latency := range latencyHistory { - t := startTime.Add(time.Duration(i) * time.Second) - ewma.Update(float64(latency), t) - sum += float64(latency) - - results = append(results, result{t, ewma.Get(), latency}) - } - mean := sum / float64(len(latencyHistory)) - min := float64(slices.Min(latencyHistory)) - max := float64(slices.Max(latencyHistory)) - - t.Logf("EWMA Latency History (half-life: %.1f seconds):", halfLife) - t.Logf("Mean latency: %.2f ms", mean) - t.Logf("Range: [%.1f, %.1f]", min, max) - - t.Log("Samples: ") - sparkline := []rune("▁▂▃▄▅▆▇█") - var sampleLine []rune - for _, r := range results { - idx := int(((float64(r.s) - min) / (max - min)) * float64(len(sparkline)-1)) - if idx >= len(sparkline) { - idx = len(sparkline) - 1 - } - sampleLine = append(sampleLine, sparkline[idx]) - } - t.Log(string(sampleLine)) - - t.Log("EWMA: ") - var ewmaLine []rune - for _, r := range results { - idx := int(((r.v - min) / (max - min)) * float64(len(sparkline)-1)) - if idx >= len(sparkline) { - idx = len(sparkline) - 1 - } - ewmaLine = append(ewmaLine, sparkline[idx]) - } - t.Log(string(ewmaLine)) - t.Log("") - - t.Logf("Time | Sample | Value | Value - Sample") - t.Logf("") - - for _, result := range results { - t.Logf("%10s | % 6d | % 5.2f | % 5.2f", result.t.Format("15:04:05"), result.s, result.v, result.v-float64(result.s)) - } - - // check that all results are greater than the min, and less than the max of the input, - // and they're all close to the mean. - for _, result := range results { - if result.v < float64(min) || result.v > float64(max) { - t.Errorf("result %f out of range [%f, %f]", result.v, min, max) - } - - if result.v < mean*0.9 || result.v > mean*1.1 { - t.Errorf("result %f not close to mean %f", result.v, mean) - } - } - } -} - -func TestHalfLife(t *testing.T) { - start := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC) - - ewma := NewEWMA(30.0) - ewma.Update(10, start) - ewma.Update(0, start.Add(30*time.Second)) - - if ewma.Get() != 5 { - t.Errorf("expected 5, got %f", ewma.Get()) - } - - ewma.Update(10, start.Add(60*time.Second)) - if ewma.Get() != 7.5 { - t.Errorf("expected 7.5, got %f", ewma.Get()) - } - - ewma.Update(10, start.Add(90*time.Second)) - if ewma.Get() != 8.75 { - t.Errorf("expected 8.75, got %f", ewma.Get()) - } -} - -func TestZeroValue(t *testing.T) { - start := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC) - - var ewma EWMA - ewma.Update(10, start) - ewma.Update(0, start.Add(time.Second)) - - if ewma.Get() != 5 { - t.Errorf("expected 5, got %f", ewma.Get()) - } - - ewma.Update(10, start.Add(2*time.Second)) - if ewma.Get() != 7.5 { - t.Errorf("expected 7.5, got %f", ewma.Get()) - } - - ewma.Update(10, start.Add(3*time.Second)) - if ewma.Get() != 8.75 { - t.Errorf("expected 8.75, got %f", ewma.Get()) - } -} - -func TestReset(t *testing.T) { - start := time.Date(2025, 1, 1, 12, 0, 0, 0, time.UTC) - - ewma := NewEWMA(30.0) - ewma.Update(10, start) - ewma.Update(0, start.Add(30*time.Second)) - - if ewma.Get() != 5 { - t.Errorf("expected 5, got %f", ewma.Get()) - } - - ewma.Reset() - - if ewma.Get() != 0 { - t.Errorf("expected 0, got %f", ewma.Get()) - } - - ewma.Update(10, start.Add(90*time.Second)) - if ewma.Get() != 10 { - t.Errorf("expected 10, got %f", ewma.Get()) - } -} diff --git a/metrics/multilabelmap_test.go b/metrics/multilabelmap_test.go index 70554c63e..0fa730992 100644 --- a/metrics/multilabelmap_test.go +++ b/metrics/multilabelmap_test.go @@ -86,10 +86,10 @@ metricname{foo="si",bar="si"} 5 func TestMultiLabelMapTypes(t *testing.T) { type LabelTypes struct { - S string - B bool - I int - U uint + S string + B bool + Int int + U uint } m := new(MultiLabelMap[LabelTypes]) @@ -100,7 +100,7 @@ func TestMultiLabelMapTypes(t *testing.T) { m.WritePrometheus(&buf, "metricname") const want = `# TYPE metricname counter # HELP metricname some good stuff -metricname{s="a",b="true",i="-1",u="2"} 3 +metricname{s="a",b="true",int="-1",u="2"} 3 ` if got := buf.String(); got != want { t.Errorf("got %q; want %q", got, want) diff --git a/misc/genreadme/genreadme.go b/misc/genreadme/genreadme.go new file mode 100644 index 000000000..97a8d9e16 --- /dev/null +++ b/misc/genreadme/genreadme.go @@ -0,0 +1,267 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +// The genreadme tool generates/updates README.md files in the tailscale repo. +// +// # Running +// +// From the repo root, run: `./tool/go run ./misc/genreadme` and it will update all +// the README.md files that are stale in the tree. +package main + +import ( + "bytes" + "errors" + "flag" + "fmt" + "go/parser" + "go/token" + "io" + "io/fs" + "log" + "os" + "path" + "path/filepath" + "runtime" + "strings" + + "github.com/creachadair/taskgroup" + "tailscale.com/tempfork/pkgdoc" +) + +// modulePath is the current module's import path, read from go.mod at startup. +var modulePath string + +var skip = map[string]bool{ + "out": true, +} + +// bkSkip lists directories where the generated file should not mention +// Buildkite because a deploy workflow is not set up for them. +var bkSkip = map[string]bool{} + +// defaultRoots are the directory trees walked when genreadme is run with +// no arguments. Add a directory here to opt its package (and any +// sub-packages) into README.md generation from godoc. +var defaultRoots = []string{ + "tsnet", +} + +func main() { + flag.Parse() + modulePath = readModulePath("go.mod") + var roots []string + switch flag.NArg() { + case 0: + roots = defaultRoots + case 1: + root := flag.Arg(0) + root = strings.TrimPrefix(root, "./") + root = strings.TrimSuffix(root, "/") + roots = []string{root} + default: + log.Fatalf("Usage: genreadme [dir]") + } + + var updateErrs []error + g, run := taskgroup.New(func(err error) { + updateErrs = append(updateErrs, err) + }).Limit(runtime.NumCPU() * 2) // usually I/O bound + + for _, root := range roots { + g.Go(func() error { + return fs.WalkDir(os.DirFS("."), root, func(path string, d fs.DirEntry, err error) error { + if err != nil { + return err + } + if !d.IsDir() { + return nil + } + if skip[path] { + return fs.SkipDir + } + base := filepath.Base(path) + if base == "testdata" || (path != "." && base[0] == '.') { + return fs.SkipDir + } + run(func() error { + return update(path) + }) + return nil + }) + }) + } + g.Wait() + if err := errors.Join(updateErrs...); err != nil { + log.Fatal(err) + } +} + +func update(dir string) error { + readmePath := filepath.Join(dir, "README.md") + cur, err := os.ReadFile(readmePath) + exists := false + if err != nil && !os.IsNotExist(err) { + return err + } + if err == nil { + exists = true + if !isGenerated(cur) { + // Do nothing; a human wrote this file. + return nil + } + } + + newContents, err := getNewContent(dir) + if err != nil { + return err + } + if newContents == nil { + if exists { + log.Printf("Deleting %s ...", readmePath) + os.Remove(readmePath) + } + return nil + } + + if bytes.Equal(cur, newContents) { + return nil + } + log.Printf("Writing %s ...", readmePath) + return os.WriteFile(readmePath, newContents, 0644) +} + +func getNewContent(dir string) (newContent []byte, err error) { + dents, err := os.ReadDir(dir) + if err != nil { + return nil, err + } + + generators := []struct { + name string + quickTest func(dir string, dents []fs.DirEntry) bool + generate func(dir string) ([]byte, error) + }{ + {"go", hasGoFiles, genGoDoc}, + } + for _, gen := range generators { + if !gen.quickTest(dir, dents) { + continue + } + newContent, err := gen.generate(dir) + if newContent == nil && err == nil { + // Generator declined to generate, try next + continue + } + return newContent, err + } + return nil, nil +} + +func genGoDoc(dir string) ([]byte, error) { + abs, err := filepath.Abs(dir) + if err != nil { + return nil, fmt.Errorf("failed to get absolute path for %q: %w", dir, err) + } + var importPath string + if modulePath != "" { + importPath = path.Join(modulePath, filepath.ToSlash(dir)) + } + godoc, err := pkgdoc.PackageDoc(abs, importPath) + if err != nil { + return nil, fmt.Errorf("failed to get package doc for %q: %w", dir, err) + } + if len(bytes.TrimSpace(godoc)) == 0 { + // No godoc; skipping. + return nil, nil + } + isLibrary := bytes.HasPrefix(godoc, []byte("package ")) + if isLibrary { + // Strip the "package X // import Y\n\n" clause emitted for library packages. + if i := bytes.Index(godoc, []byte("\n\n")); i != -1 { + godoc = godoc[i+2:] + } + } + if len(bytes.TrimSpace(godoc)) == 0 { + return nil, nil + } + var buf bytes.Buffer + io.WriteString(&buf, genHeader) + fmt.Fprintf(&buf, "\n# %s\n\n", filepath.Base(dir)) + if isLibrary && importPath != "" { + fmt.Fprintf(&buf, "[![Go Reference](https://pkg.go.dev/badge/%s.svg)](https://pkg.go.dev/%s)\n\n", importPath, importPath) + } + buf.Write(godoc) + + if !bytes.Contains(godoc, []byte("## Deploying")) { + deployPath := filepath.Join(dir, "deploy.sh") + if _, err := os.Stat(deployPath); err == nil { + fmt.Fprint(&buf, "\n## Deploying\n\n") + if hasBuildkite(dir) { + fmt.Fprintf(&buf, + "To deploy, run the https://buildkite.com/tailscale/deploy-%s workflow in Buildkite.\n", + filepath.Base(dir), + ) + } + fmt.Fprintf(&buf, "To deploy manually, run `./%s` from the repo root.\n\n", deployPath) + } + } + return buf.Bytes(), nil +} + +const genHeader = "\n" + +func isGenerated(b []byte) bool { return bytes.HasPrefix(b, []byte(genHeader)) } + +// readModulePath returns the module path declared in the given go.mod file, +// or "" if it can't be read or parsed. +func readModulePath(file string) string { + b, err := os.ReadFile(file) + if err != nil { + return "" + } + for line := range strings.Lines(string(b)) { + if rest, ok := strings.CutPrefix(strings.TrimSpace(line), "module "); ok { + return strings.Trim(strings.TrimSpace(rest), `"`) + } + } + return "" +} + +func hasBuildkite(dir string) bool { + if bkSkip[dir] { + return false + } + _, flyErr := os.Stat(filepath.Join(dir, "fly.toml")) + return flyErr != nil +} + +func hasGoFiles(dir string, dents []fs.DirEntry) bool { + var fset *token.FileSet + + for _, de := range dents { + name := de.Name() + if !strings.HasSuffix(name, ".go") || + strings.HasSuffix(name, "_test.go") { + continue + } + if fset == nil { + fset = token.NewFileSet() + } + + path := filepath.Join(dir, name) + f, err := os.Open(path) + if err != nil { + continue + } + pkgFile, err := parser.ParseFile(fset, "", f, parser.PackageClauseOnly) + f.Close() + if err != nil { + // skip files with parse errors + continue + } + + return pkgFile.Name.Name != "" + } + return false +} diff --git a/misc/git_hook/HOOK_VERSION b/misc/git_hook/HOOK_VERSION deleted file mode 100644 index d00491fd7..000000000 --- a/misc/git_hook/HOOK_VERSION +++ /dev/null @@ -1 +0,0 @@ -1 diff --git a/misc/git_hook/README.md b/misc/git_hook/README.md new file mode 100644 index 000000000..49d768937 --- /dev/null +++ b/misc/git_hook/README.md @@ -0,0 +1,35 @@ +# git_hook + +Tailscale's git hooks. + +The shared logic lives in the `githook/` package and is also imported by +`tailscale/corp`. + +## Install + +From the repo root: + + ./tool/go run ./misc/install-git-hooks.go + +The script auto-updates in the future. + + +## Adding your own hooks + +Create an executable `.git/hooks/.local` to chain a custom +script after a built-in hook. For example, put a custom check in +`.git/hooks/pre-commit.local` and `chmod +x` it. The local hook runs +only if the built-in hook succeeds; failure aborts the git operation. + + +## Version bumps + +The launcher rebuilds when the installed binary's version differs from +the concatenation of two files: + +* `githook/HOOK_VERSION` (shared): bump when changing anything under + `githook/` or `git-hook.go`. Downstream repos pick it up after + bumping their `tailscale.com` dependency. +* `misc/git_hook/HOOK_VERSION` (repo-local, optional): bump to force a + rebuild for repo-specific config changes without touching the shared + version. This repo does not use one. diff --git a/misc/git_hook/git-hook.go b/misc/git_hook/git-hook.go index 89e78b120..2cf3ff421 100644 --- a/misc/git_hook/git-hook.go +++ b/misc/git_hook/git-hook.go @@ -1,322 +1,62 @@ // Copyright (c) Tailscale Inc & contributors // SPDX-License-Identifier: BSD-3-Clause -// The git-hook command is Tailscale's git hooks. It's built by -// misc/install-git-hooks.go and installed into .git/hooks -// as .git/hooks/ts-git-hook, with shell wrappers. +// The git-hook command is Tailscale's git hook binary, built and +// installed under .git/hooks/ts-git-hook-bin by the launcher at +// .git/hooks/ts-git-hook. misc/install-git-hooks.go writes the initial +// launcher; subsequent HOOK_VERSION bumps trigger self-rebuilds. // // # Adding your own hooks // -// To add your own hook for one that we have already hooked, create a file named -// .local in .git/hooks. For example, to add your own pre-commit hook, -// create .git/hooks/pre-commit.local and make it executable. It will be run after -// the ts-git-hook, if ts-git-hook executes successfully. +// To add your own hook alongside one we already hook, create an executable +// file .git/hooks/.local (e.g. pre-commit.local). It runs after +// the built-in hook succeeds. package main import ( - "bufio" - "bytes" - "crypto/rand" - _ "embed" - "errors" "fmt" - "io" "log" "os" - "os/exec" - "path/filepath" - "strconv" "strings" - "github.com/fatih/color" - "github.com/sourcegraph/go-diff/diff" - "golang.org/x/mod/modfile" + "tailscale.com/misc/git_hook/githook" ) +var pushRemotes = []string{ + "git@github.com:tailscale/tailscale", + "git@github.com:tailscale/tailscale.git", + "https://github.com/tailscale/tailscale", + "https://github.com/tailscale/tailscale.git", +} + +// hooks are the hook names this binary handles. Used by install to +// write per-hook wrappers; must stay in sync with the dispatcher below. +var hooks = []string{"pre-commit", "commit-msg", "pre-push"} + func main() { log.SetFlags(0) if len(os.Args) < 2 { return } cmd, args := os.Args[1], os.Args[2:] + var err error switch cmd { + case "version": + fmt.Print(strings.TrimSpace(githook.HookVersion) + ":0") + case "install": + err = githook.WriteHooks(hooks) case "pre-commit": - err = preCommit(args) + err = githook.CheckForbiddenMarkers() case "commit-msg": - err = commitMsg(args) + err = githook.AddChangeID(args) case "pre-push": - err = prePush(args) - case "post-checkout": - err = postCheckout(args) + err = githook.CheckGoModReplaces(args, pushRemotes, nil) } if err != nil { - p := log.Fatalf - if nfe, ok := err.(nonFatalErr); ok { - p = log.Printf - err = nfe - } - p("git-hook: %v: %v", cmd, err) + log.Fatalf("git-hook: %v: %v", cmd, err) } - - if err == nil || errors.Is(err, nonFatalErr{}) { - err := runLocalHook(cmd, args) - if err != nil { - log.Fatalf("git-hook: %v", err) - } + if err := githook.RunLocalHook(cmd, args); err != nil { + log.Fatalf("git-hook: %v", err) } } - -func runLocalHook(hookName string, args []string) error { - cmdPath, err := os.Executable() - if err != nil { - return err - } - hookDir := filepath.Dir(cmdPath) - localHookPath := filepath.Join(hookDir, hookName+".local") - if _, err := os.Stat(localHookPath); errors.Is(err, os.ErrNotExist) { - return nil - } else if err != nil { - return fmt.Errorf("checking for local hook: %w", err) - } - - cmd := exec.Command(localHookPath, args...) - cmd.Stdout = os.Stdout - cmd.Stderr = os.Stderr - if err := cmd.Run(); err != nil { - return fmt.Errorf("running local hook %q: %w", localHookPath, err) - } - return nil -} - -// pre-commit: "It takes no parameters, and is invoked before -// obtaining the proposed commit log message and making a -// commit. Exiting with a non-zero status from this script causes the -// git commit command to abort before creating a commit." -// -// https://git-scm.com/docs/githooks#_pre_commit -func preCommit(_ []string) error { - diffOut, err := exec.Command("git", "diff", "--cached").Output() - if err != nil { - return fmt.Errorf("Could not get git diff: %w", err) - } - - diffs, err := diff.ParseMultiFileDiff(diffOut) - if err != nil { - return fmt.Errorf("Could not parse diff: %w", err) - } - - foundForbidden := false - for _, diff := range diffs { - for _, hunk := range diff.Hunks { - lines := bytes.Split(hunk.Body, []byte{'\n'}) - for i, line := range lines { - if len(line) == 0 || line[0] != '+' { - continue - } - for _, forbidden := range preCommitForbiddenPatterns { - if bytes.Contains(line, forbidden) { - if !foundForbidden { - color.New(color.Bold, color.FgRed, color.Underline).Printf("%s found:\n", forbidden) - } - // Output file name (dropping the b/ prefix) and line - // number so that it can be linkified by terminals. - fmt.Printf("%s:%d: %s\n", diff.NewName[2:], int(hunk.NewStartLine)+i, line[1:]) - foundForbidden = true - } - } - } - } - } - if foundForbidden { - return fmt.Errorf("Found forbidden string") - } - - return nil -} - -var preCommitForbiddenPatterns = [][]byte{ - // Use concatenation to avoid including the forbidden literals (and thus - // triggering the pre-commit hook). - []byte("NOCOM" + "MIT"), - []byte("DO NOT " + "SUBMIT"), -} - -// https://git-scm.com/docs/githooks#_commit_msg -func commitMsg(args []string) error { - if len(args) != 1 { - return errors.New("usage: commit-msg message.txt") - } - file := args[0] - msg, err := os.ReadFile(file) - if err != nil { - return err - } - msg = filterCutLine(msg) - - var id [20]byte - if _, err := io.ReadFull(rand.Reader, id[:]); err != nil { - return fmt.Errorf("could not generate Change-Id: %v", err) - } - cmdLines := [][]string{ - // Trim whitespace and comments. - {"git", "stripspace", "--strip-comments"}, - // Add Change-Id trailer. - {"git", "interpret-trailers", "--no-divider", "--where=start", "--if-exists", "doNothing", "--trailer", fmt.Sprintf("Change-Id: I%x", id)}, - } - for _, cmdLine := range cmdLines { - if len(msg) == 0 { - // Don't allow commands to go from empty commit message to non-empty (issue 2205). - break - } - cmd := exec.Command(cmdLine[0], cmdLine[1:]...) - cmd.Stdin = bytes.NewReader(msg) - msg, err = cmd.CombinedOutput() - if err != nil { - return fmt.Errorf("failed to run '%v': %w\n%s", cmd, err, msg) - } - } - - return os.WriteFile(file, msg, 0666) -} - -// pre-push: "this hook is called by git-push and can be used to -// prevent a push from taking place. The hook is called with two -// parameters which provide the name and location of the destination -// remote, if a named remote is not being used both values will be the -// same. -// -// Information about what is to be pushed is provided on the hook's -// standard input with lines of the form: -// -// SP SP SP LF -// -// More: https://git-scm.com/docs/githooks#_pre_push -func prePush(args []string) error { - remoteName, remoteLoc := args[0], args[1] - _ = remoteName - - pushes, err := readPushes() - if err != nil { - return fmt.Errorf("reading pushes: %w", err) - } - - switch remoteLoc { - case "git@github.com:tailscale/tailscale", "git@github.com:tailscale/tailscale.git", - "https://github.com/tailscale/tailscale", "https://github.com/tailscale/tailscale.git": - for _, p := range pushes { - if p.isDoNotMergeRef() { - continue - } - if err := checkCommit(p.localSHA); err != nil { - return fmt.Errorf("not allowing push of %v to %v: %v", p.localSHA, p.remoteRef, err) - } - } - } - - return nil -} - -//go:embed HOOK_VERSION -var compiledHookVersion string - -// post-checkout: "This hook is invoked when a git-checkout[1] or -// git-switch[1] is run after having updated the worktree. The hook is -// given three parameters: the ref of the previous HEAD, the ref of -// the new HEAD (which may or may not have changed), and a flag -// indicating whether the checkout was a branch checkout (changing -// branches, flag=1) or a file checkout (retrieving a file from the -// index, flag=0). -// -// More: https://git-scm.com/docs/githooks#_post_checkout -func postCheckout(_ []string) error { - compiled, err := strconv.Atoi(strings.TrimSpace(compiledHookVersion)) - if err != nil { - return fmt.Errorf("couldn't parse compiled-in hook version: %v", err) - } - - bs, err := os.ReadFile("misc/git_hook/HOOK_VERSION") - if errors.Is(err, os.ErrNotExist) { - // Probably checked out a commit that predates the existence - // of HOOK_VERSION, don't complain. - return nil - } - actual, err := strconv.Atoi(strings.TrimSpace(string(bs))) - if err != nil { - return fmt.Errorf("couldn't parse misc/git_hook/HOOK_VERSION: %v", err) - } - - if actual > compiled { - return nonFatalErr{fmt.Errorf("a newer git hook script is available, please run `./tool/go run ./misc/install-git-hooks.go`")} - } - return nil -} - -func checkCommit(sha string) error { - // Allow people to delete remote refs. - if sha == zeroRef { - return nil - } - // Check that go.mod doesn't contain replacements to directories. - goMod, err := exec.Command("git", "show", sha+":go.mod").Output() - if err != nil { - return err - } - mf, err := modfile.Parse("go.mod", goMod, nil) - if err != nil { - return fmt.Errorf("failed to parse its go.mod: %v", err) - } - for _, r := range mf.Replace { - if modfile.IsDirectoryPath(r.New.Path) { - return fmt.Errorf("go.mod contains replace from %v => %v", r.Old.Path, r.New.Path) - } - } - - return nil -} - -const zeroRef = "0000000000000000000000000000000000000000" - -type push struct { - localRef string // "refs/heads/bradfitz/githooks" - localSHA string // what's being pushed - remoteRef string // "refs/heads/bradfitz/githooks", "refs/heads/main" - remoteSHA string // old value being replaced, or zeroRef if it doesn't exist -} - -func (p *push) isDoNotMergeRef() bool { - return strings.HasSuffix(p.remoteRef, "/DO-NOT-MERGE") -} - -func readPushes() (pushes []push, err error) { - bs := bufio.NewScanner(os.Stdin) - for bs.Scan() { - f := strings.Fields(bs.Text()) - if len(f) != 4 { - return nil, fmt.Errorf("unexpected push line %q", bs.Text()) - } - pushes = append(pushes, push{f[0], f[1], f[2], f[3]}) - } - if err := bs.Err(); err != nil { - return nil, err - } - return pushes, nil -} - -// nonFatalErr is an error wrapper type to indicate that main() should -// not exit fatally. -type nonFatalErr struct { - error -} - -var gitCutLine = []byte("# ------------------------ >8 ------------------------") - -// filterCutLine searches for a git cutline (see above) and filters it and any -// following lines from the given message. This is typically produced in a -// commit message file by `git commit -v`. -func filterCutLine(msg []byte) []byte { - if before, _, ok := bytes.Cut(msg, gitCutLine); ok { - return before - } - return msg -} diff --git a/misc/git_hook/githook/HOOK_VERSION b/misc/git_hook/githook/HOOK_VERSION new file mode 100644 index 000000000..00750edc0 --- /dev/null +++ b/misc/git_hook/githook/HOOK_VERSION @@ -0,0 +1 @@ +3 diff --git a/misc/git_hook/githook/commit-msg.go b/misc/git_hook/githook/commit-msg.go new file mode 100644 index 000000000..e75bc79f3 --- /dev/null +++ b/misc/git_hook/githook/commit-msg.go @@ -0,0 +1,64 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package githook + +import ( + "bytes" + "crypto/rand" + "errors" + "fmt" + "io" + "os" + "os/exec" +) + +// AddChangeID strips comments from the commit message at args[0] and +// prepends a random Change-Id trailer. +// +// Intended as a commit-msg hook. +// https://git-scm.com/docs/githooks#_commit_msg +func AddChangeID(args []string) error { + if len(args) != 1 { + return errors.New("usage: commit-msg message.txt") + } + file := args[0] + msg, err := os.ReadFile(file) + if err != nil { + return err + } + msg = filterCutLine(msg) + + var id [20]byte + if _, err := io.ReadFull(rand.Reader, id[:]); err != nil { + return fmt.Errorf("could not generate Change-Id: %v", err) + } + cmdLines := [][]string{ + {"git", "stripspace", "--strip-comments"}, + {"git", "interpret-trailers", "--no-divider", "--where=start", "--if-exists", "doNothing", "--trailer", fmt.Sprintf("Change-Id: I%x", id)}, + } + for _, cmdLine := range cmdLines { + if len(msg) == 0 { + // Don't let commands turn an empty message into a non-empty one (issue 2205). + break + } + cmd := exec.Command(cmdLine[0], cmdLine[1:]...) + cmd.Stdin = bytes.NewReader(msg) + msg, err = cmd.CombinedOutput() + if err != nil { + return fmt.Errorf("failed to run %v: %w\n%s", cmd, err, msg) + } + } + return os.WriteFile(file, msg, 0666) +} + +var gitCutLine = []byte("# ------------------------ >8 ------------------------") + +// filterCutLine strips a `git commit -v`-style cutline and everything +// after it from msg. +func filterCutLine(msg []byte) []byte { + if before, _, ok := bytes.Cut(msg, gitCutLine); ok { + return before + } + return msg +} diff --git a/misc/git_hook/githook/githook.go b/misc/git_hook/githook/githook.go new file mode 100644 index 000000000..aa44051fc --- /dev/null +++ b/misc/git_hook/githook/githook.go @@ -0,0 +1,52 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +// Package githook contains the shared implementation of Tailscale's git +// hooks. The tailscale/tailscale and tailscale/corp repositories each have +// a thin main package that dispatches to this one, calling individual +// hook functions with per-repo arguments as needed. +package githook + +import ( + _ "embed" + "errors" + "fmt" + "os" + "os/exec" + "path/filepath" +) + +// Launcher is the canonical bytes of launcher.sh. Downstream repos +// (e.g. tailscale/corp) rely on these bytes at install time. +// +//go:embed launcher.sh +var Launcher []byte + +// HookVersion is the shared version of this package and launcher.sh. +// Bump HOOK_VERSION on any change under this package. +// +//go:embed HOOK_VERSION +var HookVersion string + +// RunLocalHook runs an optional user-supplied hook at +// .git/hooks/.local, if present. +func RunLocalHook(hookName string, args []string) error { + cmdPath, err := os.Executable() + if err != nil { + return err + } + localHookPath := filepath.Join(filepath.Dir(cmdPath), hookName+".local") + if _, err := os.Stat(localHookPath); errors.Is(err, os.ErrNotExist) { + return nil + } else if err != nil { + return fmt.Errorf("checking for local hook: %w", err) + } + + cmd := exec.Command(localHookPath, args...) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + if err := cmd.Run(); err != nil { + return fmt.Errorf("running local hook %q: %w", localHookPath, err) + } + return nil +} diff --git a/misc/git_hook/githook/install.go b/misc/git_hook/githook/install.go new file mode 100644 index 000000000..3c08daf8d --- /dev/null +++ b/misc/git_hook/githook/install.go @@ -0,0 +1,177 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package githook + +import ( + "fmt" + "io" + "os" + "os/exec" + "path/filepath" + "regexp" + "strings" +) + +// Install writes the launcher to .git/hooks/ts-git-hook and runs it +// once with "version", bootstrapping the binary build and per-hook +// wrappers. Called from each repo's misc/install-git-hooks.go. +func Install() error { + hookDir, err := findHookDir() + if err != nil { + return err + } + target := filepath.Join(hookDir, "ts-git-hook") + if err := writeLauncher(target); err != nil { + return err + } + + // The launcher execs the binary with our arg at the end; we pass + // "version" only to trigger the rebuild-if-stale path, and discard + // its stdout so the version string doesn't leak to the caller. + cmd := exec.Command(target, "version") + cmd.Stdout = io.Discard + cmd.Stderr = os.Stderr + if err := cmd.Run(); err != nil { + return fmt.Errorf("initial hook setup failed: %v", err) + } + return nil +} + +// WriteHooks writes the launcher to .git/hooks/ts-git-hook and a wrapper +// for each name in hooks to .git/hooks/. Stale wrappers from +// prior versions (ours, but no longer in hooks) are removed. If a path +// we are about to write exists and is not one of our wrappers, +// WriteHooks aborts with an error rather than clobber the user's hook. +// Called by the binary's "install" handler (after a rebuild) and by +// Install (initial setup). +func WriteHooks(hooks []string) error { + hookDir, err := findHookDir() + if err != nil { + return err + } + if err := writeLauncher(filepath.Join(hookDir, "ts-git-hook")); err != nil { + return err + } + want := make(map[string]bool, len(hooks)) + for _, h := range hooks { + want[h] = true + } + entries, err := os.ReadDir(hookDir) + if err != nil { + return fmt.Errorf("reading hooks dir: %v", err) + } + for _, e := range entries { + if e.IsDir() { + continue + } + name := e.Name() + path := filepath.Join(hookDir, name) + mine, err := isOurWrapper(path) + if err != nil { + return fmt.Errorf("inspecting %s: %v", path, err) + } + switch { + case want[name] && !mine: + return fmt.Errorf("%s exists and is not a ts-git-hook wrapper; "+ + "move your hook to %s.local (it will be chained after the wrapper) or delete it, then re-run: ./tool/go run ./misc/install-git-hooks.go", + path, name) + case !want[name] && mine: + // Stale wrapper from a prior version (e.g. a hook we used + // to install but no longer do). + if err := os.Remove(path); err != nil { + return fmt.Errorf("removing stale wrapper %s: %v", name, err) + } + } + } + for _, h := range hooks { + content := fmt.Sprintf(wrapperScript, h) + if err := os.WriteFile(filepath.Join(hookDir, h), []byte(content), 0755); err != nil { + return fmt.Errorf("writing wrapper for %s: %v", h, err) + } + } + return nil +} + +// isOurWrapper reports whether path is a hook wrapper written by us +// (in any historical format). Files we will never own (the launcher +// itself, user-chained .local hooks, git's .sample examples) return +// false unconditionally and are not read. An I/O error other than +// "not found" is returned to the caller; a missing file is not an +// error. +func isOurWrapper(path string) (bool, error) { + name := filepath.Base(path) + if name == "ts-git-hook" || + strings.HasSuffix(name, ".local") || + strings.HasSuffix(name, ".sample") { + return false, nil + } + b, err := os.ReadFile(path) + if os.IsNotExist(err) { + return false, nil + } + if err != nil { + return false, err + } + return wrapperRE.Match(b), nil +} + +// writeLauncher writes the embedded launcher to target via atomic rename, +// so a currently-running launcher keeps reading its old inode. +func writeLauncher(target string) error { + dir, name := filepath.Split(target) + f, err := os.CreateTemp(dir, name+".*") + if err != nil { + return fmt.Errorf("creating temp launcher: %v", err) + } + tmp := f.Name() + if _, err := f.Write(Launcher); err != nil { + f.Close() + os.Remove(tmp) + return fmt.Errorf("writing temp launcher: %v", err) + } + if err := f.Close(); err != nil { + os.Remove(tmp) + return err + } + if err := os.Chmod(tmp, 0755); err != nil { + os.Remove(tmp) + return err + } + if err := os.Rename(tmp, target); err != nil { + os.Remove(tmp) + return fmt.Errorf("installing launcher: %v", err) + } + return nil +} + +func findHookDir() (string, error) { + out, err := exec.Command("git", "rev-parse", "--git-path", "hooks").CombinedOutput() + if err != nil { + return "", fmt.Errorf("finding hooks dir: %v, %s", err, out) + } + hookDir, err := filepath.Abs(strings.TrimSpace(string(out))) + if err != nil { + return "", err + } + fi, err := os.Stat(hookDir) + if err != nil { + return "", fmt.Errorf("checking hooks dir: %v", err) + } + if !fi.IsDir() { + return "", fmt.Errorf("%s is not a directory", hookDir) + } + return hookDir, nil +} + +const wrapperScript = `#!/usr/bin/env bash +exec "$(dirname "${BASH_SOURCE[0]}")/ts-git-hook" %s "$@" +` + +// wrapperRE matches every historical shape of wrapperScript: a tiny +// bash script that execs a sibling ts-git-hook with a single hook-name +// argument. The inner quoting of ${BASH_SOURCE[0]} changed between +// versions, hence the "?s. +var wrapperRE = regexp.MustCompile( + `\A#!/usr/bin/env bash\nexec "\$\(dirname "?\$\{BASH_SOURCE\[0\]\}"?\)/ts-git-hook" [\w-]+ "\$@"\n?\z`, +) diff --git a/misc/git_hook/githook/launcher.sh b/misc/git_hook/githook/launcher.sh new file mode 100755 index 000000000..eddab585e --- /dev/null +++ b/misc/git_hook/githook/launcher.sh @@ -0,0 +1,51 @@ +#!/usr/bin/env bash +# ts-git-hook launcher (installed at .git/hooks/ts-git-hook). +# +# Written by misc/install-git-hooks.go from the canonical copy embedded +# in tailscale.com/misc/git_hook/githook. On every invocation it: +# +# 1. Compares the binary's reported version against the shared +# githook HOOK_VERSION (resolved via `go list -m tailscale.com`) +# plus the repo-local HOOK_VERSION. +# 2. If stale or missing: rebuilds ts-git-hook-bin and runs +# `ts-git-hook-bin install`. +# 3. Execs the binary with the hook's args. +set -euo pipefail + +REPO_ROOT="$(git rev-parse --show-toplevel 2>/dev/null)" || { + echo "git-hook: not in a git repo" >&2 + exit 1 +} + +HOOK_DIR="$(git -C "$REPO_ROOT" rev-parse --git-path hooks)" +case "$HOOK_DIR" in +/*) ;; +*) HOOK_DIR="$REPO_ROOT/$HOOK_DIR" ;; +esac + +# Windows (Git for Windows / MSYS2) needs .exe suffixes. +EXE="" +case "$(uname -s)" in MINGW* | MSYS* | CYGWIN*) EXE=".exe" ;; esac + +BINARY="$HOOK_DIR/ts-git-hook-bin$EXE" + +GO="$REPO_ROOT/tool/go$EXE" +if [ ! -x "$GO" ]; then GO=go; fi + +OSS_DIR="$(cd "$REPO_ROOT" && GOWORK=off "$GO" list -m -f '{{.Dir}}' tailscale.com 2>/dev/null || true)" +SHARED_VER="$(cat "$OSS_DIR/misc/git_hook/githook/HOOK_VERSION" 2>/dev/null || echo 0)" +LOCAL_VER="$(cat "$REPO_ROOT/misc/git_hook/HOOK_VERSION" 2>/dev/null || echo 0)" +WANT="$SHARED_VER:$LOCAL_VER" +HAVE="$("$BINARY" version 2>/dev/null || echo none)" + +if [ "$WANT" != "$HAVE" ]; then + echo "git-hook: rebuilding ts-git-hook-bin..." >&2 + (cd "$REPO_ROOT" && GOWORK=off "$GO" build -o "$BINARY" ./misc/git_hook) || { + echo "git-hook: rebuild failed, run: ./tool/go run ./misc/install-git-hooks.go" >&2 + exit 1 + } + "$BINARY" install +fi + +exec "$BINARY" "$@" + diff --git a/misc/git_hook/githook/pre-commit.go b/misc/git_hook/githook/pre-commit.go new file mode 100644 index 000000000..30e4f6a9e --- /dev/null +++ b/misc/git_hook/githook/pre-commit.go @@ -0,0 +1,62 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package githook + +import ( + "bytes" + "errors" + "fmt" + "os/exec" + + "github.com/fatih/color" + "github.com/sourcegraph/go-diff/diff" +) + +var preCommitForbiddenPatterns = [][]byte{ + // Concatenation avoids tripping the check on this file. + []byte("NOCOM" + "MIT"), + []byte("DO NOT " + "SUBMIT"), +} + +// CheckForbiddenMarkers scans the staged diff for forbidden markers +// and returns an error if any are found. +// +// Intended as a pre-commit hook. +// https://git-scm.com/docs/githooks#_pre_commit +func CheckForbiddenMarkers() error { + diffOut, err := exec.Command("git", "diff", "--cached").Output() + if err != nil { + return fmt.Errorf("could not get git diff: %w", err) + } + + diffs, err := diff.ParseMultiFileDiff(diffOut) + if err != nil { + return fmt.Errorf("could not parse diff: %w", err) + } + + foundForbidden := false + for _, d := range diffs { + for _, hunk := range d.Hunks { + lines := bytes.Split(hunk.Body, []byte{'\n'}) + for i, line := range lines { + if len(line) == 0 || line[0] != '+' { + continue + } + for _, forbidden := range preCommitForbiddenPatterns { + if bytes.Contains(line, forbidden) { + if !foundForbidden { + color.New(color.Bold, color.FgRed, color.Underline).Printf("%s found:\n", forbidden) + } + fmt.Printf("%s:%d: %s\n", d.NewName[2:], int(hunk.NewStartLine)+i, line[1:]) + foundForbidden = true + } + } + } + } + } + if foundForbidden { + return errors.New("found forbidden string") + } + return nil +} diff --git a/misc/git_hook/githook/pre-push.go b/misc/git_hook/githook/pre-push.go new file mode 100644 index 000000000..9d5624523 --- /dev/null +++ b/misc/git_hook/githook/pre-push.go @@ -0,0 +1,112 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package githook + +import ( + "bufio" + "fmt" + "os" + "os/exec" + "strings" + + "golang.org/x/mod/modfile" +) + +// CheckGoModReplaces reads pushes from stdin and, for pushes to a +// remote URL in watchedRemotes, rejects any commit whose go.mod has a +// directory-path replace that is not in allowedReplaceDirs. args is +// the pre-push hook's argv (remoteName, remoteLoc). +// +// Intended as a pre-push hook. +// https://git-scm.com/docs/githooks#_pre_push +func CheckGoModReplaces(args []string, watchedRemotes, allowedReplaceDirs []string) error { + if len(args) < 2 { + return fmt.Errorf("pre-push: expected 2 args, got %d", len(args)) + } + remoteLoc := args[1] + + watched := false + for _, r := range watchedRemotes { + if r == remoteLoc { + watched = true + break + } + } + if !watched { + return nil + } + + pushes, err := readPushes() + if err != nil { + return fmt.Errorf("reading pushes: %w", err) + } + for _, p := range pushes { + if p.isDoNotMergeRef() { + continue + } + if err := checkCommit(p.localSHA, allowedReplaceDirs); err != nil { + return fmt.Errorf("not allowing push of %v to %v: %v", p.localSHA, p.remoteRef, err) + } + } + return nil +} + +func checkCommit(sha string, allowedReplaceDirs []string) error { + if sha == zeroRef { + // Allow ref deletions. + return nil + } + goMod, err := exec.Command("git", "show", sha+":go.mod").Output() + if err != nil { + return err + } + mf, err := modfile.Parse("go.mod", goMod, nil) + if err != nil { + return fmt.Errorf("failed to parse its go.mod: %v", err) + } + for _, r := range mf.Replace { + if !modfile.IsDirectoryPath(r.New.Path) { + continue + } + allowed := false + for _, a := range allowedReplaceDirs { + if a == r.New.Path { + allowed = true + break + } + } + if !allowed { + return fmt.Errorf("go.mod contains replace from %v => %v", r.Old.Path, r.New.Path) + } + } + return nil +} + +const zeroRef = "0000000000000000000000000000000000000000" + +type push struct { + localRef string + localSHA string + remoteRef string + remoteSHA string +} + +func (p *push) isDoNotMergeRef() bool { + return strings.HasSuffix(p.remoteRef, "/DO-NOT-MERGE") +} + +func readPushes() (pushes []push, err error) { + bs := bufio.NewScanner(os.Stdin) + for bs.Scan() { + f := strings.Fields(bs.Text()) + if len(f) != 4 { + return nil, fmt.Errorf("unexpected push line %q", bs.Text()) + } + pushes = append(pushes, push{f[0], f[1], f[2], f[3]}) + } + if err := bs.Err(); err != nil { + return nil, err + } + return pushes, nil +} diff --git a/misc/install-git-hooks.go b/misc/install-git-hooks.go index c66ecb8f8..813a45601 100644 --- a/misc/install-git-hooks.go +++ b/misc/install-git-hooks.go @@ -3,80 +3,19 @@ //go:build ignore -// The install-git-hooks program installs git hooks. -// -// It installs a Go binary at .git/hooks/ts-git-hook and a pre-hook -// forwarding shell wrapper to .git/hooks/NAME. +// The install-git-hooks program installs git hooks by delegating to +// githook.Install. See that function's doc for what it does. package main import ( - "fmt" "log" - "os" - "os/exec" - "path/filepath" - "runtime" - "strings" + + "tailscale.com/misc/git_hook/githook" ) -var hooks = []string{ - "pre-push", - "pre-commit", - "commit-msg", - "post-checkout", -} - -func fatalf(format string, a ...any) { - log.SetFlags(0) - log.Fatalf("install-git-hooks: "+format, a...) -} - func main() { - out, err := exec.Command("git", "rev-parse", "--git-common-dir").CombinedOutput() - if err != nil { - fatalf("finding git dir: %v, %s", err, out) - } - gitDir := strings.TrimSpace(string(out)) - - hookDir := filepath.Join(gitDir, "hooks") - if fi, err := os.Stat(hookDir); err != nil { - fatalf("checking hooks dir: %v", err) - } else if !fi.IsDir() { - fatalf("%s is not a directory", hookDir) - } - - buildOut, err := exec.Command(goBin(), "build", - "-o", filepath.Join(hookDir, "ts-git-hook"+exe()), - "./misc/git_hook").CombinedOutput() - if err != nil { - log.Fatalf("go build git-hook: %v, %s", err, buildOut) - } - - for _, hook := range hooks { - content := fmt.Sprintf(hookScript, hook) - file := filepath.Join(hookDir, hook) - // Install the hook. If it already exists, overwrite it, in case there's - // been changes. - if err := os.WriteFile(file, []byte(content), 0755); err != nil { - fatalf("%v", err) - } + log.SetFlags(0) + if err := githook.Install(); err != nil { + log.Fatalf("install-git-hooks: %v", err) } } - -const hookScript = `#!/usr/bin/env bash -exec "$(dirname ${BASH_SOURCE[0]})/ts-git-hook" %s "$@" -` - -func goBin() string { - if p, err := exec.LookPath("go"); err == nil { - return p - } - return "go" -} - -func exe() string { - if runtime.GOOS == "windows" { - return ".exe" - } - return "" -} diff --git a/net/dns/direct.go b/net/dns/direct.go index ec2e42e75..f6f2fd601 100644 --- a/net/dns/direct.go +++ b/net/dns/direct.go @@ -442,7 +442,9 @@ func (m *directManager) runFileWatcher() { if !ok { return } - if err := watchFile(m.ctx, "/etc/", resolvConf, m.checkForFileTrample); err != nil { + dir := m.fs.ActualPath(filepath.Dir(resolvConf)) + file := m.fs.ActualPath(resolvConf) + if err := watchFile(m.ctx, dir, file, m.checkForFileTrample); err != nil { // This is all best effort for now, so surface warnings to users. m.logf("dns: inotify: %s", err) } @@ -597,6 +599,19 @@ type wholeFileFS interface { ReadFile(name string) ([]byte, error) Remove(name string) error Rename(oldName, newName string) error + // ActualPath returns the real filesystem path for the given absolute + // logical path. All other methods in this interface accept logical + // paths (like "/etc/resolv.conf") and translate them internally; + // ActualPath exposes that same translation for callers that need + // the real path for use outside the interface (e.g. setting up an + // inotify watch on the correct directory). + // + // For directFS with an empty prefix (production), the input is + // returned unchanged ("/etc" → "/etc"). For directFS with a test + // prefix like "/tmp/test123", the prefix is joined + // ("/etc" → "/tmp/test123/etc"). For wslFS the input is returned + // unchanged, since paths are passed through to wsl.exe as-is. + ActualPath(name string) string Stat(name string) (isRegular bool, err error) Truncate(name string) error WriteFile(name string, contents []byte, perm os.FileMode) error @@ -613,6 +628,8 @@ type directFS struct { func (fs directFS) path(name string) string { return filepath.Join(fs.prefix, name) } +func (fs directFS) ActualPath(name string) string { return fs.path(name) } + func (fs directFS) Stat(name string) (isRegular bool, err error) { fi, err := os.Stat(fs.path(name)) if err != nil { diff --git a/net/dns/direct_linux_test.go b/net/dns/direct_linux_test.go index 8199b41f3..c053db178 100644 --- a/net/dns/direct_linux_test.go +++ b/net/dns/direct_linux_test.go @@ -7,21 +7,19 @@ package dns import ( "context" - "fmt" "net/netip" "os" "path/filepath" "testing" "testing/synctest" - - "github.com/illarion/gonotify/v3" + "time" "tailscale.com/util/dnsname" "tailscale.com/util/eventbus/eventbustest" ) func TestDNSTrampleRecovery(t *testing.T) { - HookWatchFile.Set(watchFile) + t.Cleanup(HookWatchFile.SetForTest(watchFile)) synctest.Test(t, func(t *testing.T) { tmp := t.TempDir() if err := os.MkdirAll(filepath.Join(tmp, "etc"), 0700); err != nil { @@ -77,33 +75,20 @@ search ts.net ts-dns.test }) } -// watchFile is generally copied from linuxtrample, but cancels the context -// after the first call to cb() after the first trample to end the test. +// watchFile is a test implementation of the file watcher that uses a timer +// instead of inotify. Real inotify (gonotify.NewDirWatcher) creates goroutines +// that block on real syscalls, which don't work inside synctest's fake-time +// bubble. Instead, we use a one-shot timer that synctest.Wait() will advance, +// triggering a callback to check for file trampling. func watchFile(ctx context.Context, dir, filename string, cb func()) error { - ctx, cancel := context.WithCancel(ctx) - defer cancel() - - const events = gonotify.IN_ATTRIB | - gonotify.IN_CLOSE_WRITE | - gonotify.IN_CREATE | - gonotify.IN_DELETE | - gonotify.IN_MODIFY | - gonotify.IN_MOVE - - watcher, err := gonotify.NewDirWatcher(ctx, events, dir) - if err != nil { - return fmt.Errorf("NewDirWatcher: %w", err) - } - - for { - select { - case event := <-watcher.C: - if event.Name == filename { - cb() - cancel() - } - case <-ctx.Done(): - return ctx.Err() - } + timer := time.NewTimer(time.Millisecond) + defer timer.Stop() + select { + case <-ctx.Done(): + return ctx.Err() + case <-timer.C: + cb() } + <-ctx.Done() + return ctx.Err() } diff --git a/net/dns/dns_clone.go b/net/dns/dns_clone.go index 32765ceb4..724e36dac 100644 --- a/net/dns/dns_clone.go +++ b/net/dns/dns_clone.go @@ -33,8 +33,19 @@ func (src *Config) Clone() *Config { } if dst.Routes != nil { dst.Routes = map[dnsname.FQDN][]*dnstype.Resolver{} - for k := range src.Routes { - dst.Routes[k] = append([]*dnstype.Resolver{}, src.Routes[k]...) + for k, sv := range src.Routes { + if sv == nil { + dst.Routes[k] = nil + continue + } + dst.Routes[k] = make([]*dnstype.Resolver, len(sv)) + for i := range sv { + if sv[i] == nil { + dst.Routes[k][i] = nil + } else { + dst.Routes[k][i] = sv[i].Clone() + } + } } } dst.SearchDomains = append(src.SearchDomains[:0:0], src.SearchDomains...) diff --git a/net/dns/manager_darwin.go b/net/dns/manager_darwin.go index bb590aa4e..90686b246 100644 --- a/net/dns/manager_darwin.go +++ b/net/dns/manager_darwin.go @@ -5,7 +5,10 @@ package dns import ( "bytes" + "fmt" + "io/fs" "os" + "strings" "go4.org/mem" "tailscale.com/control/controlknobs" @@ -22,15 +25,22 @@ import ( // // The health tracker, bus and the knobs may be nil and are ignored on this platform. func NewOSConfigurator(logf logger.Logf, _ *health.Tracker, _ *eventbus.Bus, _ policyclient.Client, _ *controlknobs.Knobs, ifName string) (OSConfigurator, error) { - return &darwinConfigurator{logf: logf, ifName: ifName}, nil + return &darwinConfigurator{ + logf: logf, + ifName: ifName, + resolverDir: "/etc/resolver", + resolvConfPath: "/etc/resolv.conf", + }, nil } // darwinConfigurator is the tailscaled-on-macOS DNS OS configurator that // maintains the Split DNS nameserver entries pointing MagicDNS DNS suffixes // to 100.100.100.100 using the macOS /etc/resolver/$SUFFIX files. type darwinConfigurator struct { - logf logger.Logf - ifName string + logf logger.Logf + ifName string + resolverDir string // default "/etc/resolver" + resolvConfPath string // default "/etc/resolv.conf" } func (c *darwinConfigurator) Close() error { @@ -51,10 +61,16 @@ func (c *darwinConfigurator) SetDNS(cfg OSConfig) error { buf.WriteString("\n") } - if err := os.MkdirAll("/etc/resolver", 0755); err != nil { + if err := os.MkdirAll(c.resolverDir, 0755); err != nil { return err } + root, err := os.OpenRoot(c.resolverDir) + if err != nil { + return err + } + defer root.Close() + var keep map[string]bool // Add a dummy file to /etc/resolver with a "search ..." directive if we have @@ -70,7 +86,7 @@ func (c *darwinConfigurator) SetDNS(cfg OSConfig) error { sbuf.WriteString(string(d.WithoutTrailingDot())) } sbuf.WriteString("\n") - if err := os.WriteFile("/etc/resolver/"+searchFile, sbuf.Bytes(), 0644); err != nil { + if err := root.WriteFile(searchFile, sbuf.Bytes(), 0644); err != nil { return err } } @@ -78,15 +94,34 @@ func (c *darwinConfigurator) SetDNS(cfg OSConfig) error { for _, d := range cfg.MatchDomains { fileBase := string(d.WithoutTrailingDot()) mak.Set(&keep, fileBase, true) - fullPath := "/etc/resolver/" + fileBase - if err := os.WriteFile(fullPath, buf.Bytes(), 0644); err != nil { + if !isValidResolverFileName(fileBase) { + c.logf("[unexpected] invalid resolver domain %q with slashes or colons", fileBase) + return fmt.Errorf("invalid resolver domain %q: must not contain slashes or colons", fileBase) + } + + if err := root.WriteFile(fileBase, buf.Bytes(), 0644); err != nil { return err } } return c.removeResolverFiles(func(domain string) bool { return !keep[domain] }) } +func isValidResolverFileName(name string) bool { + // Verify that the filename doesn't contain any characters that + // might cause issues when used as a filename; os.Root is a + // defense against path traversal, but prefer a nice error here + // if we can. These aren't valid for domain names anyway. + if strings.Contains(name, "/") || strings.Contains(name, "\\") { + return false + } + + if strings.Contains(name, ":") { + return false + } + return true +} + // GetBaseConfig returns the current OS DNS configuration, extracting it from /etc/resolv.conf. // We should really be using the SystemConfiguration framework to get this information, as this // is not a stable public API, and is provided mostly as a compatibility effort with Unix @@ -95,9 +130,9 @@ func (c *darwinConfigurator) SetDNS(cfg OSConfig) error { func (c *darwinConfigurator) GetBaseConfig() (OSConfig, error) { cfg := OSConfig{} - resolvConf, err := resolvconffile.ParseFile("/etc/resolv.conf") + resolvConf, err := resolvconffile.ParseFile(c.resolvConfPath) if err != nil { - c.logf("failed to parse /etc/resolv.conf: %v", err) + c.logf("failed to parse %s: %v", c.resolvConfPath, err) return cfg, ErrGetBaseConfigNotSupported } @@ -113,7 +148,7 @@ func (c *darwinConfigurator) GetBaseConfig() (OSConfig, error) { if len(cfg.Nameservers) == 0 { // Log a warning in case we couldn't find any nameservers in /etc/resolv.conf. - c.logf("no nameservers found in /etc/resolv.conf, DNS resolution might fail") + c.logf("no nameservers found in %s, DNS resolution might fail", c.resolvConfPath) } return cfg, nil @@ -124,13 +159,19 @@ const macResolverFileHeader = "# Added by tailscaled\n" // removeResolverFiles deletes all files in /etc/resolver for which the shouldDelete // func returns true. func (c *darwinConfigurator) removeResolverFiles(shouldDelete func(domain string) bool) error { - dents, err := os.ReadDir("/etc/resolver") + root, err := os.OpenRoot(c.resolverDir) if os.IsNotExist(err) { return nil } if err != nil { return err } + defer root.Close() + + dents, err := fs.ReadDir(root.FS(), ".") + if err != nil { + return err + } for _, de := range dents { if !de.Type().IsRegular() { continue @@ -139,8 +180,7 @@ func (c *darwinConfigurator) removeResolverFiles(shouldDelete func(domain string if !shouldDelete(name) { continue } - fullPath := "/etc/resolver/" + name - contents, err := os.ReadFile(fullPath) + contents, err := root.ReadFile(name) if err != nil { if os.IsNotExist(err) { // race? continue @@ -150,7 +190,7 @@ func (c *darwinConfigurator) removeResolverFiles(shouldDelete func(domain string if !mem.HasPrefix(mem.B(contents), mem.S(macResolverFileHeader)) { continue } - if err := os.Remove(fullPath); err != nil { + if err := root.Remove(name); err != nil { return err } } diff --git a/net/dns/manager_darwin_test.go b/net/dns/manager_darwin_test.go new file mode 100644 index 000000000..8596f9575 --- /dev/null +++ b/net/dns/manager_darwin_test.go @@ -0,0 +1,182 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package dns + +import ( + "errors" + "maps" + "net/netip" + "os" + "path/filepath" + "slices" + "testing" + + "tailscale.com/types/logger" + "tailscale.com/util/dnsname" +) + +func newTestConfigurator(t *testing.T) *darwinConfigurator { + t.Helper() + dir := t.TempDir() + + resolvConf := filepath.Join(dir, "resolv.conf") + if err := os.WriteFile(resolvConf, []byte("nameserver 8.8.8.8\n"), 0644); err != nil { + t.Fatal(err) + } + + resolverDir := filepath.Join(dir, "resolvers") + if err := os.Mkdir(resolverDir, 0755); err != nil { + t.Fatal(err) + } + + return &darwinConfigurator{ + logf: logger.Discard, + ifName: "utun99", + resolverDir: resolverDir, + resolvConfPath: resolvConf, + } +} + +func TestSetDNS(t *testing.T) { + c := newTestConfigurator(t) + + tests := []struct { + name string + cfg OSConfig + fileContents map[string]string // path -> expected file contents + }{ + { + name: "basic", + cfg: OSConfig{ + Nameservers: []netip.Addr{netip.MustParseAddr("100.100.100.100")}, + MatchDomains: []dnsname.FQDN{"example.com.", "ts.net."}, + }, + fileContents: map[string]string{ + "example.com": macResolverFileHeader + "nameserver 100.100.100.100\n", + "ts.net": macResolverFileHeader + "nameserver 100.100.100.100\n", + }, + }, + { + name: "SearchDomains", + cfg: OSConfig{ + Nameservers: []netip.Addr{netip.MustParseAddr("100.100.100.100")}, + SearchDomains: []dnsname.FQDN{"tail1234.ts.net."}, + MatchDomains: []dnsname.FQDN{"ts.net."}, + }, + fileContents: map[string]string{ + "ts.net": macResolverFileHeader + "nameserver 100.100.100.100\n", + "search.tailscale": macResolverFileHeader + "search tail1234.ts.net\n", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := c.SetDNS(tt.cfg); err != nil { + t.Fatalf("SetDNS failed: %v", err) + } + + // We want only the expected files in the resolverDir, + // and nothing else. + files, err := os.ReadDir(c.resolverDir) + if err != nil { + t.Fatalf("reading resolver directory: %v", err) + } + + var fileNames []string + for _, f := range files { + fileNames = append(fileNames, f.Name()) + } + + if len(files) != len(tt.fileContents) { + t.Fatalf("expected %d resolver files, got %d\ngot: %v\nwant: %v", + len(tt.fileContents), len(files), + fileNames, slices.Collect(maps.Keys(tt.fileContents)), + ) + } + + // Check each file's contents. + for domain, expected := range tt.fileContents { + path := filepath.Join(c.resolverDir, domain) + data, err := os.ReadFile(path) + if err != nil { + t.Fatalf("reading resolver file %q: %v", domain, err) + } + if string(data) != expected { + t.Errorf("resolver file %q contents mismatch:\ngot: %q\nwant: %q", domain, string(data), expected) + } + } + }) + } +} + +func TestSetDNS_PathTraversal(t *testing.T) { + c := newTestConfigurator(t) + + // Use a simple path traversal that tries to escape the resolver + // directory. With the previously-vulnerable code (os.WriteFile with string + // concatenation), this writes to the parent directory. With the + // fix (os.Root), this is rejected. + traversals := []dnsname.FQDN{ + "../evil.", + "../../evil.", + "sub/../../evil.", + } + + for _, traversal := range traversals { + cfg := OSConfig{ + Nameservers: []netip.Addr{netip.MustParseAddr("100.100.100.100")}, + MatchDomains: []dnsname.FQDN{traversal}, + } + + if err := c.SetDNS(cfg); err == nil { + t.Errorf("SetDNS with MatchDomain %q should have failed, but succeeded", traversal) + } + } + + // Verify no file named "evil" was written in the parent of resolverDir. + parent := filepath.Dir(c.resolverDir) + if fileExists(filepath.Join(parent, "evil")) { + t.Fatal("file 'evil' was written to parent directory via path traversal") + } +} + +func TestRemoveResolverFiles(t *testing.T) { + c := newTestConfigurator(t) + + // Write a tailscale-managed file. + managed := filepath.Join(c.resolverDir, "ts.net") + if err := os.WriteFile(managed, []byte(macResolverFileHeader+"nameserver 100.100.100.100\n"), 0644); err != nil { + t.Fatal(err) + } + + // Write a non-tailscale file that should be left alone. + unmanaged := filepath.Join(c.resolverDir, "other.conf") + if err := os.WriteFile(unmanaged, []byte("# not ours\nnameserver 8.8.8.8\n"), 0644); err != nil { + t.Fatal(err) + } + + // Remove all resolver files and verify that only the managed one is removed. + if err := c.removeResolverFiles(func(domain string) bool { return true }); err != nil { + t.Fatal(err) + } + + if fileExists(managed) { + t.Error("managed file should have been removed") + } + if !fileExists(unmanaged) { + t.Error("unmanaged file should still exist") + } +} + +func fileExists(path string) bool { + _, err := os.Stat(path) + if errors.Is(err, os.ErrNotExist) { + return false + } else if err == nil { + return true + } + + panic("unexpected error checking file existence: " + err.Error()) +} diff --git a/net/dns/manager_linux_test.go b/net/dns/manager_linux_test.go index a108a3297..c3c99307a 100644 --- a/net/dns/manager_linux_test.go +++ b/net/dns/manager_linux_test.go @@ -316,6 +316,7 @@ func (m memFS) Stat(name string) (isRegular bool, err error) { func (m memFS) Chmod(name string, mode os.FileMode) error { panic("TODO") } func (m memFS) Rename(oldName, newName string) error { panic("TODO") } func (m memFS) Remove(name string) error { panic("TODO") } +func (m memFS) ActualPath(name string) string { return name } func (m memFS) ReadFile(name string) ([]byte, error) { v, ok := m[name] if !ok { diff --git a/net/dns/manager_windows.go b/net/dns/manager_windows.go index 1e412b2d2..20102af86 100644 --- a/net/dns/manager_windows.go +++ b/net/dns/manager_windows.go @@ -8,6 +8,7 @@ import ( "bytes" "errors" "fmt" + "io/fs" "maps" "net/netip" "os" @@ -246,7 +247,13 @@ func (m *windowsManager) setHosts(hosts []*HostEntry) error { } hostsFile := filepath.Join(systemDir, "drivers", "etc", "hosts") b, err := os.ReadFile(hostsFile) - if err != nil { + switch { + case err == nil: + // Continue. + case errors.Is(err, fs.ErrNotExist): + // Non-fatal, we'll just create a new hosts file. + m.logf("failed to read the hosts file: %v", err) + default: return err } outB, err := setTailscaleHosts(m.logf, b, hosts) diff --git a/net/dns/resolver/forwarder.go b/net/dns/resolver/forwarder.go index ed7ff78f7..3f586b60f 100644 --- a/net/dns/resolver/forwarder.go +++ b/net/dns/resolver/forwarder.go @@ -13,6 +13,7 @@ import ( "errors" "fmt" "io" + "maps" "net" "net/http" "net/netip" @@ -41,8 +42,10 @@ import ( "tailscale.com/types/dnstype" "tailscale.com/types/logger" "tailscale.com/types/nettype" + "tailscale.com/types/views" "tailscale.com/util/cloudenv" "tailscale.com/util/dnsname" + "tailscale.com/util/mak" "tailscale.com/util/race" "tailscale.com/version" ) @@ -324,6 +327,19 @@ type forwarder struct { // resolver lookup. cloudHostFallback []resolverAndDelay + // schemes are the collection of registered URI scheme names that + // dynamically decide which resolver to use at the time of each query. The + // key is the scheme (the portion before the first `:`) and the value is a + // handler that determines where the current query should be sent. + // Use schemeCacheLocked() to get the current contents that can continue to + // be accessed once mu is released. This allows the (much more common) + // resolver code path to avoid repeated locking and unlocking. + // When modified, call invalidateSchemeCacheLocked() before unlocking mu. + schemes map[string]CustomSchemeHandler + // schemeCache is an immutable copy of schemes. Do not read directly, + // use schemeCacheLocked() which will regenerate its contents as needed. + schemeCache views.Map[string, CustomSchemeHandler] + // acceptDNS tracks the CorpDNS pref (--accept-dns) // This lets us skip health warnings if the forwarder receives inbound // queries directly - but we didn't configure it with any upstream resolvers. @@ -996,15 +1012,66 @@ func (f *forwarder) sendTCP(ctx context.Context, fq *forwardQuery, rr resolverAn return out, nil } +// applySchemes resolves any custom-scheme entries in rrs using the provided +// scheme handlers, returning the resulting slice. Entries whose handler returns +// an error or empty string are dropped. Entries with no registered scheme pass +// through unchanged. If schemes is nil, rrs is returned as-is. +func applySchemes(logf logger.Logf, rrs []resolverAndDelay, schemes views.Map[string, CustomSchemeHandler]) []resolverAndDelay { + if schemes.IsNil() { + return rrs + } + var result []resolverAndDelay + for i, rr := range rrs { + scheme, _, hasColon := strings.Cut(rr.name.Addr, ":") + handler, isCustom := schemes.GetOk(scheme) + if !hasColon || !isCustom { + if result != nil { + result = append(result, rr) + } + continue + } + // Avoid making a results slice in the common case where there + // are no custom scheme resolvers. + if result == nil { + result = make([]resolverAndDelay, i, len(rrs)) + copy(result, rrs) + } + newAddr, err := handler(rr.name.Addr) + if err != nil { + logf("error from custom scheme handler, skipping resolver : %v", err) + } + if err != nil || newAddr == "" { + continue + } + newResolver := *rr.name + newResolver.Addr = newAddr + result = append(result, resolverAndDelay{name: &newResolver, startDelay: rr.startDelay}) + } + // If we didn't have any custom schemes, return the original rrs. + if result == nil { + return rrs + } + return result +} + // resolvers returns the resolvers to use for domain. func (f *forwarder) resolvers(domain dnsname.FQDN) []resolverAndDelay { f.mu.Lock() routes := f.routes cloudHostFallback := f.cloudHostFallback + schemes := f.schemeCacheLocked() f.mu.Unlock() + for _, route := range routes { - if route.Suffix == "." || route.Suffix.Contains(domain) { - return route.Resolvers + if route.Suffix != "." && !route.Suffix.Contains(domain) { + continue + } + resolved := applySchemes(f.logf, route.Resolvers, schemes) + // If scheme resolution filtered out all resolvers from a non-empty + // route, fall through to the next matching route. If the resolvers + // were configured to be empty allow resolved to be empty. + if len(resolved) > 0 || len(route.Resolvers) == 0 { + return resolved } } return cloudHostFallback // or nil if no fallback @@ -1021,6 +1088,39 @@ func (f *forwarder) GetUpstreamResolvers(name dnsname.FQDN) []*dnstype.Resolver return upstreamResolvers } +// RegisterCustomScheme adds a [CustomSchemeHandler] that is called to provide +// an updated address when a [dnstype.Resolver.Addr] uses that scheme. +func (f *forwarder) RegisterCustomScheme(scheme string, h CustomSchemeHandler) error { + f.mu.Lock() + defer f.mu.Unlock() + if _, ok := f.schemes[scheme]; ok { + return fmt.Errorf("scheme %q already registered", scheme) + } + f.invalidateSchemeCacheLocked() + mak.Set(&f.schemes, scheme, h) + return nil +} + +// invalidateSchemeCacheLocked clears f.schemeCache so that it will be rebuilt +// on the next call to f.schemeCacheLocked(). +func (f *forwarder) invalidateSchemeCacheLocked() { + f.schemeCache = views.Map[string, CustomSchemeHandler]{} +} + +// schemeCacheLocked returns an immutable copy of f.schemes that can be used +// after mu is unlocked. +func (f *forwarder) schemeCacheLocked() views.Map[string, CustomSchemeHandler] { + if !f.schemeCache.IsNil() { + return f.schemeCache + } + if f.schemes == nil { + return f.schemeCache // returns a nil view + } + // Regenerate the cache + f.schemeCache = views.MapOf(maps.Clone(f.schemes)) + return f.schemeCache +} + // forwardQuery is information and state about a forwarded DNS query that's // being sent to 1 or more upstreams. // diff --git a/net/dns/resolver/forwarder_test.go b/net/dns/resolver/forwarder_test.go index faaaa9f3c..ebe4041a6 100644 --- a/net/dns/resolver/forwarder_test.go +++ b/net/dns/resolver/forwarder_test.go @@ -27,6 +27,7 @@ import ( "tailscale.com/net/tsdial" "tailscale.com/tstest" "tailscale.com/types/dnstype" + "tailscale.com/util/dnsname" "tailscale.com/util/eventbus/eventbustest" ) @@ -1385,3 +1386,142 @@ func TestForwarderHealthOnContextExpiry(t *testing.T) { }) } } + +func TestResolversCustomScheme(t *testing.T) { + t.Parallel() + tests := []struct { + name string + domain dnsname.FQDN + schemes map[string]CustomSchemeHandler + routes map[dnsname.FQDN][]*dnstype.Resolver + wantAddrs []string + }{ + { + name: "no-custom-scheme", + domain: "example.com.", + schemes: map[string]CustomSchemeHandler{}, + routes: map[dnsname.FQDN][]*dnstype.Resolver{ + "example.com.": { + {Addr: "192.168.1.1:53"}, + {Addr: "192.168.1.2:53"}, + }, + }, + wantAddrs: []string{"192.168.1.1:53", "192.168.1.2:53"}, + }, + { + name: "single-custom-scheme", + domain: "example.com.", + schemes: map[string]CustomSchemeHandler{ + "myscheme": func(string) (string, error) { return "1.2.3.4:53", nil }, + }, + routes: map[dnsname.FQDN][]*dnstype.Resolver{ + "example.com.": {{Addr: "myscheme:customKey"}}, + }, + wantAddrs: []string{"1.2.3.4:53"}, + }, + { + name: "with-other-resolvers", + domain: "example.com.", + schemes: map[string]CustomSchemeHandler{ + "myscheme": func(key string) (string, error) { return "1.2.3.4:53", nil }, + }, + routes: map[dnsname.FQDN][]*dnstype.Resolver{ + "example.com.": { + {Addr: "192.168.1.1:53"}, + {Addr: "myscheme:customKey"}, + {Addr: "192.168.1.2:53"}, + }, + }, + wantAddrs: []string{"192.168.1.1:53", "1.2.3.4:53", "192.168.1.2:53"}, + }, + { + name: "multiple-custom-schemes", + domain: "example.com.", + schemes: map[string]CustomSchemeHandler{ + "schemeOne": func(string) (string, error) { return "1.2.3.4:53", nil }, + "schemeTwo": func(string) (string, error) { return "5.6.7.8:53", nil }, + }, + routes: map[dnsname.FQDN][]*dnstype.Resolver{ + "example.com.": { + {Addr: "schemeOne:customKey"}, + {Addr: "schemeTwo:customKey"}, + }, + }, + wantAddrs: []string{"1.2.3.4:53", "5.6.7.8:53"}, + }, + { + name: "empty-string-means-no-resolver", + domain: "example.com.", + schemes: map[string]CustomSchemeHandler{ + "myscheme": func(string) (string, error) { return "", nil }, + }, + routes: map[dnsname.FQDN][]*dnstype.Resolver{ + "example.com.": { + {Addr: "192.168.1.1:53"}, + {Addr: "myscheme:customKey"}, + }, + }, + wantAddrs: []string{"192.168.1.1:53"}, + }, + { + name: "error-means-no-resolver", + domain: "example.com.", + schemes: map[string]CustomSchemeHandler{ + "myscheme": func(string) (string, error) { return "", fmt.Errorf("handler error") }, + }, + routes: map[dnsname.FQDN][]*dnstype.Resolver{ + "example.com.": { + {Addr: "192.168.1.1:53"}, + {Addr: "myscheme:customKey"}, + }, + }, + wantAddrs: []string{"192.168.1.1:53"}, + }, + { + // If the best-matching route yields no resolvers after scheme + // resolution, fall through to the next matching route. + name: "empty-scheme-result-falls-through-to-next-matching-route", + domain: "example.com.", + schemes: map[string]CustomSchemeHandler{ + "myscheme": func(string) (string, error) { return "", nil }, + }, + routes: map[dnsname.FQDN][]*dnstype.Resolver{ + "example.com.": {{Addr: "myscheme:customKey"}}, + ".": {{Addr: "192.168.1.1:53"}}, + }, + wantAddrs: []string{"192.168.1.1:53"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + logf := tstest.WhileTestRunningLogger(t) + bus := eventbustest.NewBus(t) + netMon, err := netmon.New(bus, logf) + if err != nil { + t.Fatal(err) + } + var dialer tsdial.Dialer + dialer.SetNetMon(netMon) + dialer.SetBus(bus) + + fwd := newForwarder(logf, netMon, nil, &dialer, health.NewTracker(bus), nil) + for scheme, handler := range tt.schemes { + if err := fwd.RegisterCustomScheme(scheme, handler); err != nil { + t.Fatal(err) + } + } + + fwd.setRoutes(tt.routes, false) + + got := fwd.resolvers(tt.domain) + var gotAddrs []string + for _, r := range got { + gotAddrs = append(gotAddrs, r.name.Addr) + } + if !slices.Equal(gotAddrs, tt.wantAddrs) { + t.Errorf("got %v, want %v", gotAddrs, tt.wantAddrs) + } + }) + } +} diff --git a/net/dns/resolver/tsdns.go b/net/dns/resolver/tsdns.go index 01f0c8a63..4b2db5705 100644 --- a/net/dns/resolver/tsdns.go +++ b/net/dns/resolver/tsdns.go @@ -293,6 +293,18 @@ func (r *Resolver) SetConfig(cfg Config) error { return nil } +// CustomSchemeHandler takes a URI (retrieved from [dnstype.Resolver.Addr]) and +// returns an updated URI to use for the current query. The result is only valid +// for right now and may change over time. +type CustomSchemeHandler func(addr string) (newAddr string, err error) + +// RegisterCustomScheme adds a [CustomSchemaHandler] that is called to provide +// an updated address to the forwarder when a [dnstype.Resolver.Addr] uses that +// scheme. +func (r *Resolver) RegisterCustomScheme(scheme string, h CustomSchemeHandler) error { + return r.forwarder.RegisterCustomScheme(scheme, h) +} + // Close shuts down the resolver and ensures poll goroutines have exited. // The Resolver cannot be used again after Close is called. func (r *Resolver) Close() { diff --git a/net/dns/wsl_windows.go b/net/dns/wsl_windows.go index 1b93142f5..b0e62170b 100644 --- a/net/dns/wsl_windows.go +++ b/net/dns/wsl_windows.go @@ -148,6 +148,8 @@ type wslFS struct { distro string } +func (fs wslFS) ActualPath(name string) string { return name } + func (fs wslFS) Stat(name string) (isRegular bool, err error) { err = wslRun(fs.cmd("test", "-f", name)) if ee, _ := err.(*exec.ExitError); ee != nil { diff --git a/net/ping/ping.go b/net/ping/ping.go index de79da51c..42d381c73 100644 --- a/net/ping/ping.go +++ b/net/ping/ping.go @@ -29,8 +29,10 @@ import ( ) const ( - v4Type = "ip4:icmp" - v6Type = "ip6:icmp" + v4Type = "ip4:icmp" + v6Type = "ip6:icmp" + v4UDPType = "udp4" // unprivileged datagram-oriented ICMPv4 + v6UDPType = "udp6" // unprivileged datagram-oriented ICMPv6 ) type response struct { @@ -54,12 +56,30 @@ type ListenPacketer interface { // A new instance should be created for each concurrent set of ping requests; // this type should not be reused. type Pinger struct { + // options that must be set before the first call to Send + + // Unprivileged, when set, makes the Pinger use non-privileged + // datagram-oriented ICMP sockets ("udp4"/"udp6") opened via + // golang.org/x/net/icmp.ListenPacket instead of raw ICMP sockets + // ("ip4:icmp"/"ip6:icmp") opened via the configured ListenPacketer. + // + // Unprivileged mode is supported on macOS, iOS, and Linux (subject to + // the /proc/sys/net/ipv4/ping_group_range sysctl). When set, the + // ListenPacketer passed to New is ignored and the kernel rewrites the + // outgoing ICMP echo ID to match the socket; replies are matched by + // sequence number and echo data only. + // + // Must be set before the first call to Send. + Unprivileged bool + + Verbose bool // verbose logging + Logf logger.Logf // optional logging function; if nil, logs to the standard logger + lp ListenPacketer // closed guards against send incrementing the waitgroup concurrently with close. - closed atomic.Bool - Logf logger.Logf - Verbose bool + closed atomic.Bool + timeNow func() time.Time id uint16 // uint16 per RFC 792 wg sync.WaitGroup @@ -95,7 +115,17 @@ func (p *Pinger) mkconn(ctx context.Context, typ, addr string) (net.PacketConn, return nil, net.ErrClosed } - c, err := p.lp.ListenPacket(ctx, typ, addr) + var c net.PacketConn + var err error + if p.Unprivileged { + // icmp.ListenPacket on "udp4"/"udp6" opens a datagram-oriented + // ICMP socket that does not require elevated privileges. The + // returned *icmp.PacketConn implements net.PacketConn and, on + // Darwin/iOS, strips the IPv4 header on read via IP_STRIPHDR. + c, err = icmp.ListenPacket(typ, addr) + } else { + c, err = p.lp.ListenPacket(ctx, typ, addr) + } if err != nil { return nil, err } @@ -125,7 +155,7 @@ func (p *Pinger) getConn(ctx context.Context, typ string) (net.PacketConn, error } var addr = "0.0.0.0" - if typ == v6Type { + if typ == v6Type || typ == v6UDPType { addr = "::" } c, err := p.mkconn(ctx, typ, addr) @@ -216,9 +246,9 @@ func (p *Pinger) handleResponse(buf []byte, now time.Time, typ string) { // and IPv6. var icmpType icmp.Type switch typ { - case v4Type: + case v4Type, v4UDPType: icmpType = ipv4.ICMPTypeEchoReply - case v6Type: + case v6Type, v6UDPType: icmpType = ipv6.ICMPTypeEchoReply default: p.vlogf("handleResponse: unknown icmp.Type") @@ -243,7 +273,10 @@ func (p *Pinger) handleResponse(buf []byte, now time.Time, typ string) { } // We assume we sent this if the ID in the response is ours. - if uint16(resp.ID) != p.id { + // In unprivileged ICMP DGRAM mode the kernel rewrites the ID to match + // the socket, so the value we set on the way out is not what comes + // back; rely on sequence and data matching instead. + if !p.Unprivileged && uint16(resp.ID) != p.id { p.vlogf("handleResponse: wanted ID=%d; got %d", p.id, resp.ID) return } @@ -294,14 +327,30 @@ func (p *Pinger) Send(ctx context.Context, dest net.Addr, data []byte) (time.Dur } if ap.Is6() { icmpType = ipv6.ICMPTypeEchoRequest - conn, err = p.getConn(ctx, v6Type) + typ := v6Type + if p.Unprivileged { + typ = v6UDPType + } + conn, err = p.getConn(ctx, typ) } else { - conn, err = p.getConn(ctx, v4Type) + typ := v4Type + if p.Unprivileged { + typ = v4UDPType + } + conn, err = p.getConn(ctx, typ) } if err != nil { return 0, err } + // In unprivileged ICMP DGRAM mode (icmp.ListenPacket on "udp4"/"udp6"), + // the kernel requires a *net.UDPAddr destination for WriteTo even though + // the wire packet is ICMP. + writeDst := dest + if p.Unprivileged { + writeDst = &net.UDPAddr{IP: ap.AsSlice(), Zone: ap.Zone()} + } + m := icmp.Message{ Type: icmpType, Code: 0, @@ -324,7 +373,7 @@ func (p *Pinger) Send(ctx context.Context, dest net.Addr, data []byte) (time.Dur p.mu.Unlock() start := p.timeNow() - n, err := conn.WriteTo(b, dest) + n, err := conn.WriteTo(b, writeDst) if err != nil { return 0, err } else if n != len(b) { diff --git a/net/tsdial/tsdial.go b/net/tsdial/tsdial.go index ebbafa52b..ca08810a3 100644 --- a/net/tsdial/tsdial.go +++ b/net/tsdial/tsdial.go @@ -515,6 +515,33 @@ func (d *Dialer) UserDial(ctx context.Context, network, addr string) (net.Conn, return stdDialer.DialContext(ctx, network, ipp.String()) } +// UserDialPlan resolves addr and reports whether the dialer would +// handle it via Tailscale. If viaTailscale is false, the resolved +// address is not a Tailscale route and the caller may dial it directly. +// +// Warning: there is a TOCTOU race if addr contains a DNS name and the +// caller subsequently passes the same DNS name to [Dialer.UserDial], as DNS +// may resolve differently the second time. Callers who want to only +// dial over Tailscale should call [Dialer.UserDial] with the returned +// ipp.String() (an IP:port) rather than the original DNS name. +func (d *Dialer) UserDialPlan(ctx context.Context, network, addr string) (ipp netip.AddrPort, viaTailscale bool, err error) { + ipp, err = d.userDialResolve(ctx, network, addr) + if err != nil { + return netip.AddrPort{}, false, err + } + if d.UseNetstackForIP != nil && d.UseNetstackForIP(ipp.Addr()) { + return ipp, true, nil + } + if routes := d.routes.Load(); routes != nil { + isTailscaleRoute, _ := routes.Lookup(ipp.Addr()) + return ipp, isTailscaleRoute, nil + } + if version.IsMacGUIVariant() && tsaddr.IsTailscaleIP(ipp.Addr()) { + return ipp, true, nil + } + return ipp, false, nil +} + // dialPeerAPI connects to a Tailscale peer's peerapi over TCP. // // network must a "tcp" type, and addr must be an ip:port. Name resolution diff --git a/net/tsdial/tsdial_test.go b/net/tsdial/tsdial_test.go new file mode 100644 index 000000000..92960acbe --- /dev/null +++ b/net/tsdial/tsdial_test.go @@ -0,0 +1,97 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package tsdial + +import ( + "context" + "net/netip" + "testing" + + "github.com/gaissmai/bart" +) + +func TestUserDialPlan(t *testing.T) { + tests := []struct { + name string + addr string + routes map[netip.Prefix]bool // nil means no routes configured + useNetstackFor func(netip.Addr) bool // nil means not set + wantVia bool + wantAddr netip.AddrPort + }{ + { + name: "loopback_no_routes", + addr: "127.0.0.1:8080", + wantVia: false, + wantAddr: netip.MustParseAddrPort("127.0.0.1:8080"), + }, + { + name: "loopback_v6_no_routes", + addr: "[::1]:8080", + wantVia: false, + wantAddr: netip.MustParseAddrPort("[::1]:8080"), + }, + { + name: "tailscale_ip_in_routes", + addr: "100.64.1.1:22", + routes: map[netip.Prefix]bool{ + netip.MustParsePrefix("100.64.0.0/10"): true, + }, + wantVia: true, + wantAddr: netip.MustParseAddrPort("100.64.1.1:22"), + }, + { + name: "non_tailscale_ip_in_local_routes", + addr: "10.0.0.5:80", + routes: map[netip.Prefix]bool{ + netip.MustParsePrefix("100.64.0.0/10"): true, + netip.MustParsePrefix("10.0.0.0/8"): false, // local route + }, + wantVia: false, + wantAddr: netip.MustParseAddrPort("10.0.0.5:80"), + }, + { + name: "loopback_with_routes_configured", + addr: "127.0.0.1:3000", + routes: map[netip.Prefix]bool{ + netip.MustParsePrefix("100.64.0.0/10"): true, + }, + wantVia: false, + wantAddr: netip.MustParseAddrPort("127.0.0.1:3000"), + }, + { + name: "netstack_for_ip", + addr: "100.100.100.100:53", + useNetstackFor: func(ip netip.Addr) bool { + return ip == netip.MustParseAddr("100.100.100.100") + }, + wantVia: true, + wantAddr: netip.MustParseAddrPort("100.100.100.100:53"), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + d := &Dialer{} + if tt.routes != nil { + rt := &bart.Table[bool]{} + for pfx, v := range tt.routes { + rt.Insert(pfx, v) + } + d.routes.Store(rt) + } + d.UseNetstackForIP = tt.useNetstackFor + + ipp, viaTailscale, err := d.UserDialPlan(context.Background(), "tcp", tt.addr) + if err != nil { + t.Fatalf("UserDialPlan: %v", err) + } + if viaTailscale != tt.wantVia { + t.Errorf("viaTailscale = %v, want %v", viaTailscale, tt.wantVia) + } + if ipp != tt.wantAddr { + t.Errorf("addr = %v, want %v", ipp, tt.wantAddr) + } + }) + } +} diff --git a/net/tstun/wrap.go b/net/tstun/wrap.go index 1b28eb157..cd75aff5c 100644 --- a/net/tstun/wrap.go +++ b/net/tstun/wrap.go @@ -111,8 +111,7 @@ type Wrapper struct { // you might need to add an align64 field here. lastActivityAtomic mono.Time // time of last send or receive - destIPActivity syncs.AtomicValue[map[netip.Addr]func()] - discoKey syncs.AtomicValue[key.DiscoPublic] + discoKey syncs.AtomicValue[key.DiscoPublic] // timeNow, if non-nil, will be used to obtain the current time. timeNow func() time.Time @@ -340,16 +339,6 @@ func (t *Wrapper) now() time.Time { return time.Now() } -// SetDestIPActivityFuncs sets a map of funcs to run per packet -// destination (the map keys). -// -// The map ownership passes to the Wrapper. It must be non-nil. -func (t *Wrapper) SetDestIPActivityFuncs(m map[netip.Addr]func()) { - if buildfeatures.HasLazyWG { - t.destIPActivity.Store(m) - } -} - // SetDiscoKey sets the current discovery key. // // It is only used for filtering out bogus traffic when network @@ -997,13 +986,6 @@ func (t *Wrapper) Read(buffs [][]byte, sizes []int, offset int) (int, error) { for _, data := range res.data { p.Decode(data[res.dataOffset:]) - if buildfeatures.HasLazyWG { - if m := t.destIPActivity.Load(); m != nil { - if fn := m[p.Dst.Addr()]; fn != nil { - fn() - } - } - } if buildfeatures.HasCapture && captHook != nil { captHook(packet.FromLocal, t.now(), p.Buffer(), p.CaptureMeta) } @@ -1136,14 +1118,6 @@ func (t *Wrapper) injectedRead(res tunInjectedRead, outBuffs [][]byte, sizes []i pc.snat(p) invertGSOChecksum(pkt, gso) - if buildfeatures.HasLazyWG { - if m := t.destIPActivity.Load(); m != nil { - if fn := m[p.Dst.Addr()]; fn != nil { - fn() - } - } - } - if res.packet != nil { var gsoOptions tun.GSOOptions gsoOptions, err = stackGSOToTunGSO(pkt, gso) diff --git a/pkgdoc_test.go b/pkgdoc_test.go index 60b2d4856..d0f0d66bd 100644 --- a/pkgdoc_test.go +++ b/pkgdoc_test.go @@ -4,6 +4,7 @@ package tailscaleroot import ( + "go/ast" "go/parser" "go/token" "os" @@ -13,6 +14,17 @@ import ( "testing" ) +func hasIgnoreBuildTag(f *ast.File) bool { + for _, cg := range f.Comments { + for _, c := range cg.List { + if c.Text == "//go:build ignore" { + return true + } + } + } + return false +} + func TestPackageDocs(t *testing.T) { switch runtime.GOOS { case "darwin", "linux": @@ -26,8 +38,11 @@ func TestPackageDocs(t *testing.T) { if err != nil { return err } - if fi.Mode().IsDir() && path == ".git" { - return filepath.SkipDir // No documentation lives in .git + if fi.Mode().IsDir() && path != "." && strings.HasPrefix(filepath.Base(path), ".") { + return filepath.SkipDir // No documentation lives in dot directories (.git, .claude, etc) + } + if fi.Mode().IsDir() && filepath.Base(path) == "testdata" { + return filepath.SkipDir // testdata is ignored by the go tool; not real packages } if fi.Mode().IsRegular() && strings.HasSuffix(path, ".go") { if strings.HasSuffix(path, "_test.go") { @@ -48,6 +63,9 @@ func TestPackageDocs(t *testing.T) { if err != nil { t.Fatalf("failed to ParseFile %q: %v", fileName, err) } + if hasIgnoreBuildTag(f) { + continue + } dir := filepath.Dir(fileName) if _, ok := byDir[dir]; !ok { byDir[dir] = nil @@ -61,14 +79,8 @@ func TestPackageDocs(t *testing.T) { } } for dir, ff := range byDir { - switch dir { - case "tstest/integration/vms": - // This package has a couple go:build ignore commands and this test doesn't - // handle parsing those. Just allowlist that package for now (2024-07-10). - continue - } if len(ff) > 1 { - t.Logf("multiple files with package doc in %s: %q", dir, ff) + t.Errorf("multiple files with package doc in %s: %q", dir, ff) } if len(ff) == 0 { if strings.HasPrefix(dir, "gokrazy/") { diff --git a/posture/serialnumber_stub.go b/posture/serialnumber_stub.go index e040aacfb..6df9b4079 100644 --- a/posture/serialnumber_stub.go +++ b/posture/serialnumber_stub.go @@ -12,6 +12,7 @@ package posture import ( "errors" + "fmt" "tailscale.com/types/logger" "tailscale.com/util/syspolicy/policyclient" @@ -19,5 +20,5 @@ import ( // GetSerialNumber returns client machine serial number(s). func GetSerialNumbers(polc policyclient.Client, _ logger.Logf) ([]string, error) { - return nil, errors.New("not implemented") + return nil, fmt.Errorf("not implemented: %w", errors.ErrUnsupported) } diff --git a/pull-toolchain.sh b/pull-toolchain.sh index effeca669..8f34129c6 100755 --- a/pull-toolchain.sh +++ b/pull-toolchain.sh @@ -46,15 +46,15 @@ if [ "${TS_GO_NEXT:-}" != "1" ]; then fi fi -# Only update go.toolchain.version and go.toolchain.rev.sri for the main toolchain, +# Only update go.toolchain.version and flakehashes.json for the main toolchain, # skipping it if TS_GO_NEXT=1. Those two files are only used by Nix, and as of 2026-01-26 # don't yet support TS_GO_NEXT=1 with flake.nix or in our corp CI. if [ "${TS_GO_NEXT:-}" != "1" ]; then ./tool/go version 2>/dev/null | awk '{print $3}' | sed 's/^go//' > go.toolchain.version ./tool/go mod edit -go "$(cat go.toolchain.version)" - ./update-flake.sh + ./tool/go run ./tool/updateflakes fi -if [ -n "$(git diff-index --name-only HEAD -- "$go_toolchain_rev_file" go.toolchain.next.rev go.toolchain.rev.sri go.toolchain.version)" ]; then +if [ -n "$(git diff-index --name-only HEAD -- "$go_toolchain_rev_file" go.toolchain.next.rev flakehashes.json go.toolchain.version)" ]; then echo "pull-toolchain.sh: changes imported. Use git commit to make them permanent." >&2 fi diff --git a/release/dist/qnap/files/scripts/build-qpkg.sh b/release/dist/qnap/files/scripts/build-qpkg.sh index d478bfe6b..61786ead8 100755 --- a/release/dist/qnap/files/scripts/build-qpkg.sh +++ b/release/dist/qnap/files/scripts/build-qpkg.sh @@ -4,17 +4,9 @@ set -eu # Clean up folders and files created during build. function cleanup() { - rm -rf /Tailscale/$ARCH - rm -f /Tailscale/sed* - rm -f /Tailscale/qpkg.cfg - - # If this build was signed, a .qpkg.codesigning file will be created as an - # artifact of the build - # (see https://github.com/qnap-dev/qdk2/blob/93ac75c76941b90ee668557f7ce01e4b23881054/QDK_2.x/bin/qbuild#L992). - # - # go/client-release doesn't seem to need these, so we delete them here to - # avoid uploading them to pkgs.tailscale.com. - rm -f /out/*.qpkg.codesigning + rm -rf /Tailscale/$ARCH + rm -f /Tailscale/sed* + rm -f /Tailscale/qpkg.cfg } trap cleanup EXIT @@ -22,6 +14,6 @@ mkdir -p /Tailscale/$ARCH cp /tailscaled /Tailscale/$ARCH/tailscaled cp /tailscale /Tailscale/$ARCH/tailscale -sed "s/\$QPKG_VER/$TSTAG-$QNAPTAG/g" /Tailscale/qpkg.cfg.in > /Tailscale/qpkg.cfg +sed "s/\$QPKG_VER/$TSTAG-$QNAPTAG/g" /Tailscale/qpkg.cfg.in >/Tailscale/qpkg.cfg qbuild --root /Tailscale --build-arch $ARCH --build-dir /out diff --git a/release/dist/qnap/pkgs.go b/release/dist/qnap/pkgs.go index 1d69b3eaf..b505b1ac0 100644 --- a/release/dist/qnap/pkgs.go +++ b/release/dist/qnap/pkgs.go @@ -118,7 +118,16 @@ func (t *target) buildQPKG(b *dist.Build, qnapBuilds *qnapBuilds, inner *innerPk return nil, fmt.Errorf("docker run %v: %s", err, out) } - return []string{filePath, filePath + ".md5"}, nil + ret := []string{filePath, filePath + ".md5"} + // If the build was signed, a .codesigning file is produced containing + // the last 32 characters of the base64-encoded CMS signature. This is + // used by pkgserve to populate entries in the QNAP + // repository XML. + codesigning := filePath + ".codesigning" + if _, err := os.Stat(codesigning); err == nil { + ret = append(ret, codesigning) + } + return ret, nil } type qnapBuildsMemoizeKey struct{} diff --git a/scripts/installer.sh b/scripts/installer.sh index 2c15ea657..880c6a438 100755 --- a/scripts/installer.sh +++ b/scripts/installer.sh @@ -55,7 +55,7 @@ main() { VERSION_MAJOR="${VERSION_ID:-}" VERSION_MAJOR="${VERSION_MAJOR%%.*}" case "$ID" in - ubuntu|pop|neon|zorin|tuxedo) + ubuntu|pop|neon|tuxedo) OS="ubuntu" if [ "${UBUNTU_CODENAME:-}" != "" ]; then VERSION="$UBUNTU_CODENAME" @@ -266,7 +266,7 @@ main() { VERSION="leap/$VERSION_ID" PACKAGETYPE="zypper" ;; - opensuse-tumbleweed) + opensuse-tumbleweed|opensuse-slowroll) OS="opensuse" VERSION="tumbleweed" PACKAGETYPE="zypper" @@ -336,6 +336,16 @@ main() { VERSION="$VERSION_MAJOR" PACKAGETYPE="tdnf" ;; + zorin) + OS="ubuntu" + VERSION="$UBUNTU_CODENAME" + PACKAGETYPE="apt" + if [ "$VERSION_MAJOR" -lt 16 ]; then + APT_KEY_TYPE="legacy" + else + APT_KEY_TYPE="keyring" + fi + ;; steamos) echo "To install Tailscale on SteamOS, please follow the instructions here:" echo "https://github.com/tailscale-dev/deck-tailscale" diff --git a/shell.nix b/shell.nix index 17720bb3d..bd81d79f0 100644 --- a/shell.nix +++ b/shell.nix @@ -16,4 +16,4 @@ ) { src = ./.; }).shellNix -# nix-direnv cache busting line: sha256-aZkUnWyQokNw+lxut9Fak3CazmwYE4tXILhzfK4jeK4= +# nix-direnv cache busting line: sha256-Xwm+ZLNqd2k7c2GFQJ2Pf/xuFLMcXhYl5I/YVgS9V4U= diff --git a/ssh/tailssh/incubator.go b/ssh/tailssh/incubator.go index c20b18d3e..48c65e8e5 100644 --- a/ssh/tailssh/incubator.go +++ b/ssh/tailssh/incubator.go @@ -202,7 +202,7 @@ func (ss *sshSession) newIncubatorCommand(logf logger.Logf) (cmd *exec.Cmd, err incubatorArgs = append(incubatorArgs, "--is-selinux-enforcing") } - nm := ss.conn.srv.lb.NetMap() + nm := ss.conn.srv.lb.NetMapNoPeers() forceV1Behavior := nm.HasCap(tailcfg.NodeAttrSSHBehaviorV1) && !nm.HasCap(tailcfg.NodeAttrSSHBehaviorV2) if forceV1Behavior { incubatorArgs = append(incubatorArgs, "--force-v1-behavior") diff --git a/ssh/tailssh/incubator_plan9.go b/ssh/tailssh/incubator_plan9.go index 69112635f..8d0031413 100644 --- a/ssh/tailssh/incubator_plan9.go +++ b/ssh/tailssh/incubator_plan9.go @@ -92,7 +92,7 @@ func (ss *sshSession) newIncubatorCommand(logf logger.Logf) (cmd *exec.Cmd, err "--tty-name=", // updated in-place by startWithPTY } - nm := ss.conn.srv.lb.NetMap() + nm := ss.conn.srv.lb.NetMapNoPeers() forceV1Behavior := nm.HasCap(tailcfg.NodeAttrSSHBehaviorV1) && !nm.HasCap(tailcfg.NodeAttrSSHBehaviorV2) if forceV1Behavior { incubatorArgs = append(incubatorArgs, "--force-v1-behavior") diff --git a/ssh/tailssh/privs_test.go b/ssh/tailssh/privs_test.go index 7ddc9c861..bd483e2b4 100644 --- a/ssh/tailssh/privs_test.go +++ b/ssh/tailssh/privs_test.go @@ -20,6 +20,7 @@ import ( "syscall" "testing" + "tailscale.com/tstest" "tailscale.com/types/logger" ) @@ -71,9 +72,7 @@ func TestDoDropPrivileges(t *testing.T) { os.Exit(0) } - if os.Getuid() != 0 { - t.Skip("test only works when run as root") - } + tstest.RequireRoot(t) rerunSelf := func(t *testing.T, input SubprocInput) []byte { fpath := filepath.Join(t.TempDir(), "out.json") diff --git a/ssh/tailssh/tailssh.go b/ssh/tailssh/tailssh.go index c13d3d29e..e01f78eb3 100644 --- a/ssh/tailssh/tailssh.go +++ b/ssh/tailssh/tailssh.go @@ -76,6 +76,7 @@ const ( type ipnLocalBackend interface { ShouldRunSSH() bool NetMap() *netmap.NetworkMap + NetMapNoPeers() *netmap.NetworkMap WhoIs(proto string, ipp netip.AddrPort) (n tailcfg.NodeView, u tailcfg.UserProfile, ok bool) DoNoiseRequest(req *http.Request) (*http.Response, error) Dialer() *tsdial.Dialer @@ -598,7 +599,7 @@ func (c *conn) sshPolicy() (_ *tailcfg.SSHPolicy, ok bool) { if !lb.ShouldRunSSH() { return nil, false } - nm := lb.NetMap() + nm := lb.NetMapNoPeers() if nm == nil { return nil, false } @@ -717,7 +718,7 @@ func (c *conn) handleSessionPostSSHAuth(s gliderssh.Session) { } func (c *conn) expandDelegateURLLocked(actionURL string) string { - nm := c.srv.lb.NetMap() + nm := c.srv.lb.NetMapNoPeers() ci := c.info lu := c.localUser var dstNodeID string diff --git a/ssh/tailssh/tailssh_integration_test.go b/ssh/tailssh/tailssh_integration_test.go index d49ca8eef..7b70a6d51 100644 --- a/ssh/tailssh/tailssh_integration_test.go +++ b/ssh/tailssh/tailssh_integration_test.go @@ -6,7 +6,6 @@ package tailssh import ( - "bufio" "bytes" "context" "crypto/rand" @@ -60,56 +59,33 @@ import ( var testVarRoot string func TestMain(m *testing.M) { + debugTest.Store(true) + + // Create our log file. + if err := os.WriteFile("/tmp/tailscalessh.log", nil, 0666); err != nil { + log.Fatal(err) + } + + // Create a temp directory for SSH host keys. var err error testVarRoot, err = os.MkdirTemp("", "tailssh-test-var") if err != nil { log.Fatal(err) } - defer os.RemoveAll(testVarRoot) - // Create our log file. - file, err := os.OpenFile("/tmp/tailscalessh.log", os.O_CREATE|os.O_WRONLY, 0666) - if err != nil { - log.Fatal(err) - } - file.Close() + code := m.Run() - // Tail our log file. - cmd := exec.Command("tail", "-F", "/tmp/tailscalessh.log") + os.RemoveAll(testVarRoot) - r, err := cmd.StdoutPipe() - if err != nil { - return + // Print any log output from the incubator subprocesses. + if b, err := os.ReadFile("/tmp/tailscalessh.log"); err == nil && len(b) > 0 { + log.Print(string(b)) } - scanner := bufio.NewScanner(r) - go func() { - for scanner.Scan() { - line := scanner.Text() - log.Println(line) - } - }() - - err = cmd.Start() - if err != nil { - return - } - defer func() { - // tail -f has a default sleep interval of 1 second, so it takes a - // moment for it to finish reading our log file after we've terminated. - // So, wait a bit to let it catch up. - time.Sleep(2 * time.Second) - }() - - m.Run() + os.Exit(code) } func TestIntegrationSSH(t *testing.T) { - debugTest.Store(true) - t.Cleanup(func() { - debugTest.Store(false) - }) - homeDir := "/home/testuser" if runtime.GOOS == "darwin" { homeDir = "/Users/testuser" @@ -215,11 +191,6 @@ func TestIntegrationSSH(t *testing.T) { } func TestIntegrationSFTP(t *testing.T) { - debugTest.Store(true) - t.Cleanup(func() { - debugTest.Store(false) - }) - for _, forceV1Behavior := range []bool{false, true} { name := "v2" if forceV1Behavior { @@ -276,11 +247,6 @@ func TestIntegrationSFTP(t *testing.T) { } func TestIntegrationSCP(t *testing.T) { - debugTest.Store(true) - t.Cleanup(func() { - debugTest.Store(false) - }) - for _, forceV1Behavior := range []bool{false, true} { name := "v2" if forceV1Behavior { @@ -334,11 +300,6 @@ func TestIntegrationSCP(t *testing.T) { } func TestSSHAgentForwarding(t *testing.T) { - debugTest.Store(true) - t.Cleanup(func() { - debugTest.Store(false) - }) - // Create a client SSH key tmpDir, err := os.MkdirTemp("", "") if err != nil { @@ -428,11 +389,6 @@ func TestSSHAgentForwarding(t *testing.T) { // request 'none' auth and instead immediately authenticate with a public key // or password. func TestIntegrationParamiko(t *testing.T) { - debugTest.Store(true) - t.Cleanup(func() { - debugTest.Store(false) - }) - addr := testServer(t, "testuser", true, false) host, port, err := net.SplitHostPort(addr) if err != nil { @@ -736,26 +692,34 @@ func (s *session) run(t *testing.T, cmdString string, shell bool) string { func (s *session) read() string { ch := make(chan []byte) go func() { + defer close(ch) for { b := make([]byte, 1) n, err := s.stdout.Read(b) if n > 0 { ch <- b } - if err == io.EOF { + if err != nil { return } } }() // Read first byte in blocking fashion. - _got := <-ch + b, ok := <-ch + if !ok { + return "" + } + _got := b - // Read subsequent bytes in non-blocking fashion. + // Read subsequent bytes until EOF or silence. readLoop: for { select { - case b := <-ch: + case b, ok := <-ch: + if !ok { + break readLoop + } _got = append(_got, b...) case <-time.After(1 * time.Second): break readLoop @@ -937,8 +901,8 @@ func (tb *testBackend) NetMap() *netmap.NetworkMap { AllowLocalPortForwarding: tb.allowLocalPortForwarding, AllowRemotePortForwarding: tb.allowRemotePortForwarding, }, - SSHUsers: map[string]string{"*": tb.localUser}, - AcceptEnv: []string{"GIT_*", "EXACT_MATCH", "TEST?NG"}, + SSHUsers: map[string]string{"*": tb.localUser}, + AcceptEnv: []string{"GIT_*", "EXACT_MATCH", "TEST?NG"}, }, }, }, @@ -946,6 +910,8 @@ func (tb *testBackend) NetMap() *netmap.NetworkMap { } } +func (tb *testBackend) NetMapNoPeers() *netmap.NetworkMap { return tb.NetMap() } + func (tb *testBackend) WhoIs(_ string, ipp netip.AddrPort) (n tailcfg.NodeView, u tailcfg.UserProfile, ok bool) { return (&tailcfg.Node{}).View(), tailcfg.UserProfile{ LoginName: tb.localUser + "@example.com", diff --git a/ssh/tailssh/tailssh_test.go b/ssh/tailssh/tailssh_test.go index c8b5f698b..04c9cd2f5 100644 --- a/ssh/tailssh/tailssh_test.go +++ b/ssh/tailssh/tailssh_test.go @@ -36,6 +36,7 @@ import ( gliderssh "github.com/tailscale/gliderssh" "golang.org/x/net/http2" "golang.org/x/net/http2/h2c" + "tailscale.com/cmd/testwrapper/flakytest" "tailscale.com/ipn/ipnlocal" "tailscale.com/ipn/store/mem" "tailscale.com/net/memnet" @@ -386,7 +387,14 @@ type localState struct { serverActions map[string]*tailcfg.SSHAction } -var currentUser = os.Getenv("USER") // Use the current user for the test. +var currentUser = func() string { + // Prefer user.Current because the USER env var is not set in + // some environments (e.g. the golang:latest container used by CI). + if u, err := user.Current(); err == nil { + return u.Username + } + return os.Getenv("USER") +}() func (ts *localState) Dialer() *tsdial.Dialer { return &tsdial.Dialer{} @@ -414,6 +422,8 @@ func (ts *localState) NetMap() *netmap.NetworkMap { } } +func (ts *localState) NetMapNoPeers() *netmap.NetworkMap { return ts.NetMap() } + func (ts *localState) WhoIs(proto string, ipp netip.AddrPort) (n tailcfg.NodeView, u tailcfg.UserProfile, ok bool) { if proto != "tcp" { return tailcfg.NodeView{}, tailcfg.UserProfile{}, false @@ -469,6 +479,9 @@ func newSSHRule(action *tailcfg.SSHAction) *tailcfg.SSHRule { } func TestSSHRecordingCancelsSessionsOnUploadFailure(t *testing.T) { + if runtime.GOOS == "darwin" { + flakytest.Mark(t, "https://github.com/tailscale/tailscale/issues/7707") + } if runtime.GOOS != "linux" && runtime.GOOS != "darwin" { t.Skipf("skipping on %q; only runs on linux and darwin", runtime.GOOS) } diff --git a/ssh/tailssh/testcontainers/Dockerfile b/ssh/tailssh/testcontainers/Dockerfile index 768791028..9d662ca1a 100644 --- a/ssh/tailssh/testcontainers/Dockerfile +++ b/ssh/tailssh/testcontainers/Dockerfile @@ -28,64 +28,68 @@ COPY tailssh.test . RUN chmod 755 tailscaled -RUN echo "First run tests normally." -RUN eval `ssh-agent -s` && TAILSCALED_PATH=`pwd`tailscaled ./tailssh.test -test.v -test.run TestSSHAgentForwarding -RUN if echo "$BASE" | grep "ubuntu:"; then rm -Rf /home/testuser; fi -RUN TAILSCALED_PATH=`pwd`tailscaled ./tailssh.test -test.v -test.run TestIntegrationSFTP -RUN if echo "$BASE" | grep "ubuntu:"; then rm -Rf /home/testuser; fi -RUN TAILSCALED_PATH=`pwd`tailscaled ./tailssh.test -test.v -test.run TestIntegrationSCP -RUN if echo "$BASE" | grep "ubuntu:"; then rm -Rf /home/testuser; fi -RUN TAILSCALED_PATH=`pwd`tailscaled ./tailssh.test -test.v -test.run TestIntegrationSSH -RUN if echo "$BASE" | grep "ubuntu:"; then rm -Rf /home/testuser; fi -RUN TAILSCALED_PATH=`pwd`tailscaled ./tailssh.test -test.v -test.run TestIntegrationParamiko -RUN TAILSCALED_PATH=`pwd`tailscaled ./tailssh.test -test.v -test.run TestLocalUnixForwarding -RUN TAILSCALED_PATH=`pwd`tailscaled ./tailssh.test -test.v -test.run TestReverseUnixForwarding -RUN TAILSCALED_PATH=`pwd`tailscaled ./tailssh.test -test.v -test.run TestUnixForwardingDenied -RUN TAILSCALED_PATH=`pwd`tailscaled ./tailssh.test -test.v -test.run TestUnixForwardingPathRestriction +# Run tests normally. +# On Ubuntu, delete testuser's home directory between tests to verify +# that PAM's pam_mkhomedir recreates it each time. +RUN set -e && \ + eval $(ssh-agent -s) && \ + TAILSCALED_PATH=$(pwd)/tailscaled ./tailssh.test -test.v -test.run TestSSHAgentForwarding && \ + if echo "$BASE" | grep -q "ubuntu:"; then rm -Rf /home/testuser; fi && \ + TAILSCALED_PATH=$(pwd)/tailscaled ./tailssh.test -test.v -test.run TestIntegrationSFTP && \ + if echo "$BASE" | grep -q "ubuntu:"; then rm -Rf /home/testuser; fi && \ + TAILSCALED_PATH=$(pwd)/tailscaled ./tailssh.test -test.v -test.run TestIntegrationSCP && \ + if echo "$BASE" | grep -q "ubuntu:"; then rm -Rf /home/testuser; fi && \ + TAILSCALED_PATH=$(pwd)/tailscaled ./tailssh.test -test.v -test.run TestIntegrationSSH && \ + if echo "$BASE" | grep -q "ubuntu:"; then rm -Rf /home/testuser; fi && \ + TAILSCALED_PATH=$(pwd)/tailscaled ./tailssh.test -test.v -test.run TestIntegrationParamiko && \ + TAILSCALED_PATH=$(pwd)/tailscaled ./tailssh.test -test.v -test.run TestLocalUnixForwarding && \ + TAILSCALED_PATH=$(pwd)/tailscaled ./tailssh.test -test.v -test.run TestReverseUnixForwarding && \ + TAILSCALED_PATH=$(pwd)/tailscaled ./tailssh.test -test.v -test.run TestUnixForwardingDenied && \ + TAILSCALED_PATH=$(pwd)/tailscaled ./tailssh.test -test.v -test.run TestUnixForwardingPathRestriction -RUN echo "Then run tests as non-root user testuser and make sure tests still pass." -RUN touch /tmp/tailscalessh.log -RUN chown testuser:groupone /tmp/tailscalessh.log -RUN TAILSCALED_PATH=`pwd`tailscaled eval `su -m testuser -c ssh-agent -s` && su -m testuser -c "./tailssh.test -test.v -test.run TestSSHAgentForwarding" -RUN TAILSCALED_PATH=`pwd`tailscaled su -m testuser -c "./tailssh.test -test.v -test.run TestIntegration TestDoDropPrivileges" -RUN echo "Also, deny everyone access to the user's home directory and make sure non file-related tests still pass." -RUN mkdir -p /home/testuser && chown testuser:groupone /home/testuser && chmod 0000 /home/testuser -RUN TAILSCALED_PATH=`pwd`tailscaled SKIP_FILE_OPS=1 su -m testuser -c "./tailssh.test -test.v -test.run TestIntegrationSSH" -RUN chmod 0755 /home/testuser -RUN chown root:root /tmp/tailscalessh.log +# Run tests as non-root user testuser and make sure tests still pass. +RUN set -e && \ + touch /tmp/tailscalessh.log && \ + chown testuser:groupone /tmp/tailscalessh.log && \ + export TAILSCALED_PATH=$(pwd)/tailscaled && \ + eval $(su -m testuser -c "ssh-agent -s") && \ + su -m testuser -c "./tailssh.test -test.v -test.run 'TestSSHAgentForwarding|TestIntegration|TestDoDropPrivileges'" && \ + echo "Also, deny everyone access to the user's home directory and make sure non file-related tests still pass." && \ + mkdir -p /home/testuser && chown testuser:groupone /home/testuser && chmod 0000 /home/testuser && \ + SKIP_FILE_OPS=1 su -m testuser -c "./tailssh.test -test.v -test.run TestIntegrationSSH" && \ + chmod 0755 /home/testuser && \ + chown root:root /tmp/tailscalessh.log -RUN if echo "$BASE" | grep "ubuntu:"; then \ - echo "Then run tests in a system that's pretending to be SELinux in enforcing mode" && \ - # Remove execute permissions for /usr/bin/login so that it fails. +# On Ubuntu, run tests pretending to be SELinux in enforcing mode. +RUN if echo "$BASE" | grep -q "ubuntu:"; then \ + set -e && \ + echo "Run tests in a system that's pretending to be SELinux in enforcing mode" && \ mv /usr/bin/login /tmp/login_orig && \ - # Use nonsense for /usr/bin/login so that it fails. - # It's not the same failure mode as in SELinux, but failure is good enough for test. echo "adsfasdfasdf" > /usr/bin/login && \ chmod 755 /usr/bin/login && \ - # Simulate getenforce command printf "#!/bin/bash\necho 'Enforcing'" > /usr/bin/getenforce && \ chmod 755 /usr/bin/getenforce && \ - eval `ssh-agent -s` && TAILSCALED_PATH=`pwd`tailscaled ./tailssh.test -test.v -test.run TestSSHAgentForwarding && \ - TAILSCALED_PATH=`pwd`tailscaled ./tailssh.test -test.v -test.run TestIntegration && \ + eval $(ssh-agent -s) && \ + TAILSCALED_PATH=$(pwd)/tailscaled ./tailssh.test -test.v -test.run 'TestSSHAgentForwarding|TestIntegration' && \ mv /tmp/login_orig /usr/bin/login && \ rm /usr/bin/getenforce \ ; fi -RUN echo "Then remove the login command and make sure tests still pass." -RUN rm `which login` -RUN eval `ssh-agent -s` && TAILSCALED_PATH=`pwd`tailscaled ./tailssh.test -test.v -test.run TestSSHAgentForwarding -RUN if echo "$BASE" | grep "ubuntu:"; then rm -Rf /home/testuser; fi -RUN TAILSCALED_PATH=`pwd`tailscaled ./tailssh.test -test.v -test.run TestIntegrationSFTP -RUN if echo "$BASE" | grep "ubuntu:"; then rm -Rf /home/testuser; fi -RUN TAILSCALED_PATH=`pwd`tailscaled ./tailssh.test -test.v -test.run TestIntegrationSCP -RUN if echo "$BASE" | grep "ubuntu:"; then rm -Rf /home/testuser; fi -RUN TAILSCALED_PATH=`pwd`tailscaled ./tailssh.test -test.v -test.run TestIntegrationSSH +# Remove the login command and make sure tests still pass. +RUN set -e && \ + rm $(which login) && \ + eval $(ssh-agent -s) && \ + TAILSCALED_PATH=$(pwd)/tailscaled ./tailssh.test -test.v -test.run TestSSHAgentForwarding && \ + if echo "$BASE" | grep -q "ubuntu:"; then rm -Rf /home/testuser; fi && \ + TAILSCALED_PATH=$(pwd)/tailscaled ./tailssh.test -test.v -test.run TestIntegrationSFTP && \ + if echo "$BASE" | grep -q "ubuntu:"; then rm -Rf /home/testuser; fi && \ + TAILSCALED_PATH=$(pwd)/tailscaled ./tailssh.test -test.v -test.run TestIntegrationSCP && \ + if echo "$BASE" | grep -q "ubuntu:"; then rm -Rf /home/testuser; fi && \ + TAILSCALED_PATH=$(pwd)/tailscaled ./tailssh.test -test.v -test.run TestIntegrationSSH -RUN echo "Then remove the su command and make sure tests still pass." -RUN chown root:root /tmp/tailscalessh.log -RUN rm `which su` -RUN eval `ssh-agent -s` && TAILSCALED_PATH=`pwd`tailscaled ./tailssh.test -test.v -test.run TestSSHAgentForwarding -RUN TAILSCALED_PATH=`pwd`tailscaled ./tailssh.test -test.v -test.run TestIntegration - -RUN echo "Test doDropPrivileges" -RUN TAILSCALED_PATH=`pwd`tailscaled ./tailssh.test -test.v -test.run TestDoDropPrivileges +# Remove the su command and make sure tests still pass. +RUN set -e && \ + chown root:root /tmp/tailscalessh.log && \ + rm $(which su) && \ + eval $(ssh-agent -s) && \ + TAILSCALED_PATH=$(pwd)/tailscaled ./tailssh.test -test.v -test.run 'TestSSHAgentForwarding|TestIntegration|TestDoDropPrivileges' diff --git a/tailcfg/tailcfg.go b/tailcfg/tailcfg.go index 3d7921d75..57c68fad6 100644 --- a/tailcfg/tailcfg.go +++ b/tailcfg/tailcfg.go @@ -182,7 +182,10 @@ type CapabilityVersion int // - 133: 2026-02-17: client understands [NodeAttrForceRegisterMagicDNSIPv4Only]; MagicDNS IPv6 registered w/ OS by default // - 134: 2026-03-09: Client understands [NodeAttrDisableAndroidBindToActiveNetwork] // - 135: 2026-03-30: Client understands [NodeAttrCacheNetworkMaps] -const CurrentCapabilityVersion CapabilityVersion = 135 +// - 136: 2026-04-09: Client understands [NodeAttrDisableLinuxCGNATDropRule] +// - 137: 2026-04-15: Client handles 429 responses to /machine/register. +// - 138: 2026-03-31: can handle C2N /debug/tka. +const CurrentCapabilityVersion CapabilityVersion = 138 // ID is an integer ID for a user, node, or login allocated by the // control plane. @@ -1280,7 +1283,7 @@ type RegisterRequest struct { Ephemeral bool `json:",omitempty"` // NodeKeySignature is the node's own node-key signature, re-signed - // for its new node key using its network-lock key. + // for its new node key using its tailnet-lock key. // // This field is set when the client retries registration after learning // its NodeKeySignature (which is in need of rotation). @@ -2447,6 +2450,18 @@ type Oauth2Token struct { // These are also referred to as "Node Attributes" in the ACL policy file. type NodeCapability string +// NodeCapabilityPrefix is a prefix for [NodeCapMap] keys that share a common +// namespace, where each entry represents a distinct named instance (e.g. one +// per service). The full key is formed by concatenating the prefix with the +// instance name. +type NodeCapabilityPrefix string + +// ToAttribute returns the full [NodeCapability] key for the given value under +// this prefix, of the form prefix+value. +func (p NodeCapabilityPrefix) ToAttribute(value string) NodeCapability { + return NodeCapability(string(p) + value) +} + const ( CapabilityFileSharing NodeCapability = "https://tailscale.com/cap/file-sharing" CapabilityAdmin NodeCapability = "https://tailscale.com/cap/is-admin" @@ -2460,6 +2475,10 @@ const ( // CapabilityMacUIV2 makes the macOS GUI enable its v2 mode. CapabilityMacUIV2 NodeCapability = "https://tailscale.com/cap/mac-ui-v2" + // CapabilityServicesInDesktopClients enables services list/menu/section in desktop clients. + // If this capability is not present, desktop clients should not show services. + CapabilityServicesInDesktopClients NodeCapability = "https://tailscale.com/cap/services-in-desktop-clients" + // CapabilityBindToInterfaceByRoute changes how Darwin nodes create // sockets (in the net/netns package). See that package for more // details on the behaviour of this capability. @@ -2585,21 +2604,6 @@ const ( // This cannot be set simultaneously with NodeAttrLinuxMustUseIPTables. NodeAttrLinuxMustUseNfTables NodeCapability = "linux-netfilter?v=nftables" - // NodeAttrDisableSeamlessKeyRenewal disables seamless key renewal, which is - // enabled by default in clients as of 2025-09-17 (1.90 and later). - // - // We will use this attribute to manage the rollout, and disable seamless in - // clients with known bugs. - // http://go/seamless-key-renewal - NodeAttrDisableSeamlessKeyRenewal NodeCapability = "disable-seamless-key-renewal" - - // NodeAttrSeamlessKeyRenewal was used to opt-in to seamless key renewal - // during its private alpha. - // - // Deprecated: NodeAttrSeamlessKeyRenewal is deprecated as of CapabilityVersion 126, - // because seamless key renewal is now enabled by default. - NodeAttrSeamlessKeyRenewal NodeCapability = "seamless-key-renewal" - // NodeAttrProbeUDPLifetime makes the client probe UDP path lifetime at the // tail end of an active direct connection in magicsock. NodeAttrProbeUDPLifetime NodeCapability = "probe-udp-lifetime" @@ -2778,6 +2782,22 @@ const ( // absent (or removed), a node that supports netmap caching will ignore and // discard existing cached maps, and will not store any. NodeAttrCacheNetworkMaps NodeCapability = "cache-network-maps" + + // NodeAttrDisableLinuxCGNATDropRule tells Linux clients to not insert a + // blanket firewall DROP rule for inbound traffic from the CGNAT IP range + // that does not originate from the Tailscale network interface. + // This enables access to off-tailnet endpoints within that IP range. + NodeAttrDisableLinuxCGNATDropRule NodeCapability = "disable-linux-cgnat-drop-rule" +) + +const ( + // NodeAttrPrefixServices is the prefix for per-service [NodeCapMap] + // entries describing Services visible (accessible) to this node. + // Each value under such a key is of type [ServiceDetails]. + // The suffix after the prefix is an opaque server-chosen identifier; + // consumers must use [ServiceDetails.Name] as the canonical service name + // rather than parsing it from the map key. + NodeAttrPrefixServices NodeCapabilityPrefix = "services/" ) // SetDNSRequest is a request to add a DNS record. @@ -3318,6 +3338,51 @@ const LBHeader = "Ts-Lb" // this client is hosting can be ignored. type ServiceIPMappings map[ServiceName][]netip.Addr +// ServiceAction describes an action that a Tailscale +// client can invoke for a [ServiceDetails]. +type ServiceAction struct { + // Type is the action's identifier i.e. a unique slug corresponding to a well + // known action. It drives icon selection and client application matching. + Type string + + // Port is the target TCP port for this action. It must match one of + // the specific (non-range) TCP ports listed in the enclosing + // [ServiceDetails.Ports]. + Port uint16 + + // DisplayName is an optional human-readable label which may be shown + // in client menus when there are multiple actions to select from. + // If empty, a display name may be inferred from the Type field. + DisplayName string `json:",omitzero"` +} + +// ServiceDetails describes a Service visible to this node. +// It is the value type stored under [NodeAttrPrefixServices]+serviceName keys in [NodeCapMap]. +type ServiceDetails struct { + // Name is the name of the Service, of the form "svc:dns-label". + Name ServiceName + + // DisplayName is an optional human-readable label for the service. + // If empty, Name is used as a fallback by clients. + DisplayName string `json:",omitzero"` + + // Addrs are the IP addresses (IPv4 and IPv6) assigned to this Service. + Addrs []netip.Addr `json:",omitempty"` + + // Ports are the protocol/port combinations the Service accepts. + Ports []ProtoPortRange `json:",omitempty"` + + // Actions is an optional list of actions describing how a client may + // interact with this service. Each action maps a [ServiceAction.Type] to a + // specific TCP port; the port must match one of the concrete (non-range) + // ports listed in Ports. + // + // Multiple actions may reference the same port. Not every port requires + // a corresponding action. When Actions has length zero, clients may infer + // default interactions from Ports. + Actions []ServiceAction `json:",omitzero"` +} + // ClientAuditAction represents an auditable action that a client can report to the // control plane. These actions must correspond to the supported actions // in the control plane. diff --git a/tailcfg/tailcfg_clone.go b/tailcfg/tailcfg_clone.go index 8b966b621..df2d6d9aa 100644 --- a/tailcfg/tailcfg_clone.go +++ b/tailcfg/tailcfg_clone.go @@ -262,8 +262,19 @@ func (src *DNSConfig) Clone() *DNSConfig { } if dst.Routes != nil { dst.Routes = map[string][]*dnstype.Resolver{} - for k := range src.Routes { - dst.Routes[k] = append([]*dnstype.Resolver{}, src.Routes[k]...) + for k, sv := range src.Routes { + if sv == nil { + dst.Routes[k] = nil + continue + } + dst.Routes[k] = make([]*dnstype.Resolver, len(sv)) + for i := range sv { + if sv[i] == nil { + dst.Routes[k][i] = nil + } else { + dst.Routes[k][i] = sv[i].Clone() + } + } } } if src.FallbackResolvers != nil { diff --git a/tailcfg/tailcfg_view.go b/tailcfg/tailcfg_view.go index 9900efbcc..846663388 100644 --- a/tailcfg/tailcfg_view.go +++ b/tailcfg/tailcfg_view.go @@ -1349,7 +1349,7 @@ func (v RegisterRequestView) Hostinfo() HostinfoView { return v.ж.Hostinfo.View func (v RegisterRequestView) Ephemeral() bool { return v.ж.Ephemeral } // NodeKeySignature is the node's own node-key signature, re-signed -// for its new node key using its network-lock key. +// for its new node key using its tailnet-lock key. // // This field is set when the client retries registration after learning // its NodeKeySignature (which is in need of rotation). diff --git a/tailcfg/tka.go b/tailcfg/tka.go index 29c17b756..f392e6fd3 100644 --- a/tailcfg/tka.go +++ b/tailcfg/tka.go @@ -36,7 +36,7 @@ type TKASignInfo struct { // a NodeKeySignature (NKS), which rotates the node key. // // This is necessary so the node can rotate its node-key without - // talking to a node which holds a trusted network-lock key. + // talking to a node which holds a trusted tailnet-lock key. // It does this by nesting the original NKS in a 'rotation' NKS, // which it then signs with the key corresponding to RotationPubkey. // @@ -193,7 +193,7 @@ type TKASyncSendResponse struct { Head string } -// TKADisableRequest disables network-lock across the tailnet using the +// TKADisableRequest disables tailnet-lock across the tailnet using the // provided disablement secret. // // This is the request schema for a /tka/disable noise RPC. diff --git a/tempfork/pkgdoc/pkgdoc.go b/tempfork/pkgdoc/pkgdoc.go new file mode 100644 index 000000000..cab38dd48 --- /dev/null +++ b/tempfork/pkgdoc/pkgdoc.go @@ -0,0 +1,234 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package pkgdoc is a library-ified fork of Go's cmd/doc program +// that only does what we need for misc/genreadme. +package pkgdoc + +import ( + "bytes" + "errors" + "fmt" + "go/ast" + "go/build" + "go/doc" + "go/doc/comment" + "go/parser" + "go/token" + "io" + "io/fs" + "log" + "slices" +) + +const ( + punchedCardWidth = 80 + indent = " " +) + +type Package struct { + writer io.Writer // Destination for output. + name string // Package name, json for encoding/json. + userPath string // String the user used to find this package. + pkg *ast.Package // Parsed package. + file *ast.File // Merged from all files in the package + doc *doc.Package + build *build.Package + fs *token.FileSet // Needed for printing. + buf pkgBuffer +} + +func (pkg *Package) ToText(w io.Writer, text, prefix, codePrefix string) { + d := pkg.doc.Parser().Parse(text) + pr := pkg.doc.Printer() + pr.TextPrefix = prefix + pr.TextCodePrefix = codePrefix + w.Write(pr.Text(d)) +} + +// ToMarkdown parses the godoc comment text and writes a Markdown rendering to w +// suitable for a repository README.md: top-level sections become ## headings +// without per-heading anchor IDs, and [Symbol] doc links resolve to pkg.go.dev, +// including for symbols in the current package (which the default printer would +// otherwise emit as bare #Name fragments with no backing anchor). +func (pkg *Package) ToMarkdown(w io.Writer, text string) { + d := pkg.doc.Parser().Parse(text) + pr := pkg.doc.Printer() + pr.HeadingLevel = 2 + pr.HeadingID = func(*comment.Heading) string { return "" } + pr.DocLinkBaseURL = "https://pkg.go.dev" + pr.DocLinkURL = func(link *comment.DocLink) string { + importPath := link.ImportPath + if importPath == "" { + importPath = pkg.doc.ImportPath + } + name := link.Name + if link.Recv != "" { + name = link.Recv + "." + name + } + return "https://pkg.go.dev/" + importPath + "#" + name + } + w.Write(pr.Markdown(d)) +} + +// pkgBuffer is a wrapper for bytes.Buffer that prints a package clause the +// first time Write is called. +type pkgBuffer struct { + pkg *Package + printed bool // Prevent repeated package clauses. + bytes.Buffer +} + +func (pb *pkgBuffer) Write(p []byte) (int, error) { + pb.packageClause() + return pb.Buffer.Write(p) +} + +func (pb *pkgBuffer) packageClause() { + if !pb.printed { + pb.printed = true + // Only show package clause for commands if requested explicitly. + if pb.pkg.pkg.Name != "main" { + pb.pkg.packageClause() + } + } +} + +type PackageError string // type returned by pkg.Fatalf. + +func (p PackageError) Error() string { + return string(p) +} + +// parsePackage turns the build package we found into a parsed package +// we can then use to generate documentation. +func parsePackage(writer io.Writer, pkg *build.Package, userPath string) *Package { + // include tells parser.ParseDir which files to include. + // That means the file must be in the build package's GoFiles or CgoFiles + // list only (no tag-ignored files, tests, swig or other non-Go files). + include := func(info fs.FileInfo) bool { + return slices.Contains(pkg.GoFiles, info.Name()) || slices.Contains(pkg.CgoFiles, info.Name()) + } + fset := token.NewFileSet() + // Parse declarations (not just imports) so that doc.Package knows the + // package's symbols; the Markdown printer needs this to resolve + // [Symbol] doc links in package comments. + pkgs, err := parser.ParseDir(fset, pkg.Dir, include, parser.ParseComments) + if err != nil { + log.Fatal(err) + } + // Make sure they are all in one package. + if len(pkgs) == 0 { + log.Fatalf("no source-code package in directory %s", pkg.Dir) + } + if len(pkgs) > 1 { + log.Fatalf("multiple packages in directory %s", pkg.Dir) + } + astPkg := pkgs[pkg.Name] + + // TODO: go/doc does not include typed constants in the constants + // list, which is what we want. For instance, time.Sunday is of type + // time.Weekday, so it is defined in the type but not in the + // Consts list for the package. This prevents + // go doc time.Sunday + // from finding the symbol. Work around this for now, but we + // should fix it in go/doc. + // A similar story applies to factory functions. + mode := doc.AllDecls + docPkg := doc.New(astPkg, pkg.ImportPath, mode) + + p := &Package{ + writer: writer, + name: pkg.Name, + userPath: userPath, + pkg: astPkg, + file: ast.MergePackageFiles(astPkg, 0), + doc: docPkg, + build: pkg, + fs: fset, + } + p.buf.pkg = p + return p +} + +func (pkg *Package) Printf(format string, args ...any) { + fmt.Fprintf(&pkg.buf, format, args...) +} + +func (pkg *Package) flush() { + _, err := pkg.writer.Write(pkg.buf.Bytes()) + if err != nil { + log.Fatal(err) + } + pkg.buf.Reset() // Not needed, but it's a flush. +} + +var newlineBytes = []byte("\n\n") // We never ask for more than 2. + +// newlines guarantees there are n newlines at the end of the buffer. +func (pkg *Package) newlines(n int) { + for !bytes.HasSuffix(pkg.buf.Bytes(), newlineBytes[:n]) { + pkg.buf.WriteRune('\n') + } +} + +// packageDoc prints the docs for the package as Markdown. +func (pkg *Package) packageDoc() { + pkg.Printf("") // Trigger the package clause; we know the package exists. + pkg.ToMarkdown(&pkg.buf, pkg.doc.Doc) + pkg.newlines(1) + + pkg.bugs() +} + +// packageClause prints the package clause. +func (pkg *Package) packageClause() { + importPath := pkg.build.ImportComment + if importPath == "" { + importPath = pkg.build.ImportPath + } + + pkg.Printf("package %s // import %q\n\n", pkg.name, importPath) +} + +// bugs prints the BUGS information for the package. +// TODO: Provide access to TODOs and NOTEs as well (very noisy so off by default)? +func (pkg *Package) bugs() { + if pkg.doc.Notes["BUG"] == nil { + return + } + pkg.Printf("\n") + for _, note := range pkg.doc.Notes["BUG"] { + pkg.Printf("%s: %v\n", "BUG", note.Body) + } +} + +// PackageDoc generates Markdown documentation for the package in the given +// directory. importPath is the full Go import path of that package (e.g. +// "tailscale.com/tsnet"); it's used to render [Symbol] doc links to the +// right pkg.go.dev URL. If importPath is empty, build.ImportDir's guess +// is used (typically "." for module-based repos). +func PackageDoc(dir, importPath string) ([]byte, error) { + var buf bytes.Buffer + var writer io.Writer = &buf + + buildPackage, err := build.ImportDir(dir, build.ImportComment) + if err != nil { + var noGoError *build.NoGoError + if errors.As(err, &noGoError) { + return nil, nil + } + return nil, err + } + if importPath != "" { + buildPackage.ImportPath = importPath + } + userPath := dir + + pkg := parsePackage(writer, buildPackage, userPath) + pkg.packageDoc() + pkg.flush() + + return buf.Bytes(), nil +} diff --git a/tka/builder_test.go b/tka/builder_test.go index 29ecaf88c..4e820e14d 100644 --- a/tka/builder_test.go +++ b/tka/builder_test.go @@ -27,12 +27,10 @@ func (s signer25519) SignAUM(sigHash tkatype.AUMSigHash) ([]tkatype.Signature, e func TestAuthorityBuilderAddKey(t *testing.T) { pub, priv := testingKey25519(t, 1) key := Key{Kind: Key25519, Public: pub, Votes: 2} + state := CreateStateForTest(key) storage := ChonkMem() - a, _, err := Create(storage, State{ - Keys: []Key{key}, - DisablementValues: [][]byte{DisablementKDF([]byte{1, 2, 3})}, - }, signer25519(priv)) + a, _, err := Create(storage, state, signer25519(priv)) if err != nil { t.Fatalf("Create() failed: %v", err) } @@ -61,12 +59,10 @@ func TestAuthorityBuilderAddKey(t *testing.T) { func TestAuthorityBuilderMaxKey(t *testing.T) { pub, priv := testingKey25519(t, 1) key := Key{Kind: Key25519, Public: pub, Votes: 2} + state := CreateStateForTest(key) storage := ChonkMem() - a, _, err := Create(storage, State{ - Keys: []Key{key}, - DisablementValues: [][]byte{DisablementKDF([]byte{1, 2, 3})}, - }, signer25519(priv)) + a, _, err := Create(storage, state, signer25519(priv)) if err != nil { t.Fatalf("Create() failed: %v", err) } @@ -108,12 +104,10 @@ func TestAuthorityBuilderRemoveKey(t *testing.T) { key := Key{Kind: Key25519, Public: pub, Votes: 2} pub2, _ := testingKey25519(t, 2) key2 := Key{Kind: Key25519, Public: pub2, Votes: 1} + state := CreateStateForTest(key, key2) storage := ChonkMem() - a, _, err := Create(storage, State{ - Keys: []Key{key, key2}, - DisablementValues: [][]byte{DisablementKDF([]byte{1, 2, 3})}, - }, signer25519(priv)) + a, _, err := Create(storage, state, signer25519(priv)) if err != nil { t.Fatalf("Create() failed: %v", err) } @@ -154,12 +148,10 @@ func TestAuthorityBuilderRemoveKey(t *testing.T) { func TestAuthorityBuilderSetKeyVote(t *testing.T) { pub, priv := testingKey25519(t, 1) key := Key{Kind: Key25519, Public: pub, Votes: 2} + state := CreateStateForTest(key) storage := ChonkMem() - a, _, err := Create(storage, State{ - Keys: []Key{key}, - DisablementValues: [][]byte{DisablementKDF([]byte{1, 2, 3})}, - }, signer25519(priv)) + a, _, err := Create(storage, state, signer25519(priv)) if err != nil { t.Fatalf("Create() failed: %v", err) } @@ -190,12 +182,10 @@ func TestAuthorityBuilderSetKeyVote(t *testing.T) { func TestAuthorityBuilderSetKeyMeta(t *testing.T) { pub, priv := testingKey25519(t, 1) key := Key{Kind: Key25519, Public: pub, Votes: 2, Meta: map[string]string{"a": "b"}} + state := CreateStateForTest(key) storage := ChonkMem() - a, _, err := Create(storage, State{ - Keys: []Key{key}, - DisablementValues: [][]byte{DisablementKDF([]byte{1, 2, 3})}, - }, signer25519(priv)) + a, _, err := Create(storage, state, signer25519(priv)) if err != nil { t.Fatalf("Create() failed: %v", err) } @@ -226,12 +216,10 @@ func TestAuthorityBuilderSetKeyMeta(t *testing.T) { func TestAuthorityBuilderMultiple(t *testing.T) { pub, priv := testingKey25519(t, 1) key := Key{Kind: Key25519, Public: pub, Votes: 2} + state := CreateStateForTest(key) storage := ChonkMem() - a, _, err := Create(storage, State{ - Keys: []Key{key}, - DisablementValues: [][]byte{DisablementKDF([]byte{1, 2, 3})}, - }, signer25519(priv)) + a, _, err := Create(storage, state, signer25519(priv)) if err != nil { t.Fatalf("Create() failed: %v", err) } @@ -274,12 +262,10 @@ func TestAuthorityBuilderMultiple(t *testing.T) { func TestAuthorityBuilderCheckpointsAfterXUpdates(t *testing.T) { pub, priv := testingKey25519(t, 1) key := Key{Kind: Key25519, Public: pub, Votes: 2} + state := CreateStateForTest(key) storage := ChonkMem() - a, _, err := Create(storage, State{ - Keys: []Key{key}, - DisablementValues: [][]byte{DisablementKDF([]byte{1, 2, 3})}, - }, signer25519(priv)) + a, _, err := Create(storage, state, signer25519(priv)) if err != nil { t.Fatalf("Create() failed: %v", err) } diff --git a/tka/chaintest_test.go b/tka/chaintest_test.go index 467880e2c..71210608b 100644 --- a/tka/chaintest_test.go +++ b/tka/chaintest_test.go @@ -320,6 +320,20 @@ func optTemplate(name string, template AUM) testchainOpt { } } +func genesisTemplate(key Key) testchainOpt { + state := CreateStateForTest(key) + return optTemplate("genesis", AUM{MessageKind: AUMCheckpoint, State: &state}) +} + +func checkpointTemplate() testchainOpt { + fakeState := &State{ + Keys: []Key{{Kind: Key25519, Votes: 1}}, + DisablementValues: [][]byte{bytes.Repeat([]byte{1}, 32)}, + } + + return optTemplate("checkpoint", AUM{MessageKind: AUMCheckpoint, State: fakeState}) +} + func optKey(name string, key Key, priv ed25519.PrivateKey) testchainOpt { return testchainOpt{ Name: name, diff --git a/tka/deeplink_test.go b/tka/deeplink_test.go index 4d813272a..260ec9026 100644 --- a/tka/deeplink_test.go +++ b/tka/deeplink_test.go @@ -14,11 +14,7 @@ func TestGenerateDeeplink(t *testing.T) { G1 -> L1 G1.template = genesis - `, - optTemplate("genesis", AUM{MessageKind: AUMCheckpoint, State: &State{ - Keys: []Key{key}, - DisablementValues: [][]byte{DisablementKDF([]byte{1, 2, 3})}, - }}), + `, genesisTemplate(key), ) a, _ := Open(c.Chonk()) diff --git a/tka/disabled_stub.go b/tka/disabled_stub.go index d14473e5e..f3cabd491 100644 --- a/tka/disabled_stub.go +++ b/tka/disabled_stub.go @@ -8,6 +8,7 @@ package tka import ( "crypto/ed25519" "errors" + "time" "tailscale.com/types/key" "tailscale.com/types/logger" @@ -158,3 +159,8 @@ func SignByCredential(privKey []byte, wrapped *NodeKeySignature, nodeKey key.Nod } func (s NodeKeySignature) String() string { return "" } + +type CompactionOptions struct { + MinChain int + MinAge time.Duration +} diff --git a/tka/key.go b/tka/key.go index 08897d409..840a65f5c 100644 --- a/tka/key.go +++ b/tka/key.go @@ -32,7 +32,7 @@ func (k KeyKind) String() string { } } -// Key describes the public components of a key known to network-lock. +// Key describes the public components of a key known to tailnet-lock. type Key struct { Kind KeyKind `cbor:"1,keyasint"` diff --git a/tka/key_test.go b/tka/key_test.go index cc6a1f580..1edeaf54f 100644 --- a/tka/key_test.go +++ b/tka/key_test.go @@ -72,10 +72,8 @@ func TestNLPrivate(t *testing.T) { // Test that key.NLPrivate implements Signer by making a new // authority. k := Key{Kind: Key25519, Public: pub.Verifier(), Votes: 1} - _, aum, err := Create(ChonkMem(), State{ - Keys: []Key{k}, - DisablementValues: [][]byte{bytes.Repeat([]byte{1}, 32)}, - }, p) + state := CreateStateForTest(k) + _, aum, err := Create(ChonkMem(), state, p) if err != nil { t.Fatalf("Create() failed: %v", err) } diff --git a/tka/limits.go b/tka/limits.go index 7f5b8dccd..11f53654f 100644 --- a/tka/limits.go +++ b/tka/limits.go @@ -3,11 +3,15 @@ package tka +import ( + "time" +) + const ( // Upper bound on checkpoint elements, chosen arbitrarily. Intended // to cap the size of large AUMs. - maxDisablementSecrets = 32 - maxKeys = 512 + maxDisablementValues = 32 + maxKeys = 512 // Max amount of metadata that can be associated with a key, chosen arbitrarily. // Intended to avoid people abusing TKA as a key-value score. @@ -22,3 +26,10 @@ const ( // Limit on scanning AUM trees, chosen arbitrarily. maxScanIterations = 2000 ) + +var ( + CompactionDefaults = CompactionOptions{ + MinChain: 24, // Keep at minimum 24 AUMs since head. + MinAge: 14 * 24 * time.Hour, // Keep 2 weeks of AUMs. + } +) diff --git a/tka/scenario_test.go b/tka/scenario_test.go index 277bb1acf..61d9e2529 100644 --- a/tka/scenario_test.go +++ b/tka/scenario_test.go @@ -147,10 +147,7 @@ func testScenario(t *testing.T, sharedChain string, sharedOptions ...testchainOp pub, priv := testingKey25519(t, 1) key := Key{Kind: Key25519, Public: pub, Votes: 1} sharedOptions = append(sharedOptions, - optTemplate("genesis", AUM{MessageKind: AUMCheckpoint, State: &State{ - Keys: []Key{key}, - DisablementValues: [][]byte{DisablementKDF([]byte{1, 2, 3})}, - }}), + genesisTemplate(key), optKey("key", key, priv), optSignAllUsing("key")) diff --git a/tka/sig.go b/tka/sig.go index 9d107c98f..7361da3c4 100644 --- a/tka/sig.go +++ b/tka/sig.go @@ -178,7 +178,7 @@ func (s NodeKeySignature) UnverifiedAuthorizingKeyID() (tkatype.KeyID, error) { return s.authorizingKeyID() } -// authorizingKeyID returns the KeyID of the key trusted by network-lock which authorizes +// authorizingKeyID returns the KeyID of the key trusted by tailnet-lock which authorizes // this signature. func (s NodeKeySignature) authorizingKeyID() (tkatype.KeyID, error) { switch s.SigKind { @@ -349,14 +349,14 @@ func (s *NodeKeySignature) rotationDetails() (*RotationDetails, error) { // ResignNKS re-signs a node-key signature for a new node-key. // -// This only matters on network-locked tailnets, because node-key signatures are +// This only matters on tailnet-locked tailnets, because node-key signatures are // how other nodes know that a node-key is authentic. When the node-key is // rotated then the existing signature becomes invalid, so this function is // responsible for generating a new wrapping signature to certify the new node-key. // // The signature itself is a SigRotation signature, which embeds the old signature // and certifies the new node-key as a replacement for the old by signing the new -// signature with RotationPubkey (which is the node's own network-lock key). +// signature with RotationPubkey (which is the node's own tailnet-lock key). func ResignNKS(priv key.NLPrivate, nodeKey key.NodePublic, oldNKS tkatype.MarshaledSignature) (tkatype.MarshaledSignature, error) { var oldSig NodeKeySignature if err := oldSig.Unserialize(oldNKS); err != nil { diff --git a/tka/sig_test.go b/tka/sig_test.go index d02ef9cef..700967af2 100644 --- a/tka/sig_test.go +++ b/tka/sig_test.go @@ -51,7 +51,7 @@ func TestSigDirect(t *testing.T) { } func TestSigNested(t *testing.T) { - // Network-lock key (the key used to sign the nested sig) + // tailnet-lock key (the key used to sign the nested sig) pub, priv := testingKey25519(t, 1) k := Key{Kind: Key25519, Public: pub, Votes: 2} // Rotation key (the key used to sign the outer sig) @@ -64,7 +64,7 @@ func TestSigNested(t *testing.T) { nodeKeyPub, _ := node.Public().MarshalBinary() // The original signature for the old node key, signed by - // the network-lock key. + // the tailnet-lock key. nestedSig := NodeKeySignature{ SigKind: SigDirect, KeyID: k.MustID(), @@ -127,7 +127,7 @@ func TestSigNested(t *testing.T) { } func TestSigNested_DeepNesting(t *testing.T) { - // Network-lock key (the key used to sign the nested sig) + // tailnet-lock key (the key used to sign the nested sig) pub, priv := testingKey25519(t, 1) k := Key{Kind: Key25519, Public: pub, Votes: 2} // Rotation key (the key used to sign the outer sig) @@ -137,7 +137,7 @@ func TestSigNested_DeepNesting(t *testing.T) { oldPub, _ := oldNode.Public().MarshalBinary() // The original signature for the old node key, signed by - // the network-lock key. + // the tailnet-lock key. nestedSig := NodeKeySignature{ SigKind: SigDirect, KeyID: k.MustID(), @@ -173,11 +173,8 @@ func TestSigNested_DeepNesting(t *testing.T) { } // Test this works with our public API - a, _ := Open(newTestchain(t, "G1\nG1.template = genesis", - optTemplate("genesis", AUM{MessageKind: AUMCheckpoint, State: &State{ - Keys: []Key{k}, - DisablementValues: [][]byte{DisablementKDF([]byte{1, 2, 3})}, - }})).Chonk()) + c := newTestchain(t, "G1\nG1.template = genesis", genesisTemplate(k)) + a, _ := Open(c.Chonk()) if err := a.NodeKeyAuthorized(lastNodeKey.Public(), outer.Serialize()); err != nil { t.Errorf("NodeKeyAuthorized(lastNodeKey) failed: %v", err) } @@ -199,7 +196,7 @@ func TestSigNested_DeepNesting(t *testing.T) { } func TestSigCredential(t *testing.T) { - // Network-lock key (the key used to sign the nested sig) + // tailnet-lock key (the key used to sign the nested sig) pub, priv := testingKey25519(t, 1) k := Key{Kind: Key25519, Public: pub, Votes: 2} // 'credential' key (the one being delegated to) @@ -238,11 +235,8 @@ func TestSigCredential(t *testing.T) { } // Test someone can't misuse our public API for verifying node-keys - a, _ := Open(newTestchain(t, "G1\nG1.template = genesis", - optTemplate("genesis", AUM{MessageKind: AUMCheckpoint, State: &State{ - Keys: []Key{k}, - DisablementValues: [][]byte{DisablementKDF([]byte{1, 2, 3})}, - }})).Chonk()) + c := newTestchain(t, "G1\nG1.template = genesis", genesisTemplate(k)) + a, _ := Open(c.Chonk()) if err := a.NodeKeyAuthorized(node.Public(), nestedSig.Serialize()); err == nil { t.Error("NodeKeyAuthorized(SigCredential, node) did not fail") } @@ -519,7 +513,7 @@ func TestResignNKS(t *testing.T) { origPub, _ := origNode.Public().MarshalBinary() // The original signature for the old node key, signed by - // the network-lock key. + // the tailnet-lock key. directSig := NodeKeySignature{ SigKind: SigDirect, KeyID: authKey.MustID(), diff --git a/tka/state.go b/tka/state.go index 69b3dbfeb..ddf2d4ee0 100644 --- a/tka/state.go +++ b/tka/state.go @@ -13,6 +13,7 @@ import ( "golang.org/x/crypto/argon2" "tailscale.com/types/tkatype" + "tailscale.com/util/testenv" ) // ErrNoSuchKey is returned if the key referenced by a KeyID does not exist. @@ -261,8 +262,8 @@ func (s *State) staticValidateCheckpoint() error { if len(s.DisablementValues) == 0 { return errors.New("at least one disablement secret required") } - if numDS := len(s.DisablementValues); numDS > maxDisablementSecrets { - return fmt.Errorf("too many disablement secrets (%d, max %d)", numDS, maxDisablementSecrets) + if numDS := len(s.DisablementValues); numDS > maxDisablementValues { + return fmt.Errorf("too many disablement values (%d, max %d)", numDS, maxDisablementValues) } for i, ds := range s.DisablementValues { if len(ds) != disablementLength { @@ -313,3 +314,18 @@ func (s *State) staticValidateCheckpoint() error { } return nil } + +// CreateStateForTest creates a [State] that marks the given keys as trusted +// with an arbitrary disablement value. +// +// This is only for use in tests, and will panic if called outside a test. +func CreateStateForTest(keys ...Key) State { + testenv.AssertInTest() + + disablementSecret := bytes.Repeat([]byte{0xa5}, 32) + + return State{ + Keys: keys, + DisablementValues: [][]byte{DisablementKDF(disablementSecret)}, + } +} diff --git a/tka/sync_test.go b/tka/sync_test.go index 68f659ae5..48f197e8c 100644 --- a/tka/sync_test.go +++ b/tka/sync_test.go @@ -11,21 +11,29 @@ import ( "github.com/google/go-cmp/cmp" ) +// getSyncOffer returns a SyncOffer for the given Chonk. +func getSyncOffer(t *testing.T, storage Chonk) SyncOffer { + t.Helper() + + a, err := Open(storage) + if err != nil { + t.Fatal(err) + } + offer, err := a.SyncOffer(storage) + if err != nil { + t.Fatal(err) + } + + return offer +} + func TestSyncOffer(t *testing.T) { c := newTestchain(t, ` A1 -> A2 -> A3 -> A4 -> A5 -> A6 -> A7 -> A8 -> A9 -> A10 A10 -> A11 -> A12 -> A13 -> A14 -> A15 -> A16 -> A17 -> A18 A18 -> A19 -> A20 -> A21 -> A22 -> A23 -> A24 -> A25 `) - storage := c.Chonk() - a, err := Open(storage) - if err != nil { - t.Fatal(err) - } - got, err := a.SyncOffer(storage) - if err != nil { - t.Fatal(err) - } + got := getSyncOffer(t, c.Chonk()) // A SyncOffer includes a selection of AUMs going backwards in the tree, // progressively skipping more and more each iteration. @@ -52,24 +60,10 @@ func TestComputeSyncIntersection_FastForward(t *testing.T) { a1H, a2H := c.AUMHashes["A1"], c.AUMHashes["A2"] chonk1 := c.ChonkWith("A1", "A2") - n1, err := Open(chonk1) - if err != nil { - t.Fatal(err) - } - offer1, err := n1.SyncOffer(chonk1) - if err != nil { - t.Fatal(err) - } + offer1 := getSyncOffer(t, chonk1) chonk2 := c.Chonk() // All AUMs - n2, err := Open(chonk2) - if err != nil { - t.Fatal(err) - } - offer2, err := n2.SyncOffer(chonk2) - if err != nil { - t.Fatal(err) - } + offer2 := getSyncOffer(t, chonk2) // Node 1 only knows about the first two nodes, so the head of n2 is // alien to it. @@ -123,40 +117,28 @@ func TestComputeSyncIntersection_ForkSmallDiff(t *testing.T) { } chonk1 := c.ChonkWith("A1", "A2", "A3", "A4", "A5", "A6", "A7", "A8", "F1") - n1, err := Open(chonk1) - if err != nil { - t.Fatal(err) - } - offer1, err := n1.SyncOffer(chonk1) - if err != nil { - t.Fatal(err) - } - if diff := cmp.Diff(SyncOffer{ + offer1 := getSyncOffer(t, chonk1) + want1 := SyncOffer{ Head: c.AUMHashes["F1"], Ancestors: []AUMHash{ c.AUMHashes["A"+strconv.Itoa(9-ancestorsSkipStart)], c.AUMHashes["A1"], }, - }, offer1); diff != "" { + } + if diff := cmp.Diff(want1, offer1); diff != "" { t.Errorf("offer1 diff (-want, +got):\n%s", diff) } chonk2 := c.ChonkWith("A1", "A2", "A3", "A4", "A5", "A6", "A7", "A8", "A9", "A10") - n2, err := Open(chonk2) - if err != nil { - t.Fatal(err) - } - offer2, err := n2.SyncOffer(chonk2) - if err != nil { - t.Fatal(err) - } - if diff := cmp.Diff(SyncOffer{ + offer2 := getSyncOffer(t, chonk2) + want2 := SyncOffer{ Head: c.AUMHashes["A10"], Ancestors: []AUMHash{ c.AUMHashes["A"+strconv.Itoa(10-ancestorsSkipStart)], c.AUMHashes["A1"], }, - }, offer2); diff != "" { + } + if diff := cmp.Diff(want2, offer2); diff != "" { t.Errorf("offer2 diff (-want, +got):\n%s", diff) } @@ -339,10 +321,7 @@ func TestSyncSimpleE2E(t *testing.T) { G1 -> L1 -> L2 -> L3 G1.template = genesis `, - optTemplate("genesis", AUM{MessageKind: AUMCheckpoint, State: &State{ - Keys: []Key{key}, - DisablementValues: [][]byte{DisablementKDF([]byte{1, 2, 3})}, - }}), + genesisTemplate(key), optKey("key", key, priv), optSignAllUsing("key")) diff --git a/tka/tailchonk.go b/tka/tailchonk.go index da1915acb..3b083f327 100644 --- a/tka/tailchonk.go +++ b/tka/tailchonk.go @@ -597,9 +597,6 @@ func (c *FS) CommitVerifiedAUMs(updates []AUM) error { for i, aum := range updates { h := aum.Hash() err := c.commit(h, func(info *fsHashInfo) { - if info.PurgedUnix > 0 { - log.Printf("tka: CommitVerifiedAUMs: committing previously-deleted AUM %s", h.String()) - } info.PurgedUnix = 0 // just in-case it was set for some reason info.AUM = &aum }) @@ -976,8 +973,5 @@ func Compact(storage CompactableChonk, head AUMHash, opts CompactionOptions) (la if err := storage.SetLastActiveAncestor(lastActiveAncestor); err != nil { return AUMHash{}, err } - if len(toDelete) > 0 { - log.Printf("tka compaction: purging %d AUM(s) [%q]", len(toDelete), toDelete) - } return lastActiveAncestor, storage.PurgeAUMs(toDelete) } diff --git a/tka/tailchonk_test.go b/tka/tailchonk_test.go index afc4f4de0..125fbecc0 100644 --- a/tka/tailchonk_test.go +++ b/tka/tailchonk_test.go @@ -315,11 +315,6 @@ func TestMarkDescendantAUMs(t *testing.T) { } func TestMarkAncestorIntersectionAUMs(t *testing.T) { - fakeState := &State{ - Keys: []Key{{Kind: Key25519, Votes: 1}}, - DisablementValues: [][]byte{bytes.Repeat([]byte{1}, 32)}, - } - tcs := []struct { name string chain *testChain @@ -333,7 +328,7 @@ func TestMarkAncestorIntersectionAUMs(t *testing.T) { name: "genesis", chain: newTestchain(t, ` A - A.template = checkpoint`, optTemplate("checkpoint", AUM{MessageKind: AUMCheckpoint, State: fakeState})), + A.template = checkpoint`, checkpointTemplate()), initialAncestor: "A", wantAncestor: "A", verdicts: map[string]retainState{ @@ -346,7 +341,7 @@ func TestMarkAncestorIntersectionAUMs(t *testing.T) { chain: newTestchain(t, ` DEAD -> A -> B -> C A.template = checkpoint - B.template = checkpoint`, optTemplate("checkpoint", AUM{MessageKind: AUMCheckpoint, State: fakeState})), + B.template = checkpoint`, checkpointTemplate()), initialAncestor: "A", wantAncestor: "A", verdicts: map[string]retainState{ @@ -366,7 +361,7 @@ func TestMarkAncestorIntersectionAUMs(t *testing.T) { A.template = checkpoint C.template = checkpoint D.template = checkpoint - FORK.hashSeed = 2`, optTemplate("checkpoint", AUM{MessageKind: AUMCheckpoint, State: fakeState})), + FORK.hashSeed = 2`, checkpointTemplate()), initialAncestor: "D", wantAncestor: "C", verdicts: map[string]retainState{ @@ -387,7 +382,7 @@ func TestMarkAncestorIntersectionAUMs(t *testing.T) { A.template = checkpoint B.template = checkpoint E.template = checkpoint - FORK.hashSeed = 2`, optTemplate("checkpoint", AUM{MessageKind: AUMCheckpoint, State: fakeState})), + FORK.hashSeed = 2`, checkpointTemplate()), initialAncestor: "E", wantAncestor: "B", verdicts: map[string]retainState{ @@ -413,7 +408,7 @@ func TestMarkAncestorIntersectionAUMs(t *testing.T) { D.template = checkpoint E.template = checkpoint FORK.hashSeed = 2 - DEADFORK.hashSeed = 3`, optTemplate("checkpoint", AUM{MessageKind: AUMCheckpoint, State: fakeState})), + DEADFORK.hashSeed = 3`, checkpointTemplate()), initialAncestor: "D", wantAncestor: "C", verdicts: map[string]retainState{ @@ -443,7 +438,7 @@ func TestMarkAncestorIntersectionAUMs(t *testing.T) { F.template = checkpoint F1.hashSeed = 2 F2.hashSeed = 3 - F3.hashSeed = 4`, optTemplate("checkpoint", AUM{MessageKind: AUMCheckpoint, State: fakeState})), + F3.hashSeed = 4`, checkpointTemplate()), initialAncestor: "F", wantAncestor: "B", verdicts: map[string]retainState{ @@ -541,11 +536,6 @@ func cloneMem(src, dst *Mem) { } func TestCompact(t *testing.T) { - fakeState := &State{ - Keys: []Key{{Kind: Key25519, Votes: 1}}, - DisablementValues: [][]byte{bytes.Repeat([]byte{1}, 32)}, - } - // A & B are deleted because the new lastActiveAncestor advances beyond them. // OLD is deleted because it does not match retention criteria, and // though it is a descendant of the new lastActiveAncestor (C), it is not a @@ -578,7 +568,7 @@ func TestCompact(t *testing.T) { F1.hashSeed = 1 OLD.hashSeed = 2 G2.hashSeed = 3 - `, optTemplate("checkpoint", AUM{MessageKind: AUMCheckpoint, State: fakeState})) + `, checkpointTemplate()) storage := &compactingChonkFake{ aumAge: map[AUMHash]time.Time{(c.AUMHashes["F1"]): time.Now()}, @@ -607,12 +597,10 @@ func TestCompactLongButYoung(t *testing.T) { ourPriv := key.NewNLPrivate() ourKey := Key{Kind: Key25519, Public: ourPriv.Public().Verifier(), Votes: 1} someOtherKey := Key{Kind: Key25519, Public: key.NewNLPrivate().Public().Verifier(), Votes: 1} + state := CreateStateForTest(ourKey, someOtherKey) storage := ChonkMem() - auth, _, err := Create(storage, State{ - Keys: []Key{ourKey, someOtherKey}, - DisablementValues: [][]byte{DisablementKDF(bytes.Repeat([]byte{0xa5}, 32))}, - }, ourPriv) + auth, _, err := Create(storage, state, ourPriv) if err != nil { t.Fatalf("tka.Create() failed: %v", err) } diff --git a/tka/tka.go b/tka/tka.go index 57a8bd122..cb1c08326 100644 --- a/tka/tka.go +++ b/tka/tka.go @@ -10,7 +10,6 @@ import ( "bytes" "errors" "fmt" - "log" "os" "sort" @@ -557,8 +556,6 @@ func Bootstrap(storage Chonk, bootstrap AUM) (*Authority, error) { // Everything looks good, write it to storage. if err := storage.CommitVerifiedAUMs([]AUM{bootstrap}); err != nil { return nil, fmt.Errorf("commit: %v", err) - } else { - log.Printf("tka.Bootstrap: successfully committed bootstrap AUM (%s)", bootstrap.Hash()) } if err := storage.SetLastActiveAncestor(bootstrap.Hash()); err != nil { return nil, fmt.Errorf("set ancestor: %v", err) @@ -570,7 +567,7 @@ func Bootstrap(storage Chonk, bootstrap AUM) (*Authority, error) { // ValidDisablement returns true if the disablement secret was correct. // // If this method returns true, the caller should shut down the authority -// and purge all network-lock state. +// and purge all tailnet-lock state. func (a *Authority) ValidDisablement(secret []byte) bool { return a.state.checkDisablement(secret) } @@ -590,7 +587,6 @@ func (a *Authority) InformIdempotent(storage Chonk, updates []AUM) (Authority, e } stateAt := make(map[AUMHash]State, len(updates)+1) toCommit := make([]AUM, 0, len(updates)) - toCommitHashes := make([]AUMHash, 0, len(updates)) prevHash := a.Head() // The state at HEAD is the current state of the authority. It's likely @@ -640,13 +636,10 @@ func (a *Authority) InformIdempotent(storage Chonk, updates []AUM) (Authority, e } prevHash = hash toCommit = append(toCommit, update) - toCommitHashes = append(toCommitHashes, update.Hash()) } if err := storage.CommitVerifiedAUMs(toCommit); err != nil { return Authority{}, fmt.Errorf("commit: %v", err) - } else { - log.Printf("tka.CommitVerifiedAUMs: successfully committed %d AUMs: %v", len(toCommit), toCommitHashes) } if isHeadChain { diff --git a/tka/tka_test.go b/tka/tka_test.go index cb18be68d..f0ec3ff68 100644 --- a/tka/tka_test.go +++ b/tka/tka_test.go @@ -197,6 +197,7 @@ func TestComputeStateAt(t *testing.T) { // for tests you want one AUM to be 'lower' than another, so that // that chain is taken based on fork resolution rules). func fakeAUM(t *testing.T, template any, parent *AUMHash) (AUM, AUMHash) { + t.Helper() if seed, ok := template.(int); ok { a := AUM{MessageKind: AUMNoOp, KeyID: []byte{byte(seed)}} if parent != nil { @@ -299,15 +300,17 @@ func TestAuthorityHead(t *testing.T) { func TestAuthorityValidDisablement(t *testing.T) { pub, _ := testingKey25519(t, 1) key := Key{Kind: Key25519, Public: pub, Votes: 2} + disablementSecret := []byte{1, 2, 3} + state := State{ + Keys: []Key{key}, + DisablementValues: [][]byte{DisablementKDF(disablementSecret)}, + } c := newTestchain(t, ` G1 -> L1 G1.template = genesis `, - optTemplate("genesis", AUM{MessageKind: AUMCheckpoint, State: &State{ - Keys: []Key{key}, - DisablementValues: [][]byte{DisablementKDF([]byte{1, 2, 3})}, - }}), + optTemplate("genesis", AUM{MessageKind: AUMCheckpoint, State: &state}), ) a, _ := Open(c.Chonk()) @@ -320,10 +323,7 @@ func TestCreateBootstrapAuthority(t *testing.T) { pub, priv := testingKey25519(t, 1) key := Key{Kind: Key25519, Public: pub, Votes: 2} - a1, genesisAUM, err := Create(ChonkMem(), State{ - Keys: []Key{key}, - DisablementValues: [][]byte{DisablementKDF([]byte{1, 2, 3})}, - }, signer25519(priv)) + a1, genesisAUM, err := Create(ChonkMem(), CreateStateForTest(key), signer25519(priv)) if err != nil { t.Fatalf("Create() failed: %v", err) } @@ -352,10 +352,7 @@ func TestBootstrapChonkMustBeEmpty(t *testing.T) { pub, priv := testingKey25519(t, 1) key := Key{Kind: Key25519, Public: pub, Votes: 2} - state := State{ - Keys: []Key{key}, - DisablementValues: [][]byte{DisablementKDF([]byte{1, 2, 3})}, - } + state := CreateStateForTest(key) // Bootstrap our chonk for the first time, which should succeed. _, _, err := Create(chonk, state, signer25519(priv)) @@ -415,14 +412,11 @@ func TestAuthorityInformNonLinear(t *testing.T) { | -> L4 -> L5 G1.template = genesis - L1.hashSeed = 3 + L1.hashSeed = 2 L2.hashSeed = 2 L4.hashSeed = 2 `, - optTemplate("genesis", AUM{MessageKind: AUMCheckpoint, State: &State{ - Keys: []Key{key}, - DisablementValues: [][]byte{DisablementKDF([]byte{1, 2, 3})}, - }}), + genesisTemplate(key), optKey("key", key, priv), optSignAllUsing("key")) @@ -451,6 +445,8 @@ func TestAuthorityInformNonLinear(t *testing.T) { } if a.Head() != c.AUMHashes["L3"] { + t.Logf("a.Head() = %s", a.Head()) + t.Logf("auMHashes = %v", c.AUMHashes) t.Fatal("authority did not converge to correct AUM") } } @@ -464,10 +460,7 @@ func TestAuthorityInformLinear(t *testing.T) { G1.template = genesis `, - optTemplate("genesis", AUM{MessageKind: AUMCheckpoint, State: &State{ - Keys: []Key{key}, - DisablementValues: [][]byte{DisablementKDF([]byte{1, 2, 3})}, - }}), + genesisTemplate(key), optKey("key", key, priv), optSignAllUsing("key")) @@ -504,21 +497,12 @@ func TestInteropWithNLKey(t *testing.T) { pub2 := key.NewNLPrivate().Public() pub3 := key.NewNLPrivate().Public() - a, _, err := Create(ChonkMem(), State{ - Keys: []Key{ - { - Kind: Key25519, - Votes: 1, - Public: pub1.KeyID(), - }, - { - Kind: Key25519, - Votes: 1, - Public: pub2.KeyID(), - }, - }, - DisablementValues: [][]byte{DisablementKDF([]byte{1, 2, 3})}, - }, priv1) + state := CreateStateForTest( + Key{Kind: Key25519, Votes: 1, Public: pub1.KeyID()}, + Key{Kind: Key25519, Votes: 1, Public: pub2.KeyID()}, + ) + + a, _, err := Create(ChonkMem(), state, priv1) if err != nil { t.Errorf("tka.Create: %v", err) return @@ -538,6 +522,7 @@ func TestInteropWithNLKey(t *testing.T) { func TestAuthorityCompact(t *testing.T) { pub, priv := testingKey25519(t, 1) key := Key{Kind: Key25519, Public: pub, Votes: 2} + state := CreateStateForTest(key) c := newTestchain(t, ` G -> A -> B -> C -> D -> E @@ -545,14 +530,8 @@ func TestAuthorityCompact(t *testing.T) { G.template = genesis C.template = checkpoint2 `, - optTemplate("genesis", AUM{MessageKind: AUMCheckpoint, State: &State{ - Keys: []Key{key}, - DisablementValues: [][]byte{DisablementKDF([]byte{1, 2, 3})}, - }}), - optTemplate("checkpoint2", AUM{MessageKind: AUMCheckpoint, State: &State{ - Keys: []Key{key}, - DisablementValues: [][]byte{DisablementKDF([]byte{1, 2, 3})}, - }}), + genesisTemplate(key), + optTemplate("checkpoint2", AUM{MessageKind: AUMCheckpoint, State: &state}), optKey("key", key, priv), optSignAllUsing("key")) @@ -602,10 +581,7 @@ func TestFindParentForRewrite(t *testing.T) { C.template = add3 D.template = remove2 `, - optTemplate("genesis", AUM{MessageKind: AUMCheckpoint, State: &State{ - Keys: []Key{k1}, - DisablementValues: [][]byte{DisablementKDF([]byte{1, 2, 3})}, - }}), + genesisTemplate(k1), optTemplate("add2", AUM{MessageKind: AUMAddKey, Key: &k2}), optTemplate("add3", AUM{MessageKind: AUMAddKey, Key: &k3}), optTemplate("remove2", AUM{MessageKind: AUMRemoveKey, KeyID: k2ID})) @@ -671,10 +647,7 @@ func TestMakeRetroactiveRevocation(t *testing.T) { C.template = add2 D.template = add3 `, - optTemplate("genesis", AUM{MessageKind: AUMCheckpoint, State: &State{ - Keys: []Key{k1}, - DisablementValues: [][]byte{DisablementKDF([]byte{1, 2, 3})}, - }}), + genesisTemplate(k1), optTemplate("add2", AUM{MessageKind: AUMAddKey, Key: &k2}), optTemplate("add3", AUM{MessageKind: AUMAddKey, Key: &k3})) diff --git a/tool/listpkgs/listpkgs.go b/tool/listpkgs/listpkgs.go index 1c2dda257..b29db94b1 100644 --- a/tool/listpkgs/listpkgs.go +++ b/tool/listpkgs/listpkgs.go @@ -10,9 +10,12 @@ import ( "flag" "fmt" "go/build/constraint" + "io/fs" "log" "os" + "path/filepath" "slices" + "sort" "strings" "sync" @@ -27,11 +30,18 @@ var ( withoutTagsAnyStr = flag.String("without-tags-any", "", "if non-empty, a comma-separated list of build constraints to exclude (a package will be omitted if it contains any of these build tags)") shard = flag.String("shard", "", "if non-empty, a string of the form 'N/M' to only print packages in shard N of M (e.g. '1/3', '2/3', '3/3/' for different thirds of the list)") affectedByTag = flag.String("affected-by-tag", "", "if non-empty, only list packages whose test binary would be affected by the presence or absence of this build tag") + hasRootTests = flag.Bool("has-root-tests", false, "list packages (as ./relative/path) containing _test.go files that call tstest.RequireRoot") + hasGoGenerate = flag.Bool("has-go-generate", false, "only list packages that contain at least one //go:generate directive") ) func main() { flag.Parse() + if *hasRootTests { + printRootTestPkgs() + return + } + patterns := flag.Args() if len(patterns) == 0 { flag.Usage() @@ -112,6 +122,9 @@ Pkg: continue Pkg } } + if *hasGoGenerate && !pkgHasGoGenerate(pkg) { + continue Pkg + } matches++ if *shard != "" { @@ -281,3 +294,123 @@ func fileMentionsTag(filename, tag string) (bool, error) { } return tags[tag], nil } + +// pkgHasGoGenerate reports whether any source file in pkg contains a +// //go:generate directive. +func pkgHasGoGenerate(pkg *packages.Package) bool { + // Include IgnoredFiles so directives behind build constraints are still + // found; the caller can narrow by tag via -with-tags-all/-without-tags-any + // if they care. + all := slices.Concat(pkg.CompiledGoFiles, pkg.OtherFiles, pkg.IgnoredFiles) + for _, name := range all { + ok, err := fileHasGoGenerate(name) + if err != nil { + log.Printf("reading %s: %v", name, err) + continue + } + if ok { + return true + } + } + return false +} + +var ( + goGenerateMu sync.Mutex + goGenerate = map[string]bool{} // abs path -> whether file has //go:generate +) + +func fileHasGoGenerate(filename string) (bool, error) { + goGenerateMu.Lock() + v, ok := goGenerate[filename] + goGenerateMu.Unlock() + if ok { + return v, nil + } + + f, err := os.Open(filename) + if err != nil { + return false, err + } + defer f.Close() + + has := false + s := bufio.NewScanner(f) + for s.Scan() { + // go:generate directives must start at column 1 (no leading + // whitespace) to be recognized by the go tool. + if strings.HasPrefix(s.Text(), "//go:generate") { + has = true + break + } + } + if err := s.Err(); err != nil { + return false, fmt.Errorf("reading %s: %w", filename, err) + } + + goGenerateMu.Lock() + goGenerate[filename] = has + goGenerateMu.Unlock() + return has, nil +} + +// printRootTestPkgs walks the current directory tree looking for _test.go +// files that contain "tstest.RequireRoot" and prints the unique package +// directories as ./relative/path. +func printRootTestPkgs() { + root, err := os.Getwd() + if err != nil { + log.Fatal(err) + } + seen := map[string]bool{} + var dirs []string + filepath.WalkDir(root, func(path string, d fs.DirEntry, err error) error { + if err != nil { + return nil + } + name := d.Name() + if d.IsDir() { + // Skip hidden dirs and common non-Go dirs. + if strings.HasPrefix(name, ".") || name == "vendor" || name == "node_modules" { + return filepath.SkipDir + } + return nil + } + if !strings.HasSuffix(name, "_test.go") { + return nil + } + rel, err := filepath.Rel(root, path) + if err != nil { + return nil + } + dir := filepath.Dir(rel) + if seen[dir] { + return nil // already found a match in this dir + } + if fileContains(path, "tstest.RequireRoot") { + seen[dir] = true + dirs = append(dirs, dir) + } + return nil + }) + sort.Strings(dirs) + for _, d := range dirs { + fmt.Println("./" + filepath.ToSlash(d)) + } +} + +// fileContains reports whether the file at path contains the given substring. +func fileContains(path, substr string) bool { + f, err := os.Open(path) + if err != nil { + return false + } + defer f.Close() + s := bufio.NewScanner(f) + for s.Scan() { + if strings.Contains(s.Text(), substr) { + return true + } + } + return false +} diff --git a/tool/updateflakes/updateflakes.go b/tool/updateflakes/updateflakes.go new file mode 100644 index 000000000..e2a572d12 --- /dev/null +++ b/tool/updateflakes/updateflakes.go @@ -0,0 +1,264 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +// updateflakes regenerates flakehashes.json, the file that records +// the Nix SRI hashes for the Go module vendor tree and the Tailscale +// Go toolchain tarball. +// +// The file is content-addressed: each block records the input +// fingerprint that produced its SRI, and updateflakes only +// regenerates a block when the current input differs from the +// recorded fingerprint. As a result, repeat runs with no input +// changes are no-ops. +// +// Run from the repo root: +// +// ./tool/go run ./tool/updateflakes +package main + +import ( + "bytes" + "crypto/sha256" + "encoding/base64" + "encoding/json" + "errors" + "flag" + "fmt" + "io/fs" + "log" + "net/http" + "os" + "os/exec" + "path/filepath" + "strings" + + "golang.org/x/sync/errgroup" + "tailscale.com/cmd/nardump/nardump" +) + +const ( + hashesFile = "flakehashes.json" + goModFile = "go.mod" + goSumFile = "go.sum" + toolchainRevFile = "go.toolchain.rev" + flakeNixFile = "flake.nix" + shellNixFile = "shell.nix" + cacheBustPrefix = "# nix-direnv cache busting line:" +) + +// FlakeHashes is the on-disk schema of flakehashes.json. It is also +// consumed directly by flake.nix via builtins.fromJSON, so changes +// to the JSON shape must be coordinated with flake.nix. +type FlakeHashes struct { + Toolchain ToolchainHash `json:"toolchain"` + Vendor VendorHash `json:"vendor"` +} + +// ToolchainHash records the SRI of the Tailscale Go toolchain +// tarball. Rev is the value in go.toolchain.rev that produced SRI. +type ToolchainHash struct { + Rev string `json:"rev"` + SRI string `json:"sri"` +} + +// VendorHash records the SRI of `go mod vendor` output. GoModSum is a +// fingerprint of go.mod and go.sum that produced SRI. +type VendorHash struct { + GoModSum string `json:"goModSum"` + SRI string `json:"sri"` +} + +func main() { + flag.Parse() + if err := run(); err != nil { + log.Fatal(err) + } +} + +func run() error { + have, err := loadHashes() + if err != nil { + return err + } + want := have + + rev, err := readTrim(toolchainRevFile) + if err != nil { + return err + } + wantToolchain := have.Toolchain.Rev != rev || have.Toolchain.SRI == "" + + goModSum, err := goModFingerprint() + if err != nil { + return err + } + wantVendor := have.Vendor.GoModSum != goModSum || have.Vendor.SRI == "" + + var ( + newToolchain ToolchainHash + newVendor VendorHash + ) + var g errgroup.Group + if wantToolchain { + g.Go(func() error { + sri, err := hashToolchain(rev) + if err != nil { + return err + } + newToolchain = ToolchainHash{Rev: rev, SRI: sri} + return nil + }) + } + if wantVendor { + g.Go(func() error { + sri, err := hashVendor() + if err != nil { + return err + } + newVendor = VendorHash{GoModSum: goModSum, SRI: sri} + return nil + }) + } + if err := g.Wait(); err != nil { + return err + } + if wantToolchain { + want.Toolchain = newToolchain + } + if wantVendor { + want.Vendor = newVendor + } + + if want != have { + if err := writeHashes(want); err != nil { + return err + } + } + + // nix-direnv only watches the top-level nix files for changes, + // so when a referenced hash changes we must also tickle + // flake.nix and shell.nix to force re-evaluation. + for _, f := range []string{flakeNixFile, shellNixFile} { + if err := updateCacheBust(f, want.Vendor.SRI); err != nil { + return err + } + } + return nil +} + +func loadHashes() (FlakeHashes, error) { + var h FlakeHashes + data, err := os.ReadFile(hashesFile) + if errors.Is(err, fs.ErrNotExist) { + return h, nil + } + if err != nil { + return h, err + } + if err := json.Unmarshal(data, &h); err != nil { + return h, fmt.Errorf("parse %s: %w", hashesFile, err) + } + return h, nil +} + +func writeHashes(h FlakeHashes) error { + b, err := json.MarshalIndent(h, "", " ") + if err != nil { + return err + } + b = append(b, '\n') + return os.WriteFile(hashesFile, b, 0644) +} + +func readTrim(path string) (string, error) { + b, err := os.ReadFile(path) + if err != nil { + return "", err + } + return strings.TrimSpace(string(b)), nil +} + +// goModFingerprint returns a content fingerprint of go.mod and go.sum +// that changes whenever either file changes. +func goModFingerprint() (string, error) { + h := sha256.New() + for _, f := range []string{goModFile, goSumFile} { + b, err := os.ReadFile(f) + if err != nil { + return "", err + } + fmt.Fprintf(h, "%s %d\n", f, len(b)) + h.Write(b) + } + return "sha256-" + base64.StdEncoding.EncodeToString(h.Sum(nil)), nil +} + +func hashVendor() (string, error) { + out, err := os.MkdirTemp("", "nar-vendor-") + if err != nil { + return "", err + } + // `go mod vendor -o` requires the destination to not already exist. + if err := os.Remove(out); err != nil { + return "", err + } + defer os.RemoveAll(out) + + cmd := exec.Command("./tool/go", "mod", "vendor", "-o", out) + cmd.Env = append(os.Environ(), "GOWORK=off") + cmd.Stderr = os.Stderr + if err := cmd.Run(); err != nil { + return "", fmt.Errorf("go mod vendor: %w", err) + } + return nardump.SRI(os.DirFS(out)) +} + +func hashToolchain(rev string) (string, error) { + out, err := os.MkdirTemp("", "nar-toolchain-") + if err != nil { + return "", err + } + defer os.RemoveAll(out) + + url := fmt.Sprintf("https://github.com/tailscale/go/archive/%s.tar.gz", rev) + resp, err := http.Get(url) + if err != nil { + return "", fmt.Errorf("fetching %s: %w", url, err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("fetching %s: %s", url, resp.Status) + } + + tar := exec.Command("tar", "-xz", "-C", out) + tar.Stdin = resp.Body + tar.Stderr = os.Stderr + if err := tar.Run(); err != nil { + return "", fmt.Errorf("extracting toolchain tarball: %w", err) + } + return nardump.SRI(os.DirFS(filepath.Join(out, "go-"+rev))) +} + +// updateCacheBust rewrites the "# nix-direnv cache busting line" +// in path to embed sri so nix-direnv re-evaluates when the SRI +// changes. The line lives at end of file, so walk in reverse. +func updateCacheBust(path, sri string) error { + b, err := os.ReadFile(path) + if err != nil { + return err + } + want := []byte(cacheBustPrefix + " " + sri) + lines := bytes.Split(b, []byte("\n")) + for i := len(lines) - 1; i >= 0; i-- { + line := lines[i] + if !bytes.HasPrefix(line, []byte(cacheBustPrefix)) { + continue + } + if bytes.Equal(line, want) { + return nil + } + lines[i] = want + return os.WriteFile(path, bytes.Join(lines, []byte("\n")), 0644) + } + return fmt.Errorf("%s: missing %q line", path, cacheBustPrefix) +} diff --git a/tsconsensus/monitor.go b/tsconsensus/monitor.go index b937926a6..bf7410d0d 100644 --- a/tsconsensus/monitor.go +++ b/tsconsensus/monitor.go @@ -12,7 +12,6 @@ import ( "net/http" "slices" - "tailscale.com/ipn" "tailscale.com/ipn/ipnstate" "tailscale.com/tsnet" "tailscale.com/util/dnsname" @@ -108,24 +107,16 @@ func (m *monitor) handleNetmap(w http.ResponseWriter, r *http.Request) { http.Error(w, "", http.StatusInternalServerError) return } - watcher, err := lc.WatchIPNBus(r.Context(), ipn.NotifyInitialNetMap) + st, err := lc.Status(r.Context()) if err != nil { - log.Printf("monitor: error WatchIPNBus: %v", err) - http.Error(w, "", http.StatusInternalServerError) - return - } - defer watcher.Close() - - n, err := watcher.Next() - if err != nil { - log.Printf("monitor: error watcher.Next: %v", err) + log.Printf("monitor: error fetching status: %v", err) http.Error(w, "", http.StatusInternalServerError) return } encoder := json.NewEncoder(w) encoder.SetIndent("", "\t") - if err := encoder.Encode(n); err != nil { - log.Printf("monitor: error encoding netmap: %v", err) + if err := encoder.Encode(st); err != nil { + log.Printf("monitor: error encoding status: %v", err) return } } diff --git a/tsd/tsd.go b/tsd/tsd.go index 615c9c0e7..98b6a6bd3 100644 --- a/tsd/tsd.go +++ b/tsd/tsd.go @@ -20,6 +20,7 @@ package tsd import ( "crypto/x509" "fmt" + "net/http" "reflect" "tailscale.com/control/controlknobs" @@ -64,6 +65,10 @@ type System struct { PolicyClient SubSystem[policyclient.Client] HealthTracker SubSystem[*health.Tracker] + // NoiseRoundTripper, if set, provides an http.RoundTripper that + // sends requests over the control plane Noise connection. + NoiseRoundTripper SubSystem[http.RoundTripper] + // ExtraRootCAs, if non-nil, specifies additional trusted root CAs // beyond the system roots. On Android, this includes user-installed // CA certificates that Go's crypto/x509 does not see. diff --git a/tsnet/README.md b/tsnet/README.md new file mode 100644 index 000000000..f9a96af00 --- /dev/null +++ b/tsnet/README.md @@ -0,0 +1,109 @@ + + +# tsnet + +[![Go Reference](https://pkg.go.dev/badge/tailscale.com/tsnet.svg)](https://pkg.go.dev/tailscale.com/tsnet) + +Package tsnet embeds a Tailscale node directly into a Go program, allowing it to join a tailnet and accept or dial connections without running a separate tailscaled daemon or requiring any system-level configuration. + +## Overview + +Normally, Tailscale runs as a background system service (tailscaled) that manages a virtual network interface for the whole machine. tsnet takes a different approach: it runs a fully self-contained Tailscale node inside your process using a userspace TCP/IP stack (gVisor). This means: + + - No root privileges required. + - No system daemons to install or manage. + - Multiple independent Tailscale nodes can run within a single binary. + - The node's [Tailscale identity](https://tailscale.com/docs/concepts/tailscale-identity) and state are stored in a directory you control. + +The core type is [Server](https://pkg.go.dev/tailscale.com/tsnet#Server), which represents one embedded Tailscale node. Calling [Server.Listen](https://pkg.go.dev/tailscale.com/tsnet#Server.Listen) or [Server.Dial](https://pkg.go.dev/tailscale.com/tsnet#Server.Dial) routes traffic exclusively over the tailnet. The standard library's [net.Listener](https://pkg.go.dev/net#Listener) and [net.Conn](https://pkg.go.dev/net#Conn) interfaces are returned, so any existing Go HTTP server, gRPC server, or other net-based code works without modification. + +## Usage + + import "tailscale.com/tsnet" + + s := &tsnet.Server{ + Hostname: "my-service", + AuthKey: os.Getenv("TS_AUTHKEY"), + } + defer s.Close() + + ln, err := s.Listen("tcp", ":80") + if err != nil { + log.Fatal(err) + } + log.Fatal(http.Serve(ln, myHandler)) + +On first run, if no [Server.AuthKey](https://pkg.go.dev/tailscale.com/tsnet#Server.AuthKey) is provided and the node is not already enrolled, the server logs an authentication URL. Open it in a browser to add the node to your tailnet. + +## Authentication + +A [Server](https://pkg.go.dev/tailscale.com/tsnet#Server) authenticates using, in order of precedence: + + 1. [Server.AuthKey](https://pkg.go.dev/tailscale.com/tsnet#Server.AuthKey). + + 2. The TS\_AUTHKEY environment variable. + + 3. The TS\_AUTH\_KEY environment variable. + + 4. An OAuth client secret ([Server.ClientSecret](https://pkg.go.dev/tailscale.com/tsnet#Server.ClientSecret) or TS\_CLIENT\_SECRET), used to mint an auth key. + + 5. Workload identity federation ([Server.ClientID](https://pkg.go.dev/tailscale.com/tsnet#Server.ClientID) plus [Server.IDToken](https://pkg.go.dev/tailscale.com/tsnet#Server.IDToken) or [Server.Audience](https://pkg.go.dev/tailscale.com/tsnet#Server.Audience)). Available only if the program imports the feature: + + import \_ "tailscale.com/feature/identityfederation" + + The feature is not linked by default to keep the AWS SDK and other cloud-provider dependencies out of programs that don't use workload identity federation. + + 6. An interactive login URL printed to [Server.UserLogf](https://pkg.go.dev/tailscale.com/tsnet#Server.UserLogf). + +If the node is already enrolled (state found in [Server.Store](https://pkg.go.dev/tailscale.com/tsnet#Server.Store)), the auth key is ignored unless TSNET\_FORCE\_LOGIN=1 is set. + +## Identifying callers + +Use the WhoIs method on the client returned by [Server.LocalClient](https://pkg.go.dev/tailscale.com/tsnet#Server.LocalClient) to identify who is making a request: + + lc, _ := srv.LocalClient() + http.Serve(ln, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + who, err := lc.WhoIs(r.Context(), r.RemoteAddr) + if err != nil { + http.Error(w, err.Error(), 500) + return + } + fmt.Fprintf(w, "Hello, %s!", who.UserProfile.LoginName) + })) + +## Tailscale Funnel + +[Server.ListenFunnel](https://pkg.go.dev/tailscale.com/tsnet#Server.ListenFunnel) exposes your service on the public internet. [Tailscale Funnel](https://tailscale.com/docs/features/tailscale-funnel) currently supports TCP on ports 443, 8443, and 10000. HTTPS must be enabled in the Tailscale admin console. + + ln, err := srv.ListenFunnel("tcp", ":443") + // ln is a TLS listener; connections can come from anywhere on the + // internet as well as from your tailnet. + + // To restrict to public traffic only: + ln, err = srv.ListenFunnel("tcp", ":443", tsnet.FunnelOnly()) + +## Tailscale Services + +[Server.ListenService](https://pkg.go.dev/tailscale.com/tsnet#Server.ListenService) advertises the node as a host for a named [Tailscale Service](https://tailscale.com/docs/features/tailscale-services). The node must use a tag-based identity. To advertise multiple ports, call ListenService once per port. + + srv.AdvertiseTags = []string{"tag:myservice"} + + ln, err := srv.ListenService("svc:my-service", tsnet.ServiceModeHTTP{ + HTTPS: true, + Port: 443, + }) + log.Printf("Listening on https://%s", ln.FQDN) + +## Running multiple nodes in one process + +Each [Server](https://pkg.go.dev/tailscale.com/tsnet#Server) instance is an independent node. Give each a unique [Server.Dir](https://pkg.go.dev/tailscale.com/tsnet#Server.Dir) and [Server.Hostname](https://pkg.go.dev/tailscale.com/tsnet#Server.Hostname): + + for _, name := range []string{"frontend", "backend"} { + srv := &tsnet.Server{ + Hostname: name, + Dir: filepath.Join(baseDir, name), + AuthKey: os.Getenv("TS_AUTHKEY"), + Ephemeral: true, + } + srv.Start() + } diff --git a/tsnet/depaware.txt b/tsnet/depaware.txt index b8b6aec98..a4eed2a13 100644 --- a/tsnet/depaware.txt +++ b/tsnet/depaware.txt @@ -6,77 +6,6 @@ tailscale.com/tsnet dependencies: (generated by github.com/tailscale/depaware) W 💣 github.com/alexbrainman/sspi from github.com/alexbrainman/sspi/internal/common+ W github.com/alexbrainman/sspi/internal/common from github.com/alexbrainman/sspi/negotiate W 💣 github.com/alexbrainman/sspi/negotiate from tailscale.com/net/tshttpproxy - github.com/aws/aws-sdk-go-v2/aws from github.com/aws/aws-sdk-go-v2/aws/defaults+ - github.com/aws/aws-sdk-go-v2/aws/defaults from github.com/aws/aws-sdk-go-v2/service/sso+ - github.com/aws/aws-sdk-go-v2/aws/middleware from github.com/aws/aws-sdk-go-v2/aws/retry+ - github.com/aws/aws-sdk-go-v2/aws/protocol/query from github.com/aws/aws-sdk-go-v2/service/sts - github.com/aws/aws-sdk-go-v2/aws/protocol/restjson from github.com/aws/aws-sdk-go-v2/service/sso+ - github.com/aws/aws-sdk-go-v2/aws/protocol/xml from github.com/aws/aws-sdk-go-v2/service/sts - github.com/aws/aws-sdk-go-v2/aws/ratelimit from github.com/aws/aws-sdk-go-v2/aws/retry - github.com/aws/aws-sdk-go-v2/aws/retry from github.com/aws/aws-sdk-go-v2/credentials/endpointcreds/internal/client+ - github.com/aws/aws-sdk-go-v2/aws/signer/internal/v4 from github.com/aws/aws-sdk-go-v2/aws/signer/v4 - github.com/aws/aws-sdk-go-v2/aws/signer/v4 from github.com/aws/aws-sdk-go-v2/internal/auth/smithy+ - github.com/aws/aws-sdk-go-v2/aws/transport/http from github.com/aws/aws-sdk-go-v2/config+ - github.com/aws/aws-sdk-go-v2/config from tailscale.com/wif - github.com/aws/aws-sdk-go-v2/credentials from github.com/aws/aws-sdk-go-v2/config - github.com/aws/aws-sdk-go-v2/credentials/ec2rolecreds from github.com/aws/aws-sdk-go-v2/config - github.com/aws/aws-sdk-go-v2/credentials/endpointcreds from github.com/aws/aws-sdk-go-v2/config - github.com/aws/aws-sdk-go-v2/credentials/endpointcreds/internal/client from github.com/aws/aws-sdk-go-v2/credentials/endpointcreds - github.com/aws/aws-sdk-go-v2/credentials/processcreds from github.com/aws/aws-sdk-go-v2/config - github.com/aws/aws-sdk-go-v2/credentials/ssocreds from github.com/aws/aws-sdk-go-v2/config - github.com/aws/aws-sdk-go-v2/credentials/stscreds from github.com/aws/aws-sdk-go-v2/config - github.com/aws/aws-sdk-go-v2/feature/ec2/imds from github.com/aws/aws-sdk-go-v2/config+ - github.com/aws/aws-sdk-go-v2/feature/ec2/imds/internal/config from github.com/aws/aws-sdk-go-v2/feature/ec2/imds - github.com/aws/aws-sdk-go-v2/internal/auth from github.com/aws/aws-sdk-go-v2/aws/signer/v4+ - github.com/aws/aws-sdk-go-v2/internal/auth/smithy from github.com/aws/aws-sdk-go-v2/service/sso+ - github.com/aws/aws-sdk-go-v2/internal/configsources from github.com/aws/aws-sdk-go-v2/service/sso+ - github.com/aws/aws-sdk-go-v2/internal/context from github.com/aws/aws-sdk-go-v2/aws/retry+ - github.com/aws/aws-sdk-go-v2/internal/endpoints from github.com/aws/aws-sdk-go-v2/service/sso+ - github.com/aws/aws-sdk-go-v2/internal/endpoints/awsrulesfn from github.com/aws/aws-sdk-go-v2/service/sso+ - github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 from github.com/aws/aws-sdk-go-v2/service/sso/internal/endpoints+ - github.com/aws/aws-sdk-go-v2/internal/ini from github.com/aws/aws-sdk-go-v2/config - github.com/aws/aws-sdk-go-v2/internal/middleware from github.com/aws/aws-sdk-go-v2/service/sso+ - github.com/aws/aws-sdk-go-v2/internal/rand from github.com/aws/aws-sdk-go-v2/aws+ - github.com/aws/aws-sdk-go-v2/internal/sdk from github.com/aws/aws-sdk-go-v2/aws+ - github.com/aws/aws-sdk-go-v2/internal/sdkio from github.com/aws/aws-sdk-go-v2/credentials/processcreds - github.com/aws/aws-sdk-go-v2/internal/shareddefaults from github.com/aws/aws-sdk-go-v2/config+ - github.com/aws/aws-sdk-go-v2/internal/strings from github.com/aws/aws-sdk-go-v2/aws/signer/internal/v4 - github.com/aws/aws-sdk-go-v2/internal/sync/singleflight from github.com/aws/aws-sdk-go-v2/aws - github.com/aws/aws-sdk-go-v2/internal/timeconv from github.com/aws/aws-sdk-go-v2/aws/retry - github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding from github.com/aws/aws-sdk-go-v2/service/sts - github.com/aws/aws-sdk-go-v2/service/internal/presigned-url from github.com/aws/aws-sdk-go-v2/service/sts - github.com/aws/aws-sdk-go-v2/service/sso from github.com/aws/aws-sdk-go-v2/config+ - github.com/aws/aws-sdk-go-v2/service/sso/internal/endpoints from github.com/aws/aws-sdk-go-v2/service/sso - github.com/aws/aws-sdk-go-v2/service/sso/types from github.com/aws/aws-sdk-go-v2/service/sso - github.com/aws/aws-sdk-go-v2/service/ssooidc from github.com/aws/aws-sdk-go-v2/config+ - github.com/aws/aws-sdk-go-v2/service/ssooidc/internal/endpoints from github.com/aws/aws-sdk-go-v2/service/ssooidc - github.com/aws/aws-sdk-go-v2/service/ssooidc/types from github.com/aws/aws-sdk-go-v2/service/ssooidc - github.com/aws/aws-sdk-go-v2/service/sts from github.com/aws/aws-sdk-go-v2/config+ - github.com/aws/aws-sdk-go-v2/service/sts/internal/endpoints from github.com/aws/aws-sdk-go-v2/service/sts - github.com/aws/aws-sdk-go-v2/service/sts/types from github.com/aws/aws-sdk-go-v2/credentials/stscreds+ - github.com/aws/smithy-go from github.com/aws/aws-sdk-go-v2/aws/protocol/restjson+ - github.com/aws/smithy-go/auth from github.com/aws/aws-sdk-go-v2/internal/auth+ - github.com/aws/smithy-go/auth/bearer from github.com/aws/aws-sdk-go-v2/aws+ - github.com/aws/smithy-go/context from github.com/aws/smithy-go/auth/bearer - github.com/aws/smithy-go/document from github.com/aws/aws-sdk-go-v2/service/sso+ - github.com/aws/smithy-go/encoding from github.com/aws/smithy-go/encoding/json+ - github.com/aws/smithy-go/encoding/httpbinding from github.com/aws/aws-sdk-go-v2/aws/protocol/query+ - github.com/aws/smithy-go/encoding/json from github.com/aws/aws-sdk-go-v2/service/ssooidc - github.com/aws/smithy-go/encoding/xml from github.com/aws/aws-sdk-go-v2/service/sts - github.com/aws/smithy-go/endpoints from github.com/aws/aws-sdk-go-v2/service/sso+ - github.com/aws/smithy-go/endpoints/private/rulesfn from github.com/aws/aws-sdk-go-v2/service/sts - github.com/aws/smithy-go/internal/sync/singleflight from github.com/aws/smithy-go/auth/bearer - github.com/aws/smithy-go/io from github.com/aws/aws-sdk-go-v2/feature/ec2/imds+ - github.com/aws/smithy-go/logging from github.com/aws/aws-sdk-go-v2/aws+ - github.com/aws/smithy-go/metrics from github.com/aws/aws-sdk-go-v2/aws/retry+ - github.com/aws/smithy-go/middleware from github.com/aws/aws-sdk-go-v2/aws+ - github.com/aws/smithy-go/private/requestcompression from github.com/aws/aws-sdk-go-v2/config - github.com/aws/smithy-go/ptr from github.com/aws/aws-sdk-go-v2/aws+ - github.com/aws/smithy-go/rand from github.com/aws/aws-sdk-go-v2/aws/middleware - github.com/aws/smithy-go/time from github.com/aws/aws-sdk-go-v2/service/sso+ - github.com/aws/smithy-go/tracing from github.com/aws/aws-sdk-go-v2/aws/middleware+ - github.com/aws/smithy-go/transport/http from github.com/aws/aws-sdk-go-v2/aws+ - github.com/aws/smithy-go/transport/http/internal/io from github.com/aws/smithy-go/transport/http LDW github.com/coder/websocket from tailscale.com/util/eventbus LDW github.com/coder/websocket/internal/errd from github.com/coder/websocket LDW github.com/coder/websocket/internal/util from github.com/coder/websocket @@ -105,7 +34,6 @@ tailscale.com/tsnet dependencies: (generated by github.com/tailscale/depaware) L 💣 github.com/godbus/dbus/v5 from tailscale.com/net/dns github.com/golang/groupcache/lru from tailscale.com/net/dnscache github.com/google/btree from gvisor.dev/gvisor/pkg/tcpip/transport/tcp - DI github.com/google/uuid from github.com/prometheus-community/pro-bing github.com/hdevalence/ed25519consensus from tailscale.com/tka github.com/huin/goupnp from github.com/huin/goupnp/dcps/internetgateway2+ github.com/huin/goupnp/dcps/internetgateway2 from tailscale.com/net/portmapper @@ -128,9 +56,8 @@ tailscale.com/tsnet dependencies: (generated by github.com/tailscale/depaware) LA 💣 github.com/mdlayher/socket from github.com/mdlayher/netlink+ LDW 💣 github.com/mitchellh/go-ps from tailscale.com/safesocket github.com/pires/go-proxyproto from tailscale.com/ipn/ipnlocal - DI github.com/prometheus-community/pro-bing from tailscale.com/wgengine/netstack L 💣 github.com/safchain/ethtool from tailscale.com/net/netkernelconf - W 💣 github.com/tailscale/certstore from tailscale.com/control/controlclient + DW 💣 github.com/tailscale/certstore from tailscale.com/control/controlclient W 💣 github.com/tailscale/go-winio from tailscale.com/safesocket W 💣 github.com/tailscale/go-winio/internal/fs from github.com/tailscale/go-winio W 💣 github.com/tailscale/go-winio/internal/socket from github.com/tailscale/go-winio @@ -219,11 +146,9 @@ tailscale.com/tsnet dependencies: (generated by github.com/tailscale/depaware) tailscale.com/feature/buildfeatures from tailscale.com/wgengine/magicsock+ tailscale.com/feature/c2n from tailscale.com/tsnet tailscale.com/feature/condlite/expvar from tailscale.com/wgengine/magicsock - tailscale.com/feature/condregister/identityfederation from tailscale.com/tsnet tailscale.com/feature/condregister/oauthkey from tailscale.com/tsnet tailscale.com/feature/condregister/portmapper from tailscale.com/tsnet tailscale.com/feature/condregister/useproxy from tailscale.com/tsnet - tailscale.com/feature/identityfederation from tailscale.com/feature/condregister/identityfederation tailscale.com/feature/oauthkey from tailscale.com/feature/condregister/oauthkey tailscale.com/feature/portmapper from tailscale.com/feature/condregister/portmapper tailscale.com/feature/syspolicy from tailscale.com/logpolicy @@ -304,7 +229,7 @@ tailscale.com/tsnet dependencies: (generated by github.com/tailscale/depaware) tailscale.com/tstime from tailscale.com/control/controlclient+ tailscale.com/tstime/mono from tailscale.com/net/tstun+ tailscale.com/tstime/rate from tailscale.com/wgengine/filter - LDW tailscale.com/tsweb from tailscale.com/util/eventbus + LDW tailscale.com/tsweb from tailscale.com/util/eventbus+ tailscale.com/tsweb/varz from tailscale.com/tsweb+ tailscale.com/types/appctype from tailscale.com/ipn/ipnlocal+ tailscale.com/types/bools from tailscale.com/tsnet+ @@ -394,7 +319,6 @@ tailscale.com/tsnet dependencies: (generated by github.com/tailscale/depaware) tailscale.com/wgengine/wgcfg/nmcfg from tailscale.com/ipn/ipnlocal 💣 tailscale.com/wgengine/wgint from tailscale.com/wgengine+ tailscale.com/wgengine/wglog from tailscale.com/wgengine - tailscale.com/wif from tailscale.com/feature/identityfederation golang.org/x/crypto/argon2 from tailscale.com/tka golang.org/x/crypto/blake2b from golang.org/x/crypto/argon2+ golang.org/x/crypto/blake2s from github.com/tailscale/wireguard-go/device+ @@ -414,16 +338,16 @@ tailscale.com/tsnet dependencies: (generated by github.com/tailscale/depaware) golang.org/x/net/dns/dnsmessage from tailscale.com/appc+ golang.org/x/net/http/httpguts from tailscale.com/ipn/ipnlocal golang.org/x/net/http/httpproxy from tailscale.com/net/tshttpproxy - golang.org/x/net/icmp from github.com/prometheus-community/pro-bing+ + golang.org/x/net/icmp from tailscale.com/net/ping golang.org/x/net/idna from golang.org/x/net/http/httpguts+ golang.org/x/net/internal/iana from golang.org/x/net/icmp+ - golang.org/x/net/internal/socket from golang.org/x/net/icmp+ + golang.org/x/net/internal/socket from golang.org/x/net/ipv4+ LDW golang.org/x/net/internal/socks from golang.org/x/net/proxy - golang.org/x/net/ipv4 from github.com/prometheus-community/pro-bing+ - golang.org/x/net/ipv6 from github.com/prometheus-community/pro-bing+ + golang.org/x/net/ipv4 from github.com/tailscale/wireguard-go/conn+ + golang.org/x/net/ipv6 from github.com/tailscale/wireguard-go/conn+ LDW golang.org/x/net/proxy from tailscale.com/net/netns DI golang.org/x/net/route from tailscale.com/net/netmon+ - golang.org/x/oauth2 from golang.org/x/oauth2/clientcredentials+ + golang.org/x/oauth2 from golang.org/x/oauth2/clientcredentials golang.org/x/oauth2/clientcredentials from tailscale.com/feature/oauthkey golang.org/x/oauth2/internal from golang.org/x/oauth2+ golang.org/x/sync/errgroup from github.com/mdlayher/socket+ @@ -526,12 +450,11 @@ tailscale.com/tsnet dependencies: (generated by github.com/tailscale/depaware) crypto/sha3 from crypto/internal/fips140hash+ crypto/sha512 from crypto/ecdsa+ crypto/subtle from crypto/cipher+ - crypto/tls from github.com/prometheus-community/pro-bing+ + crypto/tls from net/http+ crypto/tls/internal/fips140tls from crypto/tls crypto/x509 from crypto/tls+ DI crypto/x509/internal/macos from crypto/x509 crypto/x509/pkix from crypto/x509+ - DI database/sql/driver from github.com/google/uuid W debug/dwarf from debug/pe W debug/pe from github.com/dblohm7/wingoes/pe embed from github.com/tailscale/web-client-prebuilt+ @@ -620,7 +543,7 @@ tailscale.com/tsnet dependencies: (generated by github.com/tailscale/depaware) mime/quotedprintable from mime/multipart net from crypto/tls+ net/http from expvar+ - net/http/httptrace from github.com/prometheus-community/pro-bing+ + net/http/httptrace from net/http+ net/http/httputil from tailscale.com/client/web+ net/http/internal from net/http+ net/http/internal/ascii from net/http+ @@ -634,7 +557,7 @@ tailscale.com/tsnet dependencies: (generated by github.com/tailscale/depaware) os/user from github.com/godbus/dbus/v5+ path from debug/dwarf+ path/filepath from crypto/x509+ - reflect from database/sql/driver+ + reflect from encoding/asn1+ regexp from github.com/huin/goupnp/httpu+ regexp/syntax from regexp runtime from crypto/internal/fips140+ diff --git a/tsnet/example/tshello/README.md b/tsnet/example/tshello/README.md new file mode 100644 index 000000000..5d9d81829 --- /dev/null +++ b/tsnet/example/tshello/README.md @@ -0,0 +1,5 @@ + + +# tshello + +The tshello server demonstrates how to use Tailscale as a library. diff --git a/tsnet/example/tsnet-funnel/README.md b/tsnet/example/tsnet-funnel/README.md new file mode 100644 index 000000000..2b3031bed --- /dev/null +++ b/tsnet/example/tsnet-funnel/README.md @@ -0,0 +1,9 @@ + + +# tsnet-funnel + +The tsnet-funnel server demonstrates how to use tsnet with Funnel. + +To use it, generate an auth key from the Tailscale admin panel and run the demo with the key: + + TS_AUTHKEY= go run tsnet-funnel.go diff --git a/tsnet/example/tsnet-http-client/README.md b/tsnet/example/tsnet-http-client/README.md new file mode 100644 index 000000000..24aba97c8 --- /dev/null +++ b/tsnet/example/tsnet-http-client/README.md @@ -0,0 +1,5 @@ + + +# tsnet-http-client + +The tshello server demonstrates how to use Tailscale as a library. diff --git a/tsnet/example/tsnet-services/README.md b/tsnet/example/tsnet-services/README.md new file mode 100644 index 000000000..18bc072d7 --- /dev/null +++ b/tsnet/example/tsnet-services/README.md @@ -0,0 +1,32 @@ + + +# tsnet-services + +The tsnet-services example demonstrates how to use tsnet with Services. + +To run this example yourself: + + 1. Add access controls which (i) define a new ACL tag, (ii) allow the demo node to host the Service, and (iii) allow peers on the tailnet to reach the Service. A sample ACL policy is provided below. + 2. [Generate an auth key](https://tailscale.com/kb/1085/auth-keys#generate-an-auth-key) using the Tailscale admin panel. When doing so, add your new tag to your key (Service hosts must be tagged nodes). + 3. [Define a Service](https://tailscale.com/kb/1552/tailscale-services#step-1-define-a-tailscale-service). For the purposes of this demo, it must be defined to listen on TCP port 443. Note that you only need to follow Step 1 in the linked document. + 4. Run the demo on the command line (step 4 command shown below). + +Command for step 4: + + TS_AUTHKEY= go run tsnet-services.go -service + +The following is a sample ACL policy for step 1: + + "tagOwners": { + "tag:tsnet-demo-host": ["autogroup:member"], + }, + "autoApprovers": { + "services": { + "svc:tsnet-demo": ["tag:tsnet-demo-host"], + }, + }, + "grants": [ + "src": ["*"], + "dst": ["svc:tsnet-demo"], + "ip": ["*"], + ], diff --git a/tsnet/example/tsnet-services/tsnet-services.go b/tsnet/example/tsnet-services/tsnet-services.go index d72fd68fd..4604e8d3f 100644 --- a/tsnet/example/tsnet-services/tsnet-services.go +++ b/tsnet/example/tsnet-services/tsnet-services.go @@ -8,17 +8,16 @@ // 1. Add access controls which (i) define a new ACL tag, (ii) allow the demo // node to host the Service, and (iii) allow peers on the tailnet to reach // the Service. A sample ACL policy is provided below. -// // 2. [Generate an auth key] using the Tailscale admin panel. When doing so, add // your new tag to your key (Service hosts must be tagged nodes). -// // 3. [Define a Service]. For the purposes of this demo, it must be defined to // listen on TCP port 443. Note that you only need to follow Step 1 in the // linked document. +// 4. Run the demo on the command line (step 4 command shown below). // -// 4. Run the demo on the command line: +// Command for step 4: // -// TS_AUTHKEY= go run tsnet-services.go -service +// TS_AUTHKEY= go run tsnet-services.go -service // // The following is a sample ACL policy for step 1: // diff --git a/tsnet/example/web-client/README.md b/tsnet/example/web-client/README.md new file mode 100644 index 000000000..6b4c42235 --- /dev/null +++ b/tsnet/example/web-client/README.md @@ -0,0 +1,5 @@ + + +# web-client + +The web-client command demonstrates serving the Tailscale web client over tsnet. diff --git a/tsnet/tsnet.go b/tsnet/tsnet.go index f28179773..eb72d28d3 100644 --- a/tsnet/tsnet.go +++ b/tsnet/tsnet.go @@ -1,7 +1,139 @@ // Copyright (c) Tailscale Inc & contributors // SPDX-License-Identifier: BSD-3-Clause -// Package tsnet provides Tailscale as a library. +// Package tsnet embeds a Tailscale node directly into a Go program, +// allowing it to join a tailnet and accept or dial connections without +// running a separate tailscaled daemon or requiring any system-level +// configuration. +// +// # Overview +// +// Normally, Tailscale runs as a background system service (tailscaled) +// that manages a virtual network interface for the whole machine. tsnet +// takes a different approach: it runs a fully self-contained Tailscale +// node inside your process using a userspace TCP/IP stack (gVisor). +// This means: +// +// - No root privileges required. +// - No system daemons to install or manage. +// - Multiple independent Tailscale nodes can run within a single binary. +// - The node's [Tailscale identity] and state are stored in a directory you control. +// +// The core type is [Server], which represents one embedded Tailscale +// node. Calling [Server.Listen] or [Server.Dial] routes traffic +// exclusively over the tailnet. The standard library's [net.Listener] +// and [net.Conn] interfaces are returned, so any existing Go HTTP +// server, gRPC server, or other net-based code works without +// modification. +// +// # Usage +// +// import "tailscale.com/tsnet" +// +// s := &tsnet.Server{ +// Hostname: "my-service", +// AuthKey: os.Getenv("TS_AUTHKEY"), +// } +// defer s.Close() +// +// ln, err := s.Listen("tcp", ":80") +// if err != nil { +// log.Fatal(err) +// } +// log.Fatal(http.Serve(ln, myHandler)) +// +// On first run, if no [Server.AuthKey] is provided and the node is not +// already enrolled, the server logs an authentication URL. Open it in a +// browser to add the node to your tailnet. +// +// # Authentication +// +// A [Server] authenticates using, in order of precedence: +// +// 1. [Server.AuthKey]. +// +// 2. The TS_AUTHKEY environment variable. +// +// 3. The TS_AUTH_KEY environment variable. +// +// 4. An OAuth client secret ([Server.ClientSecret] or TS_CLIENT_SECRET), +// used to mint an auth key. +// +// 5. Workload identity federation ([Server.ClientID] plus +// [Server.IDToken] or [Server.Audience]). Available only if the +// program imports the feature: +// +// import _ "tailscale.com/feature/identityfederation" +// +// The feature is not linked by default to keep the AWS SDK and +// other cloud-provider dependencies out of programs that don't +// use workload identity federation. +// +// 6. An interactive login URL printed to [Server.UserLogf]. +// +// If the node is already enrolled (state found in [Server.Store]), the +// auth key is ignored unless TSNET_FORCE_LOGIN=1 is set. +// +// # Identifying callers +// +// Use the WhoIs method on the client returned by [Server.LocalClient] +// to identify who is making a request: +// +// lc, _ := srv.LocalClient() +// http.Serve(ln, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { +// who, err := lc.WhoIs(r.Context(), r.RemoteAddr) +// if err != nil { +// http.Error(w, err.Error(), 500) +// return +// } +// fmt.Fprintf(w, "Hello, %s!", who.UserProfile.LoginName) +// })) +// +// # Tailscale Funnel +// +// [Server.ListenFunnel] exposes your service on the public internet. +// [Tailscale Funnel] currently supports TCP on ports 443, 8443, and +// 10000. HTTPS must be enabled in the Tailscale admin console. +// +// ln, err := srv.ListenFunnel("tcp", ":443") +// // ln is a TLS listener; connections can come from anywhere on the +// // internet as well as from your tailnet. +// +// // To restrict to public traffic only: +// ln, err = srv.ListenFunnel("tcp", ":443", tsnet.FunnelOnly()) +// +// # Tailscale Services +// +// [Server.ListenService] advertises the node as a host for a named +// [Tailscale Service]. The node must use a tag-based identity. To +// advertise multiple ports, call ListenService once per port. +// +// srv.AdvertiseTags = []string{"tag:myservice"} +// +// ln, err := srv.ListenService("svc:my-service", tsnet.ServiceModeHTTP{ +// HTTPS: true, +// Port: 443, +// }) +// log.Printf("Listening on https://%s", ln.FQDN) +// +// # Running multiple nodes in one process +// +// Each [Server] instance is an independent node. Give each a unique +// [Server.Dir] and [Server.Hostname]: +// +// for _, name := range []string{"frontend", "backend"} { +// srv := &tsnet.Server{ +// Hostname: name, +// Dir: filepath.Join(baseDir, name), +// AuthKey: os.Getenv("TS_AUTHKEY"), +// Ephemeral: true, +// } +// srv.Start() +// } +// +// [Tailscale identity]: https://tailscale.com/docs/concepts/tailscale-identity +// [Tailscale Funnel]: https://tailscale.com/docs/features/tailscale-funnel +// [Tailscale Service]: https://tailscale.com/docs/features/tailscale-services package tsnet import ( @@ -31,7 +163,6 @@ import ( "tailscale.com/control/controlclient" "tailscale.com/envknob" _ "tailscale.com/feature/c2n" - _ "tailscale.com/feature/condregister/identityfederation" _ "tailscale.com/feature/condregister/oauthkey" _ "tailscale.com/feature/condregister/portmapper" _ "tailscale.com/feature/condregister/useproxy" @@ -283,6 +414,19 @@ func (s *Server) LocalClient() (*local.Client, error) { return s.localClient, nil } +// TestHooks are hooks meant for internal-testing only; they're not stable +// or documented, intentionally. +var TestHooks testHooks + +type testHooks struct{} + +// LocalBackend returns the [ipnlocal.LocalBackend] backing s. It panics +// outside of tests. +func (testHooks) LocalBackend(s *Server) *ipnlocal.LocalBackend { + testenv.AssertInTest() + return s.lb +} + // Loopback starts a routing server on a loopback address. // // The server has multiple functions. @@ -535,7 +679,7 @@ func (s *Server) doInit() { // Server. // If the server is not running, it returns nil. func (s *Server) CertDomains() []string { - nm := s.lb.NetMap() + nm := s.lb.NetMapNoPeers() if nm == nil { return nil } @@ -546,7 +690,7 @@ func (s *Server) CertDomains() []string { // has not yet joined a tailnet or is otherwise unaware of its own IP addresses, // the returned ip4, ip6 will be !netip.IsValid(). func (s *Server) TailscaleIPs() (ip4, ip6 netip.Addr) { - nm := s.lb.NetMap() + nm := s.lb.NetMapNoPeers() if nm == nil { return } diff --git a/tsnet/tsnet_test.go b/tsnet/tsnet_test.go index 4ed20fb2a..4ee0ab10c 100644 --- a/tsnet/tsnet_test.go +++ b/tsnet/tsnet_test.go @@ -2907,9 +2907,14 @@ func TestDeps(t *testing.T) { BadDeps: map[string]string{ "golang.org/x/crypto/ssh": "tsnet should not depend on SSH", "golang.org/x/crypto/ssh/internal/bcrypt_pbkdf": "tsnet should not depend on SSH", + "tailscale.com/ipn/store/awsstore": "tsnet callers wanting AWS state storage should import awsstore themselves", + "tailscale.com/ipn/store/kubestore": "tsnet callers wanting Kubernetes state storage should import kubestore themselves", + "tailscale.com/wif": "tsnet callers wanting workload identity federation should import tailscale.com/feature/identityfederation themselves", }, OnDep: func(dep string) { - if strings.Contains(dep, "portlist") { + if strings.Contains(dep, "portlist") || + strings.Contains(dep, "github.com/aws/") || + strings.Contains(dep, "k8s.io/") { t.Errorf("unexpected dep: %q", dep) } }, diff --git a/tstest/clock.go b/tstest/clock.go index 5742c6e5a..1f88fb0a2 100644 --- a/tstest/clock.go +++ b/tstest/clock.go @@ -20,6 +20,9 @@ type ClockOpts struct { // to Clock.Now. If you are passing a value here, set an explicit // timezone, otherwise the test may be non-deterministic when TZ environment // variable is set to different values. The default time is in UTC. + // + // If you do not pass an explicit Start time, the clock will start at the + // current UTC time. Start time.Time // Step is the amount of time the Clock will advance whenever Clock.Now is diff --git a/tstest/integration/integration.go b/tstest/integration/integration.go index a98df8180..861ec808d 100644 --- a/tstest/integration/integration.go +++ b/tstest/integration/integration.go @@ -73,7 +73,11 @@ type Binaries struct { // BinaryInfo describes a tailscale or tailscaled binary. type BinaryInfo struct { - Path string // abs path to tailscale or tailscaled binary + // Path is the absolute path to the tailscale or tailscaled binary. + // This path may become invalid after the owning test's TempDir is + // cleaned up; use FD (or Contents on Windows) to access the binary + // contents. + Path string Size int64 // FD and FDmu are set on Unix to efficiently copy the binary to a new @@ -88,16 +92,24 @@ type BinaryInfo struct { Contents []byte } +// CopyTo copies or hardlinks the binary into dir, returning a new BinaryInfo +// with an updated Path. The source bytes come from FD (or Contents on Windows), +// not from b.Path, which may have been deleted when its owning test's TempDir +// was cleaned up. func (b BinaryInfo) CopyTo(dir string) (BinaryInfo, error) { ret := b ret.Path = filepath.Join(dir, path.Base(b.Path)) switch runtime.GOOS { case "linux": - // TODO(bradfitz): be fancy and use linkat with AT_EMPTY_PATH to avoid - // copying? I couldn't get it to work, though. - // For now, just do the same thing as every other Unix and copy - // the binary. + // Try to hardlink from the open FD via /proc/self/fd, avoiding a + // full copy of the binary. We can't use os.Link(b.Path, ret.Path) + // because b.Path is in the first test's TempDir, which may be + // cleaned up before later tests call CopyTo. The open FD keeps the + // inode alive after the path is deleted. + if err := tryLinkat(b.FD, ret.Path); err == nil { + return ret, nil + } fallthrough case "darwin", "freebsd", "openbsd", "netbsd": f, err := os.OpenFile(ret.Path, os.O_RDWR|os.O_CREATE|os.O_EXCL, 0o755) @@ -1044,6 +1056,9 @@ func (n *TestNode) Tailscale(arg ...string) *exec.Cmd { cmd.Env = append(os.Environ(), "TS_DEBUG_UP_FLAG_GOOS="+n.upFlagGOOS, "TS_LOGS_DIR="+n.env.t.TempDir(), + "SSH_CLIENT=", // Clear SSH_CLIENT to prevent isSSHOverTailscale() false positives in tests + "SSH_CONNECTION=", // just in case + "SSH_AUTH_SOCK=", // just in case ) if *verboseTailscale { cmd.Stdout = os.Stdout diff --git a/tstest/integration/integration_linkat_linux.go b/tstest/integration/integration_linkat_linux.go new file mode 100644 index 000000000..68e9075d9 --- /dev/null +++ b/tstest/integration/integration_linkat_linux.go @@ -0,0 +1,24 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package integration + +import ( + "fmt" + "os" + + "golang.org/x/sys/unix" +) + +// tryLinkat attempts to hardlink the file referenced by fd to newpath, +// avoiding a full copy of the binary. It uses /proc/self/fd/ with +// AT_SYMLINK_FOLLOW, which works without elevated privileges (unlike +// AT_EMPTY_PATH which requires CAP_DAC_READ_SEARCH). +func tryLinkat(fd *os.File, newpath string) error { + procPath := fmt.Sprintf("/proc/self/fd/%d", fd.Fd()) + err := unix.Linkat(unix.AT_FDCWD, procPath, unix.AT_FDCWD, newpath, unix.AT_SYMLINK_FOLLOW) + if err != nil { + return fmt.Errorf("linkat via /proc/self/fd: %w", err) + } + return nil +} diff --git a/tstest/integration/integration_linkat_linux_test.go b/tstest/integration/integration_linkat_linux_test.go new file mode 100644 index 000000000..fc0a2873f --- /dev/null +++ b/tstest/integration/integration_linkat_linux_test.go @@ -0,0 +1,48 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package integration + +import ( + "os" + "path/filepath" + "testing" + + "golang.org/x/sys/unix" +) + +func TestTryLinkat(t *testing.T) { + src := filepath.Join(t.TempDir(), "src") + if err := os.WriteFile(src, []byte("hello world"), 0o755); err != nil { + t.Fatal(err) + } + fd, err := os.Open(src) + if err != nil { + t.Fatal(err) + } + defer fd.Close() + + dst := filepath.Join(t.TempDir(), "dst") + if err := tryLinkat(fd, dst); err != nil { + t.Fatal(err) + } + + got, err := os.ReadFile(dst) + if err != nil { + t.Fatal(err) + } + if string(got) != "hello world" { + t.Fatalf("got %q, want %q", got, "hello world") + } + + var stSrc, stDst unix.Stat_t + if err := unix.Stat(src, &stSrc); err != nil { + t.Fatal(err) + } + if err := unix.Stat(dst, &stDst); err != nil { + t.Fatal(err) + } + if stSrc.Ino != stDst.Ino { + t.Fatalf("inodes differ: src=%d, dst=%d", stSrc.Ino, stDst.Ino) + } +} diff --git a/tstest/integration/integration_linkat_other.go b/tstest/integration/integration_linkat_other.go new file mode 100644 index 000000000..7e22ca0da --- /dev/null +++ b/tstest/integration/integration_linkat_other.go @@ -0,0 +1,15 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +//go:build !linux + +package integration + +import ( + "errors" + "os" +) + +func tryLinkat(_ *os.File, _ string) error { + return errors.New("linkat with AT_EMPTY_PATH not supported on this OS") +} diff --git a/tstest/integration/integration_test.go b/tstest/integration/integration_test.go index 74c9c745b..3064d6a26 100644 --- a/tstest/integration/integration_test.go +++ b/tstest/integration/integration_test.go @@ -73,9 +73,7 @@ func TestMain(m *testing.M) { // https://github.com/tailscale/tailscale/issues/7894 func TestTUNMode(t *testing.T) { tstest.Shard(t) - if os.Getuid() != 0 { - t.Skip("skipping when not root") - } + tstest.RequireRoot(t) tstest.Parallel(t) env := NewTestEnv(t) env.tunMode = true @@ -469,83 +467,70 @@ func TestOneNodeUpAuth(t *testing.T) { }, } { tstest.Shard(t) + t.Run(tt.name, func(t *testing.T) { + tstest.Parallel(t) - for _, useSeamlessKeyRenewal := range []bool{true, false} { - name := tt.name - if useSeamlessKeyRenewal { - name += "-with-seamless" - } - t.Run(name, func(t *testing.T) { - tstest.Parallel(t) - - env := NewTestEnv(t, ConfigureControl( - func(control *testcontrol.Server) { - if tt.authKey != "" { - control.RequireAuthKey = tt.authKey - } else { - control.RequireAuth = true - } - - if tt.requireDeviceApproval { - control.RequireMachineAuth = true - } - - control.AllNodesSameUser = true - - if useSeamlessKeyRenewal { - control.DefaultNodeCapabilities = &tailcfg.NodeCapMap{ - tailcfg.NodeAttrSeamlessKeyRenewal: []tailcfg.RawMessage{}, - } - } - }, - )) - - n1 := NewTestNode(t, env) - d1 := n1.StartDaemon() - defer d1.MustCleanShutdown(t) - - for i, step := range tt.steps { - t.Logf("Running step %d", i) - cmdArgs := append(step.args, "--login-server="+env.ControlURL()) - - t.Logf("Running command: %s", strings.Join(cmdArgs, " ")) - - var authURLCount atomic.Int32 - var deviceApprovalURLCount atomic.Int32 - - handler := &authURLParserWriter{t: t, - authURLFn: completeLogin(t, env.Control, &authURLCount), - deviceApprovalURLFn: completeDeviceApproval(t, n1, &deviceApprovalURLCount), + env := NewTestEnv(t, ConfigureControl( + func(control *testcontrol.Server) { + if tt.authKey != "" { + control.RequireAuthKey = tt.authKey + } else { + control.RequireAuth = true } - cmd := n1.Tailscale(cmdArgs...) - cmd.Stdout = handler - cmd.Stdout = handler - cmd.Stderr = cmd.Stdout - if err := cmd.Run(); err != nil { - t.Fatalf("up: %v", err) + if tt.requireDeviceApproval { + control.RequireMachineAuth = true } - n1.AwaitRunning() + control.AllNodesSameUser = true + }, + )) - var wantAuthURLCount int32 - if step.wantAuthURL { - wantAuthURLCount = 1 - } - if n := authURLCount.Load(); n != wantAuthURLCount { - t.Errorf("Auth URLs completed = %d; want %d", n, wantAuthURLCount) - } + n1 := NewTestNode(t, env) + d1 := n1.StartDaemon() + defer d1.MustCleanShutdown(t) - var wantDeviceApprovalURLCount int32 - if step.wantDeviceApprovalURL { - wantDeviceApprovalURLCount = 1 - } - if n := deviceApprovalURLCount.Load(); n != wantDeviceApprovalURLCount { - t.Errorf("Device approval URLs completed = %d; want %d", n, wantDeviceApprovalURLCount) - } + for i, step := range tt.steps { + t.Logf("Running step %d", i) + cmdArgs := append(step.args, "--login-server="+env.ControlURL()) + + t.Logf("Running command: %s", strings.Join(cmdArgs, " ")) + + var authURLCount atomic.Int32 + var deviceApprovalURLCount atomic.Int32 + + handler := &authURLParserWriter{t: t, + authURLFn: completeLogin(t, env.Control, &authURLCount), + deviceApprovalURLFn: completeDeviceApproval(t, n1, &deviceApprovalURLCount), } - }) - } + + cmd := n1.Tailscale(cmdArgs...) + cmd.Stdout = handler + cmd.Stdout = handler + cmd.Stderr = cmd.Stdout + if err := cmd.Run(); err != nil { + t.Fatalf("up: %v", err) + } + + n1.AwaitRunning() + + var wantAuthURLCount int32 + if step.wantAuthURL { + wantAuthURLCount = 1 + } + if n := authURLCount.Load(); n != wantAuthURLCount { + t.Errorf("Auth URLs completed = %d; want %d", n, wantAuthURLCount) + } + + var wantDeviceApprovalURLCount int32 + if step.wantDeviceApprovalURL { + wantDeviceApprovalURLCount = 1 + } + if n := deviceApprovalURLCount.Load(); n != wantDeviceApprovalURLCount { + t.Errorf("Device approval URLs completed = %d; want %d", n, wantDeviceApprovalURLCount) + } + } + }) } } @@ -1565,9 +1550,7 @@ func testAutoUpdateDefaults(t *testing.T, useCap bool) { // https://github.com/tailscale/corp/issues/22511 func TestDNSOverTCPIntervalResolver(t *testing.T) { tstest.Shard(t) - if os.Getuid() != 0 { - t.Skip("skipping when not root") - } + tstest.RequireRoot(t) env := NewTestEnv(t) env.tunMode = true n1 := NewTestNode(t, env) @@ -1637,9 +1620,7 @@ func TestDNSOverTCPIntervalResolver(t *testing.T) { // directions. func TestNetstackTCPLoopback(t *testing.T) { tstest.Shard(t) - if os.Getuid() != 0 { - t.Skip("skipping when not root") - } + tstest.RequireRoot(t) env := NewTestEnv(t) env.tunMode = true @@ -1779,9 +1760,7 @@ func TestNetstackTCPLoopback(t *testing.T) { // directions. func TestNetstackUDPLoopback(t *testing.T) { tstest.Shard(t) - if os.Getuid() != 0 { - t.Skip("skipping when not root") - } + tstest.RequireRoot(t) env := NewTestEnv(t) env.tunMode = true diff --git a/tstest/integration/nat/nat_test.go b/tstest/integration/nat/nat_test.go index 98116f173..e3e53374c 100644 --- a/tstest/integration/nat/nat_test.go +++ b/tstest/integration/nat/nat_test.go @@ -7,6 +7,7 @@ import ( "bytes" "cmp" "context" + "encoding/json" "errors" "flag" "fmt" @@ -17,7 +18,7 @@ import ( "os" "os/exec" "path/filepath" - "strconv" + "runtime" "strings" "sync" "testing" @@ -26,7 +27,6 @@ import ( "golang.org/x/mod/modfile" "golang.org/x/sync/errgroup" "tailscale.com/client/tailscale" - "tailscale.com/envknob" "tailscale.com/ipn/ipnstate" "tailscale.com/syncs" "tailscale.com/tailcfg" @@ -115,10 +115,6 @@ func findKernelPath(goMod string) (string, error) { type addNodeFunc func(c *vnet.Config) *vnet.Node // returns nil to omit test -func v6cidr(n int) string { - return fmt.Sprintf("2000:%d::1/64", n) -} - func easy(c *vnet.Config) *vnet.Node { n := c.NumNodes() + 1 return c.AddNode(c.AddNetwork( @@ -126,57 +122,6 @@ func easy(c *vnet.Config) *vnet.Node { fmt.Sprintf("192.168.%d.1/24", n), vnet.EasyNAT)) } -func easyAnd6(c *vnet.Config) *vnet.Node { - n := c.NumNodes() + 1 - return c.AddNode(c.AddNetwork( - fmt.Sprintf("2.%d.%d.%d", n, n, n), // public IP - fmt.Sprintf("192.168.%d.1/24", n), - v6cidr(n), - vnet.EasyNAT)) -} - -// easyNoControlDiscoRotate sets up a node with easy NAT, cuts traffic to -// control after connecting, and then rotates the disco key to simulate a newly -// started node (from a disco perspective). -func easyNoControlDiscoRotate(c *vnet.Config) *vnet.Node { - n := c.NumNodes() + 1 - nw := c.AddNetwork( - fmt.Sprintf("2.%d.%d.%d", n, n, n), // public IP - fmt.Sprintf("192.168.%d.1/24", n), - vnet.EasyNAT) - nw.SetPostConnectControlBlackhole(true) - return c.AddNode( - vnet.TailscaledEnv{ - Key: "TS_USE_CACHED_NETMAP", - Value: "true", - }, - vnet.RotateDisco, vnet.PreICMPPing, nw) -} - -func v6AndBlackholedIPv4(c *vnet.Config) *vnet.Node { - n := c.NumNodes() + 1 - nw := c.AddNetwork( - fmt.Sprintf("2.%d.%d.%d", n, n, n), // public IP - fmt.Sprintf("192.168.%d.1/24", n), - v6cidr(n), - vnet.EasyNAT) - nw.SetBlackholedIPv4(true) - return c.AddNode(nw) -} - -func just6(c *vnet.Config) *vnet.Node { - n := c.NumNodes() + 1 - return c.AddNode(c.AddNetwork(v6cidr(n))) // public IPv6 prefix -} - -// easy + host firewall -func easyFW(c *vnet.Config) *vnet.Node { - n := c.NumNodes() + 1 - return c.AddNode(vnet.HostFirewall, c.AddNetwork( - fmt.Sprintf("2.%d.%d.%d", n, n, n), // public IP - fmt.Sprintf("192.168.%d.1/24", n), vnet.EasyNAT)) -} - func easyAF(c *vnet.Config) *vnet.Node { n := c.NumNodes() + 1 return c.AddNode(c.AddNetwork( @@ -209,46 +154,6 @@ func easyPMP(c *vnet.Config) *vnet.Node { fmt.Sprintf("192.168.%d.1/24", n), vnet.EasyNAT, vnet.NATPMP)) } -// easy + port mapping + host firewall + BPF -func easyPMPFWPlusBPF(c *vnet.Config) *vnet.Node { - n := c.NumNodes() + 1 - return c.AddNode( - vnet.HostFirewall, - vnet.TailscaledEnv{ - Key: "TS_ENABLE_RAW_DISCO", - Value: "true", - }, - vnet.TailscaledEnv{ - Key: "TS_DEBUG_RAW_DISCO", - Value: "1", - }, - vnet.TailscaledEnv{ - Key: "TS_DEBUG_DISCO", - Value: "1", - }, - vnet.TailscaledEnv{ - Key: "TS_LOG_VERBOSITY", - Value: "2", - }, - c.AddNetwork( - fmt.Sprintf("2.%d.%d.%d", n, n, n), // public IP - fmt.Sprintf("192.168.%d.1/24", n), vnet.EasyNAT, vnet.NATPMP)) -} - -// easy + port mapping + host firewall - BPF -func easyPMPFWNoBPF(c *vnet.Config) *vnet.Node { - n := c.NumNodes() + 1 - return c.AddNode( - vnet.HostFirewall, - vnet.TailscaledEnv{ - Key: "TS_ENABLE_RAW_DISCO", - Value: "false", - }, - c.AddNetwork( - fmt.Sprintf("2.%d.%d.%d", n, n, n), // public IP - fmt.Sprintf("192.168.%d.1/24", n), vnet.EasyNAT, vnet.NATPMP)) -} - func hard(c *vnet.Config) *vnet.Node { n := c.NumNodes() + 1 return c.AddNode(c.AddNetwork( @@ -256,22 +161,6 @@ func hard(c *vnet.Config) *vnet.Node { fmt.Sprintf("10.0.%d.1/24", n), vnet.HardNAT)) } -func hardNoDERPOrEndoints(c *vnet.Config) *vnet.Node { - n := c.NumNodes() + 1 - return c.AddNode(c.AddNetwork( - fmt.Sprintf("2.%d.%d.%d", n, n, n), // public IP - fmt.Sprintf("10.0.%d.1/24", n), vnet.HardNAT), - vnet.TailscaledEnv{ - Key: "TS_DEBUG_STRIP_ENDPOINTS", - Value: "1", - }, - vnet.TailscaledEnv{ - Key: "TS_DEBUG_STRIP_HOME_DERP", - Value: "1", - }, - ) -} - func hardPMP(c *vnet.Config) *vnet.Node { n := c.NumNodes() + 1 return c.AddNode(c.AddNetwork( @@ -327,6 +216,15 @@ func (nt *natTest) setupTest(ctx context.Context, addNode ...addNodeFunc) (nodes } }) + haveKVM := false + if runtime.GOOS == "linux" { + if f, err := os.OpenFile("/dev/kvm", os.O_RDWR, 0); err == nil { + f.Close() + haveKVM = true + } + } + + qmpSocks := make([]string, len(nodes)) for i, node := range nodes { disk := fmt.Sprintf("%s/node-%d.qcow2", nt.tempDir, i) out, err := exec.Command("qemu-img", "create", @@ -349,22 +247,28 @@ func (nt *natTest) setupTest(ctx context.Context, addNode ...addNodeFunc) (nodes } envStr := envBuf.String() - cmd := exec.Command("qemu-system-x86_64", + qmpSocks[i] = fmt.Sprintf("%s/qmp-node-%d.sock", nt.tempDir, i) + qemuArgs := []string{ "-M", "microvm,isa-serial=off", "-m", "384M", "-nodefaults", "-no-user-config", "-nographic", "-kernel", nt.kernel, - "-append", "console=hvc0 root=PARTUUID=60c24cc1-f3f9-427a-8199-76baa2d60001/PARTNROFF=1 ro init=/gokrazy/init panic=10 oops=panic pci=off nousb tsc=unstable clocksource=hpet gokrazy.remote_syslog.target="+sysLogAddr+" tailscale-tta=1"+envStr, - "-drive", "id=blk0,file="+disk+",format=qcow2", + "-append", "console=hvc0 root=PARTUUID=60c24cc1-f3f9-427a-8199-76baa2d60001/PARTNROFF=1 ro init=/gokrazy/init panic=10 oops=panic pci=off nousb gokrazy.remote_syslog.target=" + sysLogAddr + " tailscale-tta=1" + envStr, + "-drive", "id=blk0,file=" + disk + ",format=qcow2", "-device", "virtio-blk-device,drive=blk0", - "-netdev", "stream,id=net0,addr.type=unix,addr.path="+sockAddr, + "-netdev", "stream,id=net0,addr.type=unix,addr.path=" + sockAddr, "-device", "virtio-serial-device", "-device", "virtio-rng-device", - "-device", "virtio-net-device,netdev=net0,mac="+node.MAC().String(), + "-device", "virtio-net-device,netdev=net0,mac=" + node.MAC().String(), "-chardev", "stdio,id=virtiocon0,mux=on", "-device", "virtconsole,chardev=virtiocon0", "-mon", "chardev=virtiocon0,mode=readline", - ) + "-qmp", "unix:" + qmpSocks[i] + ",server=on,wait=off", + } + if haveKVM { + qemuArgs = append(qemuArgs, "-enable-kvm", "-cpu", "host") + } + cmd := exec.Command("qemu-system-x86_64", qemuArgs...) cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr if err := cmd.Start(); err != nil { @@ -376,6 +280,15 @@ func (nt *natTest) setupTest(ctx context.Context, addNode ...addNodeFunc) (nodes }) } + for i, node := range nodes { + if err := nt.vnet.AwaitFirstPacket(ctx, node.MAC()); err != nil { + t.Logf("node %v: no boot progress (no packets received): %v", node, err) + t.Logf("node %v: QMP status: %s", node, qmpQueryStatus(qmpSocks[i])) + t.FailNow() + } + t.Logf("node %v: boot detected (first packet received)", node) + } + for _, n := range nodes { client := nt.vnet.NodeAgentClient(n) n.SetClient(client) @@ -411,6 +324,11 @@ func (nt *natTest) setupTest(ctx context.Context, addNode ...addNodeFunc) (nodes return fmt.Errorf("%v status: %w", node, err) } + if capMap := node.WantCapMap(); capMap != nil { + nt.tb.Logf("using capmap for %s: %+v", node.String(), capMap) + nt.vnet.ControlServer().SetNodeCapMap(st.Self.PublicKey, capMap) + } + if st.BackendState != "Running" { return fmt.Errorf("%v state = %q", node, st.BackendState) } @@ -430,34 +348,24 @@ func (nt *natTest) setupTest(ctx context.Context, addNode ...addNodeFunc) (nodes return nodes, clients, nt.vnet.Close } -func (nt *natTest) runHostConnectivityTest(addNode ...addNodeFunc) bool { - ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) - defer cancel() - nodes, clients, cleanup := nt.setupTest(ctx, addNode...) - defer cleanup() +type hasDeadline interface { + Deadline() (deadline time.Time, ok bool) +} - if len(nodes) != 2 { - nt.tb.Logf("ping can only be done among exactly two nodes") - return false - } - var fromClient, toClient *vnet.NodeAgentClient - for i, n := range nodes { - if n.ShouldJoinTailnet() && fromClient == nil { - fromClient = clients[i] - } else { - toClient = clients[i] +// testContext returns a context derived from the test's deadline (from -timeout), +// leaving a small margin for cleanup. Falls back to 60s if no deadline is set. +func testContext(tb testing.TB) (context.Context, context.CancelFunc) { + if t, ok := tb.(hasDeadline); ok { + if dl, ok := t.Deadline(); ok { + const margin = 5 * time.Second + return context.WithDeadline(context.Background(), dl.Add(-margin)) } } - got, err := sendHostNetworkPing(ctx, nt.tb, fromClient, toClient) - if err != nil { - nt.tb.Fatalf("ping host: %v", err) - } - nt.tb.Logf("ping success: %v", got) - return got + return context.WithTimeout(context.Background(), 60*time.Second) } func (nt *natTest) runTailscaleConnectivityTest(addNode ...addNodeFunc) pingRoute { - ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) + ctx, cancel := testContext(nt.tb) defer cancel() nodes, clients, cleanup := nt.setupTest(ctx, addNode...) @@ -580,6 +488,55 @@ func ping(ctx context.Context, t testing.TB, c *vnet.NodeAgentClient, target net return nil, fmt.Errorf("no ping response (ctx: %v)", ctx.Err()) } +// qmpQueryStatus connects to a QEMU QMP socket and returns the VM status +// (e.g. "running", "paused", "prelaunch") or an error string. +func qmpQueryStatus(sockPath string) string { + conn, err := net.DialTimeout("unix", sockPath, 2*time.Second) + if err != nil { + return fmt.Sprintf("dial error: %v", err) + } + defer conn.Close() + conn.SetDeadline(time.Now().Add(5 * time.Second)) + dec := json.NewDecoder(conn) + + // Read QMP greeting. + var greeting json.RawMessage + if err := dec.Decode(&greeting); err != nil { + return fmt.Sprintf("greeting error: %v", err) + } + + // Enter command mode. + if _, err := conn.Write([]byte(`{"execute":"qmp_capabilities"}` + "\n")); err != nil { + return fmt.Sprintf("write caps: %v", err) + } + var capsResp json.RawMessage + if err := dec.Decode(&capsResp); err != nil { + return fmt.Sprintf("caps response: %v", err) + } + + // Query status. + if _, err := conn.Write([]byte(`{"execute":"query-status"}` + "\n")); err != nil { + return fmt.Sprintf("write query-status: %v", err) + } + var statusResp struct { + Return struct { + Running bool `json:"running"` + Status string `json:"status"` + } `json:"return"` + Error *struct { + Class string `json:"class"` + Desc string `json:"desc"` + } `json:"error"` + } + if err := dec.Decode(&statusResp); err != nil { + return fmt.Sprintf("status response: %v", err) + } + if statusResp.Error != nil { + return fmt.Sprintf("qmp error: %s: %s", statusResp.Error.Class, statusResp.Error.Desc) + } + return fmt.Sprintf("status=%s running=%v", statusResp.Return.Status, statusResp.Return.Running) +} + func up(ctx context.Context, c *vnet.NodeAgentClient) error { req, err := http.NewRequestWithContext(ctx, "GET", "http://unused/up", nil) if err != nil { @@ -597,60 +554,6 @@ func up(ctx context.Context, c *vnet.NodeAgentClient) error { return nil } -func getClientIP(ctx context.Context, c *vnet.NodeAgentClient) (netip.Addr, error) { - getIPReq, err := http.NewRequestWithContext(ctx, "GET", "http://unused/ip", nil) - if err != nil { - return netip.Addr{}, err - } - res, err := c.HTTPClient.Do(getIPReq) - if err != nil { - return netip.Addr{}, err - } - defer res.Body.Close() - if res.StatusCode != http.StatusOK { - return netip.Addr{}, fmt.Errorf("client returned http status %q", res.Status) - } - ipBytes, err := io.ReadAll(res.Body) - if err != nil { - return netip.Addr{}, err - } - addrPort, err := netip.ParseAddrPort(string(ipBytes)) - if err != nil { - return netip.Addr{}, err - } - return addrPort.Addr(), nil -} - -// sendHostNetworkPing pings toClient from fromClient, and returns whether -// toClient responded to the ping. -func sendHostNetworkPing(ctx context.Context, tb testing.TB, fromClient, toClient *vnet.NodeAgentClient) (bool, error) { - toIP, err := getClientIP(ctx, toClient) - if err != nil { - return false, fmt.Errorf("get ip: %w", err) - } - req, err := http.NewRequestWithContext(ctx, "GET", fmt.Sprintf("http://unused/ping?host=%s", toIP.String()), nil) - if err != nil { - return false, err - } - res, err := fromClient.HTTPClient.Do(req) - if err != nil { - return false, err - } - defer res.Body.Close() - got, err := io.ReadAll(res.Body) - if err != nil { - tb.Logf("error while reading http body: %v", err) - } else { - tb.Logf("got response from ping: %q", got) - } - ec, err := strconv.Atoi(res.Header.Get("Exec-Exit-Code")) - if err != nil { - return false, fmt.Errorf("parse exit code: %w", err) - } - tb.Logf("got ec: %v", ec) - return ec == 0, nil -} - type nodeType struct { name string fn addNodeFunc @@ -667,26 +570,6 @@ var types = []nodeType{ {"cgnat", cgnatNoTailnet}, } -// want sets the expected ping route for the test. -func (nt *natTest) want(r pingRoute) { - if nt.gotRoute != r { - nt.tb.Errorf("ping route = %v; want %v", nt.gotRoute, r) - } -} - -func TestEasyEasy(t *testing.T) { - nt := newNatTest(t) - nt.runTailscaleConnectivityTest(easy, easy) - nt.want(routeDirect) -} - -func TestTwoEasyNoControlDiscoRotate(t *testing.T) { - envknob.Setenv("TS_USE_CACHED_NETMAP", "1") - nt := newNatTest(t) - nt.runTailscaleConnectivityTest(easyNoControlDiscoRotate, easyNoControlDiscoRotate) - nt.want(routeDirect) -} - func cgnatNoTailnet(c *vnet.Config) *vnet.Node { n := c.NumNodes() + 1 return c.AddNode(c.AddNetwork( @@ -696,101 +579,6 @@ func cgnatNoTailnet(c *vnet.Config) *vnet.Node { vnet.DontJoinTailnet) } -func TestNonTailscaleCGNATEndpoint(t *testing.T) { - if !*knownBroken { - t.Skip("skipping known-broken test; set --known-broken to run; see https://github.com/tailscale/corp/issues/36270") - } - nt := newNatTest(t) - if !nt.runHostConnectivityTest(cgnatNoTailnet, sameLAN) { - t.Fatalf("could not ping") - } -} - -// Issue tailscale/corp#26438: use learned DERP route as send path of last -// resort -// -// See (*magicsock.Conn).fallbackDERPRegionForPeer and its comment for -// background. -// -// This sets up a test with two nodes that must use DERP to communicate but the -// target of the ping (the second node) additionally is not getting DERP or -// Endpoint updates from the control plane. (Or rather, it's getting them but is -// configured to scrub them right when they come off the network before being -// processed) This then tests whether node2, upon receiving a packet, will be -// able to reply to node1 since it knows neither node1's endpoints nor its home -// DERP. The only reply route it can use is that fact that it just received a -// packet over a particular DERP from that peer. -func TestFallbackDERPRegionForPeer(t *testing.T) { - nt := newNatTest(t) - nt.runTailscaleConnectivityTest(hard, hardNoDERPOrEndoints) - nt.want(routeDERP) -} - -func TestSingleJustIPv6(t *testing.T) { - nt := newNatTest(t) - nt.runTailscaleConnectivityTest(just6) -} - -var knownBroken = flag.Bool("known-broken", false, "run known-broken tests") - -// TestSingleDualStackButBrokenIPv4 tests a dual-stack node with broken -// (blackholed) IPv4. -// -// See https://github.com/tailscale/tailscale/issues/13346 -func TestSingleDualBrokenIPv4(t *testing.T) { - if !*knownBroken { - t.Skip("skipping known-broken test; set --known-broken to run; see https://github.com/tailscale/tailscale/issues/13346") - } - nt := newNatTest(t) - nt.runTailscaleConnectivityTest(v6AndBlackholedIPv4) -} - -func TestJustIPv6(t *testing.T) { - nt := newNatTest(t) - nt.runTailscaleConnectivityTest(just6, just6) - nt.want(routeDirect) -} - -func TestEasy4AndJust6(t *testing.T) { - nt := newNatTest(t) - nt.runTailscaleConnectivityTest(easyAnd6, just6) - nt.want(routeDirect) -} - -func TestSameLAN(t *testing.T) { - nt := newNatTest(t) - nt.runTailscaleConnectivityTest(easy, sameLAN) - nt.want(routeLocal) -} - -// TestBPFDisco tests https://github.com/tailscale/tailscale/issues/3824 ... -// * server behind a Hard NAT -// * client behind a NAT with UPnP support -// * client machine has a stateful host firewall (e.g. ufw) -func TestBPFDisco(t *testing.T) { - nt := newNatTest(t) - nt.runTailscaleConnectivityTest(easyPMPFWPlusBPF, hard) - nt.want(routeDirect) -} - -func TestHostFWNoBPF(t *testing.T) { - nt := newNatTest(t) - nt.runTailscaleConnectivityTest(easyPMPFWNoBPF, hard) - nt.want(routeDERP) -} - -func TestHostFWPair(t *testing.T) { - nt := newNatTest(t) - nt.runTailscaleConnectivityTest(easyFW, easyFW) - nt.want(routeDirect) -} - -func TestOneHostFW(t *testing.T) { - nt := newNatTest(t) - nt.runTailscaleConnectivityTest(easy, easyFW) - nt.want(routeDirect) -} - var pair = flag.String("pair", "", "comma-separated pair of types to test (easy, easyAF, hard, easyPMP, hardPMP, one2one, sameLAN)") func TestPair(t *testing.T) { diff --git a/tstest/integration/testcontrol/testcontrol.go b/tstest/integration/testcontrol/testcontrol.go index 486bc8b81..c96b1ed33 100644 --- a/tstest/integration/testcontrol/testcontrol.go +++ b/tstest/integration/testcontrol/testcontrol.go @@ -69,6 +69,22 @@ type Server struct { // belong to the same user. AllNodesSameUser bool + // AllOnline, if true, marks every peer entry in MapResponses as + // Online=true. This is a coarse stand-in for the per-node + // online/offline tracking that production control servers do based + // on streaming map sessions: certain disco-key handling fast paths + // in [tailscale.com/control/controlclient] and + // [tailscale.com/wgengine/userspace] only fire when the peer is + // reported online, so without this flag they are silently skipped + // in tests, which can mask bugs and slow down recovery from disco + // rotations. See [tailscale.com/control/controlclient/map.go] + // removeUnwantedDiscoUpdates and + // removeUnwantedDiscoUpdatesFromFullNetmapUpdate for callers that + // branch on Online. + // + // Finer-grained per-node online tracking can be added later. + AllOnline bool + // DefaultNodeCapabilities overrides the capability map sent to each client. DefaultNodeCapabilities *tailcfg.NodeCapMap @@ -80,10 +96,18 @@ type Server struct { ExplicitBaseURL string // e.g. "http://127.0.0.1:1234" with no trailing URL HTTPTestServer *httptest.Server // if non-nil, used to get BaseURL + // MaybeRateLimitRegister, if non-nil, is called before processing + // register requests. If it returns true, a 429 response is sent + // with the given Retry-After header value and body string. + MaybeRateLimitRegister func() (reject bool, retryAfter string, msg string) + // ModifyFirstMapResponse, if non-nil, is called exactly once per // MapResponse stream to modify the first MapResponse sent in response to it. ModifyFirstMapResponse func(*tailcfg.MapResponse, *tailcfg.MapRequest) + // AltMapStream, if non-nil, takes over serveMap. See [AltMapStreamFunc]. + AltMapStream AltMapStreamFunc + initMuxOnce sync.Once mux *http.ServeMux @@ -132,12 +156,16 @@ type Server struct { updates map[tailcfg.NodeID]chan updateType authPath map[string]*AuthPath nodeKeyAuthed set.Set[key.NodePublic] - msgToSend map[key.NodePublic]any // value is *tailcfg.PingRequest or entire *tailcfg.MapResponse - allExpired bool // All nodes will be told their node key is expired. + msgToSend map[key.NodePublic][]any // FIFO queue per node; values are *tailcfg.PingRequest or *tailcfg.MapResponse + allExpired bool // All nodes will be told their node key is expired. // tkaStorage records the Tailnet Lock state, if any. // If nil, Tailnet Lock is not enabled in the Tailnet. tkaStorage tka.CompactableChonk + + // onMapRequest, if non-nil, is called at the start of each map poll request. + // It can be used in tests to panic or fail if a node contacts control unexpectedly. + onMapRequest func(nodeKey key.NodePublic) } // BaseURL returns the server's base URL, without trailing slash. @@ -276,14 +304,16 @@ func (s *Server) AddRawMapResponse(nodeKeyDst key.NodePublic, mr *tailcfg.MapRes func (s *Server) addDebugMessage(nodeKeyDst key.NodePublic, msg any) bool { s.mu.Lock() defer s.mu.Unlock() - if s.msgToSend == nil { - s.msgToSend = map[key.NodePublic]any{} - } - // Now send the update to the channel node := s.nodeLocked(nodeKeyDst) if node == nil { return false } + updatesCh := s.updates[node.ID] + if updatesCh == nil { + // No streaming poll is registered, so there's nobody to deliver + // the message to. + return false + } if _, ok := msg.(*tailcfg.MapResponse); ok { if s.suppressAutoMapResponses == nil { @@ -292,10 +322,14 @@ func (s *Server) addDebugMessage(nodeKeyDst key.NodePublic, msg any) bool { s.suppressAutoMapResponses.Add(nodeKeyDst) } - s.msgToSend[nodeKeyDst] = msg - nodeID := node.ID - oldUpdatesCh := s.updates[nodeID] - return sendUpdate(oldUpdatesCh, updateDebugInjection) + mak.Set(&s.msgToSend, nodeKeyDst, append(s.msgToSend[nodeKeyDst], msg)) + // sendUpdate returning false here is fine: the channel is a lossy + // wake-up signal whose buffer is single-slot. A full buffer means a + // prior wake-up is still pending, and the streaming poll will check + // msgToSend when it processes that wake-up. The queue in msgToSend + // is the source of truth. + sendUpdate(updatesCh, updateDebugInjection) + return true } // Mark the Node key of every node as expired @@ -768,6 +802,16 @@ func (s *Server) CompleteDeviceApproval(controlUrl string, urlStr string, nodeKe } func (s *Server) serveRegister(w http.ResponseWriter, r *http.Request, mkey key.MachinePublic) { + if fn := s.MaybeRateLimitRegister; fn != nil { + if reject, retryAfter, msg := fn(); reject { + if retryAfter != "" { + w.Header().Set("Retry-After", retryAfter) + } + http.Error(w, msg, http.StatusTooManyRequests) + return + } + } + msg, err := io.ReadAll(io.LimitReader(r.Body, msgLimit)) r.Body.Close() if err != nil { @@ -1129,6 +1173,21 @@ func (s *Server) serveMap(w http.ResponseWriter, r *http.Request, mkey key.Machi go panic(fmt.Sprintf("bad map request: %v", err)) } + s.mu.Lock() + if s.onMapRequest != nil { + s.onMapRequest(req.NodeKey) + } + s.mu.Unlock() + + if s.AltMapStream != nil { + // The caller takes over the stream entirely; it must handle + // keeping the HTTP response alive until ctx is done. + compress := req.Compress != "" + w.WriteHeader(200) + s.AltMapStream(ctx, &mapStreamSender{s: s, w: w, compress: compress}, req) + return + } + jitter := rand.N(8 * time.Second) keepAlive := 50*time.Second + jitter @@ -1142,8 +1201,15 @@ func (s *Server) serveMap(w http.ResponseWriter, r *http.Request, mkey key.Machi return } + // Per tailcfg.MapRequest.Stream docs: if Stream is true and Version >= 68, + // the server must treat this as read-only and ignore Hostinfo, Endpoints, + // DiscoKey, etc. — modern clients send those via a separate non-streaming + // POST /machine/map from a dedicated updateRoutine, not piggybacked on the + // streaming poll. Without this, the streaming MapRequest's zero-valued + // DiscoKey/Endpoints clobber whatever was just pushed out-of-band. + streamingNonUpdate := req.Stream && req.Version >= 68 var peersToUpdate []tailcfg.NodeID - if !req.ReadOnly { + if !req.ReadOnly && !streamingNonUpdate { endpoints := filterInvalidIPv6Endpoints(req.Endpoints) node.Endpoints = endpoints node.DiscoKey = req.DiscoKey @@ -1371,6 +1437,9 @@ func (s *Server) MapResponse(req *tailcfg.MapRequest) (res *tailcfg.MapResponse, p.PrimaryRoutes = routes p.AllowedIPs = append(p.AllowedIPs, routes...) } + if s.AllOnline { + p.Online = new(true) + } res.Peers = append(res.Peers, p) } @@ -1419,15 +1488,29 @@ func (s *Server) MapResponse(req *tailcfg.MapRequest) (res *tailcfg.MapResponse, res.Node.PrimaryRoutes = s.nodeSubnetRoutes[nk] res.Node.AllowedIPs = append(res.Node.Addresses, s.nodeSubnetRoutes[nk]...) - // Consume a PingRequest while protected by mutex if it exists - switch m := s.msgToSend[nk].(type) { - case *tailcfg.PingRequest: - res.PingRequest = m - delete(s.msgToSend, nk) + // Consume a PingRequest at the head of the queue, if any. + if q := s.msgToSend[nk]; len(q) > 0 { + if pr, ok := q[0].(*tailcfg.PingRequest); ok { + res.PingRequest = pr + s.popMsgToSendLocked(nk) + } } return res, nil } +// popMsgToSendLocked pops the head of the per-node message queue. +// s.mu must be held. +func (s *Server) popMsgToSendLocked(nk key.NodePublic) { + q := s.msgToSend[nk] + if len(q) <= 1 { + delete(s.msgToSend, nk) + return + } + // Zero the head to allow GC of any large referenced response. + q[0] = nil + s.msgToSend[nk] = q[1:] +} + func (s *Server) canGenerateAutomaticMapResponseFor(nk key.NodePublic) bool { s.mu.Lock() defer s.mu.Unlock() @@ -1437,22 +1520,21 @@ func (s *Server) canGenerateAutomaticMapResponseFor(nk key.NodePublic) bool { func (s *Server) hasPendingRawMapMessage(nk key.NodePublic) bool { s.mu.Lock() defer s.mu.Unlock() - _, ok := s.msgToSend[nk] - return ok + return len(s.msgToSend[nk]) > 0 } func (s *Server) takeRawMapMessage(nk key.NodePublic) (mapResJSON []byte, ok bool) { s.mu.Lock() defer s.mu.Unlock() - mr, ok := s.msgToSend[nk] - if !ok { + q := s.msgToSend[nk] + if len(q) == 0 { return nil, false } - delete(s.msgToSend, nk) + mr := q[0] + s.popMsgToSendLocked(nk) // If it's a bare PingRequest, wrap it in a MapResponse. - switch pr := mr.(type) { - case *tailcfg.PingRequest: + if pr, ok := mr.(*tailcfg.PingRequest); ok { mr = &tailcfg.MapResponse{PingRequest: pr} } @@ -1464,12 +1546,51 @@ func (s *Server) takeRawMapMessage(nk key.NodePublic) (mapResJSON []byte, ok boo return mapResJSON, true } +// AltMapStreamFunc is the type of [Server.AltMapStream]: a callback that +// takes over the serveMap handler entirely. The callback hand-builds and +// sends MapResponses via the provided [MapStreamWriter] and is responsible +// for keeping the stream alive until ctx is done. When set, the normal +// per-node map-stream state machine in serveMap is bypassed. +// +// The callback is invoked for every map long-poll, including the +// non-streaming "lite" polls controlclient issues to push HostInfo updates +// (req.Stream == false). Implementations that only care about the streaming +// long-poll typically respond to non-streaming polls with an empty +// MapResponse and return immediately. +// +// This hook is for benchmarks and stress tests that need to drive clients +// with a controlled sequence of responses. +type AltMapStreamFunc func(ctx context.Context, w MapStreamWriter, req *tailcfg.MapRequest) + +// MapStreamWriter is the interface passed to an [AltMapStreamFunc], +// letting the callback write framed MapResponse messages directly onto the +// long-poll HTTP response. +type MapStreamWriter interface { + // SendMapMessage encodes and writes msg as a single framed + // MapResponse on the stream. It respects the client's Compress flag + // (captured when the stream started). + SendMapMessage(msg *tailcfg.MapResponse) error +} + +// mapStreamSender implements [MapStreamWriter] for [Server.AltMapStream] +// callbacks. +type mapStreamSender struct { + s *Server + w http.ResponseWriter + compress bool +} + +func (m *mapStreamSender) SendMapMessage(msg *tailcfg.MapResponse) error { + return m.s.sendMapMsg(m.w, m.compress, msg) +} + func (s *Server) sendMapMsg(w http.ResponseWriter, compress bool, msg any) error { resBytes, err := s.encode(compress, msg) if err != nil { return err } - if len(resBytes) > 16<<20 { + const maxMapSize = 256 << 20 // 256MB + if len(resBytes) > maxMapSize { return fmt.Errorf("map message too big: %d", len(resBytes)) } var siz [4]byte @@ -1509,6 +1630,15 @@ func (s *Server) encode(compress bool, v any) (b []byte, err error) { return b, nil } +// SetOnMapRequest sets callback used for testing when a new mapRequest happens. +// Pass nil to remove the callback. +func (s *Server) SetOnMapRequest(f func(key.NodePublic)) { + s.mu.Lock() + defer s.mu.Unlock() + + s.onMapRequest = f +} + // filterInvalidIPv6Endpoints removes invalid IPv6 endpoints from eps, // modify the slice in place, returning the potentially smaller subset (aliasing // the original memory). diff --git a/tstest/integration/testcontrol/testcontrol_test.go b/tstest/integration/testcontrol/testcontrol_test.go new file mode 100644 index 000000000..d3008cdb7 --- /dev/null +++ b/tstest/integration/testcontrol/testcontrol_test.go @@ -0,0 +1,132 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package testcontrol_test + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "net" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "tailscale.com/control/ts2021" + "tailscale.com/control/tsp" + "tailscale.com/net/tsdial" + "tailscale.com/tailcfg" + "tailscale.com/tstest/integration/testcontrol" + "tailscale.com/types/key" + "tailscale.com/util/must" +) + +// TestStreamingMapReqReadOnlyByVersion verifies that testcontrol matches +// production control's streaming-is-read-only semantics for clients at +// capability version >= 68. Per tailcfg.MapRequest.Stream docs, a streaming +// MapRequest from a cap>=68 client must be treated as read-only by the +// server (Endpoints/Hostinfo/DiscoKey are sent separately via a non-streaming +// /machine/map call), so the streaming MapRequest's zero-valued DiscoKey +// must not clobber the node's currently stored DiscoKey. +// +// For older (cap<68) clients, the streaming MapRequest is still a write and +// writes do happen, so DiscoKey=zero in the request does clobber. +func TestStreamingMapReqReadOnlyByVersion(t *testing.T) { + tests := []struct { + version tailcfg.CapabilityVersion + wantClobber bool + }{ + {67, true}, // pre-cap-68: streaming is a write, DiscoKey=zero clobbers. + {68, false}, // cap>=68: streaming is read-only, DiscoKey unchanged. + } + + for _, tt := range tests { + t.Run(fmt.Sprintf("v%d", tt.version), func(t *testing.T) { + ctrl := &testcontrol.Server{} + ctrl.HTTPTestServer = httptest.NewUnstartedServer(ctrl) + ctrl.HTTPTestServer.Start() + t.Cleanup(ctrl.HTTPTestServer.Close) + baseURL := ctrl.HTTPTestServer.URL + + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second) + defer cancel() + + serverKey := must.Get(tsp.DiscoverServerKey(ctx, baseURL)) + + // Register a node and push a known DiscoKey via SendMapUpdate + // (a non-streaming, unambiguously-a-write request). + nodeKey := key.NewNode() + machineKey := key.NewMachine() + wantDisco := key.NewDisco().Public() + + tc := must.Get(tsp.NewClient(tsp.ClientOpts{ + ServerURL: baseURL, + MachineKey: machineKey, + })) + defer tc.Close() + tc.SetControlPublicKey(serverKey) + must.Get(tc.Register(ctx, tsp.RegisterOpts{ + NodeKey: nodeKey, + Hostinfo: &tailcfg.Hostinfo{Hostname: "target"}, + })) + if err := tc.SendMapUpdate(ctx, tsp.SendMapUpdateOpts{ + NodeKey: nodeKey, + DiscoKey: wantDisco, + Hostinfo: &tailcfg.Hostinfo{Hostname: "target"}, + }); err != nil { + t.Fatalf("SendMapUpdate: %v", err) + } + if n := ctrl.Node(nodeKey.Public()); n == nil || n.DiscoKey != wantDisco { + t.Fatalf("pre: DiscoKey not set; node=%+v", n) + } + + // Fire a streaming MapRequest with the chosen Version and a + // zero DiscoKey. Use ts2021 directly because tsp.Map hardcodes + // Version to tailcfg.CurrentCapabilityVersion. + nc := must.Get(ts2021.NewClient(ts2021.ClientOpts{ + ServerURL: baseURL, + PrivKey: machineKey, + ServerPubKey: serverKey, + Dialer: tsdial.NewFromFuncForDebug(t.Logf, (&net.Dialer{}).DialContext), + })) + defer nc.Close() + + body := must.Get(json.Marshal(&tailcfg.MapRequest{ + Version: tt.version, + NodeKey: nodeKey.Public(), + Stream: true, + // DiscoKey intentionally zero. + })) + reqURL := strings.Replace(baseURL+"/machine/map", "http:", "https:", 1) + reqCtx, reqCancel := context.WithCancel(ctx) + defer reqCancel() + req := must.Get(http.NewRequestWithContext(reqCtx, "POST", reqURL, bytes.NewReader(body))) + ts2021.AddLBHeader(req, nodeKey.Public()) + + // nc.Do returns once response headers arrive, which in + // testcontrol's serveMap is AFTER the write branch has run + // (or been skipped). So by the time this returns, any write + // this request is going to do has already happened. + res, err := nc.Do(req) + if err != nil { + t.Fatalf("nc.Do: %v", err) + } + res.Body.Close() // tears down the streaming session server-side + + got := ctrl.Node(nodeKey.Public()) + if got == nil { + t.Fatal("node disappeared") + } + switch { + case tt.wantClobber && !got.DiscoKey.IsZero(): + t.Errorf("v%d: expected DiscoKey clobbered to zero, got %v", tt.version, got.DiscoKey) + case !tt.wantClobber && got.DiscoKey != wantDisco: + t.Errorf("v%d: DiscoKey changed from %v to %v; should have been left alone", + tt.version, wantDisco, got.DiscoKey) + } + }) + } +} diff --git a/tstest/integration/vms/distros.go b/tstest/integration/vms/distros.go index 94f11c77a..b6312dba4 100644 --- a/tstest/integration/vms/distros.go +++ b/tstest/integration/vms/distros.go @@ -35,11 +35,10 @@ func (d *Distro) InstallPre() string { return ` - [ dnf, install, "-y", iptables ]` case "apt": - return ` - [ apt-get, update ] - - [ apt-get, "-y", install, curl, "apt-transport-https", gnupg2 ]` + return ` - [ apt-get, "-y", install, curl, "apt-transport-https", gnupg2 ]` case "apk": - return ` - [ apk, "-U", add, curl, "ca-certificates", iptables, ip6tables ] + return ` - [ apk, add, curl, "ca-certificates", iptables, ip6tables ] - [ modprobe, tun ]` } diff --git a/tstest/iosdeps/iosdeps.go b/tstest/iosdeps/iosdeps.go index f6290af67..a1279e20b 100644 --- a/tstest/iosdeps/iosdeps.go +++ b/tstest/iosdeps/iosdeps.go @@ -4,28 +4,36 @@ // Package iosdeps is a just a list of the packages we import on iOS, to let us // test that our transitive closure of dependencies on iOS doesn't accidentally // grow too large, as we've historically been memory constrained there. +// +// It is intended to mirror the imports of the ipn-go-bridge package in the +// private "corp" repository (the Go side of the iOS / macOS app). package iosdeps import ( _ "bufio" _ "bytes" - _ "context" - _ "crypto/rand" + _ "crypto" + _ "crypto/ecdsa" + _ "crypto/elliptic" _ "crypto/sha256" + _ "encoding/base64" _ "encoding/json" _ "errors" _ "fmt" _ "io" - _ "io/fs" _ "log" _ "math" _ "net" _ "net/http" + _ "net/netip" + _ "net/url" _ "os" _ "os/signal" _ "path/filepath" _ "runtime" _ "runtime/debug" + _ "slices" + _ "strconv" _ "strings" _ "sync" _ "sync/atomic" @@ -35,24 +43,48 @@ import ( _ "github.com/tailscale/wireguard-go/device" _ "github.com/tailscale/wireguard-go/tun" - _ "go4.org/mem" _ "golang.org/x/sys/unix" + _ "tailscale.com/client/tailscale/apitype" + _ "tailscale.com/drive/driveimpl" + _ "tailscale.com/envknob" + _ "tailscale.com/feature/condregister" + _ "tailscale.com/feature/syspolicy" + _ "tailscale.com/feature/taildrop" _ "tailscale.com/hostinfo" _ "tailscale.com/ipn" + _ "tailscale.com/ipn/ipnauth" _ "tailscale.com/ipn/ipnlocal" _ "tailscale.com/ipn/localapi" + _ "tailscale.com/logpolicy" _ "tailscale.com/logtail" _ "tailscale.com/logtail/filch" _ "tailscale.com/net/dns" - _ "tailscale.com/net/netaddr" + _ "tailscale.com/net/netmon" + _ "tailscale.com/net/netutil" + _ "tailscale.com/net/tsaddr" _ "tailscale.com/net/tsdial" + _ "tailscale.com/net/tshttpproxy" _ "tailscale.com/net/tstun" _ "tailscale.com/paths" + _ "tailscale.com/safesocket" + _ "tailscale.com/tsd" _ "tailscale.com/types/empty" + _ "tailscale.com/types/key" + _ "tailscale.com/types/lazy" _ "tailscale.com/types/logger" + _ "tailscale.com/types/logid" + _ "tailscale.com/types/netmap" _ "tailscale.com/util/clientmetric" _ "tailscale.com/util/dnsname" + _ "tailscale.com/util/eventbus" + _ "tailscale.com/util/must" + _ "tailscale.com/util/set" + _ "tailscale.com/util/syspolicy" + _ "tailscale.com/util/syspolicy/pkey" + _ "tailscale.com/util/syspolicy/setting" + _ "tailscale.com/util/syspolicy/source" _ "tailscale.com/version" _ "tailscale.com/wgengine" + _ "tailscale.com/wgengine/netstack" _ "tailscale.com/wgengine/router" ) diff --git a/tstest/kernel_linux.go b/tstest/kernel_linux.go index ab7c0d529..ed48fd071 100644 --- a/tstest/kernel_linux.go +++ b/tstest/kernel_linux.go @@ -20,8 +20,13 @@ func KernelVersion() (major, minor, patch int) { return 0, 0, 0 } release := unix.ByteSliceToString(uname.Release[:]) + return parseKernelVersion(release) +} - // Parse version string (e.g., "5.15.0-...") +// parseKernelVersion parses a Linux kernel version string like "6.12.73+deb13-amd64" +// or "5.15.0-76-generic" and returns the major, minor, and patch components. +// It returns (0, 0, 0) if the version cannot be parsed. +func parseKernelVersion(release string) (major, minor, patch int) { parts := strings.Split(release, ".") if len(parts) < 3 { return 0, 0, 0 @@ -37,9 +42,12 @@ func KernelVersion() (major, minor, patch int) { return 0, 0, 0 } - // Patch version may have additional info after a hyphen (e.g., "0-76-generic") - // Extract just the numeric part before any hyphen - patchStr, _, _ := strings.Cut(parts[2], "-") + // Patch version may have additional info after a hyphen or plus (e.g., "0-76-generic" or "41+deb13-amd64") + // Extract just the numeric part before any hyphen or plus + patchStr := parts[2] + if idx := strings.IndexAny(patchStr, "-+"); idx != -1 { + patchStr = patchStr[:idx] + } patch, err = strconv.Atoi(patchStr) if err != nil { diff --git a/tstest/kernel_linux_test.go b/tstest/kernel_linux_test.go new file mode 100644 index 000000000..9445ebe2c --- /dev/null +++ b/tstest/kernel_linux_test.go @@ -0,0 +1,34 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +//go:build linux + +package tstest + +import "testing" + +func TestParseKernelVersion(t *testing.T) { + tests := []struct { + release string + major, minor, patch int + }{ + {"5.15.0-76-generic", 5, 15, 0}, + {"6.12.73+deb13-amd64", 6, 12, 73}, + {"6.1.0-18-amd64", 6, 1, 0}, + {"5.4.0", 5, 4, 0}, + {"6.8.12", 6, 8, 12}, + {"4.19.0+1", 4, 19, 0}, + {"6.12.41+deb13-amd64", 6, 12, 41}, + {"", 0, 0, 0}, + {"not-a-version", 0, 0, 0}, + {"1.2", 0, 0, 0}, + {"a.b.c", 0, 0, 0}, + } + for _, tt := range tests { + major, minor, patch := parseKernelVersion(tt.release) + if major != tt.major || minor != tt.minor || patch != tt.patch { + t.Errorf("parseKernelVersion(%q) = (%d, %d, %d), want (%d, %d, %d)", + tt.release, major, minor, patch, tt.major, tt.minor, tt.patch) + } + } +} diff --git a/tstest/largetailnet/largetailnet.go b/tstest/largetailnet/largetailnet.go new file mode 100644 index 000000000..73ec2da80 --- /dev/null +++ b/tstest/largetailnet/largetailnet.go @@ -0,0 +1,265 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +// Package largetailnet provides reusable building blocks for in-process +// benchmarks and stress tests that drive a single tailnet client (typically a +// [tsnet.Server]) with a synthetic large-tailnet MapResponse stream. +// +// A [Streamer] takes over the map long-poll on a [testcontrol.Server] via the +// AltMapStream hook: it sends one initial MapResponse announcing the self +// node and N synthetic peers, and then forwards caller-supplied delta +// MapResponses on the same stream until ctx is done. +// +// The package is designed so that a benchmark can: +// +// - Build a [Streamer] with the desired peer count. +// - Stand up a [testcontrol.Server] with the streamer's [Streamer.AltMapStream] +// installed. +// - Stand up a [tsnet.Server] pointed at the testcontrol; its Up call +// blocks until the initial netmap has been processed. +// - Reset the benchmark timer and drive add/remove deltas with +// [Streamer.SendDelta] and [Streamer.AllocPeer]. +package largetailnet + +import ( + "context" + cryptorand "crypto/rand" + "fmt" + "net/netip" + "sync/atomic" + "time" + + "go4.org/mem" + "tailscale.com/net/tsaddr" + "tailscale.com/tailcfg" + "tailscale.com/tstest/integration/testcontrol" + "tailscale.com/types/key" +) + +// SelfUserID is the synthetic [tailcfg.UserID] assigned to the self node and +// to every initial peer produced by [Streamer]. Tests that build their own +// peers via [MakePeer] should pass this value. +const SelfUserID tailcfg.UserID = 1_000_000 + +// Streamer drives a controlled MapResponse stream to a single client via +// [testcontrol.Server.AltMapStream]. It synthesizes an initial netmap with N +// peers and forwards caller-supplied delta MapResponses on the same stream. +// +// A Streamer is single-shot: it expects exactly one map long-poll over its +// lifetime and is not safe for re-use across multiple clients. +type Streamer struct { + n int + derpMap *tailcfg.DERPMap + + started chan struct{} // closed when the alt-map-stream callback first fires + initialDone chan struct{} // closed after initial MapResponse has been written + deltas chan *tailcfg.MapResponse + + // nextID is the next free node ID. It starts at N+2 (1 is the self + // node, 2..N+1 are the initial peers) and is bumped by AllocPeer. + nextID atomic.Int64 +} + +// New constructs a Streamer that will produce an initial netmap with n peers +// and a self node when its AltMapStream callback first fires. derpMap is +// included verbatim in the initial MapResponse. +func New(n int, derpMap *tailcfg.DERPMap) *Streamer { + s := &Streamer{ + n: n, + derpMap: derpMap, + started: make(chan struct{}), + initialDone: make(chan struct{}), + // Buffered so a benchmark loop body that does send-then-wait + // doesn't block on the channel under steady state. + deltas: make(chan *tailcfg.MapResponse, 64), + } + s.nextID.Store(int64(n) + 2) + return s +} + +// AltMapStream returns a callback suitable for [testcontrol.Server.AltMapStream]. +// On the first streaming long-poll it sends the initial big MapResponse and +// then forwards deltas enqueued via [Streamer.SendDelta] until ctx is done. +// Non-streaming "lite" polls are answered with an empty MapResponse so they +// complete quickly. The streamer is single-shot: any later streaming polls +// are kept alive but produce no further messages. +func (s *Streamer) AltMapStream() testcontrol.AltMapStreamFunc { + return func(ctx context.Context, w testcontrol.MapStreamWriter, req *tailcfg.MapRequest) { + if !req.Stream { + _ = w.SendMapMessage(&tailcfg.MapResponse{}) + return + } + + select { + case <-s.started: + // Re-poll after the original stream ended. Keep the + // connection alive so the client doesn't churn. + <-ctx.Done() + return + default: + close(s.started) + } + + if err := s.sendInitial(w, req); err != nil { + // Make the failure loud rather than wedging the + // caller's [tsnet.Server.Up] on a silent retry loop. + panic(fmt.Sprintf("largetailnet: sendInitial: %v", err)) + } + close(s.initialDone) + + for { + select { + case <-ctx.Done(): + return + case mr := <-s.deltas: + if err := w.SendMapMessage(mr); err != nil { + <-ctx.Done() + return + } + } + } + } +} + +// AwaitInitialSent blocks until the initial big MapResponse has been written +// to the wire. Note this is not the same as "the client has finished +// processing it"; for that, callers should rely on [tsnet.Server.Up] +// returning, or watch the IPN bus. +func (s *Streamer) AwaitInitialSent(ctx context.Context) error { + select { + case <-ctx.Done(): + return ctx.Err() + case <-s.initialDone: + return nil + } +} + +// SendDelta enqueues mr for delivery on the active MapResponse stream. It +// blocks if the internal queue is full or the stream hasn't started yet. +func (s *Streamer) SendDelta(ctx context.Context, mr *tailcfg.MapResponse) error { + select { + case <-ctx.Done(): + return ctx.Err() + case s.deltas <- mr: + return nil + } +} + +// AllocPeer returns a fresh synthetic peer node with a never-before-used +// [tailcfg.NodeID]. It's intended for use in PeersChanged deltas. +func (s *Streamer) AllocPeer() *tailcfg.Node { + return MakePeer(tailcfg.NodeID(s.nextID.Add(1)-1), SelfUserID) +} + +// SelfNodeID returns the [tailcfg.NodeID] used for the self node in the +// initial netmap. +func (s *Streamer) SelfNodeID() tailcfg.NodeID { return 1 } + +// sendInitial writes the big initial MapResponse with s.n peers. +func (s *Streamer) sendInitial(w testcontrol.MapStreamWriter, req *tailcfg.MapRequest) error { + selfNodeID := s.SelfNodeID() + selfIP4 := node4(selfNodeID) + selfIP6 := node6(selfNodeID) + + peers := make([]*tailcfg.Node, 0, s.n) + for i := 0; i < s.n; i++ { + peers = append(peers, MakePeer(tailcfg.NodeID(i+2), SelfUserID)) + } + + now := time.Now().UTC() + selfNode := &tailcfg.Node{ + ID: selfNodeID, + StableID: "largetailnet-self", + Name: "self.largetailnet.ts.net.", + User: SelfUserID, + Key: req.NodeKey, + KeyExpiry: now.Add(24 * time.Hour), + Machine: randMachineKey(), // fake; client doesn't verify + DiscoKey: req.DiscoKey, + MachineAuthorized: true, + Addresses: []netip.Prefix{selfIP4, selfIP6}, + AllowedIPs: []netip.Prefix{selfIP4, selfIP6}, + CapMap: map[tailcfg.NodeCapability][]tailcfg.RawMessage{}, + } + + initial := &tailcfg.MapResponse{ + KeepAlive: false, + Node: selfNode, + DERPMap: s.derpMap, + Peers: peers, + PacketFilter: []tailcfg.FilterRule{{ + // Accept-all filter so the client isn't logging packet-filter + // failures; this is a benchmark harness, not a security test. + SrcIPs: []string{"*"}, + DstPorts: []tailcfg.NetPortRange{{IP: "*", Ports: tailcfg.PortRangeAny}}, + }}, + DNSConfig: &tailcfg.DNSConfig{}, + Domain: "largetailnet.ts.net", + UserProfiles: []tailcfg.UserProfile{{ + ID: SelfUserID, + LoginName: "largetailnet@example.com", + DisplayName: "largetailnet", + }}, + ControlTime: &now, + } + return w.SendMapMessage(initial) +} + +// MakePeer constructs a synthetic [tailcfg.Node] for the given NodeID and +// UserID. The peer's node/disco/machine keys are derived from random bytes +// via the *PublicFromRaw32 constructors rather than via key.New*().Public(), +// which avoids the per-peer Curve25519 ScalarBaseMult and lets the harness +// construct hundreds of thousands of peers in a few hundred milliseconds. +// The client never crypto-validates these keys in the bench, so opaque +// random bytes are sufficient. +func MakePeer(nid tailcfg.NodeID, user tailcfg.UserID) *tailcfg.Node { + v4, v6 := node4(nid), node6(nid) + name := fmt.Sprintf("peer-%d", nid) + return &tailcfg.Node{ + ID: nid, + StableID: tailcfg.StableNodeID(name), + Name: name + ".largetailnet.ts.net.", + Key: randNodeKey(), + MachineAuthorized: true, + DiscoKey: randDiscoKey(), + Machine: randMachineKey(), + Addresses: []netip.Prefix{v4, v6}, + AllowedIPs: []netip.Prefix{v4, v6}, + User: user, + // Hostinfo must be non-nil: LocalBackend.populatePeerStatus + // dereferences it via HostinfoView.Hostname unconditionally. + Hostinfo: (&tailcfg.Hostinfo{Hostname: name}).View(), + } +} + +func randNodeKey() key.NodePublic { + var b [32]byte + cryptorand.Read(b[:]) + return key.NodePublicFromRaw32(mem.B(b[:])) +} + +func randDiscoKey() key.DiscoPublic { + var b [32]byte + cryptorand.Read(b[:]) + return key.DiscoPublicFromRaw32(mem.B(b[:])) +} + +func randMachineKey() key.MachinePublic { + var b [32]byte + cryptorand.Read(b[:]) + return key.MachinePublicFromRaw32(mem.B(b[:])) +} + +func node4(nid tailcfg.NodeID) netip.Prefix { + return netip.PrefixFrom( + netip.AddrFrom4([4]byte{100, 100 + byte(nid>>16), byte(nid >> 8), byte(nid)}), + 32) +} + +func node6(nid tailcfg.NodeID) netip.Prefix { + a := tsaddr.TailscaleULARange().Addr().As16() + a[13] = byte(nid >> 16) + a[14] = byte(nid >> 8) + a[15] = byte(nid) + return netip.PrefixFrom(netip.AddrFrom16(a), 128) +} diff --git a/tstest/largetailnet/largetailnet_test.go b/tstest/largetailnet/largetailnet_test.go new file mode 100644 index 000000000..07f67df82 --- /dev/null +++ b/tstest/largetailnet/largetailnet_test.go @@ -0,0 +1,218 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package largetailnet_test + +import ( + "context" + "flag" + "net/http/httptest" + "os" + "path/filepath" + "runtime" + "testing" + "time" + + "tailscale.com/ipn/store/mem" + "tailscale.com/tailcfg" + "tailscale.com/tsnet" + "tailscale.com/tstest/integration" + "tailscale.com/tstest/integration/testcontrol" + "tailscale.com/tstest/largetailnet" + "tailscale.com/types/logger" +) + +// tsnet.Server.Up handles the wait-for-ipn.Running step itself: it +// subscribes to the IPN bus with NotifyInitialState and blocks until State +// reaches ipn.Running, which by definition means a netmap has been applied. +// We don't redo that work here. + +var ( + flagActuallyTest = flag.Bool("actually-test-giant-tailnet", false, + "if set, run the BenchmarkGiantTailnet* benchmarks; otherwise they are skipped") + flagN = flag.Int("giant-tailnet-n", 250_000, + "size of the initial netmap (peer count) for BenchmarkGiantTailnet*") + flagBenchVerbose = flag.Bool("giant-tailnet-verbose", false, + "if set, log tsnet output and DERP setup to stderr") +) + +// BenchmarkGiantTailnet measures the per-delta CPU cost of a tailnet client +// processing peer-add/peer-remove deltas in steady state, with no IPN bus +// subscribers attached. This represents the headless-tailscaled workload +// (Linux subnet routers, container sidecars, ...) where the LocalBackend +// does not pay for fanning Notify.NetMap out to GUI watchers. +// +// Use [BenchmarkGiantTailnetBusWatcher] for the GUI-client workload. +// +// The benchmark is opt-in via --actually-test-giant-tailnet. +func BenchmarkGiantTailnet(b *testing.B) { + if !*flagActuallyTest { + b.Skip("set --actually-test-giant-tailnet to run this benchmark") + } + benchGiantTailnet(b, false) +} + +// BenchmarkGiantTailnetBusWatcher is like [BenchmarkGiantTailnet] but +// attaches one [local.Client.WatchIPNBus] subscriber for the duration of the +// benchmark. The Notify-fan-out cost (notably Notify.NetMap encoding to +// every watcher on every full-rebuild path) is therefore included in the +// per-delta measurement, which approximates the GUI-client workload. +// +// The benchmark is opt-in via --actually-test-giant-tailnet. +func BenchmarkGiantTailnetBusWatcher(b *testing.B) { + if !*flagActuallyTest { + b.Skip("set --actually-test-giant-tailnet to run this benchmark") + } + benchGiantTailnet(b, true) +} + +// benchGiantTailnet is the shared body of the BenchmarkGiantTailnet* +// benchmarks. Setup is entirely in-process: a [testcontrol.Server] hosts +// the control plane, a [tsnet.Server] hosts the client, and a +// [largetailnet.Streamer] hijacks the map long-poll to drive an exact +// MapResponse sequence. +// +// Each loop iteration sends one [tailcfg.MapResponse] with PeersChanged +// (a fresh peer) and PeersRemoved (the previous fresh peer), then waits +// for the client to apply it. Net peer count stays at flagN throughout the +// loop. +// +// The wait mechanism differs by variant: +// +// - busWatcher=false: block on a channel returned by +// [ipnlocal.LocalBackend.AwaitNodeKeyForTest] (reached via +// [tsnet.TestHooks]). The channel is closed by LocalBackend the moment +// the just-added peer's key appears in the netmap, so the wait has zero +// polling overhead. +// - busWatcher=true: drain Notify events from the bus subscription, since +// a Notify firing is exactly the side-effect we want to amortize into +// the per-delta measurement. +// +// Recommended invocation for profiling on unmodified main: +// +// go test ./tstest/largetailnet/ -run=^$ \ +// -bench='BenchmarkGiantTailnet(BusWatcher)?$' \ +// -benchtime=2000x -timeout=10m \ +// --actually-test-giant-tailnet \ +// --giant-tailnet-n=250000 \ +// -cpuprofile=/tmp/giant.cpu.pprof +func benchGiantTailnet(b *testing.B, busWatcher bool) { + logf := logger.Discard + if *flagBenchVerbose { + logf = b.Logf + } + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Minute) + b.Cleanup(cancel) + + derpMap := integration.RunDERPAndSTUN(b, logf, "127.0.0.1") + + streamer := largetailnet.New(*flagN, derpMap) + + ctrl := &testcontrol.Server{ + DERPMap: derpMap, + DNSConfig: &tailcfg.DNSConfig{}, + AltMapStream: streamer.AltMapStream(), + Logf: logf, + } + ctrl.HTTPTestServer = httptest.NewUnstartedServer(ctrl) + ctrl.HTTPTestServer.Start() + b.Cleanup(ctrl.HTTPTestServer.Close) + controlURL := ctrl.HTTPTestServer.URL + b.Logf("testcontrol listening on %s", controlURL) + + tmp := filepath.Join(b.TempDir(), "tsnet") + if err := os.MkdirAll(tmp, 0755); err != nil { + b.Fatal(err) + } + + s := &tsnet.Server{ + Dir: tmp, + ControlURL: controlURL, + Hostname: "largetailnet-bench", + Store: new(mem.Store), + Ephemeral: true, + Logf: logf, + } + b.Cleanup(func() { s.Close() }) + + // tsnet.Server.Up blocks until the backend reaches Running, which + // requires the initial flagN-peer MapResponse to have been processed. + upStart := time.Now() + if _, err := s.Up(ctx); err != nil { + b.Fatalf("tsnet.Server.Up: %v", err) + } + b.Logf("initial %d-peer netmap processed in %v", *flagN, time.Since(upStart)) + + lc, err := s.LocalClient() + if err != nil { + b.Fatalf("LocalClient: %v", err) + } + lb := tsnet.TestHooks.LocalBackend(s) + + var notifyCh chan struct{} + if busWatcher { + bw, err := lc.WatchIPNBus(ctx, 0) + if err != nil { + b.Fatalf("WatchIPNBus: %v", err) + } + b.Cleanup(func() { bw.Close() }) + notifyCh = make(chan struct{}, 1024) + go func() { + for { + n, err := bw.Next() + if err != nil { + return + } + if n.NetMap != nil || len(n.PeerChanges) > 0 { + select { + case notifyCh <- struct{}{}: + default: + } + } + } + }() + } + + var prevAdded *tailcfg.Node + runtime.GC() + + b.ResetTimer() + for b.Loop() { + added := streamer.AllocPeer() + mr := &tailcfg.MapResponse{ + PeersChanged: []*tailcfg.Node{added}, + } + if prevAdded != nil { + mr.PeersRemoved = []tailcfg.NodeID{prevAdded.ID} + } + prevAdded = added + + if err := streamer.SendDelta(ctx, mr); err != nil { + b.Fatalf("SendDelta: %v", err) + } + + if busWatcher { + // A Notify firing is itself part of the workload we + // want to measure on this variant. + select { + case <-notifyCh: + case <-time.After(10 * time.Second): + b.Fatal("timed out waiting for notify") + case <-ctx.Done(): + b.Fatalf("ctx done waiting for notify: %v", ctx.Err()) + } + } else { + // Block on the LocalBackend's test-only signal that + // the just-added peer key has landed in the netmap. + // No polling, no notify fan-out cost. + select { + case <-lb.AwaitNodeKeyForTest(added.Key): + case <-time.After(10 * time.Second): + b.Fatalf("timed out waiting for node key %v", added.Key) + case <-ctx.Done(): + b.Fatalf("ctx done waiting for node key: %v", ctx.Err()) + } + } + } +} diff --git a/tstest/natlab/vmtest/assets/event.html b/tstest/natlab/vmtest/assets/event.html new file mode 100644 index 000000000..a5f596673 --- /dev/null +++ b/tstest/natlab/vmtest/assets/event.html @@ -0,0 +1,45 @@ +{{if eq .Type "test_status"}} +{{.Message}} ({{.Detail}}) +{{end}} + +{{if eq .Type "step_changed"}} +
+ {{.Step.Status.Icon}} + {{.Step.Name}} + {{formatDuration .Step.Elapsed}} +
+{{end}} + +{{if eq .Type "console_output"}} +
{{ansi .Message}} +
+{{end}} + +{{if eq .Type "dhcp_discover"}} +Discover sent +{{end}} + +{{if eq .Type "dhcp_offer"}} +Offered {{.Detail}} +{{end}} + +{{if eq .Type "dhcp_request"}} +Requesting {{.Detail}} +{{end}} + +{{if eq .Type "dhcp_ack"}} +Got {{.Detail}} +{{end}} + +{{if eq .Type "tailscale"}} +{{.Detail}} +{{end}} + +{{if eq .Type "screenshot"}} +
+{{end}} + +{{if ne .Type "screenshot"}} +
{{.Time.Format "15:04:05.000"}} {{if .NodeName}}[{{.NodeName}}] {{end}}{{.Message}}{{if .Detail}} {{.Detail}}{{end}}
+
+{{end}} diff --git a/tstest/natlab/vmtest/assets/index.html b/tstest/natlab/vmtest/assets/index.html new file mode 100644 index 000000000..044efffee --- /dev/null +++ b/tstest/natlab/vmtest/assets/index.html @@ -0,0 +1,112 @@ + + + + + VMTest: {{.TestName}} + + + + + + +

VMTest: {{.TestName}} {{.TestStatus.State}} ({{formatDuration .TestStatus.Elapsed}})

+ +
+

Progress

+ {{range .Steps}} +
+ {{.Status.Icon}} + {{.Name}} + {{if ne .Status.String "pending"}}{{formatDuration .Elapsed}}{{end}} +
+ {{end}} +
+ +
+ {{range $node := .Nodes}} +
+
+ {{$node.Name}} + {{$node.OS}} +
+
+ {{range $i, $nic := $node.NICs}} +
+ DHCP{{if gt (len $node.NICs) 1}} ({{$nic.NetName}}){{end}}: + {{$nic.DHCP}} +
+ {{end}} + {{if $node.JoinsTailnet}} +
+ Tailscale: + {{$node.Tailscale}} +
+ {{end}} +
+
{{if $node.Screenshot}}{{end}}
+
{{range $node.Console}}{{ansi .}} +{{end}}
+
+ {{end}} +
+ +
+

Events

+
+
+ + + + + diff --git a/tstest/natlab/vmtest/assets/style.css b/tstest/natlab/vmtest/assets/style.css new file mode 100644 index 000000000..5970598b8 --- /dev/null +++ b/tstest/natlab/vmtest/assets/style.css @@ -0,0 +1,182 @@ +/* CSS reset */ +*, *::before, *::after { box-sizing: border-box; } +* { margin: 0; } +body { + font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif; + line-height: 1.5; + background: #1a1a2e; + color: #e0e0e0; + padding: 16px; +} + +h1 { + font-size: 1.4em; + margin-bottom: 16px; + color: #fff; +} + +.test-status { + font-size: 0.7em; + padding: 2px 10px; + border-radius: 4px; + font-weight: bold; + vertical-align: middle; +} + +.test-Running { background: #2563eb; color: #fff; } +.test-Passed { background: #16a34a; color: #fff; } +.test-Failed { background: #dc2626; color: #fff; } + +h2 { + font-size: 1.1em; + margin-bottom: 8px; + color: #ccc; +} + +/* Step progress panel */ +.steps { + background: #16213e; + border: 1px solid #333; + border-radius: 6px; + padding: 12px; + margin-bottom: 16px; +} + +.step { + display: flex; + align-items: center; + gap: 8px; + padding: 4px 8px; + font-family: monospace; + font-size: 13px; + border-radius: 3px; +} + +.step-pending { color: #666; } +.step-running { color: #4af; font-weight: bold; background: rgba(68, 170, 255, 0.1); } +.step-done { color: #4a4; } +.step-failed { color: #f44; font-weight: bold; background: rgba(255, 68, 68, 0.1); } + +.step-icon { width: 1.2em; text-align: center; } +.step-name { flex: 1; } +.step-time { color: #666; font-size: 12px; min-width: 6em; text-align: right; } + +/* VM card grid */ +.vm-grid { + display: grid; + grid-template-columns: repeat(auto-fit, minmax(400px, 1fr)); + gap: 12px; + margin-bottom: 16px; +} + +.vm-card { + background: #16213e; + border: 1px solid #333; + border-radius: 6px; + padding: 12px; +} + +.vm-header { + display: flex; + align-items: center; + gap: 8px; + margin-bottom: 8px; +} + +.vm-name { + font-weight: bold; + font-size: 1.1em; + color: #fff; +} + +.vm-os { + font-size: 0.8em; + background: #333; + padding: 1px 6px; + border-radius: 3px; + color: #aaa; +} + +.vm-status { + display: flex; + flex-direction: column; + gap: 2px; + margin-bottom: 8px; + font-family: monospace; + font-size: 13px; +} + +.vm-status-line { + display: flex; + gap: 8px; +} + +.vm-status-label { + color: #888; + min-width: 7em; +} + +.vm-status-value { + color: #4af; +} + +/* VM display screenshot */ +.screenshot:empty { display: none; } +.screenshot { + margin-bottom: 4px; +} +.screenshot img { + width: 100%; + height: auto; + display: block; + border-radius: 4px; + border: 1px solid #222; + cursor: pointer; +} + +/* Console output */ +.console { + background: #0a0a0a; + color: #ccc; + font-family: "Cascadia Code", "Fira Code", "Consolas", monospace; + font-size: 11px; + line-height: 1.3; + max-height: 300px; + overflow-y: auto; + white-space: pre-wrap; + word-break: break-all; + padding: 8px; + border-radius: 4px; + border: 1px solid #222; +} + +/* Event log */ +.event-log { + background: #16213e; + border: 1px solid #333; + border-radius: 6px; + padding: 12px; +} + +.events { + max-height: 300px; + overflow-y: auto; +} + +.event { + font-family: monospace; + font-size: 12px; + padding: 1px 0; + border-bottom: 1px solid #1a1a2e; +} + +.event-time { color: #666; } +.event-node { color: #4af; font-weight: bold; } +.event-msg { color: #ccc; } +.event-detail { color: #888; } + +.event-dhcp_discover .event-msg, +.event-dhcp_request .event-msg { color: #fa4; } +.event-dhcp_offer .event-msg, +.event-dhcp_ack .event-msg { color: #4f4; } +.event-step_changed .event-msg { color: #aaf; } diff --git a/tstest/natlab/vmtest/cloudinit.go b/tstest/natlab/vmtest/cloudinit.go index 334863f9c..f0ef704fe 100644 --- a/tstest/natlab/vmtest/cloudinit.go +++ b/tstest/natlab/vmtest/cloudinit.go @@ -120,8 +120,11 @@ func (e *Env) generateLinuxUserData(n *Node) string { ud.WriteString(" - [\"sysctl\", \"-w\", \"net.ipv6.conf.all.forwarding=1\"]\n") } - // Start tailscaled in the background. - ud.WriteString(" - [\"/bin/sh\", \"-c\", \"/usr/local/bin/tailscaled --state=mem: &\"]\n") + // Start tailscaled in the background. --statedir provides a VarRoot so + // features like Taildrop (which needs a place to stash incoming files) + // have a directory to work with. + ud.WriteString(" - [\"mkdir\", \"-p\", \"/var/lib/tailscale\"]\n") + ud.WriteString(" - [\"/bin/sh\", \"-c\", \"/usr/local/bin/tailscaled --state=mem: --statedir=/var/lib/tailscale &\"]\n") ud.WriteString(" - [\"sleep\", \"2\"]\n") // Start tta (Tailscale Test Agent). @@ -170,14 +173,21 @@ func (e *Env) generateFreeBSDUserData(n *Node) string { ud.WriteString(" - \"sysctl net.inet6.ip6.forwarding=1\"\n") } - // Start tailscaled and tta in the background. - // Set PATH to include /usr/local/bin so that tta can find "tailscale" - // (TTA uses exec.Command("tailscale", ...) without a full path). - ud.WriteString(" - \"export PATH=/usr/local/bin:$PATH && /usr/local/bin/tailscaled --state=mem: &\"\n") + // Start tailscaled and tta in the background. Redirect stdio to log + // files and away from /dev/null on stdin; otherwise nuageinit's runcmd + // executor keeps the backgrounded child's stdout/stderr pipes open and + // blocks waiting for them, so subsequent runcmd entries (including the + // tta launch below) never run. Linux cloud-init doesn't have this + // gotcha. Set PATH to include /usr/local/bin so that tta can find + // "tailscale" (TTA uses exec.Command("tailscale", ...) without a full + // path). --statedir provides a VarRoot so features like Taildrop have a + // directory. + ud.WriteString(" - \"mkdir -p /var/lib/tailscale\"\n") + ud.WriteString(" - \"export PATH=/usr/local/bin:$PATH && /usr/local/bin/tailscaled --state=mem: --statedir=/var/lib/tailscale /var/log/tailscaled.log 2>&1 &\"\n") ud.WriteString(" - \"sleep 2\"\n") - // Start tta (Tailscale Test Agent). - ud.WriteString(" - \"export PATH=/usr/local/bin:$PATH && /usr/local/bin/tta &\"\n") + // Start tta (Tailscale Test Agent), with the same stdio redirection. + ud.WriteString(" - \"export PATH=/usr/local/bin:$PATH && /usr/local/bin/tta /var/log/tta.log 2>&1 &\"\n") return ud.String() } diff --git a/tstest/natlab/vmtest/cmd/natlabprep/natlabprep.go b/tstest/natlab/vmtest/cmd/natlabprep/natlabprep.go new file mode 100644 index 000000000..d60b878e3 --- /dev/null +++ b/tstest/natlab/vmtest/cmd/natlabprep/natlabprep.go @@ -0,0 +1,24 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +// The natlabprep tool warms the local natlab vmtest cache by downloading +// every cloud VM image natlab can boot. It is intended for CI prep steps +// so a subsequent test run does not pay the per-image download cost. +package main + +import ( + "context" + "log" + + "tailscale.com/tstest/natlab/vmtest" +) + +func main() { + ctx := context.Background() + for _, img := range vmtest.CloudImages() { + log.Printf("ensuring %s ...", img.Name) + if err := vmtest.EnsureImage(ctx, img); err != nil { + log.Fatalf("ensuring %s: %v", img.Name, err) + } + } +} diff --git a/tstest/natlab/vmtest/connectivity.go b/tstest/natlab/vmtest/connectivity.go new file mode 100644 index 000000000..60d6e0e36 --- /dev/null +++ b/tstest/natlab/vmtest/connectivity.go @@ -0,0 +1,30 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package vmtest + +import ( + "fmt" + "time" +) + +// AddNodeFunc is used to describe a func passed to [RunConnectivityTest]. +type AddNodeFunc func(*Env) *Node + +// RunConnectivityTest adds the specified nodes to the network and then +// verifies that a Disco ping from n1 to n2 completes within 30 seconds. +func (env *Env) RunConnectivityTest(name string, pingRoute PingRoute, n1, n2 AddNodeFunc) { + n1(env) + n2(env) + + discoPingStep := env.AddStep( + fmt.Sprintf("[%s] Ping a → b Disco (want %s)", name, pingRoute)) + env.Start() + + discoPingStep.Begin() + if err := env.PingExpect(env.nodes[0], env.nodes[1], pingRoute, 30*time.Second); err != nil { + discoPingStep.End(err) + env.t.Error(err) + } + discoPingStep.End(nil) +} diff --git a/tstest/natlab/vmtest/connectivity_test.go b/tstest/natlab/vmtest/connectivity_test.go new file mode 100644 index 000000000..f9be6589a --- /dev/null +++ b/tstest/natlab/vmtest/connectivity_test.go @@ -0,0 +1,257 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package vmtest_test + +import ( + "flag" + "fmt" + "testing" + + "tailscale.com/tailcfg" + "tailscale.com/tstest/natlab/vmtest" + "tailscale.com/tstest/natlab/vnet" +) + +var knownBroken = flag.Bool("known-broken", false, "run known-broken tests") + +func v6cidr(n int) string { + return fmt.Sprintf("2000:%d::1/64", n) +} + +func easy(env *vmtest.Env) *vmtest.Node { + n := env.NumNodes() + return env.AddNode(fmt.Sprintf("node-%d", n), + env.AddNetwork( + fmt.Sprintf("2.%d.%d.%d", n, n, n), // public IP + fmt.Sprintf("192.168.%d.1/24", n), vnet.EasyNAT), + vmtest.OS(vmtest.Gokrazy)) +} + +func easyAnd6(env *vmtest.Env) *vmtest.Node { + n := env.NumNodes() + return env.AddNode(fmt.Sprintf("node-%d", n), + env.AddNetwork( + fmt.Sprintf("2.%d.%d.%d", n, n, n), // public IP + fmt.Sprintf("192.168.%d.1/24", n), + v6cidr(n), + vnet.EasyNAT), + vmtest.OS(vmtest.Gokrazy)) +} + +// easyNoControlDiscoRotate sets up a node with easy NAT, cuts traffic to +// control after connecting, and then rotates the disco key to simulate a newly +// started node (from a disco perspective). +func easyNoControlDiscoRotate(env *vmtest.Env) *vmtest.Node { + n := env.NumNodes() + nw := env.AddNetwork( + fmt.Sprintf("2.%d.%d.%d", n, n, n), // public IP + fmt.Sprintf("192.168.%d.1/24", n), + vnet.EasyNAT) + nw.SetPostConnectControlBlackhole(true) + return env.AddNode(fmt.Sprintf("node-%d", n), + vnet.TailscaledEnv{Key: "TS_USE_CACHED_NETMAP", Value: "true"}, + vnet.RotateDisco, vnet.PreICMPPing, + nw, + vmtest.OS(vmtest.Gokrazy)) +} + +// easyFW is easy + host firewall. +func easyFW(env *vmtest.Env) *vmtest.Node { + n := env.NumNodes() + return env.AddNode(fmt.Sprintf("node-%d", n), + vnet.HostFirewall, + env.AddNetwork( + fmt.Sprintf("2.%d.%d.%d", n, n, n), // public IP + fmt.Sprintf("192.168.%d.1/24", n), vnet.EasyNAT), + vmtest.OS(vmtest.Gokrazy)) +} + +// easyPMPFWPlusBPF is easy + port mapping + host firewall + BPF. +func easyPMPFWPlusBPF(env *vmtest.Env) *vmtest.Node { + n := env.NumNodes() + return env.AddNode(fmt.Sprintf("node-%d", n), + vnet.HostFirewall, + vnet.TailscaledEnv{Key: "TS_ENABLE_RAW_DISCO", Value: "true"}, + vnet.TailscaledEnv{Key: "TS_DEBUG_RAW_DISCO", Value: "1"}, + vnet.TailscaledEnv{Key: "TS_DEBUG_DISCO", Value: "1"}, + vnet.TailscaledEnv{Key: "TS_LOG_VERBOSITY", Value: "2"}, + env.AddNetwork( + fmt.Sprintf("2.%d.%d.%d", n, n, n), // public IP + fmt.Sprintf("192.168.%d.1/24", n), vnet.EasyNAT, vnet.NATPMP), + vmtest.OS(vmtest.Gokrazy)) +} + +// easyPMPFWNoBPF is easy + port mapping + host firewall - BPF. +func easyPMPFWNoBPF(env *vmtest.Env) *vmtest.Node { + n := env.NumNodes() + return env.AddNode(fmt.Sprintf("node-%d", n), + vnet.HostFirewall, + vnet.TailscaledEnv{Key: "TS_ENABLE_RAW_DISCO", Value: "false"}, + env.AddNetwork( + fmt.Sprintf("2.%d.%d.%d", n, n, n), // public IP + fmt.Sprintf("192.168.%d.1/24", n), vnet.EasyNAT, vnet.NATPMP), + vmtest.OS(vmtest.Gokrazy)) +} + +func hard(env *vmtest.Env) *vmtest.Node { + n := env.NumNodes() + return env.AddNode(fmt.Sprintf("node-%d", n), + env.AddNetwork( + fmt.Sprintf("2.%d.%d.%d", n, n, n), // public IP + fmt.Sprintf("10.0.%d.1/24", n), vnet.HardNAT), + vmtest.OS(vmtest.Gokrazy)) +} + +func hardNoDERPOrEndpoints(env *vmtest.Env) *vmtest.Node { + n := env.NumNodes() + return env.AddNode(fmt.Sprintf("node-%d", n), + env.AddNetwork( + fmt.Sprintf("2.%d.%d.%d", n, n, n), // public IP + fmt.Sprintf("10.0.%d.1/24", n), vnet.HardNAT), + vnet.TailscaledEnv{Key: "TS_DEBUG_STRIP_ENDPOINTS", Value: "1"}, + vnet.TailscaledEnv{Key: "TS_DEBUG_STRIP_HOME_DERP", Value: "1"}, + vmtest.OS(vmtest.Gokrazy)) +} + +func just6(env *vmtest.Env) *vmtest.Node { + n := env.NumNodes() + return env.AddNode(fmt.Sprintf("node-%d", n), + env.AddNetwork(v6cidr(n)), // public IPv6 prefix + vmtest.OS(vmtest.Gokrazy)) +} + +func v6AndBlackholedIPv4(env *vmtest.Env) *vmtest.Node { + n := env.NumNodes() + nw := env.AddNetwork( + fmt.Sprintf("2.%d.%d.%d", n, n, n), + fmt.Sprintf("192.168.%d.1/24", n), + v6cidr(n), + vnet.EasyNAT) + nw.SetBlackholedIPv4(true) + return env.AddNode(fmt.Sprintf("node-%d", n), nw, vmtest.OS(vmtest.Gokrazy)) +} + +func TestEasyEasy(t *testing.T) { + env := vmtest.New(t) + env.RunConnectivityTest(t.Name(), vmtest.PingRouteDirect, easy, easy) +} + +// TestTwoEasyNoControlDiscoRotate tests a situation where two nodes have been +// online and connected through control, but then lose control access and also +// rotate keys. It is not a perfect proxy for a cached node, as the node will +// still have a mapState and not use the backup method of inserting keys into +// the engine directly. +func TestTwoEasyNoControlDiscoRotate(t *testing.T) { + env := vmtest.New(t) + env.RunConnectivityTest(t.Name(), vmtest.PingRouteDirect, easyNoControlDiscoRotate, easyNoControlDiscoRotate) +} + +func TestJustIPv6(t *testing.T) { + env := vmtest.New(t) + env.RunConnectivityTest(t.Name(), vmtest.PingRouteDirect, just6, just6) +} + +func TestEasy4AndJust6(t *testing.T) { + env := vmtest.New(t) + env.RunConnectivityTest(t.Name(), vmtest.PingRouteDirect, easyAnd6, just6) +} + +func TestSameLAN(t *testing.T) { + env := vmtest.New(t) + var sharedNW *vnet.Network + makeEasy := func(env *vmtest.Env) *vmtest.Node { + n := env.NumNodes() + sharedNW = env.AddNetwork( + fmt.Sprintf("2.%d.%d.%d", n, n, n), // public IP + fmt.Sprintf("192.168.%d.1/24", n), vnet.EasyNAT) + return env.AddNode(fmt.Sprintf("node-%d", n), sharedNW, vmtest.OS(vmtest.Gokrazy)) + } + sameLAN := func(env *vmtest.Env) *vmtest.Node { + n := env.NumNodes() + return env.AddNode(fmt.Sprintf("node-%d", n), sharedNW, vmtest.OS(vmtest.Gokrazy)) + } + env.RunConnectivityTest(t.Name(), vmtest.PingRouteLocal, makeEasy, sameLAN) +} + +// TestBPFDisco tests https://github.com/tailscale/tailscale/issues/3824 ... +// * server behind a Hard NAT +// * client behind a NAT with UPnP support +// * client machine has a stateful host firewall (e.g. ufw) +func TestBPFDisco(t *testing.T) { + env := vmtest.New(t) + env.RunConnectivityTest(t.Name(), vmtest.PingRouteDirect, easyPMPFWPlusBPF, hard) +} + +func TestHostFWNoBPF(t *testing.T) { + env := vmtest.New(t) + env.RunConnectivityTest(t.Name(), vmtest.PingRouteDERP, easyPMPFWNoBPF, hard) +} + +func TestHostFWPair(t *testing.T) { + env := vmtest.New(t) + env.RunConnectivityTest(t.Name(), vmtest.PingRouteDirect, easyFW, easyFW) +} + +func TestOneHostFW(t *testing.T) { + env := vmtest.New(t) + env.RunConnectivityTest(t.Name(), vmtest.PingRouteDirect, easy, easyFW) +} + +// Issue tailscale/corp#26438: use learned DERP route as send path of last +// resort +// +// See (*magicsock.Conn).fallbackDERPRegionForPeer and its comment for +// background. +// +// This sets up a test with two nodes that must use DERP to communicate but the +// target of the ping (the second node) additionally is not getting DERP or +// Endpoint updates from the control plane. (Or rather, it's getting them but is +// configured to scrub them right when they come off the network before being +// processed) This then tests whether node2, upon receiving a packet, will be +// able to reply to node1 since it knows neither node1's endpoints nor its home +// DERP. The only reply route it can use is that fact that it just received a +// packet over a particular DERP from that peer. +func TestFallbackDERPRegionForPeer(t *testing.T) { + env := vmtest.New(t) + env.RunConnectivityTest(t.Name(), vmtest.PingRouteDERP, hard, hardNoDERPOrEndpoints) +} + +// TestSingleJustIPv6 tests that a node can connect to control with just IPv6. +// Since there is no connectivity testing needed, the test just asserts the +// node coming up which will be asserted by env.Start(). +func TestSingleJustIPv6(t *testing.T) { + env := vmtest.New(t) + just6(env) + env.Start() +} + +// TestSingleDualBrokenIPv4 tests a dual-stack node with broken +// (blackholed) IPv4. +// +// See https://github.com/tailscale/tailscale/issues/13346 +func TestSingleDualBrokenIPv4(t *testing.T) { + if !*knownBroken { + t.Skip("skipping known-broken test; set --known-broken to run; see https://github.com/tailscale/tailscale/issues/13346") + } + env := vmtest.New(t) + v6AndBlackholedIPv4(env) + env.Start() +} + +func TestNonTailscaleCGNATEndpoint(t *testing.T) { + env := vmtest.New(t) + + cgnatNW := env.AddNetwork("100.65.1.1/16", "2.1.1.1", vnet.EasyNAT) + n0 := env.AddNode("node-0", + cgnatNW, + vmtest.DontJoinTailnet(), + vmtest.OS(vmtest.Gokrazy)) + n1 := env.AddNode("node-1", + cgnatNW, + tailcfg.NodeCapMap{tailcfg.NodeAttrDisableLinuxCGNATDropRule: nil}, + vmtest.OS(vmtest.Gokrazy)) + + env.Start() + env.LANPing(n1, n0.LanIP(cgnatNW)) +} diff --git a/tstest/natlab/vmtest/images.go b/tstest/natlab/vmtest/images.go index 49eba443f..e01bf1d71 100644 --- a/tstest/natlab/vmtest/images.go +++ b/tstest/natlab/vmtest/images.go @@ -26,10 +26,14 @@ type OSImage struct { SHA256 string // expected SHA256 hash of the image (of the final qcow2, after any decompression) MemoryMB int // RAM for the VM IsGokrazy bool // true for gokrazy images (different QEMU setup) + IsMacOS bool // true for macOS images (launched via tailmac, not QEMU) } // GOOS returns the Go OS name for this image. func (img OSImage) GOOS() string { + if img.IsMacOS { + return "darwin" + } if img.IsGokrazy { return "linux" } @@ -41,6 +45,9 @@ func (img OSImage) GOOS() string { // GOARCH returns the Go architecture name for this image. func (img OSImage) GOARCH() string { + if img.IsMacOS { + return "arm64" + } return "amd64" } @@ -73,8 +80,33 @@ var ( URL: "https://download.freebsd.org/releases/VM-IMAGES/15.0-RELEASE/amd64/Latest/FreeBSD-15.0-RELEASE-amd64-BASIC-CLOUDINIT-ufs.qcow2.xz", MemoryMB: 1024, } + + // MacOS is a macOS VM launched via tailmac (Apple Virtualization.framework). + // Uses a Tart pre-built base image (ghcr.io/cirruslabs/macos-tahoe-base) + // which is automatically pulled on first use. Only runs on macOS arm64 hosts. + MacOS = OSImage{ + Name: "macos", + IsMacOS: true, + MemoryMB: 4096, + } ) +// CloudImages returns the set of QEMU-bootable cloud OS images natlab can +// use for vmtests, excluding gokrazy (built from source) and macOS (which +// uses a separate snapshot pipeline). It is intended for tooling such as +// a CI prep step that wants to warm the image cache. +func CloudImages() []OSImage { + return []OSImage{Ubuntu2404, Debian12, FreeBSD150} +} + +// EnsureImage downloads img to the local cache if not already present. +// It is intended for tooling that wants to warm the image cache before +// running natlab vmtests (e.g. a CI prep step). The test framework also +// calls into the package-internal equivalent on demand. +func EnsureImage(ctx context.Context, img OSImage) error { + return ensureImage(ctx, img) +} + // imageCacheDir returns the directory for cached VM images. func imageCacheDir() string { if d := os.Getenv("VMTEST_CACHE_DIR"); d != "" { diff --git a/tstest/natlab/vmtest/qemu.go b/tstest/natlab/vmtest/qemu.go index df56322fa..73b265078 100644 --- a/tstest/natlab/vmtest/qemu.go +++ b/tstest/natlab/vmtest/qemu.go @@ -5,6 +5,7 @@ package vmtest import ( "bytes" + "context" "encoding/json" "fmt" "net" @@ -12,18 +13,89 @@ import ( "os/exec" "path/filepath" "regexp" + "runtime" "strconv" + "strings" + "testing" "time" "tailscale.com/tstest/natlab/vnet" ) -// startQEMU launches a QEMU process for the given node. -func (e *Env) startQEMU(n *Node) error { - if n.os.IsGokrazy { - return e.startGokrazyQEMU(n) +// qemuAccelArgs returns QEMU command-line flags for hardware-accelerated +// virtualisation when available, or nil to fall back to TCG (software +// emulation). On Linux, KVM is used when /dev/kvm is accessible. On other +// platforms (macOS, etc.) TCG is used, which allows the tests to run +// without a same-architecture hypervisor at the cost of speed. +func qemuAccelArgs() []string { + if runtime.GOOS == "linux" { + if f, err := os.OpenFile("/dev/kvm", os.O_RDWR, 0); err == nil { + f.Close() + return []string{"-enable-kvm", "-cpu", "host"} + } } - return e.startCloudQEMU(n) + return nil +} + +// gokrazyPlatform boots gokrazy (Linux) VMs via QEMU. +type gokrazyPlatform struct{} + +func (gokrazyPlatform) planSteps(e *Env, n *Node) { + e.Step("Build gokrazy image") + e.Step("Launch QEMU: " + n.name) +} + +func (gokrazyPlatform) boot(ctx context.Context, e *Env, n *Node) error { + e.gokrazyOnce.Do(func() { + step := e.Step("Build gokrazy image") + step.Begin() + if err := e.ensureGokrazy(ctx); err != nil { + step.End(err) + e.t.Fatalf("ensureGokrazy: %v", err) + } + step.End(nil) + }) + + e.ensureQEMUSocket() + + vmStep := e.Step("Launch QEMU: " + n.name) + vmStep.Begin() + if err := e.startGokrazyQEMU(n); err != nil { + vmStep.End(err) + return err + } + vmStep.End(nil) + return nil +} + +// qemuCloudPlatform boots cloud images (Ubuntu, Debian, FreeBSD) via QEMU. +type qemuCloudPlatform struct{} + +func (qemuCloudPlatform) planSteps(e *Env, n *Node) { + e.Step(fmt.Sprintf("Compile %s_%s binaries", n.os.GOOS(), n.os.GOARCH())) + e.Step(fmt.Sprintf("Prepare %s image", n.os.Name)) + e.Step("Launch QEMU: " + n.name) +} + +func (qemuCloudPlatform) boot(ctx context.Context, e *Env, n *Node) error { + goos, goarch := n.os.GOOS(), n.os.GOARCH() + + e.ensureCompiled(ctx, goos, goarch) + + if err := e.ensureImage(ctx, n.os); err != nil { + return err + } + + e.ensureQEMUSocket() + + vmStep := e.Step("Launch QEMU: " + n.name) + vmStep.Begin() + if err := e.startCloudQEMU(n); err != nil { + vmStep.End(err) + return err + } + vmStep.End(nil) + return nil } // startGokrazyQEMU launches a QEMU process for a gokrazy node. @@ -40,6 +112,7 @@ func (e *Env) startGokrazyQEMU(n *Node) error { } sysLogAddr := net.JoinHostPort(vnet.FakeSyslogIPv4().String(), "995") if n.vnetNode.IsV6Only() { + fmt.Fprintf(&envBuf, " tta.nameserver=%s", vnet.FakeDNSIPv6()) sysLogAddr = net.JoinHostPort(vnet.FakeSyslogIPv6().String(), "995") } @@ -69,6 +142,7 @@ func (e *Env) startGokrazyQEMU(n *Node) error { ) } + args = append(args, qemuAccelArgs()...) return e.launchQEMU(n.name, logPath, args) } @@ -89,12 +163,11 @@ func (e *Env) startCloudQEMU(n *Node) error { } logPath := filepath.Join(e.tempDir, n.name+".log") - qmpSock := filepath.Join(e.tempDir, n.name+"-qmp.sock") + qmpSock := filepath.Join(e.sockDir, n.name+"-qmp.sock") args := []string{ - "-machine", "q35,accel=kvm", + "-machine", "q35", "-m", fmt.Sprintf("%dM", n.os.MemoryMB), - "-cpu", "host", "-smp", "2", "-display", "none", "-drive", fmt.Sprintf("file=%s,if=virtio", disk), @@ -123,6 +196,8 @@ func (e *Env) startCloudQEMU(n *Node) error { "-device", "virtio-net-pci,netdev=debug0,romfile=", ) + args = append(args, qemuAccelArgs()...) + if err := e.launchQEMU(n.name, logPath, args); err != nil { return err } @@ -137,55 +212,168 @@ func (e *Env) startCloudQEMU(n *Node) error { return nil } -// launchQEMU starts a qemu-system-x86_64 process with the given args. +// qemuRun is one running qemu-system-x86_64 process plus the file handles +// the wrapping code holds open on its behalf. kill tears the whole thing +// down (used both for normal cleanup and for the in-flight retry path). +type qemuRun struct { + cmd *exec.Cmd + parentPipe *os.File + devNull *os.File + qemuLog *os.File +} + +func (r *qemuRun) kill() { + killProcessTree(r.cmd) + r.cmd.Wait() + r.parentPipe.Close() + r.devNull.Close() + r.qemuLog.Close() +} + +// launchQEMU starts a qemu-system-x86_64 process with the given args and +// watches for console activity. If the guest produces no output within +// stuckTimeout (empty console *and* QEMU has not exited with an error), +// the QEMU process is killed and re-launched. This works around CI +// hypervisor flakes seen on shared GitHub Actions runners where a QEMU +// process starts but its vCPU never makes any forward progress (the +// failure presents as both the virtconsole log and the QEMU stderr log +// being zero bytes after many minutes, with the vnet stream socket +// connected but no packet ever sent). +// // VM console output goes to logPath (via QEMU's -serial or -chardev). // QEMU's own stdout/stderr go to logPath.qemu for diagnostics. func (e *Env) launchQEMU(name, logPath string, args []string) error { + // stuckTimeout is generous: a healthy VM prints SeaBIOS/kernel + // output within ~1-2s on KVM, but slow shared CI hardware can lag. + // Setting it too low risks killing a healthy-but-slow VM; setting it + // too high masks the wedge case we want to recover from. + const stuckTimeout = 45 * time.Second + const maxAttempts = 3 + + var lastErr error + for attempt := 1; attempt <= maxAttempts; attempt++ { + if attempt > 1 { + e.t.Logf("[%s] QEMU made no progress in %v; killing and retrying (attempt %d/%d)", name, stuckTimeout, attempt, maxAttempts) + // QEMU's -chardev file backend opens append-mode, so stale + // bytes from a previous attempt would falsely trip the + // progress check on retry. Truncate it. + os.Truncate(logPath, 0) + } + run, err := e.startQEMUOnce(name, logPath, args) + if err != nil { + lastErr = err + continue + } + if waitForConsoleProgress(logPath, stuckTimeout) { + e.qemuProcs = append(e.qemuProcs, run.cmd) + if e.ctx != nil { + go e.tailLogFile(e.ctx, name, logPath) + } + e.t.Cleanup(func() { + run.kill() + // Dump tail of VM log and QEMU's own stderr on failure. + // The console log (logPath) is empty when the guest never + // produced output (e.g. QEMU exited before the kernel ran); + // in that case the .qemu file holds the only diagnostic — + // KVM errors, "kvm not available", CPU model mismatch, etc. + if e.t.Failed() { + dumpLogTail(e.t, name, "console", logPath) + dumpLogTail(e.t, name, "qemu stderr", logPath+".qemu") + } + }) + return nil + } + lastErr = fmt.Errorf("QEMU for %s produced no console output in %v", name, stuckTimeout) + run.kill() + } + return fmt.Errorf("QEMU for %s failed after %d attempts: %w", name, maxAttempts, lastErr) +} + +// startQEMUOnce starts a single qemu-system-x86_64 process. On success the +// returned qemuRun owns the process and all file handles; the caller must +// invoke kill (either inline for a retry or via t.Cleanup for the +// surviving attempt). +func (e *Env) startQEMUOnce(name, logPath string, args []string) (*qemuRun, error) { cmd := exec.Command("qemu-system-x86_64", args...) - // Send stdout/stderr to the log file for any QEMU diagnostic messages. - // Stdin must be /dev/null to prevent QEMU from trying to read. devNull, err := os.Open(os.DevNull) if err != nil { - return fmt.Errorf("open /dev/null: %w", err) + return nil, fmt.Errorf("open /dev/null: %w", err) } cmd.Stdin = devNull qemuLog, err := os.Create(logPath + ".qemu") if err != nil { devNull.Close() - return err + return nil, err } cmd.Stdout = qemuLog cmd.Stderr = qemuLog - if err := cmd.Start(); err != nil { + parentPipe, err := killWithParent(cmd) + if err != nil { devNull.Close() qemuLog.Close() - return fmt.Errorf("qemu for %s: %w", name, err) + return nil, fmt.Errorf("killWithParent: %w", err) + } + if err := cmd.Start(); err != nil { + parentPipe.Close() + devNull.Close() + qemuLog.Close() + return nil, fmt.Errorf("qemu for %s: %w", name, err) } e.t.Logf("launched QEMU for %s (pid %d), log: %s", name, cmd.Process.Pid, logPath) - e.qemuProcs = append(e.qemuProcs, cmd) - e.t.Cleanup(func() { - cmd.Process.Kill() - cmd.Wait() - devNull.Close() - qemuLog.Close() - // Dump tail of VM log on failure for debugging. - if e.t.Failed() { - if data, err := os.ReadFile(logPath); err == nil { - lines := bytes.Split(data, []byte("\n")) - start := 0 - if len(lines) > 50 { - start = len(lines) - 50 - } - e.t.Logf("=== last 50 lines of %s log ===", name) - for _, line := range lines[start:] { - e.t.Logf("[%s] %s", name, line) - } - } - } - }) - return nil + return &qemuRun{ + cmd: cmd, + parentPipe: parentPipe, + devNull: devNull, + qemuLog: qemuLog, + }, nil } +// waitForConsoleProgress polls logPath until its size is non-zero or +// timeout elapses. It returns true on observed forward progress (any +// bytes written), false on timeout. +func waitForConsoleProgress(logPath string, timeout time.Duration) bool { + deadline := time.Now().Add(timeout) + for time.Now().Before(deadline) { + if fi, err := os.Stat(logPath); err == nil && fi.Size() > 0 { + return true + } + time.Sleep(200 * time.Millisecond) + } + return false +} + +// dumpLogTail prints the last 50 lines of the file at path to the test log, +// prefixed with the VM name and kind (e.g. "console", "qemu stderr"). It is +// a no-op (with a short note) if the file can't be read or is empty, so +// callers can use it unconditionally on test failure. +func dumpLogTail(t testing.TB, name, kind, path string) { + t.Helper() + data, err := os.ReadFile(path) + if err != nil { + t.Logf("=== %s %s log unavailable: %v ===", name, kind, err) + return + } + if len(data) == 0 { + t.Logf("=== %s %s log is empty ===", name, kind) + return + } + lines := bytes.Split(data, []byte("\n")) + start := 0 + if len(lines) > 50 { + start = len(lines) - 50 + } + t.Logf("=== last 50 lines of %s %s log ===", name, kind) + for _, line := range lines[start:] { + t.Logf("[%s] %s", name, line) + } +} + +// hostFwdRe matches a single TCP[HOST_FORWARD] line from QEMU's +// "info usernet" human-monitor command output, e.g.: +// +// TCP[HOST_FORWARD] 12 127.0.0.1 35323 10.0.2.15 22 +var hostFwdRe = regexp.MustCompile(`TCP\[HOST_FORWARD\]\s+\d+\s+127\.0\.0\.1\s+(\d+)\s+`) + // qmpQueryHostFwd connects to a QEMU QMP socket and queries the host port // assigned to the first TCP host forward rule (the SSH debug port). func qmpQueryHostFwd(sockPath string) (int, error) { @@ -203,7 +391,7 @@ func qmpQueryHostFwd(sockPath string) (int, error) { return 0, fmt.Errorf("QMP socket %s not available", sockPath) } defer conn.Close() - conn.SetDeadline(time.Now().Add(5 * time.Second)) + conn.SetDeadline(time.Now().Add(20 * time.Second)) // Read the QMP greeting. var greeting json.RawMessage @@ -219,21 +407,94 @@ func qmpQueryHostFwd(sockPath string) (int, error) { return 0, fmt.Errorf("reading qmp_capabilities response: %w", err) } - // Query "info usernet" via human-monitor-command. - fmt.Fprintf(conn, `{"execute":"human-monitor-command","arguments":{"command-line":"info usernet"}}`+"\n") - var hmpResp struct { - Return string `json:"return"` + // Poll "info usernet" until the SLIRP host-forward rule appears. + // On slow runners (e.g. GitHub Actions) QEMU sometimes returns an + // empty "info usernet" if we query it before user-mode networking + // has finished wiring up the forward, so single-shot lookups fail. + deadline := time.Now().Add(10 * time.Second) + var lastReturn string + for { + fmt.Fprintf(conn, `{"execute":"human-monitor-command","arguments":{"command-line":"info usernet"}}`+"\n") + var hmpResp struct { + Return string `json:"return"` + } + if err := dec.Decode(&hmpResp); err != nil { + return 0, fmt.Errorf("reading info usernet response: %w", err) + } + lastReturn = hmpResp.Return + if m := hostFwdRe.FindStringSubmatch(hmpResp.Return); m != nil { + return strconv.Atoi(m[1]) + } + if time.Now().After(deadline) { + break + } + time.Sleep(100 * time.Millisecond) + } + return 0, fmt.Errorf("no hostfwd port found after waiting: %q", lastReturn) +} + +// tailLogFile tails a VM's serial console log file and publishes each line +// as an EventConsoleOutput to the event bus for the web UI. +func (e *Env) tailLogFile(ctx context.Context, name, logPath string) { + // Wait for the file to appear (QEMU may not have created it yet). + var f *os.File + for { + var err error + f, err = os.Open(logPath) + if err == nil { + break + } + select { + case <-ctx.Done(): + return + case <-time.After(100 * time.Millisecond): + } + } + defer f.Close() + + // Read the file in a loop, tracking our position manually. + // We can't use bufio.Scanner because it caches EOF and won't + // pick up new data appended by QEMU after the first EOF. + var buf []byte + var partial string // incomplete line (no trailing newline yet) + readBuf := make([]byte, 4096) + for { + n, err := f.Read(readBuf) + if n > 0 { + buf = append(buf, readBuf[:n]...) + // Split into complete lines. + for { + idx := bytes.IndexByte(buf, '\n') + if idx < 0 { + break + } + line := partial + string(buf[:idx]) + partial = "" + buf = buf[idx+1:] + // Strip trailing \r from serial consoles. + line = strings.TrimRight(line, "\r") + if line == "" { + continue + } + e.appendConsoleLine(name, line) + e.eventBus.Publish(VMEvent{ + NodeName: name, + Type: EventConsoleOutput, + Message: line, + }) + } + if len(buf) > 0 { + partial = string(buf) + buf = buf[:0] + } + } + if err != nil || n == 0 { + // EOF or error — wait for more data. + select { + case <-ctx.Done(): + return + case <-time.After(100 * time.Millisecond): + } + } } - if err := dec.Decode(&hmpResp); err != nil { - return 0, fmt.Errorf("reading info usernet response: %w", err) - } - - // Parse the port from output like: - // TCP[HOST_FORWARD] 12 127.0.0.1 35323 10.0.2.15 22 - re := regexp.MustCompile(`TCP\[HOST_FORWARD\]\s+\d+\s+127\.0\.0\.1\s+(\d+)\s+`) - m := re.FindStringSubmatch(hmpResp.Return) - if m == nil { - return 0, fmt.Errorf("no hostfwd port found in: %s", hmpResp.Return) - } - return strconv.Atoi(m[1]) } diff --git a/tstest/natlab/vmtest/qemu_wrapper.go b/tstest/natlab/vmtest/qemu_wrapper.go new file mode 100644 index 000000000..5b5843bed --- /dev/null +++ b/tstest/natlab/vmtest/qemu_wrapper.go @@ -0,0 +1,90 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +//go:build unix + +package vmtest + +import ( + "fmt" + "io" + "log" + "os" + "os/exec" + "strconv" + "syscall" +) + +// Re-exec'd as a wrapper around QEMU: when the test process dies (any +// reason, including SIGKILL), the kernel closes the pipe write end, the +// wrapper sees EOF, and kills QEMU's process group. + +const wrapperEnv = "TS_VMTEST_QEMU_WRAPPER" + +func init() { + if os.Getenv(wrapperEnv) == "" { + return + } + runQEMUWrapper() +} + +func runQEMUWrapper() { + fd, err := strconv.Atoi(os.Getenv(wrapperEnv)) + if err != nil { + log.Fatalf("vmtest qemu wrapper: bad %s: %v", wrapperEnv, err) + } + os.Unsetenv(wrapperEnv) + if len(os.Args) < 2 { + log.Fatalf("vmtest qemu wrapper: missing command") + } + pipeFd := os.NewFile(uintptr(fd), "parent-pipe") + + // QEMU inherits our pgid (the test set Setpgid on us), so a group kill + // from the test reaches QEMU too. Don't set Setpgid here. + cmd := exec.Command(os.Args[1], os.Args[2:]...) + cmd.Stdin, cmd.Stdout, cmd.Stderr = os.Stdin, os.Stdout, os.Stderr + if err := cmd.Start(); err != nil { + log.Fatalf("vmtest qemu wrapper: %v", err) + } + + go func() { + // Block until the parent's pipe write end closes (EOF), then kill + // our process group (which includes QEMU and any of its children). + io.Copy(io.Discard, pipeFd) + syscall.Kill(0, syscall.SIGKILL) + }() + + cmd.Wait() +} + +// killWithParent rewrites cmd to run via a wrapper that kills it if the +// test process dies. The returned *os.File must be kept alive until the +// command is no longer needed; closing it makes the wrapper exit. +func killWithParent(cmd *exec.Cmd) (*os.File, error) { + self, err := os.Executable() + if err != nil { + return nil, fmt.Errorf("os.Executable: %w", err) + } + r, w, err := os.Pipe() + if err != nil { + return nil, fmt.Errorf("pipe: %w", err) + } + + cmd.ExtraFiles = append(cmd.ExtraFiles, r) + pipeFd := 3 + len(cmd.ExtraFiles) - 1 // stdin/stdout/stderr + ExtraFiles index + cmd.Args = append([]string{self, cmd.Path}, cmd.Args[1:]...) + cmd.Path = self + if cmd.Env == nil { + cmd.Env = os.Environ() + } + cmd.Env = append(cmd.Env, fmt.Sprintf("%s=%d", wrapperEnv, pipeFd)) + cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true} + + return w, nil +} + +// killProcessTree SIGKILLs cmd's process group (cmd plus any descendants +// that didn't escape it). +func killProcessTree(cmd *exec.Cmd) error { + return syscall.Kill(-cmd.Process.Pid, syscall.SIGKILL) +} diff --git a/tstest/natlab/vmtest/qemu_wrapper_windows.go b/tstest/natlab/vmtest/qemu_wrapper_windows.go new file mode 100644 index 000000000..59d96fad9 --- /dev/null +++ b/tstest/natlab/vmtest/qemu_wrapper_windows.go @@ -0,0 +1,20 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package vmtest + +import ( + "os" + "os/exec" +) + +// Stubs for Windows: no parent-death watcher, no process-group kill. +// The test still launches QEMU; cleanup just kills the single process. + +func killWithParent(cmd *exec.Cmd) (*os.File, error) { + return os.Open(os.DevNull) +} + +func killProcessTree(cmd *exec.Cmd) error { + return cmd.Process.Kill() +} diff --git a/tstest/natlab/vmtest/tailmac.go b/tstest/natlab/vmtest/tailmac.go new file mode 100644 index 000000000..167feeb04 --- /dev/null +++ b/tstest/natlab/vmtest/tailmac.go @@ -0,0 +1,736 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package vmtest + +import ( + "bufio" + "bytes" + "context" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "net" + "net/http" + "net/netip" + "os" + "os/exec" + "path/filepath" + "strings" + "testing" + "time" + + "golang.org/x/crypto/ssh" +) + +// macPlatform boots macOS VMs via Tart base images and tailmac Host.app. +type macPlatform struct{} + +func (macPlatform) planSteps(e *Env, n *Node) { + e.Step("Prepare macOS Tart image") + e.Step("Launch macOS VM: " + n.name) +} + +func (macPlatform) boot(ctx context.Context, e *Env, n *Node) error { + imgStep := e.Step("Prepare macOS Tart image") + e.macosSnapshotOnce.Do(func() { + imgStep.Begin() + e.macosSnapshot = ensureSnapshot(e.t) + imgStep.End(nil) + }) + + e.ensureDgramSocket() + + vmStep := e.Step("Launch macOS VM: " + n.name) + vmStep.Begin() + if err := e.startTailMacVM(n); err != nil { + vmStep.End(err) + return err + } + vmStep.End(nil) + return nil +} + +const tartImage = "ghcr.io/cirruslabs/macos-tahoe-base:latest" + +// macOSSnapshotCodeVersion is bumped when the snapshot preparation logic +// changes in a way that invalidates old snapshots. Old snapshots with a +// different version are cleaned up automatically. +const macOSSnapshotCodeVersion = 5 + +// tartConfig is the subset of Tart's config.json we need. +type tartConfig struct { + HardwareModel string `json:"hardwareModel"` // base64 + ECID string `json:"ecid"` // base64 +} + +// tartManifest is the subset of Tart's OCI manifest.json we need. +type tartManifest struct { + Config struct { + Digest string `json:"digest"` // e.g. "sha256:3a6cb4eb6201..." + } `json:"config"` +} + +// ensureTartImage checks that the Tart base image is available, pulling it +// if necessary. Returns the path to the OCI cache directory containing +// disk.img, nvram.bin, config.json, and manifest.json. +func ensureTartImage(t testing.TB) string { + if _, err := exec.LookPath("tart"); err != nil { + t.Skip("tart not installed; skipping macOS VM test") + } + + home, err := os.UserHomeDir() + if err != nil { + t.Fatalf("UserHomeDir: %v", err) + } + + ociDir := filepath.Join(home, ".tart", "cache", "OCIs", + "ghcr.io", "cirruslabs", "macos-tahoe-base", "latest") + if _, err := os.Stat(filepath.Join(ociDir, "disk.img")); err == nil { + return ociDir + } + + t.Logf("pulling Tart image %s ...", tartImage) + cmd := exec.Command("tart", "pull", tartImage) + cmd.Stdout = os.Stderr + cmd.Stderr = os.Stderr + if err := cmd.Run(); err != nil { + t.Fatalf("tart pull: %v", err) + } + + if _, err := os.Stat(filepath.Join(ociDir, "disk.img")); err == nil { + return ociDir + } + t.Fatalf("tart pull succeeded but image not found at %s", ociDir) + return "" +} + +// snapshotCacheKey computes a cache key for the macOS VM snapshot. +// The key combines the image name, the first 12 hex chars of the Tart +// config digest (changes when the upstream image is updated), and the +// snapshot code version (changes when our prep logic changes). +func snapshotCacheKey(tartDir string) (string, error) { + manifestPath := filepath.Join(tartDir, "manifest.json") + data, err := os.ReadFile(manifestPath) + if err != nil { + return "", fmt.Errorf("reading manifest: %w", err) + } + var m tartManifest + if err := json.Unmarshal(data, &m); err != nil { + return "", fmt.Errorf("parsing manifest: %w", err) + } + digest := m.Config.Digest + // Strip "sha256:" prefix and take first 12 hex chars. + digest = strings.TrimPrefix(digest, "sha256:") + if len(digest) > 12 { + digest = digest[:12] + } + return fmt.Sprintf("snap-tahoe-%s-v%d", digest, macOSSnapshotCodeVersion), nil +} + +// macosVMBaseDir returns ~/.cache/tailscale/vmtest/macos/, the directory +// where Host.app expects to find VM directories by ID. +func macosVMBaseDir() (string, error) { + home, err := os.UserHomeDir() + if err != nil { + return "", err + } + return filepath.Join(home, ".cache", "tailscale", "vmtest", "macos"), nil +} + +// cleanOldSnapshots removes any snapshot directories for the given image +// prefix (e.g. "snap-tahoe") that don't match the current cache key. +func cleanOldSnapshots(t testing.TB, imagePrefix, currentKey string) { + base, err := macosVMBaseDir() + if err != nil { + return + } + matches, _ := filepath.Glob(filepath.Join(base, imagePrefix+"-*")) + currentPath := filepath.Join(base, currentKey) + for _, m := range matches { + if m != currentPath { + t.Logf("removing stale snapshot: %s", filepath.Base(m)) + os.RemoveAll(m) + } + } +} + +// ensureSnapshot returns the path to a cached macOS VM snapshot, creating +// one if necessary. The snapshot contains a fully booted VM with +// SaveFile.vzvmsave ready for fast restore. +func ensureSnapshot(t testing.TB) string { + tartDir := ensureTartImage(t) + + key, err := snapshotCacheKey(tartDir) + if err != nil { + t.Fatalf("snapshot cache key: %v", err) + } + + base, err := macosVMBaseDir() + if err != nil { + t.Fatalf("macOS VM base dir: %v", err) + } + os.MkdirAll(base, 0755) + + snapDir := filepath.Join(base, key) + saveFile := filepath.Join(snapDir, "SaveFile.vzvmsave") + if _, err := os.Stat(saveFile); err == nil { + t.Logf("using cached macOS snapshot: %s", key) + return snapDir + } + + // Clean up old snapshots for this image. + cleanOldSnapshots(t, "snap-tahoe", key) + + t.Logf("preparing macOS snapshot: %s (this takes ~30s on first run)", key) + if err := prepareSnapshot(t, tartDir, snapDir); err != nil { + os.RemoveAll(snapDir) + t.Fatalf("preparing snapshot: %v", err) + } + return snapDir +} + +// prepareSnapshot creates a new macOS VM snapshot by booting the Tart base +// image with a NAT NIC, waiting for SSH, and saving VM state. +func prepareSnapshot(t testing.TB, tartDir, snapDir string) error { + // The vmID must match the directory name under macosVMBaseDir + // because Host.app looks up VM files at //. + snapID := filepath.Base(snapDir) + + if err := cloneTartToTailmac(tartDir, snapDir, snapID, "52:cc:cc:cc:ce:01", "/dev/null"); err != nil { + return fmt.Errorf("cloning tart: %w", err) + } + + modRoot, err := findModRoot() + if err != nil { + return err + } + tailmacDir := filepath.Join(modRoot, "tstest", "tailmac", "bin") + hostBin := filepath.Join(tailmacDir, "Host.app", "Contents", "MacOS", "Host") + if _, err := os.Stat(hostBin); err != nil { + return fmt.Errorf("Host.app not found at %s; run 'make all' in tstest/tailmac/", hostBin) + } + + // Host.app reads VM files from ~/.cache/tailscale/vmtest/macos//. + // Our snapDir is already under that tree, and the config.json vmID matches. + cmd := exec.Command(hostBin, "run", "--id", snapID, "--headless", "--nat-nic") + cmd.Env = append(os.Environ(), "NSUnbufferedIO=YES") + + logPath := snapDir + ".prep.log" + logFile, err := os.Create(logPath) + if err != nil { + return err + } + defer logFile.Close() + cmd.Stdout = logFile + cmd.Stderr = logFile + devNull, _ := os.Open(os.DevNull) + cmd.Stdin = devNull + defer devNull.Close() + + if err := cmd.Start(); err != nil { + return fmt.Errorf("starting Host.app: %w", err) + } + t.Logf("snapshot prep: launched Host.app (pid %d)", cmd.Process.Pid) + + // Wait for SSH to become available via the NAT NIC. + // The VM gets an IP from macOS's vmnet DHCP (typically 192.168.64.x). + ip, err := waitForVMIP(t, "52:cc:cc:cc:ce:01", 60*time.Second) + if err != nil { + cmd.Process.Kill() + cmd.Wait() + return fmt.Errorf("waiting for VM IP: %w", err) + } + t.Logf("snapshot prep: VM IP is %s, waiting for SSH...", ip) + + sc, err := waitForSSH(ip, 60*time.Second) + if err != nil { + cmd.Process.Kill() + cmd.Wait() + return fmt.Errorf("waiting for SSH: %w", err) + } + t.Logf("snapshot prep: SSH connected") + + // Compile and install TTA in the macOS VM. + t.Logf("snapshot prep: installing TTA...") + if err := installTTA(t, sc); err != nil { + sc.Close() + cmd.Process.Kill() + cmd.Wait() + return fmt.Errorf("installing TTA: %w", err) + } + sc.Close() + + // Save VM state by sending SIGINT. + t.Logf("snapshot prep: saving VM state...") + cmd.Process.Signal(os.Interrupt) + done := make(chan error, 1) + go func() { done <- cmd.Wait() }() + select { + case err := <-done: + if err != nil { + // Host.app exits 0 after saving state, non-zero is unexpected. + t.Logf("snapshot prep: Host.app exited with: %v", err) + } + case <-time.After(60 * time.Second): + cmd.Process.Kill() + <-done + return fmt.Errorf("Host.app did not exit after SIGINT") + } + + // Verify the save file was created. + saveFile := filepath.Join(snapDir, "SaveFile.vzvmsave") + if _, err := os.Stat(saveFile); err != nil { + return fmt.Errorf("SaveFile.vzvmsave not found after prep") + } + t.Logf("snapshot prep: done, saved to %s", filepath.Base(snapDir)) + os.Remove(logPath) + return nil +} + +// installTTA compiles TTA for darwin/arm64 and installs it in the macOS VM +// as a LaunchDaemon via SSH/SCP. +func installTTA(t testing.TB, sc *ssh.Client) error { + modRoot, err := findModRoot() + if err != nil { + return err + } + + // Compile TTA for the macOS VM. + tmpDir := t.TempDir() + ttaBin := filepath.Join(tmpDir, "tta") + t.Logf("snapshot prep: compiling TTA for darwin/arm64...") + buildCmd := exec.Command("go", "build", "-o", ttaBin, "./cmd/tta") + buildCmd.Dir = modRoot + buildCmd.Env = append(os.Environ(), "GOOS=darwin", "GOARCH=arm64", "CGO_ENABLED=0") + if out, err := buildCmd.CombinedOutput(); err != nil { + return fmt.Errorf("compiling TTA: %v\n%s", err, out) + } + + // Read the binary. + ttaData, err := os.ReadFile(ttaBin) + if err != nil { + return fmt.Errorf("reading TTA binary: %w", err) + } + t.Logf("snapshot prep: TTA binary is %d bytes", len(ttaData)) + + // SCP the TTA binary to the VM via a temp file (admin user can't write /usr/local/bin directly). + if err := scpFile(sc, ttaData, "/tmp/tta", 0755); err != nil { + return fmt.Errorf("uploading TTA: %w", err) + } + if err := runSSHCmd(sc, "echo admin | sudo -S mv /tmp/tta /usr/local/bin/tta"); err != nil { + return fmt.Errorf("moving TTA to /usr/local/bin: %w", err) + } + + // Install the LaunchDaemon plist. + plist := ` + + + + Label + com.tailscale.tta + ProgramArguments + + /usr/local/bin/tta + + RunAtLoad + + KeepAlive + + StandardOutPath + /tmp/tta.log + StandardErrorPath + /tmp/tta.log + + +` + if err := scpFile(sc, []byte(plist), "/tmp/com.tailscale.tta.plist", 0644); err != nil { + return fmt.Errorf("uploading plist: %w", err) + } + if err := runSSHCmd(sc, "echo admin | sudo -S mv /tmp/com.tailscale.tta.plist /Library/LaunchDaemons/ && echo admin | sudo -S chown root:wheel /Library/LaunchDaemons/com.tailscale.tta.plist"); err != nil { + return fmt.Errorf("installing plist: %w", err) + } + + // Load the LaunchDaemon. + if err := runSSHCmd(sc, "echo admin | sudo -S launchctl load /Library/LaunchDaemons/com.tailscale.tta.plist"); err != nil { + return fmt.Errorf("loading LaunchDaemon: %w", err) + } + + // Wait for TTA to start. + for range 20 { + if err := runSSHCmd(sc, "pgrep -x tta"); err == nil { + break + } + time.Sleep(250 * time.Millisecond) + } + if err := runSSHCmd(sc, "pgrep -x tta"); err != nil { + return fmt.Errorf("TTA not running after install: %w", err) + } + t.Logf("snapshot prep: TTA installed and running") + return nil +} + +// scpFile uploads data to a remote path via SSH/SCP. +func scpFile(sc *ssh.Client, data []byte, remotePath string, mode os.FileMode) error { + sess, err := sc.NewSession() + if err != nil { + return err + } + defer sess.Close() + + // Use a simple shell command to write the file. + cmd := fmt.Sprintf("cat > %s && chmod %o %s", remotePath, mode, remotePath) + sess.Stdin = bytes.NewReader(data) + out, err := sess.CombinedOutput(cmd) + if err != nil { + return fmt.Errorf("%s: %v: %s", cmd, err, out) + } + return nil +} + +// runSSHCmd runs a command on the SSH client and returns an error if it fails. +func runSSHCmd(sc *ssh.Client, cmd string) error { + sess, err := sc.NewSession() + if err != nil { + return err + } + defer sess.Close() + out, err := sess.CombinedOutput(cmd) + if err != nil { + return fmt.Errorf("%s: %v: %s", cmd, err, out) + } + return nil +} + +// waitForVMIP polls /var/db/dhcpd_leases for a DHCP lease matching the +// given MAC address (from macOS's vmnet NAT). Returns the IP. +func waitForVMIP(t testing.TB, mac string, timeout time.Duration) (string, error) { + // Normalize MAC format: vmnet leases use "1,xx:xx:xx:xx:xx:xx" format + // with leading zeros stripped from each octet (e.g. "1,52:cc:cc:cc:ce:1" + // instead of "1,52:cc:cc:cc:ce:01"). + mac = strings.ToLower(mac) + parts := strings.Split(mac, ":") + for i, p := range parts { + parts[i] = strings.TrimLeft(p, "0") + if parts[i] == "" { + parts[i] = "0" + } + } + leaseMAC := "1," + strings.Join(parts, ":") + + deadline := time.Now().Add(timeout) + for time.Now().Before(deadline) { + data, err := os.ReadFile("/var/db/dhcpd_leases") + if err == nil { + // Parse the plist-like lease file. + lines := strings.Split(string(data), "\n") + var currentIP string + for _, line := range lines { + line = strings.TrimSpace(line) + if strings.HasPrefix(line, "ip_address=") { + currentIP = strings.TrimPrefix(line, "ip_address=") + } + if strings.HasPrefix(line, "hw_address=") { + hw := strings.TrimPrefix(line, "hw_address=") + if strings.ToLower(hw) == leaseMAC && currentIP != "" { + return currentIP, nil + } + } + if line == "}" { + currentIP = "" + } + } + } + time.Sleep(time.Second) + } + return "", fmt.Errorf("no DHCP lease for MAC %s after %v", mac, timeout) +} + +// waitForSSH retries SSH connection to the given IP until it succeeds or +// the timeout expires. +func waitForSSH(ip string, timeout time.Duration) (*ssh.Client, error) { + deadline := time.Now().Add(timeout) + addr := net.JoinHostPort(ip, "22") + cfg := &ssh.ClientConfig{ + User: "admin", + Auth: []ssh.AuthMethod{ssh.Password("admin")}, + HostKeyCallback: ssh.InsecureIgnoreHostKey(), + Timeout: 2 * time.Second, + } + for time.Now().Before(deadline) { + sc, err := ssh.Dial("tcp", addr, cfg) + if err == nil { + return sc, nil + } + time.Sleep(time.Second) + } + return nil, fmt.Errorf("SSH to %s timed out after %v", addr, timeout) +} + +// ensureTailMac locates the pre-built tailmac Host.app binary. +func (e *Env) ensureTailMac() error { + modRoot, err := findModRoot() + if err != nil { + return err + } + e.tailmacDir = filepath.Join(modRoot, "tstest", "tailmac", "bin") + hostApp := filepath.Join(e.tailmacDir, "Host.app", "Contents", "MacOS", "Host") + if _, err := os.Stat(hostApp); err != nil { + return fmt.Errorf("tailmac Host.app not found at %s; run 'make all' in tstest/tailmac/", hostApp) + } + return nil +} + +// cloneTartToTailmac creates a tailmac-compatible VM directory from a Tart +// base image. It uses APFS CoW clones for the disk and NVRAM, and extracts +// the hardware identity from Tart's config.json. +func cloneTartToTailmac(tartDir, cloneDir, testID, mac, dgramSock string) error { + if err := os.MkdirAll(cloneDir, 0755); err != nil { + return err + } + + cfgData, err := os.ReadFile(filepath.Join(tartDir, "config.json")) + if err != nil { + return fmt.Errorf("reading tart config: %w", err) + } + var tc tartConfig + if err := json.Unmarshal(cfgData, &tc); err != nil { + return fmt.Errorf("parsing tart config: %w", err) + } + + hwModel, err := base64.StdEncoding.DecodeString(tc.HardwareModel) + if err != nil { + return fmt.Errorf("decoding hardwareModel: %w", err) + } + if err := os.WriteFile(filepath.Join(cloneDir, "HardwareModel"), hwModel, 0644); err != nil { + return err + } + + ecid, err := base64.StdEncoding.DecodeString(tc.ECID) + if err != nil { + return fmt.Errorf("decoding ecid: %w", err) + } + if err := os.WriteFile(filepath.Join(cloneDir, "MachineIdentifier"), ecid, 0644); err != nil { + return err + } + + if out, err := exec.Command("cp", "-c", filepath.Join(tartDir, "disk.img"), filepath.Join(cloneDir, "Disk.img")).CombinedOutput(); err != nil { + if out2, err2 := exec.Command("cp", filepath.Join(tartDir, "disk.img"), filepath.Join(cloneDir, "Disk.img")).CombinedOutput(); err2 != nil { + return fmt.Errorf("copying disk: %v: %s (APFS clone: %v: %s)", err2, out2, err, out) + } + } + + if out, err := exec.Command("cp", "-c", filepath.Join(tartDir, "nvram.bin"), filepath.Join(cloneDir, "AuxiliaryStorage")).CombinedOutput(); err != nil { + if out2, err2 := exec.Command("cp", filepath.Join(tartDir, "nvram.bin"), filepath.Join(cloneDir, "AuxiliaryStorage")).CombinedOutput(); err2 != nil { + return fmt.Errorf("copying nvram: %v: %s (APFS clone: %v: %s)", err2, out2, err, out) + } + } + + tmCfg := struct { + VMid string `json:"vmID"` + ServerSocket string `json:"serverSocket"` + MemorySize uint64 `json:"memorySize"` + Mac string `json:"mac"` + }{ + VMid: testID, + ServerSocket: dgramSock, + MemorySize: 4 * 1024 * 1024 * 1024, + Mac: mac, + } + tmData, _ := json.MarshalIndent(tmCfg, "", " ") + return os.WriteFile(filepath.Join(cloneDir, "config.json"), tmData, 0644) +} + +// startTailMacVM restores a macOS VM from a cached snapshot and launches it +// via tailmac Host.app in headless mode, connected to vnet's dgram socket. +func (e *Env) startTailMacVM(n *Node) error { + snapDir := e.macosSnapshot + + if err := e.ensureTailMac(); err != nil { + return err + } + + testID := fmt.Sprintf("vmtest-%s-%d", n.name, os.Getpid()) + + // Host.app expects VM files under ~/.cache/tailscale/vmtest/macos// + home, err := os.UserHomeDir() + if err != nil { + return fmt.Errorf("UserHomeDir: %w", err) + } + vmBase := filepath.Join(home, ".cache", "tailscale", "vmtest", "macos") + os.MkdirAll(vmBase, 0755) + cloneDir := filepath.Join(vmBase, testID) + + // APFS clone the entire snapshot directory (includes SaveFile.vzvmsave). + e.t.Logf("[%s] cloning snapshot -> %s", n.name, testID) + if out, err := exec.Command("cp", "-c", "-r", snapDir, cloneDir).CombinedOutput(); err != nil { + if out2, err2 := exec.Command("cp", "-r", snapDir, cloneDir).CombinedOutput(); err2 != nil { + return fmt.Errorf("cloning snapshot: %v: %s (APFS clone: %v: %s)", err2, out2, err, out) + } + } + e.t.Cleanup(func() { os.RemoveAll(cloneDir) }) + + // Write test-specific config.json with the vnet MAC and dgram socket. + mac := n.vnetNode.NICMac(0) + cfg := struct { + VMid string `json:"vmID"` + ServerSocket string `json:"serverSocket"` + MemorySize uint64 `json:"memorySize"` + Mac string `json:"mac"` + }{ + VMid: testID, + ServerSocket: e.dgramSockAddr, + MemorySize: 8 * 1024 * 1024 * 1024, + Mac: mac.String(), + } + cfgData, _ := json.MarshalIndent(cfg, "", " ") + if err := os.WriteFile(filepath.Join(cloneDir, "config.json"), cfgData, 0644); err != nil { + return fmt.Errorf("writing config.json: %w", err) + } + + // Launch Host.app with disconnected NIC + hot-swap to vnet. + // Host.app will restore from SaveFile.vzvmsave (fast), then + // hot-swap the NIC to the vnet dgram socket. + hostBin := filepath.Join(e.tailmacDir, "Host.app", "Contents", "MacOS", "Host") + + // Compute the node's IP and gateway for static assignment via vsock. + nodeIP := n.vnetNode.LanIP(n.nets[0]) + // The gateway is the network's base address (e.g. 192.168.1.1 for /24). + // We derive it from the node IP: same /24 prefix, host part = 1. + gwIP := nodeIP.As4() + gwIP[3] = 1 + gateway := netip.AddrFrom4(gwIP) + + args := []string{ + "run", "--id", testID, "--headless", + "--disconnected-nic", + "--attach-network", e.dgramSockAddr, + "--assign-ip", fmt.Sprintf("%s/255.255.255.0/%s", nodeIP, gateway), + } + + wantScreenshots := *vmtestWeb != "" + if wantScreenshots { + args = append(args, "--screenshot-port", "0") + } + + logPath := filepath.Join(e.tempDir, n.name+"-tailmac.log") + logFile, err := os.Create(logPath) + if err != nil { + return fmt.Errorf("creating log file: %w", err) + } + + cmd := exec.Command(hostBin, args...) + cmd.Env = append(os.Environ(), "NSUnbufferedIO=YES") + + var stdoutPipe io.ReadCloser + if wantScreenshots { + stdoutPipe, err = cmd.StdoutPipe() + if err != nil { + logFile.Close() + return fmt.Errorf("stdout pipe: %w", err) + } + cmd.Stderr = logFile + } else { + cmd.Stdout = logFile + cmd.Stderr = logFile + } + devNull, err := os.Open(os.DevNull) + if err != nil { + logFile.Close() + return fmt.Errorf("open /dev/null: %w", err) + } + cmd.Stdin = devNull + + if err := cmd.Start(); err != nil { + devNull.Close() + logFile.Close() + return fmt.Errorf("starting tailmac for %s: %w", n.name, err) + } + e.t.Logf("[%s] launched tailmac (pid %d), log: %s", n.name, cmd.Process.Pid, logPath) + + if wantScreenshots { + screenshotPortCh := make(chan int, 1) + go func() { + scanner := bufio.NewScanner(stdoutPipe) + for scanner.Scan() { + line := scanner.Text() + fmt.Fprintln(logFile, line) + if port := 0; strings.HasPrefix(line, "SCREENSHOT_PORT=") { + fmt.Sscanf(line, "SCREENSHOT_PORT=%d", &port) + if port > 0 { + screenshotPortCh <- port + } + } + } + }() + go func() { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + select { + case port := <-screenshotPortCh: + e.t.Logf("[%s] screenshot server on port %d", n.name, port) + e.setNodeScreenshotPort(n.name, port) + e.tailScreenshots(n.name, port) + case <-ctx.Done(): + e.t.Logf("[%s] screenshot port not received", n.name) + } + }() + } + + clientSock := fmt.Sprintf("/tmp/qemu-dgram-%s.sock", testID) + + e.t.Cleanup(func() { + // Kill immediately — no need to save state for ephemeral test clones. + cmd.Process.Kill() + cmd.Wait() + devNull.Close() + logFile.Close() + os.Remove(clientSock) + + if e.t.Failed() { + if data, err := os.ReadFile(logPath); err == nil { + lines := strings.Split(string(data), "\n") + start := 0 + if len(lines) > 50 { + start = len(lines) - 50 + } + e.t.Logf("=== last 50 lines of %s tailmac log ===", n.name) + for _, line := range lines[start:] { + e.t.Logf("[%s] %s", n.name, line) + } + } + } + }) + + return nil +} + +// tailScreenshots polls the Host.app screenshot HTTP server every 2 seconds +// and publishes each screenshot as a base64 data URI to the web UI. +func (e *Env) tailScreenshots(name string, port int) { + url := fmt.Sprintf("http://127.0.0.1:%d/screenshot", port) + client := &http.Client{Timeout: 5 * time.Second} + ticker := time.NewTicker(2 * time.Second) + defer ticker.Stop() + + for range ticker.C { + resp, err := client.Get(url) + if err != nil { + continue + } + data, _ := io.ReadAll(resp.Body) + resp.Body.Close() + if resp.StatusCode != 200 || len(data) == 0 { + continue + } + b64 := base64.StdEncoding.EncodeToString(data) + dataURI := "data:image/jpeg;base64," + b64 + e.setNodeScreenshot(name, dataURI) + e.eventBus.Publish(VMEvent{ + NodeName: name, + Type: EventScreenshot, + Message: b64, + }) + } +} diff --git a/tstest/natlab/vmtest/version.go b/tstest/natlab/vmtest/version.go new file mode 100644 index 000000000..7e76716e4 --- /dev/null +++ b/tstest/natlab/vmtest/version.go @@ -0,0 +1,195 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package vmtest + +import ( + "archive/tar" + "compress/gzip" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "path" + "path/filepath" + "regexp" + "strconv" + "strings" + + "tailscale.com/types/logger" +) + +// versionRE matches a concrete X.Y.Z release version. +var versionRE = regexp.MustCompile(`^\d+\.\d+\.\d+$`) + +// resolveTestVersion returns the concrete release version (e.g. "1.97.255") +// for the given --test-version flag value. If v is "unstable" or "stable", it +// queries pkgs.tailscale.com for the latest TarballsVersion on that track. +// Otherwise it returns v unchanged. +func resolveTestVersion(ctx context.Context, v string) (string, error) { + if v != "unstable" && v != "stable" { + if !versionRE.MatchString(v) { + return "", fmt.Errorf("invalid --test-version %q: want \"stable\", \"unstable\", or X.Y.Z", v) + } + return v, nil + } + url := "https://pkgs.tailscale.com/" + v + "/?mode=json" + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) + if err != nil { + return "", err + } + resp, err := http.DefaultClient.Do(req) + if err != nil { + return "", fmt.Errorf("fetching %s: %w", url, err) + } + defer resp.Body.Close() + if resp.StatusCode != 200 { + return "", fmt.Errorf("fetching %s: HTTP %s", url, resp.Status) + } + var meta struct { + TarballsVersion string + } + if err := json.NewDecoder(resp.Body).Decode(&meta); err != nil { + return "", fmt.Errorf("decoding %s: %w", url, err) + } + if meta.TarballsVersion == "" { + return "", fmt.Errorf("no TarballsVersion in %s response", url) + } + return meta.TarballsVersion, nil +} + +// versionTrack returns the pkgs.tailscale.com track ("stable" or "unstable") +// for a release version. Even minors are stable; odd minors are unstable. +func versionTrack(version string) (string, error) { + parts := strings.Split(version, ".") + if len(parts) < 2 { + return "", fmt.Errorf("bad version %q (expected like 1.97.255)", version) + } + minor, err := strconv.Atoi(parts[1]) + if err != nil { + return "", fmt.Errorf("bad minor in version %q: %w", version, err) + } + if minor%2 == 0 { + return "stable", nil + } + return "unstable", nil +} + +// versionCacheRoot returns the root cache directory for downloaded version +// tarballs. +func versionCacheRoot() string { + if d := os.Getenv("VMTEST_BUILDS_CACHE_DIR"); d != "" { + return d + } + cache, err := os.UserCacheDir() + if err != nil { + panic(fmt.Sprintf("os.UserCacheDir: %v", err)) + } + return filepath.Join(cache, "tailscale-vmtest", "builds") +} + +// versionCacheDir returns the directory holding the extracted binaries for +// the given version+arch. +func versionCacheDir(version, arch string) string { + return filepath.Join(versionCacheRoot(), fmt.Sprintf("%s_%s", version, arch)) +} + +// ensureVersionBinaries downloads (if needed) and extracts the tailscale +// release tarball for the given concrete version+arch, returning the +// directory containing tailscale and tailscaled. +func ensureVersionBinaries(ctx context.Context, version, arch string, logf logger.Logf) (string, error) { + dir := versionCacheDir(version, arch) + tailscaled := filepath.Join(dir, "tailscaled") + tailscale := filepath.Join(dir, "tailscale") + if _, err1 := os.Stat(tailscaled); err1 == nil { + if _, err2 := os.Stat(tailscale); err2 == nil { + return dir, nil + } + } + + track, err := versionTrack(version) + if err != nil { + return "", err + } + url := fmt.Sprintf("https://pkgs.tailscale.com/%s/tailscale_%s_%s.tgz", track, version, arch) + logf("downloading %s", url) + + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) + if err != nil { + return "", err + } + resp, err := http.DefaultClient.Do(req) + if err != nil { + return "", fmt.Errorf("fetching %s: %w", url, err) + } + defer resp.Body.Close() + if resp.StatusCode != 200 { + return "", fmt.Errorf("fetching %s: HTTP %s", url, resp.Status) + } + + if err := os.MkdirAll(dir, 0755); err != nil { + return "", err + } + + gzr, err := gzip.NewReader(resp.Body) + if err != nil { + return "", fmt.Errorf("gzip reader for %s: %w", url, err) + } + defer gzr.Close() + tr := tar.NewReader(gzr) + + wantBase := map[string]bool{ + "tailscale": true, + "tailscaled": true, + } + got := map[string]bool{} + for { + hdr, err := tr.Next() + if err == io.EOF { + break + } + if err != nil { + return "", fmt.Errorf("reading tarball %s: %w", url, err) + } + if hdr.Typeflag != tar.TypeReg { + continue + } + base := path.Base(hdr.Name) + if !wantBase[base] { + continue + } + if err := writeAtomic(filepath.Join(dir, base), tr, 0755); err != nil { + return "", fmt.Errorf("extracting %s from %s: %w", base, url, err) + } + got[base] = true + } + for b := range wantBase { + if !got[b] { + return "", fmt.Errorf("tarball %s missing %s", url, b) + } + } + logf("extracted %s and %s to %s", "tailscale", "tailscaled", dir) + return dir, nil +} + +// writeAtomic writes the contents of r to dst with the given permission +// bits, by writing to a sibling temp file and renaming on success. +func writeAtomic(dst string, r io.Reader, perm os.FileMode) error { + tmp := dst + ".tmp" + f, err := os.OpenFile(tmp, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, perm) + if err != nil { + return err + } + if _, err := io.Copy(f, r); err != nil { + f.Close() + os.Remove(tmp) + return err + } + if err := f.Close(); err != nil { + os.Remove(tmp) + return err + } + return os.Rename(tmp, dst) +} diff --git a/tstest/natlab/vmtest/version_test.go b/tstest/natlab/vmtest/version_test.go new file mode 100644 index 000000000..375056290 --- /dev/null +++ b/tstest/natlab/vmtest/version_test.go @@ -0,0 +1,97 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package vmtest + +import ( + "context" + "flag" + "os" + "path/filepath" + "testing" +) + +var testDownloadVersion = flag.Bool("test-download-version", false, "in TestVersionDownload, actually hit pkgs.tailscale.com") + +func TestResolveTestVersionInvalid(t *testing.T) { + bad := []string{ + "", + "1.97", + "v1.97.255", + "1.97.255-pre", + "latest", + "unstabel", + } + for _, v := range bad { + got, err := resolveTestVersion(context.Background(), v) + if err == nil { + t.Errorf("resolveTestVersion(%q) = %q, want error", v, got) + } + } +} + +func TestVersionTrack(t *testing.T) { + cases := []struct { + v, want string + }{ + {"1.96.4", "stable"}, + {"1.97.255", "unstable"}, + {"1.98.0", "stable"}, + } + for _, c := range cases { + got, err := versionTrack(c.v) + if err != nil { + t.Errorf("versionTrack(%q): %v", c.v, err) + continue + } + if got != c.want { + t.Errorf("versionTrack(%q) = %q, want %q", c.v, got, c.want) + } + } +} + +// TestVersionDownload exercises the live network path (download + extract + +// cache). Skipped by default; set --test-download-version to run. +func TestVersionDownload(t *testing.T) { + if !*testDownloadVersion { + t.Skip("set --test-download-version to run") + } + cacheRoot := t.TempDir() + t.Setenv("VMTEST_BUILDS_CACHE_DIR", cacheRoot) + + ctx := context.Background() + const version = "1.96.4" // stable + dir, err := ensureVersionBinaries(ctx, version, "amd64", t.Logf) + if err != nil { + t.Fatal(err) + } + wantDir := filepath.Join(cacheRoot, version+"_amd64") + if dir != wantDir { + t.Errorf("dir = %q, want %q", dir, wantDir) + } + for _, name := range []string{"tailscale", "tailscaled"} { + fi, err := os.Stat(filepath.Join(dir, name)) + if err != nil { + t.Errorf("missing %s: %v", name, err) + continue + } + if fi.Size() < 1<<20 { + t.Errorf("%s suspiciously small: %d bytes", name, fi.Size()) + } + } + + // Re-fetch should be a fast no-op (cache hit). + if _, err := ensureVersionBinaries(ctx, version, "amd64", t.Logf); err != nil { + t.Fatalf("re-fetch: %v", err) + } + + // "unstable" resolution. + resolved, err := resolveTestVersion(ctx, "unstable") + if err != nil { + t.Fatalf("resolveTestVersion(unstable): %v", err) + } + t.Logf("unstable resolved to %q", resolved) + if resolved == "" || resolved == "unstable" { + t.Errorf("resolved = %q", resolved) + } +} diff --git a/tstest/natlab/vmtest/vmstatus.go b/tstest/natlab/vmtest/vmstatus.go new file mode 100644 index 000000000..240c34f42 --- /dev/null +++ b/tstest/natlab/vmtest/vmstatus.go @@ -0,0 +1,345 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package vmtest + +import ( + "fmt" + "sync" + "time" +) + +// StepStatus is the state of a declared test step. +type StepStatus int + +const ( + StepPending StepStatus = iota // not yet started + StepRunning // Begin called + StepDone // End(nil) called + StepFailed // End(non-nil) called +) + +func (s StepStatus) String() string { + switch s { + case StepPending: + return "pending" + case StepRunning: + return "running" + case StepDone: + return "done" + case StepFailed: + return "failed" + } + return fmt.Sprintf("StepStatus(%d)", int(s)) +} + +// Icon returns a Unicode icon for the step status. +func (s StepStatus) Icon() string { + switch s { + case StepPending: + return "○" + case StepRunning: + return "◉" + case StepDone: + return "✓" + case StepFailed: + return "✗" + } + return "?" +} + +// Step is a declared stage of a test, created by [Env.AddStep]. +// The web UI shows all steps from the start, tracking their progress. +type Step struct { + mu sync.Mutex + name string + index int // 0-based position in Env.steps + env *Env + status StepStatus + err error + started time.Time + ended time.Time +} + +// Name returns the step's display name. +func (s *Step) Name() string { return s.name } + +// Index returns the step's 0-based position. +func (s *Step) Index() int { return s.index } + +// Status returns the current status. +func (s *Step) Status() StepStatus { + s.mu.Lock() + defer s.mu.Unlock() + return s.status +} + +// Err returns the error if the step failed, or nil. +func (s *Step) Err() error { + s.mu.Lock() + defer s.mu.Unlock() + return s.err +} + +// Elapsed returns how long the step has been running (if running) +// or how long it took (if done/failed). Returns 0 if pending. +func (s *Step) Elapsed() time.Duration { + s.mu.Lock() + defer s.mu.Unlock() + if s.started.IsZero() { + return 0 + } + if !s.ended.IsZero() { + return s.ended.Sub(s.started) + } + return time.Since(s.started) +} + +// Begin marks the step as running. Publishes an event to the web UI. +func (s *Step) Begin() { + s.mu.Lock() + if s.status != StepPending { + s.mu.Unlock() + panic(fmt.Sprintf("Step %q: Begin called in state %s", s.name, s.status)) + } + s.started = time.Now() + s.status = StepRunning + s.mu.Unlock() + s.env.publishStepChange(s) +} + +// End marks the step as done (err == nil) or failed (err != nil). +// It publishes a status change event to the web UI. +// It does not call t.Fatalf; callers should handle the error as appropriate +// (return it from errgroup, call t.Fatalf on the test goroutine, etc). +func (s *Step) End(err error) { + s.mu.Lock() + if s.status != StepRunning { + s.mu.Unlock() + panic(fmt.Sprintf("Step %q: End called in state %s", s.name, s.status)) + } + s.ended = time.Now() + if err != nil { + s.status = StepFailed + s.err = err + } else { + s.status = StepDone + } + s.mu.Unlock() + s.env.publishStepChange(s) +} + +// Fatalf marks the step as failed (as [Step.End]), and then logs a test +// failure to the environment's test, with an error constructed from the given +// arguments. +func (s *Step) Fatalf(msg string, args ...any) { + s.Fatal(fmt.Errorf(msg, args...)) +} + +// Fatal marks the step as failed (as [Step.End]), and then logs a test failure +// to the environment's test, with the specified (non-nil) error. It will panic +// if err == nil. +func (s *Step) Fatal(err error) { + if err == nil { + panic(fmt.Sprintf("Step %q: Fatal called with a nil error", s.name)) + } + s.End(err) + s.env.t.Fatal(err) +} + +// EventType identifies the kind of event published to the EventBus. +type EventType string + +const ( + EventStepChanged EventType = "step_changed" // a Step changed status + EventConsoleOutput EventType = "console_output" // serial console line + EventDHCPDiscover EventType = "dhcp_discover" // VM sent DHCP Discover + EventDHCPOffer EventType = "dhcp_offer" // server sent DHCP Offer + EventDHCPRequest EventType = "dhcp_request" // VM sent DHCP Request + EventDHCPAck EventType = "dhcp_ack" // server sent DHCP Ack + EventScreenshot EventType = "screenshot" // VM display screenshot (JPEG, base64) + EventTailscale EventType = "tailscale" // Tailscale status change + EventTestStatus EventType = "test_status" // test Running/Passed/Failed +) + +// TestStatus tracks whether the overall test is running, passed, or failed. +type TestStatus struct { + mu sync.Mutex + state string // "Running", "Passed", "Failed" + started time.Time + ended time.Time +} + +func newTestStatus() *TestStatus { + return &TestStatus{state: "Running", started: time.Now()} +} + +// State returns the current test state. +func (ts *TestStatus) State() string { + ts.mu.Lock() + defer ts.mu.Unlock() + return ts.state +} + +// Elapsed returns total test duration. +func (ts *TestStatus) Elapsed() time.Duration { + ts.mu.Lock() + defer ts.mu.Unlock() + if !ts.ended.IsZero() { + return ts.ended.Sub(ts.started) + } + return time.Since(ts.started) +} + +// StartUnixMilli returns the test start time as Unix milliseconds, +// for the client-side elapsed timer. +func (ts *TestStatus) StartUnixMilli() int64 { + ts.mu.Lock() + defer ts.mu.Unlock() + return ts.started.UnixMilli() +} + +// finish marks the test as passed or failed. +func (ts *TestStatus) finish(failed bool) { + ts.mu.Lock() + defer ts.mu.Unlock() + ts.ended = time.Now() + if failed { + ts.state = "Failed" + } else { + ts.state = "Passed" + } +} + +// VMEvent is a single event published to the [EventBus]. +type VMEvent struct { + Time time.Time + NodeName string // "" for global events + Type EventType + Message string // human-readable description + Detail string // e.g. IP address, node key + Step *Step // non-nil for EventStepChanged + NIC int // NIC index for DHCP events (0-based); -1 if not applicable +} + +// NICStatus is the DHCP state for one NIC on a node. +type NICStatus struct { + NetName string // human label like "192.168.1.0/24" or "10.0.0.0/24" + DHCP string // "waiting", "Discover sent", "Got 10.0.0.101", etc. +} + +// NodeStatus tracks the current DHCP and Tailscale state of a VM node +// for rendering on the web UI's initial page load. +type NodeStatus struct { + Name string + OS string + NICs []NICStatus // one per NIC; index matches NIC index + JoinsTailnet bool // whether this node runs Tailscale + Tailscale string // "--", "Up (100.64.0.1)", etc. + Console []string // recent console output lines (ring buffer) + Screenshot string // latest screenshot as data URI, or "" + ScreenshotPort int // Host.app screenshot server port, or 0 +} + +const maxConsoleLines = 200 + +const ( + eventBusHistorySize = 500 + subscriberChannelSize = 1000 +) + +// EventBus broadcasts VMEvents to subscribers and keeps a history for +// late joiners. It is safe for concurrent use. +type EventBus struct { + mu sync.Mutex + history []VMEvent + subscribers map[*subscriber]struct{} +} + +func newEventBus() *EventBus { + return &EventBus{ + subscribers: make(map[*subscriber]struct{}), + } +} + +// Publish sends an event to all subscribers and appends it to the history. +// Non-blocking: slow subscribers are skipped. +func (b *EventBus) Publish(ev VMEvent) { + if ev.Time.IsZero() { + ev.Time = time.Now() + } + b.mu.Lock() + defer b.mu.Unlock() + // Don't store screenshots in history — they're large and only the + // latest one matters (stored in NodeStatus.Screenshot instead). + if ev.Type != EventScreenshot { + b.history = append(b.history, ev) + } + if len(b.history) > eventBusHistorySize { + // Trim old events. + copy(b.history, b.history[len(b.history)-eventBusHistorySize:]) + b.history = b.history[:eventBusHistorySize] + } + for sub := range b.subscribers { + select { + case sub.ch <- ev: + default: + // Slow consumer, skip. + } + } +} + +// Subscribe returns a new subscriber that receives the event history +// followed by live events. +func (b *EventBus) Subscribe() *subscriber { + b.mu.Lock() + defer b.mu.Unlock() + sub := &subscriber{ + bus: b, + ch: make(chan VMEvent, subscriberChannelSize), + done: make(chan struct{}), + } + // Send history. + for _, ev := range b.history { + select { + case sub.ch <- ev: + default: + } + } + b.subscribers[sub] = struct{}{} + return sub +} + +func (b *EventBus) unsubscribe(sub *subscriber) { + b.mu.Lock() + defer b.mu.Unlock() + delete(b.subscribers, sub) +} + +// subscriber receives events from an [EventBus]. +type subscriber struct { + bus *EventBus + ch chan VMEvent + done chan struct{} + once sync.Once +} + +// Events returns the channel of events. Closed when Close is called. +func (s *subscriber) Events() <-chan VMEvent { + return s.ch +} + +// Close unsubscribes and closes the event channel. +func (s *subscriber) Close() { + s.once.Do(func() { + if s.bus != nil { + s.bus.unsubscribe(s) + } + close(s.done) + }) +} + +// Done returns a channel that's closed when Close is called. +func (s *subscriber) Done() <-chan struct{} { + return s.done +} diff --git a/tstest/natlab/vmtest/vmtest.go b/tstest/natlab/vmtest/vmtest.go index e6c89467f..a256f9c10 100644 --- a/tstest/natlab/vmtest/vmtest.go +++ b/tstest/natlab/vmtest/vmtest.go @@ -7,7 +7,7 @@ // and multi-NIC configurations for scenarios like subnet routing. // // Prerequisites: -// - qemu-system-x86_64 and KVM access (typically the "kvm" group; no root required) +// - qemu-system-x86_64 (KVM is used automatically on Linux when /dev/kvm is accessible) // - A built gokrazy natlabapp image (auto-built on first run via "make natlab" in gokrazy/) // // Run tests with: @@ -16,31 +16,47 @@ package vmtest import ( + "bytes" "context" + "encoding/base64" + "encoding/json" "flag" "fmt" "io" "net" "net/http" "net/netip" + "net/url" "os" "os/exec" "path/filepath" + "runtime" + "strconv" "strings" + "sync" "testing" "time" + "github.com/google/gopacket/layers" + dto "github.com/prometheus/client_model/go" + "github.com/prometheus/common/expfmt" + "go4.org/mem" "golang.org/x/sync/errgroup" "tailscale.com/client/local" "tailscale.com/ipn" + "tailscale.com/ipn/ipnstate" "tailscale.com/tailcfg" + "tailscale.com/tstest" + "tailscale.com/tstest/integration/testcontrol" "tailscale.com/tstest/natlab/vnet" - "tailscale.com/util/set" + "tailscale.com/types/key" + "tailscale.com/util/mak" ) var ( - runVMTests = flag.Bool("run-vm-tests", false, "run tests that require VMs with KVM") + runVMTests = flag.Bool("run-vm-tests", false, "run tests that require QEMU VMs") verboseVMDebug = flag.Bool("verbose-vm-debug", false, "enable verbose debug logging for VM tests") + testVersion = flag.String("test-version", "", `if non-empty, download tailscale & tailscaled at the given release version (e.g. "1.97.255", "unstable", or "stable") instead of building from the source tree`) ) // Env is a test environment that manages virtual networks and QEMU VMs. @@ -52,14 +68,50 @@ type Env struct { nodes []*Node tempDir string - sockAddr string // shared Unix socket path for all QEMU netdevs - binDir string // directory for compiled binaries + sockDir string // short-path dir for Unix sockets (macOS has 104-byte limit) + sockAddr string // shared Unix socket path for all QEMU netdevs + dgramSockAddr string // Unix dgram socket path for macOS VMs (tailmac) + binDir string // directory for compiled binaries + + // testVersion is the resolved Tailscale release version to use (empty if + // building from source). When non-empty, tailscale and tailscaled binaries + // are downloaded from pkgs.tailscale.com instead of compiled from the tree. + testVersion string // gokrazy-specific paths gokrazyBase string // path to gokrazy base qcow2 image gokrazyKernel string // path to gokrazy kernel + // tailmac-specific paths (macOS VMs) + tailmacDir string // path to tailmac bin/ directory containing Host.app + macosSnapshot string // path to cached macOS VM snapshot directory + macosSnapshotOnce sync.Once + qemuProcs []*exec.Cmd // launched QEMU processes + + sameTailnetUser bool // all nodes register as the same Tailnet user + allOnline bool // mark every peer as Online=true in MapResponses + peerRelayGrants bool // grant peer-relay capabilities on the wildcard packet filter + + // Shared resource initialization (sync.Once for things multiple nodes share). + vnetOnce sync.Once + gokrazyOnce sync.Once + qemuSockOnce sync.Once + dgramSockOnce sync.Once + compileMu sync.Mutex + compileOnce map[string]*sync.Once // keyed by goos_goarch + imageOnce map[string]*sync.Once // keyed by OSImage.Name + + // Web UI support. + ctx context.Context // cancelled when test ends + eventBus *EventBus + testStatus *TestStatus + stepsMu sync.Mutex + stepsByKey map[string]*Step + steps []*Step + + nodeStatusMu sync.Mutex + nodeStatus map[string]*NodeStatus // keyed by node name } // logVerbosef logs a message only when --verbose-vm-debug is set. @@ -70,18 +122,266 @@ func (e *Env) logVerbosef(format string, args ...any) { } } -// New creates a new test environment. It skips the test if --run-vm-tests is not set. -func New(t testing.TB) *Env { +// vmPlatform defines how a VM type boots. Each OS image type (gokrazy, +// cloud, macOS) implements this interface. +type vmPlatform interface { + // planSteps registers steps with the web UI in a dry-run pass. + planSteps(e *Env, n *Node) + + // boot does everything needed to get this node running: ensure images, + // compile binaries, set up sockets, launch VM. Called concurrently. + boot(ctx context.Context, e *Env, n *Node) error +} + +// platform returns the vmPlatform for this node's OS type. +func (n *Node) platform() vmPlatform { + if n.os.IsMacOS { + return macPlatform{} + } + if n.os.IsGokrazy { + return gokrazyPlatform{} + } + return qemuCloudPlatform{} +} + +// AddStep declares an expected stage of the test. The web UI shows all steps +// from the start, tracking their progress. Call before or during the test. +// Returns a *Step whose Begin/End methods drive the progress display. +func (e *Env) AddStep(name string) *Step { + s := &Step{ + name: name, + index: len(e.steps), + env: e, + } + e.steps = append(e.steps, s) + return s +} + +// Step returns a step by key, creating it if it doesn't exist. +// Safe for concurrent use. Both planSteps (dry-run) and boot (real-run) +// call this to get the same Step object. +func (e *Env) Step(key string) *Step { + e.stepsMu.Lock() + defer e.stepsMu.Unlock() + if s, ok := e.stepsByKey[key]; ok { + return s + } + s := &Step{ + name: key, + index: len(e.steps), + env: e, + } + e.steps = append(e.steps, s) + if e.stepsByKey == nil { + e.stepsByKey = make(map[string]*Step) + } + e.stepsByKey[key] = s + return s +} + +// Steps returns all declared steps in order. +func (e *Env) Steps() []*Step { + return e.steps +} + +// publishStepChange publishes a step status change event. +func (e *Env) publishStepChange(s *Step) { + e.eventBus.Publish(VMEvent{ + Type: EventStepChanged, + Message: fmt.Sprintf("%s %s", s.Status().Icon(), s.name), + Step: s, + }) +} + +// initNodeStatus initializes the NodeStatus for all nodes. Called after +// AddNode but before Start so the web UI can render them. +func (e *Env) initNodeStatus() { + e.nodeStatusMu.Lock() + defer e.nodeStatusMu.Unlock() + for _, n := range e.nodes { + nics := make([]NICStatus, len(n.nets)) + for i := range n.nets { + nics[i] = NICStatus{ + NetName: e.nicLabel(n, i), + DHCP: "waiting", + } + } + e.nodeStatus[n.name] = &NodeStatus{ + Name: n.name, + OS: n.os.Name, + NICs: nics, + JoinsTailnet: n.joinTailnet, + Tailscale: "--", + } + } +} + +// nicLabel returns a short human-readable label for a node's i-th NIC. +// After Start(), we can use the assigned LAN IP. Before that, we use "NIC N". +func (e *Env) nicLabel(n *Node, i int) string { + if n.vnetNode != nil { + ip := n.vnetNode.LanIP(n.nets[i]) + if ip.IsValid() { + return ip.String() + } + } + return fmt.Sprintf("NIC %d", i) +} + +// getNodeStatus returns the current status for a node. +func (e *Env) getNodeStatus(name string) NodeStatus { + e.nodeStatusMu.Lock() + defer e.nodeStatusMu.Unlock() + ns := e.nodeStatus[name] + if ns == nil { + return NodeStatus{Name: name, Tailscale: "--"} + } + return *ns +} + +// setNodeDHCP updates the DHCP status for a specific NIC on a node. +func (e *Env) setNodeDHCP(name string, nicIdx int, status string) { + e.nodeStatusMu.Lock() + ns := e.nodeStatus[name] + if ns != nil && nicIdx < len(ns.NICs) { + ns.NICs[nicIdx].DHCP = status + } + e.nodeStatusMu.Unlock() +} + +// setNodeTailscale updates the Tailscale status for a node and publishes +// an event so the web UI updates via WebSocket. +func (e *Env) setNodeTailscale(name, status string) { + e.nodeStatusMu.Lock() + ns := e.nodeStatus[name] + if ns != nil { + ns.Tailscale = status + } + e.nodeStatusMu.Unlock() + e.eventBus.Publish(VMEvent{ + NodeName: name, + Type: EventTailscale, + Message: "Tailscale: " + status, + Detail: status, + }) +} + +// appendConsoleLine adds a line to a node's console buffer. +func (e *Env) appendConsoleLine(name, line string) { + e.nodeStatusMu.Lock() + ns := e.nodeStatus[name] + if ns != nil { + ns.Console = append(ns.Console, line) + if len(ns.Console) > maxConsoleLines { + ns.Console = ns.Console[len(ns.Console)-maxConsoleLines:] + } + } + e.nodeStatusMu.Unlock() +} + +// nicIndexForMAC returns the NIC index (0-based) for a given MAC on a node. +// Returns -1 if not found. +func (e *Env) nicIndexForMAC(name string, mac vnet.MAC) int { + for _, n := range e.nodes { + if n.name != name { + continue + } + for i := range n.nets { + if n.vnetNode.NICMac(i) == mac { + return i + } + } + } + return -1 +} + +// nodeNameByNum returns the node name for a given vnet node number. +func (e *Env) nodeNameByNum(num int) string { + for _, n := range e.nodes { + if n.num == num { + return n.name + } + } + return fmt.Sprintf("node%d", num) +} + +// New creates a new test environment. It skips the test if --run-vm-tests is +// not set. opts may contain [EnvOption] values returned by helpers like +// [SameTailnetUser]. +func New(t testing.TB, opts ...EnvOption) *Env { if !*runVMTests { t.Skip("skipping VM test; set --run-vm-tests to run") } tempDir := t.TempDir() - return &Env{ - t: t, - tempDir: tempDir, - binDir: filepath.Join(tempDir, "bin"), + + // Unix sockets have a short path limit (104 bytes on macOS). The Go + // test TempDir path easily exceeds that, so create a dedicated short + // directory under /tmp for sockets. + sockDir, err := os.MkdirTemp("", "vmtest") + if err != nil { + t.Fatalf("creating socket tempdir: %v", err) } + t.Cleanup(func() { os.RemoveAll(sockDir) }) + + e := &Env{ + t: t, + tempDir: tempDir, + sockDir: sockDir, + binDir: filepath.Join(tempDir, "bin"), + eventBus: newEventBus(), + testStatus: newTestStatus(), + nodeStatus: make(map[string]*NodeStatus), + } + for _, o := range opts { + o.applyTo(e) + } + t.Cleanup(func() { + e.testStatus.finish(t.Failed()) + e.eventBus.Publish(VMEvent{ + Type: EventTestStatus, + Message: e.testStatus.State(), + Detail: formatDuration(e.testStatus.Elapsed()), + }) + }) + return e +} + +// EnvOption configures an [Env] in [New]. +type EnvOption interface { + applyTo(*Env) +} + +type envOptFunc func(*Env) + +func (f envOptFunc) applyTo(e *Env) { f(e) } + +// SameTailnetUser returns an [EnvOption] that makes every node register with +// the test control server as the same Tailnet user. This is needed for +// cross-node features that require a same-user relationship — Taildrop, for +// example. +func SameTailnetUser() EnvOption { + return envOptFunc(func(e *Env) { e.sameTailnetUser = true }) +} + +// AllOnline returns an [EnvOption] that makes the test control server mark +// every peer as Online=true in MapResponses (testcontrol.Server.AllOnline). +// Several disco-key handling fast paths in the controlclient and wgengine +// only fire when the peer is reported online; without this option those +// paths are silently skipped, which can mask bugs and slow down recovery +// from disco-key rotations. +func AllOnline() EnvOption { + return envOptFunc(func(e *Env) { e.allOnline = true }) +} + +// PeerRelayGrants returns an [EnvOption] that makes the test control server +// grant [tailcfg.PeerCapabilityRelay] and [tailcfg.PeerCapabilityRelayTarget] +// on the wildcard packet filter (testcontrol.Server.PeerRelayGrants). Without +// those capabilities, magicsock does not consider any peer a candidate +// peer-relay server, so a node that has [ipn.Prefs.RelayServerPort] set +// cannot actually be used as a relay by its peers. +func PeerRelayGrants() EnvOption { + return envOptFunc(func(e *Env) { e.peerRelayGrants = true }) } // AddNetwork creates a new virtual network. Arguments follow the same pattern as @@ -95,14 +395,16 @@ type Node struct { name string num int // assigned during AddNode - os OSImage - nets []*vnet.Network - vnetNode *vnet.Node // primary vnet node (set during Start) - agent *vnet.NodeAgentClient - joinTailnet bool - advertiseRoutes string - webServerPort int - sshPort int // host port for SSH debug access (cloud VMs only) + os OSImage + nets []*vnet.Network + vnetNode *vnet.Node // primary vnet node (set during Start) + agent *vnet.NodeAgentClient + joinTailnet bool + noAgent bool // true to skip TTA agent setup (e.g. macOS VMs without TTA) + advertiseRoutes string + snatSubnetRoutes *bool // nil means default (true) + webServerPort int + sshPort int // host port for SSH debug access (cloud VMs only) } // AddNode creates a new VM node. The name is used for identification and as the @@ -128,8 +430,13 @@ func (e *Env) AddNode(name string, opts ...any) *Node { case nodeOptNoTailscale: n.joinTailnet = false vnetOpts = append(vnetOpts, vnet.DontJoinTailnet) + case nodeOptNoAgent: + n.noAgent = true case nodeOptAdvertiseRoutes: n.advertiseRoutes = string(o) + case nodeOptSNATSubnetRoutes: + v := bool(o) + n.snatSubnetRoutes = &v case nodeOptWebServer: n.webServerPort = int(o) default: @@ -138,22 +445,45 @@ func (e *Env) AddNode(name string, opts ...any) *Node { } } + // macOS VMs require a macOS arm64 host (Apple Virtualization.framework via + // tailmac). Skip the test now rather than letting it proceed through the + // rest of the setup only to fail later. + if n.os.IsMacOS && (runtime.GOOS != "darwin" || runtime.GOARCH != "arm64") { + e.t.Skipf("macOS VM tests require a macOS arm64 host (got %s/%s)", runtime.GOOS, runtime.GOARCH) + } + n.vnetNode = e.cfg.AddNode(vnetOpts...) n.num = n.vnetNode.Num() return n } +// Name returns the name of the Node. +func (n *Node) Name() string { + return n.name +} + // LanIP returns the LAN IPv4 address of this node on the given network. // This is only valid after Env.Start() has been called. +// Name returns the node's name as set in [Env.AddNode]. func (n *Node) LanIP(net *vnet.Network) netip.Addr { return n.vnetNode.LanIP(net) } +// DropControlTraffic sets up a blackhole for control traffic for just this +// node on all the networks belonging to the node. +func (n *Node) DropControlTraffic() { + for _, network := range n.nets { + network.BlackholeControlForAddr(n.LanIP(network)) + } +} + // NodeOption types for configuring nodes. type nodeOptOS OSImage type nodeOptNoTailscale struct{} +type nodeOptNoAgent struct{} type nodeOptAdvertiseRoutes string +type nodeOptSNATSubnetRoutes bool type nodeOptWebServer int // OS returns a NodeOption that sets the node's operating system image. @@ -162,132 +492,115 @@ func OS(img OSImage) nodeOptOS { return nodeOptOS(img) } // DontJoinTailnet returns a NodeOption that prevents the node from running tailscale up. func DontJoinTailnet() nodeOptNoTailscale { return nodeOptNoTailscale{} } +// NoAgent returns a NodeOption that skips TTA agent setup. The node will not +// have a test agent, so agent-dependent operations (Status, ExecOnNode, etc.) +// won't work. Useful for VMs that just need to boot and respond to ICMP. +func NoAgent() nodeOptNoAgent { return nodeOptNoAgent{} } + // AdvertiseRoutes returns a NodeOption that configures the node to advertise // the given routes (comma-separated CIDRs) when joining the tailnet. func AdvertiseRoutes(routes string) nodeOptAdvertiseRoutes { return nodeOptAdvertiseRoutes(routes) } +// SNATSubnetRoutes returns a NodeOption that sets whether the node should +// source NAT traffic to advertised subnet routes. The default is true. +// Setting this to false preserves original source IPs, which is needed +// for site-to-site configurations. +func SNATSubnetRoutes(v bool) nodeOptSNATSubnetRoutes { return nodeOptSNATSubnetRoutes(v) } + // WebServer returns a NodeOption that starts a webserver on the given port. -// The webserver responds with "Hello world I am " on all requests. +// The webserver responds with "Hello world I am from " on all requests. func WebServer(port int) nodeOptWebServer { return nodeOptWebServer(port) } -// Start initializes the virtual network, builds/downloads images, compiles -// binaries, launches QEMU processes, and waits for all TTA agents to connect. -// It should be called after all AddNetwork/AddNode calls. +// Start initializes the virtual network, boots all VMs in parallel, and waits +// for all TTA agents to connect. It should be called after all AddNetwork/AddNode calls. func (e *Env) Start() { t := e.t ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute) t.Cleanup(cancel) + e.ctx = ctx + + e.initNodeStatus() + e.maybeStartWebServer() if err := os.MkdirAll(e.binDir, 0755); err != nil { t.Fatal(err) } - // Determine which GOOS/GOARCH pairs need compiled binaries (non-gokrazy - // images). Gokrazy has binaries built-in, so doesn't need compilation. - type platform struct{ goos, goarch string } - needPlatform := set.Set[platform]{} - for _, n := range e.nodes { - if !n.os.IsGokrazy { - needPlatform.Add(platform{n.os.GOOS(), n.os.GOARCH()}) + if *testVersion != "" { + v, err := resolveTestVersion(ctx, *testVersion) + if err != nil { + t.Fatalf("resolving --test-version=%q: %v", *testVersion, err) } + e.testVersion = v + t.Logf("using Tailscale release version %s (from --test-version=%q)", v, *testVersion) } - // Compile binaries and download/build images in parallel. - // Any failure cancels the others via the errgroup context. - eg, egCtx := errgroup.WithContext(ctx) - for _, p := range needPlatform.Slice() { - eg.Go(func() error { - return e.compileBinariesForOS(egCtx, p.goos, p.goarch) + // Dry-run: let each platform register its steps with the web UI. + userSteps := e.steps + e.steps = nil + for _, n := range e.nodes { + n.platform().planSteps(e, n) + } + for _, n := range e.nodes { + if !n.noAgent { + e.Step("Wait for agent: " + n.name) + } + if n.joinTailnet { + e.Step("Tailscale up: " + n.name) + } + } + for _, s := range userSteps { + s.index = len(e.steps) + e.steps = append(e.steps, s) + } + + // Boot all nodes in parallel. Each platform handles its own + // dependencies (image prep, binary compilation, socket setup) + // via sync.Once, so independent work overlaps naturally. + var bootEg errgroup.Group + for _, n := range e.nodes { + bootEg.Go(func() error { + return n.platform().boot(ctx, e, n) }) } - didOS := set.Set[string]{} // dedup by image name - for _, n := range e.nodes { - if didOS.Contains(n.os.Name) { - continue - } - didOS.Add(n.os.Name) - if n.os.IsGokrazy { - eg.Go(func() error { - return e.ensureGokrazy(egCtx) - }) - } else { - eg.Go(func() error { - return ensureImage(egCtx, n.os) - }) - } - } - if err := eg.Wait(); err != nil { - t.Fatalf("setup: %v", err) - } - - // Create the vnet server. - var err error - e.server, err = vnet.New(&e.cfg) - if err != nil { - t.Fatalf("vnet.New: %v", err) - } - t.Cleanup(func() { e.server.Close() }) - - // Register compiled binaries with the file server VIP. - // Binaries are registered at _/ (e.g. "linux_amd64/tta"). - for _, p := range needPlatform.Slice() { - dir := p.goos + "_" + p.goarch - for _, name := range []string{"tta", "tailscale", "tailscaled"} { - data, err := os.ReadFile(filepath.Join(e.binDir, dir, name)) - if err != nil { - t.Fatalf("reading compiled %s/%s: %v", dir, name, err) - } - e.server.RegisterFile(dir+"/"+name, data) - } - } - - // Cloud-init config is delivered via local seed ISOs (created in startCloudQEMU), - // not via the cloud-init HTTP VIP, because network-config must be available - // during init-local before systemd-networkd-wait-online blocks. - - // Start Unix socket listener. - e.sockAddr = filepath.Join(e.tempDir, "vnet.sock") - srv, err := net.Listen("unix", e.sockAddr) - if err != nil { - t.Fatalf("listen unix: %v", err) - } - t.Cleanup(func() { srv.Close() }) - - go func() { - for { - c, err := srv.Accept() - if err != nil { - return - } - go e.server.ServeUnixConn(c.(*net.UnixConn), vnet.ProtocolQEMU) - } - }() - - // Launch QEMU processes. - for _, n := range e.nodes { - if err := e.startQEMU(n); err != nil { - t.Fatalf("startQEMU(%s): %v", n.name, err) - } + if err := bootEg.Wait(); err != nil { + t.Fatalf("boot: %v", err) } // Set up agent clients and wait for all agents to connect. for _, n := range e.nodes { + if n.noAgent { + continue + } + e.initVnet() // ensure vnet is ready for agent clients n.agent = e.server.NodeAgentClient(n.vnetNode) n.vnetNode.SetClient(n.agent) } - // Wait for agents, then bring up tailscale. var agentEg errgroup.Group for _, n := range e.nodes { + if n.noAgent { + continue + } agentEg.Go(func() error { + aStep := e.Step("Wait for agent: " + n.name) + aStep.Begin() t.Logf("[%s] waiting for agent...", n.name) - st, err := n.agent.Status(ctx) - if err != nil { - return fmt.Errorf("[%s] agent status: %w", n.name, err) + if n.joinTailnet { + st, err := n.agent.Status(ctx) + if err != nil { + return fmt.Errorf("[%s] agent status: %w", n.name, err) + } + t.Logf("[%s] agent connected, backend state: %s", n.name, st.BackendState) + } else { + if err := e.waitForAgentConn(ctx, n); err != nil { + return fmt.Errorf("[%s] agent connect: %w", n.name, err) + } + t.Logf("[%s] agent connected (no tailscale)", n.name) } - t.Logf("[%s] agent connected, backend state: %s", n.name, st.BackendState) + aStep.End(nil) if n.vnetNode.HostFirewall() { if err := n.agent.EnableHostFirewall(ctx); err != nil { @@ -296,17 +609,47 @@ func (e *Env) Start() { } if n.joinTailnet { + tsStep := e.Step("Tailscale up: " + n.name) + tsStep.Begin() if err := e.tailscaleUp(ctx, n); err != nil { return fmt.Errorf("[%s] tailscale up: %w", n.name, err) } - st, err = n.agent.Status(ctx) + st2, err := n.agent.Status(ctx) if err != nil { return fmt.Errorf("[%s] status after up: %w", n.name, err) } - if st.BackendState != "Running" { - return fmt.Errorf("[%s] state = %q, want Running", n.name, st.BackendState) + if st2.BackendState != "Running" { + return fmt.Errorf("[%s] state = %q, want Running", n.name, st2.BackendState) } - t.Logf("[%s] up with %v", n.name, st.Self.TailscaleIPs) + + // Apply any capabilities for the node to the map. + // SetNodeCapMap pushes an updated map response immediately, then wait + // until the node reports the capability in its status. + if cm := n.vnetNode.WantCapMap(); cm != nil { + e.server.ControlServer().SetNodeCapMap(st2.Self.PublicKey, cm) + if err := tstest.WaitFor(15*time.Second, func() error { + st, err := n.agent.Status(ctx) + if err != nil { + return err + } + if st.Self == nil { + return fmt.Errorf("self is nil") + } + for c := range cm { + if !st.Self.HasCap(c) { + return fmt.Errorf("cap %v not yet received", c) + } + } + return nil + }); err != nil { + return fmt.Errorf("[%s] waiting for capabilities: %w", n.name, err) + } + } + + ips := fmt.Sprintf("%v", st2.Self.TailscaleIPs) + e.setNodeTailscale(n.name, "Running "+ips) + t.Logf("[%s] up with %v", n.name, st2.Self.TailscaleIPs) + tsStep.End(nil) } return nil @@ -332,6 +675,13 @@ func (e *Env) tailscaleUp(ctx context.Context, n *Node) error { if n.advertiseRoutes != "" { url += "&advertise-routes=" + n.advertiseRoutes } + if n.snatSubnetRoutes != nil { + if *n.snatSubnetRoutes { + url += "&snat-subnet-routes=true" + } else { + url += "&snat-subnet-routes=false" + } + } req, err := http.NewRequestWithContext(ctx, "GET", url, nil) if err != nil { return err @@ -368,6 +718,228 @@ func (e *Env) startWebServer(ctx context.Context, n *Node) error { return nil } +// SetExitNode sets the client node's exit node to use for internet traffic. +// If exitNode is nil, the client's exit node is cleared (i.e., turned off). +// Otherwise exitNode must be a tailnet node with an approved 0.0.0.0/0 (and +// ::/0) route, typically configured via [AdvertiseRoutes] and +// [Env.ApproveRoutes]. +func (e *Env) SetExitNode(client, exitNode *Node) { + e.t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + var ip netip.Addr + if exitNode != nil { + st, err := exitNode.agent.Status(ctx) + if err != nil { + e.t.Fatalf("SetExitNode: status for %s: %v", exitNode.name, err) + } + if len(st.Self.TailscaleIPs) == 0 { + e.t.Fatalf("SetExitNode: %s has no Tailscale IPs", exitNode.name) + } + ip = st.Self.TailscaleIPs[0] + } + + if _, err := client.agent.EditPrefs(ctx, &ipn.MaskedPrefs{ + Prefs: ipn.Prefs{ + ExitNodeID: "", + ExitNodeIP: ip, + }, + ExitNodeIDSet: true, + ExitNodeIPSet: true, + }); err != nil { + e.t.Fatalf("SetExitNode(%s -> %v): %v", client.name, exitNode, err) + } + if exitNode == nil { + e.t.Logf("[%s] cleared exit node", client.name) + } else { + e.t.Logf("[%s] using exit node %s (%v)", client.name, exitNode.name, ip) + } +} + +// SetExitNodeIP sets the client's ExitNodeIP preference directly, by IP. +// This is the right helper for plain-WireGuard exit nodes (Mullvad-style) +// that aren't on the tailnet — pass an invalid netip.Addr{} to clear. +// For tailnet exit nodes whose Tailscale IP is discoverable via TTA, use +// [Env.SetExitNode] instead. +func (e *Env) SetExitNodeIP(client *Node, ip netip.Addr) { + e.t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + if _, err := client.agent.EditPrefs(ctx, &ipn.MaskedPrefs{ + Prefs: ipn.Prefs{ + ExitNodeID: "", + ExitNodeIP: ip, + }, + ExitNodeIDSet: true, + ExitNodeIPSet: true, + }); err != nil { + e.t.Fatalf("SetExitNodeIP(%s, %v): %v", client.name, ip, err) + } + if !ip.IsValid() { + e.t.Logf("[%s] cleared exit node", client.name) + } else { + e.t.Logf("[%s] using exit-node IP %v", client.name, ip) + } +} + +// ControlServer returns the underlying test control server, for tests that +// need to inject custom peers, masquerade pairs, etc. The returned server's +// Node store is shared with the running tailnet, so changes take effect on +// the next netmap update sent to peers. +func (e *Env) ControlServer() *testcontrol.Server { + return e.server.ControlServer() +} + +// BringUpMullvadWGServer brings up a userspace WireGuard server on n, +// configured as a single-peer "Mullvad-style" exit-node target. The +// server runs inside n's TTA process on a Linux TUN named "wg0". +// +// gw is the WG interface address (e.g. 10.64.0.1/24). The server listens +// on listenPort, accepts only the single peer whose public key is peerPub +// at peerAllowedIP, and MASQUERADEs egress traffic from masqSrc so that +// decrypted packets from the peer egress with n's WAN IP. +// +// It returns the freshly generated public key of the WG server, which +// the caller must pin as the peer key on the [tailcfg.Node] it injects +// into the netmap to advertise this server as a plain-WireGuard exit +// node. It fatals the test on error. +func (e *Env) BringUpMullvadWGServer(n *Node, gw netip.Prefix, listenPort uint16, peerPub key.NodePublic, peerAllowedIP, masqSrc netip.Prefix) key.NodePublic { + e.t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + peerPubRaw := peerPub.Raw32() + v := url.Values{ + "addr": {gw.String()}, + "listen-port": {strconv.Itoa(int(listenPort))}, + "peer-pub-b64": {base64.StdEncoding.EncodeToString(peerPubRaw[:])}, + "peer-allowed-ip": {peerAllowedIP.String()}, + "masq-src": {masqSrc.String()}, + } + reqURL := "http://unused/wg-server-up?" + v.Encode() + req, err := http.NewRequestWithContext(ctx, "GET", reqURL, nil) + if err != nil { + e.t.Fatalf("BringUpMullvadWGServer: %v", err) + } + res, err := n.agent.HTTPClient.Do(req) + if err != nil { + e.t.Fatalf("BringUpMullvadWGServer(%s): %v", n.name, err) + } + defer res.Body.Close() + body, _ := io.ReadAll(res.Body) + if res.StatusCode != 200 { + e.t.Fatalf("BringUpMullvadWGServer(%s): %s: %s", n.name, res.Status, body) + } + var pubB64 string + for _, line := range strings.Split(string(body), "\n") { + if s, ok := strings.CutPrefix(strings.TrimSpace(line), "PUBKEY="); ok { + pubB64 = s + break + } + } + if pubB64 == "" { + e.t.Fatalf("BringUpMullvadWGServer(%s): no PUBKEY in response: %q", n.name, body) + } + pubRaw, err := base64.StdEncoding.DecodeString(pubB64) + if err != nil || len(pubRaw) != 32 { + e.t.Fatalf("BringUpMullvadWGServer(%s): bad PUBKEY %q: %v", n.name, pubB64, err) + } + return key.NodePublicFromRaw32(mem.B(pubRaw)) +} + +// Status returns the tailscale status of the given node, fetched from its +// TTA agent. It fatals the test on error. +func (e *Env) Status(n *Node) *ipnstate.Status { + e.t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + st, err := n.agent.Status(ctx) + if err != nil { + e.t.Fatalf("Status(%s): %v", n.name, err) + } + return st +} + +// ClientMetrics returns the client metrics exported by the given node. +func (e *Env) ClientMetrics(n *Node) ClientMetrics { + e.t.Helper() + raw, err := n.Agent().DaemonMetrics(e.t.Context()) + if err != nil { + e.t.Fatalf("Node %q DaemonMetrics: %v", n.Name(), err) + } + + // Metrics are reported in Prometheus exposition format. + var parser expfmt.TextParser + mfs, err := parser.TextToMetricFamilies(bytes.NewReader(raw)) + if err != nil { + e.t.Fatalf("Node %q parse client metrics: %v", n.Name(), err) + } + + // Tailscale client metrics are all unlabelled integer-valued counters and + // gauges, so we don't need to handle the full generality of the Prometheus + // representation. If we see anything else, we'll log and skip it. + out := make(ClientMetrics) + for _, mf := range mfs { + name := mf.GetName() + if _, ok := out[name]; ok { + e.t.Logf("Node %q: duplicate client metric %q (ignored)", n.Name(), name) + continue + } else if len(mf.Metric) != 1 { + e.t.Logf("Node %q: got %d values for client metric %q, want 1 (ignored)", n.Name(), len(mf.Metric), name) + continue + } + + var mtype string + var value int64 + switch mf.GetType() { + case dto.MetricType_COUNTER: + mtype = "counter" + value = int64(mf.Metric[0].GetCounter().GetValue()) + case dto.MetricType_GAUGE: + mtype = "gauge" + value = int64(mf.Metric[0].GetGauge().GetValue()) + default: + e.t.Logf("Node %q unexpected client metric %q type %q (ignored)", n.Name(), name, mf.GetType().String()) + continue + } + out[name] = ClientMetric{ + Name: name, + Type: mtype, + Value: value, + } + } + return out +} + +// ClientMetrics is a view of the client metrics exported by a node. +// The keys of the map are the metric names. +type ClientMetrics map[string]ClientMetric + +// ClientMetric is a view of a node client metric. +type ClientMetric struct { + Name string // as published to the clientmetrics package + Type string // either "gauge" or "counter" + Value int64 // the gauge or counter value +} + +// SetAcceptRoutes toggles the node's RouteAll preference (the +// --accept-routes flag), controlling whether it installs subnet routes +// advertised by peers. +func (e *Env) SetAcceptRoutes(n *Node, on bool) { + e.t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + if _, err := n.agent.EditPrefs(ctx, &ipn.MaskedPrefs{ + Prefs: ipn.Prefs{RouteAll: on}, + RouteAllSet: true, + }); err != nil { + e.t.Fatalf("SetAcceptRoutes(%s, %v): %v", n.name, on, err) + } + e.t.Logf("[%s] accept-routes=%v", n.name, on) +} + // ApproveRoutes tells the test control server to approve subnet routes // for the given node. The routes should be CIDR strings. func (e *Env) ApproveRoutes(n *Node, routes ...string) { @@ -432,33 +1004,198 @@ func (e *Env) ApproveRoutes(n *Node, routes ...string) { } } -// ping pings from one node to another's Tailscale IP, retrying until it succeeds -// or the timeout expires. This establishes the WireGuard tunnel between the nodes. +// ping does a disco ping from one node to another's Tailscale IP, retrying +// for up to 30 seconds, fataling on failure. It is used internally to wake +// up magicsock peer state before a test runs; tests that want to assert +// connectivity should use [Env.Ping] with the appropriate ping type and +// timeout. func (e *Env) ping(from, to *Node) { - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + e.t.Helper() + if err := e.Ping(from, to, tailcfg.PingDisco, 30*time.Second); err != nil { + e.t.Fatal(err) + } +} + +// Ping pings from one node to another's Tailscale IP using the given ping +// type, retrying until it succeeds or timeout expires. It returns the error +// from the last attempt if the timeout expires. Unlike the internal ping +// helper, it does not fatal the test on failure; callers can check the error +// to assert on timing. +// +// [tailcfg.PingTSMP] actually flows packets across the WireGuard tunnel and is +// the right choice for asserting end-to-end connectivity. +// [tailcfg.PingDisco] only exchanges disco messages between magicsock layers +// and is useful for warming up peer state without requiring a working tunnel. +func (e *Env) Ping(from, to *Node, ptype tailcfg.PingType, timeout time.Duration) error { + e.t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() toSt, err := to.agent.Status(ctx) if err != nil { - e.t.Fatalf("ping: can't get %s status: %v", to.name, err) + return fmt.Errorf("ping: can't get %s status: %w", to.name, err) } if len(toSt.Self.TailscaleIPs) == 0 { - e.t.Fatalf("ping: %s has no Tailscale IPs", to.name) + return fmt.Errorf("ping: %s has no Tailscale IPs", to.name) } targetIP := toSt.Self.TailscaleIPs[0] + var lastErr error for { - pingCtx, pingCancel := context.WithTimeout(ctx, 3*time.Second) - pr, err := from.agent.PingWithOpts(pingCtx, targetIP, tailcfg.PingDisco, local.PingOpts{}) + // Per-attempt timeout: cap at 3s but never exceed the remaining budget. + attemptTimeout := 3 * time.Second + if d := time.Until(deadline(ctx)); d < attemptTimeout { + attemptTimeout = d + } + if attemptTimeout <= 0 { + break + } + pingCtx, pingCancel := context.WithTimeout(ctx, attemptTimeout) + pr, err := from.agent.PingWithOpts(pingCtx, targetIP, ptype, local.PingOpts{}) pingCancel() if err == nil && pr.Err == "" { - e.logVerbosef("ping: %s -> %s OK", from.name, targetIP) - return + e.logVerbosef("ping(%s): %s -> %s OK", ptype, from.name, targetIP) + return nil + } + switch { + case err != nil: + lastErr = err + case pr.Err != "": + lastErr = fmt.Errorf("%s", pr.Err) } if ctx.Err() != nil { - e.t.Fatalf("ping: %s -> %s timed out", from.name, targetIP) + break } - time.Sleep(time.Second) + time.Sleep(500 * time.Millisecond) + } + if lastErr == nil { + lastErr = ctx.Err() + } + return fmt.Errorf("ping(%s): %s -> %s (%s) timed out after %v: %w", ptype, from.name, to.name, targetIP, timeout, lastErr) +} + +// deadline returns ctx's deadline, or a zero Time if it has none. +func deadline(ctx context.Context) time.Time { + d, _ := ctx.Deadline() + return d +} + +// PeerDiscoKey returns n's view of the given peer's disco key. It returns a +// non-nil error if the LocalAPI request fails (e.g. tailscaled briefly +// unavailable during a restart). It returns (zero, false, nil) if n is +// reachable but has no record of the given peer in its current netmap. +// +// PeerDiscoKey is suitable for use inside a [tstest.WaitFor] poll loop: it +// does not fatal the test on transient errors. +// +// The disco key is fetched from the debug-only "peer-disco-keys" LocalAPI +// action ([ipnlocal.LocalBackend.DebugPeerDiscoKeys]) rather than via +// [ipnstate.Status], to keep the production PeerStatus struct free of disco +// keys (and free of non-comparable fields like [key.DiscoPublic] that break +// reflect-based test helpers). +func (e *Env) PeerDiscoKey(n *Node, peer key.NodePublic) (key.DiscoPublic, bool, error) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + got, err := n.agent.DebugResultJSON(ctx, "peer-disco-keys") + if err != nil { + return key.DiscoPublic{}, false, err + } + // DebugResultJSON returns the result as a generic any (the body is + // re-decoded into any), so the map comes back keyed by string text- + // encoded node keys. Re-marshal+unmarshal into a typed map for cleaner + // lookup. (Roundtripping through JSON is fine for a test helper.) + raw, err := json.Marshal(got) + if err != nil { + return key.DiscoPublic{}, false, fmt.Errorf("re-marshal: %w", err) + } + var m map[key.NodePublic]key.DiscoPublic + if err := json.Unmarshal(raw, &m); err != nil { + return key.DiscoPublic{}, false, fmt.Errorf("unmarshal peer-disco-keys: %w", err) + } + d, ok := m[peer] + return d, ok, nil +} + +// RotateDiscoKey asks tailscaled on n to rotate its discovery (magicsock) key +// in place via the LocalAPI debug action. The node key, control connection, +// and other tailscaled state are unaffected. It fatals the test on error. +func (e *Env) RotateDiscoKey(n *Node) { + e.t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + if err := n.agent.DebugAction(ctx, "rotate-disco-key"); err != nil { + e.t.Fatalf("RotateDiscoKey(%s): %v", n.name, err) + } +} + +// RestartTailscaled signals tailscaled on n to die so that its supervisor +// (gokrazy) restarts it. It then waits for tailscaled to come back to the +// "Running" backend state. It fatals the test on error. +// +// Restarting tailscaled is currently only supported on gokrazy nodes. +func (e *Env) RestartTailscaled(n *Node) { + e.t.Helper() + if !n.os.IsGokrazy { + e.t.Fatalf("RestartTailscaled(%s): only supported on gokrazy nodes (have %q)", n.name, n.os.Name) + } + ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, "GET", "http://unused/restart-tailscaled", nil) + if err != nil { + e.t.Fatalf("RestartTailscaled(%s): %v", n.name, err) + } + res, err := n.agent.HTTPClient.Do(req) + if err != nil { + e.t.Fatalf("RestartTailscaled(%s): %v", n.name, err) + } + body, _ := io.ReadAll(res.Body) + res.Body.Close() + if res.StatusCode != 200 { + e.t.Fatalf("RestartTailscaled(%s): %s: %s", n.name, res.Status, body) + } + e.t.Logf("[%s] %s", n.name, strings.TrimSpace(string(body))) + + // Wait for tailscaled to come back. Status calls will fail while the unix + // socket is gone, then return Starting/NeedsLogin briefly before settling + // on Running. + if err := tstest.WaitFor(45*time.Second, func() error { + st, err := n.agent.Status(ctx) + if err != nil { + return err + } + if st.BackendState != "Running" { + return fmt.Errorf("backend state = %q", st.BackendState) + } + return nil + }); err != nil { + e.t.Fatalf("RestartTailscaled(%s): waiting for Running: %v", n.name, err) + } +} + +// AddRoute adds a kernel static route on the given node, pointing prefix at +// via. It uses TTA's /add-route handler, so it works on any node where TTA +// is running (which is all of them — DontJoinTailnet only skips +// `tailscale up`; the agent runs regardless). Currently Linux-only in TTA. +// +// It fatals the test on error. +func (e *Env) AddRoute(n *Node, prefix, via string) { + e.t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + reqURL := fmt.Sprintf("http://unused/add-route?prefix=%s&via=%s", prefix, via) + req, err := http.NewRequestWithContext(ctx, "GET", reqURL, nil) + if err != nil { + e.t.Fatalf("AddRoute: %v", err) + } + resp, err := n.agent.HTTPClient.Do(req) + if err != nil { + e.t.Fatalf("AddRoute(%s, %s → %s): %v", n.name, prefix, via, err) + } + defer resp.Body.Close() + body, _ := io.ReadAll(resp.Body) + if resp.StatusCode != 200 { + e.t.Fatalf("AddRoute(%s, %s → %s): %s: %s", n.name, prefix, via, resp.Status, body) } } @@ -575,7 +1312,312 @@ func (e *Env) HTTPGet(from *Node, targetURL string) string { return "" } -// ensureGokrazy finds or builds the gokrazy base image and kernel. +// setNodeScreenshot stores the latest screenshot data URI for a node. +func (e *Env) setNodeScreenshot(name, dataURI string) { + e.nodeStatusMu.Lock() + if ns := e.nodeStatus[name]; ns != nil { + ns.Screenshot = dataURI + } + e.nodeStatusMu.Unlock() +} + +// setNodeScreenshotPort stores the Host.app screenshot server port for a node. +func (e *Env) setNodeScreenshotPort(name string, port int) { + e.nodeStatusMu.Lock() + if ns := e.nodeStatus[name]; ns != nil { + ns.ScreenshotPort = port + } + e.nodeStatusMu.Unlock() +} + +// nodeScreenshotPort returns the Host.app screenshot server port for a node, or 0. +func (e *Env) nodeScreenshotPort(name string) int { + e.nodeStatusMu.Lock() + defer e.nodeStatusMu.Unlock() + if ns := e.nodeStatus[name]; ns != nil { + return ns.ScreenshotPort + } + return 0 +} + +// initVnet creates the vnet server. Called once via sync.Once. +func (e *Env) initVnet() { + e.vnetOnce.Do(func() { + var err error + e.server, err = vnet.New(&e.cfg) + if err != nil { + e.t.Fatalf("vnet.New: %v", err) + } + e.t.Cleanup(func() { e.server.Close() }) + + e.server.SetDHCPCallback(func(mac vnet.MAC, nodeNum int, msgType layers.DHCPMsgType, ip netip.Addr) { + name := e.nodeNameByNum(nodeNum) + nicIdx := e.nicIndexForMAC(name, mac) + ipStr := ip.String() + switch msgType { + case layers.DHCPMsgTypeDiscover: + e.setNodeDHCP(name, nicIdx, "Discover sent") + e.eventBus.Publish(VMEvent{NodeName: name, Type: EventDHCPDiscover, Message: "DHCP Discover sent", NIC: nicIdx}) + case layers.DHCPMsgTypeOffer: + e.setNodeDHCP(name, nicIdx, "Offered "+ipStr) + e.eventBus.Publish(VMEvent{NodeName: name, Type: EventDHCPOffer, Message: "DHCP Offer received", Detail: ipStr, NIC: nicIdx}) + case layers.DHCPMsgTypeRequest: + e.setNodeDHCP(name, nicIdx, "Requesting "+ipStr) + e.eventBus.Publish(VMEvent{NodeName: name, Type: EventDHCPRequest, Message: "DHCP Request sent", Detail: ipStr, NIC: nicIdx}) + case layers.DHCPMsgTypeAck: + e.setNodeDHCP(name, nicIdx, "Got "+ipStr) + e.eventBus.Publish(VMEvent{NodeName: name, Type: EventDHCPAck, Message: "DHCP Ack: got " + ipStr, Detail: ipStr, NIC: nicIdx}) + } + }) + + if e.sameTailnetUser { + e.server.ControlServer().AllNodesSameUser = true + } + if e.allOnline { + e.server.ControlServer().AllOnline = true + } + if e.peerRelayGrants { + e.server.ControlServer().PeerRelayGrants = true + } + }) +} + +// ensureQEMUSocket creates the Unix stream socket for QEMU VMs. Called once. +func (e *Env) ensureQEMUSocket() { + e.qemuSockOnce.Do(func() { + e.initVnet() + e.sockAddr = filepath.Join(e.sockDir, "vnet.sock") + srv, err := net.Listen("unix", e.sockAddr) + if err != nil { + e.t.Fatalf("listen unix: %v", err) + } + e.t.Cleanup(func() { srv.Close() }) + go func() { + for { + c, err := srv.Accept() + if err != nil { + return + } + go e.server.ServeUnixConn(c.(*net.UnixConn), vnet.ProtocolQEMU) + } + }() + }) +} + +// ensureDgramSocket creates the Unix dgram socket for macOS VMs. Called once. +func (e *Env) ensureDgramSocket() { + e.dgramSockOnce.Do(func() { + e.initVnet() + e.dgramSockAddr = filepath.Join(e.sockDir, "dgram.sock") + dgramAddr, err := net.ResolveUnixAddr("unixgram", e.dgramSockAddr) + if err != nil { + e.t.Fatalf("resolve dgram addr: %v", err) + } + uc, err := net.ListenUnixgram("unixgram", dgramAddr) + if err != nil { + e.t.Fatalf("listen unixgram: %v", err) + } + e.t.Cleanup(func() { uc.Close() }) + go e.server.ServeUnixConn(uc, vnet.ProtocolUnixDGRAM) + }) +} + +// ensureCompiled compiles binaries for the given platform and registers them +// with the vnet file server. Safe for concurrent use; only compiles once per platform. +func (e *Env) ensureCompiled(ctx context.Context, goos, goarch string) { + key := goos + "_" + goarch + + e.compileMu.Lock() + once, ok := e.compileOnce[key] + if !ok { + once = new(sync.Once) + mak.Set(&e.compileOnce, key, once) + } + e.compileMu.Unlock() + + once.Do(func() { + step := e.Step(fmt.Sprintf("Compile %s_%s binaries", goos, goarch)) + step.Begin() + if err := e.compileBinariesForOS(ctx, goos, goarch); err != nil { + step.End(err) + e.t.Fatalf("compileBinariesForOS(%s, %s): %v", goos, goarch, err) + } + step.End(nil) + e.registerBinaries(goos, goarch) + }) +} + +// ensureImage prepares the cloud image for os and returns any error from the +// preparation. Safe for concurrent use; only prepares once per OS name. +func (e *Env) ensureImage(ctx context.Context, os OSImage) error { + e.compileMu.Lock() + once, ok := e.imageOnce[os.Name] + if !ok { + once = new(sync.Once) + mak.Set(&e.imageOnce, os.Name, once) + } + e.compileMu.Unlock() + + var err error + once.Do(func() { + step := e.Step(fmt.Sprintf("Prepare %s image", os.Name)) + step.Begin() + err = ensureImage(ctx, os) + step.End(err) + }) + return err +} + +// registerBinaries registers compiled binaries with the vnet file server. +// Safe for concurrent use. +func (e *Env) registerBinaries(goos, goarch string) { + e.initVnet() + dir := goos + "_" + goarch + for _, name := range []string{"tta", "tailscale", "tailscaled"} { + data, err := os.ReadFile(filepath.Join(e.binDir, dir, name)) + if err != nil { + e.t.Fatalf("reading compiled %s/%s: %v", dir, name, err) + } + e.server.RegisterFile(dir+"/"+name, data) + } +} + +// waitForAgentConn waits for a TTA agent to connect by issuing a simple +// HTTP GET to the root endpoint, without requiring tailscaled. +func (e *Env) waitForAgentConn(ctx context.Context, n *Node) error { + for { + reqCtx, cancel := context.WithTimeout(ctx, 2*time.Second) + req, err := http.NewRequestWithContext(reqCtx, "GET", "http://unused/", nil) + if err != nil { + cancel() + return err + } + res, err := n.agent.HTTPClient.Do(req) + cancel() + if err == nil { + res.Body.Close() + return nil + } + if ctx.Err() != nil { + return ctx.Err() + } + time.Sleep(500 * time.Millisecond) + } +} + +// Agent returns the node's TTA agent client, or nil if NoAgent is set. +func (n *Node) Agent() *vnet.NodeAgentClient { + return n.agent +} + +// LANPing pings a LAN IP from the given node using TTA's /ping endpoint. +// It retries for up to 2 minutes, which is enough for a macOS VM to boot +// and acquire a DHCP lease. +func (e *Env) LANPing(from *Node, targetIP netip.Addr) { + if from.agent == nil { + e.t.Fatalf("LANPing: node %s has no agent (NoAgent set?)", from.name) + } + e.t.Logf("LANPing: %s -> %s", from.name, targetIP) + deadline := time.Now().Add(2 * time.Minute) + for attempt := 0; time.Now().Before(deadline); attempt++ { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + reqURL := fmt.Sprintf("http://unused/ping?host=%s", targetIP) + req, err := http.NewRequestWithContext(ctx, "GET", reqURL, nil) + if err != nil { + cancel() + e.t.Fatalf("LANPing: %v", err) + } + res, err := from.agent.HTTPClient.Do(req) + cancel() + if err != nil { + if attempt%10 == 0 { + e.t.Logf("LANPing attempt %d: %v", attempt+1, err) + } + time.Sleep(2 * time.Second) + continue + } + body, _ := io.ReadAll(res.Body) + res.Body.Close() + if res.StatusCode == 200 { + e.t.Logf("LANPing: %s -> %s succeeded on attempt %d", from.name, targetIP, attempt+1) + return + } + if attempt%10 == 0 { + e.t.Logf("LANPing attempt %d: status %d, body: %s", attempt+1, res.StatusCode, string(body)) + } + time.Sleep(2 * time.Second) + } + e.t.Fatalf("LANPing: %s -> %s timed out after 2 minutes", from.name, targetIP) +} + +// SendTaildropFile sends a file via Taildrop from one node to another. +// The to node must be on the tailnet. It fatals on error. +func (e *Env) SendTaildropFile(from, to *Node, name string, content []byte) { + e.t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second) + defer cancel() + + st, err := to.agent.Status(ctx) + if err != nil { + e.t.Fatalf("SendTaildropFile: status for %s: %v", to.name, err) + } + if len(st.Self.TailscaleIPs) == 0 { + e.t.Fatalf("SendTaildropFile: %s has no Tailscale IPs", to.name) + } + target := st.Self.TailscaleIPs[0].String() + + reqURL := fmt.Sprintf("http://unused/taildrop-send?to=%s&name=%s", target, name) + req, err := http.NewRequestWithContext(ctx, "POST", reqURL, bytes.NewReader(content)) + if err != nil { + e.t.Fatalf("SendTaildropFile: %v", err) + } + res, err := from.agent.HTTPClient.Do(req) + if err != nil { + e.t.Fatalf("SendTaildropFile(%s -> %s): %v", from.name, to.name, err) + } + defer res.Body.Close() + body, _ := io.ReadAll(res.Body) + if res.StatusCode != 200 { + e.t.Fatalf("SendTaildropFile(%s -> %s): %s: %s", from.name, to.name, res.Status, body) + } + if msg := strings.TrimSpace(string(body)); msg != "" { + e.t.Logf("[%s] %s", from.name, msg) + } + e.t.Logf("[%s] sent Taildrop %q (%d bytes) to %s", from.name, name, len(content), to.name) +} + +// RecvTaildropFile waits for an incoming Taildrop file on the node and +// returns the filename and contents. The provided context bounds the wait; +// in addition, RecvTaildropFile imposes its own 90s upper bound. It fatals +// on error or timeout. +func (e *Env) RecvTaildropFile(ctx context.Context, n *Node) (name string, content []byte) { + e.t.Helper() + ctx, cancel := context.WithTimeout(ctx, 90*time.Second) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, "GET", "http://unused/taildrop-recv", nil) + if err != nil { + e.t.Fatalf("RecvTaildropFile: %v", err) + } + res, err := n.agent.HTTPClient.Do(req) + if err != nil { + e.t.Fatalf("RecvTaildropFile(%s): %v", n.name, err) + } + defer res.Body.Close() + body, _ := io.ReadAll(res.Body) + if res.StatusCode != 200 { + e.t.Fatalf("RecvTaildropFile(%s): %s: %s", n.name, res.Status, body) + } + name = res.Header.Get("Taildrop-Filename") + e.t.Logf("[%s] received Taildrop %q (%d bytes)", n.name, name, len(body)) + return name, body +} + +var buildGokrazy sync.Once + +// ensureGokrazy builds the gokrazy base image (once per test process) and +// locates the kernel. The build is fast (~4s) so we always rebuild to ensure +// the baked-in binaries (tta, tailscale, tailscaled) match the current source. func (e *Env) ensureGokrazy(ctx context.Context) error { if e.gokrazyBase != "" { return nil // already found @@ -586,21 +1628,23 @@ func (e *Env) ensureGokrazy(ctx context.Context) error { return err } - e.gokrazyBase = filepath.Join(modRoot, "gokrazy/natlabapp.qcow2") - if _, err := os.Stat(e.gokrazyBase); err != nil { - if !os.IsNotExist(err) { - return err - } + var buildErr error + buildGokrazy.Do(func() { e.t.Logf("building gokrazy natlab image...") cmd := exec.CommandContext(ctx, "make", "natlab") cmd.Dir = filepath.Join(modRoot, "gokrazy") cmd.Stderr = os.Stderr cmd.Stdout = os.Stdout if err := cmd.Run(); err != nil { - return fmt.Errorf("make natlab: %w", err) + buildErr = fmt.Errorf("make natlab: %w", err) } + }) + if buildErr != nil { + return buildErr } + e.gokrazyBase = filepath.Join(modRoot, "gokrazy/natlabapp.qcow2") + kernel, err := findKernelPath(filepath.Join(modRoot, "go.mod")) if err != nil { return fmt.Errorf("finding kernel: %w", err) @@ -609,8 +1653,13 @@ func (e *Env) ensureGokrazy(ctx context.Context) error { return nil } -// compileBinariesForOS cross-compiles tta, tailscale, and tailscaled for the -// given GOOS/GOARCH and places them in e.binDir/_/. +// compileBinariesForOS prepares the tta, tailscale, and tailscaled binaries +// for the given GOOS/GOARCH and places them in e.binDir/_/. +// +// tta is always built from the local source tree (the test agent must match +// the test framework). When --test-version is set, tailscale and tailscaled +// are taken from the downloaded release tarball instead of being compiled +// from source. func (e *Env) compileBinariesForOS(ctx context.Context, goos, goarch string) error { modRoot, err := findModRoot() if err != nil { @@ -623,14 +1672,20 @@ func (e *Env) compileBinariesForOS(ctx context.Context, goos, goarch string) err return err } - binaries := []struct{ name, pkg string }{ - {"tta", "./cmd/tta"}, - {"tailscale", "./cmd/tailscale"}, - {"tailscaled", "./cmd/tailscaled"}, + // Use downloaded release binaries only on Linux: pkgs.tailscale.com only + // publishes Linux tarballs, so other GOOS values still build from source. + useDownloaded := e.testVersion != "" && goos == "linux" + + type binary struct{ name, pkg string } + buildBins := []binary{{"tta", "./cmd/tta"}} + if !useDownloaded { + buildBins = append(buildBins, + binary{"tailscale", "./cmd/tailscale"}, + binary{"tailscaled", "./cmd/tailscaled"}) } var eg errgroup.Group - for _, bin := range binaries { + for _, bin := range buildBins { eg.Go(func() error { outPath := filepath.Join(outDir, bin.name) e.t.Logf("compiling %s/%s...", dir, bin.name) @@ -644,9 +1699,36 @@ func (e *Env) compileBinariesForOS(ctx context.Context, goos, goarch string) err return nil }) } + + if useDownloaded { + eg.Go(func() error { + srcDir, err := ensureVersionBinaries(ctx, e.testVersion, goarch, e.t.Logf) + if err != nil { + return err + } + for _, name := range []string{"tailscale", "tailscaled"} { + if err := copyFile(filepath.Join(srcDir, name), filepath.Join(outDir, name), 0755); err != nil { + return fmt.Errorf("staging %s/%s: %w", dir, name, err) + } + } + e.t.Logf("staged version %s tailscale & tailscaled for %s", e.testVersion, dir) + return nil + }) + } + return eg.Wait() } +// copyFile copies src to dst with the given permission bits. +func copyFile(src, dst string, perm os.FileMode) error { + in, err := os.Open(src) + if err != nil { + return err + } + defer in.Close() + return writeAtomic(dst, in, perm) +} + // findModRoot returns the root of the Go module (where go.mod is). func findModRoot() (string, error) { out, err := exec.Command("go", "env", "GOMOD").CombinedOutput() @@ -686,3 +1768,73 @@ func findKernelPath(goMod string) (string, error) { } return "", fmt.Errorf("gokrazy-kernel not found in %s", goMod) } + +// PingRoute describes what connection type was used to transfer a Disco ping. +type PingRoute string + +const ( + PingRouteDirect PingRoute = "direct" + PingRouteDERP PingRoute = "derp" + PingRouteLocal PingRoute = "local" + PingRouteNil PingRoute = "nil" +) + +// classifyPing finds what kind of route has been used on a ping path. +// It is only really relevant for DiscoPings. +func classifyPing(pr *ipnstate.PingResult) PingRoute { + if pr == nil { + return PingRouteNil + } + + if pr.Endpoint == "" { + return PingRouteDERP + } + + ap, err := netip.ParseAddrPort(pr.Endpoint) + if err == nil && ap.Addr().IsPrivate() { + return PingRouteLocal + } + return PingRouteDirect +} + +// PingExpect retries disco pings until the result matches wantRoute or the +// timeout is reached. It is using DiscoPings as this is the only ping type +// that can classify the connection type. +func (e *Env) PingExpect(from, to *Node, wantRoute PingRoute, timeout time.Duration) error { + e.t.Helper() + ctx, cancel := context.WithTimeout(e.t.Context(), timeout) + defer cancel() + var lastRoute PingRoute + toSt, err := to.agent.Status(ctx) + if err != nil { + return fmt.Errorf("ping: can't get %s status: %w", to.name, err) + } + if len(toSt.Self.TailscaleIPs) == 0 { + return fmt.Errorf("ping: %s has no Tailscale IPs", to.name) + } + targetIP := toSt.Self.TailscaleIPs[0] + for ctx.Err() == nil { + pingCtx, pingCancel := context.WithTimeout(ctx, 3*time.Second) + pr, err := from.agent.PingWithOpts(pingCtx, targetIP, tailcfg.PingDisco, local.PingOpts{}) + pingCancel() + if err == nil && pr.Err == "" { + if got := classifyPing(pr); got == wantRoute { + e.t.Logf("Saw ping type %q", got) + return nil + } else { + e.t.Logf("Saw ping type %q", got) + lastRoute = got + } + } + select { + case <-time.After(500 * time.Millisecond): + case <-ctx.Done(): + } + } + return fmt.Errorf("ping route = %q, want %q (after %v)", lastRoute, wantRoute, timeout) +} + +// NumNodes returns the current number of nodes configured in the env. +func (env *Env) NumNodes() int { + return len(env.nodes) +} diff --git a/tstest/natlab/vmtest/vmtest_test.go b/tstest/natlab/vmtest/vmtest_test.go index 91c8359f1..ad8c6f296 100644 --- a/tstest/natlab/vmtest/vmtest_test.go +++ b/tstest/natlab/vmtest/vmtest_test.go @@ -4,14 +4,78 @@ package vmtest_test import ( + "bytes" + "context" "fmt" + "net/netip" + "runtime" "strings" "testing" + "time" + "tailscale.com/client/local" + "tailscale.com/ipn" + "tailscale.com/net/udprelay/status" + "tailscale.com/tailcfg" + "tailscale.com/tstest" + "tailscale.com/tstest/integration/testcontrol" "tailscale.com/tstest/natlab/vmtest" "tailscale.com/tstest/natlab/vnet" + "tailscale.com/types/key" + "tailscale.com/types/netmap" + "tailscale.com/util/set" ) +// skipIfNotMacOSArm64 skips the test when the host isn't a macOS arm64 host. +// macOS VM tests require Apple Virtualization.framework via tailmac. +// AddNode also enforces this when a macOS node is added, but having an +// explicit skip at the top of macOS-only tests makes the requirement +// obvious to readers. +func skipIfNotMacOSArm64(t *testing.T) { + t.Helper() + if runtime.GOOS != "darwin" || runtime.GOARCH != "arm64" { + t.Skipf("macOS VM tests require a macOS arm64 host (got %s/%s)", runtime.GOOS, runtime.GOARCH) + } +} + +func TestMacOSAndLinuxCanPing(t *testing.T) { + skipIfNotMacOSArm64(t) + env := vmtest.New(t) + + lan := env.AddNetwork("192.168.1.1/24") + + linux := env.AddNode("linux", lan, + vmtest.OS(vmtest.Gokrazy), + vmtest.DontJoinTailnet()) + macos := env.AddNode("macos", lan, + vmtest.OS(vmtest.MacOS), + vmtest.DontJoinTailnet()) + + env.Start() + + env.LANPing(linux, macos.LanIP(lan)) +} + +func TestTwoMacOSVMsCanPing(t *testing.T) { + skipIfNotMacOSArm64(t) + env := vmtest.New(t) + + lan := env.AddNetwork("192.168.1.1/24") + + mac1 := env.AddNode("mac1", lan, + vmtest.OS(vmtest.MacOS), + vmtest.DontJoinTailnet()) + mac2 := env.AddNode("mac2", lan, + vmtest.OS(vmtest.MacOS), + vmtest.DontJoinTailnet()) + + env.Start() + + // Both macOS VMs have TTA. Ping from mac1 to mac2 and vice versa. + env.LANPing(mac1, mac2.LanIP(lan)) + env.LANPing(mac2, mac1.LanIP(lan)) +} + func TestSubnetRouter(t *testing.T) { testSubnetRouterForOS(t, vmtest.Ubuntu2404) } @@ -37,11 +101,1090 @@ func testSubnetRouterForOS(t testing.TB, srOS vmtest.OSImage) { vmtest.DontJoinTailnet(), vmtest.WebServer(8080)) - env.Start() - env.ApproveRoutes(sr, "10.0.0.0/24") + // Declare test-specific steps for the web UI. + approveStep := env.AddStep("Approve subnet routes") + httpStep := env.AddStep("HTTP GET through subnet router") + env.Start() + + approveStep.Begin() + env.ApproveRoutes(sr, "10.0.0.0/24") + approveStep.End(nil) + + httpStep.Begin() body := env.HTTPGet(client, fmt.Sprintf("http://%s:8080/", backend.LanIP(internalNet))) if !strings.Contains(body, "Hello world I am backend") { - t.Fatalf("got %q", body) + httpStep.Fatalf("got %q", body) + } + httpStep.End(nil) +} + +func TestSiteToSite(t *testing.T) { + testSiteToSite(t, vmtest.Ubuntu2404) +} + +// testSiteToSite runs a site-to-site subnet routing test with +// --snat-subnet-routes=false, verifying that original source IPs are preserved +// across Tailscale subnet routes. +// +// Topology: +// +// Site A: backend-a (10.1.0.0/24) ← → sr-a (WAN + LAN-A) +// Site B: backend-b (10.2.0.0/24) ← → sr-b (WAN + LAN-B) +// +// Both subnet routers are on Tailscale with --snat-subnet-routes=false. +// The test sends HTTP from backend-a to backend-b through the subnet routers +// and verifies that backend-b sees backend-a's LAN IP (not the subnet router's). +func testSiteToSite(t *testing.T, srOS vmtest.OSImage) { + env := vmtest.New(t) + + // WAN networks for each site (each behind NAT). + wanA := env.AddNetwork("2.1.1.1", "192.168.1.1/24", vnet.EasyNAT) + wanB := env.AddNetwork("3.1.1.1", "192.168.2.1/24", vnet.EasyNAT) + + // Internal LAN for each site. + lanA := env.AddNetwork("10.1.0.1/24") + lanB := env.AddNetwork("10.2.0.1/24") + + // Subnet routers: each on its WAN + LAN, advertising the local LAN, + // with SNAT disabled to preserve source IPs. + srA := env.AddNode("sr-a", wanA, lanA, + vmtest.OS(srOS), + vmtest.AdvertiseRoutes("10.1.0.0/24"), + vmtest.SNATSubnetRoutes(false)) + srB := env.AddNode("sr-b", wanB, lanB, + vmtest.OS(srOS), + vmtest.AdvertiseRoutes("10.2.0.0/24"), + vmtest.SNATSubnetRoutes(false)) + + // Backend servers on each site's LAN (not on Tailscale). + // Use Ubuntu so we can SSH in to add static routes. + backendA := env.AddNode("backend-a", lanA, + vmtest.OS(vmtest.Ubuntu2404), + vmtest.DontJoinTailnet(), + vmtest.WebServer(8080)) + backendB := env.AddNode("backend-b", lanB, + vmtest.OS(vmtest.Ubuntu2404), + vmtest.DontJoinTailnet(), + vmtest.WebServer(8080)) + + // Declare test-specific steps for the web UI. + approveStep := env.AddStep("Approve subnet routes (sr-a, sr-b)") + staticRouteStep := env.AddStep("Add static routes on backends") + httpStep := env.AddStep("HTTP GET through site-to-site") + + env.Start() + + approveStep.Begin() + env.ApproveRoutes(srA, "10.1.0.0/24") + env.ApproveRoutes(srB, "10.2.0.0/24") + approveStep.End(nil) + + // Add static routes on the backends so that traffic to the remote site's + // subnet goes through the local subnet router. This mirrors how a real + // site-to-site deployment is configured. + srALanIP := srA.LanIP(lanA).String() + srBLanIP := srB.LanIP(lanB).String() + t.Logf("sr-a LAN IP: %s, sr-b LAN IP: %s", srALanIP, srBLanIP) + t.Logf("backend-a LAN IP: %s, backend-b LAN IP: %s", backendA.LanIP(lanA), backendB.LanIP(lanB)) + + staticRouteStep.Begin() + env.AddRoute(backendA, "10.2.0.0/24", srALanIP) + env.AddRoute(backendB, "10.1.0.0/24", srBLanIP) + staticRouteStep.End(nil) + + // Make an HTTP request from backend-a to backend-b through the subnet routers. + // TTA's /http-get falls back to direct dial on non-Tailscale nodes. + httpStep.Begin() + backendBIP := backendB.LanIP(lanB) + body := env.HTTPGet(backendA, fmt.Sprintf("http://%s:8080/", backendBIP)) + t.Logf("response: %s", body) + + if !strings.Contains(body, "Hello world I am backend-b") { + httpStep.Fatalf("expected response from backend-b, got %q", body) + } + + // Verify the source IP was preserved. With --snat-subnet-routes=false, + // backend-b should see backend-a's LAN IP as the source, not sr-b's LAN IP. + backendAIP := backendA.LanIP(lanA).String() + if !strings.Contains(body, "from "+backendAIP) { + httpStep.Fatalf("source IP not preserved: expected %q in response, got %q", backendAIP, body) + } + httpStep.End(nil) +} + +// TestInterNetworkTCP verifies that vnet routes raw TCP between simulated +// networks: a non-Tailscale VM on one NAT'd LAN can reach a webserver on a +// different network using a 1:1 NAT, and the webserver sees the client's +// network's WAN IP as the source (post-NAT). +func TestInterNetworkTCP(t *testing.T) { + env := vmtest.New(t) + + const ( + clientWAN = "1.0.0.1" + webWAN = "5.0.0.1" + ) + + clientNet := env.AddNetwork(clientWAN, "192.168.1.1/24", vnet.EasyNAT) + webNet := env.AddNetwork(webWAN, "192.168.5.1/24", vnet.One2OneNAT) + + client := env.AddNode("client", clientNet, + vmtest.OS(vmtest.Gokrazy), + vmtest.DontJoinTailnet()) + env.AddNode("webserver", webNet, + vmtest.OS(vmtest.Gokrazy), + vmtest.DontJoinTailnet(), + vmtest.WebServer(8080)) + + // Declare test-specific steps for the web UI. + httpStep := env.AddStep("HTTP GET across networks via NAT") + + env.Start() + + httpStep.Begin() + body := env.HTTPGet(client, fmt.Sprintf("http://%s:8080/", webWAN)) + t.Logf("response: %s", body) + if !strings.Contains(body, "Hello world I am webserver") { + httpStep.Fatalf("unexpected response: %q", body) + } + if !strings.Contains(body, "from "+clientWAN) { + httpStep.Fatalf("expected source %q in response, got %q", clientWAN, body) + } + httpStep.End(nil) +} + +// TestSubnetRouterPublicIP verifies that toggling --accept-routes on the +// client switches between dialing a webserver directly and routing through a +// subnet router that advertises the webserver's public IP range. +// +// Topology: client, subnet router, and webserver each live behind their own +// NAT'd network with distinct WAN IPs; the subnet router advertises the +// webserver's network as a route. The webserver echoes the source IP it +// sees: +// - accept-routes=off: client dials webserver directly; source is client's WAN. +// - accept-routes=on: client tunnels to the subnet router, which forwards +// and SNATs; source is subnet router's WAN. +func TestSubnetRouterPublicIP(t *testing.T) { + env := vmtest.New(t) + + const ( + clientWAN = "1.0.0.1" + routerWAN = "2.0.0.1" + webWAN = "5.0.0.1" + webRoute = "5.0.0.0/24" + ) + + clientNet := env.AddNetwork(clientWAN, "192.168.1.1/24", vnet.EasyNAT) + routerNet := env.AddNetwork(routerWAN, "192.168.2.1/24", vnet.EasyNAT) + webNet := env.AddNetwork(webWAN, "192.168.5.1/24", vnet.One2OneNAT) + + client := env.AddNode("client", clientNet, + vmtest.OS(vmtest.Gokrazy)) + sr := env.AddNode("subnet-router", routerNet, + vmtest.OS(vmtest.Gokrazy), + vmtest.AdvertiseRoutes(webRoute)) + env.AddNode("webserver", webNet, + vmtest.OS(vmtest.Gokrazy), + vmtest.DontJoinTailnet(), + vmtest.WebServer(8080)) + + // Declare test-specific steps for the web UI. + approveStep := env.AddStep("Approve subnet route (public IP)") + checkOn1Step := env.AddStep("HTTP GET (accept-routes=on)") + checkOffStep := env.AddStep("HTTP GET (accept-routes=off)") + checkOn2Step := env.AddStep("HTTP GET (accept-routes=on, again)") + + env.Start() + // ApproveRoutes also turns on RouteAll on the client. + approveStep.Begin() + env.ApproveRoutes(sr, webRoute) + approveStep.End(nil) + + webURL := fmt.Sprintf("http://%s:8080/", webWAN) + check := func(step *vmtest.Step, label, wantSrc string) { + t.Helper() + step.Begin() + body := env.HTTPGet(client, webURL) + t.Logf("[%s] response: %s", label, body) + if !strings.Contains(body, "Hello world I am webserver") { + step.Fatalf("[%s] unexpected webserver response: %q", label, body) + } + if !strings.Contains(body, "from "+wantSrc) { + step.Fatalf("[%s] expected source %q in response, got %q", label, wantSrc, body) + } + step.End(nil) + } + + // accept-routes=on (set by ApproveRoutes): traffic flows via the subnet router. + check(checkOn1Step, "accept-routes=on", routerWAN) + + // accept-routes=off: client dials the webserver directly. + env.SetAcceptRoutes(client, false) + check(checkOffStep, "accept-routes=off", clientWAN) + + // Toggle back on to confirm the transition works in both directions. + env.SetAcceptRoutes(client, true) + check(checkOn2Step, "accept-routes=on (again)", routerWAN) +} + +// TestSubnetRouterAndExitNode checks how the subnet router and exit node +// preferences interact. Topology: client, subnet router, exit node, and +// webserver, each on its own NAT'd network with distinct WAN IPs. The subnet +// router advertises the webserver's network (5.0.0.0/24); the exit node +// advertises 0.0.0.0/0 + ::/0. The webserver echoes the source IP it sees: +// +// exit=off, subnet=off → client's WAN (direct dial) +// exit=off, subnet=on → subnet router's WAN +// exit=on, subnet=off → exit node's WAN +// exit=on, subnet=on → subnet router's WAN (more-specific /24 beats /0) +func TestSubnetRouterAndExitNode(t *testing.T) { + env := vmtest.New(t) + + const ( + clientWAN = "1.0.0.1" + routerWAN = "2.0.0.1" + exitWAN = "3.0.0.1" + webWAN = "5.0.0.1" + webRoute = "5.0.0.0/24" + ) + + clientNet := env.AddNetwork(clientWAN, "192.168.1.1/24", vnet.EasyNAT) + routerNet := env.AddNetwork(routerWAN, "192.168.2.1/24", vnet.EasyNAT) + exitNet := env.AddNetwork(exitWAN, "192.168.3.1/24", vnet.EasyNAT) + webNet := env.AddNetwork(webWAN, "192.168.5.1/24", vnet.One2OneNAT) + + client := env.AddNode("client", clientNet, + vmtest.OS(vmtest.Gokrazy)) + sr := env.AddNode("subnet-router", routerNet, + vmtest.OS(vmtest.Gokrazy), + vmtest.AdvertiseRoutes(webRoute)) + exit := env.AddNode("exit", exitNet, + vmtest.OS(vmtest.Gokrazy), + vmtest.AdvertiseRoutes("0.0.0.0/0,::/0")) + env.AddNode("webserver", webNet, + vmtest.OS(vmtest.Gokrazy), + vmtest.DontJoinTailnet(), + vmtest.WebServer(8080)) + + // Declare test-specific steps for the web UI. + approveStep := env.AddStep("Approve subnet & exit routes") + + webURL := fmt.Sprintf("http://%s:8080/", webWAN) + tests := []struct { + name string // subtest name; describes (exit, subnet) toggles + exit *vmtest.Node + subnet bool + wantSrc string + step *vmtest.Step + }{ + {"exit-off,subnet-off", nil, false, clientWAN, nil}, + {"exit-off,subnet-on", nil, true, routerWAN, nil}, + {"exit-on,subnet-off", exit, false, exitWAN, nil}, + // More-specific 5.0.0.0/24 from sr beats 0.0.0.0/0 from exit. + {"exit-on,subnet-on", exit, true, routerWAN, nil}, + } + for i := range tests { + tests[i].step = env.AddStep("HTTP GET: " + tests[i].name) + } + + env.Start() + approveStep.Begin() + env.ApproveRoutes(sr, webRoute) + env.ApproveRoutes(exit, "0.0.0.0/0", "::/0") + // Don't let the exit node itself forward via the subnet router: when the + // client is using the exit node only, we want the exit node to egress to + // the simulated internet directly so the webserver sees the exit's WAN. + env.SetAcceptRoutes(exit, false) + approveStep.End(nil) + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + tc.step.Begin() + env.SetExitNode(client, tc.exit) + env.SetAcceptRoutes(client, tc.subnet) + body := env.HTTPGet(client, webURL) + t.Logf("response: %s", body) + if !strings.Contains(body, "Hello world I am webserver") { + tc.step.End(fmt.Errorf("unexpected webserver response: %q", body)) + t.Fatalf("unexpected webserver response: %q", body) + } + if !strings.Contains(body, "from "+tc.wantSrc) { + tc.step.End(fmt.Errorf("expected source %q in response, got %q", tc.wantSrc, body)) + t.Fatalf("expected source %q in response, got %q", tc.wantSrc, body) + } + tc.step.End(nil) + }) } } + +// TestTaildrop verifies that one Ubuntu node can send a file to another +// Ubuntu node via Taildrop, and the receiver gets the same content. +// +// Topology: two Ubuntu nodes, each behind its own EasyNAT, both joined to the +// tailnet. The sender runs `tailscale file cp` to push to the receiver's +// Tailscale IP; the receiver then runs `tailscale file get --wait` to fetch +// it. +func TestTaildrop(t *testing.T) { + env := vmtest.New(t, vmtest.SameTailnetUser()) + + senderNet := env.AddNetwork("1.0.0.1", "192.168.1.1/24", vnet.EasyNAT) + receiverNet := env.AddNetwork("2.0.0.1", "192.168.2.1/24", vnet.EasyNAT) + + sender := env.AddNode("sender", senderNet, + vmtest.OS(vmtest.Ubuntu2404)) + receiver := env.AddNode("receiver", receiverNet, + vmtest.OS(vmtest.Ubuntu2404)) + + // Declare test-specific steps for the web UI. + sendStep := env.AddStep("Taildrop send (sender -> receiver)") + recvStep := env.AddStep("Taildrop receive (on receiver)") + verifyStep := env.AddStep("Verify received name and contents") + + env.Start() + + const filename = "hello.txt" + want := []byte("hello world this is a Taildrop test\n") + + sendStep.Begin() + env.SendTaildropFile(sender, receiver, filename, want) + sendStep.End(nil) + + recvStep.Begin() + gotName, gotContent := env.RecvTaildropFile(t.Context(), receiver) + recvStep.End(nil) + + verifyStep.Begin() + if gotName != filename { + verifyStep.Fatalf("received name = %q; want %q", gotName, filename) + return + } + if !bytes.Equal(gotContent, want) { + verifyStep.Fatalf("received content = %q; want %q", gotContent, want) + return + } + verifyStep.End(nil) +} + +// TestExitNode verifies that switching the client's exit node setting between +// off, exit1, and exit2 correctly routes the client's internet traffic. +// +// Topology: each of the client and the two exit nodes lives behind its own NAT +// with a unique WAN IP, and a webserver lives on yet another network using a +// 1:1 NAT so it's reachable from the simulated internet at a stable address. +// The webserver echoes the source IP of incoming requests, so we can tell +// which network's NAT the client's traffic egressed through: +// - off: source is the client's network WAN IP. +// - exit1: source is exit1's network WAN IP. +// - exit2: source is exit2's network WAN IP. +func TestExitNode(t *testing.T) { + env := vmtest.New(t) + + const ( + clientWAN = "1.0.0.1" + exit1WAN = "2.0.0.1" + exit2WAN = "3.0.0.1" + webWAN = "5.0.0.1" + ) + + clientNet := env.AddNetwork(clientWAN, "192.168.1.1/24", vnet.EasyNAT) + exit1Net := env.AddNetwork(exit1WAN, "192.168.2.1/24", vnet.EasyNAT) + exit2Net := env.AddNetwork(exit2WAN, "192.168.3.1/24", vnet.EasyNAT) + webNet := env.AddNetwork(webWAN, "192.168.5.1/24", vnet.One2OneNAT) + + client := env.AddNode("client", clientNet, + vmtest.OS(vmtest.Gokrazy)) + exit1 := env.AddNode("exit1", exit1Net, + vmtest.OS(vmtest.Gokrazy), + vmtest.AdvertiseRoutes("0.0.0.0/0,::/0")) + exit2 := env.AddNode("exit2", exit2Net, + vmtest.OS(vmtest.Gokrazy), + vmtest.AdvertiseRoutes("0.0.0.0/0,::/0")) + env.AddNode("webserver", webNet, + vmtest.OS(vmtest.Gokrazy), + vmtest.DontJoinTailnet(), + vmtest.WebServer(8080)) + + // Declare test-specific steps for the web UI. + approveStep := env.AddStep("Approve exit-node routes (exit1, exit2)") + + webURL := fmt.Sprintf("http://%s:8080/", webWAN) + tests := []struct { + name string // subtest name + exit *vmtest.Node + wantSrc string + step *vmtest.Step + }{ + {"off", nil, clientWAN, nil}, + {"exit1", exit1, exit1WAN, nil}, + {"exit2", exit2, exit2WAN, nil}, + } + for i := range tests { + tests[i].step = env.AddStep("HTTP GET: exit=" + tests[i].name) + } + + env.Start() + approveStep.Begin() + env.ApproveRoutes(exit1, "0.0.0.0/0", "::/0") + env.ApproveRoutes(exit2, "0.0.0.0/0", "::/0") + approveStep.End(nil) + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.step.Begin() + env.SetExitNode(client, tt.exit) + body := env.HTTPGet(client, webURL) + t.Logf("response: %s", body) + if !strings.Contains(body, "Hello world I am webserver") { + tt.step.Fatalf("unexpected webserver response: %q", body) + } + if !strings.Contains(body, "from "+tt.wantSrc) { + tt.step.Fatalf("expected source %q in response, got %q", tt.wantSrc, body) + } + tt.step.End(nil) + }) + } +} + +// TestDiscoKeyChange verifies that when one node's disco key rotates without +// its WireGuard node key changing, peers detect the change, tear down stale +// WireGuard session state for that peer, and re-establish the tunnel in both +// directions. This exercises the disco-key-change handling that the +// bradfitz/rm_lazy_wg branch relies on for traffic to and from a peer whose +// magicsock state has been reset. +// +// Topology: two gokrazy nodes A and B, each on its own One2OneNAT network so +// every connection between them is a direct UDP path with no port-mapping or +// filtering. With NAT effects out of the way, what we measure here is the +// speed of disco-key-change reconciliation in wgengine/magicsock alone. The +// test control server is also configured with [testcontrol.Server.AllOnline] +// (via [vmtest.AllOnline]) so the controlclient/wgengine fast paths that +// branch on Online actually fire — without that flag the test exercises +// only the offline-peer code paths, which mask separate latent issues and +// are several seconds slower. +// +// The test runs four B-side rotations followed by a TSMP ping in the +// requested direction: +// +// rotate (LocalAPI rotate-disco-key) → ping B → A +// rotate (LocalAPI rotate-disco-key) → ping A → B +// restart (SIGKILL tailscaled) → ping B → A +// restart (SIGKILL tailscaled) → ping A → B +// +// Plus an initial A→B TSMP ping with a generous 30s budget to bring up the +// WireGuard tunnel before the rotations begin (so the post-rotation pings +// measure stale-state recovery, not first-time setup). All pings are TSMP +// because TSMP traverses the actual WireGuard data plane; PingDisco only +// exercises the magicsock disco layer and would mask any stale WG session +// problems. +// +// Two rotation methods are exercised: +// +// - LocalAPI rotate-disco-key (debug action): rolls B's magicsock disco +// private key in place, then bounces WantRunning to force wgengine to +// drop wireguard-go session keys for every peer (RotateDiscoKey alone +// only touches local disco state; without the WantRunning bounce, B +// keeps using stale per-peer session keys against A and A drops +// everything until B's WG rekey timer eventually fires). +// - SIGKILL of tailscaled (via TTA's /kill-tailscaled): the gokrazy +// supervisor respawns tailscaled, fully resetting B's magicsock and +// wgengine state in addition to rotating the disco key. +// +// Each post-rotation ping currently gets a 15-second budget. On a +// hypothetical perfect build it should take well under a second. In +// practice today there are two unavoidable multi-second waits: +// +// - The rotate-then-a→b phase on main takes ~10s for LazyWG. After +// B's WantRunning bounce, B's wgengine resets its sentActivityAt/ +// recvActivityAt maps and trims A out of the wireguard-go config +// as an "idle peer"; B only re-adds A on inbound activity, by +// which point A's first few TSMP packets have been silently +// dropped at B's tundev. The bradfitz/rm_lazy_wg branch removes +// that trimming entirely (verified locally), so this phase will +// drop to <100ms once that branch lands. +// +// - The restart phases take ~5s for the wireguard-go handshake retry +// timer. After SIGKILL+respawn the first WG handshake init from +// the restarted node sometimes goes into the void (likely the +// brief peer-removed window in the receiver's two-step +// [wgengine.userspaceEngine.maybeReconfigWireguardLocked] reconfig +// during which the peer is absent from wireguard-go), and wg-go's +// [device.RekeyTimeout] of 5s + jitter is the next opportunity to +// retry. That retry succeeds and the staged TSMP packet flushes. +// This is intrinsic to the protocol's retransmit policy. +// +// Once LazyWG is removed and the first-handshake-after-reconfig race +// is fixed, this budget should be tightened to 5s (or less). +// +// All four rotations also assert that B's WireGuard node key is unchanged. +func TestDiscoKeyChange(t *testing.T) { + // AllOnline makes the test control server mark every peer as Online=true + // in its MapResponses. Several disco-key handling fast paths + // (controlclient.removeUnwantedDiscoUpdates, + // removeUnwantedDiscoUpdatesFromFullNetmapUpdate, and the wgengine + // tsmpLearnedDisco fast path) only fire for online peers. Production + // control servers always populate Online; without this flag the test + // would only exercise the offline-peer paths. + env := vmtest.New(t, vmtest.AllOnline()) + + // One2OneNAT so each node has a 1:1 mapping to a public WAN IP with no + // port-translation or address-port filtering. This makes A↔B traffic + // behave like two unfirewalled hosts on the public internet, so any + // slowness we observe in this test cannot be blamed on NAT traversal. + aNet := env.AddNetwork("1.0.0.1", "192.168.1.1/24", vnet.One2OneNAT) + bNet := env.AddNetwork("2.0.0.1", "192.168.2.1/24", vnet.One2OneNAT) + + a := env.AddNode("a", aNet, vmtest.OS(vmtest.Gokrazy)) + b := env.AddNode("b", bNet, vmtest.OS(vmtest.Gokrazy)) + + type phase struct { + name string + rotate func() + pingFrom *vmtest.Node + pingTo *vmtest.Node + applyStep *vmtest.Step + verify *vmtest.Step + wait *vmtest.Step + ping *vmtest.Step + } + phases := []*phase{ + {name: "rotate (LocalAPI), b → a", pingFrom: b, pingTo: a, rotate: func() { env.RotateDiscoKey(b) }}, + {name: "rotate (LocalAPI), a → b", pingFrom: a, pingTo: b, rotate: func() { env.RotateDiscoKey(b) }}, + {name: "restart, b → a", pingFrom: b, pingTo: a, rotate: func() { env.RestartTailscaled(b) }}, + {name: "restart, a → b", pingFrom: a, pingTo: b, rotate: func() { env.RestartTailscaled(b) }}, + } + + pingABStep := env.AddStep("Ping a → b TSMP (establish tunnel)") + for _, p := range phases { + p.applyStep = env.AddStep("Apply: " + p.name) + p.verify = env.AddStep("Verify b: same node key, new disco key (" + p.name + ")") + p.wait = env.AddStep("Wait for a to see b's new disco key (" + p.name + ")") + p.ping = env.AddStep("Ping " + p.pingFrom.Name() + " → " + p.pingTo.Name() + " TSMP (" + p.name + ")") + } + + env.Start() + + pingABStep.Begin() + if err := env.Ping(a, b, tailcfg.PingTSMP, 30*time.Second); err != nil { + pingABStep.Fatal(err) + } + pingABStep.End(nil) + + bStInitial := env.Status(b) + bNodeKey := bStInitial.Self.PublicKey + cs := env.ControlServer() + bCtlNode := cs.Node(bNodeKey) + if bCtlNode == nil { + t.Fatalf("control server has no node for b's key %v", bNodeKey) + } + prevDisco := bCtlNode.DiscoKey + if prevDisco.IsZero() { + t.Fatalf("control server has no disco key for b before rotation") + } + t.Logf("[b] initial: nodekey=%s discokey=%s", bNodeKey.ShortString(), prevDisco.ShortString()) + + for _, p := range phases { + p.applyStep.Begin() + p.rotate() + p.applyStep.End(nil) + prevDisco = checkDiscoRotated(t, env, a, b, p.pingFrom, p.pingTo, bNodeKey, prevDisco, p.name, + p.verify, p.wait, p.ping) + } +} + +// checkDiscoRotated verifies that after some action that should have rotated +// b's disco key, control has learned the new key, b's node key is unchanged, +// a's local view picks up the new disco key, and pingFrom can ping pingTo +// (TSMP) within the budget. It returns b's new disco key and fatals on +// failure. +// +// The TSMP ping budget is 15 seconds rather than the few hundred ms it +// ought to take. See the top-level test docstring for a full breakdown: +// it has to absorb LazyWG's trim+re-add for the rotate-a→b phase (~10s) +// and wireguard-go's RekeyTimeout retry for the SIGKILL+restart phases +// (~5s). Tighten this once both are addressed. +func checkDiscoRotated(t *testing.T, env *vmtest.Env, a, b, pingFrom, pingTo *vmtest.Node, bNodeKey key.NodePublic, oldDisco key.DiscoPublic, label string, verifyStep, waitStep, pingStep *vmtest.Step) key.DiscoPublic { + t.Helper() + cs := env.ControlServer() + + verifyStep.Begin() + bSt := env.Status(b) + if got := bSt.Self.PublicKey; got != bNodeKey { + verifyStep.Fatalf("[%s] b's node key changed: %v -> %v", label, bNodeKey, got) + } + var newDisco key.DiscoPublic + if err := tstest.WaitFor(15*time.Second, func() error { + n := cs.Node(bNodeKey) + if n == nil { + return fmt.Errorf("control server has no node for b") + } + if n.DiscoKey.IsZero() || n.DiscoKey == oldDisco { + return fmt.Errorf("control still has old disco key %v for b", n.DiscoKey) + } + newDisco = n.DiscoKey + return nil + }); err != nil { + verifyStep.Fatalf("[%s] %v", label, err) + } + t.Logf("[b] after %s: nodekey=%s discokey=%s", label, bNodeKey.ShortString(), newDisco.ShortString()) + verifyStep.End(nil) + + waitStep.Begin() + if err := tstest.WaitFor(30*time.Second, func() error { + d, ok, err := env.PeerDiscoKey(a, bNodeKey) + if err != nil { + return err + } + if !ok { + return fmt.Errorf("a doesn't yet have b in its status") + } + if d != newDisco { + return fmt.Errorf("a still sees b's old disco %v, want %v", d.ShortString(), newDisco.ShortString()) + } + return nil + }); err != nil { + waitStep.End(err) + env.DumpStatus(a) + t.Fatalf("[%s] %v", label, err) + } + waitStep.End(nil) + + pingStep.Begin() + t0 := time.Now() + if err := env.Ping(pingFrom, pingTo, tailcfg.PingTSMP, 15*time.Second); err != nil { + pingStep.End(err) + env.DumpStatus(a) + env.DumpStatus(b) + t.Fatalf("[%s] %v", label, err) + } + t.Logf("[%s] ping %s -> %s succeeded in %v", label, pingFrom.Name(), pingTo.Name(), time.Since(t0).Round(100*time.Millisecond)) + pingStep.End(nil) + return newDisco +} + +// TestMullvadExitNode verifies that a Tailscale client whose netmap contains +// a plain-WireGuard exit node (the way Mullvad exit nodes are wired up by +// the control plane) can route internet traffic through it, with the source +// IP rewritten to the per-client Mullvad-assigned address. +// +// Topology: +// +// client (Tailscale, gokrazy) — clientNet (EasyNAT) WAN 1.0.0.1 +// mullvad (Ubuntu, userspace WG) — mullvadNet (One2OneNAT) WAN 2.0.0.1 +// webserver (no Tailscale, gokrazy) — webNet (One2OneNAT) WAN 5.0.0.1 +// +// The mullvad VM impersonates a Mullvad WireGuard server. After boot, the +// test asks its TTA agent to bring up a userspace WireGuard interface (a +// real Linux TUN driven by wireguard-go) that pins the client's Tailscale +// node public key as its only allowed peer, sets up IP-forwarding + a +// MASQUERADE rule, and reports the WG server's freshly generated public +// key back. Userspace vs kernel WireGuard makes no difference on the wire +// — what's being tested is Tailscale's plain-WireGuard exit-node code +// path, not the kernel module. +// +// The test then injects a netmap peer with IsWireGuardOnly=true, +// AllowedIPs=[gw/32, 0.0.0.0/0, ::/0], the WG endpoint, and a per-client +// SelfNodeV4MasqAddrForThisPeer (the mock equivalent of the per-client IP +// Mullvad's API hands out at registration time). +// +// The webserver echoes the source IP it sees: +// - exit-node off: source is client's WAN (direct egress) +// - exit-node on: source is mullvad's WAN (egress via WG + MASQUERADE) +func TestMullvadExitNode(t *testing.T) { + env := vmtest.New(t) + + const ( + clientWAN = "1.0.0.1" + mullvadWAN = "2.0.0.1" + webWAN = "5.0.0.1" + ) + // Mullvad-side WG network. The client appears as clientMasqIP to + // mullvad's wg0; mullvad terminates the tunnel at gw. + var ( + mullvadWGNet = netip.MustParsePrefix("10.64.0.0/24") + gw = netip.MustParsePrefix("10.64.0.1/24") + clientMasq = netip.MustParsePrefix("10.64.0.2/32") + ) + const wgListenPort uint16 = 51820 + + clientNet := env.AddNetwork(clientWAN, "192.168.1.1/24", vnet.EasyNAT) + mullvadNet := env.AddNetwork(mullvadWAN, "192.168.2.1/24", vnet.One2OneNAT) + webNet := env.AddNetwork(webWAN, "192.168.5.1/24", vnet.One2OneNAT) + + client := env.AddNode("client", clientNet, vmtest.OS(vmtest.Gokrazy)) + mullvad := env.AddNode("mullvad", mullvadNet, + vmtest.OS(vmtest.Ubuntu2404), + vmtest.DontJoinTailnet()) + env.AddNode("webserver", webNet, + vmtest.OS(vmtest.Gokrazy), + vmtest.DontJoinTailnet(), + vmtest.WebServer(8080)) + + // Declare test-specific steps for the web UI. + wgUpStep := env.AddStep("Bring up Mullvad WG server") + injectStep := env.AddStep("Inject Mullvad netmap peer") + checkOff1Step := env.AddStep("HTTP GET (exit off)") + checkMullvadStep := env.AddStep("HTTP GET (exit=mullvad)") + checkOff2Step := env.AddStep("HTTP GET (exit off, again)") + + env.Start() + + // Bring up the WG server inside mullvad's TTA, pinning the client's + // Tailscale node public key as the sole allowed peer. + wgUpStep.Begin() + clientStatus := env.Status(client) + mullvadPub := env.BringUpMullvadWGServer(mullvad, + gw, wgListenPort, + clientStatus.Self.PublicKey, clientMasq, mullvadWGNet) + wgUpStep.End(nil) + + // Inject the mullvad node into the netmap as a plain-WireGuard exit + // node. This mirrors how the control plane describes Mullvad exit + // nodes to clients (see control/cmullvad in the closed repo): a + // peer with IsWireGuardOnly=true, an Endpoints entry pointing at + // the public WG host:port, and AllowedIPs covering both the gateway + // /32 and the 0.0.0.0/0+::/0 exit-node routes. + injectStep.Begin() + mullvadEndpoint := netip.AddrPortFrom(netip.MustParseAddr(mullvadWAN), wgListenPort) + gwHost := netip.PrefixFrom(gw.Addr(), gw.Addr().BitLen()) + mullvadNode := &tailcfg.Node{ + ID: 999_001, + StableID: "mullvad-test", + Name: "mullvad-test.fake-control.example.net.", + Key: mullvadPub, + MachineAuthorized: true, + IsWireGuardOnly: true, + Endpoints: []netip.AddrPort{mullvadEndpoint}, + Addresses: []netip.Prefix{gwHost}, + AllowedIPs: []netip.Prefix{ + gwHost, + netip.MustParsePrefix("0.0.0.0/0"), + netip.MustParsePrefix("::/0"), + }, + Hostinfo: (&tailcfg.Hostinfo{ + Hostname: "mullvad-test", + }).View(), + } + cs := env.ControlServer() + cs.UpdateNode(mullvadNode) + + // Set the per-peer source-IP masquerade. The control plane normally + // derives this from the Mullvad API's per-client registration; here + // we just pin it to the address mullvad's wg0 was told to accept. + cs.SetMasqueradeAddresses([]testcontrol.MasqueradePair{{ + Node: clientStatus.Self.PublicKey, + Peer: mullvadPub, + NodeMasqueradesAs: clientMasq.Addr(), + }}) + injectStep.End(nil) + + webURL := fmt.Sprintf("http://%s:8080/", webWAN) + check := func(step *vmtest.Step, label, wantSrc string) { + t.Helper() + step.Begin() + body := env.HTTPGet(client, webURL) + t.Logf("[%s] response: %s", label, body) + if !strings.Contains(body, "Hello world I am webserver") { + step.Fatalf("[%s] unexpected webserver response: %q", label, body) + } + if !strings.Contains(body, "from "+wantSrc) { + step.Fatalf("[%s] expected source %q in response, got %q", label, wantSrc, body) + } + step.End(nil) + } + + // Exit-node off: client routes 0.0.0.0/0 directly via its host stack, + // so the webserver sees client's WAN IP. + check(checkOff1Step, "exit-off", clientWAN) + + // Switch to the Mullvad WG-only peer as exit node. The client should + // now route 0.0.0.0/0 through the WG tunnel; mullvad MASQUERADEs to + // its WAN; the webserver sees the mullvad VM's WAN IP. + env.SetExitNodeIP(client, gw.Addr()) + check(checkMullvadStep, "exit-mullvad", mullvadWAN) + + // And back off again, to make sure the transition works in both + // directions. + env.SetExitNodeIP(client, netip.Addr{}) + check(checkOff2Step, "exit-off (again)", clientWAN) +} + +// checkClientMetrics verifies that each entry in want exists and has the given +// value in metrics. +func checkClientMetrics(t *testing.T, label string, metrics vmtest.ClientMetrics, want map[string]int64) { + t.Helper() + for name, wantValue := range want { + got, ok := metrics[name] + if !ok { + t.Errorf("%s: required metric %q not found", label, name) + } else if got.Value != wantValue { + t.Errorf("%s: metric %q: got %v, want %v", label, name, got.Value, wantValue) + } + } +} + +// TestCachedNetmapAfterRestart verifies that two nodes with netmap +// caching enabled (NodeAttrCacheNetworkMaps) can re-establish a direct +// WireGuard tunnel after both are restarted while the control server is +// unreachable. After restart the nodes must use only their on-disk cached +// netmaps to re-connect. +func TestCachedNetmapAfterRestart(t *testing.T) { + env := vmtest.New(t) + + aNet := env.AddNetwork("1.0.0.1", "192.168.1.1/24", vnet.EasyNAT) + bNet := env.AddNetwork("2.0.0.1", "192.168.2.1/24", vnet.EasyNAT) + + a := env.AddNode("a", aNet, + vmtest.OS(vmtest.Gokrazy), + tailcfg.NodeCapMap{tailcfg.NodeAttrCacheNetworkMaps: nil}) + b := env.AddNode("b", bNet, + vmtest.OS(vmtest.Gokrazy), + tailcfg.NodeCapMap{tailcfg.NodeAttrCacheNetworkMaps: nil}) + + connectStep := env.AddStep("Establish initial TSMP tunnel") + cutControlStep := env.AddStep("Cut control server access") + restartStep := env.AddStep("Restart tailscaled on both nodes") + netmapCheckStep := env.AddStep("Check netmap loaded is cached") + pingStep := env.AddStep("Ping a → b TSMP (cached netmap, no control)") + + env.Start() + + connectStep.Begin() + if err := env.Ping(a, b, tailcfg.PingTSMP, 30*time.Second); err != nil { + connectStep.Fatal(err) + } + connectStep.End(nil) + + cutControlStep.Begin() + // Both nodes lose connection to control + a.DropControlTraffic() + b.DropControlTraffic() + env.ControlServer().SetOnMapRequest(func(nk key.NodePublic) { + panic(fmt.Sprintf("got connection from %v", nk)) + }) + cutControlStep.End(nil) + + restartStep.Begin() + env.RestartTailscaled(a) + env.RestartTailscaled(b) + restartStep.End(nil) + + netmapCheckStep.Begin() + for _, node := range []*vmtest.Node{a, b} { + nm, err := local.GetDebugResultJSON[netmap.NetworkMap](t.Context(), node.Agent().Client, "current-netmap") + if err != nil { + netmapCheckStep.Fatalf("[%s] got err fetching netmap %q", node.Name(), err) + } + if !nm.Cached { + netmapCheckStep.Fatalf("[%s] expected netmap.Cached = true, got: %t", node.Name(), nm.Cached) + } + } + netmapCheckStep.End(nil) + + // 90s is generous on purpose. After both nodes restart with stale cached + // netmap entries, a's first WG handshake to b's pre-restart endpoint + // hits the dead NAT mapping on b's side and is silently dropped (we + // see this as "no recent outgoing packet" NAT drops in the vnet log). + // Recovery then waits on wireguard-go's REKEY_TIMEOUT (~5s) before the + // next handshake attempt, and on disco-via-DERP to teach each side the + // other's new endpoint. On an idle host this converges in well under + // 15s; on a contended host (a 14/16-CPU-loaded local repro, or any + // shared CI runner) the same sequence has been observed at 50-60s + // because every timer fires multiple times under scheduling jitter. + pingStep.Begin() + if err := env.Ping(a, b, tailcfg.PingTSMP, 90*time.Second); err != nil { + pingStep.Fatal(err) + } + pingStep.End(nil) +} + +// TestDirectConnectionWithCachedNetmap verifies that two nodes with netmap +// caching enabled (NodeAttrCacheNetworkMaps) can re-establish a direct +// WireGuard tunnel after one is restarted while the control server is +// unreachable. After restart the node must use only its on-disk cached +// netmaps to re-connect and ping the other (still online) node. +func TestDirectConnectionWithCachedNetmapOnOneNode(t *testing.T) { + env := vmtest.New(t) + + aNet := env.AddNetwork("1.0.0.1", "192.168.1.1/24", vnet.EasyNAT) + bNet := env.AddNetwork("2.0.0.1", "192.168.2.1/24", vnet.EasyNAT) + + a := env.AddNode("a", aNet, + vmtest.OS(vmtest.Gokrazy), + tailcfg.NodeCapMap{tailcfg.NodeAttrCacheNetworkMaps: nil}) + b := env.AddNode("b", bNet, + vmtest.OS(vmtest.Gokrazy), + tailcfg.NodeCapMap{tailcfg.NodeAttrCacheNetworkMaps: nil}) + + checkInitialMetrics := env.AddStep("Check initial client metrics") + cutControlStep := env.AddStep("Cut control server access") + restartStep := env.AddStep("Restart tailscaled on a") + tsmpPingStep := env.AddStep("Ping a → b TSMP (cached netmap, no control)") + discoPingStep := env.AddStep("Ping a → b Disco (want Direct)") + checkFinalMetrics := env.AddStep("Check final client metrics") + + env.Start() + + // Before: Verify that we have not recorded any cached contacts. + checkInitialMetrics.Begin() + checkClientMetrics(t, "Node A", env.ClientMetrics(a), map[string]int64{ + "magicsock_cached_peer_contact_derp": 0, + "magicsock_cached_peer_contact_direct": 0, + }) + checkInitialMetrics.End(nil) + + cutControlStep.Begin() + a.DropControlTraffic() + env.ControlServer().SetOnMapRequest(func(nk key.NodePublic) { + if env.ControlServer().Node(nk).Name == a.Name() { + panic(fmt.Sprintf("got connection from %v", a.Name())) + } + }) + cutControlStep.End(nil) + + restartStep.Begin() + env.RestartTailscaled(a) + restartStep.End(nil) + + tsmpPingStep.Begin() + if err := env.Ping(a, b, tailcfg.PingTSMP, 30*time.Second); err != nil { + tsmpPingStep.Fatal(err) + } + tsmpPingStep.End(nil) + + discoPingStep.Begin() + if err := env.PingExpect(a, b, vmtest.PingRouteDirect, 30*time.Second); err != nil { + discoPingStep.Fatal(err) + } + discoPingStep.End(nil) + + // After: Verify that we recorded a direct contact on the disconnected node. + checkFinalMetrics.Begin() + checkClientMetrics(t, "Node A", env.ClientMetrics(a), map[string]int64{ + "magicsock_cached_peer_contact_direct": 1, + }) + checkFinalMetrics.End(nil) +} + +// TestPeerRelay verifies that two Tailscale nodes whose direct UDP path is +// impossible at the network layer (both behind HardNAT, with no port-mapping +// services on either of their networks) can still communicate via a third +// Tailscale node configured as a peer-relay server. +// +// Topology: +// +// a (gokrazy, HardNAT) — aNet WAN 1.0.0.1 +// b (gokrazy, HardNAT) — bNet WAN 2.0.0.1 +// relay (gokrazy, One2OneNAT) — relayNet WAN 3.0.0.1 +// +// HardNAT in natlab is endpoint-dependent (each (src, dst) tuple gets a fresh +// outbound port, and the inbound table keys on (wanPort, src)). Without +// NAT-PMP/UPnP a→b and b→a direct UDP paths cannot be established. The relay +// uses One2OneNAT so its STUN-discovered WAN endpoint is reachable from both +// peers. The test then asserts that magicsock chose the peer-relay path +// (not DERP) and that the relay reports the session. +func TestPeerRelay(t *testing.T) { + env := vmtest.New(t, vmtest.PeerRelayGrants()) + + aNet := env.AddNetwork("1.0.0.1", "192.168.1.1/24", vnet.HardNAT) + bNet := env.AddNetwork("2.0.0.1", "192.168.2.1/24", vnet.HardNAT) + relayNet := env.AddNetwork("3.0.0.1", "192.168.3.1/24", vnet.One2OneNAT) + + a := env.AddNode("a", aNet, vmtest.OS(vmtest.Gokrazy)) + b := env.AddNode("b", bNet, vmtest.OS(vmtest.Gokrazy)) + relay := env.AddNode("relay", relayNet, vmtest.OS(vmtest.Gokrazy)) + + enableRelayStep := env.AddStep("Enable peer-relay server on relay") + pingStep := env.AddStep("Disco ping a → b (want peer-relay path)") + sessionsStep := env.AddStep("Check DebugPeerRelaySessions on relay") + + env.Start() + + // Turn on the relay server. Port 0 picks an unused port. + enableRelayStep.Begin() + editCtx, editCancel := context.WithTimeout(t.Context(), 30*time.Second) + _, err := relay.Agent().EditPrefs(editCtx, &ipn.MaskedPrefs{ + Prefs: ipn.Prefs{RelayServerPort: new(uint16(0))}, + RelayServerPortSet: true, + }) + editCancel() + if err != nil { + enableRelayStep.Fatalf("EditPrefs(relay, RelayServerPort=0): %v", err) + } + enableRelayStep.End(nil) + + // Wait for the relay to start, peers to learn about it via netmap, + // and the a→b disco ping to traverse it. + // PingResult.PeerRelay is set by magicsock to "ip:port:vni:N" when the + // disco probe rode a peer relay (vs Endpoint for direct UDP or + // DERPRegionID for DERP). + pingStep.Begin() + bIP := env.Status(b).Self.TailscaleIPs[0] + var lastDetail string + err = tstest.WaitFor(60*time.Second, func() error { + ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second) + defer cancel() + pr, err := a.Agent().PingWithOpts(ctx, bIP, tailcfg.PingDisco, local.PingOpts{}) + if err != nil { + return fmt.Errorf("ping: %w", err) + } + if pr.Err != "" { + return fmt.Errorf("ping err: %s", pr.Err) + } + if pr.PeerRelay == "" { + lastDetail = fmt.Sprintf("endpoint=%q derp=%d", pr.Endpoint, pr.DERPRegionID) + return fmt.Errorf("ping did not use a peer relay; %s", lastDetail) + } + t.Logf("a → b disco ping rode peer-relay %s", pr.PeerRelay) + return nil + }) + if err != nil { + env.DumpStatus(a) + env.DumpStatus(b) + env.DumpStatus(relay) + pingStep.Fatalf("waiting for peer-relay path a → b: %v (last: %s)", err, lastDetail) + } + pingStep.End(nil) + + // The relay's local debug-peer-relay-sessions LocalAPI should now + // report a single session for the a↔b disco probe. Cross-check the + // session's client disco keys against control's view of a and b, and + // confirm both sides recorded non-zero packet/byte counts (the disco + // ping + pong each take one underlay packet through the relay). + sessionsStep.Begin() + ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second) + defer cancel() + srv, err := relay.Agent().DebugPeerRelaySessions(ctx) + if err != nil { + sessionsStep.Fatalf("DebugPeerRelaySessions: %v", err) + } + if srv.UDPPort == nil { + sessionsStep.Fatalf("relay UDPPort is nil; want set") + } + if got, want := len(srv.Sessions), 1; got != want { + sessionsStep.Fatalf("relay sessions = %d; want %d: %+v", got, want, srv.Sessions) + } + cs := env.ControlServer() + wantShorts := set.Of( + cs.Node(env.Status(a).Self.PublicKey).DiscoKey.ShortString(), + cs.Node(env.Status(b).Self.PublicKey).DiscoKey.ShortString(), + ) + session := srv.Sessions[0] + gotShorts := set.Of(session.Client1.ShortDisco, session.Client2.ShortDisco) + if !gotShorts.Equal(wantShorts) { + sessionsStep.Fatalf("session disco shorts = %v; want %v", gotShorts, wantShorts) + } + for _, ci := range []status.ClientInfo{session.Client1, session.Client2} { + if !ci.Endpoint.IsValid() { + sessionsStep.Fatalf("session client %s: invalid Endpoint", ci.ShortDisco) + } + if ci.PacketsTx == 0 { + sessionsStep.Fatalf("session client %s: PacketsTx = 0; want >0", ci.ShortDisco) + } + if ci.BytesTx == 0 { + sessionsStep.Fatalf("session client %s: BytesTx = 0; want >0", ci.ShortDisco) + } + } + t.Logf("relay session VNI=%d %s <-> %s on UDP port %d", + session.VNI, session.Client1.ShortDisco, session.Client2.ShortDisco, *srv.UDPPort) + sessionsStep.End(nil) +} diff --git a/tstest/natlab/vmtest/web.go b/tstest/natlab/vmtest/web.go new file mode 100644 index 000000000..d512740e6 --- /dev/null +++ b/tstest/natlab/vmtest/web.go @@ -0,0 +1,209 @@ +// Copyright (c) Tailscale Inc & contributors +// SPDX-License-Identifier: BSD-3-Clause + +package vmtest + +import ( + "embed" + "flag" + "fmt" + "hash/crc32" + "html/template" + "io" + "io/fs" + "net" + "net/http" + "os" + "path/filepath" + "strings" + "sync" + "time" + + "github.com/coder/websocket" + "github.com/robert-nix/ansihtml" +) + +var vmtestWeb = flag.String("vmtest-web", "", "listen address for vmtest web UI (e.g. :0, localhost:0, :8080)") + +//go:embed assets/*.html +var templatesSrc embed.FS + +//go:embed assets/*.css +var staticAssets embed.FS + +var tmpl = sync.OnceValue(func() *template.Template { + d, err := fs.Sub(templatesSrc, "assets") + if err != nil { + panic(fmt.Errorf("getting vmtest web templates subdir: %w", err)) + } + return template.Must(template.New("").Funcs(template.FuncMap{ + "formatDuration": formatDuration, + "ansi": ansiToHTML, + }).ParseFS(d, "*")) +}) + +// ansiToHTML converts a string with ANSI escape sequences to HTML with +// inline styles. Returns template.HTML so html/template doesn't double-escape it. +func ansiToHTML(s string) template.HTML { + return template.HTML(ansihtml.ConvertToHTML([]byte(s))) +} + +// formatDuration returns a human-readable duration like "1.2s" or "45.3s". +func formatDuration(d time.Duration) string { + if d < time.Second { + return fmt.Sprintf("%dms", d.Milliseconds()) + } + return fmt.Sprintf("%.1fs", d.Seconds()) +} + +// deterministicPort returns a deterministic port in the range [20000, 40000) +// based on the test name, so re-running the same test gets the same URL. +func deterministicPort(testName string) int { + return int(crc32.ChecksumIEEE([]byte(testName)))%20000 + 20000 +} + +// listenWeb listens on the given address. If the port is 0, it first tries a +// deterministic port based on the test name so re-runs get the same URL. +// Falls back to :0 (OS-assigned) on any listen error. +func (e *Env) listenWeb(addr string) (net.Listener, error) { + host, port, _ := net.SplitHostPort(addr) + if port == "0" { + detPort := deterministicPort(e.t.Name()) + detAddr := net.JoinHostPort(host, fmt.Sprintf("%d", detPort)) + if ln, err := net.Listen("tcp", detAddr); err == nil { + return ln, nil + } + // Deterministic port busy; fall back to OS-assigned. + } + return net.Listen("tcp", addr) +} + +// maybeStartWebServer starts the web UI if --vmtest-web is set. +// Called at the very top of Env.Start(), before compilation or image downloads. +func (e *Env) maybeStartWebServer() { + addr := *vmtestWeb + if addr == "" { + return + } + + ln, err := e.listenWeb(addr) + if err != nil { + e.t.Fatalf("vmtest-web listen: %v", err) + } + e.t.Cleanup(func() { ln.Close() }) + + actualAddr := ln.Addr().(*net.TCPAddr) + + host, _, _ := net.SplitHostPort(addr) + if host == "" || host == "0.0.0.0" || host == "::" { + hostname, err := os.Hostname() + if err != nil { + hostname = "localhost" + } + e.t.Logf("Status at http://%s:%d/", hostname, actualAddr.Port) + } else { + e.t.Logf("Status at http://%s/", actualAddr.String()) + } + + mux := http.NewServeMux() + mux.HandleFunc("GET /", e.serveIndex) + mux.HandleFunc("GET /ws", e.serveWebSocket) + mux.HandleFunc("GET /screenshot/{node}", e.serveScreenshot) + mux.HandleFunc("GET /style.css", serveStaticAsset("style.css")) + + srv := &http.Server{Handler: mux} + go srv.Serve(ln) + e.t.Cleanup(func() { srv.Close() }) +} + +func serveStaticAsset(name string) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + if !strings.HasSuffix(name, ".css") { + http.Error(w, "not found", http.StatusNotFound) + return + } + w.Header().Set("Content-Type", "text/css") + f, err := staticAssets.Open(filepath.Join("assets", name)) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + defer f.Close() + io.Copy(w, f) + } +} + +func (e *Env) serveIndex(w http.ResponseWriter, r *http.Request) { + type indexData struct { + TestName string + TestStatus *TestStatus + Steps []*Step + Nodes []NodeStatus + } + + data := indexData{ + TestName: e.t.Name(), + TestStatus: e.testStatus, + Steps: e.Steps(), + } + for _, n := range e.nodes { + data.Nodes = append(data.Nodes, e.getNodeStatus(n.name)) + } + + w.Header().Set("Content-Type", "text/html; charset=utf-8") + if err := tmpl().ExecuteTemplate(w, "index.html", data); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } +} + +// serveScreenshot proxies a full-resolution screenshot from the Host.app +// screenshot server. Returns raw JPEG with no HTML wrapper. +func (e *Env) serveScreenshot(w http.ResponseWriter, r *http.Request) { + name := r.PathValue("node") + port := e.nodeScreenshotPort(name) + if port == 0 { + http.Error(w, "no screenshot server for node", http.StatusNotFound) + return + } + resp, err := http.Get(fmt.Sprintf("http://127.0.0.1:%d/screenshot?full=1", port)) + if err != nil { + http.Error(w, err.Error(), http.StatusBadGateway) + return + } + defer resp.Body.Close() + w.Header().Set("Content-Type", "image/jpeg") + io.Copy(w, resp.Body) +} + +func (e *Env) serveWebSocket(w http.ResponseWriter, r *http.Request) { + conn, err := websocket.Accept(w, r, nil) + if err != nil { + return + } + defer conn.CloseNow() + wsCtx := conn.CloseRead(r.Context()) + + sub := e.eventBus.Subscribe() + defer sub.Close() + + for { + select { + case <-wsCtx.Done(): + return + case <-sub.Done(): + return + case ev := <-sub.Events(): + msg, err := conn.Writer(r.Context(), websocket.MessageText) + if err != nil { + return + } + if err := tmpl().ExecuteTemplate(msg, "event.html", ev); err != nil { + msg.Close() + return + } + if err := msg.Close(); err != nil { + return + } + } + } +} diff --git a/tstest/natlab/vnet/conf.go b/tstest/natlab/vnet/conf.go index 64f28fbc9..191de9e18 100644 --- a/tstest/natlab/vnet/conf.go +++ b/tstest/natlab/vnet/conf.go @@ -15,6 +15,7 @@ import ( "github.com/google/gopacket/layers" "github.com/google/gopacket/pcapgo" + "tailscale.com/tailcfg" "tailscale.com/types/logger" "tailscale.com/util/must" "tailscale.com/util/set" @@ -137,6 +138,8 @@ func (c *Config) AddNode(opts ...any) *Node { } case MAC: n.mac = o + case tailcfg.NodeCapMap: + n.capMap = o default: if n.err == nil { n.err = fmt.Errorf("unknown AddNode option type %T", o) @@ -225,6 +228,7 @@ type Node struct { preICMPPing bool verboseSyslog bool dontJoinTailnet bool + capMap tailcfg.NodeCapMap // TODO(bradfitz): this is halfway converted to supporting multiple NICs // but not done. We need a MAC-per-Network. @@ -318,6 +322,12 @@ func (n *Node) ShouldJoinTailnet() bool { return !n.dontJoinTailnet } +// WantCapMap returns the [tailcfg.NodeCapMap] that control should send down to +// this node, if any. +func (n *Node) WantCapMap() tailcfg.NodeCapMap { + return n.capMap +} + // IsV6Only reports whether this node is only connected to IPv6 networks. func (n *Node) IsV6Only() bool { for _, net := range n.nets { @@ -435,6 +445,12 @@ func (n *Network) PostConnectedToControl() { n.network.SetControlBlackholed(n.postConnectBlackholeControl) } +// BlackholeControlForAddr sets weither the network should drop all control +// traffic for the specified addr starting immediately. +func (n *Network) BlackholeControlForAddr(addr netip.Addr) { + n.network.BlackholeControlForAddr(addr) +} + // NetworkService is a service that can be added to a network. type NetworkService string diff --git a/tstest/natlab/vnet/vip.go b/tstest/natlab/vnet/vip.go index a6973ed50..07b64f54c 100644 --- a/tstest/natlab/vnet/vip.go +++ b/tstest/natlab/vnet/vip.go @@ -33,6 +33,13 @@ func (v virtualIP) Match(a netip.Addr) bool { return v.v4 == a.Unmap() || v.v6 == a } +// TestDriverIPv4 returns the IPv4 address of the test driver VIP (52.52.0.2). +// TTA agents dial this IP on port TestDriverPort to connect to the test harness. +func TestDriverIPv4() netip.Addr { return fakeTestAgent.v4 } + +// TestDriverPort is the port the test driver listens on. +const TestDriverPort = 8008 + // FakeDNSIPv4 returns the fake DNS IPv4 address. func FakeDNSIPv4() netip.Addr { return fakeDNS.v4 } diff --git a/tstest/natlab/vnet/vnet.go b/tstest/natlab/vnet/vnet.go index 43256dafe..958da04de 100644 --- a/tstest/natlab/vnet/vnet.go +++ b/tstest/natlab/vnet/vnet.go @@ -205,7 +205,7 @@ func (n *network) initStack() error { return tcpFwd.HandlePacket(tei, pb) }) - go func() { + n.s.wg.Go(func() { for { pkt := n.linkEP.ReadContext(n.s.shutdownCtx) if pkt == nil { @@ -217,7 +217,7 @@ func (n *network) initStack() error { } n.handleIPPacketFromGvisor(pkt.ToView().AsSlice()) } - }() + }) return nil } @@ -352,7 +352,7 @@ func (n *network) acceptTCP(r *tcp.ForwarderRequest) { return } - if destPort == 8008 && fakeTestAgent.Match(destIP) { + if destPort == TestDriverPort && fakeTestAgent.Match(destIP) { node, ok := n.nodeByIP(clientRemoteIP) if !ok { n.logf("unknown client IP %v trying to connect to test driver", clientRemoteIP) @@ -369,8 +369,11 @@ func (n *network) acceptTCP(r *tcp.ForwarderRequest) { if destPort == 80 && fakeControl.Match(destIP) { r.Complete(false) tc := gonet.NewTCPConn(&wq, ep) + context.AfterFunc(n.s.shutdownCtx, func() { tc.SetDeadline(time.Now()) }) hs := &http.Server{Handler: n.s.control} - go hs.Serve(netutil.NewOneConnListener(tc, nil)) + n.s.wg.Go(func() { + hs.Serve(netutil.NewOneConnListener(tc, nil)) + }) return } @@ -383,39 +386,54 @@ func (n *network) acceptTCP(r *tcp.ForwarderRequest) { r.Complete(false) tc := gonet.NewTCPConn(&wq, ep) + context.AfterFunc(n.s.shutdownCtx, func() { tc.SetDeadline(time.Now()) }) tlsConn := tls.Server(tc, ds.tlsConfig) hs := &http.Server{Handler: ds.handler} - go hs.Serve(netutil.NewOneConnListener(tlsConn, nil)) + n.s.wg.Go(func() { + hs.Serve(netutil.NewOneConnListener(tlsConn, nil)) + }) return } if destPort == 80 { r.Complete(false) tc := gonet.NewTCPConn(&wq, ep) + context.AfterFunc(n.s.shutdownCtx, func() { tc.SetDeadline(time.Now()) }) hs := &http.Server{Handler: n.s.derps[0].handler} - go hs.Serve(netutil.NewOneConnListener(tc, nil)) + n.s.wg.Go(func() { + hs.Serve(netutil.NewOneConnListener(tc, nil)) + }) return } } if destPort == 443 && fakeLogCatcher.Match(destIP) { r.Complete(false) tc := gonet.NewTCPConn(&wq, ep) - go n.serveLogCatcherConn(clientRemoteIP, tc) + context.AfterFunc(n.s.shutdownCtx, func() { tc.SetDeadline(time.Now()) }) + n.s.wg.Go(func() { + n.serveLogCatcherConn(clientRemoteIP, tc) + }) return } if destPort == 80 && fakeCloudInit.Match(destIP) { r.Complete(false) tc := gonet.NewTCPConn(&wq, ep) + context.AfterFunc(n.s.shutdownCtx, func() { tc.SetDeadline(time.Now()) }) hs := &http.Server{Handler: n.s.cloudInitHandler()} - go hs.Serve(netutil.NewOneConnListener(tc, nil)) + n.s.wg.Go(func() { + hs.Serve(netutil.NewOneConnListener(tc, nil)) + }) return } if destPort == 80 && fakeFiles.Match(destIP) { r.Complete(false) tc := gonet.NewTCPConn(&wq, ep) + context.AfterFunc(n.s.shutdownCtx, func() { tc.SetDeadline(time.Now()) }) hs := &http.Server{Handler: n.s.fileServerHandler()} - go hs.Serve(netutil.NewOneConnListener(tc, nil)) + n.s.wg.Go(func() { + hs.Serve(netutil.NewOneConnListener(tc, nil)) + }) return } @@ -588,6 +606,9 @@ type network struct { // writers is a map of MAC -> networkWriters to write packets to that MAC. // It contains entries for connected nodes only. writers syncs.Map[MAC, networkWriter] // MAC -> to networkWriter for that MAC + + blackholeMu sync.Mutex + blackholeMap map[netip.Addr]netip.Addr // blackholeMap contains address pairs for dropping traffic (in either direction) } // registerWriter registers a client address with a MAC address. @@ -635,6 +656,19 @@ func (n *network) SetControlBlackholed(v bool) { n.blackholeControl = v } +// BlackholeControlForAddr sets up a map entry, ensuring that traffic to or from +// control from the addr is dropped. +func (n *network) BlackholeControlForAddr(addr netip.Addr) { + n.blackholeMu.Lock() + defer n.blackholeMu.Unlock() + + if addr.Is6() { + mak.Set(&n.blackholeMap, addr, fakeControl.v6) + } else { + mak.Set(&n.blackholeMap, addr, fakeControl.v4) + } +} + // nodeNIC represents a single network interface on a node. // For multi-homed nodes, additional NICs beyond the primary are stored in node.extraNICs. type nodeNIC struct { @@ -747,9 +781,14 @@ type Server struct { agentConnWaiter map[*node]chan<- struct{} // signaled after added to set agentConns set.Set[*agentConn] // not keyed by node; should be small/cheap enough to scan all agentDialer map[*node]netx.DialFunc + gotFirstPacket map[MAC]chan struct{} // closed on first packet from each MAC cloudInitData map[int]*CloudInitData // node num → cloud-init config fileContents map[string][]byte // filename → file bytes + + // onDHCPEvent, if non-nil, is called when DHCP messages are processed. + // Parameters are: source MAC, node number, DHCP message type, assigned IP. + onDHCPEvent func(nodeMAC MAC, nodeNum int, msgType layers.DHCPMsgType, assignedIP netip.Addr) } func (s *Server) logf(format string, args ...any) { @@ -764,6 +803,13 @@ func (s *Server) SetLoggerForTest(logf func(format string, args ...any)) { s.optLogf = logf } +// SetDHCPCallback registers a function to be called when DHCP messages are +// processed. The callback receives the source MAC, node number, DHCP message +// type (Discover, Offer, Request, Ack), and the assigned IP address. +func (s *Server) SetDHCPCallback(fn func(MAC, int, layers.DHCPMsgType, netip.Addr)) { + s.onDHCPEvent = fn +} + var derpMap = &tailcfg.DERPMap{ Regions: map[int]*tailcfg.DERPRegion{ 1: { @@ -825,6 +871,10 @@ func New(c *Config) (*Server, error) { if err := s.initFromConfig(c); err != nil { return nil, err } + s.gotFirstPacket = make(map[MAC]chan struct{}) + for mac := range s.nodeByMAC { + s.gotFirstPacket[mac] = make(chan struct{}) + } for n := range s.networks { if err := n.initStack(); err != nil { return nil, fmt.Errorf("newServer: initStack: %v", err) @@ -932,6 +982,22 @@ func (s *Server) Close() { s.wg.Wait() } +// AwaitFirstPacket waits until the first ethernet frame is received from the +// given MAC address, indicating the VM has booted far enough to send network +// traffic. It returns an error if the context expires first. +func (s *Server) AwaitFirstPacket(ctx context.Context, mac MAC) error { + ch, ok := s.gotFirstPacket[mac] + if !ok { + return fmt.Errorf("unknown MAC %v", mac) + } + select { + case <-ch: + return nil + case <-ctx.Done(): + return fmt.Errorf("no network packets received from %v: %w", mac, ctx.Err()) + } +} + // MACs returns the MAC addresses of the configured nodes. func (s *Server) MACs() iter.Seq[MAC] { return maps.Keys(s.nodeByMAC) @@ -1045,8 +1111,7 @@ func (s *Server) ServeUnixConn(uc *net.UnixConn, proto Protocol) { n, addr, err := uc.ReadFromUnix(buf) raddr = addr if err != nil { - if s.shutdownCtx.Err() != nil { - // Return without logging. + if s.shutdownCtx.Err() != nil || errors.Is(err, net.ErrClosed) { return } s.logf("ReadFromUnix: %#v", err) @@ -1088,6 +1153,13 @@ func (s *Server) ServeUnixConn(uc *net.UnixConn, proto Protocol) { } if !didReg[srcMAC] { didReg[srcMAC] = true + if ch, ok := s.gotFirstPacket[srcMAC]; ok { + select { + case <-ch: // already closed + default: + close(ch) + } + } srcNet := srcNode.netForMAC(srcMAC) if srcNet == nil { s.logf("[conn %p] node %v has no network for MAC %v", c.uc, srcNode, srcMAC) @@ -1133,6 +1205,23 @@ func (s *Server) handleEthernetFrameFromVM(packetRaw []byte) error { return nil } +// routeTCPPacket forwards a TCP packet to the network owning the +// destination IP (looked up by WAN IP). Used for inter-network TCP +// forwarding so guest VM TCP stacks talk end-to-end through vnet's +// packet-level NAT. +func (s *Server) routeTCPPacket(tp TCPPacket) { + dstIP := tp.Dst.Addr() + netw, ok := s.networkByWAN.Lookup(dstIP) + if !ok { + if dstIP.IsPrivate() { + return + } + log.Printf("no network to route TCP packet for %v", tp.Dst) + return + } + netw.HandleTCPPacket(tp) +} + func (s *Server) routeUDPPacket(up UDPPacket) { // Find which network owns this based on the destination IP // and all the known networks' wan IPs. @@ -1369,6 +1458,65 @@ func (n *network) nodeByIP(ip netip.Addr) (node *node, ok bool) { return node, ok } +// HandleTCPPacket handles a TCP packet arriving from the simulated +// internet, addressed to the network's WAN IP. It NATs the destination +// back to a LAN node and writes the rewritten packet onto the LAN. +func (n *network) HandleTCPPacket(p TCPPacket) { + buf, err := n.serializedTCPPacket(p.Src, p.Dst, p.TCP, nil) + if err != nil { + n.logf("serializing TCP packet: %v", err) + return + } + n.s.pcapWriter.WritePacket(gopacket.CaptureInfo{ + Timestamp: time.Now(), + CaptureLength: len(buf), + Length: len(buf), + InterfaceIndex: n.wanInterfaceID, + }, buf) + if p.Dst.Addr().Is4() && n.breakWAN4 { + return + } + dst := n.doNATIn(p.Src, p.Dst) + if !dst.IsValid() { + n.logf("Warning: NAT dropped TCP packet; no mapping for %v=>%v", p.Src, p.Dst) + return + } + p.Dst = dst + buf, err = n.serializedTCPPacket(p.Src, p.Dst, p.TCP, nil) + if err != nil { + n.logf("serializing TCP packet: %v", err) + return + } + n.s.pcapWriter.WritePacket(gopacket.CaptureInfo{ + Timestamp: time.Now(), + CaptureLength: len(buf), + Length: len(buf), + InterfaceIndex: n.lanInterfaceID, + }, buf) + n.WriteTCPPacketNoNAT(p) +} + +// WriteTCPPacketNoNAT writes a TCP packet to the network without doing +// any NAT translation. The src/dst in p must already be in their final +// form for the LAN. +func (n *network) WriteTCPPacketNoNAT(p TCPPacket) { + node, ok := n.nodeByIP(p.Dst.Addr()) + if !ok { + n.logf("no node for dest IP %v in TCP packet %v=>%v", p.Dst.Addr(), p.Src, p.Dst) + return + } + eth := &layers.Ethernet{ + SrcMAC: n.mac.HWAddr(), + DstMAC: node.macForNet(n).HWAddr(), + } + ethRaw, err := n.serializedTCPPacket(p.Src, p.Dst, p.TCP, eth) + if err != nil { + n.logf("serializing TCP packet: %v", err) + return + } + n.writeEth(ethRaw) +} + // WriteUDPPacketNoNAT writes a UDP packet to the network, without // doing any NAT translation. // @@ -1418,6 +1566,27 @@ func mkIPLayer(proto layers.IPProtocol, src, dst netip.Addr) serializableNetwork panic("invalid src IP") } +// serializedTCPPacket serializes a TCP packet with the given src/dst, +// using the provided TCP layer (its flags, seq/ack, window, options, +// and payload are preserved; only the src/dst ports are overwritten). +// +// If eth is non-nil, it is used as the Ethernet layer, otherwise the +// Ethernet layer is omitted. +func (n *network) serializedTCPPacket(src, dst netip.AddrPort, tcp *layers.TCP, eth *layers.Ethernet) ([]byte, error) { + ip := mkIPLayer(layers.IPProtocolTCP, src.Addr(), dst.Addr()) + // Copy the TCP layer with new ports and a zeroed checksum so + // gopacket recomputes it against the new IP pseudo-header. + newTCP := *tcp + newTCP.SrcPort = layers.TCPPort(src.Port()) + newTCP.DstPort = layers.TCPPort(dst.Port()) + newTCP.Checksum = 0 + payload := gopacket.Payload(tcp.Payload) + if eth == nil { + return mkPacket(ip, &newTCP, payload) + } + return mkPacket(eth, ip, &newTCP, payload) +} + // serializedUDPPacket serializes a UDP packet with the given source and // destination IP:port pairs, and payload. // @@ -1468,6 +1637,17 @@ func (n *network) HandleEthernetPacketForRouter(ep EthernetPacket) { // Blackhole the packet. return } + + // Drop traffic to/from address pairs in the blackholeMap. + n.blackholeMu.Lock() + defer n.blackholeMu.Unlock() + if src, ok := n.blackholeMap[flow.dst]; ok && flow.src == src { + return + } + if dst, ok := n.blackholeMap[flow.src]; ok && flow.dst == dst { + return + } + var base *layers.BaseLayer proto := header.IPv4ProtocolNumber if v4, ok := packet.Layer(layers.LayerTypeIPv4).(*layers.IPv4); ok { @@ -1489,6 +1669,19 @@ func (n *network) HandleEthernetPacketForRouter(ep EthernetPacket) { return } + // Inter-network TCP forwarding: a guest VM is sending TCP to another + // simulated network's WAN IP. Apply egress NAT (rewriting src) and + // hand the packet off to the destination network for ingress NAT and + // LAN delivery, so the two guest TCP stacks talk end-to-end. + if toForward && flow.dst.Is4() { + if tcp, ok := packet.Layer(layers.LayerTypeTCP).(*layers.TCP); ok { + if _, ok := n.s.networkByWAN.Lookup(flow.dst); ok { + n.handleTCPPacketForRouter(tcp, flow) + return + } + } + } + if flow.src.Is6() && flow.src.IsLinkLocalUnicast() && !flow.dst.IsLinkLocalUnicast() { // Don't log. return @@ -1503,6 +1696,54 @@ func (n *network) HandleEthernetPacketForRouter(ep EthernetPacket) { n.logf("router got unknown packet: %v", packet) } +// handleTCPPacketForRouter handles a TCP packet from a LAN node that +// targets another simulated network's WAN IP. It rewrites src via the +// local NAT, then routes the packet to the destination network where +// HandleTCPPacket rewrites dst and delivers it to the LAN. +func (n *network) handleTCPPacketForRouter(tcp *layers.TCP, flow ipSrcDst) { + if flow.dst.Is4() && n.breakWAN4 { + return + } + src := netip.AddrPortFrom(flow.src, uint16(tcp.SrcPort)) + dst := netip.AddrPortFrom(flow.dst, uint16(tcp.DstPort)) + + buf, err := n.serializedTCPPacket(src, dst, tcp, nil) + if err != nil { + n.logf("serializing TCP packet: %v", err) + return + } + n.s.pcapWriter.WritePacket(gopacket.CaptureInfo{ + Timestamp: time.Now(), + CaptureLength: len(buf), + Length: len(buf), + InterfaceIndex: n.lanInterfaceID, + }, buf) + + lanSrc := src + src = n.doNATOut(src, dst) + if !src.IsValid() { + n.logf("warning: NAT dropped TCP packet; no NAT out mapping for %v=>%v", lanSrc, dst) + return + } + buf, err = n.serializedTCPPacket(src, dst, tcp, nil) + if err != nil { + n.logf("serializing TCP packet: %v", err) + return + } + n.s.pcapWriter.WritePacket(gopacket.CaptureInfo{ + Timestamp: time.Now(), + CaptureLength: len(buf), + Length: len(buf), + InterfaceIndex: n.wanInterfaceID, + }, buf) + + n.s.routeTCPPacket(TCPPacket{ + Src: src, + Dst: dst, + TCP: tcp, + }) +} + func (n *network) handleUDPPacketForRouter(ep EthernetPacket, udp *layers.UDP, toForward bool, flow ipSrcDst) { packet := ep.gp srcIP, dstIP := flow.src, flow.dst @@ -1804,6 +2045,10 @@ func (s *Server) createDHCPResponse(request gopacket.Packet) ([]byte, error) { Length: 4, }, ) + if s.onDHCPEvent != nil { + s.onDHCPEvent(srcMAC, node.num, layers.DHCPMsgTypeDiscover, clientIP) + s.onDHCPEvent(srcMAC, node.num, layers.DHCPMsgTypeOffer, clientIP) + } case layers.DHCPMsgTypeRequest: response.Options = append(response.Options, layers.DHCPOption{ @@ -1832,6 +2077,10 @@ func (s *Server) createDHCPResponse(request gopacket.Packet) ([]byte, error) { Length: 4, }, ) + if s.onDHCPEvent != nil { + s.onDHCPEvent(srcMAC, node.num, layers.DHCPMsgTypeRequest, clientIP) + s.onDHCPEvent(srcMAC, node.num, layers.DHCPMsgTypeAck, clientIP) + } } eth := &layers.Ethernet{ @@ -1902,7 +2151,7 @@ func (s *Server) shouldInterceptTCP(pkt gopacket.Packet) bool { return true } } - if tcp.DstPort == 8008 && fakeTestAgent.Match(flow.dst) { + if tcp.DstPort == TestDriverPort && fakeTestAgent.Match(flow.dst) { // Connection from cmd/tta. return true } @@ -2292,6 +2541,17 @@ type UDPPacket struct { Payload []byte // everything after UDP header } +// TCPPacket is a TCP packet flowing through vnet's NAT, used for +// packet-level TCP forwarding between simulated networks. Unlike UDP +// (which only needs ports + payload), TCP carries flags, sequence +// numbers, and options that must be preserved end-to-end so the guest +// VM kernels' TCP state machines stay in sync. +type TCPPacket struct { + Src netip.AddrPort + Dst netip.AddrPort + TCP *layers.TCP // full parsed TCP layer (header + options + payload) +} + func (s *Server) WriteStartingBanner(w io.Writer) { fmt.Fprintf(w, "vnet serving clients:\n") @@ -2323,14 +2583,24 @@ func (s *Server) addIdleAgentConn(ac *agentConn) { func (s *Server) takeAgentConn(ctx context.Context, n *node) (_ *agentConn, ok bool) { const debug = false + // stuckThreshold is how long we wait before deciding the agent is slow + // enough to warrant a log line. Below this we stay quiet because, in + // healthy runs with many agent dials in flight, even a few-millisecond + // wait would otherwise log every poll for every concurrent waiter. + const stuckThreshold = 10 * time.Second + start := time.Now() + var lastWarn time.Time for { - ac, ok := s.takeAgentConnOne(n) - if ok { + ac, miss := s.takeAgentConnOne(n) + if ac != nil { if debug { log.Printf("takeAgentConn: got agent conn for %v", n.mac) } return ac, true } + if debug && miss > 0 { + log.Printf("takeAgentConnOne: missed %d times for %v", miss, n.mac) + } s.mu.Lock() ready := make(chan struct{}) mak.Set(&s.agentConnWaiter, n, ready) @@ -2339,6 +2609,10 @@ func (s *Server) takeAgentConn(ctx context.Context, n *node) (_ *agentConn, ok b if debug { log.Printf("takeAgentConn: waiting for agent conn for %v", n.mac) } + if elapsed := time.Since(start); elapsed > stuckThreshold && time.Since(lastWarn) > stuckThreshold { + log.Printf("takeAgentConn: still waiting for agent conn for %v after %v (%d idle conns for other nodes)", n.mac, elapsed.Round(time.Second), miss) + lastWarn = time.Now() + } select { case <-ctx.Done(): return nil, false @@ -2351,21 +2625,21 @@ func (s *Server) takeAgentConn(ctx context.Context, n *node) (_ *agentConn, ok b } } -func (s *Server) takeAgentConnOne(n *node) (_ *agentConn, ok bool) { +// takeAgentConnOne returns an idle agent conn for n if one is available, +// otherwise nil. miss is the number of idle agent conns for other nodes that +// were walked over while looking; the caller may use it for diagnostics when +// a wait drags on. +func (s *Server) takeAgentConnOne(n *node) (ac *agentConn, miss int) { s.mu.Lock() defer s.mu.Unlock() - miss := 0 for ac := range s.agentConns { if ac.node == n { s.agentConns.Delete(ac) - return ac, true + return ac, 0 } miss++ } - if miss > 0 { - log.Printf("takeAgentConnOne: missed %d times for %v", miss, n.mac) - } - return nil, false + return nil, miss } type NodeAgentClient struct { diff --git a/tstest/tailmac/Makefile b/tstest/tailmac/Makefile index b87e44ed1..303f72c1f 100644 --- a/tstest/tailmac/Makefile +++ b/tstest/tailmac/Makefile @@ -5,12 +5,12 @@ endif .PHONY: tailmac tailmac: - xcodebuild -scheme tailmac -destination 'platform=macOS,arch=arm64' -derivedDataPath build -configuration Release build | $(XCPRETTIFIER) + set -o pipefail && xcodebuild -scheme tailmac -destination 'platform=macOS,arch=arm64' -derivedDataPath build -configuration Release build | $(XCPRETTIFIER) cp -r ./build/Build/Products/Release/tailmac ./bin/tailmac .PHONY: host host: - xcodebuild -scheme host -destination 'platform=macOS,arch=arm64' -derivedDataPath build -configuration Release build | $(XCPRETTIFIER) + set -o pipefail && xcodebuild -scheme host -destination 'platform=macOS,arch=arm64' -derivedDataPath build -configuration Release build | $(XCPRETTIFIER) cp -r ./build/Build/Products/Release/Host.app ./bin/Host.app .PHONY: clean diff --git a/tstest/tailmac/Swift/Common/Config.swift b/tstest/tailmac/Swift/Common/Config.swift index 53d768020..53281628a 100644 --- a/tstest/tailmac/Swift/Common/Config.swift +++ b/tstest/tailmac/Swift/Common/Config.swift @@ -103,10 +103,10 @@ class Config: Codable { } -// The VM Bundle URL holds the restore image and a set of VM images -// By default, VM's are persisted at ~/VM.bundle +// The VM Bundle URL holds the restore image and a set of VM images. +// VMs are stored under ~/.cache/tailscale/vmtest/macos/. var vmBundleURL: URL = { - let vmBundlePath = NSHomeDirectory() + "/VM.bundle/" + let vmBundlePath = NSHomeDirectory() + "/.cache/tailscale/vmtest/macos/" createDir(vmBundlePath) let bundleURL = URL(fileURLWithPath: vmBundlePath) return bundleURL diff --git a/tstest/tailmac/Swift/Common/TailMacConfigHelper.swift b/tstest/tailmac/Swift/Common/TailMacConfigHelper.swift index fc7f2d89d..562eae1fa 100644 --- a/tstest/tailmac/Swift/Common/TailMacConfigHelper.swift +++ b/tstest/tailmac/Swift/Common/TailMacConfigHelper.swift @@ -74,18 +74,31 @@ struct TailMacConfigHelper { return networkDevice } + /// Creates a NIC configuration connected to the vnet dgram socket. func createSocketNetworkDeviceConfiguration() -> VZVirtioNetworkDeviceConfiguration { let networkDevice = VZVirtioNetworkDeviceConfiguration() networkDevice.macAddress = VZMACAddress(string: config.mac)! + if let attachment = createDgramAttachment(serverSocket: config.serverSocket, clientID: config.vmID) { + networkDevice.attachment = attachment + } + return networkDevice + } + /// Creates a NIC configuration with no attachment (disconnected). + /// The attachment can be hot-swapped later via VZNetworkDevice.attachment. + func createDisconnectedNetworkDeviceConfiguration() -> VZVirtioNetworkDeviceConfiguration { + let networkDevice = VZVirtioNetworkDeviceConfiguration() + networkDevice.macAddress = VZMACAddress(string: config.mac)! + // No attachment — NIC appears disconnected to the guest. + return networkDevice + } + + /// Creates a dgram socket attachment for connecting to a vnet server. + /// Returns nil on error. + func createDgramAttachment(serverSocket: String, clientID: String) -> VZFileHandleNetworkDeviceAttachment? { let socket = Darwin.socket(AF_UNIX, SOCK_DGRAM, 0) - // Outbound network packets - let serverSocket = config.serverSocket - - // Inbound network packets - let clientSockId = config.vmID - let clientSocket = "/tmp/qemu-dgram-\(clientSockId).sock" + let clientSocket = "/tmp/qemu-dgram-\(clientID).sock" unlink(clientSocket) var clientAddr = sockaddr_un() @@ -102,7 +115,7 @@ struct TailMacConfigHelper { if bindRes == -1 { print("Error binding virtual network client socket - \(String(cString: strerror(errno)))") - return networkDevice + return nil } var serverAddr = sockaddr_un() @@ -118,20 +131,16 @@ struct TailMacConfigHelper { socklen_t(MemoryLayout.size)) if connectRes == -1 { - print("Error binding virtual network server socket - \(String(cString: strerror(errno)))") - return networkDevice + print("Error connecting to server socket \(serverSocket) - \(String(cString: strerror(errno)))") + return nil } print("Virtual if mac address is \(config.mac)") print("Client bound to \(clientSocket)") print("Connected to server at \(serverSocket)") - print("Socket fd is \(socket)") - let handle = FileHandle(fileDescriptor: socket) - let device = VZFileHandleNetworkDeviceAttachment(fileHandle: handle) - networkDevice.attachment = device - return networkDevice + return VZFileHandleNetworkDeviceAttachment(fileHandle: handle) } func createPointingDeviceConfiguration() -> VZPointingDeviceConfiguration { diff --git a/tstest/tailmac/Swift/Host/HostCli.swift b/tstest/tailmac/Swift/Host/HostCli.swift index 9c9ae6fa0..16711b2aa 100644 --- a/tstest/tailmac/Swift/Host/HostCli.swift +++ b/tstest/tailmac/Swift/Host/HostCli.swift @@ -20,13 +20,291 @@ extension HostCli { struct Run: ParsableCommand { @Option var id: String @Option var share: String? + @Flag(help: "Run without GUI (for automated testing)") var headless: Bool = false + @Flag(help: "Create NIC with no attachment (for later hot-swap)") var disconnectedNic: Bool = false + @Flag(help: "Use NAT NIC instead of socket NIC (for snapshot prep)") var natNic: Bool = false + @Option(help: "Hot-swap NIC to this dgram socket path after boot/restore") var attachNetwork: String? + @Option(help: "Serve screenshots on this localhost port (0 = auto)") var screenshotPort: Int? + @Option(help: "Assign IP/mask/gw to guest via vsock (e.g. 192.168.1.2/255.255.255.0/192.168.1.1)") var assignIp: String? mutating func run() { config = Config(id) config.sharedDir = share print("Running vm with identifier \(id) and sharedDir \(share ?? "")") - _ = NSApplicationMain(CommandLine.argc, CommandLine.unsafeArgv) + + if headless { + let attachSocket = attachNetwork + let useNatNIC = natNic + let disconnected = !useNatNIC && (disconnectedNic || attachSocket != nil) + let wantScreenshots = screenshotPort != nil + let requestedPort = UInt16(screenshotPort ?? 0) + let ipConfig = assignIp + + // Set up SIGINT handler before entering the event loop. + // The dispatch source must be stored in a global to prevent ARC deallocation. + signal(SIGINT, SIG_IGN) + let sigintSource = DispatchSource.makeSignalSource(signal: SIGINT, queue: .main) + retainedSigintSource = sigintSource + + DispatchQueue.main.async { + let controller = VMController() + controller.createVirtualMachine(headless: true, disconnectedNIC: disconnected, natNIC: useNatNIC) + + // Start vsock listener for IP assignment. + // If --assign-ip is set, the listener replies with the IP config JSON. + // If not set (snapshot prep), it replies "wait" so TTA keeps polling. + if let ipCfg = ipConfig { + let parts = ipCfg.split(separator: "/") + if parts.count == 3 { + let response = "{\"ip\":\"\(parts[0])\",\"mask\":\"\(parts[1])\",\"gw\":\"\(parts[2])\"}" + controller.startIPConfigListener(response: response) + } + } else { + controller.startIPConfigListener(response: "wait") + } + + sigintSource.setEventHandler { + print("SIGINT received, disconnecting NIC and saving VM state...") + controller.disconnectNetwork() + controller.pauseAndSaveVirtualMachine { + print("VM state saved, exiting.") + Foundation.exit(0) + } + } + sigintSource.resume() + + // Set up screenshot HTTP server if requested. + // The window must be ordered on-screen for the window server + // to composite VZVirtualMachineView's content. We place it + // behind all other windows and make it tiny (1x1) so it's + // effectively invisible. + if wantScreenshots { + let vmView = VZVirtualMachineView() + vmView.virtualMachine = controller.virtualMachine + vmView.frame = NSRect(x: 0, y: 0, width: 1920, height: 1200) + + let window = NSWindow( + contentRect: NSRect(x: 0, y: 0, width: 1920, height: 1200), + styleMask: [.borderless], + backing: .buffered, + defer: false + ) + window.isReleasedWhenClosed = false + window.contentView = vmView + // Place behind all other windows so it's not visible to the user. + window.level = NSWindow.Level(rawValue: Int(CGWindowLevelForKey(.minimumWindow)) - 1) + window.orderFront(nil) + + startScreenshotServer(view: vmView, port: requestedPort) + } + + let doAttach = { + if let sock = attachSocket { + controller.attachNetwork(serverSocket: sock, clientID: config.vmID) + } + } + + let fileManager = FileManager.default + if fileManager.fileExists(atPath: config.saveFileURL.path) { + print("Restoring virtual machine state from \(config.saveFileURL)") + controller.restoreVirtualMachine() + doAttach() + } else { + print("Starting virtual machine") + controller.startVirtualMachine() + doAttach() + } + } + + if wantScreenshots { + // NSApp event loop needed for VZVirtualMachineView rendering. + let app = NSApplication.shared + app.setActivationPolicy(.accessory) + print("STARTING_NSAPP") + fflush(stdout) + app.run() + } else { + // Use dispatchMain() instead of RunLoop.main.run() so that + // GCD dispatch sources (like the SIGINT handler) are processed. + dispatchMain() + } + } else { + _ = NSApplicationMain(CommandLine.argc, CommandLine.unsafeArgv) + } } } } +// startScreenshotServer starts a localhost HTTP server that serves VM display +// screenshots on GET /screenshot as JPEG. The port is printed to stdout as +// "SCREENSHOT_PORT=" so the Go test harness can discover it. +var retainedSigintSource: DispatchSourceSignal? // prevent ARC deallocation +var screenshotServer: ScreenshotHTTPServer? // prevent GC + +func startScreenshotServer(view: NSView, port: UInt16) { + let server = ScreenshotHTTPServer(view: view) + screenshotServer = server + server.start(port: port) +} + +/// Minimal HTTP server that serves screenshots of a VZVirtualMachineView. +class ScreenshotHTTPServer: NSObject { + let view: NSView + var acceptSource: DispatchSourceRead? // prevent GC + + init(view: NSView) { + self.view = view + } + + private func log(_ msg: String) { + let s = msg + "\n" + FileHandle.standardError.write(Data(s.utf8)) + } + + func start(port: UInt16) { + let queue = DispatchQueue(label: "screenshot-server") + + let fd = socket(AF_INET, SOCK_STREAM, 0) + guard fd >= 0 else { + log("screenshot server: socket() failed") + return + } + var yes: Int32 = 1 + setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &yes, socklen_t(MemoryLayout.size)) + + var addr = sockaddr_in() + addr.sin_len = UInt8(MemoryLayout.size) + addr.sin_family = sa_family_t(AF_INET) + addr.sin_port = port.bigEndian + addr.sin_addr.s_addr = UInt32(0x7f000001).bigEndian // 127.0.0.1 + + let bindResult = withUnsafePointer(to: &addr) { ptr in + ptr.withMemoryRebound(to: sockaddr.self, capacity: 1) { sockPtr in + Darwin.bind(fd, sockPtr, socklen_t(MemoryLayout.size)) + } + } + guard bindResult == 0 else { + log("screenshot server: bind() failed: \(errno)") + close(fd) + return + } + guard Darwin.listen(fd, 4) == 0 else { + log("screenshot server: listen() failed") + close(fd) + return + } + + var boundAddr = sockaddr_in() + var boundLen = socklen_t(MemoryLayout.size) + withUnsafeMutablePointer(to: &boundAddr) { ptr in + ptr.withMemoryRebound(to: sockaddr.self, capacity: 1) { sockPtr in + getsockname(fd, sockPtr, &boundLen) + } + } + let actualPort = UInt16(bigEndian: boundAddr.sin_port) + print("SCREENSHOT_PORT=\(actualPort)") + fflush(stdout) + + let source = DispatchSource.makeReadSource(fileDescriptor: fd, queue: queue) + source.setEventHandler { [self] in + let clientFd = accept(fd, nil, nil) + self.log("screenshot: accept fd=\(clientFd)") + guard clientFd >= 0 else { return } + self.handleConnection(clientFd) + } + source.setCancelHandler { close(fd) } + source.resume() + self.acceptSource = source + } + + private func handleConnection(_ fd: Int32) { + var buf = [UInt8](repeating: 0, count: 4096) + let n = read(fd, &buf, buf.count) + let requestLine = n > 0 ? String(bytes: buf[.. Data? { + guard let window = view.window else { + log("screenshot: no window") + return nil + } + + // Use CGWindowListCreateImage to capture the composited window content, + // which includes GPU-rendered layers like VZVirtualMachineView's Metal surface. + let windowID = CGWindowID(window.windowNumber) + guard let cgImage = CGWindowListCreateImage( + .null, + .optionIncludingWindow, + windowID, + [.boundsIgnoreFraming, .bestResolution] + ) else { + log("screenshot: CGWindowListCreateImage returned nil") + return nil + } + + if fullSize { + let bitmapRep = NSBitmapImageRep(cgImage: cgImage) + return bitmapRep.representation(using: .jpeg, properties: [.compressionFactor: 0.85]) + } + + // Resize to ~800px wide for thumbnails. + let targetWidth = 800 + let scale = Double(targetWidth) / Double(cgImage.width) + let targetHeight = Int(Double(cgImage.height) * scale) + + guard let ctx = CGContext( + data: nil, + width: targetWidth, + height: targetHeight, + bitsPerComponent: 8, + bytesPerRow: 0, + space: CGColorSpaceCreateDeviceRGB(), + bitmapInfo: CGImageAlphaInfo.premultipliedFirst.rawValue + ) else { + log("screenshot: CGContext creation failed") + return nil + } + ctx.interpolationQuality = .high + ctx.draw(cgImage, in: CGRect(x: 0, y: 0, width: targetWidth, height: targetHeight)) + + guard let resized = ctx.makeImage() else { + log("screenshot: makeImage failed") + return nil + } + + let bitmapRep = NSBitmapImageRep(cgImage: resized) + return bitmapRep.representation(using: .jpeg, properties: [.compressionFactor: 0.6]) + } +} + diff --git a/tstest/tailmac/Swift/Host/VMController.swift b/tstest/tailmac/Swift/Host/VMController.swift index a19d7222e..c2014009a 100644 --- a/tstest/tailmac/Swift/Host/VMController.swift +++ b/tstest/tailmac/Swift/Host/VMController.swift @@ -81,7 +81,7 @@ class VMController: NSObject, VZVirtualMachineDelegate { return macPlatform } - func createVirtualMachine() { + func createVirtualMachine(headless: Bool = false, disconnectedNIC: Bool = false, natNIC: Bool = false) { let virtualMachineConfiguration = VZVirtualMachineConfiguration() virtualMachineConfiguration.platform = createMacPlaform() @@ -90,7 +90,21 @@ class VMController: NSObject, VZVirtualMachineDelegate { virtualMachineConfiguration.memorySize = helper.computeMemorySize() virtualMachineConfiguration.graphicsDevices = [helper.createGraphicsDeviceConfiguration()] virtualMachineConfiguration.storageDevices = [helper.createBlockDeviceConfiguration()] - virtualMachineConfiguration.networkDevices = [helper.createNetworkDeviceConfiguration(), helper.createSocketNetworkDeviceConfiguration()] + if headless { + if natNIC { + // NAT NIC for SSH access during snapshot preparation. + virtualMachineConfiguration.networkDevices = [helper.createNetworkDeviceConfiguration()] + } else if disconnectedNIC { + // Create a NIC with no attachment. The NIC exists in the hardware + // config (so saved state is compatible) but appears disconnected. + // Call attachNetwork() after restore to hot-swap the attachment. + virtualMachineConfiguration.networkDevices = [helper.createDisconnectedNetworkDeviceConfiguration()] + } else { + virtualMachineConfiguration.networkDevices = [helper.createSocketNetworkDeviceConfiguration()] + } + } else { + virtualMachineConfiguration.networkDevices = [helper.createNetworkDeviceConfiguration(), helper.createSocketNetworkDeviceConfiguration()] + } virtualMachineConfiguration.pointingDevices = [helper.createPointingDeviceConfiguration()] virtualMachineConfiguration.keyboards = [helper.createKeyboardConfiguration()] virtualMachineConfiguration.socketDevices = [helper.createSocketDeviceConfiguration()] @@ -109,6 +123,33 @@ class VMController: NSObject, VZVirtualMachineDelegate { virtualMachine.delegate = self } + /// Disconnect the NIC by setting its attachment to nil. + /// Call before saving state so the snapshot has no active link. + func disconnectNetwork() { + guard let nic = virtualMachine.networkDevices.first else { + print("disconnectNetwork: no network devices") + return + } + nic.attachment = nil + print("disconnectNetwork: NIC attachment set to nil") + } + + /// Hot-swap the NIC attachment on a running VM. The VM must have been + /// created with disconnectedNIC=true. After calling this, the guest + /// sees the link come up and does DHCP. + func attachNetwork(serverSocket: String, clientID: String) { + guard let nic = virtualMachine.networkDevices.first else { + print("attachNetwork: no network devices") + return + } + guard let attachment = helper.createDgramAttachment(serverSocket: serverSocket, clientID: clientID) else { + print("attachNetwork: failed to create attachment") + return + } + nic.attachment = attachment + print("attachNetwork: NIC attachment swapped to \(serverSocket)") + } + func startVirtualMachine() { virtualMachine.start(completionHandler: { (result) in @@ -130,6 +171,21 @@ class VMController: NSObject, VZVirtualMachineDelegate { } } + /// Start a vsock listener that tells the guest TTA agent what IP to configure. + /// If response is nil, the listener replies "wait" (snapshot prep mode). + func startIPConfigListener(response: String) { + guard let device = virtualMachine.socketDevices.first as? VZVirtioSocketDevice else { + print("startIPConfigListener: no socket device") + return + } + let listener = IPConfigListener(response: response) + retainedIPConfigListener = listener + let vsockListener = VZVirtioSocketListener() + vsockListener.delegate = listener + device.setSocketListener(vsockListener, forPort: 51011) + print("startIPConfigListener: listening on vsock port 51011") + } + func resumeVirtualMachine() { virtualMachine.resume(completionHandler: { (result) in if case let .failure(error) = result { @@ -184,3 +240,28 @@ class VMController: NSObject, VZVirtualMachineDelegate { exit(0) } } + +// Global to prevent ARC deallocation of the vsock listener. +var retainedIPConfigListener: IPConfigListener? + +/// Listens on vsock port 51011 for TTA connections and replies with +/// an IP configuration JSON string (or "wait" during snapshot prep). +class IPConfigListener: NSObject, VZVirtioSocketListenerDelegate { + let response: String + + init(response: String) { + self.response = response + } + + func listener(_ listener: VZVirtioSocketListener, + shouldAcceptNewConnection connection: VZVirtioSocketConnection, + from socketDevice: VZVirtioSocketDevice) -> Bool { + let fd = connection.fileDescriptor + let data = Array((response + "\n").utf8) + data.withUnsafeBufferPointer { buf in + _ = write(fd, buf.baseAddress!, buf.count) + } + connection.close() + return true + } +} diff --git a/tstest/tailmac/Swift/TailMac/TailMac.swift b/tstest/tailmac/Swift/TailMac/TailMac.swift index 3859b9b0b..2271d3bb2 100644 --- a/tstest/tailmac/Swift/TailMac/TailMac.swift +++ b/tstest/tailmac/Swift/TailMac/TailMac.swift @@ -329,7 +329,7 @@ extension Tailmac { } } - dispatchMain() + RunLoop.main.run() } } } diff --git a/tstest/tstest.go b/tstest/tstest.go index 4e00fbaa3..7e25ce8a0 100644 --- a/tstest/tstest.go +++ b/tstest/tstest.go @@ -20,8 +20,22 @@ import ( "tailscale.com/util/cibuild" ) +// AssertNotParallel asserts that t has not been marked as parallel. +// It panics (via t.Setenv) if t.Parallel has already been called. +// +// Use this when a test modifies package-level globals or other shared +// state that would be unsafe to modify concurrently with other tests. +func AssertNotParallel(t testing.TB) { + t.Helper() + t.Setenv("ASSERT_NOT_PARALLEL_TEST", "1") // panics if t.Parallel was called +} + // Replace replaces the value of target with val. // The old value is restored when the test ends. +// +// When target is a package-level variable, the caller should also call +// [AssertNotParallel] to ensure the test is not running in parallel with +// other tests that may access the same variable. func Replace[T any](t testing.TB, target *T, val T) { t.Helper() if target == nil { @@ -95,6 +109,14 @@ func Parallel(t *testing.T) { } } +// RequireRoot skips the test if the current user is not root. +func RequireRoot(tb testing.TB) { + tb.Helper() + if os.Getuid() != 0 { + tb.Skip("skipping test; requires root") + } +} + // SkipOnKernelVersions skips the test if the current // kernel version is in the specified list. func SkipOnKernelVersions(t testing.TB, issue string, versions ...string) { diff --git a/tsweb/varz/varz_test.go b/tsweb/varz/varz_test.go index d041edb4b..27094e77b 100644 --- a/tsweb/varz/varz_test.go +++ b/tsweb/varz/varz_test.go @@ -205,7 +205,7 @@ func TestVarzHandler(t *testing.T) { "string_map", func() *expvar.Map { m := new(expvar.Map) - m.Set("a", expvar.NewString("foo")) + m.Set("a", new(expvar.String)) return m }(), "# skipping \"string_map\" expvar map key \"a\" with unknown value type *expvar.String\n", diff --git a/types/key/nl.go b/types/key/nl.go index 0e8c5ed96..32bc94364 100644 --- a/types/key/nl.go +++ b/types/key/nl.go @@ -29,7 +29,7 @@ const ( nlPublicHexPrefixCLI = "tlpub:" ) -// NLPrivate is a node-managed network-lock key, used for signing +// NLPrivate is a node-managed tailnet-lock key, used for signing // node-key signatures and authority update messages. type NLPrivate struct { _ structs.Incomparable // because == isn't constant-time @@ -42,7 +42,7 @@ func (k NLPrivate) IsZero() bool { return subtle.ConstantTimeCompare(k.k[:], empty.k[:]) == 1 } -// NewNLPrivate creates and returns a new network-lock key. +// NewNLPrivate creates and returns a new tailnet-lock key. func NewNLPrivate() NLPrivate { // ed25519.GenerateKey 'clamps' the key, not that it // matters given we don't do Diffie-Hellman. @@ -120,7 +120,7 @@ type NLPublic struct { // a type of NLPublic. // // New uses of this function should be avoided, as it's possible to -// accidentally construct an NLPublic from a non network-lock key. +// accidentally construct an NLPublic from a non tailnet-lock key. func NLPublicFromEd25519Unsafe(public ed25519.PublicKey) NLPublic { var out NLPublic copy(out.k[:], public) diff --git a/types/key/node.go b/types/key/node.go index 98f72c719..a1d8e47ba 100644 --- a/types/key/node.go +++ b/types/key/node.go @@ -65,6 +65,11 @@ func NewNode() NodePrivate { // Raw32 returns k as 32 raw bytes. func (k NodePrivate) Raw32() [32]byte { return k.k } +// NodePrivateAs returns a NodePrivate as a named fixed-size array of bytes. +// It's intended for interoperability with wireguard-go's +// device.NoisePrivateKey type. +func NodePrivateAs[T ~[32]byte](k NodePrivate) T { return k.k } + // NodePrivateFromRaw32 parses a 32-byte raw value as a NodePrivate. // // Deprecated: only needed to cast from legacy node private key types, diff --git a/types/netmap/netmap.go b/types/netmap/netmap.go index ac95254da..fbf415be0 100644 --- a/types/netmap/netmap.go +++ b/types/netmap/netmap.go @@ -146,6 +146,34 @@ func (nm *NetworkMap) GetIPVIPServiceMap() IPServiceMappings { return res } +// Services returns the Services visible (accessible) to this node, +// decoded from [tailcfg.NodeAttrPrefixServices] entries in the self node's +// CapMap. The returned map is keyed by [tailcfg.ServiceDetails.Name], +// which is the canonical service name. It returns nil if nm is nil +// or SelfNode is invalid. +// +// TODO(adrianosela): cache the result of decoding the capmap so +// we don't have to decode it multiple times after each netmap update. +func (nm *NetworkMap) Services() map[tailcfg.ServiceName]tailcfg.ServiceDetails { + if nm == nil || !nm.SelfNode.Valid() { + return nil + } + result := make(map[tailcfg.ServiceName]tailcfg.ServiceDetails) + for cap := range nm.SelfNode.CapMap().All() { + if !strings.HasPrefix(string(cap), string(tailcfg.NodeAttrPrefixServices)) { + continue + } + svcs, err := tailcfg.UnmarshalNodeCapViewJSON[tailcfg.ServiceDetails](nm.SelfNode.CapMap(), cap) + if err != nil || len(svcs) < 1 { + continue + } + // NOTE(adrianosela): the NodeCapMap key suffix is opaque and MUST not + // be parsed or relied upon (so we extract name from the inner field). + result[svcs[0].Name] = svcs[0] + } + return result +} + // SelfNodeOrZero returns the self node, or a zero value if nm is nil. func (nm *NetworkMap) SelfNodeOrZero() tailcfg.NodeView { if nm == nil { @@ -284,13 +312,6 @@ func (nm *NetworkMap) TailnetDisplayName() string { return tailnetDisplayNames[0] } -// HasSelfCapability reports whether nm.SelfNode contains capability c. -// -// It exists to satisify an unused (as of 2025-01-04) interface in the logknob package. -func (nm *NetworkMap) HasSelfCapability(c tailcfg.NodeCapability) bool { - return nm.AllCaps.Contains(c) -} - func (nm *NetworkMap) String() string { return nm.Concise() } diff --git a/update-flake.sh b/update-flake.sh deleted file mode 100755 index c22572b86..000000000 --- a/update-flake.sh +++ /dev/null @@ -1,26 +0,0 @@ -#!/bin/sh -# Updates SRI hashes for flake.nix. - -set -eu - -OUT=$(mktemp -d -t nar-hash-XXXXXX) -rm -rf "$OUT" - -./tool/go mod vendor -o "$OUT" -./tool/go run tailscale.com/cmd/nardump --sri "$OUT" >go.mod.sri -rm -rf "$OUT" - -GOOUT=$(mktemp -d -t gocross-XXXXXX) -GOREV=$(xargs < ./go.toolchain.rev) -TARBALL="$GOOUT/go-$GOREV.tar.gz" -curl -Ls -o "$TARBALL" "https://github.com/tailscale/go/archive/$GOREV.tar.gz" -tar -xzf "$TARBALL" -C "$GOOUT" -./tool/go run tailscale.com/cmd/nardump --sri "$GOOUT/go-$GOREV" > go.toolchain.rev.sri -rm -rf "$GOOUT" - -# nix-direnv only watches the top-level nix file for changes. As a -# result, when we change a referenced SRI file, we have to cause some -# change to shell.nix and flake.nix as well, so that nix-direnv -# notices and reevaluates everything. Sigh. -perl -pi -e "s,# nix-direnv cache busting line:.*,# nix-direnv cache busting line: $(cat go.mod.sri)," shell.nix -perl -pi -e "s,# nix-direnv cache busting line:.*,# nix-direnv cache busting line: $(cat go.mod.sri)," flake.nix diff --git a/util/cibuild/cibuild.go b/util/cibuild/cibuild.go index 4a4e241ac..a821862b0 100644 --- a/util/cibuild/cibuild.go +++ b/util/cibuild/cibuild.go @@ -12,3 +12,15 @@ func On() bool { // https://docs.github.com/en/actions/learn-github-actions/environment-variables#default-environment-variables return os.Getenv("GITHUB_ACTIONS") != "" || os.Getenv("CI") == "true" } + +// OnTailscaleCI reports whether the current binary is executing on +// tailscale/tailscale's own GitHub Actions CI, as opposed to a fork's CI +// or an unrelated downstream CI (such as a Linux distribution's package +// build infrastructure) that also sets the generic CI=true environment +// variable. +func OnTailscaleCI() bool { + // GITHUB_REPOSITORY_OWNER is set by GitHub Actions to the owner of + // the repository whose workflow is running. For pull requests, this + // is the base repository's owner, not the fork's. + return os.Getenv("GITHUB_REPOSITORY_OWNER") == "tailscale" +} diff --git a/util/clientmetric/clientmetric.go b/util/clientmetric/clientmetric.go index b67cbbd39..98068b9fa 100644 --- a/util/clientmetric/clientmetric.go +++ b/util/clientmetric/clientmetric.go @@ -22,6 +22,7 @@ import ( "tailscale.com/feature/buildfeatures" "tailscale.com/util/set" + "tailscale.com/util/testenv" ) var ( @@ -452,6 +453,24 @@ func (b *deltaEncBuf) writeHexVarint(v int64) { b.buf.Write(hexBuf) } +// ResetForTest resets all client metric values to zero. +// It panics if not in a test or if called from a parallel test. +func ResetForTest(t testenv.TB) { + if !testenv.InTest() { + panic("clientmetric.ResetForTest called outside a test") + } + if testenv.InParallelTest(t) { + panic("clientmetric.ResetForTest called from a parallel test") + } + mu.Lock() + defer mu.Unlock() + for _, m := range metrics { + if m.v != nil { + atomic.StoreInt64(m.v, 0) + } + } +} + var TestHooks testHooks type testHooks struct{} diff --git a/util/clientmetric/omit.go b/util/clientmetric/omit.go index 380205eeb..74018f12a 100644 --- a/util/clientmetric/omit.go +++ b/util/clientmetric/omit.go @@ -30,3 +30,5 @@ func NewCounter(string) *Metric { return &zeroMetric } func NewGauge(string) *Metric { return &zeroMetric } func NewAggregateCounter(string) *Metric { return &zeroMetric } func NewCounterFunc(string, func() int64) *Metric { return &zeroMetric } + +func ResetForTest(any) {} diff --git a/util/cstruct/cstruct.go b/util/cstruct/cstruct.go deleted file mode 100644 index afb0150bb..000000000 --- a/util/cstruct/cstruct.go +++ /dev/null @@ -1,177 +0,0 @@ -// Copyright (c) Tailscale Inc & contributors -// SPDX-License-Identifier: BSD-3-Clause - -// Package cstruct provides a helper for decoding binary data that is in the -// form of a padded C structure. -package cstruct - -import ( - "encoding/binary" - "errors" - "io" -) - -// Size of a pointer-typed value, in bits -const pointerSize = 32 << (^uintptr(0) >> 63) - -// We assume that non-64-bit platforms are 32-bit; we don't expect Go to run on -// a 16- or 8-bit architecture any time soon. -const is64Bit = pointerSize == 64 - -// Decoder reads and decodes padded fields from a slice of bytes. All fields -// are decoded with native endianness. -// -// Methods of a Decoder do not return errors, but rather store any error within -// the Decoder. The first error can be obtained via the Err method; after the -// first error, methods will return the zero value for their type. -type Decoder struct { - b []byte - off int - err error - dbuf [8]byte // for decoding -} - -// NewDecoder creates a Decoder from a byte slice. -func NewDecoder(b []byte) *Decoder { - return &Decoder{b: b} -} - -var errUnsupportedSize = errors.New("unsupported size") - -func padBytes(offset, size int) int { - if offset == 0 || size == 1 { - return 0 - } - remainder := offset % size - return size - remainder -} - -func (d *Decoder) getField(b []byte) error { - size := len(b) - - // We only support fields that are multiples of 2 (or 1-sized) - if size != 1 && size&1 == 1 { - return errUnsupportedSize - } - - // Fields are aligned to their size - padBytes := padBytes(d.off, size) - if d.off+size+padBytes > len(d.b) { - return io.EOF - } - d.off += padBytes - - copy(b, d.b[d.off:d.off+size]) - d.off += size - return nil -} - -// Err returns the first error that was encountered by this Decoder. -func (d *Decoder) Err() error { - return d.err -} - -// Offset returns the current read offset for data in the buffer. -func (d *Decoder) Offset() int { - return d.off -} - -// Byte returns a single byte from the buffer. -func (d *Decoder) Byte() byte { - if d.err != nil { - return 0 - } - - if err := d.getField(d.dbuf[0:1]); err != nil { - d.err = err - return 0 - } - return d.dbuf[0] -} - -// Byte returns a number of bytes from the buffer based on the size of the -// input slice. No padding is applied. -// -// If an error is encountered or this Decoder has previously encountered an -// error, no changes are made to the provided buffer. -func (d *Decoder) Bytes(b []byte) { - if d.err != nil { - return - } - - // No padding for byte slices - size := len(b) - if d.off+size >= len(d.b) { - d.err = io.EOF - return - } - copy(b, d.b[d.off:d.off+size]) - d.off += size -} - -// Uint16 returns a uint16 decoded from the buffer. -func (d *Decoder) Uint16() uint16 { - if d.err != nil { - return 0 - } - - if err := d.getField(d.dbuf[0:2]); err != nil { - d.err = err - return 0 - } - return binary.NativeEndian.Uint16(d.dbuf[0:2]) -} - -// Uint32 returns a uint32 decoded from the buffer. -func (d *Decoder) Uint32() uint32 { - if d.err != nil { - return 0 - } - - if err := d.getField(d.dbuf[0:4]); err != nil { - d.err = err - return 0 - } - return binary.NativeEndian.Uint32(d.dbuf[0:4]) -} - -// Uint64 returns a uint64 decoded from the buffer. -func (d *Decoder) Uint64() uint64 { - if d.err != nil { - return 0 - } - - if err := d.getField(d.dbuf[0:8]); err != nil { - d.err = err - return 0 - } - return binary.NativeEndian.Uint64(d.dbuf[0:8]) -} - -// Uintptr returns a uintptr decoded from the buffer. -func (d *Decoder) Uintptr() uintptr { - if d.err != nil { - return 0 - } - - if is64Bit { - return uintptr(d.Uint64()) - } else { - return uintptr(d.Uint32()) - } -} - -// Int16 returns a int16 decoded from the buffer. -func (d *Decoder) Int16() int16 { - return int16(d.Uint16()) -} - -// Int32 returns a int32 decoded from the buffer. -func (d *Decoder) Int32() int32 { - return int32(d.Uint32()) -} - -// Int64 returns a int64 decoded from the buffer. -func (d *Decoder) Int64() int64 { - return int64(d.Uint64()) -} diff --git a/util/cstruct/cstruct_example_test.go b/util/cstruct/cstruct_example_test.go deleted file mode 100644 index a665abe35..000000000 --- a/util/cstruct/cstruct_example_test.go +++ /dev/null @@ -1,73 +0,0 @@ -// Copyright (c) Tailscale Inc & contributors -// SPDX-License-Identifier: BSD-3-Clause - -// Only built on 64-bit platforms to avoid complexity - -//go:build amd64 || arm64 || mips64le || ppc64le || riscv64 - -package cstruct - -import "fmt" - -// This test provides a semi-realistic example of how you can -// use this package to decode a C structure. -func ExampleDecoder() { - // Our example C structure: - // struct mystruct { - // char *p; - // char c; - // /* implicit: char _pad[3]; */ - // int x; - // }; - // - // The Go structure definition: - type myStruct struct { - Ptr uintptr - Ch byte - Intval uint32 - } - - // Our "in-memory" version of the above structure - buf := []byte{ - 1, 2, 3, 4, 0, 0, 0, 0, // ptr - 5, // ch - 99, 99, 99, // padding - 78, 6, 0, 0, // x - } - d := NewDecoder(buf) - - // Decode the structure; if one of these function returns an error, - // then subsequent decoder functions will return the zero value. - var x myStruct - x.Ptr = d.Uintptr() - x.Ch = d.Byte() - x.Intval = d.Uint32() - - // Note that per the Go language spec: - // [...] when evaluating the operands of an expression, assignment, - // or return statement, all function calls, method calls, and - // (channel) communication operations are evaluated in lexical - // left-to-right order - // - // Since each field is assigned via a function call, one could use the - // following snippet to decode the struct. - // x := myStruct{ - // Ptr: d.Uintptr(), - // Ch: d.Byte(), - // Intval: d.Uint32(), - // } - // - // However, this means that reordering the fields in the initialization - // statement–normally a semantically identical operation–would change - // the way the structure is parsed. Thus we do it as above with - // explicit ordering. - - // After finishing with the decoder, check errors - if err := d.Err(); err != nil { - panic(err) - } - - // Print the decoder offset and structure - fmt.Printf("off=%d struct=%#v\n", d.Offset(), x) - // Output: off=16 struct=cstruct.myStruct{Ptr:0x4030201, Ch:0x5, Intval:0x64e} -} diff --git a/util/cstruct/cstruct_test.go b/util/cstruct/cstruct_test.go deleted file mode 100644 index 95d4876ca..000000000 --- a/util/cstruct/cstruct_test.go +++ /dev/null @@ -1,150 +0,0 @@ -// Copyright (c) Tailscale Inc & contributors -// SPDX-License-Identifier: BSD-3-Clause - -package cstruct - -import ( - "errors" - "fmt" - "io" - "testing" -) - -func TestPadBytes(t *testing.T) { - testCases := []struct { - offset int - size int - want int - }{ - // No padding at beginning of structure - {0, 1, 0}, - {0, 2, 0}, - {0, 4, 0}, - {0, 8, 0}, - - // No padding for single bytes - {1, 1, 0}, - - // Single byte padding - {1, 2, 1}, - {3, 4, 1}, - - // Multi-byte padding - {1, 4, 3}, - {2, 8, 6}, - } - for _, tc := range testCases { - t.Run(fmt.Sprintf("%d_%d_%d", tc.offset, tc.size, tc.want), func(t *testing.T) { - got := padBytes(tc.offset, tc.size) - if got != tc.want { - t.Errorf("got=%d; want=%d", got, tc.want) - } - }) - } -} - -func TestDecoder(t *testing.T) { - t.Run("UnsignedTypes", func(t *testing.T) { - dec := func(n int) *Decoder { - buf := make([]byte, n) - buf[0] = 1 - - d := NewDecoder(buf) - - // Use t.Cleanup to perform an assertion on this - // decoder after the test code is finished with it. - t.Cleanup(func() { - if err := d.Err(); err != nil { - t.Fatal(err) - } - }) - return d - } - if got := dec(2).Uint16(); got != 1 { - t.Errorf("uint16: got=%d; want=1", got) - } - if got := dec(4).Uint32(); got != 1 { - t.Errorf("uint32: got=%d; want=1", got) - } - if got := dec(8).Uint64(); got != 1 { - t.Errorf("uint64: got=%d; want=1", got) - } - if got := dec(pointerSize / 8).Uintptr(); got != 1 { - t.Errorf("uintptr: got=%d; want=1", got) - } - }) - - t.Run("SignedTypes", func(t *testing.T) { - dec := func(n int) *Decoder { - // Make a buffer of the exact size that consists of 0xff bytes - buf := make([]byte, n) - for i := range n { - buf[i] = 0xff - } - - d := NewDecoder(buf) - - // Use t.Cleanup to perform an assertion on this - // decoder after the test code is finished with it. - t.Cleanup(func() { - if err := d.Err(); err != nil { - t.Fatal(err) - } - }) - return d - } - if got := dec(2).Int16(); got != -1 { - t.Errorf("int16: got=%d; want=-1", got) - } - if got := dec(4).Int32(); got != -1 { - t.Errorf("int32: got=%d; want=-1", got) - } - if got := dec(8).Int64(); got != -1 { - t.Errorf("int64: got=%d; want=-1", got) - } - }) - - t.Run("InsufficientData", func(t *testing.T) { - dec := func(n int) *Decoder { - // Make a buffer that's too small and contains arbitrary bytes - buf := make([]byte, n-1) - for i := range n - 1 { - buf[i] = 0xAD - } - - // Use t.Cleanup to perform an assertion on this - // decoder after the test code is finished with it. - d := NewDecoder(buf) - t.Cleanup(func() { - if err := d.Err(); err == nil || !errors.Is(err, io.EOF) { - t.Errorf("(n=%d) expected io.EOF; got=%v", n, err) - } - }) - return d - } - - dec(2).Uint16() - dec(4).Uint32() - dec(8).Uint64() - dec(pointerSize / 8).Uintptr() - - dec(2).Int16() - dec(4).Int32() - dec(8).Int64() - }) - - t.Run("Bytes", func(t *testing.T) { - d := NewDecoder([]byte("hello worldasdf")) - t.Cleanup(func() { - if err := d.Err(); err != nil { - t.Fatal(err) - } - }) - - buf := make([]byte, 11) - d.Bytes(buf) - if got := string(buf); got != "hello world" { - t.Errorf("bytes: got=%q; want=%q", got, "hello world") - } - }) -} diff --git a/util/deephash/deephash_test.go b/util/deephash/deephash_test.go index 7ea83566f..a82203d50 100644 --- a/util/deephash/deephash_test.go +++ b/util/deephash/deephash_test.go @@ -59,7 +59,7 @@ func TestHash(t *testing.T) { I16 int16 I32 int32 I64 int64 - I int + Int int U8 uint8 U16 uint16 U32 uint32 @@ -92,7 +92,7 @@ func TestHash(t *testing.T) { {in: tuple{scalars{I16: math.MinInt16}, scalars{I16: math.MinInt16 / 2}}, wantEq: false}, {in: tuple{scalars{I32: math.MinInt32}, scalars{I32: math.MinInt32 / 2}}, wantEq: false}, {in: tuple{scalars{I64: math.MinInt64}, scalars{I64: math.MinInt64 / 2}}, wantEq: false}, - {in: tuple{scalars{I: -1234}, scalars{I: -1234 / 2}}, wantEq: false}, + {in: tuple{scalars{Int: -1234}, scalars{Int: -1234 / 2}}, wantEq: false}, {in: tuple{scalars{U8: math.MaxUint8}, scalars{U8: math.MaxUint8 / 2}}, wantEq: false}, {in: tuple{scalars{U16: math.MaxUint16}, scalars{U16: math.MaxUint16 / 2}}, wantEq: false}, {in: tuple{scalars{U32: math.MaxUint32}, scalars{U32: math.MaxUint32 / 2}}, wantEq: false}, diff --git a/util/eventbus/bench_test.go b/util/eventbus/bench_test.go index 7cd7a4241..9657c4e6f 100644 --- a/util/eventbus/bench_test.go +++ b/util/eventbus/bench_test.go @@ -39,6 +39,27 @@ func BenchmarkBasicThroughput(b *testing.B) { bus.Close() } +// BenchmarkBasicFuncThroughput is the SubscribeFunc analogue of +// BenchmarkBasicThroughput: one publisher and one SubscribeFunc +// callback, shoveling events as fast as they can through the +// plumbing. Useful for tracking per-event allocation behavior on the +// SubscribeFunc dispatch path. +func BenchmarkBasicFuncThroughput(b *testing.B) { + bus := eventbus.New() + pcli := bus.Client(b.Name() + "-pub") + scli := bus.Client(b.Name() + "-sub") + + type emptyEvent [0]byte + + pub := eventbus.Publish[emptyEvent](pcli) + eventbus.SubscribeFunc(scli, func(emptyEvent) {}) + + for b.Loop() { + pub.Publish(emptyEvent{}) + } + bus.Close() +} + func BenchmarkSubsThroughput(b *testing.B) { bus := eventbus.New() pcli := bus.Client(b.Name() + "-pub") diff --git a/util/eventbus/client.go b/util/eventbus/client.go index f405146ce..54a841b27 100644 --- a/util/eventbus/client.go +++ b/util/eventbus/client.go @@ -146,7 +146,10 @@ func Subscribe[T any](c *Client) *Subscriber[T] { r := c.subscribeStateLocked() s := newSubscriber[T](r, logfForCaller(c.logger())) - r.addSubscriber(s) + // Register the non-generic core with the bus rather than the typed facade, + // mirroring SubscribeFunc and Publish: this keeps the bus's outputs map + // and subscriber-interface itab out of per-T cost. + r.addSubscriber(s.core) return s } @@ -169,7 +172,11 @@ func SubscribeFunc[T any](c *Client, f func(T)) *SubscriberFunc[T] { r := c.subscribeStateLocked() s := newSubscriberFunc[T](r, f, logfForCaller(c.logger())) - r.addSubscriber(s) + // Register the non-generic core, not the typed facade. Doing + // so means the bus's outputs map and the subscriber interface + // itab are not parameterized by T, eliminating per-T itab and + // dictionary cost. + r.addSubscriber(s.core) return s } @@ -177,6 +184,11 @@ func SubscribeFunc[T any](c *Client, f func(T)) *SubscriberFunc[T] { // It panics if c is closed. func Publish[T any](c *Client) *Publisher[T] { p := newPublisher[T](c) - c.addPublisher(p) + // Register the non-generic core with the client so the + // per-Client publisher set, the publisher interface itab, and + // the publisher equality function are not parameterized by T. + // This eliminates per-T itab/dictionary/eq cost for every new + // event type passed through Publish[T]. + c.addPublisher(p.core) return p } diff --git a/util/eventbus/publish.go b/util/eventbus/publish.go index f6fd029b7..fd037ac34 100644 --- a/util/eventbus/publish.go +++ b/util/eventbus/publish.go @@ -7,8 +7,13 @@ import ( "reflect" ) -// publisher is a uniformly typed wrapper around Publisher[T], so that -// debugging facilities can look at active publishers. +// publisher is a uniformly typed wrapper around publisherCore so that +// debugging facilities can enumerate active publishers on a [Client] +// and report the types each one publishes. The interface is +// implemented by the non-generic *publisherCore (not by the typed +// user-facing *Publisher[T]); this keeps the bus's per-Client +// publisher set, and the publisher itab/dictionary, free of +// per-T duplication. type publisher interface { publishType() reflect.Type Close() @@ -16,12 +21,35 @@ type publisher interface { // A Publisher publishes typed events on a bus. type Publisher[T any] struct { + // Implementation note: Publisher[T] is a thin user-facing facade over a + // non-generic *publisherCore. Carrying T on the public type preserves the + // typed API of Publish(v T), but all of the actual state (the *Client + // back-pointer, the stop flag, and the cached reflect.Type used by + // diagnostic introspection) lives on the core and is not duplicated per T. + // + // The diagnostic surface that motivates the publisher interface + // (Debugger.PublishTypes) is served by *publisherCore directly, so adding + // new typed publishers does not pay an itab+dictionary cost just to satisfy + // diagnostic enumeration. + core *publisherCore +} + +// publisherCore is the non-generic implementation of a Publisher. +// It implements the package-private publisher interface; the bus's +// outputs map and itab key on this single type, not on Publisher[T]. +type publisherCore struct { client *Client stop stopFlag + typ reflect.Type // cached reflect.TypeFor[T]() } func newPublisher[T any](c *Client) *Publisher[T] { - return &Publisher[T]{client: c} + return &Publisher[T]{ + core: &publisherCore{ + client: c, + typ: reflect.TypeFor[T](), + }, + } } // Close closes the publisher. @@ -31,35 +59,46 @@ func newPublisher[T any](c *Client) *Publisher[T] { // If the Bus or Client from which the Publisher was created is closed, // the Publisher is implicitly closed and does not need to be closed // separately. -func (p *Publisher[T]) Close() { +func (p *Publisher[T]) Close() { p.core.Close() } + +// Close implements the publisher interface and the user-facing +// (*Publisher[T]).Close. +func (c *publisherCore) Close() { // Just unblocks any active calls to Publish, no other // synchronization needed. - p.stop.Stop() - p.client.deletePublisher(p) + c.stop.Stop() + c.client.deletePublisher(c) } -func (p *Publisher[T]) publishType() reflect.Type { - return reflect.TypeFor[T]() -} +// publishType implements the publisher interface. +func (c *publisherCore) publishType() reflect.Type { return c.typ } // Publish publishes event v on the bus. func (p *Publisher[T]) Publish(v T) { + publish(p.core, v) +} + +// publish is the non-generic body of Publisher[T].Publish. The only +// per-T work is the boxing of v into evt.Event (an `any` field) and +// the construction of the PublishedEvent struct itself; all of the +// channel/select dance is shared across every T. +func publish(c *publisherCore, v any) { // Check for just a stopped publisher or bus before trying to // write, so that once closed Publish consistently does nothing. select { - case <-p.stop.Done(): + case <-c.stop.Done(): return default: } evt := PublishedEvent{ Event: v, - From: p.client, + From: c.client, } select { - case p.client.publish() <- evt: - case <-p.stop.Done(): + case c.client.publish() <- evt: + case <-c.stop.Done(): } } @@ -70,5 +109,5 @@ func (p *Publisher[T]) Publish(v T) { // nobody seems to care. Publishers must not assume that someone will // definitely receive an event if ShouldPublish returns true. func (p *Publisher[T]) ShouldPublish() bool { - return p.client.shouldPublish(reflect.TypeFor[T]()) + return p.core.client.shouldPublish(p.core.typ) } diff --git a/util/eventbus/subscribe.go b/util/eventbus/subscribe.go index 3edf6deb4..3c3dced1f 100644 --- a/util/eventbus/subscribe.go +++ b/util/eventbus/subscribe.go @@ -184,49 +184,62 @@ func (s *subscribeState) closed() <-chan struct{} { // A Subscriber delivers one type of event from a [Client]. // Events are sent to the [Subscriber.Events] channel. type Subscriber[T any] struct { - stop stopFlag - read chan T - unregister func() - logf logger.Logf - slow *time.Timer // used to detect slow subscriber service + // core holds the non-generic subscriber-interface implementation + // (Close, subscribeType, dispatch, slow timer, unregister) shared + // with [SubscriberFunc] via [subscriberCore]. The only per-T state + // owned by the facade itself is the typed delivery channel; the + // dispatch loop, unlike SubscriberFunc, must remain per-T — see + // [Subscriber.dispatchTyped]. + core *subscriberCore + read chan T } func newSubscriber[T any](r *subscribeState, logf logger.Logf) *Subscriber[T] { - slow := time.NewTimer(0) - slow.Stop() // reset in dispatch - return &Subscriber[T]{ - read: make(chan T), - unregister: func() { r.deleteSubscriber(reflect.TypeFor[T]()) }, - logf: logf, - slow: slow, + core := newSubscriberCore(r, logf, reflect.TypeFor[T]()) + s := &Subscriber[T]{ + core: core, + read: make(chan T), } -} - -func newMonitor[T any](attach func(fn func(T)) (cancel func())) *Subscriber[T] { - ret := &Subscriber[T]{ - read: make(chan T, 100), // arbitrary, large + // Subscriber[T] keeps a per-T dispatch loop; see [Subscriber.dispatchTyped] + // for why we don't share the non-generic dispatchFunc that SubscriberFunc + // uses. + core.dispatchFn = func( + ctx context.Context, + vals *queue[DeliveredEvent], + acceptCh func() chan DeliveredEvent, + snapshot chan chan []DeliveredEvent, + ) bool { + return s.dispatchTyped(ctx, vals, acceptCh, snapshot) } - ret.unregister = attach(ret.monitor) - return ret + return s } -func (s *Subscriber[T]) subscribeType() reflect.Type { - return reflect.TypeFor[T]() -} - -func (s *Subscriber[T]) monitor(debugEvent T) { - select { - case s.read <- debugEvent: - case <-s.stop.Done(): - } -} - -func (s *Subscriber[T]) dispatch(ctx context.Context, vals *queue[DeliveredEvent], acceptCh func() chan DeliveredEvent, snapshot chan chan []DeliveredEvent) bool { +// dispatchTyped is the per-T dispatch loop for Subscriber[T]. It has to remain +// generic because the typed channel send `case s.read <- t:` must appear +// lexically inside the select. The rest of the cases match the non-generic +// dispatchFunc body to keep behavior aligned between Subscriber and +// SubscriberFunc. +// +// We don't share dispatchFunc (the way SubscriberFunc does) because bridging +// the typed channel send and the non-generic select would require running the +// send on its own goroutine on every event delivery. That bridge was measured +// at ~2.7x throughput regression on BenchmarkBasicThroughput, so we keep +// dispatchTyped generic and pay the per-shape stencil cost instead (measured +// at ~1,600 B body + ~1,100 B pclntab per shape on linux/amd64 tailscaled). +// Only the typed select lives in the per-shape stencil; the surrounding state +// (slow timer, log function, type name) is reached through the non-generic +// core. +func (s *Subscriber[T]) dispatchTyped( + ctx context.Context, + vals *queue[DeliveredEvent], + acceptCh func() chan DeliveredEvent, + snapshot chan chan []DeliveredEvent, +) bool { t := vals.Peek().Event.(T) start := time.Now() - s.slow.Reset(slowSubscriberTimeout) - defer s.slow.Stop() + s.core.slow.Reset(slowSubscriberTimeout) + defer s.core.slow.Stop() for { // Keep the cases in this select in sync with subscribeState.pump @@ -242,13 +255,35 @@ func (s *Subscriber[T]) dispatch(ctx context.Context, vals *queue[DeliveredEvent return false case ch := <-snapshot: ch <- vals.Snapshot() - case <-s.slow.C: - s.logf("subscriber for %T is slow (%v elapsed)", t, time.Since(start)) - s.slow.Reset(slowSubscriberTimeout) + case <-s.core.slow.C: + s.core.logf("subscriber for %s is slow (%v elapsed)", s.core.typeName, time.Since(start)) + s.core.slow.Reset(slowSubscriberTimeout) } } } +func newMonitor[T any](attach func(fn func(T)) (cancel func())) *Subscriber[T] { + s := &Subscriber[T]{ + // Monitors don't go through the bus's dispatch path (they + // are attached directly to the debug hook), so they don't + // need a fully-initialized subscriberCore — only the typed + // delivery channel and an unregister callback. We give them + // a placeholder core so Close() and Done() work uniformly. + core: &subscriberCore{}, + read: make(chan T, 100), // arbitrary, large + } + cancel := attach(s.monitor) + s.core.unregister = func(reflect.Type) { cancel() } + return s +} + +func (s *Subscriber[T]) monitor(debugEvent T) { + select { + case s.read <- debugEvent: + case <-s.core.stop.Done(): + } +} + // Events returns a channel on which the subscriber's events are // delivered. func (s *Subscriber[T]) Events() <-chan T { @@ -258,7 +293,7 @@ func (s *Subscriber[T]) Events() <-chan T { // Done returns a channel that is closed when the subscriber is // closed. func (s *Subscriber[T]) Done() <-chan struct{} { - return s.stop.Done() + return s.core.stop.Done() } // Close closes the Subscriber, indicating the caller no longer wishes @@ -268,30 +303,105 @@ func (s *Subscriber[T]) Done() <-chan struct{} { // If the Bus from which the Subscriber was created is closed, // the Subscriber is implicitly closed and does not need to be closed // separately. -func (s *Subscriber[T]) Close() { - s.stop.Stop() // unblock receivers - s.unregister() -} +func (s *Subscriber[T]) Close() { s.core.Close() } // A SubscriberFunc delivers one type of event from a [Client]. // Events are forwarded synchronously to a function provided at construction. type SubscriberFunc[T any] struct { + // core holds the non-generic subscriber-interface implementation shared + // with [Subscriber] via [subscriberCore]. The user callback is captured + // in the dispatchFn closure on the core, so SubscriberFunc[T] itself + // carries no per-T state beyond the core pointer; per-T cost is limited + // to the small forwarding Close method below. + core *subscriberCore +} + +// subscriberCore is the non-generic backing for both Subscriber[T] and +// SubscriberFunc[T]. It implements the package-private subscriber interface +// so that the bus (and the subscribeState map) can store it without per-T +// itabs or dictionaries. The per-T behavior (type assertion plus either typed +// channel send or user callback invocation) is encapsulated in the dispatchFn +// closure set up by the constructor of the typed facade. +type subscriberCore struct { stop stopFlag - read func(T) - unregister func() + unregister func(reflect.Type) logf logger.Logf slow *time.Timer // used to detect slow subscriber service + + // typ is the cached reflect.Type of T. Returned by + // subscribeType() and used by the dispatch closure to format + // slow-subscriber log messages. + typ reflect.Type + // typeName is the cached reflect.TypeFor[T]().String() result. + // Computed once at construction time so the dispatch closure + // (which runs once per delivered event) doesn't allocate a + // fresh string on every call. The string is also independent + // of T, so it doesn't contribute to per-T stencil cost. + typeName string + + // dispatchFn is the per-T dispatch closure. It performs the type + // assertion vals.Peek().Event.(T) and runs the typed delivery (either a + // user-callback invocation for SubscriberFunc[T] or a typed channel send + // for Subscriber[T]). The closure body is non-generic apart from those + // two T-bound operations; the bulk of the dispatch work happens in the + // non-generic dispatchFunc helper (used by SubscriberFunc) or in the + // Subscriber[T].dispatchTyped per-shape stencil. + dispatchFn func( + ctx context.Context, + vals *queue[DeliveredEvent], + acceptCh func() chan DeliveredEvent, + snapshot chan chan []DeliveredEvent, + ) bool } func newSubscriberFunc[T any](r *subscribeState, f func(T), logf logger.Logf) *SubscriberFunc[T] { + core := newSubscriberCore(r, logf, reflect.TypeFor[T]()) + // The dispatch closure is the only piece that intrinsically + // needs T: it performs the type assertion on the head queue + // value and forwards the unboxed value to the user callback. + // All non-generic setup (timer, core allocation, unregister + // closure) lives in newSubscriberCore so it isn't + // duplicated per T. + core.dispatchFn = func( + ctx context.Context, + vals *queue[DeliveredEvent], + acceptCh func() chan DeliveredEvent, + snapshot chan chan []DeliveredEvent, + ) bool { + t := vals.Peek().Event.(T) + callDone := make(chan struct{}) + // `go runFuncCallback(f, t, callDone)` binds its arguments + // directly to the new goroutine's frame; using a closure + // (`go func() { f(t) }()`) would allocate a closure on the + // heap on every dispatched event. + go runFuncCallback(f, t, callDone) + return dispatchFunc(ctx, core, vals, acceptCh, snapshot, callDone) + } + return &SubscriberFunc[T]{core: core} +} + +// newSubscriberCore performs the non-generic portion of subscriber +// construction: timer setup, core struct allocation, and assignment of the +// unregister method-value. The caller fills in the per-T dispatchFn +// afterward. +// +// Hoisting this out of the typed constructors (newSubscriber[T] and +// newSubscriberFunc[T]) eliminates the bulk of their per-T stencil cost; the +// only T-typed instructions left in each generic constructor are the +// reflect.TypeFor[T]() call (whose body is shared via the +// internal/abi.TypeFor[T] dictionary) and the construction of the dispatch +// closure itself. +func newSubscriberCore(r *subscribeState, logf logger.Logf, typ reflect.Type) *subscriberCore { slow := time.NewTimer(0) slow.Stop() // reset in dispatch - return &SubscriberFunc[T]{ - read: f, - unregister: func() { r.deleteSubscriber(reflect.TypeFor[T]()) }, - logf: logf, - slow: slow, + core := &subscriberCore{ + logf: logf, + slow: slow, + typ: typ, + typeName: typ.String(), } + core.unregister = r.deleteSubscriber + return core } // Close closes the SubscriberFunc, indicating the caller no longer wishes to @@ -300,24 +410,54 @@ func newSubscriberFunc[T any](r *subscribeState, f func(T), logf logger.Logf) *S // // If the [Bus] from which s was created is closed, s is implicitly closed and // does not need to be closed separately. -func (s *SubscriberFunc[T]) Close() { s.stop.Stop(); s.unregister() } +func (s *SubscriberFunc[T]) Close() { s.core.Close() } -// subscribeType implements part of the subscriber interface. -func (s *SubscriberFunc[T]) subscribeType() reflect.Type { return reflect.TypeFor[T]() } +// Close implements the subscriber interface and the user-facing Close on +// both Subscriber[T] and SubscriberFunc[T]. +func (c *subscriberCore) Close() { + c.stop.Stop() + c.unregister(c.typ) +} -// dispatch implements part of the subscriber interface. -func (s *SubscriberFunc[T]) dispatch(ctx context.Context, vals *queue[DeliveredEvent], acceptCh func() chan DeliveredEvent, snapshot chan chan []DeliveredEvent) bool { - t := vals.Peek().Event.(T) - callDone := make(chan struct{}) - go s.runCallback(t, callDone) +// subscribeType implements the subscriber interface. +func (c *subscriberCore) subscribeType() reflect.Type { return c.typ } +// dispatch implements the subscriber interface by invoking the +// per-T dispatch closure that was captured at construction time. +func (c *subscriberCore) dispatch( + ctx context.Context, + vals *queue[DeliveredEvent], + acceptCh func() chan DeliveredEvent, + snapshot chan chan []DeliveredEvent, +) bool { + return c.dispatchFn(ctx, vals, acceptCh, snapshot) +} + +// dispatchFunc is the non-generic body of SubscriberFunc[T].dispatch. +// It is identical in observable behavior to the original loop; the +// only differences are that the dispatched value has already been +// unboxed by the caller (and the user callback is already running +// on its own goroutine, signaling completion via callDone) and the +// slow-subscriber timer / cached type name come from the +// non-generic core, not from a per-T struct. +// +// callDone is closed by runFuncCallback when the user callback returns. +func dispatchFunc( + ctx context.Context, + core *subscriberCore, + vals *queue[DeliveredEvent], + acceptCh func() chan DeliveredEvent, + snapshot chan chan []DeliveredEvent, + callDone chan struct{}, +) bool { start := time.Now() - s.slow.Reset(slowSubscriberTimeout) - defer s.slow.Stop() + core.slow.Reset(slowSubscriberTimeout) + defer core.slow.Stop() // Keep the cases in this select in sync with subscribeState.pump // above. The only difference should be that this select - // delivers a value by calling s.read. + // delivers a value by calling the user callback (via the + // goroutine spawned by the typed wrapper). for { select { case <-callDone: @@ -327,30 +467,35 @@ func (s *SubscriberFunc[T]) dispatch(ctx context.Context, vals *queue[DeliveredE vals.Add(val) case <-ctx.Done(): // Wait for the callback to be complete, but not forever. - s.slow.Reset(5 * slowSubscriberTimeout) + core.slow.Reset(5 * slowSubscriberTimeout) select { - case <-s.slow.C: - s.logf("giving up on subscriber for %T after %v at close", t, time.Since(start)) + case <-core.slow.C: + core.logf("giving up on subscriber for %s after %v at close", core.typeName, time.Since(start)) if cibuild.On() { all := make([]byte, 2<<20) n := runtime.Stack(all, true) - s.logf("goroutine stacks:\n%s", all[:n]) + core.logf("goroutine stacks:\n%s", all[:n]) } case <-callDone: } return false case ch := <-snapshot: ch <- vals.Snapshot() - case <-s.slow.C: - s.logf("subscriber for %T is slow (%v elapsed)", t, time.Since(start)) - s.slow.Reset(slowSubscriberTimeout) + case <-core.slow.C: + core.logf("subscriber for %s is slow (%v elapsed)", core.typeName, time.Since(start)) + core.slow.Reset(slowSubscriberTimeout) } } } -// runCallback invokes the callback on v and closes ch when it returns. -// This should be run in a goroutine. -func (s *SubscriberFunc[T]) runCallback(v T, ch chan struct{}) { - defer close(ch) - s.read(v) +// runFuncCallback runs f(t) and closes done when it returns. It is +// the per-T worker spawned as a goroutine for each dispatched +// event. Keeping it as a regular generic function (rather than a +// closure) means `go runFuncCallback(f, t, done)` binds its +// arguments to the goroutine's frame directly, with no per-event +// closure allocation. The body is small (defer + one indirect +// call), so the per-shape stencil cost is minimal. +func runFuncCallback[T any](f func(T), t T, done chan struct{}) { + defer close(done) + f(t) } diff --git a/util/expvarx/expvarx.go b/util/expvarx/expvarx.go deleted file mode 100644 index 6dc2379b9..000000000 --- a/util/expvarx/expvarx.go +++ /dev/null @@ -1,89 +0,0 @@ -// Copyright (c) Tailscale Inc & contributors -// SPDX-License-Identifier: BSD-3-Clause - -// Package expvarx provides some extensions to the [expvar] package. -package expvarx - -import ( - "encoding/json" - "expvar" - "time" - - "tailscale.com/syncs" - "tailscale.com/types/lazy" -) - -// SafeFunc is a wrapper around [expvar.Func] that guards against unbounded call -// time and ensures that only a single call is in progress at any given time. -type SafeFunc struct { - f expvar.Func - limit time.Duration - onSlow func(time.Duration, any) - - mu syncs.Mutex - inflight *lazy.SyncValue[any] -} - -// NewSafeFunc returns a new SafeFunc that wraps f. -// If f takes longer than limit to execute then Value calls return nil. -// If onSlow is non-nil, it is called when f takes longer than limit to execute. -// onSlow is called with the duration of the slow call and the final computed -// value. -func NewSafeFunc(f expvar.Func, limit time.Duration, onSlow func(time.Duration, any)) *SafeFunc { - return &SafeFunc{f: f, limit: limit, onSlow: onSlow} -} - -// Value acts similarly to [expvar.Func.Value], but if the underlying function -// takes longer than the configured limit, all callers will receive nil until -// the underlying operation completes. On completion of the underlying -// operation, the onSlow callback is called if set. -func (s *SafeFunc) Value() any { - s.mu.Lock() - - if s.inflight == nil { - s.inflight = new(lazy.SyncValue[any]) - } - var inflight = s.inflight - s.mu.Unlock() - - // inflight ensures that only a single work routine is spawned at any given - // time, but if the routine takes too long inflight is populated with a nil - // result. The long running computed value is lost forever. - return inflight.Get(func() any { - start := time.Now() - result := make(chan any, 1) - - // work is spawned in routine so that the caller can timeout. - go func() { - // Allow new work to be started after this work completes - defer func() { - s.mu.Lock() - s.inflight = nil - s.mu.Unlock() - - }() - - v := s.f.Value() - result <- v - }() - - select { - case v := <-result: - return v - case <-time.After(s.limit): - if s.onSlow != nil { - go func() { - s.onSlow(time.Since(start), <-result) - }() - } - return nil - } - }) -} - -// String implements stringer in the same pattern as [expvar.Func], calling -// Value and serializing the result as JSON, ignoring errors. -func (s *SafeFunc) String() string { - v, _ := json.Marshal(s.Value()) - return string(v) -} diff --git a/util/expvarx/expvarx_test.go b/util/expvarx/expvarx_test.go deleted file mode 100644 index f8d2139d3..000000000 --- a/util/expvarx/expvarx_test.go +++ /dev/null @@ -1,141 +0,0 @@ -// Copyright (c) Tailscale Inc & contributors -// SPDX-License-Identifier: BSD-3-Clause - -package expvarx - -import ( - "expvar" - "fmt" - "sync" - "sync/atomic" - "testing" - "testing/synctest" - "time" -) - -func ExampleNewSafeFunc() { - // An artificial blocker to emulate a slow operation. - blocker := make(chan struct{}) - - // limit is the amount of time a call can take before Value returns nil. No - // new calls to the unsafe func will be started until the slow call - // completes, at which point onSlow will be called. - limit := time.Millisecond - - // onSlow is called with the final call duration and the final value in the - // event a slow call. - onSlow := func(d time.Duration, v any) { - _ = d // d contains the time the call took - _ = v // v contains the final value computed by the slow call - fmt.Println("slow call!") - } - - // An unsafe expvar.Func that blocks on the blocker channel. - unsafeFunc := expvar.Func(func() any { - for range blocker { - } - return "hello world" - }) - - // f implements the same interface as expvar.Func, but returns nil values - // when the unsafe func is too slow. - f := NewSafeFunc(unsafeFunc, limit, onSlow) - - fmt.Println(f.Value()) - fmt.Println(f.Value()) - close(blocker) - time.Sleep(time.Millisecond) - fmt.Println(f.Value()) - // Output: - // - // slow call! - // hello world -} - -func TestSafeFuncHappyPath(t *testing.T) { - synctest.Test(t, func(t *testing.T) { - var count int - f := NewSafeFunc(expvar.Func(func() any { - count++ - return count - }), time.Second, nil) - - if got, want := f.Value(), 1; got != want { - t.Errorf("got %v, want %v", got, want) - } - time.Sleep(5 * time.Second) // (fake time in synctest) - if got, want := f.Value(), 2; got != want { - t.Errorf("got %v, want %v", got, want) - } - }) -} - -func TestSafeFuncSlow(t *testing.T) { - var count int - blocker := make(chan struct{}) - var wg sync.WaitGroup - wg.Add(1) - f := NewSafeFunc(expvar.Func(func() any { - defer wg.Done() - count++ - <-blocker - return count - }), time.Millisecond, nil) - - if got := f.Value(); got != nil { - t.Errorf("got %v; want nil", got) - } - if got := f.Value(); got != nil { - t.Errorf("got %v; want nil", got) - } - - close(blocker) - wg.Wait() - - if count != 1 { - t.Errorf("got count=%d; want 1", count) - } -} - -func TestSafeFuncSlowOnSlow(t *testing.T) { - var count int - blocker := make(chan struct{}) - var wg sync.WaitGroup - wg.Add(2) - var slowDuration atomic.Pointer[time.Duration] - var slowCallCount atomic.Int32 - var slowValue atomic.Value - f := NewSafeFunc(expvar.Func(func() any { - defer wg.Done() - count++ - <-blocker - return count - }), time.Millisecond, func(d time.Duration, v any) { - defer wg.Done() - slowDuration.Store(&d) - slowCallCount.Add(1) - slowValue.Store(v) - }) - - for range 10 { - if got := f.Value(); got != nil { - t.Fatalf("got value=%v; want nil", got) - } - } - - close(blocker) - wg.Wait() - - if count != 1 { - t.Errorf("got count=%d; want 1", count) - } - if got, want := *slowDuration.Load(), 1*time.Millisecond; got < want { - t.Errorf("got slowDuration=%v; want at least %d", got, want) - } - if got, want := slowCallCount.Load(), int32(1); got != want { - t.Errorf("got slowCallCount=%d; want %d", got, want) - } - if got, want := slowValue.Load().(int), 1; got != want { - t.Errorf("got slowValue=%d, want %d", got, want) - } -} diff --git a/util/httpm/httpm_test.go b/util/httpm/httpm_test.go index 4a36a38e1..e8342a74f 100644 --- a/util/httpm/httpm_test.go +++ b/util/httpm/httpm_test.go @@ -24,6 +24,13 @@ func TestUsedConsistently(t *testing.T) { t.Skipf("skipping test since .git doesn't exist: %v", err) } + // Open .git/index so Go's test cache tracks it as an input. + // The index file changes on git reset, checkout, pull, etc., + // so the cache is properly invalidated when moving between commits. + if f, err := os.Open(filepath.Join(rootDir, ".git", "index")); err == nil { + f.Close() + } + cmd := exec.Command("git", "grep", "-l", "-F", "http.Method") cmd.Dir = rootDir matches, _ := cmd.Output() diff --git a/util/linuxfw/fake_netfilter.go b/util/linuxfw/fake_netfilter.go index eac5d904c..e9d853508 100644 --- a/util/linuxfw/fake_netfilter.go +++ b/util/linuxfw/fake_netfilter.go @@ -95,3 +95,5 @@ func (f *FakeNetfilterRunner) DeleteSvc(svc, tun string, targetIPs []netip.Addr, func (f *FakeNetfilterRunner) EnsurePortMapRuleForSvc(svc, tun string, targetIP netip.Addr, pm PortMap) error { return nil } +func (f *FakeNetfilterRunner) AddExternalCGNATRules(mode CGNATMode, tunname string) error { return nil } +func (f *FakeNetfilterRunner) DelExternalCGNATRules(mode CGNATMode, tunname string) error { return nil } diff --git a/util/linuxfw/iptables_runner.go b/util/linuxfw/iptables_runner.go index 0d50bdd61..868114ac8 100644 --- a/util/linuxfw/iptables_runner.go +++ b/util/linuxfw/iptables_runner.go @@ -214,23 +214,8 @@ func (i *iptablesRunner) AddBase(tunname string) error { // addBase4 adds some basic IPv4 processing rules to be // supplemented by later calls to other helpers. func (i *iptablesRunner) addBase4(tunname string) error { - // Only allow CGNAT range traffic to come from tailscale0. There - // is an exception carved out for ranges used by ChromeOS, for - // which we fall out of the Tailscale chain. - // - // Note, this will definitely break nodes that end up using the - // CGNAT range for other purposes :(. - args := []string{"!", "-i", tunname, "-s", tsaddr.ChromeOSVMRange().String(), "-j", "RETURN"} - if err := i.ipt4.Append("filter", "ts-input", args...); err != nil { - return fmt.Errorf("adding %v in v4/filter/ts-input: %w", args, err) - } - args = []string{"!", "-i", tunname, "-s", tsaddr.CGNATRange().String(), "-j", "DROP"} - if err := i.ipt4.Append("filter", "ts-input", args...); err != nil { - return fmt.Errorf("adding %v in v4/filter/ts-input: %w", args, err) - } - - // Explicitly allow all other inbound traffic to the tun interface - args = []string{"-i", tunname, "-j", "ACCEPT"} + // Explicitly allow all inbound traffic to the tun interface + args := []string{"-i", tunname, "-j", "ACCEPT"} if err := i.ipt4.Append("filter", "ts-input", args...); err != nil { return fmt.Errorf("adding %v in v4/filter/ts-input: %w", args, err) } @@ -551,10 +536,14 @@ func (i *iptablesRunner) AddConnmarkSaveRule() error { // mangle/PREROUTING: Restore mark from conntrack for ESTABLISHED/RELATED connections // This runs BEFORE routing decision and rp_filter check + // The connmark check ensures we only restore when Tailscale has marked the connection, + // preventing us from wiping mark bits set by other systems when ct mark is zero. for _, ipt := range i.getTables() { args := []string{ "-m", "conntrack", "--ctstate", "ESTABLISHED,RELATED", + "-m", "connmark", + "!", "--mark", "0x0/" + fwmarkMask, // Only restore if ct mark has Tailscale bits set "-j", "CONNMARK", "--restore-mark", "--nfmask", fwmarkMask, @@ -592,6 +581,8 @@ func (i *iptablesRunner) DelConnmarkSaveRule() error { args := []string{ "-m", "conntrack", "--ctstate", "ESTABLISHED,RELATED", + "-m", "connmark", + "!", "--mark", "0x0/" + fwmarkMask, "-j", "CONNMARK", "--restore-mark", "--nfmask", fwmarkMask, @@ -682,6 +673,67 @@ func (i *iptablesRunner) DelMagicsockPortRule(port uint16, network string) error return nil } +// buildExternalCGNATRules abstracts out logic for constructing firewall rules +// for handling non-Tailscale CGNAT traffic, since these rules need to be +// identical across [AddExternalCGNATRules] and [DelExternalCGNATRules]. +func buildExternalCGNATRules(mode CGNATMode, tunname string) ([][]string, error) { + switch mode { + case CGNATModeDrop: + // Only allow CGNAT range traffic to come from the Tailscale interface. + // There is an exception carved out for ranges used by ChromeOS, for + // which we fall out of the Tailscale chain. + return [][]string{ + {"!", "-i", tunname, "-s", tsaddr.ChromeOSVMRange().String(), "-j", "RETURN"}, + {"!", "-i", tunname, "-s", tsaddr.CGNATRange().String(), "-j", "DROP"}, + }, nil + case CGNATModeReturn: + // Fall out of the Tailscale chain for CGNAT traffic that doesn't + // originate from the Tailscale interface. + return [][]string{ + {"!", "-i", tunname, "-s", tsaddr.CGNATRange().String(), "-j", "RETURN"}, + }, nil + default: + return nil, fmt.Errorf("unsupported mode %q", mode) + } +} + +// AddExternalCGNATRules adds rules to the ts-input chain to deal with +// traffic from the CGNAT range that arrives on non-Tailscale network +// interfaces. +func (i *iptablesRunner) AddExternalCGNATRules(mode CGNATMode, tunname string) error { + rules, err := buildExternalCGNATRules(mode, tunname) + if err != nil { + return fmt.Errorf("build cgnat mode rule: %v", err) + } + for _, rule := range rules { + if err := i.ipt4.Append("filter", "ts-input", rule...); err != nil { + return fmt.Errorf("adding %v in v4/filter/ts-input: %w", rule, err) + } + } + return nil +} + +// DelExternalCGNATRules removes the rules created by AddExternalCGNATRules, +// if they exist. +func (i *iptablesRunner) DelExternalCGNATRules(mode CGNATMode, tunname string) error { + rules, err := buildExternalCGNATRules(mode, tunname) + if err != nil { + return fmt.Errorf("build cgnat mode rule: %v", err) + } + for _, rule := range rules { + if found, err := i.ipt4.Exists("filter", "ts-input", rule...); err != nil { + return fmt.Errorf("checking for %v in v4/filter/ts-input: %w", rule, err) + } else if !found { + // Don't need to delete a rule that isn't there. + continue + } + if err := i.ipt4.Delete("filter", "ts-input", rule...); err != nil { + return fmt.Errorf("deleting %v in v4/filter/ts-input: %w", rule, err) + } + } + return nil +} + // delTSHook deletes hook in a chain that jumps to a ts-chain. If the hook does not // exist, it's a no-op since the desired state is already achieved but we log the // error because error code from the iptables module resists unwrapping. diff --git a/util/linuxfw/iptables_runner_test.go b/util/linuxfw/iptables_runner_test.go index 77c753004..5bf624ef4 100644 --- a/util/linuxfw/iptables_runner_test.go +++ b/util/linuxfw/iptables_runner_test.go @@ -126,8 +126,6 @@ func TestAddAndDeleteBase(t *testing.T) { // Check that the rules were created. tsRulesV4 := []fakeRule{ // table/chain/rule - {"filter", "ts-input", []string{"!", "-i", tunname, "-s", tsaddr.ChromeOSVMRange().String(), "-j", "RETURN"}}, - {"filter", "ts-input", []string{"!", "-i", tunname, "-s", tsaddr.CGNATRange().String(), "-j", "DROP"}}, {"filter", "ts-forward", []string{"-o", tunname, "-s", tsaddr.CGNATRange().String(), "-j", "DROP"}}, } @@ -369,6 +367,8 @@ func TestAddAndDelConnmarkSaveRule(t *testing.T) { preroutingArgs := []string{ "-m", "conntrack", "--ctstate", "ESTABLISHED,RELATED", + "-m", "connmark", + "!", "--mark", "0x0/0xff0000", "-j", "CONNMARK", "--restore-mark", "--nfmask", "0xff0000", @@ -504,3 +504,56 @@ func TestAddAndDelConnmarkSaveRule(t *testing.T) { } }) } + +func TestAddAndDelCGNATRules(t *testing.T) { + iptr := newFakeIPTablesRunner() + tunname := "tun0" + + // We need the chains to exist so we can add rules into them. + if err := iptr.AddChains(); err != nil { + t.Fatal(err) + } + + tests := []struct { + mode CGNATMode + wantRules []fakeRule + }{ + { + CGNATModeDrop, []fakeRule{ + {"filter", "ts-input", []string{"!", "-i", tunname, "-s", tsaddr.ChromeOSVMRange().String(), "-j", "RETURN"}}, + {"filter", "ts-input", []string{"!", "-i", tunname, "-s", tsaddr.CGNATRange().String(), "-j", "DROP"}}, + }, + }, + { + CGNATModeReturn, []fakeRule{ + {"filter", "ts-input", []string{"!", "-i", tunname, "-s", tsaddr.CGNATRange().String(), "-j", "RETURN"}}, + }, + }, + } + + for _, tt := range tests { + if err := iptr.AddExternalCGNATRules(tt.mode, tunname); err != nil { + t.Fatal(err) + } + + for _, tr := range tt.wantRules { + if exists, err := iptr.ipt4.Exists(tr.table, tr.chain, tr.args...); err != nil { + t.Fatalf("mode %q: error checking for rule: %v", tt.mode, err) + } else if !exists { + t.Errorf("mode %q: rule %s/%s/%s doesn't exist", tt.mode, tr.table, tr.chain, strings.Join(tr.args, " ")) + } + } + + if err := iptr.DelExternalCGNATRules(tt.mode, tunname); err != nil { + t.Fatal(err) + } + + for _, tr := range tt.wantRules { + if exists, err := iptr.ipt4.Exists(tr.table, tr.chain, tr.args...); err != nil { + t.Fatalf("mode %q: error checking for rule: %v", tt.mode, err) + } else if exists { + t.Errorf("mode %q: rule %s/%s/%s not deleted", tt.mode, tr.table, tr.chain, strings.Join(tr.args, " ")) + } + } + } +} diff --git a/util/linuxfw/linuxfw.go b/util/linuxfw/linuxfw.go index 325a5809f..eb2da1896 100644 --- a/util/linuxfw/linuxfw.go +++ b/util/linuxfw/linuxfw.go @@ -7,6 +7,7 @@ package linuxfw import ( + "encoding/binary" "errors" "fmt" "os" @@ -53,6 +54,13 @@ const ( FirewallModeNfTables FirewallMode = "nftables" ) +type CGNATMode string + +const ( + CGNATModeDrop CGNATMode = "DROP" + CGNATModeReturn CGNATMode = "RETURN" +) + // The following bits are added to packet marks for Tailscale use. // // We tried to pick bits sufficiently out of the way that it's @@ -79,19 +87,28 @@ const ( bypassMarkNum = tsconst.LinuxBypassMarkNum ) -// getTailscaleFwmarkMaskNeg returns the negation of TailscaleFwmarkMask in bytes. +// getTailscaleFwmarkMaskNeg returns the negation of TailscaleFwmarkMask +// in native byte order. func getTailscaleFwmarkMaskNeg() []byte { - return []byte{0xff, 0x00, 0xff, 0xff} + return nativeEndianUint32(^uint32(fwmarkMaskNum)) } -// getTailscaleFwmarkMask returns the TailscaleFwmarkMask in bytes. +// getTailscaleFwmarkMask returns the TailscaleFwmarkMask in native byte order. func getTailscaleFwmarkMask() []byte { - return []byte{0x00, 0xff, 0x00, 0x00} + return nativeEndianUint32(fwmarkMaskNum) } -// getTailscaleSubnetRouteMark returns the TailscaleSubnetRouteMark in bytes. +// getTailscaleSubnetRouteMark returns the TailscaleSubnetRouteMark +// in native byte order. func getTailscaleSubnetRouteMark() []byte { - return []byte{0x00, 0x04, 0x00, 0x00} + return nativeEndianUint32(subnetRouteMarkNum) +} + +// nativeEndianUint32 returns v as a 4-byte slice in the host's native byte order. +func nativeEndianUint32(v uint32) []byte { + b := make([]byte, 4) + binary.NativeEndian.PutUint32(b, v) + return b } // checkIPv6ForTest can be set in tests. diff --git a/util/linuxfw/linuxfwtest/linuxfwtest.go b/util/linuxfw/linuxfwtest/linuxfwtest.go deleted file mode 100644 index bf1477ad9..000000000 --- a/util/linuxfw/linuxfwtest/linuxfwtest.go +++ /dev/null @@ -1,31 +0,0 @@ -// Copyright (c) Tailscale Inc & contributors -// SPDX-License-Identifier: BSD-3-Clause - -//go:build cgo && linux - -// Package linuxfwtest contains tests for the linuxfw package. Go does not -// support cgo in tests, and we don't want the main package to have a cgo -// dependency, so we put all the tests here and call them from the main package -// in tests intead. -package linuxfwtest - -import ( - "testing" - "unsafe" -) - -/* -#include // socket() -*/ -import "C" - -type SizeInfo struct { - SizeofSocklen uintptr -} - -func TestSizes(t *testing.T, si *SizeInfo) { - want := unsafe.Sizeof(C.socklen_t(0)) - if want != si.SizeofSocklen { - t.Errorf("sockLen has wrong size; want=%d got=%d", want, si.SizeofSocklen) - } -} diff --git a/util/linuxfw/linuxfwtest/linuxfwtest_unsupported.go b/util/linuxfw/linuxfwtest/linuxfwtest_unsupported.go deleted file mode 100644 index ec2d24d35..000000000 --- a/util/linuxfw/linuxfwtest/linuxfwtest_unsupported.go +++ /dev/null @@ -1,18 +0,0 @@ -// Copyright (c) Tailscale Inc & contributors -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !cgo || !linux - -package linuxfwtest - -import ( - "testing" -) - -type SizeInfo struct { - SizeofSocklen uintptr -} - -func TestSizes(t *testing.T, si *SizeInfo) { - t.Skip("not supported without cgo") -} diff --git a/util/linuxfw/nftables_runner.go b/util/linuxfw/nftables_runner.go index cdb1c5bfb..639a044de 100644 --- a/util/linuxfw/nftables_runner.go +++ b/util/linuxfw/nftables_runner.go @@ -453,8 +453,13 @@ func getOrCreateChain(c *nftables.Conn, cinfo chainInfo) (*nftables.Chain, error // type/hook/priority, but for "conventional chains" assume they're what // we expect (in case iptables-nft/ufw make minor behavior changes in // the future). - if isTSChain(chain.Name) && (chain.Type != cinfo.chainType || *chain.Hooknum != *cinfo.chainHook || *chain.Priority != *cinfo.chainPriority) { - return nil, fmt.Errorf("chain %s already exists with different type/hook/priority", cinfo.name) + if isTSChain(chain.Name) { + if chain.Hooknum == nil || chain.Priority == nil { + return nil, errors.New("nftables chain has nil hooknum or priority; kernel may lack nftables support (CONFIG_NF_TABLES)") + } + if chain.Type != cinfo.chainType || *chain.Hooknum != *cinfo.chainHook || *chain.Priority != *cinfo.chainPriority { + return nil, fmt.Errorf("chain %s already exists with different type/hook/priority", cinfo.name) + } } return chain, nil } @@ -588,6 +593,15 @@ type NetfilterRunner interface { // DelMagicsockPortRule removes the rule created by AddMagicsockPortRule, // if it exists. DelMagicsockPortRule(port uint16, network string) error + + // AddExternalCGNATRules adds rules to the ts-input chain to deal with + // traffic from the CGNAT range that arrives on non-Tailscale network + // interfaces. + AddExternalCGNATRules(mode CGNATMode, tunname string) error + + // DelExternalCGNATRules removes the rules created by AddExternalCGNATRules, + // if they exist. + DelExternalCGNATRules(mode CGNATMode, tunname string) error } // New creates a NetfilterRunner, auto-detecting whether to use @@ -1216,6 +1230,27 @@ func addReturnChromeOSVMRangeRule(c *nftables.Conn, table *nftables.Table, chain return nil } +// delReturnChromeOSVMRangeRule deletes the rule created by addReturnChromeOSVMRangeRule, +// if it exists. +func delReturnChromeOSVMRangeRule(c *nftables.Conn, table *nftables.Table, chain *nftables.Chain, tunname string) error { + rule, err := createRangeRule(table, chain, tunname, tsaddr.ChromeOSVMRange(), expr.VerdictReturn) + if err != nil { + return fmt.Errorf("create rule: %w", err) + } + rule, err = findRule(c, rule) + if err != nil { + return fmt.Errorf("find rule: %v", err) + } + if rule == nil { + return nil + } + _ = c.DelRule(rule) + if err := c.Flush(); err != nil { + return fmt.Errorf("flush del rule: %w", err) + } + return nil +} + // addDropCGNATRangeRule adds a rule to drop if the source IP is in the // CGNAT range. func addDropCGNATRangeRule(c *nftables.Conn, table *nftables.Table, chain *nftables.Chain, tunname string) error { @@ -1230,6 +1265,62 @@ func addDropCGNATRangeRule(c *nftables.Conn, table *nftables.Table, chain *nftab return nil } +// delDropCGNATRangeRule deletes the rule created by addDropCGNATRangeRule, +// if it exists. +func delDropCGNATRangeRule(c *nftables.Conn, table *nftables.Table, chain *nftables.Chain, tunname string) error { + rule, err := createRangeRule(table, chain, tunname, tsaddr.CGNATRange(), expr.VerdictDrop) + if err != nil { + return fmt.Errorf("create rule: %w", err) + } + rule, err = findRule(c, rule) + if err != nil { + return fmt.Errorf("find rule: %v", err) + } + if rule == nil { + return nil + } + _ = c.DelRule(rule) + if err := c.Flush(); err != nil { + return fmt.Errorf("flush del rule: %w", err) + } + return nil +} + +// addReturnCGNATRangeRule adds a rule to return if the source IP is in the +// CGNAT range. +func addReturnCGNATRangeRule(c *nftables.Conn, table *nftables.Table, chain *nftables.Chain, tunname string) error { + rule, err := createRangeRule(table, chain, tunname, tsaddr.CGNATRange(), expr.VerdictReturn) + if err != nil { + return fmt.Errorf("create rule: %w", err) + } + _ = c.AddRule(rule) + if err = c.Flush(); err != nil { + return fmt.Errorf("add rule: %w", err) + } + return nil +} + +// delReturnCGNATRangeRule deletes the rule created by addReturnCGNATRangeRule, +// if it exists. +func delReturnCGNATRangeRule(c *nftables.Conn, table *nftables.Table, chain *nftables.Chain, tunname string) error { + rule, err := createRangeRule(table, chain, tunname, tsaddr.CGNATRange(), expr.VerdictReturn) + if err != nil { + return fmt.Errorf("create rule: %w", err) + } + rule, err = findRule(c, rule) + if err != nil { + return fmt.Errorf("find rule: %v", err) + } + if rule == nil { + return nil + } + _ = c.DelRule(rule) + if err := c.Flush(); err != nil { + return fmt.Errorf("flush del rule: %w", err) + } + return nil +} + // createSetSubnetRouteMarkRule creates a rule to set the subnet route // mark if the packet is from the given interface. func createSetSubnetRouteMarkRule(table *nftables.Table, chain *nftables.Chain, tunname string) (*nftables.Rule, error) { @@ -1497,6 +1588,67 @@ func (n *nftablesRunner) DelMagicsockPortRule(port uint16, network string) error return nil } +// AddExternalCGNATRules adds rules to the ts-input chain to deal with +// traffic from the CGNAT range that arrives on non-Tailscale network +// interfaces. +func (n *nftablesRunner) AddExternalCGNATRules(mode CGNATMode, tunname string) error { + conn := n.conn + + inputChain, err := getChainFromTable(conn, n.nft4.Filter, chainNameInput) + if err != nil { + return fmt.Errorf("get input chain v4: %v", err) + } + switch mode { + case CGNATModeDrop: + if err = addReturnChromeOSVMRangeRule(conn, n.nft4.Filter, inputChain, tunname); err != nil { + return fmt.Errorf("add return chromeos vm range rule v4: %w", err) + } + if err = addDropCGNATRangeRule(conn, n.nft4.Filter, inputChain, tunname); err != nil { + return fmt.Errorf("add drop cgnat range rule v4: %w", err) + } + case CGNATModeReturn: + if err = addReturnCGNATRangeRule(conn, n.nft4.Filter, inputChain, tunname); err != nil { + return fmt.Errorf("add return cgnat range rule v4: %w", err) + } + default: + return fmt.Errorf("unsupported cgnat mode %q", mode) + } + if err = conn.Flush(); err != nil { + return fmt.Errorf("flush cgnat rules v4: %w", err) + } + return nil +} + +// DelExternalCGNATRules removes the rules created by AddExternalCGNATRules, +// if they exist. +func (n *nftablesRunner) DelExternalCGNATRules(mode CGNATMode, tunname string) error { + conn := n.conn + + inputChain, err := getChainFromTable(conn, n.nft4.Filter, chainNameInput) + if err != nil { + return fmt.Errorf("get input chain v4: %v", err) + } + switch mode { + case CGNATModeDrop: + if err = delReturnChromeOSVMRangeRule(conn, n.nft4.Filter, inputChain, tunname); err != nil { + return fmt.Errorf("del return chromeos vm range rule v4: %w", err) + } + if err = delDropCGNATRangeRule(conn, n.nft4.Filter, inputChain, tunname); err != nil { + return fmt.Errorf("del drop cgnat range rule v4: %w", err) + } + case CGNATModeReturn: + if err = delReturnCGNATRangeRule(conn, n.nft4.Filter, inputChain, tunname); err != nil { + return fmt.Errorf("del return cgnat range rule v4: %w", err) + } + default: + return fmt.Errorf("unsupported mode %q", mode) + } + if err = conn.Flush(); err != nil { + return fmt.Errorf("flush cgnat rules v4: %w", err) + } + return nil +} + // createAcceptIncomingPacketRule creates a rule to accept incoming packets to // the given interface. func createAcceptIncomingPacketRule(table *nftables.Table, chain *nftables.Chain, tunname string) *nftables.Rule { @@ -1550,12 +1702,6 @@ func (n *nftablesRunner) addBase4(tunname string) error { if err != nil { return fmt.Errorf("get input chain v4: %v", err) } - if err = addReturnChromeOSVMRangeRule(conn, n.nft4.Filter, inputChain, tunname); err != nil { - return fmt.Errorf("add return chromeos vm range rule v4: %w", err) - } - if err = addDropCGNATRangeRule(conn, n.nft4.Filter, inputChain, tunname); err != nil { - return fmt.Errorf("add drop cgnat range rule v4: %w", err) - } if err = addAcceptIncomingPacketRule(conn, n.nft4.Filter, inputChain, tunname); err != nil { return fmt.Errorf("add accept incoming packet rule v4: %w", err) } @@ -1771,9 +1917,7 @@ func (n *nftablesRunner) DelSNATRule() error { } func nativeUint32(v uint32) []byte { - b := make([]byte, 4) - binary.NativeEndian.PutUint32(b, v) - return b + return nativeEndianUint32(v) } func makeStatefulRuleExprs(tunname string) []expr.Any { @@ -1960,6 +2104,24 @@ func (n *nftablesRunner) DelStatefulRule(tunname string) error { // makeConnmarkRestoreExprs creates nftables expressions to restore mark from conntrack. // Implements: ct state established,related ct mark & 0xff0000 != 0 meta mark set ct mark & 0xff0000 +// +// LIMITATION: Unlike iptables CONNMARK --restore-mark with --nfmask, this implementation +// overwrites non-Tailscale bits in the packet mark rather than merging them. This is a +// fundamental limitation of the Linux kernel's nftables expression VM (not the Go library). +// +// The nftables Bitwise expression only supports: (register & CONSTANT_MASK) ^ CONSTANT_XOR. +// It cannot perform register-to-register operations needed for perfect bit preservation: +// +// meta mark = (meta mark & ~0xff0000) | (ct mark & 0xff0000) +// ^^^^^ ^^^^^^^ +// needs meta mark and ct mark combined +// +// In contrast, iptables CONNMARK is a specialized kernel module with custom C code that +// can atomically merge marks from different sources. +// +// The conditional check (ct mark & 0xff0000 != 0) prevents the worst case of wiping all +// mark bits to zero. Perfect bit preservation would require kernel +// changes to add register-to-register bitwise operations to nftables. func makeConnmarkRestoreExprs() []expr.Any { return []expr.Any{ // Load conntrack state into register 1 @@ -1995,7 +2157,13 @@ func makeConnmarkRestoreExprs() []expr.Any { Mask: getTailscaleFwmarkMask(), Xor: []byte{0x00, 0x00, 0x00, 0x00}, }, - // Set packet mark from register 1 + // Check if masked ct mark is non-zero (critical: prevents wiping marks with 0) + &expr.Cmp{ + Op: expr.CmpOpNeq, + Register: 1, + Data: []byte{0, 0, 0, 0}, + }, + // Set packet mark from register 1 (contains ct mark & 0xff0000) &expr.Meta{ Key: expr.MetaKeyMARK, SourceRegister: true, diff --git a/util/linuxfw/nftables_runner_test.go b/util/linuxfw/nftables_runner_test.go index 17945e245..5aa418378 100644 --- a/util/linuxfw/nftables_runner_test.go +++ b/util/linuxfw/nftables_runner_test.go @@ -10,7 +10,6 @@ import ( "errors" "fmt" "net/netip" - "os" "runtime" "slices" "strings" @@ -299,7 +298,7 @@ func TestAddSetSubnetRouteMarkRule(t *testing.T) { // nft add chain ip ts-filter-test ts-forward-test { type filter hook forward priority 0\; } []byte("\x02\x00\x00\x00\x13\x00\x01\x00\x74\x73\x2d\x66\x69\x6c\x74\x65\x72\x2d\x74\x65\x73\x74\x00\x00\x14\x00\x03\x00\x74\x73\x2d\x66\x6f\x72\x77\x61\x72\x64\x2d\x74\x65\x73\x74\x00\x14\x00\x04\x80\x08\x00\x01\x00\x00\x00\x00\x02\x08\x00\x02\x00\x00\x00\x00\x00\x0b\x00\x07\x00\x66\x69\x6c\x74\x65\x72\x00\x00"), // nft add rule ip ts-filter-test ts-forward-test iifname "testTunn" counter meta mark set mark and 0xff00ffff xor 0x40000 - []byte("\x02\x00\x00\x00\x13\x00\x01\x00\x74\x73\x2d\x66\x69\x6c\x74\x65\x72\x2d\x74\x65\x73\x74\x00\x00\x14\x00\x02\x00\x74\x73\x2d\x66\x6f\x72\x77\x61\x72\x64\x2d\x74\x65\x73\x74\x00\x10\x01\x04\x80\x24\x00\x01\x80\x09\x00\x01\x00\x6d\x65\x74\x61\x00\x00\x00\x00\x14\x00\x02\x80\x08\x00\x02\x00\x00\x00\x00\x06\x08\x00\x01\x00\x00\x00\x00\x01\x30\x00\x01\x80\x08\x00\x01\x00\x63\x6d\x70\x00\x24\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x00\x10\x00\x03\x80\x0c\x00\x01\x00\x74\x65\x73\x74\x54\x75\x6e\x6e\x2c\x00\x01\x80\x0c\x00\x01\x00\x63\x6f\x75\x6e\x74\x65\x72\x00\x1c\x00\x02\x80\x0c\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x0c\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x24\x00\x01\x80\x09\x00\x01\x00\x6d\x65\x74\x61\x00\x00\x00\x00\x14\x00\x02\x80\x08\x00\x02\x00\x00\x00\x00\x03\x08\x00\x01\x00\x00\x00\x00\x01\x44\x00\x01\x80\x0c\x00\x01\x00\x62\x69\x74\x77\x69\x73\x65\x00\x34\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x01\x08\x00\x03\x00\x00\x00\x00\x04\x0c\x00\x04\x80\x08\x00\x01\x00\xff\x00\xff\xff\x0c\x00\x05\x80\x08\x00\x01\x00\x00\x04\x00\x00\x24\x00\x01\x80\x09\x00\x01\x00\x6d\x65\x74\x61\x00\x00\x00\x00\x14\x00\x02\x80\x08\x00\x02\x00\x00\x00\x00\x03\x08\x00\x03\x00\x00\x00\x00\x01"), + []byte("\x02\x00\x00\x00\x13\x00\x01\x00\x74\x73\x2d\x66\x69\x6c\x74\x65\x72\x2d\x74\x65\x73\x74\x00\x00\x14\x00\x02\x00\x74\x73\x2d\x66\x6f\x72\x77\x61\x72\x64\x2d\x74\x65\x73\x74\x00\x10\x01\x04\x80\x24\x00\x01\x80\x09\x00\x01\x00\x6d\x65\x74\x61\x00\x00\x00\x00\x14\x00\x02\x80\x08\x00\x02\x00\x00\x00\x00\x06\x08\x00\x01\x00\x00\x00\x00\x01\x30\x00\x01\x80\x08\x00\x01\x00\x63\x6d\x70\x00\x24\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x00\x10\x00\x03\x80\x0c\x00\x01\x00\x74\x65\x73\x74\x54\x75\x6e\x6e\x2c\x00\x01\x80\x0c\x00\x01\x00\x63\x6f\x75\x6e\x74\x65\x72\x00\x1c\x00\x02\x80\x0c\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x0c\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x24\x00\x01\x80\x09\x00\x01\x00\x6d\x65\x74\x61\x00\x00\x00\x00\x14\x00\x02\x80\x08\x00\x02\x00\x00\x00\x00\x03\x08\x00\x01\x00\x00\x00\x00\x01\x44\x00\x01\x80\x0c\x00\x01\x00\x62\x69\x74\x77\x69\x73\x65\x00\x34\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x01\x08\x00\x03\x00\x00\x00\x00\x04\x0c\x00\x04\x80\x08\x00\x01\x00\xff\xff\x00\xff\x0c\x00\x05\x80\x08\x00\x01\x00\x00\x00\x04\x00\x24\x00\x01\x80\x09\x00\x01\x00\x6d\x65\x74\x61\x00\x00\x00\x00\x14\x00\x02\x80\x08\x00\x02\x00\x00\x00\x00\x03\x08\x00\x03\x00\x00\x00\x00\x01"), // batch end []byte("\x00\x00\x00\x0a"), } @@ -427,7 +426,7 @@ func TestAddMatchSubnetRouteMarkRuleMasq(t *testing.T) { // nft add chain ip ts-nat-test ts-postrouting-test { type nat hook postrouting priority 100; } []byte("\x02\x00\x00\x00\x10\x00\x01\x00\x74\x73\x2d\x6e\x61\x74\x2d\x74\x65\x73\x74\x00\x18\x00\x03\x00\x74\x73\x2d\x70\x6f\x73\x74\x72\x6f\x75\x74\x69\x6e\x67\x2d\x74\x65\x73\x74\x00\x14\x00\x04\x80\x08\x00\x01\x00\x00\x00\x00\x04\x08\x00\x02\x00\x00\x00\x00\x64\x08\x00\x07\x00\x6e\x61\x74\x00"), // nft add rule ip ts-nat-test ts-postrouting-test meta mark & 0x00ff0000 == 0x00040000 counter masquerade - []byte("\x02\x00\x00\x00\x10\x00\x01\x00\x74\x73\x2d\x6e\x61\x74\x2d\x74\x65\x73\x74\x00\x18\x00\x02\x00\x74\x73\x2d\x70\x6f\x73\x74\x72\x6f\x75\x74\x69\x6e\x67\x2d\x74\x65\x73\x74\x00\xd8\x00\x04\x80\x24\x00\x01\x80\x09\x00\x01\x00\x6d\x65\x74\x61\x00\x00\x00\x00\x14\x00\x02\x80\x08\x00\x02\x00\x00\x00\x00\x03\x08\x00\x01\x00\x00\x00\x00\x01\x44\x00\x01\x80\x0c\x00\x01\x00\x62\x69\x74\x77\x69\x73\x65\x00\x34\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x01\x08\x00\x03\x00\x00\x00\x00\x04\x0c\x00\x04\x80\x08\x00\x01\x00\x00\xff\x00\x00\x0c\x00\x05\x80\x08\x00\x01\x00\x00\x00\x00\x00\x2c\x00\x01\x80\x08\x00\x01\x00\x63\x6d\x70\x00\x20\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x00\x0c\x00\x03\x80\x08\x00\x01\x00\x00\x04\x00\x00\x2c\x00\x01\x80\x0c\x00\x01\x00\x63\x6f\x75\x6e\x74\x65\x72\x00\x1c\x00\x02\x80\x0c\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x0c\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x14\x00\x01\x80\x09\x00\x01\x00\x6d\x61\x73\x71\x00\x00\x00\x00\x04\x00\x02\x80"), + []byte("\x02\x00\x00\x00\x10\x00\x01\x00\x74\x73\x2d\x6e\x61\x74\x2d\x74\x65\x73\x74\x00\x18\x00\x02\x00\x74\x73\x2d\x70\x6f\x73\x74\x72\x6f\x75\x74\x69\x6e\x67\x2d\x74\x65\x73\x74\x00\xd8\x00\x04\x80\x24\x00\x01\x80\x09\x00\x01\x00\x6d\x65\x74\x61\x00\x00\x00\x00\x14\x00\x02\x80\x08\x00\x02\x00\x00\x00\x00\x03\x08\x00\x01\x00\x00\x00\x00\x01\x44\x00\x01\x80\x0c\x00\x01\x00\x62\x69\x74\x77\x69\x73\x65\x00\x34\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x01\x08\x00\x03\x00\x00\x00\x00\x04\x0c\x00\x04\x80\x08\x00\x01\x00\x00\x00\xff\x00\x0c\x00\x05\x80\x08\x00\x01\x00\x00\x00\x00\x00\x2c\x00\x01\x80\x08\x00\x01\x00\x63\x6d\x70\x00\x20\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x00\x0c\x00\x03\x80\x08\x00\x01\x00\x00\x00\x04\x00\x2c\x00\x01\x80\x0c\x00\x01\x00\x63\x6f\x75\x6e\x74\x65\x72\x00\x1c\x00\x02\x80\x0c\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x0c\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x14\x00\x01\x80\x09\x00\x01\x00\x6d\x61\x73\x71\x00\x00\x00\x00\x04\x00\x02\x80"), // batch end []byte("\x00\x00\x00\x0a"), } @@ -498,7 +497,7 @@ func TestAddMatchSubnetRouteMarkRuleAccept(t *testing.T) { // nft add chain ip ts-filter-test ts-forward-test { type filter hook forward priority 0\; } []byte("\x02\x00\x00\x00\x13\x00\x01\x00\x74\x73\x2d\x66\x69\x6c\x74\x65\x72\x2d\x74\x65\x73\x74\x00\x00\x14\x00\x03\x00\x74\x73\x2d\x66\x6f\x72\x77\x61\x72\x64\x2d\x74\x65\x73\x74\x00\x14\x00\x04\x80\x08\x00\x01\x00\x00\x00\x00\x02\x08\x00\x02\x00\x00\x00\x00\x00\x0b\x00\x07\x00\x66\x69\x6c\x74\x65\x72\x00\x00"), // nft add rule ip ts-filter-test ts-forward-test meta mark and 0x00ff0000 eq 0x00040000 counter accept - []byte("\x02\x00\x00\x00\x13\x00\x01\x00\x74\x73\x2d\x66\x69\x6c\x74\x65\x72\x2d\x74\x65\x73\x74\x00\x00\x14\x00\x02\x00\x74\x73\x2d\x66\x6f\x72\x77\x61\x72\x64\x2d\x74\x65\x73\x74\x00\xf4\x00\x04\x80\x24\x00\x01\x80\x09\x00\x01\x00\x6d\x65\x74\x61\x00\x00\x00\x00\x14\x00\x02\x80\x08\x00\x02\x00\x00\x00\x00\x03\x08\x00\x01\x00\x00\x00\x00\x01\x44\x00\x01\x80\x0c\x00\x01\x00\x62\x69\x74\x77\x69\x73\x65\x00\x34\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x01\x08\x00\x03\x00\x00\x00\x00\x04\x0c\x00\x04\x80\x08\x00\x01\x00\x00\xff\x00\x00\x0c\x00\x05\x80\x08\x00\x01\x00\x00\x00\x00\x00\x2c\x00\x01\x80\x08\x00\x01\x00\x63\x6d\x70\x00\x20\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x00\x0c\x00\x03\x80\x08\x00\x01\x00\x00\x04\x00\x00\x2c\x00\x01\x80\x0c\x00\x01\x00\x63\x6f\x75\x6e\x74\x65\x72\x00\x1c\x00\x02\x80\x0c\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x0c\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x30\x00\x01\x80\x0e\x00\x01\x00\x69\x6d\x6d\x65\x64\x69\x61\x74\x65\x00\x00\x00\x1c\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x00\x10\x00\x02\x80\x0c\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01"), + []byte("\x02\x00\x00\x00\x13\x00\x01\x00\x74\x73\x2d\x66\x69\x6c\x74\x65\x72\x2d\x74\x65\x73\x74\x00\x00\x14\x00\x02\x00\x74\x73\x2d\x66\x6f\x72\x77\x61\x72\x64\x2d\x74\x65\x73\x74\x00\xf4\x00\x04\x80\x24\x00\x01\x80\x09\x00\x01\x00\x6d\x65\x74\x61\x00\x00\x00\x00\x14\x00\x02\x80\x08\x00\x02\x00\x00\x00\x00\x03\x08\x00\x01\x00\x00\x00\x00\x01\x44\x00\x01\x80\x0c\x00\x01\x00\x62\x69\x74\x77\x69\x73\x65\x00\x34\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x01\x08\x00\x03\x00\x00\x00\x00\x04\x0c\x00\x04\x80\x08\x00\x01\x00\x00\x00\xff\x00\x0c\x00\x05\x80\x08\x00\x01\x00\x00\x00\x00\x00\x2c\x00\x01\x80\x08\x00\x01\x00\x63\x6d\x70\x00\x20\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x00\x0c\x00\x03\x80\x08\x00\x01\x00\x00\x00\x04\x00\x2c\x00\x01\x80\x0c\x00\x01\x00\x63\x6f\x75\x6e\x74\x65\x72\x00\x1c\x00\x02\x80\x0c\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x0c\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x30\x00\x01\x80\x0e\x00\x01\x00\x69\x6d\x6d\x65\x64\x69\x61\x74\x65\x00\x00\x00\x1c\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x00\x10\x00\x02\x80\x0c\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01"), // batch end []byte("\x00\x00\x00\x0a"), } @@ -522,10 +521,7 @@ func TestAddMatchSubnetRouteMarkRuleAccept(t *testing.T) { func newSysConn(t *testing.T) *nftables.Conn { t.Helper() - if os.Geteuid() != 0 { - t.Skip(t.Name(), " requires privileges to create a namespace in order to run") - return nil - } + tstest.RequireRoot(t) runtime.LockOSThread() @@ -637,7 +633,7 @@ func TestAddAndDelNetfilterChains(t *testing.T) { func getTsChains( conn *nftables.Conn, proto nftables.TableFamily) (*nftables.Chain, *nftables.Chain, *nftables.Chain, error) { - chains, err := conn.ListChainsOfTableFamily(nftables.TableFamilyIPv4) + chains, err := conn.ListChainsOfTableFamily(proto) if err != nil { return nil, nil, nil, fmt.Errorf("list chains failed: %w", err) } @@ -662,17 +658,7 @@ func findV4BaseRules( forwChain *nftables.Chain, tunname string) ([]*nftables.Rule, error) { want := []*nftables.Rule{} - rule, err := createRangeRule(inpChain.Table, inpChain, tunname, tsaddr.ChromeOSVMRange(), expr.VerdictReturn) - if err != nil { - return nil, fmt.Errorf("create rule: %w", err) - } - want = append(want, rule) - rule, err = createRangeRule(inpChain.Table, inpChain, tunname, tsaddr.CGNATRange(), expr.VerdictDrop) - if err != nil { - return nil, fmt.Errorf("create rule: %w", err) - } - want = append(want, rule) - rule, err = createDropOutgoingPacketFromCGNATRangeRuleWithTunname(forwChain.Table, forwChain, tunname) + rule, err := createDropOutgoingPacketFromCGNATRangeRuleWithTunname(forwChain.Table, forwChain, tunname) if err != nil { return nil, fmt.Errorf("create rule: %w", err) } @@ -749,7 +735,7 @@ func TestNFTAddAndDelNetfilterBase(t *testing.T) { if err != nil { t.Fatalf("getTsChains() failed: %v", err) } - checkChainRules(t, conn, inputV4, 3) + checkChainRules(t, conn, inputV4, 1) checkChainRules(t, conn, forwardV4, 4) checkChainRules(t, conn, postroutingV4, 0) @@ -767,8 +753,8 @@ func TestNFTAddAndDelNetfilterBase(t *testing.T) { if err != nil { t.Fatalf("getTsChains() failed: %v", err) } - checkChainRules(t, conn, inputV6, 3) - checkChainRules(t, conn, forwardV6, 4) + checkChainRules(t, conn, inputV6, 1) + checkChainRules(t, conn, forwardV6, 3) checkChainRules(t, conn, postroutingV6, 0) _, err = findCommonBaseRules(conn, forwardV6, "testTunn") @@ -787,6 +773,92 @@ func TestNFTAddAndDelNetfilterBase(t *testing.T) { } } +func findCGNATRules( + conn *nftables.Conn, + inpChain *nftables.Chain, + mode CGNATMode, + tunname string, +) error { + want := []*nftables.Rule{} + switch mode { + case CGNATModeDrop: + rule, err := createRangeRule(inpChain.Table, inpChain, tunname, tsaddr.ChromeOSVMRange(), expr.VerdictReturn) + if err != nil { + return fmt.Errorf("create rule: %w", err) + } + want = append(want, rule) + rule, err = createRangeRule(inpChain.Table, inpChain, tunname, tsaddr.CGNATRange(), expr.VerdictDrop) + if err != nil { + return fmt.Errorf("create rule: %w", err) + } + want = append(want, rule) + case CGNATModeReturn: + rule, err := createRangeRule(inpChain.Table, inpChain, tunname, tsaddr.CGNATRange(), expr.VerdictReturn) + if err != nil { + return fmt.Errorf("create rule: %w", err) + } + want = append(want, rule) + default: + return fmt.Errorf("unknown mode %q", mode) + } + for _, rule := range want { + _, err := findRule(conn, rule) + if err != nil { + return fmt.Errorf("find rule: %w", err) + } + } + return nil +} + +func TestNFTAddAndDelCGNATRules(t *testing.T) { + modes := []CGNATMode{CGNATModeDrop, CGNATModeReturn} + for _, mode := range modes { + t.Run(string(mode), func(t *testing.T) { + conn := newSysConn(t) + + runner := newFakeNftablesRunnerWithConn(t, conn, false) + + if err := runner.AddChains(); err != nil { + t.Fatalf("AddChains() failed: %v", err) + } + defer runner.DelChains() + + inputV4, _, _, err := getTsChains(conn, nftables.TableFamilyIPv4) + if err != nil { + t.Fatalf("getTsChains() failed: %v", err) + } + + checkChainRules(t, conn, inputV4, 0) + + tunname := "tun0" + + if err := runner.AddExternalCGNATRules(mode, tunname); err != nil { + t.Fatalf("add rules: %v", err) + } + + switch mode { + case CGNATModeDrop: + checkChainRules(t, conn, inputV4, 2) + case CGNATModeReturn: + checkChainRules(t, conn, inputV4, 1) + default: + t.Fatalf("unknown mode %q", mode) + } + + if err := findCGNATRules(conn, inputV4, mode, tunname); err != nil { + t.Fatalf("find rules: %v", err) + } + + if err := runner.DelExternalCGNATRules(mode, tunname); err != nil { + t.Fatalf("delete rules: %v", err) + } + + // Verify that all the rules have been deleted (0 remaining). + checkChainRules(t, conn, inputV4, 0) + }) + } +} + func findLoopBackRule(conn *nftables.Conn, proto nftables.TableFamily, table *nftables.Table, chain *nftables.Chain, addr netip.Addr) (*nftables.Rule, error) { matchingAddr := addr.AsSlice() saddrExpr, err := newLoadSaddrExpr(proto, 1) @@ -849,16 +921,16 @@ func TestNFTAddAndDelLoopbackRule(t *testing.T) { runner.AddBase("testTunn") defer runner.DelBase() - checkChainRules(t, conn, inputV4, 3) - checkChainRules(t, conn, inputV6, 3) + checkChainRules(t, conn, inputV4, 1) + checkChainRules(t, conn, inputV6, 1) addr := netip.MustParseAddr("192.168.0.2") addrV6 := netip.MustParseAddr("2001:db8::2") runner.AddLoopbackRule(addr) runner.AddLoopbackRule(addrV6) - checkChainRules(t, conn, inputV4, 4) - checkChainRules(t, conn, inputV6, 4) + checkChainRules(t, conn, inputV4, 2) + checkChainRules(t, conn, inputV6, 2) existingLoopBackRule, err := findLoopBackRule(conn, nftables.TableFamilyIPv4, runner.nft4.Filter, inputV4, addr) if err != nil { @@ -881,8 +953,8 @@ func TestNFTAddAndDelLoopbackRule(t *testing.T) { runner.DelLoopbackRule(addr) runner.DelLoopbackRule(addrV6) - checkChainRules(t, conn, inputV4, 3) - checkChainRules(t, conn, inputV6, 3) + checkChainRules(t, conn, inputV4, 1) + checkChainRules(t, conn, inputV6, 1) } func TestNFTAddAndDelHookRule(t *testing.T) { @@ -1246,7 +1318,7 @@ func TestMakeConnmarkRestoreExprs(t *testing.T) { // nft add chain ip mangle PREROUTING { type filter hook prerouting priority mangle; } []byte("\x02\x00\x00\x00\x0b\x00\x01\x00\x6d\x61\x6e\x67\x6c\x65\x00\x00\x0f\x00\x03\x00\x50\x52\x45\x52\x4f\x55\x54\x49\x4e\x47\x00\x00\x14\x00\x04\x80\x08\x00\x01\x00\x00\x00\x00\x00\x08\x00\x02\x00\xff\xff\xff\x6a\x0b\x00\x07\x00\x66\x69\x6c\x74\x65\x72\x00\x00"), // nft add rule ip mangle PREROUTING ct state established,related ct mark & 0xff0000 != 0 meta mark set ct mark & 0xff0000 - []byte("\x02\x00\x00\x00\x0b\x00\x01\x00\x6d\x61\x6e\x67\x6c\x65\x00\x00\x0f\x00\x02\x00\x50\x52\x45\x52\x4f\x55\x54\x49\x4e\x47\x00\x00\x1c\x01\x04\x80\x20\x00\x01\x80\x07\x00\x01\x00\x63\x74\x00\x00\x14\x00\x02\x80\x08\x00\x02\x00\x00\x00\x00\x00\x08\x00\x01\x00\x00\x00\x00\x01\x44\x00\x01\x80\x0c\x00\x01\x00\x62\x69\x74\x77\x69\x73\x65\x00\x34\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x01\x08\x00\x03\x00\x00\x00\x00\x04\x0c\x00\x04\x80\x08\x00\x01\x00\x06\x00\x00\x00\x0c\x00\x05\x80\x08\x00\x01\x00\x00\x00\x00\x00\x2c\x00\x01\x80\x08\x00\x01\x00\x63\x6d\x70\x00\x20\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x01\x0c\x00\x03\x80\x08\x00\x01\x00\x00\x00\x00\x00\x20\x00\x01\x80\x07\x00\x01\x00\x63\x74\x00\x00\x14\x00\x02\x80\x08\x00\x02\x00\x00\x00\x00\x03\x08\x00\x01\x00\x00\x00\x00\x01\x44\x00\x01\x80\x0c\x00\x01\x00\x62\x69\x74\x77\x69\x73\x65\x00\x34\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x01\x08\x00\x03\x00\x00\x00\x00\x04\x0c\x00\x04\x80\x08\x00\x01\x00\x00\xff\x00\x00\x0c\x00\x05\x80\x08\x00\x01\x00\x00\x00\x00\x00\x24\x00\x01\x80\x09\x00\x01\x00\x6d\x65\x74\x61\x00\x00\x00\x00\x14\x00\x02\x80\x08\x00\x02\x00\x00\x00\x00\x03\x08\x00\x03\x00\x00\x00\x00\x01"), + []byte("\x02\x00\x00\x00\x0b\x00\x01\x00\x6d\x61\x6e\x67\x6c\x65\x00\x00\x0f\x00\x02\x00\x50\x52\x45\x52\x4f\x55\x54\x49\x4e\x47\x00\x00\x48\x01\x04\x80\x20\x00\x01\x80\x07\x00\x01\x00\x63\x74\x00\x00\x14\x00\x02\x80\x08\x00\x02\x00\x00\x00\x00\x00\x08\x00\x01\x00\x00\x00\x00\x01\x44\x00\x01\x80\x0c\x00\x01\x00\x62\x69\x74\x77\x69\x73\x65\x00\x34\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x01\x08\x00\x03\x00\x00\x00\x00\x04\x0c\x00\x04\x80\x08\x00\x01\x00\x06\x00\x00\x00\x0c\x00\x05\x80\x08\x00\x01\x00\x00\x00\x00\x00\x2c\x00\x01\x80\x08\x00\x01\x00\x63\x6d\x70\x00\x20\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x01\x0c\x00\x03\x80\x08\x00\x01\x00\x00\x00\x00\x00\x20\x00\x01\x80\x07\x00\x01\x00\x63\x74\x00\x00\x14\x00\x02\x80\x08\x00\x02\x00\x00\x00\x00\x03\x08\x00\x01\x00\x00\x00\x00\x01\x44\x00\x01\x80\x0c\x00\x01\x00\x62\x69\x74\x77\x69\x73\x65\x00\x34\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x01\x08\x00\x03\x00\x00\x00\x00\x04\x0c\x00\x04\x80\x08\x00\x01\x00\x00\x00\xff\x00\x0c\x00\x05\x80\x08\x00\x01\x00\x00\x00\x00\x00\x2c\x00\x01\x80\x08\x00\x01\x00\x63\x6d\x70\x00\x20\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x01\x0c\x00\x03\x80\x08\x00\x01\x00\x00\x00\x00\x00\x24\x00\x01\x80\x09\x00\x01\x00\x6d\x65\x74\x61\x00\x00\x00\x00\x14\x00\x02\x80\x08\x00\x02\x00\x00\x00\x00\x03\x08\x00\x03\x00\x00\x00\x00\x01"), // batch end []byte("\x00\x00\x00\x0a"), } @@ -1287,7 +1359,7 @@ func TestMakeConnmarkSaveExprs(t *testing.T) { // nft add chain ip mangle OUTPUT { type route hook output priority mangle; } []byte("\x02\x00\x00\x00\x0b\x00\x01\x00\x6d\x61\x6e\x67\x6c\x65\x00\x00\x0b\x00\x03\x00\x4f\x55\x54\x50\x55\x54\x00\x00\x14\x00\x04\x80\x08\x00\x01\x00\x00\x00\x00\x03\x08\x00\x02\x00\xff\xff\xff\x6a\x0a\x00\x07\x00\x72\x6f\x75\x74\x65\x00\x00\x00"), // nft add rule ip mangle OUTPUT ct state new meta mark & 0xff0000 != 0 ct mark set meta mark & 0xff0000 - []byte("\x02\x00\x00\x00\x0b\x00\x01\x00\x6d\x61\x6e\x67\x6c\x65\x00\x00\x0b\x00\x02\x00\x4f\x55\x54\x50\x55\x54\x00\x00\xb0\x01\x04\x80\x20\x00\x01\x80\x07\x00\x01\x00\x63\x74\x00\x00\x14\x00\x02\x80\x08\x00\x02\x00\x00\x00\x00\x00\x08\x00\x01\x00\x00\x00\x00\x01\x44\x00\x01\x80\x0c\x00\x01\x00\x62\x69\x74\x77\x69\x73\x65\x00\x34\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x01\x08\x00\x03\x00\x00\x00\x00\x04\x0c\x00\x04\x80\x08\x00\x01\x00\x08\x00\x00\x00\x0c\x00\x05\x80\x08\x00\x01\x00\x00\x00\x00\x00\x2c\x00\x01\x80\x08\x00\x01\x00\x63\x6d\x70\x00\x20\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x01\x0c\x00\x03\x80\x08\x00\x01\x00\x00\x00\x00\x00\x24\x00\x01\x80\x09\x00\x01\x00\x6d\x65\x74\x61\x00\x00\x00\x00\x14\x00\x02\x80\x08\x00\x02\x00\x00\x00\x00\x03\x08\x00\x01\x00\x00\x00\x00\x01\x44\x00\x01\x80\x0c\x00\x01\x00\x62\x69\x74\x77\x69\x73\x65\x00\x34\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x01\x08\x00\x03\x00\x00\x00\x00\x04\x0c\x00\x04\x80\x08\x00\x01\x00\x00\xff\x00\x00\x0c\x00\x05\x80\x08\x00\x01\x00\x00\x00\x00\x00\x2c\x00\x01\x80\x08\x00\x01\x00\x63\x6d\x70\x00\x20\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x01\x0c\x00\x03\x80\x08\x00\x01\x00\x00\x00\x00\x00\x24\x00\x01\x80\x09\x00\x01\x00\x6d\x65\x74\x61\x00\x00\x00\x00\x14\x00\x02\x80\x08\x00\x02\x00\x00\x00\x00\x03\x08\x00\x01\x00\x00\x00\x00\x01\x44\x00\x01\x80\x0c\x00\x01\x00\x62\x69\x74\x77\x69\x73\x65\x00\x34\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x01\x08\x00\x03\x00\x00\x00\x00\x04\x0c\x00\x04\x80\x08\x00\x01\x00\x00\xff\x00\x00\x0c\x00\x05\x80\x08\x00\x01\x00\x00\x00\x00\x00\x20\x00\x01\x80\x07\x00\x01\x00\x63\x74\x00\x00\x14\x00\x02\x80\x08\x00\x02\x00\x00\x00\x00\x03\x08\x00\x04\x00\x00\x00\x00\x01"), + []byte("\x02\x00\x00\x00\x0b\x00\x01\x00\x6d\x61\x6e\x67\x6c\x65\x00\x00\x0b\x00\x02\x00\x4f\x55\x54\x50\x55\x54\x00\x00\xb0\x01\x04\x80\x20\x00\x01\x80\x07\x00\x01\x00\x63\x74\x00\x00\x14\x00\x02\x80\x08\x00\x02\x00\x00\x00\x00\x00\x08\x00\x01\x00\x00\x00\x00\x01\x44\x00\x01\x80\x0c\x00\x01\x00\x62\x69\x74\x77\x69\x73\x65\x00\x34\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x01\x08\x00\x03\x00\x00\x00\x00\x04\x0c\x00\x04\x80\x08\x00\x01\x00\x08\x00\x00\x00\x0c\x00\x05\x80\x08\x00\x01\x00\x00\x00\x00\x00\x2c\x00\x01\x80\x08\x00\x01\x00\x63\x6d\x70\x00\x20\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x01\x0c\x00\x03\x80\x08\x00\x01\x00\x00\x00\x00\x00\x24\x00\x01\x80\x09\x00\x01\x00\x6d\x65\x74\x61\x00\x00\x00\x00\x14\x00\x02\x80\x08\x00\x02\x00\x00\x00\x00\x03\x08\x00\x01\x00\x00\x00\x00\x01\x44\x00\x01\x80\x0c\x00\x01\x00\x62\x69\x74\x77\x69\x73\x65\x00\x34\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x01\x08\x00\x03\x00\x00\x00\x00\x04\x0c\x00\x04\x80\x08\x00\x01\x00\x00\x00\xff\x00\x0c\x00\x05\x80\x08\x00\x01\x00\x00\x00\x00\x00\x2c\x00\x01\x80\x08\x00\x01\x00\x63\x6d\x70\x00\x20\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x01\x0c\x00\x03\x80\x08\x00\x01\x00\x00\x00\x00\x00\x24\x00\x01\x80\x09\x00\x01\x00\x6d\x65\x74\x61\x00\x00\x00\x00\x14\x00\x02\x80\x08\x00\x02\x00\x00\x00\x00\x03\x08\x00\x01\x00\x00\x00\x00\x01\x44\x00\x01\x80\x0c\x00\x01\x00\x62\x69\x74\x77\x69\x73\x65\x00\x34\x00\x02\x80\x08\x00\x01\x00\x00\x00\x00\x01\x08\x00\x02\x00\x00\x00\x00\x01\x08\x00\x03\x00\x00\x00\x00\x04\x0c\x00\x04\x80\x08\x00\x01\x00\x00\x00\xff\x00\x0c\x00\x05\x80\x08\x00\x01\x00\x00\x00\x00\x00\x20\x00\x01\x80\x07\x00\x01\x00\x63\x74\x00\x00\x14\x00\x02\x80\x08\x00\x02\x00\x00\x00\x00\x03\x08\x00\x04\x00\x00\x00\x00\x01"), // batch end []byte("\x00\x00\x00\x0a"), } @@ -1313,3 +1385,39 @@ func TestMakeConnmarkSaveExprs(t *testing.T) { t.Fatalf("Flush() failed: %v", err) } } + +// TestGetOrCreateChainNilHooknum verifies that getOrCreateChain returns a clear +// error when a ts- chain exists but has nil Hooknum/Priority, which happens when +// the kernel lacks nftables support (CONFIG_NF_TABLES). +func TestGetOrCreateChainNilHooknum(t *testing.T) { + conn := newSysConn(t) + + table := conn.AddTable(&nftables.Table{ + Family: nftables.TableFamilyIPv4, + Name: "ts-filter-test", + }) + // Add a ts- chain without hooknum/priority (regular chain), simulating + // the broken state returned by a kernel without nftables support. + conn.AddChain(&nftables.Chain{ + Name: "ts-input", + Table: table, + }) + if err := conn.Flush(); err != nil { + t.Fatalf("Flush() failed: %v", err) + } + + // Now try getOrCreateChain expecting a base chain with hooknum/priority. + _, err := getOrCreateChain(conn, chainInfo{ + table: table, + name: "ts-input", + chainType: nftables.ChainTypeFilter, + chainHook: nftables.ChainHookInput, + chainPriority: nftables.ChainPriorityFilter, + }) + if err == nil { + t.Fatal("expected error for chain with nil hooknum/priority, got nil") + } + if !strings.Contains(err.Error(), "nil hooknum") { + t.Fatalf("unexpected error: %v", err) + } +} diff --git a/util/pidowner/pidowner.go b/util/pidowner/pidowner.go deleted file mode 100644 index cec92ba36..000000000 --- a/util/pidowner/pidowner.go +++ /dev/null @@ -1,24 +0,0 @@ -// Copyright (c) Tailscale Inc & contributors -// SPDX-License-Identifier: BSD-3-Clause - -// Package pidowner handles lookups from process ID to its owning user. -package pidowner - -import ( - "errors" - "runtime" -) - -var ErrNotImplemented = errors.New("not implemented for GOOS=" + runtime.GOOS) - -var ErrProcessNotFound = errors.New("process not found") - -// OwnerOfPID returns the user ID that owns the given process ID. -// -// The returned user ID is suitable to passing to os/user.LookupId. -// -// The returned error will be ErrNotImplemented for operating systems where -// this isn't supported. -func OwnerOfPID(pid int) (userID string, err error) { - return ownerOfPID(pid) -} diff --git a/util/pidowner/pidowner_linux.go b/util/pidowner/pidowner_linux.go deleted file mode 100644 index f3f5cd97d..000000000 --- a/util/pidowner/pidowner_linux.go +++ /dev/null @@ -1,36 +0,0 @@ -// Copyright (c) Tailscale Inc & contributors -// SPDX-License-Identifier: BSD-3-Clause - -package pidowner - -import ( - "fmt" - "os" - "strings" - - "tailscale.com/util/lineiter" -) - -func ownerOfPID(pid int) (userID string, err error) { - file := fmt.Sprintf("/proc/%d/status", pid) - for lr := range lineiter.File(file) { - line, err := lr.Value() - if err != nil { - if os.IsNotExist(err) { - return "", ErrProcessNotFound - } - return "", err - } - if len(line) < 4 || string(line[:4]) != "Uid:" { - continue - } - f := strings.Fields(string(line)) - if len(f) >= 2 { - userID = f[1] // real userid - } - } - if userID == "" { - return "", fmt.Errorf("missing Uid line in %s", file) - } - return userID, nil -} diff --git a/util/pidowner/pidowner_noimpl.go b/util/pidowner/pidowner_noimpl.go deleted file mode 100644 index 4bc665d61..000000000 --- a/util/pidowner/pidowner_noimpl.go +++ /dev/null @@ -1,8 +0,0 @@ -// Copyright (c) Tailscale Inc & contributors -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !windows && !linux - -package pidowner - -func ownerOfPID(pid int) (userID string, err error) { return "", ErrNotImplemented } diff --git a/util/pidowner/pidowner_test.go b/util/pidowner/pidowner_test.go deleted file mode 100644 index 2774a8ab0..000000000 --- a/util/pidowner/pidowner_test.go +++ /dev/null @@ -1,50 +0,0 @@ -// Copyright (c) Tailscale Inc & contributors -// SPDX-License-Identifier: BSD-3-Clause - -package pidowner - -import ( - "math/rand" - "os" - "os/user" - "testing" -) - -func TestOwnerOfPID(t *testing.T) { - id, err := OwnerOfPID(os.Getpid()) - if err == ErrNotImplemented { - t.Skip(err) - } - if err != nil { - t.Fatal(err) - } - t.Logf("id=%q", id) - - u, err := user.LookupId(id) - if err != nil { - t.Fatalf("LookupId: %v", err) - } - t.Logf("Got: %+v", u) -} - -// validate that OS implementation returns ErrProcessNotFound. -func TestNotFoundError(t *testing.T) { - // Try a bunch of times to stumble upon a pid that doesn't exist... - const tries = 50 - for range tries { - _, err := OwnerOfPID(rand.Intn(1e9)) - if err == ErrNotImplemented { - t.Skip(err) - } - if err == nil { - // We got unlucky and this pid existed. Try again. - continue - } - if err == ErrProcessNotFound { - // Pass. - return - } - t.Fatalf("Error is not ErrProcessNotFound: %T %v", err, err) - } - t.Errorf("after %d tries, couldn't find a process that didn't exist", tries) -} diff --git a/util/pidowner/pidowner_windows.go b/util/pidowner/pidowner_windows.go deleted file mode 100644 index 8edd7698d..000000000 --- a/util/pidowner/pidowner_windows.go +++ /dev/null @@ -1,35 +0,0 @@ -// Copyright (c) Tailscale Inc & contributors -// SPDX-License-Identifier: BSD-3-Clause - -package pidowner - -import ( - "fmt" - "syscall" - - "golang.org/x/sys/windows" -) - -func ownerOfPID(pid int) (userID string, err error) { - procHnd, err := windows.OpenProcess(windows.PROCESS_QUERY_INFORMATION, false, uint32(pid)) - if err == syscall.Errno(0x57) { // invalid parameter, for PIDs that don't exist - return "", ErrProcessNotFound - } - if err != nil { - return "", fmt.Errorf("OpenProcess: %T %#v", err, err) - } - defer windows.CloseHandle(procHnd) - - var tok windows.Token - if err := windows.OpenProcessToken(procHnd, windows.TOKEN_QUERY, &tok); err != nil { - return "", fmt.Errorf("OpenProcessToken: %w", err) - } - - tokUser, err := tok.GetTokenUser() - if err != nil { - return "", fmt.Errorf("GetTokenUser: %w", err) - } - - sid := tokUser.User.Sid - return sid.String(), nil -} diff --git a/util/pool/pool.go b/util/pool/pool.go deleted file mode 100644 index 2e223e577..000000000 --- a/util/pool/pool.go +++ /dev/null @@ -1,211 +0,0 @@ -// Copyright (c) Tailscale Inc & contributors -// SPDX-License-Identifier: BSD-3-Clause - -// Package pool contains a generic type for managing a pool of resources; for -// example, connections to a database, or to a remote service. -// -// Unlike sync.Pool from the Go standard library, this pool does not remove -// items from the pool when garbage collection happens, nor is it safe for -// concurrent use like sync.Pool. -package pool - -import ( - "fmt" - "math/rand/v2" -) - -// consistencyCheck enables additional runtime checks to ensure that the pool -// is well-formed; it is disabled by default, and can be enabled during tests -// to catch additional bugs. -const consistencyCheck = false - -// Pool is a pool of resources. It is not safe for concurrent use. -type Pool[V any] struct { - s []itemAndIndex[V] -} - -type itemAndIndex[V any] struct { - // item is the element in the pool - item V - - // index is the current location of this item in pool.s. It gets set to - // -1 when the item is removed from the pool. - index *int -} - -// Handle is an opaque handle to a resource in a pool. It is used to delete an -// item from the pool, without requiring the item to be comparable. -type Handle[V any] struct { - idx *int // pointer to index; -1 if not in slice -} - -// Len returns the current size of the pool. -func (p *Pool[V]) Len() int { - return len(p.s) -} - -// Clear removes all items from the pool. -func (p *Pool[V]) Clear() { - p.s = nil -} - -// AppendTakeAll removes all items from the pool, appending them to the -// provided slice (which can be nil) and returning them. The returned slice can -// be nil if the provided slice was nil and the pool was empty. -// -// This function does not free the backing storage for the pool; to do that, -// use the Clear function. -func (p *Pool[V]) AppendTakeAll(dst []V) []V { - ret := dst - for i := range p.s { - e := p.s[i] - if consistencyCheck && e.index == nil { - panic(fmt.Sprintf("pool: index is nil at %d", i)) - } - if *e.index >= 0 { - ret = append(ret, p.s[i].item) - } - } - p.s = p.s[:0] - return ret -} - -// Add adds an item to the pool and returns a handle to it. The handle can be -// used to delete the item from the pool with the Delete method. -func (p *Pool[V]) Add(item V) Handle[V] { - // Store the index in a pointer, so that we can pass it to both the - // handle and store it in the itemAndIndex. - idx := new(len(p.s)) - p.s = append(p.s, itemAndIndex[V]{ - item: item, - index: idx, - }) - return Handle[V]{idx} -} - -// Peek will return the item with the given handle without removing it from the -// pool. -// -// It will return ok=false if the item has been deleted or previously taken. -func (p *Pool[V]) Peek(h Handle[V]) (v V, ok bool) { - p.checkHandle(h) - idx := *h.idx - if idx < 0 { - var zero V - return zero, false - } - p.checkIndex(idx) - return p.s[idx].item, true -} - -// Delete removes the item from the pool. -// -// It reports whether the element was deleted; it will return false if the item -// has been taken with the TakeRandom function, or if the item was already -// deleted. -func (p *Pool[V]) Delete(h Handle[V]) bool { - p.checkHandle(h) - idx := *h.idx - if idx < 0 { - return false - } - p.deleteIndex(idx) - return true -} - -func (p *Pool[V]) deleteIndex(idx int) { - // Mark the item as deleted. - p.checkIndex(idx) - *(p.s[idx].index) = -1 - - // If this isn't the last element in the slice, overwrite the element - // at this item's index with the last element. - lastIdx := len(p.s) - 1 - - if idx < lastIdx { - last := p.s[lastIdx] - p.checkElem(lastIdx, last) - *last.index = idx - p.s[idx] = last - } - - // Zero out last element (for GC) and truncate slice. - p.s[lastIdx] = itemAndIndex[V]{} - p.s = p.s[:lastIdx] -} - -// Take will remove the item with the given handle from the pool and return it. -// -// It will return ok=false and the zero value if the item has been deleted or -// previously taken. -func (p *Pool[V]) Take(h Handle[V]) (v V, ok bool) { - p.checkHandle(h) - idx := *h.idx - if idx < 0 { - var zero V - return zero, false - } - - e := p.s[idx] - p.deleteIndex(idx) - return e.item, true -} - -// TakeRandom returns and removes a random element from p -// and reports whether there was one to take. -// -// It will return ok=false and the zero value if the pool is empty. -func (p *Pool[V]) TakeRandom() (v V, ok bool) { - if len(p.s) == 0 { - var zero V - return zero, false - } - pick := rand.IntN(len(p.s)) - e := p.s[pick] - p.checkElem(pick, e) - p.deleteIndex(pick) - return e.item, true -} - -// checkIndex verifies that the provided index is within the bounds of the -// pool's slice, and that the corresponding element has a non-nil index -// pointer, and panics if not. -func (p *Pool[V]) checkIndex(idx int) { - if !consistencyCheck { - return - } - - if idx >= len(p.s) { - panic(fmt.Sprintf("pool: index %d out of range (len %d)", idx, len(p.s))) - } - if p.s[idx].index == nil { - panic(fmt.Sprintf("pool: index is nil at %d", idx)) - } -} - -// checkHandle verifies that the provided handle is not nil, and panics if it -// is. -func (p *Pool[V]) checkHandle(h Handle[V]) { - if !consistencyCheck { - return - } - - if h.idx == nil { - panic("pool: nil handle") - } -} - -// checkElem verifies that the provided itemAndIndex has a non-nil index, and -// that the stored index matches the expected position within the slice. -func (p *Pool[V]) checkElem(idx int, e itemAndIndex[V]) { - if !consistencyCheck { - return - } - - if e.index == nil { - panic("pool: index is nil") - } - if got := *e.index; got != idx { - panic(fmt.Sprintf("pool: index is incorrect: want %d, got %d", idx, got)) - } -} diff --git a/util/pool/pool_test.go b/util/pool/pool_test.go deleted file mode 100644 index ad509a563..000000000 --- a/util/pool/pool_test.go +++ /dev/null @@ -1,203 +0,0 @@ -// Copyright (c) Tailscale Inc & contributors -// SPDX-License-Identifier: BSD-3-Clause - -package pool - -import ( - "slices" - "testing" -) - -func TestPool(t *testing.T) { - p := Pool[int]{} - - if got, want := p.Len(), 0; got != want { - t.Errorf("got initial length %v; want %v", got, want) - } - - h1 := p.Add(101) - h2 := p.Add(102) - h3 := p.Add(103) - h4 := p.Add(104) - - if got, want := p.Len(), 4; got != want { - t.Errorf("got length %v; want %v", got, want) - } - - tests := []struct { - h Handle[int] - want int - }{ - {h1, 101}, - {h2, 102}, - {h3, 103}, - {h4, 104}, - } - for i, test := range tests { - got, ok := p.Peek(test.h) - if !ok { - t.Errorf("test[%d]: did not find item", i) - continue - } - if got != test.want { - t.Errorf("test[%d]: got %v; want %v", i, got, test.want) - } - } - - if deleted := p.Delete(h2); !deleted { - t.Errorf("h2 not deleted") - } - if deleted := p.Delete(h2); deleted { - t.Errorf("h2 should not be deleted twice") - } - if got, want := p.Len(), 3; got != want { - t.Errorf("got length %v; want %v", got, want) - } - if _, ok := p.Peek(h2); ok { - t.Errorf("h2 still in pool") - } - - // Remove an item by handle - got, ok := p.Take(h4) - if !ok { - t.Errorf("h4 not found") - } - if got != 104 { - t.Errorf("got %v; want 104", got) - } - - // Take doesn't work on previously-taken or deleted items. - if _, ok := p.Take(h4); ok { - t.Errorf("h4 should not be taken twice") - } - if _, ok := p.Take(h2); ok { - t.Errorf("h2 should not be taken after delete") - } - - // Remove all items and return them - items := p.AppendTakeAll(nil) - want := []int{101, 103} - if !slices.Equal(items, want) { - t.Errorf("got items %v; want %v", items, want) - } - if got := p.Len(); got != 0 { - t.Errorf("got length %v; want 0", got) - } - - // Insert and then clear should result in no items. - p.Add(105) - p.Clear() - if got := p.Len(); got != 0 { - t.Errorf("got length %v; want 0", got) - } -} - -func TestTakeRandom(t *testing.T) { - p := Pool[int]{} - for i := range 10 { - p.Add(i + 100) - } - - seen := make(map[int]bool) - for range 10 { - item, ok := p.TakeRandom() - if !ok { - t.Errorf("unexpected empty pool") - break - } - if seen[item] { - t.Errorf("got duplicate item %v", item) - } - seen[item] = true - } - - // Verify that the pool is empty - if _, ok := p.TakeRandom(); ok { - t.Errorf("expected empty pool") - } - - for i := range 10 { - want := 100 + i - if !seen[want] { - t.Errorf("item %v not seen", want) - } - } - - if t.Failed() { - t.Logf("seen: %+v", seen) - } -} - -func BenchmarkPool_AddDelete(b *testing.B) { - b.Run("impl=Pool", func(b *testing.B) { - p := Pool[int]{} - - // Warm up/force an initial allocation - h := p.Add(0) - p.Delete(h) - - b.ResetTimer() - - for i := 0; i < b.N; i++ { - h := p.Add(i) - p.Delete(h) - } - }) - b.Run("impl=map", func(b *testing.B) { - p := make(map[int]bool) - - // Force initial allocation - p[0] = true - delete(p, 0) - - b.ResetTimer() - - for i := 0; i < b.N; i++ { - p[i] = true - delete(p, i) - } - }) -} - -func BenchmarkPool_TakeRandom(b *testing.B) { - b.Run("impl=Pool", func(b *testing.B) { - p := Pool[int]{} - - // Insert the number of items we'll be taking, then reset the timer. - for i := 0; i < b.N; i++ { - p.Add(i) - } - b.ResetTimer() - - // Now benchmark taking all the items. - for i := 0; i < b.N; i++ { - p.TakeRandom() - } - - if p.Len() != 0 { - b.Errorf("pool not empty") - } - }) - b.Run("impl=map", func(b *testing.B) { - p := make(map[int]bool) - - // Insert the number of items we'll be taking, then reset the timer. - for i := 0; i < b.N; i++ { - p[i] = true - } - b.ResetTimer() - - // Now benchmark taking all the items. - for i := 0; i < b.N; i++ { - // Taking a random item is simulated by a single map iteration. - for k := range p { - delete(p, k) // "take" the item by removing it - break - } - } - - if len(p) != 0 { - b.Errorf("map not empty") - } - }) -} diff --git a/util/sysresources/memory.go b/util/sysresources/memory.go deleted file mode 100644 index 3c6b9ae85..000000000 --- a/util/sysresources/memory.go +++ /dev/null @@ -1,10 +0,0 @@ -// Copyright (c) Tailscale Inc & contributors -// SPDX-License-Identifier: BSD-3-Clause - -package sysresources - -// TotalMemory returns the total accessible system memory, in bytes. If the -// value cannot be determined, then 0 will be returned. -func TotalMemory() uint64 { - return totalMemoryImpl() -} diff --git a/util/sysresources/memory_bsd.go b/util/sysresources/memory_bsd.go deleted file mode 100644 index 945f86ea3..000000000 --- a/util/sysresources/memory_bsd.go +++ /dev/null @@ -1,16 +0,0 @@ -// Copyright (c) Tailscale Inc & contributors -// SPDX-License-Identifier: BSD-3-Clause - -//go:build freebsd || openbsd || dragonfly || netbsd - -package sysresources - -import "golang.org/x/sys/unix" - -func totalMemoryImpl() uint64 { - val, err := unix.SysctlUint64("hw.physmem") - if err != nil { - return 0 - } - return val -} diff --git a/util/sysresources/memory_darwin.go b/util/sysresources/memory_darwin.go deleted file mode 100644 index 165f12eb3..000000000 --- a/util/sysresources/memory_darwin.go +++ /dev/null @@ -1,16 +0,0 @@ -// Copyright (c) Tailscale Inc & contributors -// SPDX-License-Identifier: BSD-3-Clause - -//go:build darwin - -package sysresources - -import "golang.org/x/sys/unix" - -func totalMemoryImpl() uint64 { - val, err := unix.SysctlUint64("hw.memsize") - if err != nil { - return 0 - } - return val -} diff --git a/util/sysresources/memory_linux.go b/util/sysresources/memory_linux.go deleted file mode 100644 index 3885a8aa6..000000000 --- a/util/sysresources/memory_linux.go +++ /dev/null @@ -1,19 +0,0 @@ -// Copyright (c) Tailscale Inc & contributors -// SPDX-License-Identifier: BSD-3-Clause - -//go:build linux - -package sysresources - -import "golang.org/x/sys/unix" - -func totalMemoryImpl() uint64 { - var info unix.Sysinfo_t - - if err := unix.Sysinfo(&info); err != nil { - return 0 - } - - // uint64 casts are required since these might be uint32s - return uint64(info.Totalram) * uint64(info.Unit) -} diff --git a/util/sysresources/memory_unsupported.go b/util/sysresources/memory_unsupported.go deleted file mode 100644 index c88e9ed52..000000000 --- a/util/sysresources/memory_unsupported.go +++ /dev/null @@ -1,8 +0,0 @@ -// Copyright (c) Tailscale Inc & contributors -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !(linux || darwin || freebsd || openbsd || dragonfly || netbsd) - -package sysresources - -func totalMemoryImpl() uint64 { return 0 } diff --git a/util/sysresources/sysresources.go b/util/sysresources/sysresources.go deleted file mode 100644 index 33d0d5d96..000000000 --- a/util/sysresources/sysresources.go +++ /dev/null @@ -1,6 +0,0 @@ -// Copyright (c) Tailscale Inc & contributors -// SPDX-License-Identifier: BSD-3-Clause - -// Package sysresources provides OS-independent methods of determining the -// resources available to the current system. -package sysresources diff --git a/util/sysresources/sysresources_test.go b/util/sysresources/sysresources_test.go deleted file mode 100644 index 7fea1bf0f..000000000 --- a/util/sysresources/sysresources_test.go +++ /dev/null @@ -1,25 +0,0 @@ -// Copyright (c) Tailscale Inc & contributors -// SPDX-License-Identifier: BSD-3-Clause - -package sysresources - -import ( - "runtime" - "testing" -) - -func TestTotalMemory(t *testing.T) { - switch runtime.GOOS { - case "linux": - case "freebsd", "openbsd", "dragonfly", "netbsd": - case "darwin": - default: - t.Skipf("not supported on runtime.GOOS=%q yet", runtime.GOOS) - } - - mem := TotalMemory() - if mem == 0 { - t.Fatal("wanted TotalMemory > 0") - } - t.Logf("total memory: %v bytes", mem) -} diff --git a/util/topk/topk.go b/util/topk/topk.go deleted file mode 100644 index 95ebd895d..000000000 --- a/util/topk/topk.go +++ /dev/null @@ -1,261 +0,0 @@ -// Copyright (c) Tailscale Inc & contributors -// SPDX-License-Identifier: BSD-3-Clause - -// Package topk defines a count-min sketch and a cheap probabilistic top-K data -// structure that uses the count-min sketch to track the top K items in -// constant memory and O(log(k)) time. -package topk - -import ( - "container/heap" - "hash/maphash" - "math" - "slices" - "sync" -) - -// TopK is a probabilistic counter of the top K items, using a count-min sketch -// to keep track of item counts and a heap to track the top K of them. -type TopK[T any] struct { - heap minHeap[T] - k int - sf SerializeFunc[T] - cms CountMinSketch -} - -// HashFunc is responsible for providing a []byte serialization of a value, -// appended to the provided byte slice. This is used for hashing the value when -// adding to a CountMinSketch. -type SerializeFunc[T any] func([]byte, T) []byte - -// New creates a new TopK that stores k values. Parameters for the underlying -// count-min sketch are chosen for a 0.1% error rate and a 0.1% probability of -// error. -func New[T any](k int, sf SerializeFunc[T]) *TopK[T] { - hashes, buckets := PickParams(0.001, 0.001) - return NewWithParams(k, sf, hashes, buckets) -} - -// NewWithParams creates a new TopK that stores k values, and additionally -// allows customizing the parameters for the underlying count-min sketch. -func NewWithParams[T any](k int, sf SerializeFunc[T], numHashes, numCols int) *TopK[T] { - ret := &TopK[T]{ - heap: make(minHeap[T], 0, k), - k: k, - sf: sf, - } - ret.cms.init(numHashes, numCols) - return ret -} - -// Add calls AddN(val, 1). -func (tk *TopK[T]) Add(val T) uint64 { - return tk.AddN(val, 1) -} - -var hashPool = &sync.Pool{ - New: func() any { - buf := make([]byte, 0, 128) - return &buf - }, -} - -// AddN adds the given item to the set with the provided count, returning the -// new estimated count. -func (tk *TopK[T]) AddN(val T, count uint64) uint64 { - buf := hashPool.Get().(*[]byte) - defer hashPool.Put(buf) - ser := tk.sf((*buf)[:0], val) - - vcount := tk.cms.AddN(ser, count) - - // If we don't have a full heap, just push it. - if len(tk.heap) < tk.k { - heap.Push(&tk.heap, mhValue[T]{ - count: vcount, - val: val, - }) - return vcount - } - - // If this item's count surpasses the heap's minimum, update the heap. - if vcount > tk.heap[0].count { - tk.heap[0] = mhValue[T]{ - count: vcount, - val: val, - } - heap.Fix(&tk.heap, 0) - } - return vcount -} - -// Top returns the estimated top K items as stored by this TopK. -func (tk *TopK[T]) Top() []T { - ret := make([]T, 0, tk.k) - for _, item := range tk.heap { - ret = append(ret, item.val) - } - return ret -} - -// AppendTop appends the estimated top K items as stored by this TopK to the -// provided slice, allocating only if the slice does not have enough capacity -// to store all items. The provided slice can be nil. -func (tk *TopK[T]) AppendTop(sl []T) []T { - sl = slices.Grow(sl, tk.k) - for _, item := range tk.heap { - sl = append(sl, item.val) - } - return sl -} - -// CountMinSketch implements a count-min sketch, a probabilistic data structure -// that tracks the frequency of events in a stream of data. -// -// See: https://en.wikipedia.org/wiki/Count%E2%80%93min_sketch -type CountMinSketch struct { - hashes []maphash.Seed - nbuckets int - matrix []uint64 -} - -// NewCountMinSketch creates a new CountMinSketch with the provided number of -// hashes and buckets. Hashes and buckets are often called "depth" and "width", -// or "d" and "w", respectively. -func NewCountMinSketch(hashes, buckets int) *CountMinSketch { - ret := &CountMinSketch{} - ret.init(hashes, buckets) - return ret -} - -// PickParams provides good parameters for 'hashes' and 'buckets' when -// constructing a CountMinSketch, given an estimated total number of counts -// (i.e. the sum of all counts ever stored), the error factor ϵ as a float -// (e.g. 1% is 0.001), and the probability factor δ. -// -// Parameters are chosen such that with a probability of 1−δ, the error is at -// most ϵ∗totalCount. Or, in other words: if N is the true count of an event, -// E is the estimate given by a sketch and T the total count of items in the -// sketch, E ≤ N + T*ϵ with probability (1 - δ). -func PickParams(err, probability float64) (hashes, buckets int) { - d := math.Ceil(math.Log(1 / probability)) - w := math.Ceil(math.E / err) - - return int(d), int(w) -} - -func (cms *CountMinSketch) init(hashes, buckets int) { - for range hashes { - cms.hashes = append(cms.hashes, maphash.MakeSeed()) - } - - // Need a matrix of hashes * buckets to store counts - cms.nbuckets = buckets - cms.matrix = make([]uint64, hashes*buckets) -} - -// Add calls AddN(val, 1). -func (cms *CountMinSketch) Add(val []byte) uint64 { - return cms.AddN(val, 1) -} - -// AddN increments the count for the given value by the provided count, -// returning the new count. -func (cms *CountMinSketch) AddN(val []byte, count uint64) uint64 { - var ( - mh maphash.Hash - ret uint64 = math.MaxUint64 - ) - for i, seed := range cms.hashes { - mh.SetSeed(seed) - - // Generate a hash for this value using Lemire's alternative to modular reduction: - // https://lemire.me/blog/2016/06/27/a-fast-alternative-to-the-modulo-reduction/ - mh.Write(val) - hash := mh.Sum64() - hash = multiplyHigh64(hash, uint64(cms.nbuckets)) - - // The index in our matrix is (i * buckets) to move "down" i - // rows in our matrix to the row for this hash, plus 'hash' to - // move inside this row. - idx := (i * cms.nbuckets) + int(hash) - - // Add to this row - cms.matrix[idx] += count - ret = min(ret, cms.matrix[idx]) - } - return ret -} - -// Get returns the count for the provided value. -func (cms *CountMinSketch) Get(val []byte) uint64 { - var ( - mh maphash.Hash - ret uint64 = math.MaxUint64 - ) - for i, seed := range cms.hashes { - mh.SetSeed(seed) - - // Generate a hash for this value using Lemire's alternative to modular reduction: - // https://lemire.me/blog/2016/06/27/a-fast-alternative-to-the-modulo-reduction/ - mh.Write(val) - hash := mh.Sum64() - hash = multiplyHigh64(hash, uint64(cms.nbuckets)) - - // The index in our matrix is (i * buckets) to move "down" i - // rows in our matrix to the row for this hash, plus 'hash' to - // move inside this row. - idx := (i * cms.nbuckets) + int(hash) - - // Select the minimal value among all rows - ret = min(ret, cms.matrix[idx]) - } - return ret -} - -// multiplyHigh64 implements (x * y) >> 64 "the long way" without access to a -// 128-bit type. This function is adapted from something similar in Tensorflow: -// -// https://github.com/tensorflow/tensorflow/commit/a47a300185026fe7829990def9113bf3a5109fed -// -// TODO(andrew-d): this could be replaced with a single "MULX" instruction on -// x86_64 platforms, which we can do if this ever turns out to be a performance -// bottleneck. -func multiplyHigh64(x, y uint64) uint64 { - x_lo := x & 0xffffffff - x_hi := x >> 32 - buckets_lo := y & 0xffffffff - buckets_hi := y >> 32 - prod_hi := x_hi * buckets_hi - prod_lo := x_lo * buckets_lo - prod_mid1 := x_hi * buckets_lo - prod_mid2 := x_lo * buckets_hi - carry := ((prod_mid1 & 0xffffffff) + (prod_mid2 & 0xffffffff) + (prod_lo >> 32)) >> 32 - return prod_hi + (prod_mid1 >> 32) + (prod_mid2 >> 32) + carry -} - -type mhValue[T any] struct { - count uint64 - val T -} - -// An minHeap is a min-heap of ints and associated values. -type minHeap[T any] []mhValue[T] - -func (h minHeap[T]) Len() int { return len(h) } -func (h minHeap[T]) Less(i, j int) bool { return h[i].count < h[j].count } -func (h minHeap[T]) Swap(i, j int) { h[i], h[j] = h[j], h[i] } - -func (h *minHeap[T]) Push(x any) { - // Push and Pop use pointer receivers because they modify the slice's length, - // not just its contents. - *h = append(*h, x.(mhValue[T])) -} - -func (h *minHeap[T]) Pop() any { - old := *h - n := len(old) - x := old[n-1] - *h = old[0 : n-1] - return x -} diff --git a/util/topk/topk_test.go b/util/topk/topk_test.go deleted file mode 100644 index 7679f59a3..000000000 --- a/util/topk/topk_test.go +++ /dev/null @@ -1,135 +0,0 @@ -// Copyright (c) Tailscale Inc & contributors -// SPDX-License-Identifier: BSD-3-Clause - -package topk - -import ( - "encoding/binary" - "fmt" - "slices" - "testing" -) - -func TestCountMinSketch(t *testing.T) { - cms := NewCountMinSketch(4, 10) - items := []string{"foo", "bar", "baz", "asdf", "quux"} - for _, item := range items { - cms.Add([]byte(item)) - } - for _, item := range items { - count := cms.Get([]byte(item)) - if count < 1 { - t.Errorf("item %q should have count >= 1", item) - } else if count > 1 { - t.Logf("item %q has count > 1: %d", item, count) - } - } - - // Test that an item that's *not* in the set has a value lower than the - // total number of items we inserted (in the case that all items - // collided). - noItemCount := cms.Get([]byte("doesn't exist")) - if noItemCount > uint64(len(items)) { - t.Errorf("expected nonexistent item to have value < %d; got %d", len(items), noItemCount) - } -} - -func TestTopK(t *testing.T) { - // This is probabilistic, so we're going to try 10 times to get the - // "right" value; the likelihood that we fail on all attempts is - // vanishingly small since the number of hash buckets is drastically - // larger than the number of items we're inserting. - var ( - got []int - want = []int{5, 6, 7, 8, 9} - ) - for range 10 { - topk := NewWithParams[int](5, func(in []byte, val int) []byte { - return binary.LittleEndian.AppendUint64(in, uint64(val)) - }, 4, 1000) - - // Add the first 10 integers with counts equal to 2x their value - for i := range 10 { - topk.AddN(i, uint64(i*2)) - } - - got = topk.Top() - t.Logf("top K items: %+v", got) - slices.Sort(got) - - if slices.Equal(got, want) { - // All good! - return - } - - // continue and retry or fail - } - - t.Errorf("top K mismatch\ngot: %v\nwant: %v", got, want) -} - -func TestPickParams(t *testing.T) { - hashes, buckets := PickParams( - 0.001, // 0.1% error rate - 0.001, // 0.1% chance of having an error, or 99.9% chance of not having an error - ) - t.Logf("hashes = %d, buckets = %d", hashes, buckets) -} - -func BenchmarkCountMinSketch(b *testing.B) { - cms := NewCountMinSketch(PickParams(0.001, 0.001)) - b.ResetTimer() - b.ReportAllocs() - - var enc [8]byte - for i := range b.N { - binary.LittleEndian.PutUint64(enc[:], uint64(i)) - cms.Add(enc[:]) - } -} - -func BenchmarkTopK(b *testing.B) { - for _, n := range []int{ - 10, - 128, - 256, - 1024, - 8192, - } { - b.Run(fmt.Sprintf("Top%d", n), func(b *testing.B) { - out := make([]int, 0, n) - topk := New[int](n, func(in []byte, val int) []byte { - return binary.LittleEndian.AppendUint64(in, uint64(val)) - }) - b.ResetTimer() - b.ReportAllocs() - - for i := range b.N { - topk.Add(i) - } - out = topk.AppendTop(out[:0]) // should not allocate - _ = out // appease linter - }) - } -} - -func TestMultiplyHigh64(t *testing.T) { - testCases := []struct { - x, y uint64 - want uint64 - }{ - {0, 0, 0}, - {0xffffffff, 0xffffffff, 0}, - {0x2, 0xf000000000000000, 1}, - {0x3, 0xf000000000000000, 2}, - {0x3, 0xf000000000000001, 2}, - {0x3, 0xffffffffffffffff, 2}, - {0xffffffffffffffff, 0xffffffffffffffff, 0xfffffffffffffffe}, - } - for _, tc := range testCases { - got := multiplyHigh64(tc.x, tc.y) - if got != tc.want { - t.Errorf("got multiplyHigh64(%x, %x) = %x, want %x", tc.x, tc.y, got, tc.want) - } - } -} diff --git a/util/winutil/gp/gp_windows_test.go b/util/winutil/gp/gp_windows_test.go index dfad02930..1ce75871e 100644 --- a/util/winutil/gp/gp_windows_test.go +++ b/util/winutil/gp/gp_windows_test.go @@ -106,23 +106,13 @@ func TestGroupPolicyReadLockClose(t *testing.T) { doWithCustomEnterLeaveFuncs(t, func(gpLock *PolicyLock) { done := make(chan struct{}) + var lockErr error go func() { defer close(done) - - err := gpLock.Lock() - if err == nil { + lockErr = gpLock.Lock() + if lockErr == nil { defer gpLock.Unlock() } - - // We closed gpLock before the enter function returned. - // (*PolicyLock).Lock is expected to fail. - if err == nil || !errors.Is(err, ErrInvalidLockState) { - t.Errorf("(*PolicyLock).Lock: got %v; want %v", err, ErrInvalidLockState) - } - // gpLock must not be held as Lock() failed. - if lockCnt := gpLock.lockCnt.Load(); lockCnt != 0 { - t.Errorf("lockCnt: got %v; want 0", lockCnt) - } }() <-init @@ -130,7 +120,22 @@ func TestGroupPolicyReadLockClose(t *testing.T) { if err := gpLock.Close(); err != nil { t.Fatalf("(*PolicyLock).Close failed: %v", err) } + // Wait for Lock to fully unwind before reading lockCnt. + // Otherwise we race with Close clearing the LSB: + // close(lk.closing) wakes lockSlow's select, whose defer + // runs Add(-2) on lockCnt before Close completes its CAS, + // briefly leaving lockCnt at 1 instead of 0. <-done + + // We closed gpLock before the enter function returned. + // (*PolicyLock).Lock is expected to fail. + if lockErr == nil || !errors.Is(lockErr, ErrInvalidLockState) { + t.Errorf("(*PolicyLock).Lock: got %v; want %v", lockErr, ErrInvalidLockState) + } + // gpLock must not be held as Lock() failed. + if lockCnt := gpLock.lockCnt.Load(); lockCnt != 0 { + t.Errorf("lockCnt: got %v; want 0", lockCnt) + } }, enter, leave) } diff --git a/version/cmdname.go b/version/cmdname.go index 5a0b84875..8e6adb047 100644 --- a/version/cmdname.go +++ b/version/cmdname.go @@ -13,6 +13,7 @@ import ( "os" "path" "runtime" + "runtime/debug" "strings" ) @@ -20,6 +21,15 @@ import ( // using os.Executable. If os.Executable fails (it shouldn't), then // "cmd" is returned. func CmdName() string { + // On non-Windows, the modinfo embedded in the running binary is + // authoritative and avoids re-reading the executable from disk. + // Windows needs the executable-name-based GUI override in cmdName, + // so it still takes the slower path. + if runtime.GOOS != "windows" { + if info, ok := debug.ReadBuildInfo(); ok && info.Path != "" { + return path.Base(info.Path) + } + } e, err := os.Executable() if err != nil { return "cmd" diff --git a/version/print.go b/version/print.go index ca62226ee..3b4a256cf 100644 --- a/version/print.go +++ b/version/print.go @@ -24,7 +24,14 @@ var stringLazy = sync.OnceValue(func() string { if extraGitCommitStamp != "" { fmt.Fprintf(&ret, " other commit: %s\n", extraGitCommitStamp) } - fmt.Fprintf(&ret, " go version: %s\n", runtime.Version()) + if tsGoRev := tailscaleToolchainRev(); tsGoRev != "" { + if len(tsGoRev) > 10 { + tsGoRev = tsGoRev[:10] + } + fmt.Fprintf(&ret, " go version: %s (tailscale/go %s)\n", runtime.Version(), tsGoRev) + } else { + fmt.Fprintf(&ret, " go version: %s\n", runtime.Version()) + } return strings.TrimSpace(ret.String()) }) diff --git a/version/prop.go b/version/prop.go index 36d769917..59ca74086 100644 --- a/version/prop.go +++ b/version/prop.go @@ -312,6 +312,11 @@ type Meta struct { // GitCommitTime is the commit time of the git commit in GitCommit. GitCommitTime string `json:"gitCommitTime,omitempty"` + // TailscaleGoGitHash is the git commit hash from + // https://github.com/tailscale/go used to build this binary, if built + // with the Tailscale Go toolchain. Otherwise it is empty. + TailscaleGoGitHash string `json:"tailscaleGoGitHash,omitempty"` + // Cap is the current Tailscale capability version. It's a monotonically // incrementing integer that's incremented whenever a new capability is // added. @@ -324,17 +329,18 @@ var getMeta lazy.SyncValue[Meta] func GetMeta() Meta { return getMeta.Get(func() Meta { return Meta{ - MajorMinorPatch: majorMinorPatch(), - Short: Short(), - Long: Long(), - GitCommitTime: getEmbeddedInfo().commitTime, - GitCommit: gitCommit(), - GitDirty: gitDirty(), - OSVariant: osVariant(), - ExtraGitCommit: extraGitCommitStamp, - IsDev: isDev(), - UnstableBranch: IsUnstableBuild(), - Cap: int(tailcfg.CurrentCapabilityVersion), + MajorMinorPatch: majorMinorPatch(), + Short: Short(), + Long: Long(), + GitCommitTime: getEmbeddedInfo().commitTime, + GitCommit: gitCommit(), + GitDirty: gitDirty(), + OSVariant: osVariant(), + ExtraGitCommit: extraGitCommitStamp, + IsDev: isDev(), + UnstableBranch: IsUnstableBuild(), + TailscaleGoGitHash: tailscaleToolchainRev(), + Cap: int(tailcfg.CurrentCapabilityVersion), } }) } diff --git a/version/version.go b/version/version.go index 7d8efc375..8ffc21832 100644 --- a/version/version.go +++ b/version/version.go @@ -146,6 +146,23 @@ var getEmbeddedInfo = sync.OnceValue(func() embeddedInfo { return ret }) +// tailscaleToolchainRev returns the git hash of the Tailscale Go toolchain +// used to build this binary, if any. It is read separately from getEmbeddedInfo +// because that function discards build info when VCS fields are missing (e.g. +// in test binaries), but the toolchain rev is still present. +var tailscaleToolchainRev = sync.OnceValue(func() string { + bi, ok := debug.ReadBuildInfo() + if !ok { + return "" + } + for _, s := range bi.Settings { + if s.Key == "tailscale.toolchain.rev" { + return s.Value + } + } + return "" +}) + func gitCommit() string { if gitCommitStamp != "" { return gitCommitStamp diff --git a/version/version_internal_test.go b/version/version_internal_test.go index c78df4ff8..72b2dcd5f 100644 --- a/version/version_internal_test.go +++ b/version/version_internal_test.go @@ -3,7 +3,13 @@ package version -import "testing" +import ( + "os/exec" + "strings" + "testing" + + "tailscale.com/util/cibuild" +) func TestIsValidLongWithTwoRepos(t *testing.T) { tests := []struct { @@ -26,6 +32,26 @@ func TestIsValidLongWithTwoRepos(t *testing.T) { } } +func TestTailscaleToolchainRev(t *testing.T) { + out, err := exec.Command("go", "env", "GOROOT").Output() + if err != nil { + t.Fatalf("go env GOROOT: %v", err) + } + goRoot := strings.TrimSpace(string(out)) + isTsgo := strings.Contains(goRoot, "/.cache/tsgo/") + if !cibuild.On() && !isTsgo { + t.Skip("skipping; not in CI and not using the Tailscale Go toolchain") + } + if !isTailscaleGo { + t.Skip("skipping; not built with tailscale_go build tag") + } + rev := tailscaleToolchainRev() + if rev == "" { + t.Fatal("tailscale.toolchain.rev is empty in build info; expected non-empty when using tsgo") + } + t.Logf("tailscale.toolchain.rev = %s", rev) +} + func TestPrepExeNameForCmp(t *testing.T) { cases := []struct { exe string diff --git a/version/version_test.go b/version/version_test.go index 42bcf2163..01fcd47ec 100644 --- a/version/version_test.go +++ b/version/version_test.go @@ -6,6 +6,8 @@ package version_test import ( "bytes" "os" + "path" + "runtime/debug" "testing" ts "tailscale.com" @@ -49,3 +51,21 @@ func TestShortAllocs(t *testing.T) { t.Errorf("allocs = %v; want 0", allocs) } } + +func BenchmarkCmdName(b *testing.B) { + b.ReportAllocs() + for b.Loop() { + _ = version.CmdName() + } +} + +func BenchmarkReadBuildInfo(b *testing.B) { + b.ReportAllocs() + for b.Loop() { + info, ok := debug.ReadBuildInfo() + if !ok { + b.Fatal("ReadBuildInfo failed") + } + _ = path.Base(info.Path) + } +} diff --git a/wgengine/magicsock/debughttp.go b/wgengine/magicsock/debughttp.go index 68019d0a7..a9f4734f9 100644 --- a/wgengine/magicsock/debughttp.go +++ b/wgengine/magicsock/debughttp.go @@ -108,8 +108,8 @@ func (c *Conn) ServeHTTPDebug(w http.ResponseWriter, r *http.Request) { } sort.Slice(ent, func(i, j int) bool { return ent[i].pub.Less(ent[j].pub) }) - peers := map[key.NodePublic]tailcfg.NodeView{} - for _, p := range c.peers.All() { + peers := make(map[key.NodePublic]tailcfg.NodeView, len(c.peersByID)) + for _, p := range c.peersByID { peers[p.Key()] = p } diff --git a/wgengine/magicsock/derp.go b/wgengine/magicsock/derp.go index 1cab52b93..72c75db5a 100644 --- a/wgengine/magicsock/derp.go +++ b/wgengine/magicsock/derp.go @@ -102,6 +102,7 @@ type activeDerp struct { var ( pickDERPFallbackForTests func() int + reSTUNHookForTests func(why string) ) // pickDERPFallback returns a non-zero but deterministic DERP node to @@ -155,7 +156,7 @@ var checkControlHealthDuringNearestDERPInTests = false // region that it selected and set (via setNearestDERP). // // c.mu must NOT be held. -func (c *Conn) maybeSetNearestDERP(report *netcheck.Report) (preferredDERP int) { +func (c *Conn) maybeSetNearestDERP(report *netcheck.Report, force bool) (preferredDERP int) { // Don't change our PreferredDERP if we don't have a connection to // control; if we don't, then we can't inform peers about a DERP home // change, which breaks all connectivity. Even if this DERP region is @@ -169,7 +170,10 @@ func (c *Conn) maybeSetNearestDERP(report *netcheck.Report) (preferredDERP int) // // Despite the above behaviour, ensure that we set the nearest DERP if // we don't currently have one set; any DERP server is better than - // none, even if not connected to control. + // none, even if not connected to control. The exception here is if we have + // a cached netmap with a previous DERP server. Retaining the previous DERP + // makes it easier for other nodes to find each other when control is not + // available. var connectedToControl bool if testenv.InTest() && !checkControlHealthDuringNearestDERPInTests { connectedToControl = true @@ -179,7 +183,7 @@ func (c *Conn) maybeSetNearestDERP(report *netcheck.Report) (preferredDERP int) c.mu.Lock() myDerp := c.myDerp c.mu.Unlock() - if !connectedToControl { + if !connectedToControl && !force { if myDerp != 0 { metricDERPHomeNoChangeNoControl.Add(1) return myDerp @@ -198,15 +202,32 @@ func (c *Conn) maybeSetNearestDERP(report *netcheck.Report) (preferredDERP int) } if preferredDERP != myDerp { c.logf( - "magicsock: home DERP changing from derp-%d [%dms] to derp-%d [%dms]", - c.myDerp, report.RegionLatency[myDerp].Milliseconds(), preferredDERP, report.RegionLatency[preferredDERP].Milliseconds()) + "magicsock: home DERP changing from derp-%d [%dms] to derp-%d [%dms] (forced=%t)", + c.myDerp, report.RegionLatency[myDerp].Milliseconds(), preferredDERP, report.RegionLatency[preferredDERP].Milliseconds(), force) } if !c.setNearestDERP(preferredDERP) { preferredDERP = 0 + } else if preferredDERP != myDerp { + c.homeDERPChangedPub.Publish(HomeDERPChanged{Old: myDerp, New: preferredDERP}) } return } +// HomeDERPChanged is an event sent on the [eventbus.Bus] when a new home DERP +// server has been selected. Its publisher is [magicsock.Coon]; its main +// subscriber is [ipnlocal.LocalBackend] that updates the homeDERP used by the +// netmap cache. +// TODO(cmol): Move the subscriber to not inject into localBackend, but rather +// into the netmap at the controlClient mapSession level once there is a stable +// abstraction to use. +type HomeDERPChanged struct { + Old, New int +} + +func (c *Conn) ForceSetNearestDERP(regionID int) int { + return c.maybeSetNearestDERP(&netcheck.Report{PreferredDERP: regionID}, true) +} + func (c *Conn) derpRegionCodeLocked(regionID int) string { if c.derpMap == nil { return "" @@ -771,7 +792,24 @@ func (c *Conn) SetOnlyTCP443(v bool) { // SetDERPMap controls which (if any) DERP servers are used. // A nil value means to disable DERP; it's disabled by default. +// +// SetDERPMap triggers a ReSTUN after updating the map. Callers that want to +// set the map without triggering a ReSTUN should use [Conn.SetDERPMapWithoutReSTUN] +// instead. func (c *Conn) SetDERPMap(dm *tailcfg.DERPMap) { + c.setDERPMap(dm, true) +} + +// SetDERPMapWithoutReSTUN is like [Conn.SetDERPMap] but does not trigger a +// ReSTUN after updating the map. +// +// It is used for setting the map from a cache, so the homeDERP can be set +// from cache before any STUN happens. +func (c *Conn) SetDERPMapWithoutReSTUN(dm *tailcfg.DERPMap) { + c.setDERPMap(dm, false) +} + +func (c *Conn) setDERPMap(dm *tailcfg.DERPMap, doReStun bool) { c.mu.Lock() defer c.mu.Unlock() @@ -828,8 +866,14 @@ func (c *Conn) SetDERPMap(dm *tailcfg.DERPMap) { } } - go c.ReSTUN("derp-map-update") + if doReStun { + if reSTUNHookForTests != nil { + reSTUNHookForTests("derp-map-update") + } + go c.ReSTUN("derp-map-update") + } } + func (c *Conn) wantDerpLocked() bool { return c.derpMap != nil } // c.mu must be held. diff --git a/wgengine/magicsock/derp_test.go b/wgengine/magicsock/derp_test.go index 084f710d8..c79882d54 100644 --- a/wgengine/magicsock/derp_test.go +++ b/wgengine/magicsock/derp_test.go @@ -4,9 +4,15 @@ package magicsock import ( + "fmt" "testing" + "tailscale.com/health" "tailscale.com/net/netcheck" + "tailscale.com/tailcfg" + "tailscale.com/tstest" + "tailscale.com/util/eventbus" + "tailscale.com/util/eventbus/eventbustest" ) func CheckDERPHeuristicTimes(t *testing.T) { @@ -14,3 +20,111 @@ func CheckDERPHeuristicTimes(t *testing.T) { t.Errorf("PreferredDERPFrameTime too low; should be at least frameReceiveRecordRate") } } + +func TestForceSetNearestDERP(t *testing.T) { + derpMap := &tailcfg.DERPMap{ + Regions: map[int]*tailcfg.DERPRegion{ + 7: { + RegionID: 7, + RegionCode: "test", + Nodes: []*tailcfg.DERPNode{ + { + Name: "7a", + RegionID: 7, + HostName: "derp7.test.unused", + IPv4: "127.0.0.1", + IPv6: "none", + }, + }, + }, + }, + } + + // Force the real control health check so we can verify force=true bypasses it. + tstest.Replace(t, &checkControlHealthDuringNearestDERPInTests, true) + + bus := eventbustest.NewBus(t) + ht := health.NewTracker(bus) + c := newConn(t.Logf) + ec := bus.Client("magicsock.Conn.Test") + c.eventClient = ec + c.homeDERPChangedPub = eventbus.Publish[HomeDERPChanged](ec) + c.eventBus = bus + c.derpMap = derpMap + c.health = ht + + ht.SetOutOfPollNetMap() + + tw := eventbustest.NewWatcher(t, bus) + + got := c.ForceSetNearestDERP(7) + if got != 7 { + t.Fatalf("ForceSetNearestDERP(7) = %d, want 7", got) + } + if c.myDerp != 7 { + t.Errorf("c.myDerp = %d after ForceSetNearestDERP, want 7", c.myDerp) + } + + if err := eventbustest.Expect(tw, func(e HomeDERPChanged) error { + if e.Old != 0 || e.New != 7 { + return fmt.Errorf("got HomeDERPChanged{Old:%d, New:%d}, want {Old:0, New:7}", e.Old, e.New) + } + return nil + }); err != nil { + t.Errorf("expected HomeDERPChanged event: %v", err) + } +} + +func TestSetDERPMapDoReStun(t *testing.T) { + derpMap1 := &tailcfg.DERPMap{ + Regions: map[int]*tailcfg.DERPRegion{ + 1: { + RegionID: 1, + RegionCode: "cph", + Nodes: []*tailcfg.DERPNode{ + {Name: "1a", RegionID: 1, HostName: "cph.test.unused", IPv4: "127.0.0.1", IPv6: "none"}, + }, + }, + }, + } + derpMap2 := &tailcfg.DERPMap{ + Regions: map[int]*tailcfg.DERPRegion{ + 2: { + RegionID: 2, + RegionCode: "inc", + Nodes: []*tailcfg.DERPNode{ + {Name: "2a", RegionID: 2, HostName: "inc.test.unused", IPv4: "127.0.0.1", IPv6: "none"}, + }, + }, + }, + } + + var reSTUNCalls int + tstest.Replace(t, &reSTUNHookForTests, func(_ string) { + reSTUNCalls++ + }) + + bus := eventbustest.NewBus(t) + ht := health.NewTracker(bus) + c := newConn(t.Logf) + ec := bus.Client("magicsock.Conn.Test") + c.eventClient = ec + c.homeDERPChangedPub = eventbus.Publish[HomeDERPChanged](ec) + c.eventBus = bus + c.health = ht + // With a zero private key and everHadKey=true, ReSTUN returns early without + // spawning updateEndpoints. + c.everHadKey = true + + // SetDERPMapWithoutReSTUN should not trigger a ReSTUN. + c.SetDERPMapWithoutReSTUN(derpMap1) + if reSTUNCalls != 0 { + t.Errorf("SetDERPMapWithoutReSTUN: got %d ReSTUN calls, want 0", reSTUNCalls) + } + + // SetDERPMap should trigger a ReSTUN. + c.SetDERPMap(derpMap2) + if reSTUNCalls != 1 { + t.Errorf("SetDERPMap: got %d ReSTUN calls, want 1", reSTUNCalls) + } +} diff --git a/wgengine/magicsock/endpoint.go b/wgengine/magicsock/endpoint.go index b8d3b96be..d831a9032 100644 --- a/wgengine/magicsock/endpoint.go +++ b/wgengine/magicsock/endpoint.go @@ -138,6 +138,13 @@ func (de *endpoint) udpRelayEndpointReady(maybeBest addrQuality) { func (de *endpoint) setBestAddrLocked(v addrQuality) { if v.epAddr != de.bestAddr.epAddr { de.probeUDPLifetime.resetCycleEndpointLocked() + + // Reaching here, if we are using data from a cached netmap and we are + // upgrading from an invalid (missing) address to a valid one, increment + // the counter for peers established. + if !de.bestAddr.ap.IsValid() && v.ap.IsValid() && de.c.usingCachedNetmap.Load() { + metricCachedPeerContactDirect.Add(1) + } } de.bestAddr = v } @@ -530,11 +537,6 @@ func (de *endpoint) noteRecvActivity(src epAddr, now mono.Time) bool { elapsed := now.Sub(de.lastRecvWG.LoadAtomic()) if elapsed > 10*time.Second { de.lastRecvWG.StoreAtomic(now) - - if de.c.noteRecvActivity == nil { - return false - } - de.c.noteRecvActivity(de.publicKey) return true } return false @@ -897,7 +899,7 @@ func (de *endpoint) wantUDPRelayPathDiscoveryLocked(now mono.Time) bool { if runtime.GOOS == "js" { return false } - if !de.c.hasPeerRelayServers.Load() { + if !de.c.relayManager.hasPeerRelayServers.Load() { // Changes in this value between its access and a call to // [endpoint.discoverUDPRelayPathsLocked] are fine, we will eventually // do the "right" thing during future path discovery. The worst case is @@ -1183,8 +1185,7 @@ func (de *endpoint) discoPingTimeout(txid stun.TxID) { return } bestUntrusted := mono.Now().After(de.trustBestAddrUntil) - if sp.to == de.bestAddr.epAddr && sp.to.vni.IsSet() && bestUntrusted { - // TODO(jwhited): consider applying this to direct UDP paths as well + if sp.to == de.bestAddr.epAddr && bestUntrusted { de.clearBestAddrLocked() } if debugDisco() || !de.bestAddr.ap.IsValid() || bestUntrusted { @@ -1778,12 +1779,8 @@ func (de *endpoint) handlePongConnLocked(m *disco.Pong, di *discoInfo, src epAdd latency: latency, wireMTU: pingSizeToPktLen(sp.size, sp.to), } - // TODO(jwhited): consider checking de.trustBestAddrUntil as well. If - // de.bestAddr is untrusted we may want to clear it, otherwise we could - // get stuck with a forever untrusted bestAddr that blackholes, since - // we don't clear direct UDP paths on disco ping timeout (see - // discoPingTimeout). - if betterAddr(thisPong, de.bestAddr) { + bestUntrusted := now.After(de.trustBestAddrUntil) + if betterAddr(thisPong, de.bestAddr) || bestUntrusted { de.c.logf("magicsock: disco: node %v %v now using %v mtu=%v tx=%x", de.publicKey.ShortString(), de.discoShort(), sp.to, thisPong.wireMTU, m.TxID[:6]) de.debugUpdates.Add(EndpointChange{ When: time.Now(), @@ -2098,7 +2095,7 @@ func (de *endpoint) setDERPHome(regionID uint16) { de.mu.Lock() defer de.mu.Unlock() de.derpAddr = netip.AddrPortFrom(tailcfg.DerpMagicIPAddr, uint16(regionID)) - if de.c.hasPeerRelayServers.Load() { + if de.c.relayManager.hasPeerRelayServers.Load() { de.c.relayManager.handleDERPHomeChange(de.publicKey, regionID) } } diff --git a/wgengine/magicsock/endpoint_test.go b/wgengine/magicsock/endpoint_test.go index eba742b00..593cf1455 100644 --- a/wgengine/magicsock/endpoint_test.go +++ b/wgengine/magicsock/endpoint_test.go @@ -6,12 +6,16 @@ package magicsock import ( "net/netip" "testing" + "testing/synctest" "time" + "tailscale.com/disco" "tailscale.com/net/packet" + "tailscale.com/net/stun" "tailscale.com/tailcfg" "tailscale.com/tstime/mono" "tailscale.com/types/key" + "tailscale.com/util/ringlog" ) func TestProbeUDPLifetimeConfig_Equals(t *testing.T) { @@ -453,3 +457,233 @@ func Test_endpoint_udpRelayEndpointReady(t *testing.T) { }) } } + +func Test_endpoint_discoPingTimeout(t *testing.T) { + expired := -1 * time.Hour + valid := 1 * time.Hour + directAddrA := epAddr{ap: netip.MustParseAddrPort("192.0.2.1:7")} + relayAddrA := epAddr{ap: netip.MustParseAddrPort("192.0.2.2:77")} + relayAddrA.vni.Set(1) + directAddrB := epAddr{ap: netip.MustParseAddrPort("192.0.2.3:7")} + relayAddrB := epAddr{ap: netip.MustParseAddrPort("192.0.2.4:77")} + relayAddrB.vni.Set(1) + + for _, tc := range []struct { + name string + bestAddr addrQuality + trustBestAddrUntil time.Duration + pingTo epAddr + wantBestAddrCleared bool + }{ + { + name: "relay-path-trust-expired", + bestAddr: addrQuality{epAddr: relayAddrA}, + trustBestAddrUntil: expired, + pingTo: relayAddrA, + wantBestAddrCleared: true, + }, + { + name: "direct-udp-path-trust-expired", + bestAddr: addrQuality{epAddr: directAddrA}, + trustBestAddrUntil: expired, + pingTo: directAddrA, + wantBestAddrCleared: true, + }, + { + name: "direct-udp-path-trust-valid", + bestAddr: addrQuality{epAddr: directAddrA}, + trustBestAddrUntil: valid, + pingTo: directAddrA, + wantBestAddrCleared: false, + }, + { + name: "relay-path-trust-valid", + bestAddr: addrQuality{epAddr: relayAddrA}, + trustBestAddrUntil: valid, + pingTo: relayAddrA, + wantBestAddrCleared: false, + }, + { + name: "ping-to-different-direct-addr-trust-expired", + bestAddr: addrQuality{epAddr: directAddrA}, + trustBestAddrUntil: expired, + pingTo: directAddrB, + wantBestAddrCleared: false, + }, + { + name: "ping-to-different-relay-addr-trust-expired", + bestAddr: addrQuality{epAddr: relayAddrA}, + trustBestAddrUntil: expired, + pingTo: relayAddrB, + wantBestAddrCleared: false, + }, + } { + t.Run(tc.name, func(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + now := mono.Now() // synctest to match this to the internal 'now' + c := &Conn{ + logf: func(msg string, args ...any) {}, + } + c.discoAtomic.Set(key.NewDisco()) + de := &endpoint{ + c: c, + bestAddr: tc.bestAddr, + trustBestAddrUntil: now.Add(tc.trustBestAddrUntil), + sentPing: make(map[stun.TxID]sentPing), + } + txid := stun.NewTxID() + timer := time.NewTimer(time.Hour) + timer.Stop() + de.sentPing[txid] = sentPing{ + to: tc.pingTo, + at: now.Add(-100 * time.Millisecond), + timer: timer, + purpose: pingDiscovery, + } + + de.discoPingTimeout(txid) + if tc.wantBestAddrCleared { + if de.bestAddr.ap.IsValid() { + t.Errorf("expected bestAddr to be cleared, but bestAddr.ap is valid: %v", de.bestAddr.ap) + } + if de.trustBestAddrUntil != 0 { + t.Errorf("expected trustBestAddrUntil to be cleared, but got: %v", de.trustBestAddrUntil) + } + } else { + if de.bestAddr != tc.bestAddr { + t.Errorf("expected bestAddr to be unchanged, got: %v, want: %v", de.bestAddr, tc.bestAddr) + } + } + if _, ok := de.sentPing[txid]; ok { + t.Errorf("expected sentPing[txid] to be removed, but it still exists") + } + }) + }) + } +} + +func Test_endpoint_handlePongConnLocked(t *testing.T) { + goodLatency := 50 * time.Millisecond + badLatency := 100 * time.Millisecond + expired := -1 * time.Hour + valid := 1 * time.Hour + directAddrA := epAddr{ap: netip.MustParseAddrPort("192.0.2.1:7")} + directAddrB := epAddr{ap: netip.MustParseAddrPort("192.0.2.2:8")} + derpAddr := epAddr{ap: netip.AddrPortFrom(tailcfg.DerpMagicIPAddr, 0)} + + for _, tc := range []struct { + name string + bestAddr addrQuality + trustBestAddrUntil time.Duration + pongFrom epAddr + pongLatency time.Duration + wantBestAddr epAddr + }{ + { + name: "better-latency-trust-valid", + bestAddr: addrQuality{epAddr: directAddrA, latency: badLatency}, + trustBestAddrUntil: valid, + pongFrom: directAddrB, + pongLatency: goodLatency, + wantBestAddr: directAddrB, + }, + { + name: "worse-latency-trust-valid", + bestAddr: addrQuality{epAddr: directAddrA, latency: goodLatency}, + trustBestAddrUntil: valid, + pongFrom: directAddrB, + pongLatency: badLatency, + wantBestAddr: directAddrA, + }, + { + name: "worse-latency-trust-expired", + bestAddr: addrQuality{epAddr: directAddrA, latency: goodLatency}, + trustBestAddrUntil: expired, + pongFrom: directAddrB, + pongLatency: badLatency, + wantBestAddr: directAddrB, + }, + { + name: "same-path-trust-expired", + bestAddr: addrQuality{epAddr: directAddrA, latency: badLatency}, + trustBestAddrUntil: expired, + pongFrom: directAddrA, + pongLatency: goodLatency, // updated latency + wantBestAddr: directAddrA, + }, + { + name: "derp-pong-trust-expired", + bestAddr: addrQuality{epAddr: directAddrA, latency: badLatency}, + trustBestAddrUntil: expired, + pongFrom: derpAddr, + pongLatency: goodLatency, + wantBestAddr: directAddrA, + }, + { + name: "better-latency-trust-expired", + bestAddr: addrQuality{epAddr: directAddrA, latency: badLatency}, + trustBestAddrUntil: expired, + pongFrom: directAddrB, + pongLatency: goodLatency, + wantBestAddr: directAddrB, + }, + } { + t.Run(tc.name, func(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + now := mono.Now() // synctest to match this to the internal 'now' + pm := newPeerMap() + c := &Conn{ + logf: func(msg string, args ...any) {}, + peerMap: pm, + } + c.discoAtomic.Set(key.NewDisco()) + de := &endpoint{ + c: c, + bestAddr: tc.bestAddr, + bestAddrAt: now.Add(-5 * time.Minute), + trustBestAddrUntil: now.Add(tc.trustBestAddrUntil), + sentPing: make(map[stun.TxID]sentPing), + endpointState: make(map[netip.AddrPort]*endpointState), + debugUpdates: ringlog.New[EndpointChange](10), + } + txid := stun.NewTxID() + pong := &disco.Pong{ + TxID: txid, + Src: tc.pongFrom.ap, + } + timer := time.NewTimer(time.Hour) + timer.Stop() + de.sentPing[txid] = sentPing{ + to: tc.pongFrom, + at: now.Add(-tc.pongLatency), + timer: timer, + purpose: pingDiscovery, + } + if tc.pongFrom.ap.Addr() != tailcfg.DerpMagicIPAddr && !tc.pongFrom.vni.IsSet() { + de.endpointState[tc.pongFrom.ap] = &endpointState{} + } + di := &discoInfo{ + discoKey: key.NewDisco().Public(), + discoShort: "test", + } + + knownTxID := de.handlePongConnLocked(pong, di, tc.pongFrom) + if !knownTxID { + t.Errorf("expected knownTxID to be true, got false") + } + if de.bestAddr.epAddr != tc.wantBestAddr { + t.Errorf("expected bestAddr.epAddr to be %v, got: %v", tc.wantBestAddr, de.bestAddr.epAddr) + } + if tc.pongFrom == tc.bestAddr.epAddr && de.bestAddr.latency-tc.pongLatency > 0 { + t.Errorf("expected latency to be %v, got: %v", tc.pongLatency, de.bestAddr.latency) + } + if tc.pongFrom != derpAddr && de.trustBestAddrUntil.Before(now) { + t.Errorf("expected trustBestAddrUntil to be refreshed, but it's in the past: %v", de.trustBestAddrUntil) + } + if _, ok := de.sentPing[txid]; ok { + t.Errorf("expected sentPing[txid] to be removed, but it still exists") + } + }) + }) + } +} diff --git a/wgengine/magicsock/magicsock.go b/wgengine/magicsock/magicsock.go index f13e31554..6461c552e 100644 --- a/wgengine/magicsock/magicsock.go +++ b/wgengine/magicsock/magicsock.go @@ -164,7 +164,6 @@ type Conn struct { derpActiveFunc func() idleFunc func() time.Duration // nil means unknown testOnlyPacketListener nettype.PacketListener - noteRecvActivity func(key.NodePublic) // or nil, see Options.NoteRecvActivity onDERPRecv func(int, key.NodePublic, []byte) bool // or nil, see Options.OnDERPRecv netMon *netmon.Monitor // must be non-nil health *health.Tracker // or nil @@ -183,6 +182,7 @@ type Conn struct { allocRelayEndpointPub *eventbus.Publisher[UDPRelayAllocReq] portUpdatePub *eventbus.Publisher[router.PortUpdate] tsmpDiscoKeyAvailablePub *eventbus.Publisher[NewDiscoKeyAvailable] + homeDERPChangedPub *eventbus.Publisher[HomeDERPChanged] // pconn4 and pconn6 are the underlying UDP sockets used to // send/receive packets for wireguard and other magicsock @@ -269,15 +269,12 @@ type Conn struct { // captureHook, if non-nil, is the pcap logging callback when capturing. captureHook syncs.AtomicValue[packet.CaptureCallback] - // hasPeerRelayServers is whether [relayManager] is configured with at least - // one peer relay server via [relayManager.handleRelayServersSet]. It exists - // to suppress calls into [relayManager] leading to wasted work involving - // channel operations and goroutine creation. - hasPeerRelayServers atomic.Bool - // discoAtomic is the current disco private and public keypair for this conn. discoAtomic discoAtomic + // usingCacheNetmap is whether the latest update to self and peersByID are from a cached network map + usingCachedNetmap atomic.Bool + // ============================================================ // mu guards all following fields; see userspaceEngine lock // ordering rules against the engine. For derphttp, mu must @@ -361,18 +358,19 @@ type Conn struct { // magicsock could do with any complexity reduction it can get. netInfoLast *tailcfg.NetInfo - derpMap *tailcfg.DERPMap // nil (or zero regions/nodes) means DERP is disabled - self tailcfg.NodeView // from last SetNetworkMap - peers views.Slice[tailcfg.NodeView] // from last SetNetworkMap, sorted by Node.ID; Note: [netmap.NodeMutation]'s rx'd in UpdateNetmapDelta are never applied - filt *filter.Filter // from last SetFilter - relayClientEnabled bool // whether we can allocate UDP relay endpoints on UDP relay servers or receive CallMeMaybeVia messages from peers - lastFlags debugFlags // at time of last SetNetworkMap - privateKey key.NodePrivate // WireGuard private key for this node - everHadKey bool // whether we ever had a non-zero private key - myDerp int // nearest DERP region ID; 0 means none/unknown - homeless bool // if true, don't try to find & stay conneted to a DERP home (myDerp will stay 0) - derpStarted chan struct{} // closed on first connection to DERP; for tests & cleaner Close - activeDerp map[int]activeDerp // DERP regionID -> connection to a node in that region + derpMap *tailcfg.DERPMap // nil (or zero regions/nodes) means DERP is disabled + self tailcfg.NodeView // from last SetNetworkMap + peersByID map[tailcfg.NodeID]tailcfg.NodeView // current peer set, keyed by NodeID. Maintained by SetNetworkMap/UpsertPeer/RemovePeer. Note: per-field NodeMutation patches received in UpdateNetmapDelta are never applied to these snapshots. + + filt *filter.Filter // from last SetFilter + relayClientEnabled bool // whether we can allocate UDP relay endpoints on UDP relay servers or receive CallMeMaybeVia messages from peers + lastFlags debugFlags // at time of last SetNetworkMap + privateKey key.NodePrivate // WireGuard private key for this node + everHadKey bool // whether we ever had a non-zero private key + myDerp int // nearest DERP region ID; 0 means none/unknown + homeless bool // if true, don't try to find & stay conneted to a DERP home (myDerp will stay 0) + derpStarted chan struct{} // closed on first connection to DERP; for tests & cleaner Close + activeDerp map[int]activeDerp // DERP regionID -> connection to a node in that region prevDerp map[int]*syncs.WaitGroupChan // derpRoute contains optional alternate routes to use as an @@ -462,19 +460,6 @@ type Options struct { // Only used by tests. TestOnlyPacketListener nettype.PacketListener - // NoteRecvActivity, if provided, is a func for magicsock to call - // whenever it receives a packet from a a peer if it's been more - // than ~10 seconds since the last one. (10 seconds is somewhat - // arbitrary; the sole user, lazy WireGuard configuration, - // just doesn't need or want it called on - // every packet, just every minute or two for WireGuard timeouts, - // and 10 seconds seems like a good trade-off between often enough - // and not too often.) - // The provided func is likely to call back into - // Conn.ParseEndpoint, which acquires Conn.mu. As such, you should - // not hold Conn.mu while calling it. - NoteRecvActivity func(key.NodePublic) - // NetMon is the network monitor to use. // It must be non-nil. NetMon *netmon.Monitor @@ -653,7 +638,6 @@ func NewConn(opts Options) (*Conn, error) { c.derpActiveFunc = opts.derpActiveFunc() c.idleFunc = opts.IdleFunc c.testOnlyPacketListener = opts.TestOnlyPacketListener - c.noteRecvActivity = opts.NoteRecvActivity c.onDERPRecv = opts.OnDERPRecv // Set up publishers and subscribers. Subscribe calls must return before @@ -663,6 +647,7 @@ func NewConn(opts Options) (*Conn, error) { c.allocRelayEndpointPub = eventbus.Publish[UDPRelayAllocReq](ec) c.portUpdatePub = eventbus.Publish[router.PortUpdate](ec) c.tsmpDiscoKeyAvailablePub = eventbus.Publish[NewDiscoKeyAvailable](ec) + c.homeDERPChangedPub = eventbus.Publish[HomeDERPChanged](ec) eventbus.SubscribeFunc(ec, c.onPortMapChanged) eventbus.SubscribeFunc(ec, c.onUDPRelayAllocResp) @@ -1062,7 +1047,7 @@ func (c *Conn) updateNetInfo(ctx context.Context) (*netcheck.Report, error) { ni.OSHasIPv6.Set(report.OSHasIPv6) ni.WorkingUDP.Set(report.UDP) ni.WorkingICMPv4.Set(report.ICMPv4) - ni.PreferredDERP = c.maybeSetNearestDERP(report) + ni.PreferredDERP = c.maybeSetNearestDERP(report, false) ni.FirewallMode = hostinfo.FirewallMode() c.callNetInfoCallback(ni) @@ -2353,6 +2338,13 @@ func (c *Conn) handleDiscoMessage(msg []byte, src epAddr, shouldBeRelayHandshake c.logf("[unexpected] %s from peer via DERP whose netmap discokey != disco source", msgType) return } + + // Reaching here, if we are using data from a cached network map the + // receipt of a CallMeMaybe from a peer indicates we have a sufficiently + // viable connection to that peer to count it as active while cached. + if c.usingCachedNetmap.Load() { + metricCachedPeerContactDERP.Add(1) + } if isVia { c.dlogf("[v1] magicsock: disco: %v<-%v via %v (%v, %v) got call-me-maybe-via, %d endpoints", c.discoAtomic.Short(), epDisco.short, via.ServerDisco.ShortString(), @@ -2430,24 +2422,12 @@ func (c *Conn) handleDiscoMessage(msg []byte, src epAddr, shouldBeRelayHandshake if c.filt == nil { return } - // Binary search of peers is O(log n) while c.mu is held. - // TODO: We might be able to use ep.nodeAddr instead of all addresses, - // or we might be able to release c.mu before doing this work. Keep it - // simple and slow for now. c.peers.AsSlice is a copy. We may need to - // write our own binary search for a [views.Slice]. - peerI, ok := slices.BinarySearchFunc(c.peers.AsSlice(), ep.nodeID, func(peer tailcfg.NodeView, target tailcfg.NodeID) int { - if peer.ID() < target { - return -1 - } else if peer.ID() > target { - return 1 - } - return 0 - }) + peer, ok := c.peersByID[ep.nodeID] if !ok { // unexpected return } - if !nodeHasCap(c.filt, c.peers.At(peerI), c.self, tailcfg.PeerCapabilityRelay) { + if !nodeHasCap(c.filt, peer, c.self, tailcfg.PeerCapabilityRelay) { return } // [Conn.mu] must not be held while publishing, or [Conn.onUDPRelayAllocResp] @@ -2784,18 +2764,6 @@ func (c *Conn) UpdatePeers(newPeers set.Set[key.NodePublic]) { } } -func nodesEqual(x, y views.Slice[tailcfg.NodeView]) bool { - if x.Len() != y.Len() { - return false - } - for i := range x.Len() { - if !x.At(i).Equal(y.At(i)) { - return false - } - } - return true -} - // debugRingBufferSize returns a maximum size for our set of endpoint ring // buffers by assuming that a single large update is ~500 bytes, and that we // want to not use more than 1MiB of memory on phones / 4MiB on other devices. @@ -2883,7 +2851,7 @@ func (c *Conn) SetFilter(f *filter.Filter) { c.mu.Lock() c.filt = f self := c.self - peers := c.peers + peers := c.peerSnapshotLocked() relayClientEnabled := c.relayClientEnabled c.mu.Unlock() // release c.mu before potentially calling c.updateRelayServersSet which is O(m * n) @@ -2897,11 +2865,26 @@ func (c *Conn) SetFilter(f *filter.Filter) { c.updateRelayServersSet(f, self, peers) } +// peerSnapshotLocked returns a freshly-allocated slice of the current peers. +// It's used by callers that need to pass peer state to an O(m * n) callee +// (like [Conn.updateRelayServersSet]) after releasing c.mu. c.mu must be held. +func (c *Conn) peerSnapshotLocked() []tailcfg.NodeView { + if len(c.peersByID) == 0 { + return nil + } + out := make([]tailcfg.NodeView, 0, len(c.peersByID)) + for _, p := range c.peersByID { + out = append(out, p) + } + return out +} + // updateRelayServersSet iterates all peers and self, evaluating filt for each // one in order to determine which are relay server candidates. filt, self, and // peers are passed as args (vs c.mu-guarded fields) to enable callers to // release c.mu before calling as this is O(m * n) (we iterate all cap rules 'm' -// in filt for every peer 'n'). +// in filt for every peer 'n'). peers must be a snapshot owned by the caller; +// this function does not retain it after return. // // Calls to updateRelayServersSet must never run concurrent to // [endpoint.setDERPHome], otherwise [candidatePeerRelay] DERP home changes may @@ -2913,9 +2896,9 @@ func (c *Conn) SetFilter(f *filter.Filter) { // them. // 2. Moving this work upstream into [nodeBackend] or similar, and publishing // the computed result over the eventbus instead. -func (c *Conn) updateRelayServersSet(filt *filter.Filter, self tailcfg.NodeView, peers views.Slice[tailcfg.NodeView]) { +func (c *Conn) updateRelayServersSet(filt *filter.Filter, self tailcfg.NodeView, peers []tailcfg.NodeView) { relayServers := make(set.Set[candidatePeerRelay]) - nodes := append(peers.AsSlice(), self) + nodes := append(peers, self) for _, maybeCandidate := range nodes { if maybeCandidate.ID() != self.ID() && !capVerIsRelayCapable(maybeCandidate.Cap()) { // If maybeCandidate's [tailcfg.CapabilityVersion] is not relay-capable, @@ -2933,12 +2916,9 @@ func (c *Conn) updateRelayServersSet(filt *filter.Filter, self tailcfg.NodeView, derpHomeRegionID: uint16(maybeCandidate.HomeDERP()), }) } + // [relayManager]'s run loop updates [relayManager.hasPeerRelayServers] + // to reflect the new server count. c.relayManager.handleRelayServersSet(relayServers) - if len(relayServers) > 0 { - c.hasPeerRelayServers.Store(true) - } else { - c.hasPeerRelayServers.Store(false) - } } // nodeHasCap returns true if src has cap on dst, otherwise it returns false. @@ -2985,12 +2965,31 @@ func (c *candidatePeerRelay) isValid() bool { return !c.nodeKey.IsZero() && !c.discoKey.IsZero() } -// SetNetworkMap updates the network map with the given self node and peers. -// It must be called synchronously from the caller's goroutine to ensure -// magicsock has the current state before subsequent operations proceed. +// SetNetworkMap updates the network map with the given self node and peers +// reported by the control plane (rather than cached). It must be called +// synchronously from the caller's goroutine to ensure magicsock has the +// current state before subsequent operations proceed. // // self may be invalid if there's no network map. +// +// SetNetworkMap takes the full peer list and walks all of it. For incremental +// updates where only a single peer changes, prefer the O(1) [Conn.UpsertPeer] +// and [Conn.RemovePeer] methods. SetNetworkMap remains the right call for the +// initial netmap and for changes to self or to global state (filter, DERP, +// etc.) that aren't covered by the per-peer methods. func (c *Conn) SetNetworkMap(self tailcfg.NodeView, peers []tailcfg.NodeView) { + c.setNetworkMapInternal(self, peers, false) +} + +// SetNetworkMapCached behaves as SetNetworkMap, but indicates to c that the +// data provided are from a cache rather than the control plane. The same +// constraints otherwise apply. +func (c *Conn) SetNetworkMapCached(self tailcfg.NodeView, peers []tailcfg.NodeView) { + c.setNetworkMapInternal(self, peers, true) +} + +// setNetworkMapInternal is the shared implementation of SetNetworkMap and SetNetworkMapCached. +func (c *Conn) setNetworkMapInternal(self tailcfg.NodeView, peers []tailcfg.NodeView, isCached bool) { peersChanged := c.updateNodes(self, peers) relayClientEnabled := self.Valid() && @@ -3002,8 +3001,9 @@ func (c *Conn) SetNetworkMap(self tailcfg.NodeView, peers []tailcfg.NodeView) { c.relayClientEnabled = relayClientEnabled filt := c.filt selfView := c.self - peersView := c.peers + peersSnap := c.peerSnapshotLocked() isClosed := c.closed + c.usingCachedNetmap.Store(isCached) c.mu.Unlock() // release c.mu before potentially calling c.updateRelayServersSet which is O(m * n) if isClosed { @@ -3012,16 +3012,16 @@ func (c *Conn) SetNetworkMap(self tailcfg.NodeView, peers []tailcfg.NodeView) { if peersChanged || relayClientChanged { if !relayClientEnabled { + // [relayManager]'s run loop updates [relayManager.hasPeerRelayServers]. c.relayManager.handleRelayServersSet(nil) - c.hasPeerRelayServers.Store(false) } else { - c.updateRelayServersSet(filt, selfView, peersView) + c.updateRelayServersSet(filt, selfView, peersSnap) } } } // updateNodes updates [Conn] to reflect the given self node and peers. -// It reports whether the peers were changed from before. +// It reports whether the peer set (membership or any field) changed. func (c *Conn) updateNodes(self tailcfg.NodeView, peers []tailcfg.NodeView) (peersChanged bool) { c.mu.Lock() defer c.mu.Unlock() @@ -3030,13 +3030,9 @@ func (c *Conn) updateNodes(self tailcfg.NodeView, peers []tailcfg.NodeView) (pee return false } - priorPeers := c.peers metricNumPeers.Set(int64(len(peers))) - // Update c.self & c.peers regardless, before the following early return. c.self = self - curPeers := views.SliceOf(peers) - c.peers = curPeers // [debugFlags] are mutable in [Conn.SetSilentDisco] & // [Conn.SetProbeUDPLifetime]. These setters are passed [controlknobs.Knobs] @@ -3049,137 +3045,43 @@ func (c *Conn) updateNodes(self tailcfg.NodeView, peers []tailcfg.NodeView) (pee // TODO: mutate [debugFlags] here instead of in various [Conn] setters. flags := c.debugFlagsLocked() - peersChanged = !nodesEqual(priorPeers, curPeers) - if !peersChanged && c.lastFlags == flags { - // The rest of this function is all adjusting state for peers that have - // changed. But if the set of peers is equal and the debug flags (for - // silent disco and probe UDP lifetime) haven't changed, there is no - // need to do anything else. - return + // Fast path: if the peer set and every peer's NodeView are unchanged, + // and flags are unchanged, skip all further work. + if c.lastFlags == flags && len(peers) == len(c.peersByID) { + allSame := true + for _, n := range peers { + if prev, ok := c.peersByID[n.ID()]; !ok || !prev.Equal(n) { + allSame = false + break + } + } + if allSame { + return false + } } c.lastFlags = flags - c.logf("[v1] magicsock: got updated network map; %d peers", len(peers)) entriesPerBuffer := debugRingBufferSize(len(peers)) - // Try a pass of just upserting nodes and creating missing - // endpoints. If the set of nodes is the same, this is an - // efficient alloc-free update. If the set of nodes is different, - // we'll fall through to the next pass, which allocates but can - // handle full set updates. + // Build the new peer map while upserting each peer. + newPeers := make(map[tailcfg.NodeID]tailcfg.NodeView, len(peers)) for _, n := range peers { - if n.ID() == 0 { - devPanicf("node with zero ID") - continue - } - if n.Key().IsZero() { - devPanicf("node with zero key") - continue - } - ep, ok := c.peerMap.endpointForNodeID(n.ID()) - if ok && ep.publicKey != n.Key() { - // The node rotated public keys. Delete the old endpoint and create - // it anew. - c.peerMap.deleteEndpoint(ep) - ok = false - } - if ok { - // At this point we're modifying an existing endpoint (ep) whose - // public key and nodeID match n. Its other fields (such as disco - // key or endpoints) might've changed. - - if n.DiscoKey().IsZero() && !n.IsWireGuardOnly() { - // Discokey transitioned from non-zero to zero? This should not - // happen in the wild, however it could mean: - // 1. A node was downgraded from post 0.100 to pre 0.100. - // 2. A Tailscale node key was extracted and used on a - // non-Tailscale node (should not enter here due to the - // IsWireGuardOnly check) - // 3. The server is misbehaving. - c.peerMap.deleteEndpoint(ep) - continue - } - var oldDiscoKey key.DiscoPublic - if epDisco := ep.disco.Load(); epDisco != nil { - oldDiscoKey = epDisco.key - } - ep.updateFromNode(n, flags.heartbeatDisabled, flags.probeUDPLifetimeOn) - c.peerMap.upsertEndpoint(ep, oldDiscoKey) // maybe update discokey mappings in peerMap - continue - } - - if ep, ok := c.peerMap.endpointForNodeKey(n.Key()); ok { - // At this point n.Key() should be for a key we've never seen before. If - // ok was true above, it was an update to an existing matching key and - // we don't get this far. If ok was false above, that means it's a key - // that differs from the one the NodeID had. But double check. - if ep.nodeID != n.ID() { - // Server error. This is known to be a particular issue for Mullvad - // nodes (http://go/corp/27300), so log a distinct error for the - // Mullvad and non-Mullvad cases. The error will be logged either way, - // so an approximate heuristic is fine. - // - // When #27300 is fixed, we can delete this branch and log the same - // panic for any public key moving. - if strings.HasSuffix(n.Name(), ".mullvad.ts.net.") { - devPanicf("public key moved between Mullvad nodeIDs (old=%v new=%v, key=%s); see http://go/corp/27300", ep.nodeID, n.ID(), n.Key().String()) - } else { - devPanicf("public key moved between nodeIDs (old=%v new=%v, key=%s)", ep.nodeID, n.ID(), n.Key().String()) - } - } else { - // Internal data structures out of sync. - devPanicf("public key found in peerMap but not by nodeID") - } - continue - } - if n.DiscoKey().IsZero() && !n.IsWireGuardOnly() { - // Ancient pre-0.100 node, which does not have a disco key. - // No longer supported. - continue - } - - ep = &endpoint{ - c: c, - nodeID: n.ID(), - publicKey: n.Key(), - publicKeyHex: n.Key().UntypedHexString(), - sentPing: map[stun.TxID]sentPing{}, - endpointState: map[netip.AddrPort]*endpointState{}, - heartbeatDisabled: flags.heartbeatDisabled, - isWireguardOnly: n.IsWireGuardOnly(), - } - switch runtime.GOOS { - case "ios", "android": - // Omit, to save memory. Prior to 2024-03-20 we used to limit it to - // ~1MB on mobile but we never used the data so the memory was just - // wasted. - default: - ep.debugUpdates = ringlog.New[EndpointChange](entriesPerBuffer) - } - if n.Addresses().Len() > 0 { - ep.nodeAddr = n.Addresses().At(0).Addr() - } - ep.initFakeUDPAddr() - ep.updateDiscoKey(n.DiscoKey()) - - if debugPeerMap() { - c.logEndpointCreated(n) - } - - ep.updateFromNode(n, flags.heartbeatDisabled, flags.probeUDPLifetimeOn) - c.peerMap.upsertEndpoint(ep, key.DiscoPublic{}) + newPeers[n.ID()] = n + c.upsertPeerLocked(n, flags, entriesPerBuffer) } + if len(newPeers) != len(peers) { + // Duplicate NodeIDs in the input shouldn't happen, but log if so. + c.logf("[unexpected] magicsock.updateNodes: %d peers input but %d unique IDs", len(peers), len(newPeers)) + } + c.peersByID = newPeers - // If the set of nodes changed since the last SetNetworkMap, the - // upsert loop just above made c.peerMap contain the union of the - // old and new peers - which will be larger than the set from the - // current netmap. If that happens, go through the allocful - // deletion path to clean up moribund nodes. - if c.peerMap.nodeCount() != len(peers) { + // If the upsert pass left stale endpoints in peerMap (peers removed + // relative to before), clean them up. + if c.peerMap.nodeCount() != len(newPeers) { keep := set.Set[key.NodePublic]{} - for _, n := range peers { + for _, n := range newPeers { keep.Add(n.Key()) } c.peerMap.forEachEndpoint(func(ep *endpoint) { @@ -3189,14 +3091,226 @@ func (c *Conn) updateNodes(self tailcfg.NodeView, peers []tailcfg.NodeView) (pee }) } - // discokeys might have changed in the above. Discard unused info. + // discokeys might have changed above. Discard unused info. for dk := range c.discoInfo { if !c.peerMap.knownPeerDiscoKey(dk) { delete(c.discoInfo, dk) } } - return peersChanged + return true +} + +// upsertPeerLocked upserts a single peer's endpoint in c.peerMap. It is the +// per-peer body shared by [Conn.SetNetworkMap]'s upsert pass and by the +// efficient per-peer [Conn.UpsertPeer] path. +// +// c.mu must be held. +func (c *Conn) upsertPeerLocked(n tailcfg.NodeView, flags debugFlags, entriesPerBuffer int) { + if n.ID() == 0 { + devPanicf("node with zero ID") + return + } + if n.Key().IsZero() { + devPanicf("node with zero key") + return + } + ep, ok := c.peerMap.endpointForNodeID(n.ID()) + if ok && ep.publicKey != n.Key() { + // The node rotated public keys. Delete the old endpoint and create + // it anew. + c.peerMap.deleteEndpoint(ep) + ok = false + } + if ok { + // At this point we're modifying an existing endpoint (ep) whose + // public key and nodeID match n. Its other fields (such as disco + // key or endpoints) might've changed. + + if n.DiscoKey().IsZero() && !n.IsWireGuardOnly() { + // Discokey transitioned from non-zero to zero? This should not + // happen in the wild, however it could mean: + // 1. A node was downgraded from post 0.100 to pre 0.100. + // 2. A Tailscale node key was extracted and used on a + // non-Tailscale node (should not enter here due to the + // IsWireGuardOnly check) + // 3. The server is misbehaving. + c.peerMap.deleteEndpoint(ep) + return + } + var oldDiscoKey key.DiscoPublic + if epDisco := ep.disco.Load(); epDisco != nil { + oldDiscoKey = epDisco.key + } + ep.updateFromNode(n, flags.heartbeatDisabled, flags.probeUDPLifetimeOn) + c.peerMap.upsertEndpoint(ep, oldDiscoKey) // maybe update discokey mappings in peerMap + return + } + + if ep, ok := c.peerMap.endpointForNodeKey(n.Key()); ok { + // At this point n.Key() should be for a key we've never seen before. If + // ok was true above, it was an update to an existing matching key and + // we don't get this far. If ok was false above, that means it's a key + // that differs from the one the NodeID had. But double check. + if ep.nodeID != n.ID() { + // Server error. This is known to be a particular issue for Mullvad + // nodes (http://go/corp/27300), so log a distinct error for the + // Mullvad and non-Mullvad cases. The error will be logged either way, + // so an approximate heuristic is fine. + // + // When #27300 is fixed, we can delete this branch and log the same + // panic for any public key moving. + if strings.HasSuffix(n.Name(), ".mullvad.ts.net.") { + devPanicf("public key moved between Mullvad nodeIDs (old=%v new=%v, key=%s); see http://go/corp/27300", ep.nodeID, n.ID(), n.Key().String()) + } else { + devPanicf("public key moved between nodeIDs (old=%v new=%v, key=%s)", ep.nodeID, n.ID(), n.Key().String()) + } + } else { + // Internal data structures out of sync. + devPanicf("public key found in peerMap but not by nodeID") + } + return + } + if n.DiscoKey().IsZero() && !n.IsWireGuardOnly() { + // Ancient pre-0.100 node, which does not have a disco key. + // No longer supported. + return + } + + ep = &endpoint{ + c: c, + nodeID: n.ID(), + publicKey: n.Key(), + publicKeyHex: n.Key().UntypedHexString(), + sentPing: map[stun.TxID]sentPing{}, + endpointState: map[netip.AddrPort]*endpointState{}, + heartbeatDisabled: flags.heartbeatDisabled, + isWireguardOnly: n.IsWireGuardOnly(), + } + switch runtime.GOOS { + case "ios", "android": + // Omit, to save memory. Prior to 2024-03-20 we used to limit it to + // ~1MB on mobile but we never used the data so the memory was just + // wasted. + default: + ep.debugUpdates = ringlog.New[EndpointChange](entriesPerBuffer) + } + if n.Addresses().Len() > 0 { + ep.nodeAddr = n.Addresses().At(0).Addr() + } + ep.initFakeUDPAddr() + ep.updateDiscoKey(n.DiscoKey()) + + if debugPeerMap() { + c.logEndpointCreated(n) + } + + ep.updateFromNode(n, flags.heartbeatDisabled, flags.probeUDPLifetimeOn) + c.peerMap.upsertEndpoint(ep, key.DiscoPublic{}) +} + +// UpsertPeer adds or updates a single peer in c. It is the efficient +// O(1)-per-peer alternative to [Conn.SetNetworkMap] when a single peer was +// added or its fields changed. The caller is responsible for serializing +// UpsertPeer/RemovePeer/SetNetworkMap calls relative to one another. +// +// UpsertPeer updates the relay-server set incrementally (O(1)) when the +// upserted peer's relay candidacy changed, rather than rebuilding the +// whole set with [Conn.updateRelayServersSet]. +func (c *Conn) UpsertPeer(n tailcfg.NodeView) { + c.mu.Lock() + if c.closed { + c.mu.Unlock() + return + } + if n.ID() == 0 { + c.mu.Unlock() + devPanicf("UpsertPeer: node with zero ID") + return + } + flags := c.debugFlagsLocked() + c.peersByID[n.ID()] = n + c.upsertPeerLocked(n, flags, debugRingBufferSize(len(c.peersByID))) + + var relayUpsert candidatePeerRelay + relayQualifies := false + if c.relayClientEnabled { + relayQualifies, relayUpsert = c.relayCandidateLocked(n) + } + relayClientEnabled := c.relayClientEnabled + c.mu.Unlock() + + if relayClientEnabled { + if relayQualifies { + c.relayManager.handleRelayServerUpsert(relayUpsert) + } else { + // The peer may have previously qualified; remove covers that + // case and is a no-op otherwise. + c.relayManager.handleRelayServerRemove(n.Key()) + } + } +} + +// RemovePeer removes a single peer from c. It is the efficient +// O(1)-per-peer alternative to [Conn.SetNetworkMap] when a single peer was +// removed. The caller is responsible for serializing UpsertPeer/RemovePeer/ +// SetNetworkMap calls relative to one another. +func (c *Conn) RemovePeer(nid tailcfg.NodeID) { + c.mu.Lock() + if c.closed { + c.mu.Unlock() + return + } + prev, ok := c.peersByID[nid] + if !ok { + c.mu.Unlock() + return + } + delete(c.peersByID, nid) + if ep, ok := c.peerMap.endpointForNodeID(nid); ok { + c.peerMap.deleteEndpoint(ep) + } + + // If the peer we just removed held the only reference to its disco + // key, drop the now-orphaned c.discoInfo entry. No need to scan the + // whole map — only this peer's disco key can have become unreferenced + // by this single removal. + if dk := prev.DiscoKey(); !dk.IsZero() && !c.peerMap.knownPeerDiscoKey(dk) { + delete(c.discoInfo, dk) + } + + relayClientEnabled := c.relayClientEnabled + c.mu.Unlock() + + if relayClientEnabled { + // Tell the relay manager to drop the peer. The run loop no-ops + // this if the peer wasn't a relay server. + c.relayManager.handleRelayServerRemove(prev.Key()) + } +} + +// relayCandidateLocked reports whether peer p is eligible to be a relay +// server candidate for self, and if so returns the [candidatePeerRelay] +// that would be added to the relay-server set. c.mu must be held. +// +// It mirrors the per-peer predicate in [Conn.updateRelayServersSet]. +func (c *Conn) relayCandidateLocked(p tailcfg.NodeView) (ok bool, cp candidatePeerRelay) { + if !p.Valid() { + return false, candidatePeerRelay{} + } + // The cap-version gate in updateRelayServersSet only applies to peers + // (not self). This helper is only called for peers, so always check. + if !capVerIsRelayCapable(p.Cap()) { + return false, candidatePeerRelay{} + } + if !nodeHasCap(c.filt, p, c.self, tailcfg.PeerCapabilityRelayTarget) { + return false, candidatePeerRelay{} + } + return true, candidatePeerRelay{ + nodeKey: p.Key(), + discoKey: p.DiscoKey(), + derpHomeRegionID: uint16(p.HomeDERP()), + } } func devPanicf(format string, a ...any) { @@ -4109,6 +4223,10 @@ var ( metricTSMPDiscoKeyAdvertisementReceived = clientmetric.NewCounter("magicsock_tsmp_disco_key_advertisement_received") metricTSMPDiscoKeyAdvertisementApplied = clientmetric.NewCounter("magicsock_tsmp_disco_key_advertisement_applied") metricTSMPDiscoKeyAdvertisementUnchanged = clientmetric.NewCounter("magicsock_tsmp_disco_key_advertisement_unchanged") + + // Counters for peer contacts established using cached network map data. + metricCachedPeerContactDERP = clientmetric.NewCounter("magicsock_cached_peer_contact_derp") + metricCachedPeerContactDirect = clientmetric.NewCounter("magicsock_cached_peer_contact_direct") ) // newUDPLifetimeCounter returns a new *clientmetric.Metric with the provided @@ -4166,16 +4284,10 @@ var _ conn.Endpoint = (*lazyEndpoint)(nil) // InitiationMessagePublicKey implements [conn.InitiationAwareEndpoint]. // wireguard-go calls us here if we passed it a [*lazyEndpoint] for an -// initiation message, for which it might not have the relevant peer configured, -// enabling us to just-in-time configure it and note its activity via -// [*endpoint.noteRecvActivity], before it performs peer lookup and attempts -// decryption. +// initiation message, for which it might not have the relevant peer configured. +// Wireguard-go's PeerLookupFunc handles on-demand peer creation. // -// Reception of all other WireGuard message types implies pre-existing knowledge -// of the peer by wireguard-go for it to do useful work. See -// [userspaceEngine.maybeReconfigWireguardLocked] & -// [userspaceEngine.noteRecvActivity] for more details around just-in-time -// wireguard-go peer (de)configuration. +// We still update endpoint activity tracking for bestAddr management. func (le *lazyEndpoint) InitiationMessagePublicKey(peerPublicKey [32]byte) { pubKey := key.NodePublicFromRaw32(mem.B(peerPublicKey[:])) if le.maybeEP != nil && pubKey.Compare(le.maybeEP.publicKey) == 0 { @@ -4183,9 +4295,6 @@ func (le *lazyEndpoint) InitiationMessagePublicKey(peerPublicKey [32]byte) { } le.c.mu.Lock() ep, ok := le.c.peerMap.endpointForNodeKey(pubKey) - // [Conn.mu] must not be held while [Conn.noteRecvActivity] is called, which - // [endpoint.noteRecvActivity] can end up calling. See - // [Options.NoteRecvActivity] docs. le.c.mu.Unlock() if !ok { return @@ -4193,11 +4302,6 @@ func (le *lazyEndpoint) InitiationMessagePublicKey(peerPublicKey [32]byte) { now := mono.Now() ep.lastRecvUDPAny.StoreAtomic(now) ep.noteRecvActivity(le.src, now) - // [ep.noteRecvActivity] may end up JIT configuring the peer, but we don't - // update [peerMap] as wireguard-go hasn't decrypted the initiation - // message yet. wireguard-go will call us below in [lazyEndpoint.FromPeer] - // if it successfully decrypts the message, at which point it's safe to - // insert le.src into the [peerMap] for ep. } func (le *lazyEndpoint) ClearSrc() {} @@ -4328,11 +4432,8 @@ type NewDiscoKeyAvailable struct { // maybeSendTSMPDiscoAdvert conditionally emits an event indicating that we // should send our DiscoKey to the first node address of the magicksock endpoint. -// The event is only emitted if we have not yet contacted that endpoint since -// the DiscoKey changed. -// -// This condition is most likely met only once per endpoint, after the start of -// tailscaled, but not until we contact the endpoint for the first time. +// The event is only emitted if we are not already communicating directly and +// more than 60 seconds has passed since the last DiscoKey was sent. // // We do not need the Conn to be locked, but the endpoint should be. func (c *Conn) maybeSendTSMPDiscoAdvert(de *endpoint) { @@ -4342,11 +4443,16 @@ func (c *Conn) maybeSendTSMPDiscoAdvert(de *endpoint) { de.mu.Lock() defer de.mu.Unlock() - if mono.Now().Sub(de.lastDiscoKeyAdvertisement) > discoKeyAdvertisementInterval { - de.lastDiscoKeyAdvertisement = mono.Now() - c.tsmpDiscoKeyAvailablePub.Publish(NewDiscoKeyAvailable{ - NodeFirstAddr: de.nodeAddr, - NodeID: de.nodeID, - }) + + now := mono.Now() + if now.Sub(de.lastDiscoKeyAdvertisement) <= discoKeyAdvertisementInterval || + (!de.lastDiscoKeyAdvertisement.IsZero() && de.bestAddr.isDirect()) { + return } + + de.lastDiscoKeyAdvertisement = now + c.tsmpDiscoKeyAvailablePub.Publish(NewDiscoKeyAvailable{ + NodeFirstAddr: de.nodeAddr, + NodeID: de.nodeID, + }) } diff --git a/wgengine/magicsock/magicsock_test.go b/wgengine/magicsock/magicsock_test.go index b7492b867..b3c21cb24 100644 --- a/wgengine/magicsock/magicsock_test.go +++ b/wgengine/magicsock/magicsock_test.go @@ -39,7 +39,6 @@ import ( "go4.org/mem" "golang.org/x/net/icmp" "golang.org/x/net/ipv4" - "tailscale.com/cmd/testwrapper/flakytest" "tailscale.com/control/controlknobs" "tailscale.com/derp/derpserver" "tailscale.com/disco" @@ -63,7 +62,6 @@ import ( "tailscale.com/types/netlogtype" "tailscale.com/types/netmap" "tailscale.com/types/nettype" - "tailscale.com/types/views" "tailscale.com/util/cibuild" "tailscale.com/util/clientmetric" "tailscale.com/util/eventbus" @@ -244,6 +242,25 @@ func newMagicStackWithKey(t testing.TB, logf logger.Logf, ln nettype.PacketListe func (s *magicStack) Reconfig(cfg *wgcfg.Config) error { s.tsTun.SetWGConfig(cfg) s.wgLogger.SetPeers(cfg.Peers) + + // In production, LocalBackend installs a PeerByIPPacketFunc via + // Engine.SetPeerByIPPacketFunc. Tests that bypass LocalBackend need + // to install one here for outbound packet routing. + ipToPeer := make(map[netip.Addr]device.NoisePublicKey, len(cfg.Peers)) + for _, p := range cfg.Peers { + pk := p.PublicKey.Raw32() + for _, pfx := range p.AllowedIPs { + if pfx.IsSingleIP() { + ipToPeer[pfx.Addr()] = pk + } + } + } + s.dev.SetPeerByIPPacketFunc(func(_, dst netip.Addr, _ []byte) (device.NoisePublicKey, bool) { + pk, ok := ipToPeer[dst] + return pk, ok + }) + + s.dev.SetPrivateKey(key.NodePrivateAs[device.NoisePrivateKey](cfg.PrivateKey)) return wgcfg.ReconfigDevice(s.dev, cfg, s.conn.logf) } @@ -414,9 +431,11 @@ func TestNewConn(t *testing.T) { stunAddr, stunCleanupFn := stuntest.Serve(t) defer stunCleanupFn() - port := pickPort(t) + // Use port 0 to let the system assign a port, avoiding TOCTOU races + // from the previous pickPort approach which would close a socket and + // hope to rebind to the same port. conn, err := NewConn(Options{ - Port: port, + Port: 0, DisablePortMapper: true, EndpointsFunc: epFunc, Logf: t.Logf, @@ -428,6 +447,13 @@ func TestNewConn(t *testing.T) { t.Fatal(err) } defer conn.Close() + + // Get the actual port that was assigned + port := conn.LocalPort() + if port == 0 { + t.Fatal("LocalPort returned 0") + } + conn.SetDERPMap(stuntest.DERPMapOf(stunAddr.String())) conn.SetPrivateKey(key.NewNode()) @@ -463,16 +489,6 @@ collectEndpoints: } } -func pickPort(t testing.TB) uint16 { - t.Helper() - conn, err := net.ListenPacket("udp4", "127.0.0.1:0") - if err != nil { - t.Fatal(err) - } - defer conn.Close() - return uint16(conn.LocalAddr().(*net.UDPAddr).Port) -} - func TestPickDERPFallback(t *testing.T) { tstest.PanicOnLog() tstest.ResourceCheck(t) @@ -733,7 +749,6 @@ func (localhostListener) ListenPacket(ctx context.Context, network, address stri } func TestTwoDevicePing(t *testing.T) { - flakytest.Mark(t, "https://github.com/tailscale/tailscale/issues/11762") ln, ip := localhostListener{}, netaddr.IPv4(127, 0, 0, 1) n := &devices{ m1: ln, @@ -1265,96 +1280,133 @@ func testTwoDevicePing(t *testing.T, d *devices) { t.Run("compare-metrics-stats", func(t *testing.T) { setT(t) defer setT(outerT) - m1.counts.Reset() - m2.counts.Reset() - m1.conn.resetMetricsForTest() - m2.conn.resetMetricsForTest() - t.Logf("Metrics before: %s\n", m1.metrics.String()) + + // Snapshot both counting systems before pings rather than + // resetting them. Resetting two independent systems + // non-atomically left a window where background WireGuard + // keepalives could increment one system but not the other, + // causing flaky off-by-one mismatches. + physBefore1, metricBefore1 := snapshotCounts(m1) + physBefore2, metricBefore2 := snapshotCounts(m2) + ping1(t) ping2(t) - assertConnStatsAndUserMetricsEqual(t, m1) - assertConnStatsAndUserMetricsEqual(t, m2) + + assertConnStatDeltasMatchMetricDeltas(t, m1, physBefore1, metricBefore1) + assertConnStatDeltasMatchMetricDeltas(t, m2, physBefore2, metricBefore2) assertGlobalMetricsMatchPerConn(t, m1, m2) - t.Logf("Metrics after: %s\n", m1.metrics.String()) }) } -func (c *Conn) resetMetricsForTest() { - c.metrics.inboundBytesIPv4Total.Set(0) - c.metrics.inboundPacketsIPv4Total.Set(0) - c.metrics.outboundBytesIPv4Total.Set(0) - c.metrics.outboundPacketsIPv4Total.Set(0) - c.metrics.inboundBytesIPv6Total.Set(0) - c.metrics.inboundPacketsIPv6Total.Set(0) - c.metrics.outboundBytesIPv6Total.Set(0) - c.metrics.outboundPacketsIPv6Total.Set(0) - c.metrics.inboundBytesDERPTotal.Set(0) - c.metrics.inboundPacketsDERPTotal.Set(0) - c.metrics.outboundBytesDERPTotal.Set(0) - c.metrics.outboundPacketsDERPTotal.Set(0) +// countSnapshot holds a point-in-time snapshot of packet/byte statistics, +// categorized by transport type (IPv4 vs DERP). +type countSnapshot struct { + ipv4RxBytes, ipv4TxBytes int64 + ipv4RxPackets, ipv4TxPackets int64 + derpRxBytes, derpTxBytes int64 + derpRxPackets, derpTxPackets int64 } -func assertConnStatsAndUserMetricsEqual(t *testing.T, ms *magicStack) { - t.Helper() - physIPv4RxBytes := int64(0) - physIPv4TxBytes := int64(0) - physDERPRxBytes := int64(0) - physDERPTxBytes := int64(0) - physIPv4RxPackets := int64(0) - physIPv4TxPackets := int64(0) - physDERPRxPackets := int64(0) - physDERPTxPackets := int64(0) +// snapshotCounts captures the current physical connection counter values and +// user metrics for ms, returning them as separate snapshots. Reading both +// systems back-to-back (rather than resetting them non-atomically) avoids a +// race where background WireGuard keepalives could increment one system but +// not the other during a reset window. +func snapshotCounts(ms *magicStack) (phys, metric countSnapshot) { for conn, count := range ms.counts.Clone() { - t.Logf("physconn src: %s, dst: %s", conn.Src.String(), conn.Dst.String()) if conn.Dst.String() == "127.3.3.40:1" { - physDERPRxBytes += int64(count.RxBytes) - physDERPTxBytes += int64(count.TxBytes) - physDERPRxPackets += int64(count.RxPackets) - physDERPTxPackets += int64(count.TxPackets) + phys.derpRxBytes += int64(count.RxBytes) + phys.derpTxBytes += int64(count.TxBytes) + phys.derpRxPackets += int64(count.RxPackets) + phys.derpTxPackets += int64(count.TxPackets) } else { - physIPv4RxBytes += int64(count.RxBytes) - physIPv4TxBytes += int64(count.TxBytes) - physIPv4RxPackets += int64(count.RxPackets) - physIPv4TxPackets += int64(count.TxPackets) + phys.ipv4RxBytes += int64(count.RxBytes) + phys.ipv4TxBytes += int64(count.TxBytes) + phys.ipv4RxPackets += int64(count.RxPackets) + phys.ipv4TxPackets += int64(count.TxPackets) } } - - metricIPv4RxBytes := ms.conn.metrics.inboundBytesIPv4Total.Value() - metricIPv4RxPackets := ms.conn.metrics.inboundPacketsIPv4Total.Value() - metricIPv4TxBytes := ms.conn.metrics.outboundBytesIPv4Total.Value() - metricIPv4TxPackets := ms.conn.metrics.outboundPacketsIPv4Total.Value() - - metricDERPRxBytes := ms.conn.metrics.inboundBytesDERPTotal.Value() - metricDERPRxPackets := ms.conn.metrics.inboundPacketsDERPTotal.Value() - metricDERPTxBytes := ms.conn.metrics.outboundBytesDERPTotal.Value() - metricDERPTxPackets := ms.conn.metrics.outboundPacketsDERPTotal.Value() - - // Reset counts after reading all values to minimize the window where a - // background packet could increment metrics but miss the cloned counts. - ms.counts.Reset() - - // Compare physical connection stats with per-conn user metrics. - // A rebind during the measurement window can reset the physical connection - // counter, causing physical stats to show 0 while user metrics recorded - // packets normally. Tolerate this by logging instead of failing. - checkPhysVsMetric := func(phys, metric int64, name string) { - if phys == metric { - return - } - if phys == 0 && metric > 0 { - t.Logf("%s: physical counter is 0 but metric is %d (possible rebind during measurement)", name, metric) - return - } - t.Errorf("%s: physical=%d, metric=%d", name, phys, metric) + metric = countSnapshot{ + ipv4RxBytes: ms.conn.metrics.inboundBytesIPv4Total.Value(), + ipv4TxBytes: ms.conn.metrics.outboundBytesIPv4Total.Value(), + ipv4RxPackets: ms.conn.metrics.inboundPacketsIPv4Total.Value(), + ipv4TxPackets: ms.conn.metrics.outboundPacketsIPv4Total.Value(), + derpRxBytes: ms.conn.metrics.inboundBytesDERPTotal.Value(), + derpTxBytes: ms.conn.metrics.outboundBytesDERPTotal.Value(), + derpRxPackets: ms.conn.metrics.inboundPacketsDERPTotal.Value(), + derpTxPackets: ms.conn.metrics.outboundPacketsDERPTotal.Value(), + } + return phys, metric +} + +// assertConnStatDeltasMatchMetricDeltas checks that the changes in physical +// connection counters since physBefore match the changes in user metrics since +// metricBefore. Using deltas avoids a race from non-atomically resetting the +// two independent counting systems. +// +// As a safety net, a difference of exactly one packet (and the corresponding +// bytes) is tolerated, since a background WireGuard keepalive could still +// arrive in the narrow window between snapshotting the two systems. +func assertConnStatDeltasMatchMetricDeltas(t *testing.T, ms *magicStack, physBefore, metricBefore countSnapshot) { + t.Helper() + physAfter, metricAfter := snapshotCounts(ms) + + type stat struct { + name string + physDelta, metDelta int64 + isPackets bool // true for packet counts, false for byte counts + packetDeltaTolerated bool // set by packet check, used by byte check + } + + stats := []stat{ + {name: "IPv4RxPackets", physDelta: physAfter.ipv4RxPackets - physBefore.ipv4RxPackets, metDelta: metricAfter.ipv4RxPackets - metricBefore.ipv4RxPackets, isPackets: true}, + {name: "IPv4RxBytes", physDelta: physAfter.ipv4RxBytes - physBefore.ipv4RxBytes, metDelta: metricAfter.ipv4RxBytes - metricBefore.ipv4RxBytes}, + {name: "IPv4TxPackets", physDelta: physAfter.ipv4TxPackets - physBefore.ipv4TxPackets, metDelta: metricAfter.ipv4TxPackets - metricBefore.ipv4TxPackets, isPackets: true}, + {name: "IPv4TxBytes", physDelta: physAfter.ipv4TxBytes - physBefore.ipv4TxBytes, metDelta: metricAfter.ipv4TxBytes - metricBefore.ipv4TxBytes}, + {name: "DERPRxPackets", physDelta: physAfter.derpRxPackets - physBefore.derpRxPackets, metDelta: metricAfter.derpRxPackets - metricBefore.derpRxPackets, isPackets: true}, + {name: "DERPRxBytes", physDelta: physAfter.derpRxBytes - physBefore.derpRxBytes, metDelta: metricAfter.derpRxBytes - metricBefore.derpRxBytes}, + {name: "DERPTxPackets", physDelta: physAfter.derpTxPackets - physBefore.derpTxPackets, metDelta: metricAfter.derpTxPackets - metricBefore.derpTxPackets, isPackets: true}, + {name: "DERPTxBytes", physDelta: physAfter.derpTxBytes - physBefore.derpTxBytes, metDelta: metricAfter.derpTxBytes - metricBefore.derpTxBytes}, + } + + // First pass: check packet counts, tolerating ±1 from stray keepalives. + for i := range stats { + s := &stats[i] + if !s.isPackets { + continue + } + if s.physDelta == s.metDelta { + continue + } + diff := s.physDelta - s.metDelta + if diff < 0 { + diff = -diff + } + if diff <= 1 { + s.packetDeltaTolerated = true + t.Logf("%s: physical delta=%d, metric delta=%d (off by 1, likely background WireGuard keepalive)", s.name, s.physDelta, s.metDelta) + continue + } + t.Errorf("%s: physical delta=%d, metric delta=%d", s.name, s.physDelta, s.metDelta) + } + + // Second pass: check byte counts; tolerate mismatches when the + // corresponding packet count was already tolerated. + for i := range stats { + s := &stats[i] + if s.isPackets { + continue + } + if s.physDelta == s.metDelta { + continue + } + // The preceding entry in the slice is always the corresponding packet stat. + if stats[i-1].packetDeltaTolerated { + t.Logf("%s: physical delta=%d, metric delta=%d (within single-packet tolerance)", s.name, s.physDelta, s.metDelta) + continue + } + t.Errorf("%s: physical delta=%d, metric delta=%d", s.name, s.physDelta, s.metDelta) } - checkPhysVsMetric(physDERPRxBytes, metricDERPRxBytes, "DERPRxBytes") - checkPhysVsMetric(physDERPTxBytes, metricDERPTxBytes, "DERPTxBytes") - checkPhysVsMetric(physIPv4RxBytes, metricIPv4RxBytes, "IPv4RxBytes") - checkPhysVsMetric(physIPv4TxBytes, metricIPv4TxBytes, "IPv4TxBytes") - checkPhysVsMetric(physDERPRxPackets, metricDERPRxPackets, "DERPRxPackets") - checkPhysVsMetric(physDERPTxPackets, metricDERPTxPackets, "DERPTxPackets") - checkPhysVsMetric(physIPv4RxPackets, metricIPv4RxPackets, "IPv4RxPackets") - checkPhysVsMetric(physIPv4TxPackets, metricIPv4TxPackets, "IPv4TxPackets") } // assertGlobalMetricsMatchPerConn validates that the global clientmetric @@ -1408,34 +1460,39 @@ func TestDiscoStringLogRace(t *testing.T) { wg.Wait() } +// Test32bitAlignment verifies that that the 64-bit atomic mono.Time fields are +// 64-bit aligned, so that StoreAtomic and LoadAtomic won't panic on 32-bit +// platforms. +// +// For normal Go atomic types (sync/atomic.Int64, etc), the Go compiler +// guarantees 64-bit alignment on 32-bit platforms with an unexported magic +// embedded struct field. We can't make mono.Time use that easily. We could change +// mono.Time to be type Time struct { atomic.Int64 }, but that's pretty invasive. +// Instead, we just have this test to keep us safe on 32-bit platforms. func Test32bitAlignment(t *testing.T) { - // Need an associated conn with non-nil noteRecvActivity to - // trigger interesting work on the atomics in endpoint. - called := 0 + if rt := reflect.TypeFor[mono.Time](); rt.Kind() != reflect.Int64 { + t.Fatalf("mono.Time is not a 64-bit integer type anymore; this test may be irrelevant now or out of date") + } + de := endpoint{ - c: &Conn{ - noteRecvActivity: func(key.NodePublic) { called++ }, - }, + c: &Conn{}, } if off := unsafe.Offsetof(de.lastRecvWG); off%8 != 0 { t.Fatalf("endpoint.lastRecvWG is not 8-byte aligned") } + if off := unsafe.Offsetof(de.lastRecvUDPAny); off%8 != 0 { + t.Fatalf("endpoint.lastRecvUDPAny is not 8-byte aligned") + } - de.noteRecvActivity(epAddr{}, mono.Now()) // verify this doesn't panic on 32-bit - if called != 1 { - t.Fatal("expected call to noteRecvActivity") - } - de.noteRecvActivity(epAddr{}, mono.Now()) - if called != 1 { - t.Error("expected no second call to noteRecvActivity") - } + // Verify these don't panic. + de.lastRecvWG.StoreAtomic(mono.Now()) + de.lastRecvUDPAny.StoreAtomic(mono.Now()) } // newTestConn returns a new Conn. func newTestConn(t testing.TB) *Conn { t.Helper() - port := pickPort(t) bus := eventbustest.NewBus(t) @@ -1452,7 +1509,7 @@ func newTestConn(t testing.TB) *Conn { Metrics: new(usermetric.Registry), DisablePortMapper: true, Logf: t.Logf, - Port: port, + Port: 0, TestOnlyPacketListener: localhostListener{}, EndpointsFunc: func(eps []tailcfg.Endpoint) { t.Logf("endpoints: %q", eps) @@ -3048,6 +3105,7 @@ func TestMaybeSetNearestDERP(t *testing.T) { old int reportDERP int connectedToControl bool + force bool want int }{ { @@ -3071,6 +3129,22 @@ func TestMaybeSetNearestDERP(t *testing.T) { connectedToControl: false, // not connected... want: 21, // ... but want to change to new DERP }, + { + name: "force_not_connected_with_report_derp", + old: 1, + reportDERP: 21, + connectedToControl: false, + force: true, + want: 21, // force bypasses the no-change-without-control guard + }, + { + name: "force_not_connected_no_derp_no_current", + old: 0, + reportDERP: 0, + connectedToControl: false, + force: true, + want: 31, // force + no report DERP → deterministic fallback + }, { name: "not_connected_with_fallback_and_no_current", old: 0, // no current DERP @@ -3095,8 +3169,13 @@ func TestMaybeSetNearestDERP(t *testing.T) { } for _, tt := range testCases { t.Run(tt.name, func(t *testing.T) { - ht := health.NewTracker(eventbustest.NewBus(t)) + bus := eventbustest.NewBus(t) + ht := health.NewTracker(bus) c := newConn(t.Logf) + ec := bus.Client("magicsock.Conn.Test") + c.eventClient = ec + c.homeDERPChangedPub = eventbus.Publish[HomeDERPChanged](ec) + c.eventBus = bus c.myDerp = tt.old c.derpMap = derpMap c.health = ht @@ -3114,7 +3193,7 @@ func TestMaybeSetNearestDERP(t *testing.T) { } } - got := c.maybeSetNearestDERP(report) + got := c.maybeSetNearestDERP(report, tt.force) if got != tt.want { t.Errorf("got new DERP region %d, want %d", got, tt.want) } @@ -3814,7 +3893,7 @@ func TestConn_SetNetworkMap_updateRelayServersSet(t *testing.T) { c.filt = tt.filt if len(tt.wantRelayServers) == 0 { // So we can verify it gets flipped back. - c.hasPeerRelayServers.Store(true) + c.relayManager.hasPeerRelayServers.Store(true) } c.SetNetworkMap(tt.self, tt.peers) @@ -3822,8 +3901,8 @@ func TestConn_SetNetworkMap_updateRelayServersSet(t *testing.T) { if !got.Equal(tt.wantRelayServers) { t.Fatalf("got: %v != want: %v", got, tt.wantRelayServers) } - if len(tt.wantRelayServers) > 0 != c.hasPeerRelayServers.Load() { - t.Fatalf("c.hasPeerRelayServers: %v != len(tt.wantRelayServers) > 0: %v", c.hasPeerRelayServers.Load(), len(tt.wantRelayServers) > 0) + if got, want := c.relayManager.hasPeerRelayServers.Load(), len(tt.wantRelayServers) > 0; got != want { + t.Fatalf("c.relayManager.hasPeerRelayServers: %v != len(tt.wantRelayServers) > 0: %v", got, want) } if c.relayClientEnabled != tt.wantRelayClientEnabled { t.Fatalf("c.relayClientEnabled: %v != wantRelayClientEnabled: %v", c.relayClientEnabled, tt.wantRelayClientEnabled) @@ -3903,60 +3982,55 @@ func TestConn_receiveIP(t *testing.T) { // If [*endpoint] then we expect 'got' to be the same [*endpoint]. If // [*lazyEndpoint] and [*lazyEndpoint.maybeEP] is non-nil, we expect // got.maybeEP to also be non-nil. Must not be reused across tests. - wantEndpointType wgconn.Endpoint - wantSize int - wantIsGeneveEncap bool - wantOk bool - wantMetricInc *clientmetric.Metric - wantNoteRecvActivityCalled bool + wantEndpointType wgconn.Endpoint + wantSize int + wantIsGeneveEncap bool + wantOk bool + wantMetricInc *clientmetric.Metric }{ { - name: "naked-disco", - b: looksLikeNakedDisco, - ipp: netip.MustParseAddrPort("127.0.0.1:7777"), - cache: &epAddrEndpointCache{}, - wantEndpointType: nil, - wantSize: 0, - wantIsGeneveEncap: false, - wantOk: false, - wantMetricInc: metricRecvDiscoBadPeer, - wantNoteRecvActivityCalled: false, + name: "naked-disco", + b: looksLikeNakedDisco, + ipp: netip.MustParseAddrPort("127.0.0.1:7777"), + cache: &epAddrEndpointCache{}, + wantEndpointType: nil, + wantSize: 0, + wantIsGeneveEncap: false, + wantOk: false, + wantMetricInc: metricRecvDiscoBadPeer, }, { - name: "geneve-encap-disco", - b: looksLikeGeneveDisco, - ipp: netip.MustParseAddrPort("127.0.0.1:7777"), - cache: &epAddrEndpointCache{}, - wantEndpointType: nil, - wantSize: 0, - wantIsGeneveEncap: false, - wantOk: false, - wantMetricInc: metricRecvDiscoBadPeer, - wantNoteRecvActivityCalled: false, + name: "geneve-encap-disco", + b: looksLikeGeneveDisco, + ipp: netip.MustParseAddrPort("127.0.0.1:7777"), + cache: &epAddrEndpointCache{}, + wantEndpointType: nil, + wantSize: 0, + wantIsGeneveEncap: false, + wantOk: false, + wantMetricInc: metricRecvDiscoBadPeer, }, { - name: "STUN-binding", - b: looksLikeSTUNBinding, - ipp: netip.MustParseAddrPort("127.0.0.1:7777"), - cache: &epAddrEndpointCache{}, - wantEndpointType: nil, - wantSize: 0, - wantIsGeneveEncap: false, - wantOk: false, - wantMetricInc: findMetricByName("netcheck_stun_recv_ipv4"), - wantNoteRecvActivityCalled: false, + name: "STUN-binding", + b: looksLikeSTUNBinding, + ipp: netip.MustParseAddrPort("127.0.0.1:7777"), + cache: &epAddrEndpointCache{}, + wantEndpointType: nil, + wantSize: 0, + wantIsGeneveEncap: false, + wantOk: false, + wantMetricInc: findMetricByName("netcheck_stun_recv_ipv4"), }, { - name: "naked-WireGuard-init-lazyEndpoint-empty-peerMap", - b: looksLikeNakedWireGuardInit, - ipp: netip.MustParseAddrPort("127.0.0.1:7777"), - cache: &epAddrEndpointCache{}, - wantEndpointType: &lazyEndpoint{}, - wantSize: len(looksLikeNakedWireGuardInit), - wantIsGeneveEncap: false, - wantOk: true, - wantMetricInc: nil, - wantNoteRecvActivityCalled: false, + name: "naked-WireGuard-init-lazyEndpoint-empty-peerMap", + b: looksLikeNakedWireGuardInit, + ipp: netip.MustParseAddrPort("127.0.0.1:7777"), + cache: &epAddrEndpointCache{}, + wantEndpointType: &lazyEndpoint{}, + wantSize: len(looksLikeNakedWireGuardInit), + wantIsGeneveEncap: false, + wantOk: true, + wantMetricInc: nil, }, { name: "naked-WireGuard-init-endpoint-matching-peerMap-entry", @@ -3970,19 +4044,17 @@ func TestConn_receiveIP(t *testing.T) { wantIsGeneveEncap: false, wantOk: true, wantMetricInc: nil, - wantNoteRecvActivityCalled: true, }, { - name: "geneve-WireGuard-init-lazyEndpoint-empty-peerMap", - b: looksLikeGeneveWireGuardInit, - ipp: netip.MustParseAddrPort("127.0.0.1:7777"), - cache: &epAddrEndpointCache{}, - wantEndpointType: &lazyEndpoint{}, - wantSize: len(looksLikeGeneveWireGuardInit) - packet.GeneveFixedHeaderLength, - wantIsGeneveEncap: true, - wantOk: true, - wantMetricInc: nil, - wantNoteRecvActivityCalled: false, + name: "geneve-WireGuard-init-lazyEndpoint-empty-peerMap", + b: looksLikeGeneveWireGuardInit, + ipp: netip.MustParseAddrPort("127.0.0.1:7777"), + cache: &epAddrEndpointCache{}, + wantEndpointType: &lazyEndpoint{}, + wantSize: len(looksLikeGeneveWireGuardInit) - packet.GeneveFixedHeaderLength, + wantIsGeneveEncap: true, + wantOk: true, + wantMetricInc: nil, }, { name: "geneve-WireGuard-init-lazyEndpoint-matching-peerMap-activity-noted", @@ -3994,11 +4066,10 @@ func TestConn_receiveIP(t *testing.T) { wantEndpointType: &lazyEndpoint{ maybeEP: newPeerMapInsertableEndpoint(0), }, - wantSize: len(looksLikeGeneveWireGuardInit) - packet.GeneveFixedHeaderLength, - wantIsGeneveEncap: true, - wantOk: true, - wantMetricInc: nil, - wantNoteRecvActivityCalled: true, + wantSize: len(looksLikeGeneveWireGuardInit) - packet.GeneveFixedHeaderLength, + wantIsGeneveEncap: true, + wantOk: true, + wantMetricInc: nil, }, { name: "geneve-WireGuard-init-lazyEndpoint-matching-peerMap-no-activity-noted", @@ -4010,17 +4081,15 @@ func TestConn_receiveIP(t *testing.T) { wantEndpointType: &lazyEndpoint{ maybeEP: newPeerMapInsertableEndpoint(mono.Now().Add(time.Hour * 24)), }, - wantSize: len(looksLikeGeneveWireGuardInit) - packet.GeneveFixedHeaderLength, - wantIsGeneveEncap: true, - wantOk: true, - wantMetricInc: nil, - wantNoteRecvActivityCalled: false, + wantSize: len(looksLikeGeneveWireGuardInit) - packet.GeneveFixedHeaderLength, + wantIsGeneveEncap: true, + wantOk: true, + wantMetricInc: nil, }, // TODO(jwhited): verify cache.de is used when conditions permit } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - noteRecvActivityCalled := false metricBefore := int64(0) if tt.wantMetricInc != nil { metricBefore = tt.wantMetricInc.Value() @@ -4033,9 +4102,6 @@ func TestConn_receiveIP(t *testing.T) { peerMap: newPeerMap(), } c.havePrivateKey.Store(true) - c.noteRecvActivity = func(public key.NodePublic) { - noteRecvActivityCalled = true - } var counts netlogtype.CountsByConnection c.SetConnectionCounter(counts.Add) @@ -4090,10 +4156,6 @@ func TestConn_receiveIP(t *testing.T) { if tt.wantMetricInc != nil && tt.wantMetricInc.Value() != metricBefore+1 { t.Errorf("receiveIP() metric %v not incremented", tt.wantMetricInc.Name()) } - if tt.wantNoteRecvActivityCalled != noteRecvActivityCalled { - t.Errorf("receiveIP() noteRecvActivityCalled = %v, want %v", noteRecvActivityCalled, tt.wantNoteRecvActivityCalled) - } - if tt.cache.de != nil { switch ep := got.(type) { case *endpoint: @@ -4145,34 +4207,29 @@ func TestConn_receiveIP(t *testing.T) { func Test_lazyEndpoint_InitiationMessagePublicKey(t *testing.T) { tests := []struct { - name string - callWithPeerMapKey bool - maybeEPMatchingKey bool - wantNoteRecvActivityCalled bool + name string + callWithPeerMapKey bool + maybeEPMatchingKey bool }{ { - name: "noteRecvActivity-called", - callWithPeerMapKey: true, - maybeEPMatchingKey: false, - wantNoteRecvActivityCalled: true, + name: "noteRecvActivity-called", + callWithPeerMapKey: true, + maybeEPMatchingKey: false, }, { - name: "maybeEP-early-return", - callWithPeerMapKey: true, - maybeEPMatchingKey: true, - wantNoteRecvActivityCalled: false, + name: "maybeEP-early-return", + callWithPeerMapKey: true, + maybeEPMatchingKey: true, }, { - name: "not-in-peerMap-early-return", - callWithPeerMapKey: false, - maybeEPMatchingKey: false, - wantNoteRecvActivityCalled: false, + name: "not-in-peerMap-early-return", + callWithPeerMapKey: false, + maybeEPMatchingKey: false, }, { - name: "not-in-peerMap-maybeEP-early-return", - callWithPeerMapKey: false, - maybeEPMatchingKey: true, - wantNoteRecvActivityCalled: false, + name: "not-in-peerMap-maybeEP-early-return", + callWithPeerMapKey: false, + maybeEPMatchingKey: true, }, } for _, tt := range tests { @@ -4185,19 +4242,7 @@ func Test_lazyEndpoint_InitiationMessagePublicKey(t *testing.T) { key: key.NewDisco().Public(), }) - var noteRecvActivityCalledFor key.NodePublic conn := newConn(t.Logf) - conn.noteRecvActivity = func(public key.NodePublic) { - // wireguard-go will call into ParseEndpoint if the "real" - // noteRecvActivity ends up JIT configuring the peer. Mimic that - // to ensure there are no deadlocks around conn.mu. - // See tailscale/tailscale#16651 & http://go/corp#30836 - _, err := conn.ParseEndpoint(ep.publicKey.UntypedHexString()) - if err != nil { - t.Fatalf("ParseEndpoint() err: %v", err) - } - noteRecvActivityCalledFor = public - } ep.c = conn var pubKey [32]byte @@ -4213,13 +4258,6 @@ func Test_lazyEndpoint_InitiationMessagePublicKey(t *testing.T) { le.maybeEP = ep } le.InitiationMessagePublicKey(pubKey) - want := key.NodePublic{} - if tt.wantNoteRecvActivityCalled { - want = ep.publicKey - } - if noteRecvActivityCalledFor.Compare(want) != 0 { - t.Fatalf("noteRecvActivityCalledFor = %v, want %v", noteRecvActivityCalledFor, want) - } }) } } @@ -4389,7 +4427,7 @@ func TestReceiveTSMPDiscoKeyAdvertisement(t *testing.T) { netip.MustParsePrefix("100.64.0.1/32"), }, }).View() - conn.peers = views.SliceOf([]tailcfg.NodeView{nodeView}) + conn.peersByID = map[tailcfg.NodeID]tailcfg.NodeView{nodeView.ID(): nodeView} conn.mu.Unlock() conn.peerMap.upsertEndpoint(ep, key.DiscoPublic{}) @@ -4411,7 +4449,6 @@ func TestReceiveTSMPDiscoKeyAdvertisement(t *testing.T) { } func TestSendingTSMPDiscoTimer(t *testing.T) { - t.Setenv("TS_USE_CACHED_NETMAP", "1") conn := newTestConn(t) tw := eventbustest.NewWatcher(t, conn.eventBus) t.Cleanup(func() { conn.Close() }) @@ -4435,7 +4472,7 @@ func TestSendingTSMPDiscoTimer(t *testing.T) { netip.MustParsePrefix("100.64.0.1/32"), }, }).View() - conn.peers = views.SliceOf([]tailcfg.NodeView{nodeView}) + conn.peersByID = map[tailcfg.NodeID]tailcfg.NodeView{nodeView.ID(): nodeView} conn.mu.Unlock() conn.peerMap.upsertEndpoint(ep, key.DiscoPublic{}) @@ -4444,12 +4481,36 @@ func TestSendingTSMPDiscoTimer(t *testing.T) { t.Errorf("Original disco key %s, does not match %s", discoKey.ShortString(), ep.discoShort()) } + // Only one gets through, second is rate limited. conn.maybeSendTSMPDiscoAdvert(ep) conn.maybeSendTSMPDiscoAdvert(ep) - eventbustest.ExpectExactly(tw, eventbustest.Type[NewDiscoKeyAvailable]()) + if err := eventbustest.ExpectExactly(tw, eventbustest.Type[NewDiscoKeyAvailable]()); err != nil { + t.Errorf("expected only one event, got: %s", err) + } + + // Reset to get the event firing again. ep.mu.Lock() ep.lastDiscoKeyAdvertisement = 0 ep.mu.Unlock() conn.maybeSendTSMPDiscoAdvert(ep) - eventbustest.Expect(tw, eventbustest.Type[NewDiscoKeyAvailable]()) + if err := eventbustest.Expect(tw, eventbustest.Type[NewDiscoKeyAvailable]()); err != nil { + t.Errorf("expected only one event, got: %s", err) + } + + // With a direct bestAddr and a non-zero lastDiscoKeyAdvertisement past the + // rate-limit interval. No advert should be sent due to the active bestAddr. + ep.mu.Lock() + ep.lastDiscoKeyAdvertisement = mono.Now().Add(-discoKeyAdvertisementInterval - time.Second) + ep.bestAddr = addrQuality{epAddr: epAddr{ap: netip.MustParseAddrPort("1.2.3.4:567")}} + ep.mu.Unlock() + conn.maybeSendTSMPDiscoAdvert(ep) + + // Simulating restart should send an advert. + ep.mu.Lock() + ep.lastDiscoKeyAdvertisement = 0 + ep.mu.Unlock() + conn.maybeSendTSMPDiscoAdvert(ep) + if err := eventbustest.ExpectExactly(tw, eventbustest.Type[NewDiscoKeyAvailable]()); err != nil { + t.Errorf("expected only one event, got: %s", err) + } } diff --git a/wgengine/magicsock/relaymanager.go b/wgengine/magicsock/relaymanager.go index e4cd5eb9f..8ea15bce3 100644 --- a/wgengine/magicsock/relaymanager.go +++ b/wgengine/magicsock/relaymanager.go @@ -9,6 +9,7 @@ import ( "fmt" "net/netip" "sync" + "sync/atomic" "time" "tailscale.com/disco" @@ -34,6 +35,14 @@ import ( type relayManager struct { initOnce sync.Once + // hasPeerRelayServers is whether relayManager is configured with at + // least one peer relay server via [relayManager.handleRelayServersSet] + // (or per-peer variants). Exposed as an atomic so [endpoint] hot paths + // can short-circuit when there are no relay servers without taking any + // lock or entering the run loop. Written only from runLoop() via + // [relayManager.publishHasServersRunLoop]. + hasPeerRelayServers atomic.Bool + // =================================================================== // The following fields are owned by a single goroutine, runLoop(). serversByNodeKey map[key.NodePublic]candidatePeerRelay @@ -56,6 +65,8 @@ type relayManager struct { newServerEndpointCh chan newRelayServerEndpointEvent rxDiscoMsgCh chan relayDiscoMsgEvent serversCh chan set.Set[candidatePeerRelay] + serverUpsertCh chan candidatePeerRelay + serverRemoveCh chan key.NodePublic getServersCh chan chan set.Set[candidatePeerRelay] derpHomeChangeCh chan derpHomeChangeEvent @@ -228,6 +239,16 @@ func (r *relayManager) runLoop() { if !r.hasActiveWorkRunLoop() { return } + case upsert := <-r.serverUpsertCh: + r.handleServerUpsertRunLoop(upsert) + if !r.hasActiveWorkRunLoop() { + return + } + case nk := <-r.serverRemoveCh: + r.handleServerRemoveRunLoop(nk) + if !r.hasActiveWorkRunLoop() { + return + } case getServersCh := <-r.getServersCh: r.handleGetServersRunLoop(getServersCh) if !r.hasActiveWorkRunLoop() { @@ -265,6 +286,34 @@ func (r *relayManager) handleServersUpdateRunLoop(update set.Set[candidatePeerRe for _, v := range update.Slice() { r.serversByNodeKey[v.nodeKey] = v } + r.publishHasServersRunLoop() +} + +// handleServerUpsertRunLoop inserts or updates cp in serversByNodeKey. It is +// the per-peer analog of [relayManager.handleServersUpdateRunLoop] used by +// [Conn.UpsertPeer]. +func (r *relayManager) handleServerUpsertRunLoop(cp candidatePeerRelay) { + r.serversByNodeKey[cp.nodeKey] = cp + r.publishHasServersRunLoop() +} + +// handleServerRemoveRunLoop deletes nk from serversByNodeKey. It is a no-op +// if nk isn't currently a known server. It is the per-peer analog of +// [relayManager.handleServersUpdateRunLoop] used by [Conn.RemovePeer] and by +// [Conn.UpsertPeer] when a peer is upserted with fields that make it no +// longer a relay candidate. +func (r *relayManager) handleServerRemoveRunLoop(nk key.NodePublic) { + if _, ok := r.serversByNodeKey[nk]; !ok { + return + } + delete(r.serversByNodeKey, nk) + r.publishHasServersRunLoop() +} + +// publishHasServersRunLoop updates [relayManager.hasPeerRelayServers] to +// reflect whether any relay servers are currently known. +func (r *relayManager) publishHasServersRunLoop() { + r.hasPeerRelayServers.Store(len(r.serversByNodeKey) > 0) } type relayDiscoMsgEvent struct { @@ -330,6 +379,8 @@ func (r *relayManager) init() { r.newServerEndpointCh = make(chan newRelayServerEndpointEvent) r.rxDiscoMsgCh = make(chan relayDiscoMsgEvent) r.serversCh = make(chan set.Set[candidatePeerRelay]) + r.serverUpsertCh = make(chan candidatePeerRelay) + r.serverRemoveCh = make(chan key.NodePublic) r.getServersCh = make(chan chan set.Set[candidatePeerRelay]) r.derpHomeChangeCh = make(chan derpHomeChangeEvent) r.runLoopStoppedCh = make(chan struct{}, 1) @@ -436,6 +487,21 @@ func (r *relayManager) handleRelayServersSet(servers set.Set[candidatePeerRelay] relayManagerInputEvent(r, nil, &r.serversCh, servers) } +// handleRelayServerUpsert is the O(1) per-peer variant of +// [relayManager.handleRelayServersSet]: it inserts or updates a single +// relay server entry. +func (r *relayManager) handleRelayServerUpsert(cp candidatePeerRelay) { + relayManagerInputEvent(r, nil, &r.serverUpsertCh, cp) +} + +// handleRelayServerRemove is the O(1) per-peer variant of +// [relayManager.handleRelayServersSet]: it removes a single relay server +// entry by node key. It is a no-op if nk is not currently a known relay +// server. +func (r *relayManager) handleRelayServerRemove(nk key.NodePublic) { + relayManagerInputEvent(r, nil, &r.serverRemoveCh, nk) +} + // relayManagerInputEvent initializes [relayManager] if necessary, starts // relayManager.runLoop() if it is not running, and writes 'event' on 'eventCh'. // diff --git a/wgengine/netstack/netstack.go b/wgengine/netstack/netstack.go index 4da89e364..659e07924 100644 --- a/wgengine/netstack/netstack.go +++ b/wgengine/netstack/netstack.go @@ -64,10 +64,12 @@ import ( const debugPackets = false // If non-zero, these override the values returned from the corresponding -// functions, below. +// functions, below. They are accessed atomically because background +// goroutines in the gVisor TCP stack read them while test cleanup +// goroutines may be restoring them concurrently. var ( - maxInFlightConnectionAttemptsForTest int - maxInFlightConnectionAttemptsPerClientForTest int + maxInFlightConnectionAttemptsForTest atomic.Int32 + maxInFlightConnectionAttemptsPerClientForTest atomic.Int32 ) // maxInFlightConnectionAttempts returns the global number of in-flight @@ -80,8 +82,8 @@ var ( // connection, so we want to ensure that we don't allow an unbounded number of // connections. func maxInFlightConnectionAttempts() int { - if n := maxInFlightConnectionAttemptsForTest; n > 0 { - return n + if n := maxInFlightConnectionAttemptsForTest.Load(); n > 0 { + return int(n) } if version.IsMobile() { @@ -106,8 +108,8 @@ func maxInFlightConnectionAttempts() int { // maxInFlightConnectionAttempts, but applies on a per-client basis // (i.e. keyed by the remote Tailscale IP). func maxInFlightConnectionAttemptsPerClient() int { - if n := maxInFlightConnectionAttemptsPerClientForTest; n > 0 { - return n + if n := maxInFlightConnectionAttemptsPerClientForTest.Load(); n > 0 { + return int(n) } // For now, allow each individual client at most 2/3rds of the global @@ -214,6 +216,7 @@ type Impl struct { dialer *tsdial.Dialer ctx context.Context // alive until Close ctxCancel context.CancelFunc // called on Close + injectWG sync.WaitGroup // wait for the inject goroutine lb *ipnlocal.LocalBackend // or nil dns *dns.Manager @@ -448,6 +451,7 @@ func (ns *Impl) Close() error { ns.ctxCancel() ns.ipstack.Close() ns.ipstack.Wait() + ns.injectWG.Wait() return nil } @@ -642,7 +646,9 @@ func (ns *Impl) Start(b LocalBackend) error { udpFwd := udp.NewForwarder(ns.ipstack, ns.acceptUDPNoICMP) ns.ipstack.SetTransportProtocolHandler(tcp.ProtocolNumber, ns.wrapTCPProtocolHandler(tcpFwd.HandlePacket)) ns.ipstack.SetTransportProtocolHandler(udp.ProtocolNumber, ns.wrapUDPProtocolHandler(udpFwd.HandlePacket)) - go ns.inject() + ns.injectWG.Go(func() { + ns.inject() + }) if ns.ready.Swap(true) { panic("already started") } @@ -843,20 +849,27 @@ func (ns *Impl) handleLocalPackets(p *packet.Parsed, t *tstun.Wrapper, gro *gro. serviceName, isVIPServiceIP := ns.atomicIPVIPServiceMap.Load()[dst] switch { case dst == serviceIP || dst == serviceIPv6: - // We want to intercept some traffic to the "service IP" (e.g. - // 100.100.100.100 for IPv4). However, of traffic to the - // service IP, we only care about UDP 53, and TCP on port 53, - // 80, and 8080. - switch p.IPProto { - case ipproto.TCP: - if port := p.Dst.Port(); port != 53 && port != 80 && port != 8080 && !ns.isLoopbackPort(port) { - return filter.Accept, gro - } - case ipproto.UDP: - if port := p.Dst.Port(); port != 53 && !ns.isLoopbackPort(port) { - return filter.Accept, gro - } - } + // Traffic to the Tailscale service IP (100.100.100.100 / + // fd7a:115c:a1e0::53) is always terminated locally on this + // node; it must never be forwarded out over WireGuard to a + // peer. Netstack's TCP/UDP acceptors handle the ports we + // actually serve (UDP 53 MagicDNS, TCP 53/80/8080 for DNS, + // the web client, and Taildrive, plus any debug loopback + // port). Other ports are rejected cleanly by netstack: UDP + // closes the endpoint in acceptUDP, and TCP is RST'd by + // acceptTCP's hittingServiceIP guard. + // + // Previously we returned filter.Accept for TCP/UDP on any + // other port, which let the packet fall through to the ACL + // filter and ultimately wireguard-go, where no peer owns the + // quad-100 AllowedIP. That produced noisy "open-conn-track: + // timeout opening ...; no associated peer node" log lines + // (e.g. for stray traffic to 100.100.100.100:853 / DoT) and + // leaked quad-100 packets onto the tailnet. + // + // We now unconditionally absorb quad-100 into netstack here, + // regardless of IP protocol or port, so such traffic never + // reaches the conntrack / peer-routing layers. case isVIPServiceIP: // returns all active VIP services in a set, since the IPVIPServiceMap // contains inactive service IPs when node hosts the service, we need to @@ -1652,6 +1665,24 @@ func (ns *Impl) acceptTCP(r *tcp.ForwarderRequest) { } else { dialIP = ipv4Loopback } + case hittingServiceIP: + // TCP to the Tailscale service IP on a port we don't serve + // (anything other than DNS/53, web client/80, Taildrive/8080, + // or the debug loopback port handled above). handleLocalPackets + // absorbs all quad-100 traffic into netstack to prevent it + // from leaking to WireGuard peers as noisy "open-conn-track: + // timeout opening ...; no associated peer node" log lines + // (see the comment there). + // + // Without this explicit guard, execution would fall through + // to the isTailscaleIP case below (quad-100 is in the + // tailscale IP range), rewriting the dial target to + // 127.0.0.1: and forwardTCP'ing the connection onto + // whatever random service happens to be listening on the + // host's loopback at that port. Reject cleanly with a RST + // here instead. + r.Complete(true) // sends a RST + return case isTailscaleIP: dialIP = ipv4Loopback } diff --git a/wgengine/netstack/netstack_test.go b/wgengine/netstack/netstack_test.go index e588fa47c..7f248cd44 100644 --- a/wgengine/netstack/netstack_test.go +++ b/wgengine/netstack/netstack_test.go @@ -33,6 +33,7 @@ import ( "tailscale.com/types/ipproto" "tailscale.com/types/logid" "tailscale.com/types/netmap" + "tailscale.com/util/clientmetric" "tailscale.com/wgengine" "tailscale.com/wgengine/filter" ) @@ -736,7 +737,10 @@ func makeHangDialer(tb testing.TB) (netx.DialFunc, chan struct{}) { // TestTCPForwardLimits verifies that the limits on the TCP forwarder work in a // success case (i.e. when we don't hit the limit). func TestTCPForwardLimits(t *testing.T) { + tstest.AssertNotParallel(t) // calls envknob.Setenv envknob.Setenv("TS_DEBUG_NETSTACK", "true") + t.Cleanup(func() { envknob.Setenv("TS_DEBUG_NETSTACK", "") }) + impl := makeNetstack(t, func(impl *Impl) { impl.ProcessSubnets = true }) @@ -811,11 +815,16 @@ func TestTCPForwardLimits(t *testing.T) { // TestTCPForwardLimits_PerClient verifies that the per-client limit for TCP // forwarding works. func TestTCPForwardLimits_PerClient(t *testing.T) { + clientmetric.ResetForTest(t) + tstest.AssertNotParallel(t) // calls envknob.Setenv envknob.Setenv("TS_DEBUG_NETSTACK", "true") + t.Cleanup(func() { envknob.Setenv("TS_DEBUG_NETSTACK", "") }) // Set our test override limits during this test. - tstest.Replace(t, &maxInFlightConnectionAttemptsForTest, 2) - tstest.Replace(t, &maxInFlightConnectionAttemptsPerClientForTest, 1) + maxInFlightConnectionAttemptsForTest.Store(2) + t.Cleanup(func() { maxInFlightConnectionAttemptsForTest.Store(0) }) + maxInFlightConnectionAttemptsPerClientForTest.Store(1) + t.Cleanup(func() { maxInFlightConnectionAttemptsPerClientForTest.Store(0) }) impl := makeNetstack(t, func(impl *Impl) { impl.ProcessSubnets = true @@ -948,6 +957,7 @@ func TestHandleLocalPackets(t *testing.T) { impl.lb.SetIPServiceMappingsForTest(IPServiceMap) t.Run("ShouldHandleServiceIP", func(t *testing.T) { + t.Parallel() pkt := &packet.Parsed{ IPVersion: 4, IPProto: ipproto.TCP, @@ -960,7 +970,94 @@ func TestHandleLocalPackets(t *testing.T) { t.Errorf("got filter outcome %v, want filter.DropSilently", resp) } }) + // Any port on the quad-100 service IP must be absorbed locally by + // netstack and never leak out to the WireGuard / peer-routing + // layers. Historically we only intercepted specific ports (UDP 53 + // and TCP 53/80/8080), causing stray traffic to other ports such + // as 100.100.100.100:853 (DoT) to time out in wireguard-go and + // produce "open-conn-track: timeout opening ...; no associated + // peer node" log spam. See the handleLocalPackets comment. + quad100LeakCases := []struct { + name string + proto ipproto.Proto + dst string + }{ + {"TCP-853-DoT-v4", ipproto.TCP, "100.100.100.100:853"}, + {"TCP-443-DoH-v4", ipproto.TCP, "100.100.100.100:443"}, + {"TCP-9000-stray-v4", ipproto.TCP, "100.100.100.100:9000"}, + {"UDP-853-DoQ-v4", ipproto.UDP, "100.100.100.100:853"}, + {"UDP-443-v4", ipproto.UDP, "100.100.100.100:443"}, + {"TCP-853-DoT-v6", ipproto.TCP, "[fd7a:115c:a1e0::53]:853"}, + {"UDP-443-v6", ipproto.UDP, "[fd7a:115c:a1e0::53]:443"}, + } + for _, tc := range quad100LeakCases { + t.Run("ShouldNotLeakQuad100_"+tc.name, func(t *testing.T) { + t.Parallel() + dst := netip.MustParseAddrPort(tc.dst) + ipVersion := uint8(4) + if dst.Addr().Is6() { + ipVersion = 6 + } + src := "127.0.0.1:9999" + if ipVersion == 6 { + src = "[::1]:9999" + } + pkt := &packet.Parsed{ + IPVersion: ipVersion, + IPProto: tc.proto, + Src: netip.MustParseAddrPort(src), + Dst: dst, + } + if tc.proto == ipproto.TCP { + pkt.TCPFlags = packet.TCPSyn + } + resp, _ := impl.handleLocalPackets(pkt, impl.tundev, nil) + if resp != filter.DropSilently { + t.Errorf("quad-100 %s packet leaked: got filter outcome %v, want filter.DropSilently", tc.name, resp) + } + }) + } + // Exhaustive sweep of all ports for both transport protocols and + // both IP versions, confirming no port leaks. The quad-100 branch + // of handleLocalPackets is port-independent by construction; this + // test serves as a regression guard against accidental port-based + // exemptions slipping back in. + t.Run("ShouldNotLeakQuad100_AllPorts", func(t *testing.T) { + t.Parallel() + protos := []ipproto.Proto{ipproto.TCP, ipproto.UDP} + dsts := []netip.Addr{ + netip.MustParseAddr("100.100.100.100"), + netip.MustParseAddr("fd7a:115c:a1e0::53"), + } + for _, proto := range protos { + for _, dstAddr := range dsts { + ipVersion := uint8(4) + srcStr := "127.0.0.1:9999" + if dstAddr.Is6() { + ipVersion = 6 + srcStr = "[::1]:9999" + } + src := netip.MustParseAddrPort(srcStr) + for port := 1; port <= 65535; port++ { + pkt := &packet.Parsed{ + IPVersion: ipVersion, + IPProto: proto, + Src: src, + Dst: netip.AddrPortFrom(dstAddr, uint16(port)), + } + if proto == ipproto.TCP { + pkt.TCPFlags = packet.TCPSyn + } + resp, _ := impl.handleLocalPackets(pkt, impl.tundev, nil) + if resp != filter.DropSilently { + t.Fatalf("port=%d proto=%v dst=%v: got %v, want filter.DropSilently", port, proto, dstAddr, resp) + } + } + } + } + }) t.Run("ShouldHandle4via6", func(t *testing.T) { + t.Parallel() pkt := &packet.Parsed{ IPVersion: 6, IPProto: ipproto.TCP, @@ -983,6 +1080,7 @@ func TestHandleLocalPackets(t *testing.T) { } }) t.Run("ShouldHandleLocalTailscaleServices", func(t *testing.T) { + t.Parallel() pkt := &packet.Parsed{ IPVersion: 4, IPProto: ipproto.TCP, @@ -996,6 +1094,7 @@ func TestHandleLocalPackets(t *testing.T) { } }) t.Run("OtherNonHandled", func(t *testing.T) { + t.Parallel() pkt := &packet.Parsed{ IPVersion: 6, IPProto: ipproto.TCP, @@ -1018,6 +1117,100 @@ func TestHandleLocalPackets(t *testing.T) { }) } +// TestQuad100UnservedTCPPortDoesNotForward verifies that a TCP SYN to the +// Tailscale service IP (100.100.100.100) on a port we don't serve is +// absorbed by netstack and rejected cleanly, without triggering the +// outbound forwardTCP dialer. +// +// handleLocalPackets now absorbs all quad-100 traffic regardless of +// port to prevent it leaking to WireGuard peers (which produced noisy +// "open-conn-track: timeout opening ...; no associated peer node" log +// lines). That leaves acceptTCP responsible for rejecting connections +// to ports we don't handle; without an explicit guard, execution would +// fall through to the isTailscaleIP case (quad-100 is in the tailscale +// range), rewriting the dial target to 127.0.0.1: and forwarding +// the connection to whatever random service happened to be listening +// on the host's loopback at that port. +// +// This test asserts that the forward dialer is NOT invoked for quad-100 +// SYNs on unserved ports; the guard in acceptTCP must RST instead. +func TestQuad100UnservedTCPPortDoesNotForward(t *testing.T) { + impl := makeNetstack(t, func(impl *Impl) { + impl.ProcessSubnets = false + impl.ProcessLocalIPs = false + impl.atomicIsLocalIPFunc.Store(looksLikeATailscaleSelfAddress) + }) + + dialFn, gotConn := makeHangDialer(t) + impl.forwardDialFunc = dialFn + + // Use a client IP in the CGNAT range so shouldProcessInbound-adjacent + // code paths treat this as plausibly-peer-sourced traffic, matching + // what a real stray quad-100 probe from the host OS would look like. + client := netip.MustParseAddr("100.101.102.103") + quad100 := tsaddr.TailscaleServiceIP() + + // 853 is DoT, the specific case called out in the original bug + // report ("conntrack error no peer found for 100.100.100.100:853"). + // Before the fix, port 853 (and any non-{53,80,8080} port) leaked + // out to WireGuard; after the fix it is absorbed here and must NOT + // trigger forwardTCP. + pkt := tcp4syn(t, client, quad100, 1234, 853) + var parsed packet.Parsed + parsed.Decode(pkt) + + resp, _ := impl.handleLocalPackets(&parsed, impl.tundev, nil) + if resp != filter.DropSilently { + t.Fatalf("handleLocalPackets for quad-100:853: got %v, want filter.DropSilently", resp) + } + + // acceptTCP runs asynchronously in the gVisor TCP dispatcher after + // handleLocalPackets injects the packet into netstack. Use the + // in-flight connection counter as a deterministic synchronization + // point: wrapTCPProtocolHandler increments connsInFlightByClient + // when the dispatcher hands the connection off to acceptTCP, and + // acceptTCP's deferred decrementInFlightTCPForward decrements it + // on return. + // + // On the green path (RST guard fires), acceptTCP returns promptly + // and the counter reaches 0. On the red path (fall-through to + // forwardTCP), acceptTCP blocks inside the forwardDialFunc call — + // makeHangDialer signals gotConn on entry (buffered, non-blocking) + // and then blocks forever — so the counter never reaches 0 but + // gotConn fires synchronously from the dispatcher goroutine. A + // select on both races those outcomes without real-time padding. + // + // testing/synctest is not usable here: gVisor's sleep package calls + // the runtime's gopark directly rather than via the standard + // library, so synctest.Wait() cannot observe those goroutines + // becoming durably blocked and hangs indefinitely. + inFlightZero := make(chan struct{}) + go func() { + for { + impl.mu.Lock() + n := impl.connsInFlightByClient[client] + impl.mu.Unlock() + if n == 0 { + close(inFlightZero) + return + } + time.Sleep(time.Millisecond) + } + }() + + select { + case <-gotConn: + t.Fatalf("forwardDialFunc was called for quad-100:853; acceptTCP fell through to forwardTCP instead of sending RST. This means stray traffic to quad-100 on unserved ports is being redirected to the host's loopback at the same port.") + case <-inFlightZero: + // acceptTCP returned cleanly; the RST guard fired. + case <-time.After(5 * time.Second): + // Safety net so a regression in the in-flight counter plumbing + // doesn't hang the whole test run; both outcomes above should + // fire within milliseconds in practice. + t.Fatal("timed out waiting for acceptTCP to dispatch quad-100:853 SYN") + } +} + func TestShouldSendToHost(t *testing.T) { var ( selfIP4 = netip.MustParseAddr("100.64.1.2") diff --git a/wgengine/netstack/netstack_userping_apple.go b/wgengine/netstack/netstack_userping_apple.go index a82b81e99..cb6926f0a 100644 --- a/wgengine/netstack/netstack_userping_apple.go +++ b/wgengine/netstack/netstack_userping_apple.go @@ -6,33 +6,30 @@ package netstack import ( + "context" + "net" "net/netip" "time" - probing "github.com/prometheus-community/pro-bing" + "tailscale.com/net/ping" ) // sendOutboundUserPing sends a non-privileged ICMP (or ICMPv6) ping to dstIP with the given timeout. func (ns *Impl) sendOutboundUserPing(dstIP netip.Addr, timeout time.Duration) error { - p, err := probing.NewPinger(dstIP.String()) + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + p := ping.New(ctx, ns.logf, nil) + p.Unprivileged = true + defer p.Close() + + dst := &net.IPAddr{IP: dstIP.AsSlice(), Zone: dstIP.Zone()} + ns.logf("sendOutboundUserPing: forwarding ping to %s", dstIP) + d, err := p.Send(ctx, dst, []byte("tailscale-userping")) if err != nil { - ns.logf("sendICMPPingToIP failed to create pinger: %v", err) + ns.logf("sendOutboundUserPing: ping to %s failed: %v", dstIP, err) return err } - - p.Timeout = timeout - p.Count = 1 - p.SetPrivileged(false) - - p.OnSend = func(pkt *probing.Packet) { - ns.logf("sendICMPPingToIP: forwarding ping to %s:", p.Addr()) - } - p.OnRecv = func(pkt *probing.Packet) { - ns.logf("sendICMPPingToIP: %d bytes pong from %s: icmp_seq=%d time=%v", pkt.Nbytes, pkt.IPAddr, pkt.Seq, pkt.Rtt) - } - p.OnFinish = func(stats *probing.Statistics) { - ns.logf("sendICMPPingToIP: done, %d replies received", stats.PacketsRecv) - } - - return p.Run() + ns.logf("sendOutboundUserPing: pong from %s in %v", dstIP, d) + return nil } diff --git a/wgengine/router/osrouter/router_linux.go b/wgengine/router/osrouter/router_linux.go index 3c261c912..73f65cdf1 100644 --- a/wgengine/router/osrouter/router_linux.go +++ b/wgengine/router/osrouter/router_linux.go @@ -89,6 +89,7 @@ type linuxRouter struct { connmarkEnabled bool // whether connmark rules are currently enabled netfilterMode preftype.NetfilterMode netfilterKind string + cgnatMode linuxfw.CGNATMode magicsockPortV4 uint16 magicsockPortV6 uint16 } @@ -489,7 +490,9 @@ func (r *linuxRouter) Set(cfg *router.Config) error { // Connmark rules for rp_filter compatibility. // Always enabled when netfilter is ON to handle all rp_filter=1 scenarios // (normal operation, exit nodes, subnet routers, and clients using exit nodes). - netfilterOn := cfg.NetfilterMode == netfilterOn + // Gate on r.netfilterMode (actual state) rather than cfg.NetfilterMode + // (desired state) so we don't call into the runner when chain setup failed. + netfilterOn := r.netfilterMode == netfilterOn switch { case netfilterOn == r.connmarkEnabled: // state already correct, nothing to do. @@ -502,6 +505,14 @@ func (r *linuxRouter) Set(cfg *router.Config) error { // Only update state on success to keep it in sync with actual rules r.connmarkEnabled = true } + // Enable src_valid_mark so the kernel uses the packet's fwmark + // during the rp_filter reverse-path check. Without this, the + // connmark restore in mangle/PREROUTING is ineffective — rp_filter + // does its routing lookup with fwmark=0, ignoring the restored + // bypass mark, and drops reply packets as martians. + if err := writeSysctl("net.ipv4.conf.all.src_valid_mark", "1"); err != nil { + r.logf("warning: failed to enable src_valid_mark: %v", err) + } default: r.logf("disabling connmark-based rp_filter workaround") if err := r.nfr.DelConnmarkSaveRule(); err != nil { @@ -521,9 +532,50 @@ func (r *linuxRouter) Set(cfg *router.Config) error { r.enableIPForwarding() } + // Remove the rule to drop off-tailnet CGNAT traffic, if needed. + if netfilterOn || r.netfilterMode == netfilterNoDivert { + var cgnatMode linuxfw.CGNATMode + if cfg.RemoveCGNATDropRule { + cgnatMode = linuxfw.CGNATModeReturn + } else { + cgnatMode = linuxfw.CGNATModeDrop + } + err := r.setCGNATDropModeLocked(cgnatMode) + if err != nil { + errs = append(errs, fmt.Errorf("set cgnat mode: %w", err)) + } + } + return errors.Join(errs...) } +// setCGNATDropModeLocked clears old rules and add new rules for the desired +// behavior for incoming non-Tailscale CGNAT packets. +// [linuxRouter.mu] must be held. +func (r *linuxRouter) setCGNATDropModeLocked(want linuxfw.CGNATMode) error { + if want == r.cgnatMode { + return nil + } + // r.cgnatMode is empty at initial startup, before this function has been + // called for the first time. In that case, we can skip deleting old + // rules, because there aren't any. + if r.cgnatMode != "" { + err := r.nfr.DelExternalCGNATRules(r.cgnatMode, r.tunname) + if err != nil { + return fmt.Errorf("clear old cgnat rules: %w", err) + } + } + err := r.nfr.AddExternalCGNATRules(want, r.tunname) + if err != nil { + // We currently have no rules set, so change the state to reflect that + // so we might try again on a future Router update. + r.cgnatMode = "" + return fmt.Errorf("add new cgnat rules: %w", err) + } + r.cgnatMode = want + return nil +} + var dockerStatefulFilteringWarnable = health.Register(&health.Warnable{ Code: "docker-stateful-filtering", Title: "Docker with stateful filtering", @@ -772,6 +824,20 @@ func (r *linuxRouter) setNetfilterModeLocked(mode preftype.NetfilterMode) error } } + // Re-add the CGNAT rules if we had any set. + // This does not call [linuxRouter.setCGNATDropModeLocked] because that + // function assumes that [linuxRouter.cgnatMode] accurately represents the + // current state in the firewall. This would not be true when we hit this + // code path, and is what we're fixing up here. + if r.cgnatMode != "" { + if err := r.nfr.AddExternalCGNATRules(r.cgnatMode, r.tunname); err != nil { + // We currently have no rules set, so change the state to reflect that + // so we might try again on a future Router update. + r.cgnatMode = "" + return fmt.Errorf("add cgnat rules: %w", err) + } + } + return nil } diff --git a/wgengine/router/osrouter/router_linux_test.go b/wgengine/router/osrouter/router_linux_test.go index 07aa8ced8..340ebb148 100644 --- a/wgengine/router/osrouter/router_linux_test.go +++ b/wgengine/router/osrouter/router_linux_test.go @@ -562,6 +562,10 @@ type fakeIPTablesRunner struct { ipt4 map[string][]string ipt6 map[string][]string // we always assume ipv6 and ipv6 nat are enabled when testing + + addChainsErr error // if non-nil, AddChains returns it instead of setting up chains + addConnmarkSaveCalls int + addExternalCGNATCalls int } func newIPTablesRunner(t *testing.T) linuxfw.NetfilterRunner { @@ -717,11 +721,11 @@ func (n *fakeIPTablesRunner) DeleteDNATRuleForSvc(svcName string, origDst, dst n return errors.New("not implemented") } +type iptRule struct{ chain, rule string } + func (n *fakeIPTablesRunner) addBase4(tunname string) error { curIPT := n.ipt4 - newRules := []struct{ chain, rule string }{ - {"filter/ts-input", fmt.Sprintf("! -i %s -s %s -j RETURN", tunname, tsaddr.ChromeOSVMRange().String())}, - {"filter/ts-input", fmt.Sprintf("! -i %s -s %s -j DROP", tunname, tsaddr.CGNATRange().String())}, + newRules := []iptRule{ {"filter/ts-forward", fmt.Sprintf("-i %s -j MARK --set-mark %s/%s", tunname, tsconst.LinuxSubnetRouteMark, tsconst.LinuxFwmarkMask)}, {"filter/ts-forward", fmt.Sprintf("-m mark --mark %s/%s -j ACCEPT", tsconst.LinuxSubnetRouteMark, tsconst.LinuxFwmarkMask)}, {"filter/ts-forward", fmt.Sprintf("-o %s -s %s -j DROP", tunname, tsaddr.CGNATRange().String())}, @@ -737,7 +741,7 @@ func (n *fakeIPTablesRunner) addBase4(tunname string) error { func (n *fakeIPTablesRunner) addBase6(tunname string) error { curIPT := n.ipt6 - newRules := []struct{ chain, rule string }{ + newRules := []iptRule{ {"filter/ts-forward", fmt.Sprintf("-i %s -j MARK --set-mark %s/%s", tunname, tsconst.LinuxSubnetRouteMark, tsconst.LinuxFwmarkMask)}, {"filter/ts-forward", fmt.Sprintf("-m mark --mark %s/%s -j ACCEPT", tsconst.LinuxSubnetRouteMark, tsconst.LinuxFwmarkMask)}, {"filter/ts-forward", fmt.Sprintf("-o %s -j ACCEPT", tunname)}, @@ -762,7 +766,7 @@ func (n *fakeIPTablesRunner) DelLoopbackRule(addr netip.Addr) error { } func (n *fakeIPTablesRunner) AddHooks() error { - newRules := []struct{ chain, rule string }{ + newRules := []iptRule{ {"filter/INPUT", "-j ts-input"}, {"filter/FORWARD", "-j ts-forward"}, {"nat/POSTROUTING", "-j ts-postrouting"}, @@ -778,7 +782,7 @@ func (n *fakeIPTablesRunner) AddHooks() error { } func (n *fakeIPTablesRunner) DelHooks(logf logger.Logf) error { - delRules := []struct{ chain, rule string }{ + delRules := []iptRule{ {"filter/INPUT", "-j ts-input"}, {"filter/FORWARD", "-j ts-forward"}, {"nat/POSTROUTING", "-j ts-postrouting"}, @@ -794,6 +798,9 @@ func (n *fakeIPTablesRunner) DelHooks(logf logger.Logf) error { } func (n *fakeIPTablesRunner) AddChains() error { + if n.addChainsErr != nil { + return n.addChainsErr + } for _, ipt := range []map[string][]string{n.ipt4, n.ipt6} { for _, chain := range []string{"filter/ts-input", "filter/ts-forward", "nat/ts-postrouting"} { ipt[chain] = nil @@ -922,6 +929,7 @@ func (n *fakeIPTablesRunner) DelMagicsockPortRule(port uint16, network string) e } func (n *fakeIPTablesRunner) AddConnmarkSaveRule() error { + n.addConnmarkSaveCalls++ // PREROUTING rule: restore mark from conntrack prerouteRule := "-m conntrack --ctstate ESTABLISHED,RELATED -j CONNMARK --restore-mark --nfmask 0xff0000 --ctmask 0xff0000" for _, ipt := range []map[string][]string{n.ipt4, n.ipt6} { @@ -953,6 +961,49 @@ func (n *fakeIPTablesRunner) DelConnmarkSaveRule() error { return nil } +func buildExternalCGNATRules(mode linuxfw.CGNATMode, tunname string) ([]iptRule, error) { + switch mode { + case linuxfw.CGNATModeDrop: + return []iptRule{ + {"filter/ts-input", fmt.Sprintf("! -i %s -s %s -j RETURN", tunname, tsaddr.ChromeOSVMRange().String())}, + {"filter/ts-input", fmt.Sprintf("! -i %s -s %s -j DROP", tunname, tsaddr.CGNATRange().String())}, + }, nil + case linuxfw.CGNATModeReturn: + return []iptRule{ + {"filter/ts-input", fmt.Sprintf("! -i %s -s %s -j RETURN", tunname, tsaddr.CGNATRange().String())}, + }, nil + default: + return nil, fmt.Errorf("unsupported mode %q", mode) + } +} + +func (n *fakeIPTablesRunner) AddExternalCGNATRules(mode linuxfw.CGNATMode, tunname string) error { + n.addExternalCGNATCalls++ + rules, err := buildExternalCGNATRules(mode, tunname) + if err != nil { + return err + } + for _, rule := range rules { + if err := appendRule(n, n.ipt4, rule.chain, rule.rule); err != nil { + return fmt.Errorf("add rule %q to chain %q: %w", rule.rule, rule.chain, err) + } + } + return nil +} + +func (n *fakeIPTablesRunner) DelExternalCGNATRules(mode linuxfw.CGNATMode, tunname string) error { + rules, err := buildExternalCGNATRules(mode, tunname) + if err != nil { + return err + } + for _, rule := range rules { + if err := deleteRule(n, n.ipt4, rule.chain, rule.rule); err != nil { + return fmt.Errorf("del rule %q to chain %q: %w", rule.rule, rule.chain, err) + } + } + return nil +} + func (n *fakeIPTablesRunner) HasIPV6() bool { return true } func (n *fakeIPTablesRunner) HasIPV6NAT() bool { return true } func (n *fakeIPTablesRunner) HasIPV6Filter() bool { return true } @@ -1157,9 +1208,7 @@ func (lt *linuxTest) Close() error { } func newLinuxRootTest(t *testing.T) (*linuxTest, *eventbus.Bus) { - if os.Getuid() != 0 { - t.Skip("test requires root") - } + tstest.RequireRoot(t) lt := new(linuxTest) lt.tun = createTestTUN(t) @@ -1211,7 +1260,9 @@ func TestRuleDeletedEvent(t *testing.T) { } func TestDelRouteIdempotent(t *testing.T) { + fake := NewFakeOS(t) lt, _ := newLinuxRootTest(t) + lt.r.nfr = fake.nfr defer lt.Close() for _, s := range []string{ @@ -1237,7 +1288,9 @@ func TestDelRouteIdempotent(t *testing.T) { } func TestAddRemoveRules(t *testing.T) { + fake := NewFakeOS(t) lt, _ := newLinuxRootTest(t) + lt.r.nfr = fake.nfr defer lt.Close() r := lt.r @@ -1506,3 +1559,53 @@ func TestUpdateMagicsockPortChange(t *testing.T) { oldPortRule, nfr.ipt4["filter/ts-input"]) } } + +// TestSetSkipsNetfilterAddonsWhenSetupFails verifies that Set does not invoke +// rule-management methods that depend on the ts-* chains existing when chain +// setup failed. +func TestSetSkipsNetfilterAddonsWhenSetupFails(t *testing.T) { + nfr := newIPTablesRunner(t).(*fakeIPTablesRunner) + nfr.addChainsErr = errors.New("kernel lacks netfilter support") + + bus := eventbus.New() + defer bus.Close() + mon, err := netmon.New(bus, logger.Discard) + if err != nil { + t.Fatal(err) + } + mon.Start() + defer mon.Close() + + fake := NewFakeOS(t) + ht := health.NewTracker(bus) + r, err := newUserspaceRouterAdvanced(logger.Discard, "tailscale0", mon, fake, ht, bus) + if err != nil { + t.Fatalf("newUserspaceRouterAdvanced: %v", err) + } + lr := r.(*linuxRouter) + lr.nfr = nfr + if err := lr.Up(); err != nil { + t.Fatalf("Up: %v", err) + } + defer lr.Close() + + cfg := &Config{ + LocalAddrs: mustCIDRs("100.101.102.103/10"), + NetfilterMode: netfilterOn, + } + // Set must return an error (chain setup failed) but must not panic. + if err := lr.Set(cfg); err == nil { + t.Fatal("Set returned nil; want error because AddChains failed") + } + if lr.netfilterMode != netfilterOff { + t.Errorf("netfilterMode = %v; want netfilterOff after failed AddChains", lr.netfilterMode) + } + if nfr.addConnmarkSaveCalls != 0 { + t.Errorf("AddConnmarkSaveRule called %d times; want 0 when chain setup failed", + nfr.addConnmarkSaveCalls) + } + if nfr.addExternalCGNATCalls != 0 { + t.Errorf("AddExternalCGNATRules called %d times; want 0 when chain setup failed", + nfr.addExternalCGNATCalls) + } +} diff --git a/wgengine/router/router.go b/wgengine/router/router.go index 6868acb43..f8d702d47 100644 --- a/wgengine/router/router.go +++ b/wgengine/router/router.go @@ -132,10 +132,11 @@ type Config struct { SubnetRoutes []netip.Prefix // Linux-only things below, ignored on other platforms. - SNATSubnetRoutes bool // SNAT traffic to local subnets - StatefulFiltering bool // Apply stateful filtering to inbound connections - NetfilterMode preftype.NetfilterMode // how much to manage netfilter rules - NetfilterKind string // what kind of netfilter to use ("nftables", "iptables", or "" to auto-detect) + SNATSubnetRoutes bool // SNAT traffic to local subnets + StatefulFiltering bool // Apply stateful filtering to inbound connections + NetfilterMode preftype.NetfilterMode // how much to manage netfilter rules + NetfilterKind string // what kind of netfilter to use ("nftables", "iptables", or "" to auto-detect) + RemoveCGNATDropRule bool // whether to remove the firewall rule to drop non-Tailscale inbound traffic from CGNAT IPs } func (a *Config) Equal(b *Config) bool { diff --git a/wgengine/router/router_test.go b/wgengine/router/router_test.go index f6176f1d0..e6b415586 100644 --- a/wgengine/router/router_test.go +++ b/wgengine/router/router_test.go @@ -15,7 +15,7 @@ func TestConfigEqual(t *testing.T) { testedFields := []string{ "LocalAddrs", "Routes", "LocalRoutes", "NewMTU", "SubnetRoutes", "SNATSubnetRoutes", "StatefulFiltering", - "NetfilterMode", "NetfilterKind", + "NetfilterMode", "NetfilterKind", "RemoveCGNATDropRule", } configType := reflect.TypeFor[Config]() configFields := []string{} diff --git a/wgengine/userspace.go b/wgengine/userspace.go index 1b77d4b97..222df1bc8 100644 --- a/wgengine/userspace.go +++ b/wgengine/userspace.go @@ -4,23 +4,21 @@ package wgengine import ( - "bufio" "context" crand "crypto/rand" "crypto/x509" "errors" "fmt" "io" - "maps" "math" "net/netip" - "reflect" "runtime" "slices" - "strings" "sync" + "sync/atomic" "time" + "github.com/gaissmai/bart" "github.com/tailscale/wireguard-go/device" "github.com/tailscale/wireguard-go/tun" "tailscale.com/control/controlknobs" @@ -49,7 +47,6 @@ import ( "tailscale.com/types/logger" "tailscale.com/types/netmap" "tailscale.com/types/views" - "tailscale.com/util/backoff" "tailscale.com/util/checkchange" "tailscale.com/util/clientmetric" "tailscale.com/util/eventbus" @@ -70,29 +67,6 @@ import ( "tailscale.com/wgengine/wglog" ) -// Lazy wireguard-go configuration parameters. -const ( - // lazyPeerIdleThreshold is the idle duration after - // which we remove a peer from the wireguard configuration. - // (This includes peers that have never been idle, which - // effectively have infinite idleness) - lazyPeerIdleThreshold = 5 * time.Minute - - // packetSendTimeUpdateFrequency controls how often we record - // the time that we wrote a packet to an IP address. - packetSendTimeUpdateFrequency = 10 * time.Second - - // packetSendRecheckWireguardThreshold controls how long we can go - // between packet sends to an IP before checking to see - // whether this IP address needs to be added back to the - // WireGuard peer oconfig. - packetSendRecheckWireguardThreshold = 1 * time.Minute -) - -// statusPollInterval is how often we ask wireguard-go for its engine -// status (as long as there's activity). See docs on its use below. -const statusPollInterval = 1 * time.Minute - // networkLoggerUploadTimeout is the maximum timeout to wait when // shutting down the network logger as it uploads the last network log messages. const networkLoggerUploadTimeout = 5 * time.Second @@ -134,21 +108,27 @@ type userspaceEngine struct { // is being routed over Tailscale. isDNSIPOverTailscale syncs.AtomicValue[func(netip.Addr) bool] - wgLock sync.Mutex // serializes all wgdev operations; see lock order comment below - lastCfgFull wgcfg.Config - lastNMinPeers int - lastRouter *router.Config - lastEngineFull *wgcfg.Config // of full wireguard config, not trimmed - lastEngineInputs *maybeReconfigInputs - lastDNSConfig dns.ConfigView // or invalid if none - lastIsSubnetRouter bool // was the node a primary subnet router in the last run. - recvActivityAt map[key.NodePublic]mono.Time - trimmedNodes map[key.NodePublic]bool // set of node keys of peers currently excluded from wireguard config - sentActivityAt map[netip.Addr]*mono.Time // value is accessed atomically - destIPActivityFuncs map[netip.Addr]func() - lastStatusPollTime mono.Time // last time we polled the engine status - reconfigureVPN func() error // or nil - conn25PacketHooks Conn25PacketHooks // or nil + wgLock sync.Mutex // serializes all wgdev operations; see lock order comment below + + // peerByIPRoute is a longest-prefix-match table built from + // lastCfgFull.Peers AllowedIPs. It's the slow path for + // SetPeerByIPPacketFunc, used when LocalBackend's exact-IP fast path + // (nodeByAddr) misses — i.e. for subnet routes and exit-node default + // routes. Built from lastCfgFull (the wireguard-filtered peer list) + // rather than the netmap so that exit-node selection is honored: the + // netmap has 0.0.0.0/0 in AllowedIPs for every exit-capable peer, but + // lastCfgFull only has it for the currently-selected exit node. + // + // Replaced (not mutated) by maybeReconfigWireguardLocked. Read by + // the per-packet wgdev callback without locking. + peerByIPRoute atomic.Pointer[bart.Table[key.NodePublic]] + + lastCfgFull wgcfg.Config + lastRouter *router.Config + lastDNSConfig dns.ConfigView // or invalid if none + lastIsSubnetRouter bool // was the node a primary subnet router in the last run. + reconfigureVPN func() error // or nil + conn25PacketHooks Conn25PacketHooks // or nil mu sync.Mutex // guards following; see lock order comment below netMap *netmap.NetworkMap // or nil @@ -462,10 +442,6 @@ func NewUserspaceEngine(logf logger.Logf, conf Config) (_ Engine, reterr error) ForceDiscoKey: conf.ForceDiscoKey, OnDERPRecv: conf.OnDERPRecv, } - if buildfeatures.HasLazyWG { - magicsockOpts.NoteRecvActivity = e.noteRecvActivity - } - var err error e.magicConn, err = magicsock.NewConn(magicsockOpts) if err != nil { @@ -533,6 +509,16 @@ func NewUserspaceEngine(logf logger.Logf, conf Config) (_ Engine, reterr error) e.logf("Creating WireGuard device...") e.wgdev = wgcfg.NewDevice(e.tundev, e.magicConn.Bind(), e.wgLogger.DeviceLogger) closePool.addFunc(e.wgdev.Close) + + // Install a default outbound-packet peer lookup callback. It uses only + // the engine's BART table, which is rebuilt from the wireguard-filtered + // peer list on every Reconfig. Consumers (e.g. LocalBackend) may later + // call SetPeerByIPPacketFunc to additionally install a fast path for + // exact node-address matches; the BART remains the slow-path fallback. + // Without this default, callers that don't run a LocalBackend would + // have no way to route outbound packets to peers, since peers are + // created lazily from inbound packets only via SetPeerLookupFunc. + e.SetPeerByIPPacketFunc(nil) closePool.addFunc(func() { if err := e.magicConn.Close(); err != nil { e.logf("error closing magicconn: %v", err) @@ -692,135 +678,11 @@ func (e *userspaceEngine) handleLocalPackets(p *packet.Parsed, t *tstun.Wrapper) return filter.Accept } -var debugTrimWireguard = envknob.RegisterOptBool("TS_DEBUG_TRIM_WIREGUARD") - -// forceFullWireguardConfig reports whether we should give wireguard our full -// network map, even for inactive peers. -// -// TODO(bradfitz): remove this at some point. We had a TODO to do it before 1.0 -// but it's still there as of 1.30. Really we should not do this wireguard lazy -// peer config at all and just fix wireguard-go to not have so much extra memory -// usage per peer. That would simplify a lot of Tailscale code. OTOH, we have 50 -// MB of memory on iOS now instead of 15 MB, so the other option is to just give -// up on lazy wireguard config and blow the memory and hope for the best on iOS. -// That's sad too. Or we get rid of these knobs (lazy wireguard config has been -// stable!) but I'm worried that a future regression would be easier to debug -// with these knobs in place. -func (e *userspaceEngine) forceFullWireguardConfig(numPeers int) bool { - // Did the user explicitly enable trimming via the environment variable knob? - if b, ok := debugTrimWireguard().Get(); ok { - return !b - } - return e.controlKnobs != nil && e.controlKnobs.KeepFullWGConfig.Load() -} - -// isTrimmablePeer reports whether p is a peer that we can trim out of the -// network map. -// -// For implementation simplicity, we can only trim peers that have -// only non-subnet AllowedIPs (an IPv4 /32 or IPv6 /128), which is the -// common case for most peers. Subnet router nodes will just always be -// created in the wireguard-go config. -func (e *userspaceEngine) isTrimmablePeer(p *wgcfg.Peer, numPeers int) bool { - if e.forceFullWireguardConfig(numPeers) { - return false - } - - // AllowedIPs must all be single IPs, not subnets. - for _, aip := range p.AllowedIPs { - if !aip.IsSingleIP() { - return false - } - } - return true -} - -// noteRecvActivity is called by magicsock when a packet has been -// received for the peer with node key nk. Magicsock calls this no -// more than every 10 seconds for a given peer. -func (e *userspaceEngine) noteRecvActivity(nk key.NodePublic) { - e.wgLock.Lock() - defer e.wgLock.Unlock() - - if _, ok := e.recvActivityAt[nk]; !ok { - // Not a trimmable peer we care about tracking. (See isTrimmablePeer) - if e.trimmedNodes[nk] { - e.logf("wgengine: [unexpected] noteReceiveActivity called on idle node %v that's not in recvActivityAt", nk.ShortString()) - } - return - } - now := e.timeNow() - e.recvActivityAt[nk] = now - - // As long as there's activity, periodically poll the engine to get - // stats for the far away side effect of - // ipn/ipnlocal.LocalBackend.parseWgStatusLocked to log activity, for - // use in various admin dashboards. - // This particularly matters on platforms without a connected GUI, as - // the GUIs generally poll this enough to cause that logging. But - // tailscaled alone did not, hence this. - if e.lastStatusPollTime.IsZero() || now.Sub(e.lastStatusPollTime) >= statusPollInterval { - e.lastStatusPollTime = now - go e.RequestStatus() - } - - // If the last activity time jumped a bunch (say, at least - // half the idle timeout) then see if we need to reprogram - // WireGuard. This could probably be just - // lazyPeerIdleThreshold without the divide by 2, but - // maybeReconfigWireguardLocked is cheap enough to call every - // couple minutes (just not on every packet). - if e.trimmedNodes[nk] { - e.logf("wgengine: idle peer %v now active, reconfiguring WireGuard", nk.ShortString()) - e.maybeReconfigWireguardLocked(nil) - } -} - -// isActiveSinceLocked reports whether the peer identified by (nk, ip) -// has had a packet sent to or received from it since t. +// maybeReconfigWireguardLocked reconfigures wireguard-go with the current +// full config, installing a PeerLookupFunc for on-demand peer creation. // // e.wgLock must be held. -func (e *userspaceEngine) isActiveSinceLocked(nk key.NodePublic, ip netip.Addr, t mono.Time) bool { - if e.recvActivityAt[nk].After(t) { - return true - } - timePtr, ok := e.sentActivityAt[ip] - if !ok { - return false - } - return timePtr.LoadAtomic().After(t) -} - -// maybeReconfigInputs holds the inputs to the maybeReconfigWireguardLocked -// function. If these things don't change between calls, there's nothing to do. -type maybeReconfigInputs struct { - WGConfig *wgcfg.Config - TrimmedNodes map[key.NodePublic]bool - TrackNodes views.Slice[key.NodePublic] - TrackIPs views.Slice[netip.Addr] -} - -func (i *maybeReconfigInputs) Equal(o *maybeReconfigInputs) bool { - return reflect.DeepEqual(i, o) -} - -func (i *maybeReconfigInputs) Clone() *maybeReconfigInputs { - if i == nil { - return nil - } - v := *i - v.WGConfig = i.WGConfig.Clone() - v.TrimmedNodes = maps.Clone(i.TrimmedNodes) - return &v -} - -// discoChanged are the set of peers whose disco keys have changed, implying they've restarted. -// If a peer is in this set and was previously in the live wireguard config, -// it needs to be first removed and then re-added to flush out its wireguard session key. -// If discoChanged is nil or empty, this extra removal step isn't done. -// -// e.wgLock must be held. -func (e *userspaceEngine) maybeReconfigWireguardLocked(discoChanged map[key.NodePublic]bool) error { +func (e *userspaceEngine) maybeReconfigWireguardLocked() error { if hook := e.testMaybeReconfigHook; hook != nil { hook() return nil @@ -829,177 +691,49 @@ func (e *userspaceEngine) maybeReconfigWireguardLocked(discoChanged map[key.Node full := e.lastCfgFull e.wgLogger.SetPeers(full.Peers) - // Compute a minimal config to pass to wireguard-go - // based on the full config. Prune off all the peers - // and only add the active ones back. - min := full - min.Peers = make([]wgcfg.Peer, 0, e.lastNMinPeers) - - // We'll only keep a peer around if it's been active in - // the past 5 minutes. That's more than WireGuard's key - // rotation time anyway so it's no harm if we remove it - // later if it's been inactive. - var activeCutoff mono.Time - if buildfeatures.HasLazyWG { - activeCutoff = e.timeNow().Add(-lazyPeerIdleThreshold) - } - - // Not all peers can be trimmed from the network map (see - // isTrimmablePeer). For those that are trimmable, keep track of - // their NodeKey and Tailscale IPs. These are the ones we'll need - // to install tracking hooks for to watch their send/receive - // activity. - var trackNodes []key.NodePublic - var trackIPs []netip.Addr - if buildfeatures.HasLazyWG { - trackNodes = make([]key.NodePublic, 0, len(full.Peers)) - trackIPs = make([]netip.Addr, 0, len(full.Peers)) - } - - // Don't re-alloc the map; the Go compiler optimizes map clears as of - // Go 1.11, so we can re-use the existing + allocated map. - if e.trimmedNodes != nil { - clear(e.trimmedNodes) - } else { - e.trimmedNodes = make(map[key.NodePublic]bool) - } - - needRemoveStep := false - for i := range full.Peers { - p := &full.Peers[i] - nk := p.PublicKey - if !buildfeatures.HasLazyWG || !e.isTrimmablePeer(p, len(full.Peers)) { - min.Peers = append(min.Peers, *p) - if discoChanged[nk] { - needRemoveStep = true - } - continue - } - trackNodes = append(trackNodes, nk) - recentlyActive := false - for _, cidr := range p.AllowedIPs { - trackIPs = append(trackIPs, cidr.Addr()) - recentlyActive = recentlyActive || e.isActiveSinceLocked(nk, cidr.Addr(), activeCutoff) - } - if recentlyActive { - min.Peers = append(min.Peers, *p) - if discoChanged[nk] { - needRemoveStep = true - } - } else { - e.trimmedNodes[nk] = true + // Rebuild the prefix-match peer routing table from the current + // (wireguard-filtered) peer list and publish it atomically. + rt := &bart.Table[key.NodePublic]{} + for _, p := range full.Peers { + for _, pfx := range p.AllowedIPs { + rt.Insert(pfx, p.PublicKey) } } - e.lastNMinPeers = len(min.Peers) + e.peerByIPRoute.Store(rt) - if changed := checkchange.Update(&e.lastEngineInputs, &maybeReconfigInputs{ - WGConfig: &min, - TrimmedNodes: e.trimmedNodes, - TrackNodes: views.SliceOf(trackNodes), - TrackIPs: views.SliceOf(trackIPs), - }); !changed { - return nil - } - - if buildfeatures.HasLazyWG { - e.updateActivityMapsLocked(trackNodes, trackIPs) - } - - if needRemoveStep { - minner := min - minner.Peers = nil - numRemove := 0 - for _, p := range min.Peers { - if discoChanged[p.PublicKey] { - numRemove++ - continue - } - minner.Peers = append(minner.Peers, p) - } - if numRemove > 0 { - e.logf("wgengine: Reconfig: removing session keys for %d peers", numRemove) - if err := wgcfg.ReconfigDevice(e.wgdev, &minner, e.logf); err != nil { - e.logf("wgdev.Reconfig: %v", err) - return err - } - } - } - - e.logf("wgengine: Reconfig: configuring userspace WireGuard config (with %d/%d peers)", len(min.Peers), len(full.Peers)) - if err := wgcfg.ReconfigDevice(e.wgdev, &min, e.logf); err != nil { + e.logf("wgengine: Reconfig: configuring userspace WireGuard config (with %d peers)", len(full.Peers)) + if err := wgcfg.ReconfigDevice(e.wgdev, &full, e.logf); err != nil { e.logf("wgdev.Reconfig: %v", err) return err } return nil } -// updateActivityMapsLocked updates the data structures used for tracking the activity -// of wireguard peers that we might add/remove dynamically from the real config -// as given to wireguard-go. +// SetPeerByIPPacketFunc installs a callback used by wireguard-go to look up +// which peer should handle an outbound packet by destination IP. // -// e.wgLock must be held. -func (e *userspaceEngine) updateActivityMapsLocked(trackNodes []key.NodePublic, trackIPs []netip.Addr) { - if !buildfeatures.HasLazyWG { - return - } - // Generate the new map of which nodekeys we want to track - // receive times for. - mr := map[key.NodePublic]mono.Time{} // TODO: only recreate this if set of keys changed - for _, nk := range trackNodes { - // Preserve old times in the new map, but also - // populate map entries for new trackNodes values with - // time.Time{} zero values. (Only entries in this map - // are tracked, so the Time zero values allow it to be - // tracked later) - mr[nk] = e.recvActivityAt[nk] - } - e.recvActivityAt = mr - - oldTime := e.sentActivityAt - e.sentActivityAt = make(map[netip.Addr]*mono.Time, len(oldTime)) - oldFunc := e.destIPActivityFuncs - e.destIPActivityFuncs = make(map[netip.Addr]func(), len(oldFunc)) - - updateFn := func(timePtr *mono.Time) func() { - return func() { - now := e.timeNow() - old := timePtr.LoadAtomic() - - // How long's it been since we last sent a packet? - elapsed := now.Sub(old) - if old == 0 { - // For our first packet, old is 0, which has indeterminate meaning. - // Set elapsed to a big number (four score and seven years). - elapsed = 762642 * time.Hour - } - - if elapsed >= packetSendTimeUpdateFrequency { - timePtr.StoreAtomic(now) - } - // On a big jump, assume we might no longer be in the wireguard - // config and go check. - if elapsed >= packetSendRecheckWireguardThreshold { - e.wgLock.Lock() - defer e.wgLock.Unlock() - e.maybeReconfigWireguardLocked(nil) +// fn is an optional fast path for exact node-address matches (e.g. dst is a +// Tailscale IP). On miss (or if fn is nil), the engine's own BART table +// ([userspaceEngine.peerByIPRoute], built from the wireguard-filtered peer +// list) is consulted to handle subnet routes and exit-node default routes. +// +// [NewUserspaceEngine] installs a BART-only default at engine creation time, +// so callers that don't call SetPeerByIPPacketFunc (e.g. those not running +// a LocalBackend) still get working outbound packet routing. +func (e *userspaceEngine) SetPeerByIPPacketFunc(fn func(netip.Addr) (_ key.NodePublic, ok bool)) { + e.wgdev.SetPeerByIPPacketFunc(func(_, dst netip.Addr, _ []byte) (device.NoisePublicKey, bool) { + if fn != nil { + if pk, ok := fn(dst); ok { + return pk.Raw32(), true } } - } - - for _, ip := range trackIPs { - timePtr := oldTime[ip] - if timePtr == nil { - timePtr = new(mono.Time) + if rt := e.peerByIPRoute.Load(); rt != nil { + if pk, ok := rt.Lookup(dst); ok { + return pk.Raw32(), true + } } - e.sentActivityAt[ip] = timePtr - - fn := oldFunc[ip] - if fn == nil { - fn = updateFn(timePtr) - } - e.destIPActivityFuncs[ip] = fn - } - e.tundev.SetDestIPActivityFuncs(e.destIPActivityFuncs) + return device.NoisePublicKey{}, false + }) } // hasOverlap checks if there is a IPPrefix which is common amongst the two @@ -1014,29 +748,17 @@ func hasOverlap(aips, rips views.Slice[netip.Prefix]) bool { } // ResetAndStop resets the engine to a clean state (like calling Reconfig -// with all pointers to zero values) and waits for it to be fully stopped, -// with no live peers or DERPs. +// with all pointers to zero values) and returns the resulting status. // // Unlike Reconfig, it does not return ErrNoChanges. // -// If the engine stops, returns the status. NB that this status will not be sent -// to the registered status callback, it is on the caller to ensure this status -// is handled appropriately. +// The returned status will not be sent to the registered status callback; +// it is on the caller to ensure this status is handled appropriately. func (e *userspaceEngine) ResetAndStop() (*Status, error) { if err := e.Reconfig(&wgcfg.Config{}, &router.Config{}, &dns.Config{}); err != nil && !errors.Is(err, ErrNoChanges) { return nil, err } - bo := backoff.NewBackoff("UserspaceEngineResetAndStop", e.logf, 1*time.Second) - for { - st, err := e.getStatus() - if err != nil { - return nil, err - } - if len(st.Peers) == 0 && st.DERPs == 0 { - return st, nil - } - bo.BackOff(context.Background(), fmt.Errorf("waiting for engine to stop: peers=%d derps=%d", len(st.Peers), st.DERPs)) - } + return e.getStatus() } func (e *userspaceEngine) PatchDiscoKey(pub key.NodePublic, disco key.DiscoPublic) { @@ -1088,7 +810,7 @@ func (e *userspaceEngine) Reconfig(cfg *wgcfg.Config, routerCfg *router.Config, } isSubnetRouterChanged := buildfeatures.HasAdvertiseRoutes && isSubnetRouter != e.lastIsSubnetRouter - engineChanged := checkchange.Update(&e.lastEngineFull, cfg) + engineChanged := !e.lastCfgFull.Equal(cfg) routerChanged := checkchange.Update(&e.lastRouter, routerCfg) dnsChanged := buildfeatures.HasDNS && !e.lastDNSConfig.Equal(dnsCfg.View()) if dnsChanged { @@ -1120,11 +842,10 @@ func (e *userspaceEngine) Reconfig(cfg *wgcfg.Config, routerCfg *router.Config, } // See if any peers have changed disco keys, which means they've restarted. - // If so, we need to update the wireguard-go/device.Device in two phases: - // once without the node which has restarted, to clear its wireguard session key, - // and a second time with it. + // If so, remove the peer from wireguard-go to flush its session key, + // then let the PeerLookupFunc re-create it on demand. discoChanged := make(map[key.NodePublic]bool) - { + if engineChanged { prevEP := make(map[key.NodePublic]key.DiscoPublic) for i := range e.lastCfgFull.Peers { if p := &e.lastCfgFull.Peers[i]; !p.DiscoKey.IsZero() { @@ -1137,7 +858,6 @@ func (e *userspaceEngine) Reconfig(cfg *wgcfg.Config, routerCfg *router.Config, continue } - // If the key changed, mark the connection for reconfiguration. pub := p.PublicKey if old, ok := prevEP[pub]; ok && old != p.DiscoKey { @@ -1145,31 +865,26 @@ func (e *userspaceEngine) Reconfig(cfg *wgcfg.Config, routerCfg *router.Config, // wireguard config as the new key was received over an existing wireguard // connection. if discoTSMP, okTSMP := e.tsmpLearnedDisco[p.PublicKey]; okTSMP { + // Key matches, remove entry from map. + delete(e.tsmpLearnedDisco, p.PublicKey) if discoTSMP == p.DiscoKey { - // Key matches, remove entry from map. e.logf("wgengine: Skipping reconfig (TSMP key): %s changed from %q to %q", pub.ShortString(), old, p.DiscoKey) - delete(e.tsmpLearnedDisco, p.PublicKey) - } else { - // The new disco key does not match what we received via - // TSMP for this peer. This is unexpected, so log it. - // If it does happen, overwrite the previously-saved - // disco key with the new one for now: We expect another - // update must be pending in that case, so keep the map - // entry. - // The reason why this should never happen is that only a single - // request is coming through the netmap pipeline at a time, and there - // should realistically ever only be a single entry in the map. This - // is really a belt and suspenders solution to find usage that is - // inconsistent with our expectations. - e.logf("wgengine: [unexpected] Reconfig: using TSMP key for %s (control stale): tsmp=%q control=%q old=%q", - pub.ShortString(), discoTSMP, p.DiscoKey, old) - metricTSMPLearnedKeyMismatch.Add(1) - p.DiscoKey = discoTSMP + // Skip session clear. + continue } - // Skip session clear no matter what. - continue + // The new disco key does not match what we received via + // TSMP for this peer. This is unexpected, though possible + // if processing a change in a large netmap ends up taking + // longer than the 2 second timeout in + // [controlClient.mapRoutineState.UpdateNetmapDelta], or if + // the context is cancelled mid update. Log the event, and reset + // the connection as it is possibly a stale entry in the map + // instead of a TSMP disco key update that led us here. + e.logf("wgengine: [unexpected] Reconfig: using TSMP key for %s (control stale): tsmp=%q control=%q old=%q", + pub.ShortString(), discoTSMP, p.DiscoKey, old) + metricTSMPLearnedKeyMismatch.Add(1) } discoChanged[pub] = true @@ -1183,21 +898,36 @@ func (e *userspaceEngine) Reconfig(cfg *wgcfg.Config, routerCfg *router.Config, e.testDiscoChangedHook(discoChanged) } + if !e.lastCfgFull.PrivateKey.Equal(cfg.PrivateKey) { + // Tell magicsock about the new (or initial) private key + // (which is needed by DERP) before wgdev gets it, as wgdev + // will start trying to handshake, which we want to be able to + // go over DERP. + if err := e.magicConn.SetPrivateKey(cfg.PrivateKey); err != nil { + e.logf("wgengine: Reconfig: SetPrivateKey: %v", err) + } + + if err := e.wgdev.SetPrivateKey(key.NodePrivateAs[device.NoisePrivateKey](cfg.PrivateKey)); err != nil { + e.logf("wgengine: Reconfig: wgdev.SetPrivateKey: %v", err) + } + } + e.lastCfgFull = *cfg.Clone() - // Tell magicsock about the new (or initial) private key - // (which is needed by DERP) before wgdev gets it, as wgdev - // will start trying to handshake, which we want to be able to - // go over DERP. - if err := e.magicConn.SetPrivateKey(cfg.PrivateKey); err != nil { - e.logf("wgengine: Reconfig: SetPrivateKey: %v", err) - } e.magicConn.UpdatePeers(peerSet) e.magicConn.SetPreferredPort(listenPort) e.magicConn.UpdatePMTUD() - if err := e.maybeReconfigWireguardLocked(discoChanged); err != nil { - return err + if engineChanged { + if err := e.maybeReconfigWireguardLocked(); err != nil { + return err + } + // Now that we've reconfigured wireguard-go, remove any peers with + // changed disco keys to flush their session keys, and let them be + // re-created on demand by the PeerLookupFunc. + for pub := range discoChanged { + e.wgdev.RemovePeer(pub.Raw32()) + } } // Cleanup map of tsmp marks for peers that no longer exists in config. @@ -1342,8 +1072,14 @@ func (e *userspaceEngine) PeerByKey(pubKey key.NodePublic) (_ wgint.Peer, ok boo if dev == nil { return wgint.Peer{}, false } - peer := dev.LookupPeer(pubKey.Raw32()) - if peer == nil { + // Use LookupActivePeer (not LookupPeer) to avoid triggering on-demand + // peer creation via PeerLookupFunc. PeerByKey is called from status + // polling paths (getStatus, getPeerStatusLite) which iterate every peer + // in the netmap; using LookupPeer would lazily create a wireguard-go + // peer for every single netmap peer on each status poll, leaking + // memory via per-peer queues and goroutines. + peer, ok := dev.LookupActivePeer(pubKey.Raw32()) + if !ok { return wgint.Peer{}, false } return wgint.PeerOf(peer), true @@ -1439,8 +1175,6 @@ func (e *userspaceEngine) Close() { e.closing = true e.mu.Unlock() - r := bufio.NewReader(strings.NewReader("")) - e.wgdev.IpcSetOperation(r) e.magicConn.Close() if e.netMonOwned { e.netMon.Close() diff --git a/wgengine/userspace_test.go b/wgengine/userspace_test.go index 558df4ced..b2f40fada 100644 --- a/wgengine/userspace_test.go +++ b/wgengine/userspace_test.go @@ -8,8 +8,9 @@ import ( "math/rand" "net/netip" "os" - "reflect" "runtime" + "slices" + "sync" "testing" "go4.org/mem" @@ -18,81 +19,22 @@ import ( "tailscale.com/envknob" "tailscale.com/health" "tailscale.com/net/dns" + "tailscale.com/net/dns/resolver" "tailscale.com/net/netaddr" - "tailscale.com/net/tstun" + "tailscale.com/net/netmon" "tailscale.com/tailcfg" - "tailscale.com/tstest" - "tailscale.com/tstime/mono" + "tailscale.com/types/dnstype" "tailscale.com/types/key" + "tailscale.com/types/logger" "tailscale.com/types/netmap" "tailscale.com/types/opt" + "tailscale.com/util/dnsname" "tailscale.com/util/eventbus/eventbustest" "tailscale.com/util/usermetric" "tailscale.com/wgengine/router" "tailscale.com/wgengine/wgcfg" ) -func TestNoteReceiveActivity(t *testing.T) { - now := mono.Time(123456) - var logBuf tstest.MemLogger - - confc := make(chan bool, 1) - gotConf := func() bool { - select { - case <-confc: - return true - default: - return false - } - } - e := &userspaceEngine{ - timeNow: func() mono.Time { return now }, - recvActivityAt: map[key.NodePublic]mono.Time{}, - logf: logBuf.Logf, - tundev: new(tstun.Wrapper), - testMaybeReconfigHook: func() { confc <- true }, - trimmedNodes: map[key.NodePublic]bool{}, - } - ra := e.recvActivityAt - - nk := key.NewNode().Public() - - // Activity on an untracked key should do nothing. - e.noteRecvActivity(nk) - if len(ra) != 0 { - t.Fatalf("unexpected growth in map: now has %d keys; want 0", len(ra)) - } - if logBuf.Len() != 0 { - t.Fatalf("unexpected log write (and thus activity): %s", logBuf.Bytes()) - } - - // Now track it, but don't mark it trimmed, so shouldn't update. - ra[nk] = 0 - e.noteRecvActivity(nk) - if len(ra) != 1 { - t.Fatalf("unexpected growth in map: now has %d keys; want 1", len(ra)) - } - if got := ra[nk]; got != now { - t.Fatalf("time in map = %v; want %v", got, now) - } - if gotConf() { - t.Fatalf("unexpected reconfig") - } - - // Now mark it trimmed and expect an update. - e.trimmedNodes[nk] = true - e.noteRecvActivity(nk) - if len(ra) != 1 { - t.Fatalf("unexpected growth in map: now has %d keys; want 1", len(ra)) - } - if got := ra[nk]; got != now { - t.Fatalf("time in map = %v; want %v", got, now) - } - if !gotConf() { - t.Fatalf("didn't get expected reconfig") - } -} - func nodeViews(v []*tailcfg.Node) []tailcfg.NodeView { nv := make([]tailcfg.NodeView, len(v)) for i, n := range v { @@ -111,7 +53,6 @@ func TestUserspaceEngineReconfig(t *testing.T) { t.Fatal(err) } t.Cleanup(e.Close) - ue := e.(*userspaceEngine) routerCfg := &router.Config{} @@ -147,20 +88,6 @@ func TestUserspaceEngineReconfig(t *testing.T) { if err != nil { t.Fatal(err) } - - wantRecvAt := map[key.NodePublic]mono.Time{ - nkFromHex(nodeHex): 0, - } - if got := ue.recvActivityAt; !reflect.DeepEqual(got, wantRecvAt) { - t.Errorf("wrong recvActivityAt\n got: %v\nwant: %v\n", got, wantRecvAt) - } - - wantTrimmedNodes := map[key.NodePublic]bool{ - nkFromHex(nodeHex): true, - } - if got := ue.trimmedNodes; !reflect.DeepEqual(got, wantTrimmedNodes) { - t.Errorf("wrong wantTrimmedNodes\n got: %v\nwant: %v\n", got, wantTrimmedNodes) - } } } @@ -263,8 +190,9 @@ func TestUserspaceEngineTSMPLearnedMismatch(t *testing.T) { wrongKey bool }{ {tsmp: false, inMap: false, wrongKey: false}, - {tsmp: true, inMap: false, wrongKey: true}, - {tsmp: false, inMap: false, wrongKey: false}, + {tsmp: true, inMap: false, wrongKey: false}, + {tsmp: true, inMap: true, wrongKey: true}, + {tsmp: false, inMap: true, wrongKey: false}, } nkHex := "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" @@ -609,3 +537,76 @@ func BenchmarkGenLocalAddrFunc(b *testing.B) { }) b.Logf("x = %v", x) } + +// Regression test for #19730: on major link change, MatchDomains Routes must +// be preserved. +func TestLinkChangeReapplyPreservesMagicDNSRoutes(t *testing.T) { + switch runtime.GOOS { + case "linux", "android", "darwin", "ios", "openbsd": + default: + t.Skipf("linkChange DNS reapply path not exercised on %s", runtime.GOOS) + } + + bus := eventbustest.NewBus(t) + noop, err := dns.NewNoopManager() + if err != nil { + t.Fatal(err) + } + e, err := NewUserspaceEngine(t.Logf, Config{ + HealthTracker: health.NewTracker(bus), + Metrics: new(usermetric.Registry), + EventBus: bus, + DNS: noop, + RespondToPing: true, + }) + if err != nil { + t.Fatal(err) + } + t.Cleanup(e.Close) + + var ( + mu sync.Mutex + last resolver.Config + ) + e.(*userspaceEngine).dns.Resolver().TestOnlySetHook(func(cfg resolver.Config) { + mu.Lock() + defer mu.Unlock() + last = cfg + }) + snapshot := func() []dnsname.FQDN { + mu.Lock() + defer mu.Unlock() + return slices.Clone(last.LocalDomains) + } + + dnsCfg := &dns.Config{ + Routes: map[dnsname.FQDN][]*dnstype.Resolver{ + "ts.net.": {{Addr: "199.247.155.53"}}, + "foo.ts.net.": nil, + "64.100.in-addr.arpa.": nil, + }, + Hosts: map[dnsname.FQDN][]netip.Addr{ + "node.foo.ts.net.": {netip.MustParseAddr("100.64.0.5")}, + }, + SearchDomains: []dnsname.FQDN{"foo.ts.net."}, + } + if err := e.Reconfig(&wgcfg.Config{}, &router.Config{}, dnsCfg); err != nil { + t.Fatalf("Reconfig: %v", err) + } + initial := snapshot() + + cd, err := netmon.NewChangeDelta(nil, &netmon.State{HaveV4: true}, 0, true) + if err != nil { + t.Fatal(err) + } + cd.RebindLikelyRequired = true + e.(*userspaceEngine).linkChange(cd) + + after := snapshot() + slices.Sort(initial) + slices.Sort(after) + if !slices.Equal(initial, after) { + t.Errorf("resolver LocalDomains changed after linkChange:\n initial: %s\n after: %s", + logger.AsJSON(initial), logger.AsJSON(after)) + } +} diff --git a/wgengine/watchdog.go b/wgengine/watchdog.go deleted file mode 100644 index 4bb320b4b..000000000 --- a/wgengine/watchdog.go +++ /dev/null @@ -1,256 +0,0 @@ -// Copyright (c) Tailscale Inc & contributors -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !js && !ts_omit_debug - -package wgengine - -import ( - "fmt" - "log" - "net/netip" - "runtime/pprof" - "strings" - "sync" - "time" - - "tailscale.com/envknob" - "tailscale.com/feature/buildfeatures" - "tailscale.com/ipn/ipnstate" - "tailscale.com/net/dns" - "tailscale.com/net/packet" - "tailscale.com/tailcfg" - "tailscale.com/types/key" - "tailscale.com/types/netmap" - "tailscale.com/util/clientmetric" - "tailscale.com/wgengine/filter" - "tailscale.com/wgengine/router" - "tailscale.com/wgengine/wgcfg" - "tailscale.com/wgengine/wgint" -) - -type watchdogEvent string - -const ( - Any watchdogEvent = "Any" - Reconfig watchdogEvent = "Reconfig" - ResetAndStop watchdogEvent = "ResetAndStop" - SetFilter watchdogEvent = "SetFilter" - SetJailedFilter watchdogEvent = "SetJailedFilter" - SetStatusCallback watchdogEvent = "SetStatusCallback" - UpdateStatus watchdogEvent = "UpdateStatus" - RequestStatus watchdogEvent = "RequestStatus" - SetNetworkMap watchdogEvent = "SetNetworkMap" - Ping watchdogEvent = "Ping" - Close watchdogEvent = "Close" - PeerForIPEvent watchdogEvent = "PeerForIP" -) - -var ( - watchdogMetrics = map[watchdogEvent]*clientmetric.Metric{ - Any: clientmetric.NewCounter("watchdog_timeout_any_total"), - Reconfig: clientmetric.NewCounter("watchdog_timeout_reconfig"), - ResetAndStop: clientmetric.NewCounter("watchdog_timeout_resetandstop"), - SetFilter: clientmetric.NewCounter("watchdog_timeout_setfilter"), - SetJailedFilter: clientmetric.NewCounter("watchdog_timeout_setjailedfilter"), - SetStatusCallback: clientmetric.NewCounter("watchdog_timeout_setstatuscallback"), - UpdateStatus: clientmetric.NewCounter("watchdog_timeout_updatestatus"), - RequestStatus: clientmetric.NewCounter("watchdog_timeout_requeststatus"), - SetNetworkMap: clientmetric.NewCounter("watchdog_timeout_setnetworkmap"), - Ping: clientmetric.NewCounter("watchdog_timeout_ping"), - Close: clientmetric.NewCounter("watchdog_timeout_close"), - PeerForIPEvent: clientmetric.NewCounter("watchdog_timeout_peerforipevent"), - } -) - -// NewWatchdog wraps an Engine and makes sure that all methods complete -// within a reasonable amount of time. -// -// If they do not, the watchdog crashes the process. -func NewWatchdog(e Engine) Engine { - if envknob.Bool("TS_DEBUG_DISABLE_WATCHDOG") { - return e - } - return &watchdogEngine{ - wrap: e, - logf: log.Printf, - fatalf: log.Fatalf, - maxWait: 45 * time.Second, - inFlight: make(map[inFlightKey]time.Time), - } -} - -type inFlightKey struct { - op watchdogEvent - ctr uint64 -} - -type watchdogEngine struct { - wrap Engine - logf func(format string, args ...any) - fatalf func(format string, args ...any) - maxWait time.Duration - - // Track the start time(s) of in-flight operations - inFlightMu sync.Mutex - inFlight map[inFlightKey]time.Time - inFlightCtr uint64 -} - -func (e *watchdogEngine) watchdogErr(event watchdogEvent, fn func() error) error { - // Track all in-flight operations so we can print more useful error - // messages on watchdog failure - e.inFlightMu.Lock() - - key := inFlightKey{ - op: event, - ctr: e.inFlightCtr, - } - e.inFlightCtr++ - e.inFlight[key] = time.Now() - e.inFlightMu.Unlock() - - defer func() { - e.inFlightMu.Lock() - defer e.inFlightMu.Unlock() - delete(e.inFlight, key) - }() - - errCh := make(chan error) - go func() { - errCh <- fn() - }() - t := time.NewTimer(e.maxWait) - select { - case err := <-errCh: - t.Stop() - return err - case <-t.C: - buf := new(strings.Builder) - pprof.Lookup("goroutine").WriteTo(buf, 1) - e.logf("wgengine watchdog stacks:\n%s", buf.String()) - // Collect the list of in-flight operations for debugging. - var ( - b []byte - now = time.Now() - ) - e.inFlightMu.Lock() - for k, t := range e.inFlight { - dur := now.Sub(t).Round(time.Millisecond) - b = fmt.Appendf(b, "in-flight[%d]: name=%s duration=%v start=%s\n", k.ctr, k.op, dur, t.Format(time.RFC3339Nano)) - } - e.recordEvent(event) - e.inFlightMu.Unlock() - - // Print everything as a single string to avoid log - // rate limits. - e.logf("wgengine watchdog in-flight:\n%s", b) - e.fatalf("wgengine: watchdog timeout on %s", event) - return nil - } -} - -func (e *watchdogEngine) recordEvent(event watchdogEvent) { - if watchdogMetrics == nil { - return - } - - mEvent, ok := watchdogMetrics[event] - if ok { - mEvent.Add(1) - } - mAny, ok := watchdogMetrics[Any] - if ok { - mAny.Add(1) - } -} - -func (e *watchdogEngine) watchdog(event watchdogEvent, fn func()) { - e.watchdogErr(event, func() error { - fn() - return nil - }) -} - -func (e *watchdogEngine) Reconfig(cfg *wgcfg.Config, routerCfg *router.Config, dnsCfg *dns.Config) error { - return e.watchdogErr(Reconfig, func() error { return e.wrap.Reconfig(cfg, routerCfg, dnsCfg) }) -} - -func (e *watchdogEngine) ResetAndStop() (st *Status, err error) { - e.watchdog(ResetAndStop, func() { - st, err = e.wrap.ResetAndStop() - }) - return st, err -} - -func (e *watchdogEngine) GetFilter() *filter.Filter { - return e.wrap.GetFilter() -} - -func (e *watchdogEngine) SetFilter(filt *filter.Filter) { - e.watchdog(SetFilter, func() { e.wrap.SetFilter(filt) }) -} - -func (e *watchdogEngine) GetJailedFilter() *filter.Filter { - return e.wrap.GetJailedFilter() -} - -func (e *watchdogEngine) SetJailedFilter(filt *filter.Filter) { - e.watchdog(SetJailedFilter, func() { e.wrap.SetJailedFilter(filt) }) -} - -func (e *watchdogEngine) SetStatusCallback(cb StatusCallback) { - e.watchdog(SetStatusCallback, func() { e.wrap.SetStatusCallback(cb) }) -} - -func (e *watchdogEngine) UpdateStatus(sb *ipnstate.StatusBuilder) { - e.watchdog(UpdateStatus, func() { e.wrap.UpdateStatus(sb) }) -} - -func (e *watchdogEngine) RequestStatus() { - e.watchdog(RequestStatus, func() { e.wrap.RequestStatus() }) -} - -func (e *watchdogEngine) SetNetworkMap(nm *netmap.NetworkMap) { - e.watchdog(SetNetworkMap, func() { e.wrap.SetNetworkMap(nm) }) -} - -func (e *watchdogEngine) Ping(ip netip.Addr, pingType tailcfg.PingType, size int, cb func(*ipnstate.PingResult)) { - e.watchdog(Ping, func() { e.wrap.Ping(ip, pingType, size, cb) }) -} - -func (e *watchdogEngine) Close() { - e.watchdog(Close, e.wrap.Close) -} - -func (e *watchdogEngine) PeerForIP(ip netip.Addr) (ret PeerForIP, ok bool) { - e.watchdog(PeerForIPEvent, func() { ret, ok = e.wrap.PeerForIP(ip) }) - return ret, ok -} - -func (e *watchdogEngine) Done() <-chan struct{} { - return e.wrap.Done() -} - -func (e *watchdogEngine) InstallCaptureHook(cb packet.CaptureCallback) { - if !buildfeatures.HasCapture { - return - } - e.wrap.InstallCaptureHook(cb) -} - -func (e *watchdogEngine) PeerByKey(pubKey key.NodePublic) (_ wgint.Peer, ok bool) { - return e.wrap.PeerByKey(pubKey) -} - -func (e *watchdogEngine) PatchDiscoKey(pub key.NodePublic, disco key.DiscoPublic) { - // PatchDiscoKey mirrors the implementation of [controlclient.patchDiscoKeyer ]. - // It is implemented here to avoid the dependency edge to controlclient, but must be kept - // in sync with the original implementation. - type patchDiscoKeyer interface { - PatchDiscoKey(key.NodePublic, key.DiscoPublic) - } - if n, ok := e.wrap.(patchDiscoKeyer); ok { - n.PatchDiscoKey(pub, disco) - } -} diff --git a/wgengine/watchdog_omit.go b/wgengine/watchdog_omit.go deleted file mode 100644 index b4ed43442..000000000 --- a/wgengine/watchdog_omit.go +++ /dev/null @@ -1,8 +0,0 @@ -// Copyright (c) Tailscale Inc & contributors -// SPDX-License-Identifier: BSD-3-Clause - -//go:build js || ts_omit_debug - -package wgengine - -func NewWatchdog(e Engine) Engine { return e } diff --git a/wgengine/watchdog_test.go b/wgengine/watchdog_test.go deleted file mode 100644 index a0ce9cf07..000000000 --- a/wgengine/watchdog_test.go +++ /dev/null @@ -1,141 +0,0 @@ -// Copyright (c) Tailscale Inc & contributors -// SPDX-License-Identifier: BSD-3-Clause - -//go:build !js - -package wgengine - -import ( - "runtime" - "sync" - "testing" - "time" - - "tailscale.com/health" - "tailscale.com/util/eventbus/eventbustest" - "tailscale.com/util/usermetric" -) - -func TestWatchdog(t *testing.T) { - t.Parallel() - - var maxWaitMultiple time.Duration = 1 - if runtime.GOOS == "darwin" { - // Work around slow close syscalls on Big Sur with content filter Network Extensions installed. - // See https://github.com/tailscale/tailscale/issues/1598. - maxWaitMultiple = 15 - } - - t.Run("default-watchdog-does-not-fire", func(t *testing.T) { - t.Parallel() - bus := eventbustest.NewBus(t) - ht := health.NewTracker(bus) - reg := new(usermetric.Registry) - e, err := NewFakeUserspaceEngine(t.Logf, 0, ht, reg, bus) - if err != nil { - t.Fatal(err) - } - - e = NewWatchdog(e) - e.(*watchdogEngine).maxWait = maxWaitMultiple * 150 * time.Millisecond - e.(*watchdogEngine).logf = t.Logf - e.(*watchdogEngine).fatalf = t.Fatalf - - e.RequestStatus() - e.RequestStatus() - e.RequestStatus() - e.Close() - }) -} - -func TestWatchdogMetrics(t *testing.T) { - tests := []struct { - name string - events []watchdogEvent - wantCounts map[watchdogEvent]int64 - }{ - { - name: "single-event-types", - events: []watchdogEvent{RequestStatus, PeerForIPEvent, Ping}, - wantCounts: map[watchdogEvent]int64{ - RequestStatus: 1, - PeerForIPEvent: 1, - Ping: 1, - }, - }, - { - name: "repeated-events", - events: []watchdogEvent{RequestStatus, RequestStatus, Ping, RequestStatus}, - wantCounts: map[watchdogEvent]int64{ - RequestStatus: 3, - Ping: 1, - }, - }, - } - - // For swallowing fatalf calls and stack traces - logf := func(format string, args ...any) {} - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - clearMetrics(t) - bus := eventbustest.NewBus(t) - ht := health.NewTracker(bus) - reg := new(usermetric.Registry) - e, err := NewFakeUserspaceEngine(logf, 0, ht, reg, bus) - if err != nil { - t.Fatal(err) - } - e = NewWatchdog(e) - w := e.(*watchdogEngine) - w.maxWait = 1 * time.Microsecond - w.logf = logf - w.fatalf = logf - - var wg sync.WaitGroup - wg.Add(len(tt.events)) - - for _, ev := range tt.events { - blocked := make(chan struct{}) - w.watchdog(ev, func() { - defer wg.Done() - <-blocked - }) - close(blocked) - } - wg.Wait() - - // Check individual event counts - for ev, want := range tt.wantCounts { - m, ok := watchdogMetrics[ev] - if !ok { - t.Fatalf("no metric found for event %q", ev) - } - got := m.Value() - if got != want { - t.Errorf("got %d metric events for %q, want %d", got, ev, want) - } - } - - // Check total count for Any - m, ok := watchdogMetrics[Any] - if !ok { - t.Fatalf("no Any metric found") - } - got := m.Value() - if got != int64(len(tt.events)) { - t.Errorf("got %d metric events for Any, want %d", got, len(tt.events)) - } - }) - } -} - -func clearMetrics(t *testing.T) { - t.Helper() - if watchdogMetrics == nil { - return - } - for _, m := range watchdogMetrics { - m.Set(0) - } -} diff --git a/wgengine/wgcfg/config.go b/wgengine/wgcfg/config.go index 782812139..5510b65b2 100644 --- a/wgengine/wgcfg/config.go +++ b/wgengine/wgcfg/config.go @@ -53,11 +53,6 @@ type Peer struct { V6MasqAddr *netip.Addr // if non-nil, masquerade IPv6 traffic to this peer using this address IsJailed bool // if true, this peer is jailed and cannot initiate connections PersistentKeepalive uint16 // in seconds between keep-alives; 0 to disable - // wireguard-go's endpoint for this peer. It should always equal Peer.PublicKey. - // We represent it explicitly so that we can detect if they diverge and recover. - // There is no need to set WGEndpoint explicitly when constructing a Peer by hand. - // It is only populated when reading Peers from wireguard-go. - WGEndpoint key.NodePublic } func addrPtrEq(a, b *netip.Addr) bool { @@ -74,8 +69,7 @@ func (p Peer) Equal(o Peer) bool { p.IsJailed == o.IsJailed && p.PersistentKeepalive == o.PersistentKeepalive && addrPtrEq(p.V4MasqAddr, o.V4MasqAddr) && - addrPtrEq(p.V6MasqAddr, o.V6MasqAddr) && - p.WGEndpoint == o.WGEndpoint + addrPtrEq(p.V6MasqAddr, o.V6MasqAddr) } // PeerWithKey returns the Peer with key k and reports whether it was found. diff --git a/wgengine/wgcfg/config_test.go b/wgengine/wgcfg/config_test.go index 7059b17b2..013d3a4b4 100644 --- a/wgengine/wgcfg/config_test.go +++ b/wgengine/wgcfg/config_test.go @@ -30,7 +30,7 @@ func TestPeerEqual(t *testing.T) { for sf := range rt.Fields() { switch sf.Name { case "PublicKey", "DiscoKey", "AllowedIPs", "IsJailed", - "PersistentKeepalive", "V4MasqAddr", "V6MasqAddr", "WGEndpoint": + "PersistentKeepalive", "V4MasqAddr", "V6MasqAddr": // These are compared in [Peer.Equal]. default: t.Errorf("Have you added field %q to Peer.Equal? Do so if not, and then update TestPeerEqual", sf.Name) diff --git a/wgengine/wgcfg/device.go b/wgengine/wgcfg/device.go index ba29cfbdc..ed32f8337 100644 --- a/wgengine/wgcfg/device.go +++ b/wgengine/wgcfg/device.go @@ -4,9 +4,8 @@ package wgcfg import ( - "errors" - "io" - "sort" + "fmt" + "net/netip" "github.com/tailscale/wireguard-go/conn" "github.com/tailscale/wireguard-go/device" @@ -21,27 +20,15 @@ func NewDevice(tunDev tun.Device, bind conn.Bind, logger *device.Logger) *device return ret } -func DeviceConfig(d *device.Device) (*Config, error) { - r, w := io.Pipe() - errc := make(chan error, 1) - go func() { - errc <- d.IpcGetOperation(w) - w.Close() - }() - cfg, fromErr := FromUAPI(r) - r.Close() - getErr := <-errc - err := errors.Join(getErr, fromErr) - if err != nil { - return nil, err - } - sort.Slice(cfg.Peers, func(i, j int) bool { - return cfg.Peers[i].PublicKey.Less(cfg.Peers[j].PublicKey) - }) - return cfg, nil -} - // ReconfigDevice replaces the existing device configuration with cfg. +// +// Instead of using the UAPI text protocol, it uses the wireguard-go direct API +// to install a [device.PeerLookupFunc] callback that creates peers on demand. +// +// The caller is responsible for: +// - calling [device.Device.SetPrivateKey] when the key changes +// - installing a [device.PeerByIPPacketFunc] on the device for outbound +// packet routing (e.g. via [tailscale.com/wgengine.Engine.SetPeerByIPPacketFunc]) func ReconfigDevice(d *device.Device, cfg *Config, logf logger.Logf) (err error) { defer func() { if err != nil { @@ -49,20 +36,52 @@ func ReconfigDevice(d *device.Device, cfg *Config, logf logger.Logf) (err error) } }() - prev, err := DeviceConfig(d) - if err != nil { - return err + // Build peer map: public key → allowed IPs. + peers := make(map[device.NoisePublicKey][]netip.Prefix, len(cfg.Peers)) + for _, p := range cfg.Peers { + peers[p.PublicKey.Raw32()] = p.AllowedIPs } - r, w := io.Pipe() - errc := make(chan error, 1) - go func() { - errc <- d.IpcSetOperation(r) - r.Close() - }() + // Remove peers not in the new config. + d.RemoveMatchingPeers(func(pk device.NoisePublicKey) bool { + _, exists := peers[pk] + return !exists + }) - toErr := cfg.ToUAPI(logf, w, prev) - w.Close() - setErr := <-errc - return errors.Join(setErr, toErr) + // Update AllowedIPs on any already-active peers whose config may have + // changed. Peers that don't exist yet will get the correct AllowedIPs + // from PeerLookupFunc when they are lazily created. + for pk, allowedIPs := range peers { + if peer, ok := d.LookupActivePeer(pk); ok { + peer.SetAllowedIPs(allowedIPs) + } + } + + // Install callback for lazy peer creation (incoming packets). + bind := d.Bind() + d.SetPeerLookupFunc(func(pubk device.NoisePublicKey) (_ *device.NewPeerConfig, ok bool) { + allowedIPs, ok := peers[pubk] + if !ok { + return nil, false + } + ep, err := bind.ParseEndpoint(fmt.Sprintf("%x", pubk[:])) + if err != nil { + logf("wgcfg: failed to parse endpoint for peer %x: %v", pubk[:8], err) + return nil, false + } + return &device.NewPeerConfig{ + AllowedIPs: allowedIPs, + Endpoint: ep, + }, true + }) + + // RemoveMatchingPeers _again_, now that SetPeerLookupFunc is installed, + // lest any removed peers got re-created before the new SetPeerLookupFunc + // func was installed. + d.RemoveMatchingPeers(func(pk device.NoisePublicKey) bool { + _, exists := peers[pk] + return !exists + }) + + return nil } diff --git a/wgengine/wgcfg/device_test.go b/wgengine/wgcfg/device_test.go index 507f22311..07eb41adb 100644 --- a/wgengine/wgcfg/device_test.go +++ b/wgengine/wgcfg/device_test.go @@ -4,33 +4,22 @@ package wgcfg import ( - "bufio" - "bytes" "io" "net/netip" "os" - "sort" - "strings" - "sync" "testing" "github.com/tailscale/wireguard-go/conn" "github.com/tailscale/wireguard-go/device" "github.com/tailscale/wireguard-go/tun" - "go4.org/mem" "tailscale.com/types/key" ) -func TestDeviceConfig(t *testing.T) { - newK := func() (key.NodePublic, key.NodePrivate) { - t.Helper() - k := key.NewNode() - return k.Public(), k - } +func TestReconfigDevice(t *testing.T) { k1, pk1 := newK() ip1 := netip.MustParsePrefix("10.0.0.1/32") - k2, pk2 := newK() + k2, _ := newK() ip2 := netip.MustParsePrefix("10.0.0.2/32") k3, _ := newK() @@ -38,165 +27,80 @@ func TestDeviceConfig(t *testing.T) { cfg1 := &Config{ PrivateKey: pk1, - Peers: []Peer{{ - PublicKey: k2, - AllowedIPs: []netip.Prefix{ip2}, - }}, + Peers: []Peer{ + {PublicKey: k2, AllowedIPs: []netip.Prefix{ip2}}, + }, } - cfg2 := &Config{ - PrivateKey: pk2, - Peers: []Peer{{ - PublicKey: k1, - AllowedIPs: []netip.Prefix{ip1}, - PersistentKeepalive: 5, - }}, - } + dev := NewDevice(newNilTun(), new(noopBind), device.NewLogger(device.LogLevelError, "test")) + defer dev.Close() - device1 := NewDevice(newNilTun(), new(noopBind), device.NewLogger(device.LogLevelError, "device1")) - device2 := NewDevice(newNilTun(), new(noopBind), device.NewLogger(device.LogLevelError, "device2")) - defer device1.Close() - defer device2.Close() - - cmp := func(t *testing.T, d *device.Device, want *Config) { - t.Helper() - got, err := DeviceConfig(d) - if err != nil { + t.Run("initial-config", func(t *testing.T) { + if err := ReconfigDevice(dev, cfg1, t.Logf); err != nil { t.Fatal(err) } - prev := new(Config) - gotbuf := new(strings.Builder) - err = got.ToUAPI(t.Logf, gotbuf, prev) - gotStr := gotbuf.String() - if err != nil { - t.Errorf("got.ToUAPI(): error: %v", err) - return + // Peer should be creatable on demand via LookupPeer. + peer := dev.LookupPeer(k2.Raw32()) + if peer == nil { + t.Fatal("expected peer k2 to exist via LookupPeer") } - wantbuf := new(strings.Builder) - err = want.ToUAPI(t.Logf, wantbuf, prev) - wantStr := wantbuf.String() - if err != nil { - t.Errorf("want.ToUAPI(): error: %v", err) - return + // Unknown peer should not be found. + peer = dev.LookupPeer(k3.Raw32()) + if peer != nil { + t.Fatal("expected unknown peer k3 to not exist") } - if gotStr != wantStr { - buf := new(bytes.Buffer) - w := bufio.NewWriter(buf) - if err := d.IpcGetOperation(w); err != nil { - t.Errorf("on error, could not IpcGetOperation: %v", err) - } - w.Flush() - t.Errorf("config mismatch:\n---- got:\n%s\n---- want:\n%s\n---- uapi:\n%s", gotStr, wantStr, buf.String()) - } - } - - t.Run("device1-config", func(t *testing.T) { - if err := ReconfigDevice(device1, cfg1, t.Logf); err != nil { - t.Fatal(err) - } - cmp(t, device1, cfg1) }) - t.Run("device2-config", func(t *testing.T) { - if err := ReconfigDevice(device2, cfg2, t.Logf); err != nil { - t.Fatal(err) - } - cmp(t, device2, cfg2) - }) - - // This is only to test that Config and Reconfig are properly synchronized. - t.Run("device2-config-reconfig", func(t *testing.T) { - var wg sync.WaitGroup - wg.Add(2) - - go func() { - ReconfigDevice(device2, cfg2, t.Logf) - wg.Done() - }() - - go func() { - DeviceConfig(device2) - wg.Done() - }() - - wg.Wait() - }) - - t.Run("device1-modify-peer", func(t *testing.T) { - cfg1.Peers[0].DiscoKey = key.DiscoPublicFromRaw32(mem.B([]byte{0: 1, 31: 0})) - if err := ReconfigDevice(device1, cfg1, t.Logf); err != nil { - t.Fatal(err) - } - cmp(t, device1, cfg1) - }) - - t.Run("device1-replace-endpoint", func(t *testing.T) { - cfg1.Peers[0].DiscoKey = key.DiscoPublicFromRaw32(mem.B([]byte{0: 2, 31: 0})) - if err := ReconfigDevice(device1, cfg1, t.Logf); err != nil { - t.Fatal(err) - } - cmp(t, device1, cfg1) - }) - - t.Run("device1-add-new-peer", func(t *testing.T) { + t.Run("add-peer", func(t *testing.T) { cfg1.Peers = append(cfg1.Peers, Peer{ PublicKey: k3, AllowedIPs: []netip.Prefix{ip3}, }) - sort.Slice(cfg1.Peers, func(i, j int) bool { - return cfg1.Peers[i].PublicKey.Less(cfg1.Peers[j].PublicKey) - }) - - origCfg, err := DeviceConfig(device1) - if err != nil { + if err := ReconfigDevice(dev, cfg1, t.Logf); err != nil { t.Fatal(err) } - - if err := ReconfigDevice(device1, cfg1, t.Logf); err != nil { - t.Fatal(err) + // Both peers should now be discoverable. + if p := dev.LookupPeer(k2.Raw32()); p == nil { + t.Fatal("expected peer k2 to exist") } - cmp(t, device1, cfg1) - - newCfg, err := DeviceConfig(device1) - if err != nil { - t.Fatal(err) - } - - peer0 := func(cfg *Config) Peer { - p, ok := cfg.PeerWithKey(k2) - if !ok { - t.Helper() - t.Fatal("failed to look up peer 2") - } - return p - } - peersEqual := func(p, q Peer) bool { - return p.PublicKey == q.PublicKey && p.DiscoKey == q.DiscoKey && p.PersistentKeepalive == q.PersistentKeepalive && cidrsEqual(p.AllowedIPs, q.AllowedIPs) - } - if !peersEqual(peer0(origCfg), peer0(newCfg)) { - t.Error("reconfig modified old peer") + if p := dev.LookupPeer(k3.Raw32()); p == nil { + t.Fatal("expected peer k3 to exist") } }) - t.Run("device1-remove-peer", func(t *testing.T) { - removeKey := cfg1.Peers[len(cfg1.Peers)-1].PublicKey - cfg1.Peers = cfg1.Peers[:len(cfg1.Peers)-1] - - if err := ReconfigDevice(device1, cfg1, t.Logf); err != nil { + t.Run("remove-peer", func(t *testing.T) { + cfg2 := &Config{ + PrivateKey: pk1, + Peers: []Peer{ + {PublicKey: k2, AllowedIPs: []netip.Prefix{ip2}}, + }, + } + if err := ReconfigDevice(dev, cfg2, t.Logf); err != nil { t.Fatal(err) } - cmp(t, device1, cfg1) - - newCfg, err := DeviceConfig(device1) - if err != nil { - t.Fatal(err) + // k2 should still be discoverable. + if p := dev.LookupPeer(k2.Raw32()); p == nil { + t.Fatal("expected peer k2 to exist") } - - _, ok := newCfg.PeerWithKey(removeKey) - if ok { - t.Error("reconfig failed to remove peer") + // k3 should no longer be discoverable. + if p := dev.LookupPeer(k3.Raw32()); p != nil { + t.Fatal("expected peer k3 to not exist after removal") } }) + + t.Run("self-key-not-peer", func(t *testing.T) { + // The device's own key should not be a peer. + if p := dev.LookupPeer(k1.Raw32()); p != nil { + t.Fatal("expected own key to not be a peer") + } + }) + + _ = ip1 // suppress unused +} + +func newK() (key.NodePublic, key.NodePrivate) { + k := key.NewNode() + return k.Public(), k } // TODO: replace with a loopback tunnel diff --git a/wgengine/wgcfg/parser.go b/wgengine/wgcfg/parser.go deleted file mode 100644 index 8fb921409..000000000 --- a/wgengine/wgcfg/parser.go +++ /dev/null @@ -1,186 +0,0 @@ -// Copyright (c) Tailscale Inc & contributors -// SPDX-License-Identifier: BSD-3-Clause - -package wgcfg - -import ( - "bufio" - "fmt" - "io" - "net" - "net/netip" - "strconv" - "strings" - - "go4.org/mem" - "tailscale.com/types/key" -) - -type ParseError struct { - why string - offender string -} - -func (e *ParseError) Error() string { - return fmt.Sprintf("%s: %q", e.why, e.offender) -} - -func parseEndpoint(s string) (host string, port uint16, err error) { - i := strings.LastIndexByte(s, ':') - if i < 0 { - return "", 0, &ParseError{"Missing port from endpoint", s} - } - host, portStr := s[:i], s[i+1:] - if len(host) < 1 { - return "", 0, &ParseError{"Invalid endpoint host", host} - } - uport, err := strconv.ParseUint(portStr, 10, 16) - if err != nil { - return "", 0, err - } - hostColon := strings.IndexByte(host, ':') - if host[0] == '[' || host[len(host)-1] == ']' || hostColon > 0 { - err := &ParseError{"Brackets must contain an IPv6 address", host} - if len(host) > 3 && host[0] == '[' && host[len(host)-1] == ']' && hostColon > 0 { - maybeV6 := net.ParseIP(host[1 : len(host)-1]) - if maybeV6 == nil || len(maybeV6) != net.IPv6len { - return "", 0, err - } - } else { - return "", 0, err - } - host = host[1 : len(host)-1] - } - return host, uint16(uport), nil -} - -// memROCut separates a mem.RO at the separator if it exists, otherwise -// it returns two empty ROs and reports that it was not found. -func memROCut(s mem.RO, sep byte) (before, after mem.RO, found bool) { - if i := mem.IndexByte(s, sep); i >= 0 { - return s.SliceTo(i), s.SliceFrom(i + 1), true - } - found = false - return -} - -// FromUAPI generates a Config from r. -// r should be generated by calling device.IpcGetOperation; -// it is not compatible with other uapi streams. -func FromUAPI(r io.Reader) (*Config, error) { - cfg := new(Config) - var peer *Peer // current peer being operated on - deviceConfig := true - - scanner := bufio.NewScanner(r) - for scanner.Scan() { - line := mem.B(scanner.Bytes()) - if line.Len() == 0 { - continue - } - key, value, ok := memROCut(line, '=') - if !ok { - return nil, fmt.Errorf("failed to cut line %q on =", line.StringCopy()) - } - valueBytes := scanner.Bytes()[key.Len()+1:] - - if key.EqualString("public_key") { - if deviceConfig { - deviceConfig = false - } - // Load/create the peer we are now configuring. - var err error - peer, err = cfg.handlePublicKeyLine(valueBytes) - if err != nil { - return nil, err - } - continue - } - - var err error - if deviceConfig { - err = cfg.handleDeviceLine(key, value, valueBytes) - } else { - err = cfg.handlePeerLine(peer, key, value, valueBytes) - } - if err != nil { - return nil, err - } - } - - if err := scanner.Err(); err != nil { - return nil, err - } - - return cfg, nil -} - -func (cfg *Config) handleDeviceLine(k, value mem.RO, valueBytes []byte) error { - switch { - case k.EqualString("private_key"): - // wireguard-go guarantees not to send zero value; private keys are already clamped. - var err error - cfg.PrivateKey, err = key.ParseNodePrivateUntyped(value) - if err != nil { - return err - } - case k.EqualString("listen_port") || k.EqualString("fwmark"): - // ignore - default: - return fmt.Errorf("unexpected IpcGetOperation key: %q", k.StringCopy()) - } - return nil -} - -func (cfg *Config) handlePublicKeyLine(valueBytes []byte) (*Peer, error) { - p := Peer{} - var err error - p.PublicKey, err = key.ParseNodePublicUntyped(mem.B(valueBytes)) - if err != nil { - return nil, err - } - cfg.Peers = append(cfg.Peers, p) - return &cfg.Peers[len(cfg.Peers)-1], nil -} - -func (cfg *Config) handlePeerLine(peer *Peer, k, value mem.RO, valueBytes []byte) error { - switch { - case k.EqualString("endpoint"): - nk, err := key.ParseNodePublicUntyped(value) - if err != nil { - return fmt.Errorf("invalid endpoint %q for peer %q, expected a hex public key", value.StringCopy(), peer.PublicKey.ShortString()) - } - // nk ought to equal peer.PublicKey. - // Under some rare circumstances, it might not. See corp issue #3016. - // Even if that happens, don't stop early, so that we can recover from it. - // Instead, note the value of nk so we can fix as needed. - peer.WGEndpoint = nk - case k.EqualString("persistent_keepalive_interval"): - n, err := mem.ParseUint(value, 10, 16) - if err != nil { - return err - } - peer.PersistentKeepalive = uint16(n) - case k.EqualString("allowed_ip"): - ipp := netip.Prefix{} - err := ipp.UnmarshalText(valueBytes) - if err != nil { - return err - } - peer.AllowedIPs = append(peer.AllowedIPs, ipp) - case k.EqualString("protocol_version"): - if !value.EqualString("1") { - return fmt.Errorf("invalid protocol version: %q", value.StringCopy()) - } - case k.EqualString("replace_allowed_ips") || - k.EqualString("preshared_key") || - k.EqualString("last_handshake_time_sec") || - k.EqualString("last_handshake_time_nsec") || - k.EqualString("tx_bytes") || - k.EqualString("rx_bytes"): - // ignore - default: - return fmt.Errorf("unexpected IpcGetOperation key: %q", k.StringCopy()) - } - return nil -} diff --git a/wgengine/wgcfg/parser_test.go b/wgengine/wgcfg/parser_test.go deleted file mode 100644 index 8c38ec025..000000000 --- a/wgengine/wgcfg/parser_test.go +++ /dev/null @@ -1,95 +0,0 @@ -// Copyright (c) Tailscale Inc & contributors -// SPDX-License-Identifier: BSD-3-Clause - -package wgcfg - -import ( - "bufio" - "bytes" - "io" - "net/netip" - "reflect" - "runtime" - "testing" - - "tailscale.com/types/key" -) - -func noError(t *testing.T, err error) bool { - if err == nil { - return true - } - _, fn, line, _ := runtime.Caller(1) - t.Errorf("Error at %s:%d: %#v", fn, line, err) - return false -} - -func equal(t *testing.T, expected, actual any) bool { - if reflect.DeepEqual(expected, actual) { - return true - } - _, fn, line, _ := runtime.Caller(1) - t.Errorf("Failed equals at %s:%d\nactual %#v\nexpected %#v", fn, line, actual, expected) - return false -} - -func TestParseEndpoint(t *testing.T) { - _, _, err := parseEndpoint("[192.168.42.0:]:51880") - if err == nil { - t.Error("Error was expected") - } - host, port, err := parseEndpoint("192.168.42.0:51880") - if noError(t, err) { - equal(t, "192.168.42.0", host) - equal(t, uint16(51880), port) - } - host, port, err = parseEndpoint("test.wireguard.com:18981") - if noError(t, err) { - equal(t, "test.wireguard.com", host) - equal(t, uint16(18981), port) - } - host, port, err = parseEndpoint("[2607:5300:60:6b0::c05f:543]:2468") - if noError(t, err) { - equal(t, "2607:5300:60:6b0::c05f:543", host) - equal(t, uint16(2468), port) - } - _, _, err = parseEndpoint("[::::::invalid:18981") - if err == nil { - t.Error("Error was expected") - } -} - -func BenchmarkFromUAPI(b *testing.B) { - newK := func() (key.NodePublic, key.NodePrivate) { - b.Helper() - k := key.NewNode() - return k.Public(), k - } - k1, pk1 := newK() - ip1 := netip.MustParsePrefix("10.0.0.1/32") - - peer := Peer{ - PublicKey: k1, - AllowedIPs: []netip.Prefix{ip1}, - } - cfg1 := &Config{ - PrivateKey: pk1, - Peers: []Peer{peer, peer, peer, peer}, - } - - buf := new(bytes.Buffer) - w := bufio.NewWriter(buf) - if err := cfg1.ToUAPI(b.Logf, w, &Config{}); err != nil { - b.Fatal(err) - } - w.Flush() - r := bytes.NewReader(buf.Bytes()) - b.ReportAllocs() - for range b.N { - r.Seek(0, io.SeekStart) - _, err := FromUAPI(r) - if err != nil { - b.Errorf("failed from UAPI: %v", err) - } - } -} diff --git a/wgengine/wgcfg/wgcfg_clone.go b/wgengine/wgcfg/wgcfg_clone.go index 9e8de7b6f..a8a212267 100644 --- a/wgengine/wgcfg/wgcfg_clone.go +++ b/wgengine/wgcfg/wgcfg_clone.go @@ -72,5 +72,4 @@ var _PeerCloneNeedsRegeneration = Peer(struct { V6MasqAddr *netip.Addr IsJailed bool PersistentKeepalive uint16 - WGEndpoint key.NodePublic }{}) diff --git a/wgengine/wgcfg/writer.go b/wgengine/wgcfg/writer.go deleted file mode 100644 index f4981e3e9..000000000 --- a/wgengine/wgcfg/writer.go +++ /dev/null @@ -1,154 +0,0 @@ -// Copyright (c) Tailscale Inc & contributors -// SPDX-License-Identifier: BSD-3-Clause - -package wgcfg - -import ( - "fmt" - "io" - "net/netip" - "strconv" - - "tailscale.com/types/key" - "tailscale.com/types/logger" -) - -// ToUAPI writes cfg in UAPI format to w. -// Prev is the previous device Config. -// -// Prev is required so that we can remove now-defunct peers without having to -// remove and re-add all peers, and so that we can avoid writing information -// about peers that have not changed since the previous time we wrote our -// Config. -func (cfg *Config) ToUAPI(logf logger.Logf, w io.Writer, prev *Config) error { - var stickyErr error - set := func(key, value string) { - if stickyErr != nil { - return - } - _, err := fmt.Fprintf(w, "%s=%s\n", key, value) - if err != nil { - stickyErr = err - } - } - setUint16 := func(key string, value uint16) { - set(key, strconv.FormatUint(uint64(value), 10)) - } - setPeer := func(peer Peer) { - set("public_key", peer.PublicKey.UntypedHexString()) - } - - // Device config. - if !prev.PrivateKey.Equal(cfg.PrivateKey) { - set("private_key", cfg.PrivateKey.UntypedHexString()) - } - - old := make(map[key.NodePublic]Peer) - for _, p := range prev.Peers { - old[p.PublicKey] = p - } - - // Add/configure all new peers. - for _, p := range cfg.Peers { - oldPeer, wasPresent := old[p.PublicKey] - - // We only want to write the peer header/version if we're about - // to change something about that peer, or if it's a new peer. - // Figure out up-front whether we'll need to do anything for - // this peer, and skip doing anything if not. - // - // If the peer was not present in the previous config, this - // implies that this is a new peer; set all of these to 'true' - // to ensure that we're writing the full peer configuration. - willSetEndpoint := oldPeer.WGEndpoint != p.PublicKey || !wasPresent - willChangeIPs := !cidrsEqual(oldPeer.AllowedIPs, p.AllowedIPs) || !wasPresent - willChangeKeepalive := oldPeer.PersistentKeepalive != p.PersistentKeepalive // if not wasPresent, no need to redundantly set zero (default) - - if !willSetEndpoint && !willChangeIPs && !willChangeKeepalive { - // It's safe to skip doing anything here; wireguard-go - // will not remove a peer if it's unspecified unless we - // tell it to (which we do below if necessary). - continue - } - - setPeer(p) - set("protocol_version", "1") - - // Avoid setting endpoints if the correct one is already known - // to WireGuard, because doing so generates a bit more work in - // calling magicsock's ParseEndpoint for effectively a no-op. - if willSetEndpoint { - if wasPresent { - // We had an endpoint, and it was wrong. - // By construction, this should not happen. - // If it does, keep going so that we can recover from it, - // but log so that we know about it, - // because it is an indicator of other failed invariants. - // See corp issue 3016. - logf("[unexpected] endpoint changed from %s to %s", oldPeer.WGEndpoint, p.PublicKey) - } - set("endpoint", p.PublicKey.UntypedHexString()) - } - - // TODO: replace_allowed_ips is expensive. - // If p.AllowedIPs is a strict superset of oldPeer.AllowedIPs, - // then skip replace_allowed_ips and instead add only - // the new ipps with allowed_ip. - if willChangeIPs { - set("replace_allowed_ips", "true") - for _, ipp := range p.AllowedIPs { - set("allowed_ip", ipp.String()) - } - } - - // Set PersistentKeepalive after the peer is otherwise configured, - // because it can trigger handshake packets. - if willChangeKeepalive { - setUint16("persistent_keepalive_interval", p.PersistentKeepalive) - } - } - - // Remove peers that were present but should no longer be. - for _, p := range cfg.Peers { - delete(old, p.PublicKey) - } - for _, p := range old { - setPeer(p) - set("remove", "true") - } - - if stickyErr != nil { - stickyErr = fmt.Errorf("ToUAPI: %w", stickyErr) - } - return stickyErr -} - -func cidrsEqual(x, y []netip.Prefix) bool { - // TODO: re-implement using netaddr.IPSet.Equal. - if len(x) != len(y) { - return false - } - // First see if they're equal in order, without allocating. - exact := true - for i := range x { - if x[i] != y[i] { - exact = false - break - } - } - if exact { - return true - } - - // Otherwise, see if they're the same, but out of order. - m := make(map[netip.Prefix]bool) - for _, v := range x { - m[v] = true - } - for _, v := range y { - if !m[v] { - return false - } - } - return true -} diff --git a/wgengine/wgengine.go b/wgengine/wgengine.go index 9dd782e4a..5ca4b75cf 100644 --- a/wgengine/wgengine.go +++ b/wgengine/wgengine.go @@ -137,4 +137,8 @@ type Engine interface { // packets traversing the data path. The hook can be uninstalled by // calling this function with a nil value. InstallCaptureHook(packet.CaptureCallback) + + // SetPeerByIPPacketFunc installs a callback used by wireguard-go to + // look up which peer should handle an outbound packet by destination IP. + SetPeerByIPPacketFunc(func(netip.Addr) (_ key.NodePublic, ok bool)) }