Updates tailscale/corp#27805 Updates tailscale/corp#27806 Updates tailscale/corp#37964 Change-Id: I7bb5ed7f258e840a8208e5d725c7b2f126d7ef96 Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>main
parent
120f27f383
commit
d42b3743b7
@ -0,0 +1,176 @@ |
||||
// Copyright (c) Tailscale Inc & contributors
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
// Package porttrack provides race-free ephemeral port assignment for
|
||||
// subprocess tests. The parent test process creates a [Collector] that
|
||||
// listens on a TCP port; the child process uses [Listen] which, when
|
||||
// given a magic address, binds to localhost:0 and reports the actual
|
||||
// port back to the collector.
|
||||
//
|
||||
// The magic address format is:
|
||||
//
|
||||
// testport-report:HOST:PORT/LABEL
|
||||
//
|
||||
// where HOST:PORT is the collector's TCP address and LABEL identifies
|
||||
// which listener this is (e.g. "main", "plaintext").
|
||||
//
|
||||
// When [Listen] is called with a non-magic address, it falls through to
|
||||
// [net.Listen] with zero overhead beyond a single [strings.HasPrefix]
|
||||
// check.
|
||||
package porttrack |
||||
|
||||
import ( |
||||
"bufio" |
||||
"context" |
||||
"fmt" |
||||
"net" |
||||
"strconv" |
||||
"strings" |
||||
"sync" |
||||
|
||||
"tailscale.com/util/testenv" |
||||
) |
||||
|
||||
const magicPrefix = "testport-report:" |
||||
|
||||
// Collector is the parent/test side of the porttrack protocol. It
|
||||
// listens for port reports from child processes that used [Listen]
|
||||
// with a magic address obtained from [Collector.Addr].
|
||||
type Collector struct { |
||||
ln net.Listener |
||||
mu sync.Mutex |
||||
cond *sync.Cond |
||||
ports map[string]int |
||||
err error // non-nil if a context passed to Port was cancelled
|
||||
} |
||||
|
||||
// NewCollector creates a new Collector. The collector's TCP listener is
|
||||
// closed when t finishes.
|
||||
func NewCollector(t testenv.TB) *Collector { |
||||
t.Helper() |
||||
ln, err := net.Listen("tcp", "127.0.0.1:0") |
||||
if err != nil { |
||||
t.Fatalf("porttrack.NewCollector: %v", err) |
||||
} |
||||
c := &Collector{ |
||||
ln: ln, |
||||
ports: make(map[string]int), |
||||
} |
||||
c.cond = sync.NewCond(&c.mu) |
||||
go c.accept(t) |
||||
t.Cleanup(func() { ln.Close() }) |
||||
return c |
||||
} |
||||
|
||||
// accept runs in a goroutine, accepting connections and parsing port
|
||||
// reports until the listener is closed.
|
||||
func (c *Collector) accept(t testenv.TB) { |
||||
for { |
||||
conn, err := c.ln.Accept() |
||||
if err != nil { |
||||
return // listener closed
|
||||
} |
||||
go c.handleConn(t, conn) |
||||
} |
||||
} |
||||
|
||||
func (c *Collector) handleConn(t testenv.TB, conn net.Conn) { |
||||
defer conn.Close() |
||||
scanner := bufio.NewScanner(conn) |
||||
for scanner.Scan() { |
||||
line := scanner.Text() |
||||
label, portStr, ok := strings.Cut(line, "\t") |
||||
if !ok { |
||||
t.Errorf("porttrack: malformed report line: %q", line) |
||||
return |
||||
} |
||||
port, err := strconv.Atoi(portStr) |
||||
if err != nil { |
||||
t.Errorf("porttrack: bad port in report %q: %v", line, err) |
||||
return |
||||
} |
||||
c.mu.Lock() |
||||
c.ports[label] = port |
||||
c.cond.Broadcast() |
||||
c.mu.Unlock() |
||||
} |
||||
} |
||||
|
||||
// Addr returns a magic address string that, when passed to [Listen],
|
||||
// causes the child to bind to localhost:0 and report its actual port
|
||||
// back to this collector under the given label.
|
||||
func (c *Collector) Addr(label string) string { |
||||
return magicPrefix + c.ln.Addr().String() + "/" + label |
||||
} |
||||
|
||||
// Port blocks until the child process has reported the port for the
|
||||
// given label, then returns it. If ctx is cancelled before a port is
|
||||
// reported, Port returns the context's cause as an error.
|
||||
func (c *Collector) Port(ctx context.Context, label string) (int, error) { |
||||
stop := context.AfterFunc(ctx, func() { |
||||
c.mu.Lock() |
||||
defer c.mu.Unlock() |
||||
if c.err == nil { |
||||
c.err = context.Cause(ctx) |
||||
} |
||||
c.cond.Broadcast() |
||||
}) |
||||
defer stop() |
||||
|
||||
c.mu.Lock() |
||||
defer c.mu.Unlock() |
||||
for { |
||||
if p, ok := c.ports[label]; ok { |
||||
return p, nil |
||||
} |
||||
if c.err != nil { |
||||
return 0, c.err |
||||
} |
||||
c.cond.Wait() |
||||
} |
||||
} |
||||
|
||||
// Listen is the child/production side of the porttrack protocol.
|
||||
//
|
||||
// If address has the magic prefix (as returned by [Collector.Addr]),
|
||||
// Listen binds to localhost:0 on the given network, then TCP-connects
|
||||
// to the collector and writes "LABEL\tPORT\n" to report the actual
|
||||
// port. The collector connection is closed before returning.
|
||||
//
|
||||
// If address does not have the magic prefix, Listen is simply
|
||||
// [net.Listen](network, address).
|
||||
func Listen(network, address string) (net.Listener, error) { |
||||
rest, ok := strings.CutPrefix(address, magicPrefix) |
||||
if !ok { |
||||
return net.Listen(network, address) |
||||
} |
||||
|
||||
// rest is "HOST:PORT/LABEL"
|
||||
slashIdx := strings.LastIndex(rest, "/") |
||||
if slashIdx < 0 { |
||||
return nil, fmt.Errorf("porttrack: malformed magic address %q: missing /LABEL", address) |
||||
} |
||||
collectorAddr := rest[:slashIdx] |
||||
label := rest[slashIdx+1:] |
||||
|
||||
ln, err := net.Listen(network, "localhost:0") |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
|
||||
port := ln.Addr().(*net.TCPAddr).Port |
||||
|
||||
conn, err := net.Dial("tcp", collectorAddr) |
||||
if err != nil { |
||||
ln.Close() |
||||
return nil, fmt.Errorf("porttrack: failed to connect to collector at %s: %v", collectorAddr, err) |
||||
} |
||||
_, err = fmt.Fprintf(conn, "%s\t%d\n", label, port) |
||||
conn.Close() |
||||
if err != nil { |
||||
ln.Close() |
||||
return nil, fmt.Errorf("porttrack: failed to report port to collector: %v", err) |
||||
} |
||||
|
||||
return ln, nil |
||||
} |
||||
@ -0,0 +1,95 @@ |
||||
// Copyright (c) Tailscale Inc & contributors
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package porttrack |
||||
|
||||
import ( |
||||
"context" |
||||
"errors" |
||||
"fmt" |
||||
"net" |
||||
"net/http" |
||||
"testing" |
||||
) |
||||
|
||||
func TestCollectorAndListen(t *testing.T) { |
||||
c := NewCollector(t) |
||||
|
||||
labels := []string{"main", "plaintext", "debug"} |
||||
ports := make([]int, len(labels)) |
||||
|
||||
for i, label := range labels { |
||||
ln, err := Listen("tcp", c.Addr(label)) |
||||
if err != nil { |
||||
t.Fatalf("Listen(%q): %v", label, err) |
||||
} |
||||
defer ln.Close() |
||||
p, err := c.Port(t.Context(), label) |
||||
if err != nil { |
||||
t.Fatalf("Port(%q): %v", label, err) |
||||
} |
||||
ports[i] = p |
||||
} |
||||
|
||||
// All ports should be distinct non-zero values.
|
||||
seen := map[int]string{} |
||||
for i, label := range labels { |
||||
if ports[i] == 0 { |
||||
t.Errorf("Port(%q) = 0", label) |
||||
} |
||||
if prev, ok := seen[ports[i]]; ok { |
||||
t.Errorf("Port(%q) = Port(%q) = %d", label, prev, ports[i]) |
||||
} |
||||
seen[ports[i]] = label |
||||
} |
||||
} |
||||
|
||||
func TestListenPassthrough(t *testing.T) { |
||||
ln, err := Listen("tcp", "localhost:0") |
||||
if err != nil { |
||||
t.Fatalf("Listen passthrough: %v", err) |
||||
} |
||||
defer ln.Close() |
||||
if ln.Addr().(*net.TCPAddr).Port == 0 { |
||||
t.Fatal("expected non-zero port") |
||||
} |
||||
} |
||||
|
||||
func TestRoundTrip(t *testing.T) { |
||||
c := NewCollector(t) |
||||
|
||||
ln, err := Listen("tcp", c.Addr("http")) |
||||
if err != nil { |
||||
t.Fatalf("Listen: %v", err) |
||||
} |
||||
defer ln.Close() |
||||
|
||||
// Start a server on the listener.
|
||||
go http.Serve(ln, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
||||
w.WriteHeader(http.StatusNoContent) |
||||
})) |
||||
|
||||
port, err := c.Port(t.Context(), "http") |
||||
if err != nil { |
||||
t.Fatalf("Port: %v", err) |
||||
} |
||||
resp, err := http.Get(fmt.Sprintf("http://localhost:%d/", port)) |
||||
if err != nil { |
||||
t.Fatalf("http.Get: %v", err) |
||||
} |
||||
resp.Body.Close() |
||||
if resp.StatusCode != http.StatusNoContent { |
||||
t.Fatalf("status = %d, want %d", resp.StatusCode, http.StatusNoContent) |
||||
} |
||||
} |
||||
|
||||
func TestPortContextCancelled(t *testing.T) { |
||||
c := NewCollector(t) |
||||
// Nobody will ever report "never", so Port should block until ctx is done.
|
||||
ctx, cancel := context.WithCancel(t.Context()) |
||||
cancel() |
||||
_, err := c.Port(ctx, "never") |
||||
if !errors.Is(err, context.Canceled) { |
||||
t.Fatalf("Port with cancelled context: got %v, want %v", err, context.Canceled) |
||||
} |
||||
} |
||||
Loading…
Reference in new issue