tsnet: clean up state when Service listener is closed

Previous to this change, closing the listener returned by
Server.ListenService would free system resources, but not clean up state
in the Server's local backend. With this change, the local backend state
is now cleaned on close.

Fixes tailscale/corp#35860

Signed-off-by: Harry Harpham <harry@tailscale.com>
main
Harry Harpham 3 months ago
parent 1794765cc6
commit 4f43ad3042
  1. 267
      tsnet/tsnet.go
  2. 242
      tsnet/tsnet_test.go

@ -59,6 +59,7 @@ import (
"tailscale.com/types/logger" "tailscale.com/types/logger"
"tailscale.com/types/logid" "tailscale.com/types/logid"
"tailscale.com/types/nettype" "tailscale.com/types/nettype"
"tailscale.com/types/views"
"tailscale.com/util/clientmetric" "tailscale.com/util/clientmetric"
"tailscale.com/util/mak" "tailscale.com/util/mak"
"tailscale.com/util/set" "tailscale.com/util/set"
@ -175,32 +176,33 @@ type Server struct {
// This field must be set before calling Start. // This field must be set before calling Start.
Tun tun.Device Tun tun.Device
initOnce sync.Once initOnce sync.Once
initErr error initErr error
lb *ipnlocal.LocalBackend lb *ipnlocal.LocalBackend
sys *tsd.System sys *tsd.System
netstack *netstack.Impl netstack *netstack.Impl
netMon *netmon.Monitor netMon *netmon.Monitor
rootPath string // the state directory rootPath string // the state directory
hostname string hostname string
shutdownCtx context.Context shutdownCtx context.Context
shutdownCancel context.CancelFunc shutdownCancel context.CancelFunc
proxyCred string // SOCKS5 proxy auth for loopbackListener proxyCred string // SOCKS5 proxy auth for loopbackListener
localAPICred string // basic auth password for loopbackListener localAPICred string // basic auth password for loopbackListener
loopbackListener net.Listener // optional loopback for localapi and proxies loopbackListener net.Listener // optional loopback for localapi and proxies
localAPIListener net.Listener // in-memory, used by localClient localAPIListener net.Listener // in-memory, used by localClient
localClient *local.Client // in-memory localClient *local.Client // in-memory
localAPIServer *http.Server localAPIServer *http.Server
resetServeConfigOnce sync.Once resetServeStateOnce sync.Once
logbuffer *filch.Filch logbuffer *filch.Filch
logtail *logtail.Logger logtail *logtail.Logger
logid logid.PublicID logid logid.PublicID
mu sync.Mutex mu sync.Mutex
listeners map[listenKey]*listener listeners map[listenKey]*listener
nextEphemeralPort uint16 // next port to try in ephemeral range; 0 means use ephemeralPortFirst nextEphemeralPort uint16 // next port to try in ephemeral range; 0 means use ephemeralPortFirst
fallbackTCPHandlers set.HandleSet[FallbackTCPHandler] fallbackTCPHandlers set.HandleSet[FallbackTCPHandler]
dialer *tsdial.Dialer dialer *tsdial.Dialer
advertisedServices map[tailcfg.ServiceName]int
closeOnce sync.Once closeOnce sync.Once
} }
@ -415,15 +417,27 @@ func (s *Server) Up(ctx context.Context) (*ipnstate.Status, error) {
return nil, errors.New("tsnet.Up: running, but no ip") return nil, errors.New("tsnet.Up: running, but no ip")
} }
// The first time Up is run, clear the persisted serve config. // The first time Up is run, clear the persisted serve config
// We do this to prevent messy interactions with stale config in // and Service advertisements. We do this to prevent messy
// the face of code changes. // interactions with stale config in the face of code changes.
var srvResetErr error var srvCfgErr error
s.resetServeConfigOnce.Do(func() { var svcAdErr error
srvResetErr = lc.SetServeConfig(ctx, new(ipn.ServeConfig)) s.resetServeStateOnce.Do(func() {
if err := lc.SetServeConfig(ctx, new(ipn.ServeConfig)); err != nil {
srvCfgErr = fmt.Errorf("clearing serve config: %w", err)
}
_, err := s.lb.EditPrefs(&ipn.MaskedPrefs{
AdvertiseServicesSet: true,
Prefs: ipn.Prefs{
AdvertiseServices: []string{},
},
})
if err != nil {
svcAdErr = fmt.Errorf("clearing Service advertisements: %w", err)
}
}) })
if srvResetErr != nil { if err := errors.Join(srvCfgErr, svcAdErr); err != nil {
return nil, fmt.Errorf("tsnet.Up: clearing serve config: %w", err) return nil, fmt.Errorf("tsnet.Up: %w", err)
} }
return status, nil return status, nil
@ -1474,6 +1488,13 @@ type ServiceListener struct {
// FQDN is the fully-qualifed domain name of this Service. // FQDN is the fully-qualifed domain name of this Service.
FQDN string FQDN string
// Used by Close.
closeOnce sync.Once
closeErr error // written to during execution of closeOnce, read by Close()
s *Server // read and written to during execution of closeOnce
svcName tailcfg.ServiceName // read during execution of closeOnce
mode ServiceMode // read during execution of closeOnce
} }
// Addr returns the listener's network address. This will be the Service's // Addr returns the listener's network address. This will be the Service's
@ -1481,16 +1502,142 @@ type ServiceListener struct {
// //
// A hostname is not truly a network address, but Services listen on multiple // A hostname is not truly a network address, but Services listen on multiple
// addresses (the IPv4 and IPv6 virtual IPs). // addresses (the IPv4 and IPv6 virtual IPs).
func (sl ServiceListener) Addr() net.Addr { func (sl *ServiceListener) Addr() net.Addr {
return sl.addr return sl.addr
} }
// cleanServeConfig cleans serve config changes made to support this listener.
// This should only be called by Close.
func (sl *ServiceListener) cleanServeConfig() error {
sc, etag, err := sl.s.lb.ServeConfigETag()
if err != nil {
return fmt.Errorf("fetching current config: %w", err)
}
if !sc.Valid() || !sc.Services().Contains(sl.svcName) {
return nil
}
srvConfig := sc.AsStruct()
svcConfig := srvConfig.Services[sl.svcName]
switch m := sl.mode.(type) {
case ServiceModeTCP:
delete(svcConfig.TCP, m.Port)
case ServiceModeHTTP:
hp := net.JoinHostPort(sl.FQDN, strconv.Itoa(int(m.Port)))
delete(svcConfig.Web, ipn.HostPort(hp))
delete(svcConfig.TCP, m.Port)
default:
return fmt.Errorf("unexpected ServiceMode %T", sl.mode)
}
if err := sl.s.lb.SetServeConfig(srvConfig, etag); err != nil {
return fmt.Errorf("setting config: %w", err)
}
return nil
}
// Close closes the listener and clears state related to hosting the Service.
// Behavior is undefined after the [Server] has been closed.
func (sl *ServiceListener) Close() error {
// We should only clean up state once. Otherwise we can stomp on state
// created by new listeners.
sl.closeOnce.Do(func() {
// Two pieces of state we need to clear:
// 1. The Service advertisement pref
// 2. Artifacts in the serve config
// Then we can close the listener.
var adErr error
if err := sl.s.decrementServiceAdvertisement(sl.svcName); err != nil {
adErr = fmt.Errorf("managing Service advertisements: %w", err)
}
var srvCfgErr error
if err := sl.cleanServeConfig(); err != nil {
srvCfgErr = fmt.Errorf("cleaning config changes: %w", err)
}
sl.closeErr = errors.Join(sl.Listener.Close(), adErr, srvCfgErr)
})
return sl.closeErr
}
// ErrUntaggedServiceHost is returned by ListenService when run on a node // ErrUntaggedServiceHost is returned by ListenService when run on a node
// without any ACL tags. A node must use a tag-based identity to act as a // without any ACL tags. A node must use a tag-based identity to act as a
// Service host. For more information, see: // Service host. For more information, see:
// https://tailscale.com/kb/1552/tailscale-services#prerequisites // https://tailscale.com/kb/1552/tailscale-services#prerequisites
var ErrUntaggedServiceHost = errors.New("service hosts must be tagged nodes") var ErrUntaggedServiceHost = errors.New("service hosts must be tagged nodes")
// advertiseService ensures the Service is advertised by this node.
func (s *Server) advertiseService(name tailcfg.ServiceName) error {
s.mu.Lock()
defer s.mu.Unlock()
advertised := s.lb.Prefs().AdvertiseServices()
if !views.SliceContains(advertised, name.String()) {
newAdvertised := make([]string, 0, advertised.Len()+1)
advertised.AppendTo(newAdvertised)
newAdvertised = append(newAdvertised, name.String())
_, err := s.lb.EditPrefs(&ipn.MaskedPrefs{
AdvertiseServicesSet: true,
Prefs: ipn.Prefs{
AdvertiseServices: newAdvertised,
},
})
if err != nil {
return err
}
}
mak.Set(&s.advertisedServices, name, s.advertisedServices[name]+1)
return nil
}
// decrementServiceAdvertisement decrements the count of listeners this node has
// advertising the Service. Advertisement of the Service will be withdrawn if
// the count hits zero. It is an error to call this function when the Service is
// not being advertised by this node.
func (s *Server) decrementServiceAdvertisement(name tailcfg.ServiceName) error {
s.mu.Lock()
defer s.mu.Unlock()
cleanAdvertisement := func() error {
delete(s.advertisedServices, name)
advertised := s.lb.Prefs().AdvertiseServices()
if !views.SliceContains(advertised, name.String()) {
return nil
}
newAdvertised := make([]string, 0, advertised.Len()-1)
for _, svc := range advertised.All() {
if svc == name.String() {
continue
}
newAdvertised = append(newAdvertised, svc)
}
_, err := s.lb.EditPrefs(&ipn.MaskedPrefs{
AdvertiseServicesSet: true,
Prefs: ipn.Prefs{
AdvertiseServices: newAdvertised,
},
})
return err
}
if s.advertisedServices[name] <= 0 {
advertisements := s.advertisedServices[name]
// We somehow mismatched increments and decrements. Clear current
// advertisements and surface the mismatch as an error.
return errors.Join(
cleanAdvertisement(),
fmt.Errorf("service decrement requested with %d advertisements", advertisements),
)
}
s.advertisedServices[name]--
if s.advertisedServices[name] > 0 {
// If there are still listeners advertising the Service, then there's
// nothing more for us to do.
return nil
}
return cleanAdvertisement()
}
// ListenService creates a network listener for a Tailscale Service. This will // ListenService creates a network listener for a Tailscale Service. This will
// advertise this node as hosting the Service. Note that: // advertise this node as hosting the Service. Note that:
// - Approval must still be granted by an admin or by ACL auto-approval rules. // - Approval must still be granted by an admin or by ACL auto-approval rules.
@ -1503,13 +1650,22 @@ var ErrUntaggedServiceHost = errors.New("service hosts must be tagged nodes")
// //
// This function will start the server if it is not already started. // This function will start the server if it is not already started.
func (s *Server) ListenService(name string, mode ServiceMode) (*ServiceListener, error) { func (s *Server) ListenService(name string, mode ServiceMode) (*ServiceListener, error) {
if err := tailcfg.ServiceName(name).Validate(); err != nil { svcName := tailcfg.ServiceName(name)
if err := svcName.Validate(); err != nil {
return nil, err return nil, err
} }
if mode == nil { if mode == nil {
return nil, errors.New("mode may not be nil") return nil, errors.New("mode may not be nil")
} }
svcName := name
// We collect cleanup tasks as we go and execute these on error. If we make
// it to the end we abandon these cleanup tasks by setting onError to nil.
var onError []func()
defer func() {
for _, f := range onError {
f()
}
}()
// TODO(hwh33,tailscale/corp#35859): support TUN mode // TODO(hwh33,tailscale/corp#35859): support TUN mode
@ -1524,31 +1680,21 @@ func (s *Server) ListenService(name string, mode ServiceMode) (*ServiceListener,
return nil, ErrUntaggedServiceHost return nil, ErrUntaggedServiceHost
} }
advertisedServices := s.lb.Prefs().AdvertiseServices().AsSlice() if err := s.advertiseService(svcName); err != nil {
if !slices.Contains(advertisedServices, svcName) { return nil, fmt.Errorf("advertising Service: %w", err)
// TODO(hwh33,tailscale/corp#35860): clean these prefs up when (a) we
// exit early due to error or (b) when the returned listener is closed.
_, err = s.lb.EditPrefs(&ipn.MaskedPrefs{
AdvertiseServicesSet: true,
Prefs: ipn.Prefs{
AdvertiseServices: append(advertisedServices, svcName),
},
})
if err != nil {
return nil, fmt.Errorf("updating advertised Services: %w", err)
}
} }
onError = append(onError, func() { s.decrementServiceAdvertisement(svcName) })
srvConfig := new(ipn.ServeConfig) srvCfg := new(ipn.ServeConfig)
sc, srvConfigETag, err := s.lb.ServeConfigETag() sc, srvCfgETag, err := s.lb.ServeConfigETag()
if err != nil { if err != nil {
return nil, fmt.Errorf("fetching current serve config: %w", err) return nil, fmt.Errorf("fetching current serve config: %w", err)
} }
if sc.Valid() { if sc.Valid() {
srvConfig = sc.AsStruct() srvCfg = sc.AsStruct()
} }
fqdn := tailcfg.ServiceName(svcName).WithoutPrefix() + "." + st.CurrentTailnet.MagicDNSSuffix fqdn := svcName.WithoutPrefix() + "." + st.CurrentTailnet.MagicDNSSuffix
// svcAddr is used to implement Addr() on the returned listener. // svcAddr is used to implement Addr() on the returned listener.
svcAddr := addr{ svcAddr := addr{
@ -1564,6 +1710,13 @@ func (s *Server) ListenService(name string, mode ServiceMode) (*ServiceListener,
if m.port() == 0 { if m.port() == 0 {
return nil, errors.New("must specify a port to advertise") return nil, errors.New("must specify a port to advertise")
} }
if svcCfg, ok := srvCfg.Services[svcName]; ok {
if _, handlerExists := svcCfg.TCP[m.port()]; handlerExists {
// We know that a handler must have been started in this runtime
// because serve config is reset on the first [Server.Up].
return nil, errors.New("a Service handler already exists for this port")
}
}
svcAddr.addr += ":" + strconv.Itoa(int(m.port())) svcAddr.addr += ":" + strconv.Itoa(int(m.port()))
} }
@ -1572,11 +1725,12 @@ func (s *Server) ListenService(name string, mode ServiceMode) (*ServiceListener,
if err != nil { if err != nil {
return nil, fmt.Errorf("starting local listener: %w", err) return nil, fmt.Errorf("starting local listener: %w", err)
} }
onError = append(onError, func() { ln.Close() })
switch m := mode.(type) { switch m := mode.(type) {
case ServiceModeTCP: case ServiceModeTCP:
// Forward all connections from service-hostname:port to our socket. // Forward all connections from service-hostname:port to our socket.
srvConfig.SetTCPForwardingForService( srvCfg.SetTCPForwardingForService(
m.Port, ln.Addr().String(), m.TerminateTLS, m.Port, ln.Addr().String(), m.TerminateTLS,
tailcfg.ServiceName(svcName), m.PROXYProtocolVersion, st.CurrentTailnet.MagicDNSSuffix) tailcfg.ServiceName(svcName), m.PROXYProtocolVersion, st.CurrentTailnet.MagicDNSSuffix)
case ServiceModeHTTP: case ServiceModeHTTP:
@ -1597,30 +1751,29 @@ func (s *Server) ListenService(name string, mode ServiceMode) (*ServiceListener,
} else { } else {
h.Proxy += path h.Proxy += path
} }
srvConfig.SetWebHandler(&h, svcName, m.Port, path, m.HTTPS, mds) srvCfg.SetWebHandler(&h, svcName.String(), m.Port, path, m.HTTPS, mds)
} }
// We always need a root handler. // We always need a root handler.
if !haveRootHandler { if !haveRootHandler {
h := ipn.HTTPHandler{Proxy: ln.Addr().String()} h := ipn.HTTPHandler{Proxy: ln.Addr().String()}
srvConfig.SetWebHandler(&h, svcName, m.Port, "/", m.HTTPS, mds) srvCfg.SetWebHandler(&h, svcName.String(), m.Port, "/", m.HTTPS, mds)
} }
default: default:
ln.Close()
return nil, fmt.Errorf("unknown ServiceMode type %T", m) return nil, fmt.Errorf("unknown ServiceMode type %T", m)
} }
if err := s.lb.SetServeConfig(srvConfig, srvConfigETag); err != nil { if err := s.lb.SetServeConfig(srvCfg, srvCfgETag); err != nil {
ln.Close()
return nil, err return nil, err
} }
// TODO(hwh33,tailscale/corp#35860): clean up state (advertising prefs, onError = nil
// serve config changes) when the returned listener is closed.
return &ServiceListener{ return &ServiceListener{
Listener: ln, Listener: ln,
FQDN: fqdn, FQDN: fqdn,
addr: svcAddr, addr: svcAddr,
s: s,
svcName: svcName,
mode: mode,
}, nil }, nil
} }

@ -59,7 +59,6 @@ import (
"tailscale.com/types/key" "tailscale.com/types/key"
"tailscale.com/types/logger" "tailscale.com/types/logger"
"tailscale.com/types/netmap" "tailscale.com/types/netmap"
"tailscale.com/types/views"
"tailscale.com/util/mak" "tailscale.com/util/mak"
"tailscale.com/util/must" "tailscale.com/util/must"
) )
@ -877,7 +876,7 @@ func TestFunnelClose(t *testing.T) {
// To obtain config the listener might want to clobber, we: // To obtain config the listener might want to clobber, we:
// - run a listener // - run a listener
// - grab the config // - grab the config
// - close the listener (clearing config) // - close the listener (so we can run another on the same port)
ln := must.Get(s.ListenFunnel("tcp", ":443")) ln := must.Get(s.ListenFunnel("tcp", ":443"))
before := s.lb.ServeConfig() before := s.lb.ServeConfig()
ln.Close() ln.Close()
@ -935,10 +934,7 @@ func TestFunnelClose(t *testing.T) {
// The listener should immediately return an error indicating closure. // The listener should immediately return an error indicating closure.
_, err := ln.Accept() _, err := ln.Accept()
// Looking for a string in the error sucks, but it's supposed to stay if !errors.Is(err, net.ErrClosed) {
// consistent:
// https://github.com/golang/go/blob/108b333d510c1f60877ac917375d7931791acfe6/src/internal/poll/fd.go#L20-L24
if err == nil || !strings.Contains(err.Error(), "use of closed network connection") {
t.Fatal("expected listener to be closed, got:", err) t.Fatal("expected listener to be closed, got:", err)
} }
}) })
@ -947,24 +943,6 @@ func TestFunnelClose(t *testing.T) {
func TestListenService(t *testing.T) { func TestListenService(t *testing.T) {
tstest.Shard(t) tstest.Shard(t)
// First test an error case which doesn't require all of the fancy setup.
t.Run("untagged_node_error", func(t *testing.T) {
ctx := t.Context()
controlURL, _ := startControl(t)
serviceHost, _, _ := startServer(t, ctx, controlURL, "service-host")
ln, err := serviceHost.ListenService("svc:foo", ServiceModeTCP{Port: 8080})
if ln != nil {
ln.Close()
}
if !errors.Is(err, ErrUntaggedServiceHost) {
t.Fatalf("expected %v, got %v", ErrUntaggedServiceHost, err)
}
})
// Now on to the fancier tests.
type dialFn func(context.Context, string, string) (net.Conn, error) type dialFn func(context.Context, string, string) (net.Conn, error)
// TCP helpers // TCP helpers
@ -1243,11 +1221,10 @@ func TestListenService(t *testing.T) {
// The Service host must have the 'service-host' capability, which // The Service host must have the 'service-host' capability, which
// is a mapping from the Service name to the Service VIP. // is a mapping from the Service name to the Service VIP.
var serviceHostCaps map[tailcfg.ServiceName]views.Slice[netip.Addr]
mak.Set(&serviceHostCaps, serviceName, views.SliceOf([]netip.Addr{netip.MustParseAddr(serviceVIP)}))
j := must.Get(json.Marshal(serviceHostCaps))
cm := serviceHost.lb.NetMap().SelfNode.CapMap().AsMap() cm := serviceHost.lb.NetMap().SelfNode.CapMap().AsMap()
mak.Set(&cm, tailcfg.NodeAttrServiceHost, []tailcfg.RawMessage{tailcfg.RawMessage(j)}) mak.Set(&cm, tailcfg.NodeAttrServiceHost, []tailcfg.RawMessage{
tailcfg.RawMessage(fmt.Sprintf(`{"%s": ["%s"]}`, serviceName, serviceVIP)),
})
control.SetNodeCapMap(serviceHost.lb.NodeKey(), cm) control.SetNodeCapMap(serviceHost.lb.NodeKey(), cm)
// The Service host must be allowed to advertise the Service VIP. // The Service host must be allowed to advertise the Service VIP.
@ -1269,16 +1246,19 @@ func TestListenService(t *testing.T) {
}, },
})) }))
// Do the test's extra setup before configuring DNS. This allows
// us to use the configured DNS records as sentinel values when
// waiting for all of this setup to be visible to test nodes.
if tt.extraSetup != nil {
tt.extraSetup(t, control)
}
// Set up DNS for our Service. // Set up DNS for our Service.
control.AddDNSRecords(tailcfg.DNSRecord{ control.AddDNSRecords(tailcfg.DNSRecord{
Name: serviceName.WithoutPrefix() + "." + control.MagicDNSDomain, Name: serviceName.WithoutPrefix() + "." + control.MagicDNSDomain,
Value: serviceVIP, Value: serviceVIP,
}) })
if tt.extraSetup != nil {
tt.extraSetup(t, control)
}
// Wait until both nodes have up-to-date netmaps before // Wait until both nodes have up-to-date netmaps before
// proceeding with the test. // proceeding with the test.
netmapUpToDate := func(nm *netmap.NetworkMap) bool { netmapUpToDate := func(nm *netmap.NetworkMap) bool {
@ -1295,12 +1275,210 @@ func TestListenService(t *testing.T) {
} }
waitForLatestNetmap(t, serviceClient) waitForLatestNetmap(t, serviceClient)
waitForLatestNetmap(t, serviceHost) waitForLatestNetmap(t, serviceHost)
// == Done setting up mock state ==
// Start the Service listeners.
listeners := make([]*ServiceListener, 0, len(tt.modes))
for _, input := range tt.modes {
ln := must.Get(serviceHost.ListenService(serviceName.String(), input))
defer ln.Close()
listeners = append(listeners, ln)
}
tt.run(t, listeners, serviceClient)
} }
t.Run("TUN", func(t *testing.T) { doTest(t, true) }) t.Run("TUN", func(t *testing.T) { doTest(t, true) })
t.Run("netstack", func(t *testing.T) { doTest(t, false) }) t.Run("netstack", func(t *testing.T) { doTest(t, false) })
}) })
} }
// Error cases.
t.Run("untagged_node_error", func(t *testing.T) {
ctx := t.Context()
controlURL, _ := startControl(t)
serviceHost, _, _ := startServer(t, ctx, controlURL, "service-host")
ln, err := serviceHost.ListenService("svc:foo", ServiceModeTCP{Port: 8080})
if ln != nil {
ln.Close()
}
if !errors.Is(err, ErrUntaggedServiceHost) {
t.Fatalf("expected %v, got %v", ErrUntaggedServiceHost, err)
}
})
t.Run("duplicate_listeners", func(t *testing.T) {
ctx := t.Context()
controlURL, control := startControl(t)
serviceHost, _, _ := startServer(t, ctx, controlURL, "service-host")
// Service hosts must be a tagged node (any tag will do).
serviceHostNode := control.Node(serviceHost.lb.NodeKey())
serviceHostNode.Tags = append(serviceHostNode.Tags, "some-tag")
control.UpdateNode(serviceHostNode)
// Wait for an up-to-date netmap before proceeding with the test.
netmapUpToDate := func(nm *netmap.NetworkMap) bool {
return nm != nil && nm.SelfNode.IsTagged()
}
waitForLatestNetmap := func(t *testing.T, s *Server) {
t.Helper()
w := must.Get(s.localClient.WatchIPNBus(t.Context(), ipn.NotifyInitialNetMap))
defer w.Close()
for n := must.Get(w.Next()); !netmapUpToDate(n.NetMap); n = must.Get(w.Next()) {
}
}
waitForLatestNetmap(t, serviceHost)
ln := must.Get(serviceHost.ListenService("svc:foo", ServiceModeTCP{Port: 8080}))
defer ln.Close()
ln, err := serviceHost.ListenService("svc:foo", ServiceModeTCP{Port: 8080})
if ln != nil {
ln.Close()
}
if err == nil {
t.Fatal("expected error for redundant listener")
}
// An HTTP listener on the same port should also collide
ln, err = serviceHost.ListenService("svc:foo", ServiceModeHTTP{Port: 8080})
if ln != nil {
ln.Close()
}
if err == nil {
t.Fatal("expected error for redundant listener")
}
})
}
func TestListenServiceClose(t *testing.T) {
tstest.Shard(t)
diffServeConfig := func(a, b ipn.ServeConfigView) string {
// We treat a mapping from svc:foo to nil or the zero value as if it
// didn't exist at all. This is consistent with how the local backend
// treats service configs when nil or zero.
tr := cmp.Transformer("DeleteEmptyServices", func(m map[tailcfg.ServiceName]*ipn.ServiceConfig) map[tailcfg.ServiceName]*ipn.ServiceConfig {
mCopy := map[tailcfg.ServiceName]*ipn.ServiceConfig{}
for k, v := range m {
if v == nil {
continue
}
if rv := reflect.ValueOf(*v); rv.IsValid() && rv.IsZero() {
continue
}
mCopy[k] = v
}
return mCopy
})
return cmp.Diff(a.AsStruct(), b.AsStruct(), tr)
}
tests := []struct {
name string
run func(t *testing.T, serviceHost *Server)
}{
{
name: "TCP",
run: func(t *testing.T, s *Server) {
before := s.lb.ServeConfig()
ln := must.Get(s.ListenService("svc:foo", ServiceModeTCP{Port: 8080}))
ln.Close()
after := s.lb.ServeConfig()
if diff := diffServeConfig(after, before); diff != "" {
t.Fatalf("expected serve config to be unchanged after close (-got, +want):\n%s", diff)
}
},
},
{
name: "HTTP",
run: func(t *testing.T, s *Server) {
before := s.lb.ServeConfig()
ln := must.Get(s.ListenService("svc:foo", ServiceModeHTTP{Port: 8080}))
ln.Close()
after := s.lb.ServeConfig()
if diff := diffServeConfig(after, before); diff != "" {
t.Fatalf("expected serve config to be unchanged after close (-got, +want):\n%s", diff)
}
},
},
{
// Closing one listener should not affect config for another listener.
name: "two_listeners",
run: func(t *testing.T, s *Server) {
// Start a listener on 443.
ln1 := must.Get(s.ListenService("svc:foo", ServiceModeTCP{Port: 443}))
defer ln1.Close()
// Save the serve config for this original listener.
before := s.lb.ServeConfig()
// Now start and close a new listener on a different port.
ln2 := must.Get(s.ListenService("svc:foo", ServiceModeTCP{Port: 8080}))
ln2.Close()
// The serve config for the original listener should be intact.
after := s.lb.ServeConfig()
if diff := diffServeConfig(after, before); diff != "" {
t.Fatalf("expected existing config to remain intact (-got, +want):\n%s", diff)
}
},
},
{
// It should be possible to close a listener and free system
// resources even when the Server has been closed (or the listener
// should be automatically closed).
name: "after_server_close",
run: func(t *testing.T, s *Server) {
ln := must.Get(s.ListenService("svc:foo", ServiceModeTCP{Port: 8080}))
// Close the server, then close the listener.
must.Do(s.Close())
// We don't care whether we get an error from the listener closing.
t.Log("close error:", ln.Close())
// The listener should immediately return an error indicating closure.
_, err := ln.Accept()
if !errors.Is(err, net.ErrClosed) {
t.Fatal("expected listener to be closed, got:", err)
}
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := t.Context()
controlURL, control := startControl(t)
s, _, _ := startServer(t, ctx, controlURL, "service-host")
// Service hosts must be a tagged node (any tag will do).
serviceHostNode := control.Node(s.lb.NodeKey())
serviceHostNode.Tags = append(serviceHostNode.Tags, "some-tag")
control.UpdateNode(serviceHostNode)
// Wait for an up-to-date netmap before proceeding with the test.
netmapUpToDate := func(nm *netmap.NetworkMap) bool {
return nm != nil && nm.SelfNode.IsTagged()
}
waitForLatestNetmap := func(t *testing.T, s *Server) {
t.Helper()
w := must.Get(s.localClient.WatchIPNBus(t.Context(), ipn.NotifyInitialNetMap))
defer w.Close()
for n := must.Get(w.Next()); !netmapUpToDate(n.NetMap); n = must.Get(w.Next()) {
}
}
waitForLatestNetmap(t, s)
tt.run(t, s)
})
}
} }
func TestListenerClose(t *testing.T) { func TestListenerClose(t *testing.T) {

Loading…
Cancel
Save