diff --git a/ssh/tailssh/tailssh.go b/ssh/tailssh/tailssh.go index 95cf771af..c13d3d29e 100644 --- a/ssh/tailssh/tailssh.go +++ b/ssh/tailssh/tailssh.go @@ -1034,9 +1034,11 @@ func (ss *sshSession) run() { if err != nil { logf("start failed: %v", err.Error()) if errors.Is(err, context.Canceled) { - err := context.Cause(ss.ctx) - if uve, ok := errors.AsType[userVisibleError](err); ok { - fmt.Fprintf(ss, "%s\r\n", uve) + cause := context.Cause(ss.ctx) + if serr, ok := cause.(SSHTerminationError); ok { + if msg := serr.SSHTerminationMessage(); msg != "" { + io.WriteString(ss.Stderr(), "\r\n\r\n"+msg+"\r\n\r\n") + } } } ss.Exit(1) @@ -1093,6 +1095,15 @@ func (ss *sshSession) run() { err = ss.cmd.Wait() processDone.Store(true) + if ss.ctx.Err() != nil { + // Context was canceled (e.g., recording upload failure). + // Wait for killProcessOnContextDone to finish writing any + // termination message before we proceed. This must happen + // before closeAll and CloseWrite so the SSH channel is + // still writable. + <-ss.exitHandled + } + // This will either make the SSH Termination goroutine be a no-op, // or itself will be a no-op because the process was killed by the // aforementioned goroutine. @@ -1105,9 +1116,6 @@ func (ss *sshSession) run() { select { case <-outputDone: case <-ss.ctx.Done(): - // Wait for killProcessOnContextDone to finish writing any - // termination message to the client before we call ss.Exit, - // which tears down the SSH channel. <-ss.exitHandled } diff --git a/ssh/tailssh/tailssh_test.go b/ssh/tailssh/tailssh_test.go index 5141209ec..c8b5f698b 100644 --- a/ssh/tailssh/tailssh_test.go +++ b/ssh/tailssh/tailssh_test.go @@ -36,7 +36,6 @@ import ( gliderssh "github.com/tailscale/gliderssh" "golang.org/x/net/http2" "golang.org/x/net/http2/h2c" - "tailscale.com/cmd/testwrapper/flakytest" "tailscale.com/ipn/ipnlocal" "tailscale.com/ipn/store/mem" "tailscale.com/net/memnet" @@ -470,8 +469,6 @@ func newSSHRule(action *tailcfg.SSHAction) *tailcfg.SSHRule { } func TestSSHRecordingCancelsSessionsOnUploadFailure(t *testing.T) { - flakytest.Mark(t, "https://github.com/tailscale/tailscale/issues/7707") - if runtime.GOOS != "linux" && runtime.GOOS != "darwin" { t.Skipf("skipping on %q; only runs on linux and darwin", runtime.GOOS) }