ssh/tailssh: fix race in session termination message delivery

When a recording upload fails mid-session, the recording goroutine
cancels the session context. This triggers two concurrent paths:
exec.CommandContext kills the process (causing cmd.Wait to return),
and killProcessOnContextDone tries to write the termination message
via exitOnce.Do. If cmd.Wait returns first, the main goroutine's
exitOnce.Do(func(){}) steals the once, and the termination message
is never written to the client.

Fix by waiting for killProcessOnContextDone to finish writing the
termination message (via <-ss.exitHandled) before claiming exitOnce,
when the context is already done.

Also fix the fallback path when launchProcess itself fails due to
context cancellation: use SSHTerminationMessage() with the correct
"\r\n\r\n" framing instead of fmt.Fprintf with the internal error
string.

Deflakes TestSSHRecordingCancelsSessionsOnUploadFailure, which was
failing consistently at a low rate due to the exitOnce race. After
this fix, flakestress passes with 8,668 runs, 0 failures.

Fixes #7707 (again. hopefully for good.)

Change-Id: I5ab911c71574db8d3f9d979fb839f273be51ecf9
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
main
Brad Fitzpatrick 2 weeks ago committed by Brad Fitzpatrick
parent 6e44c6828b
commit 2b1cfa7c4d
  1. 20
      ssh/tailssh/tailssh.go
  2. 3
      ssh/tailssh/tailssh_test.go

@ -1034,9 +1034,11 @@ func (ss *sshSession) run() {
if err != nil { if err != nil {
logf("start failed: %v", err.Error()) logf("start failed: %v", err.Error())
if errors.Is(err, context.Canceled) { if errors.Is(err, context.Canceled) {
err := context.Cause(ss.ctx) cause := context.Cause(ss.ctx)
if uve, ok := errors.AsType[userVisibleError](err); ok { if serr, ok := cause.(SSHTerminationError); ok {
fmt.Fprintf(ss, "%s\r\n", uve) if msg := serr.SSHTerminationMessage(); msg != "" {
io.WriteString(ss.Stderr(), "\r\n\r\n"+msg+"\r\n\r\n")
}
} }
} }
ss.Exit(1) ss.Exit(1)
@ -1093,6 +1095,15 @@ func (ss *sshSession) run() {
err = ss.cmd.Wait() err = ss.cmd.Wait()
processDone.Store(true) processDone.Store(true)
if ss.ctx.Err() != nil {
// Context was canceled (e.g., recording upload failure).
// Wait for killProcessOnContextDone to finish writing any
// termination message before we proceed. This must happen
// before closeAll and CloseWrite so the SSH channel is
// still writable.
<-ss.exitHandled
}
// This will either make the SSH Termination goroutine be a no-op, // This will either make the SSH Termination goroutine be a no-op,
// or itself will be a no-op because the process was killed by the // or itself will be a no-op because the process was killed by the
// aforementioned goroutine. // aforementioned goroutine.
@ -1105,9 +1116,6 @@ func (ss *sshSession) run() {
select { select {
case <-outputDone: case <-outputDone:
case <-ss.ctx.Done(): case <-ss.ctx.Done():
// Wait for killProcessOnContextDone to finish writing any
// termination message to the client before we call ss.Exit,
// which tears down the SSH channel.
<-ss.exitHandled <-ss.exitHandled
} }

@ -36,7 +36,6 @@ import (
gliderssh "github.com/tailscale/gliderssh" gliderssh "github.com/tailscale/gliderssh"
"golang.org/x/net/http2" "golang.org/x/net/http2"
"golang.org/x/net/http2/h2c" "golang.org/x/net/http2/h2c"
"tailscale.com/cmd/testwrapper/flakytest"
"tailscale.com/ipn/ipnlocal" "tailscale.com/ipn/ipnlocal"
"tailscale.com/ipn/store/mem" "tailscale.com/ipn/store/mem"
"tailscale.com/net/memnet" "tailscale.com/net/memnet"
@ -470,8 +469,6 @@ func newSSHRule(action *tailcfg.SSHAction) *tailcfg.SSHRule {
} }
func TestSSHRecordingCancelsSessionsOnUploadFailure(t *testing.T) { func TestSSHRecordingCancelsSessionsOnUploadFailure(t *testing.T) {
flakytest.Mark(t, "https://github.com/tailscale/tailscale/issues/7707")
if runtime.GOOS != "linux" && runtime.GOOS != "darwin" { if runtime.GOOS != "linux" && runtime.GOOS != "darwin" {
t.Skipf("skipping on %q; only runs on linux and darwin", runtime.GOOS) t.Skipf("skipping on %q; only runs on linux and darwin", runtime.GOOS)
} }

Loading…
Cancel
Save