|
|
|
|
@ -326,6 +326,108 @@ func newSSHRule(action *tailcfg.SSHAction) *tailcfg.SSHRule { |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func TestSSHRecordingCancelsSessionsOnUploadFailure(t *testing.T) { |
|
|
|
|
if runtime.GOOS != "linux" && runtime.GOOS != "darwin" { |
|
|
|
|
t.Skipf("skipping on %q; only runs on linux and darwin", runtime.GOOS) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
var handler http.HandlerFunc |
|
|
|
|
recordingServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { |
|
|
|
|
handler(w, r) |
|
|
|
|
})) |
|
|
|
|
defer recordingServer.Close() |
|
|
|
|
|
|
|
|
|
s := &server{ |
|
|
|
|
logf: t.Logf, |
|
|
|
|
httpc: recordingServer.Client(), |
|
|
|
|
lb: &localState{ |
|
|
|
|
sshEnabled: true, |
|
|
|
|
matchingRule: newSSHRule( |
|
|
|
|
&tailcfg.SSHAction{ |
|
|
|
|
Accept: true, |
|
|
|
|
Recorders: []netip.AddrPort{ |
|
|
|
|
netip.MustParseAddrPort(recordingServer.Listener.Addr().String()), |
|
|
|
|
}, |
|
|
|
|
}, |
|
|
|
|
), |
|
|
|
|
}, |
|
|
|
|
} |
|
|
|
|
defer s.Shutdown() |
|
|
|
|
|
|
|
|
|
const sshUser = "alice" |
|
|
|
|
cfg := &gossh.ClientConfig{ |
|
|
|
|
User: sshUser, |
|
|
|
|
HostKeyCallback: gossh.InsecureIgnoreHostKey(), |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
tests := []struct { |
|
|
|
|
name string |
|
|
|
|
handler func(w http.ResponseWriter, r *http.Request) |
|
|
|
|
sshCommand string |
|
|
|
|
wantClientOutput string |
|
|
|
|
}{ |
|
|
|
|
{ |
|
|
|
|
name: "upload-denied", |
|
|
|
|
handler: func(w http.ResponseWriter, r *http.Request) { |
|
|
|
|
w.WriteHeader(http.StatusForbidden) |
|
|
|
|
}, |
|
|
|
|
sshCommand: "echo hello", |
|
|
|
|
wantClientOutput: "recording: server responded with 403 Forbidden\r\n", |
|
|
|
|
}, |
|
|
|
|
{ |
|
|
|
|
name: "upload-fails-after-starting", |
|
|
|
|
handler: func(w http.ResponseWriter, r *http.Request) { |
|
|
|
|
r.Body.Read(make([]byte, 1)) |
|
|
|
|
time.Sleep(100 * time.Millisecond) |
|
|
|
|
w.WriteHeader(http.StatusInternalServerError) |
|
|
|
|
}, |
|
|
|
|
sshCommand: "echo hello && sleep 1 && echo world", |
|
|
|
|
wantClientOutput: "hello\n\r\n\r\nrecording server responded with: 500 Internal Server Error\r\n\r\n", |
|
|
|
|
}, |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
src, dst := must.Get(netip.ParseAddrPort("100.100.100.101:2231")), must.Get(netip.ParseAddrPort("100.100.100.102:22")) |
|
|
|
|
|
|
|
|
|
for _, tt := range tests { |
|
|
|
|
t.Run(tt.name, func(t *testing.T) { |
|
|
|
|
tstest.Replace(t, &handler, tt.handler) |
|
|
|
|
sc, dc := memnet.NewTCPConn(src, dst, 1024) |
|
|
|
|
var wg sync.WaitGroup |
|
|
|
|
wg.Add(1) |
|
|
|
|
go func() { |
|
|
|
|
defer wg.Done() |
|
|
|
|
c, chans, reqs, err := gossh.NewClientConn(sc, sc.RemoteAddr().String(), cfg) |
|
|
|
|
if err != nil { |
|
|
|
|
t.Errorf("client: %v", err) |
|
|
|
|
return |
|
|
|
|
} |
|
|
|
|
client := gossh.NewClient(c, chans, reqs) |
|
|
|
|
defer client.Close() |
|
|
|
|
session, err := client.NewSession() |
|
|
|
|
if err != nil { |
|
|
|
|
t.Errorf("client: %v", err) |
|
|
|
|
return |
|
|
|
|
} |
|
|
|
|
defer session.Close() |
|
|
|
|
t.Logf("client established session") |
|
|
|
|
got, err := session.CombinedOutput(tt.sshCommand) |
|
|
|
|
if err != nil { |
|
|
|
|
t.Logf("client got: %q: %v", got, err) |
|
|
|
|
} else { |
|
|
|
|
t.Errorf("client did not get kicked out: %q", got) |
|
|
|
|
} |
|
|
|
|
if string(got) != tt.wantClientOutput { |
|
|
|
|
t.Errorf("client got %q, want %q", got, tt.wantClientOutput) |
|
|
|
|
} |
|
|
|
|
}() |
|
|
|
|
if err := s.HandleSSHConn(dc); err != nil { |
|
|
|
|
t.Errorf("unexpected error: %v", err) |
|
|
|
|
} |
|
|
|
|
wg.Wait() |
|
|
|
|
}) |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
// TestSSHRecordingNonInteractive tests that the SSH server records the SSH session
|
|
|
|
|
// when the client is not interactive (i.e. no PTY).
|
|
|
|
|
// It starts a local SSH server and a recording server. The recording server
|
|
|
|
|
@ -346,30 +448,28 @@ func TestSSHRecordingNonInteractive(t *testing.T) { |
|
|
|
|
t.Error(err) |
|
|
|
|
return |
|
|
|
|
} |
|
|
|
|
w.WriteHeader(http.StatusOK) |
|
|
|
|
})) |
|
|
|
|
defer recordingServer.Close() |
|
|
|
|
|
|
|
|
|
state := &localState{ |
|
|
|
|
sshEnabled: true, |
|
|
|
|
matchingRule: newSSHRule( |
|
|
|
|
&tailcfg.SSHAction{ |
|
|
|
|
Accept: true, |
|
|
|
|
Recorders: []netip.AddrPort{ |
|
|
|
|
must.Get(netip.ParseAddrPort(recordingServer.Listener.Addr().String())), |
|
|
|
|
}, |
|
|
|
|
}, |
|
|
|
|
), |
|
|
|
|
} |
|
|
|
|
s := &server{ |
|
|
|
|
logf: t.Logf, |
|
|
|
|
logf: logger.Discard, |
|
|
|
|
httpc: recordingServer.Client(), |
|
|
|
|
lb: &localState{ |
|
|
|
|
sshEnabled: true, |
|
|
|
|
matchingRule: newSSHRule( |
|
|
|
|
&tailcfg.SSHAction{ |
|
|
|
|
Accept: true, |
|
|
|
|
Recorders: []netip.AddrPort{ |
|
|
|
|
must.Get(netip.ParseAddrPort(recordingServer.Listener.Addr().String())), |
|
|
|
|
}, |
|
|
|
|
}, |
|
|
|
|
), |
|
|
|
|
}, |
|
|
|
|
} |
|
|
|
|
defer s.Shutdown() |
|
|
|
|
|
|
|
|
|
src, dst := must.Get(netip.ParseAddrPort("100.100.100.101:2231")), must.Get(netip.ParseAddrPort("100.100.100.102:22")) |
|
|
|
|
sc, dc := memnet.NewTCPConn(src, dst, 1024) |
|
|
|
|
s.lb = state |
|
|
|
|
|
|
|
|
|
const sshUser = "alice" |
|
|
|
|
cfg := &gossh.ClientConfig{ |
|
|
|
|
|