|
|
|
|
@ -8,11 +8,24 @@ |
|
|
|
|
package tailssh |
|
|
|
|
|
|
|
|
|
import ( |
|
|
|
|
"context" |
|
|
|
|
"errors" |
|
|
|
|
"fmt" |
|
|
|
|
"net" |
|
|
|
|
"os/exec" |
|
|
|
|
"os/user" |
|
|
|
|
"testing" |
|
|
|
|
"time" |
|
|
|
|
|
|
|
|
|
"github.com/gliderlabs/ssh" |
|
|
|
|
"inet.af/netaddr" |
|
|
|
|
"tailscale.com/ipn" |
|
|
|
|
"tailscale.com/ipn/ipnlocal" |
|
|
|
|
"tailscale.com/net/tsdial" |
|
|
|
|
"tailscale.com/tailcfg" |
|
|
|
|
"tailscale.com/tstest" |
|
|
|
|
"tailscale.com/types/logger" |
|
|
|
|
"tailscale.com/wgengine" |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
func TestMatchRule(t *testing.T) { |
|
|
|
|
@ -155,3 +168,75 @@ func TestMatchRule(t *testing.T) { |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func timePtr(t time.Time) *time.Time { return &t } |
|
|
|
|
|
|
|
|
|
func TestSSH(t *testing.T) { |
|
|
|
|
ml := new(tstest.MemLogger) |
|
|
|
|
var logf logger.Logf = ml.Logf |
|
|
|
|
eng, err := wgengine.NewFakeUserspaceEngine(logf, 0) |
|
|
|
|
if err != nil { |
|
|
|
|
t.Fatal(err) |
|
|
|
|
} |
|
|
|
|
lb, err := ipnlocal.NewLocalBackend(logf, "", |
|
|
|
|
new(ipn.MemoryStore), |
|
|
|
|
new(tsdial.Dialer), |
|
|
|
|
eng, 0) |
|
|
|
|
if err != nil { |
|
|
|
|
t.Fatal(err) |
|
|
|
|
} |
|
|
|
|
defer lb.Shutdown() |
|
|
|
|
dir := t.TempDir() |
|
|
|
|
lb.SetVarRoot(dir) |
|
|
|
|
|
|
|
|
|
srv := &server{lb, logf} |
|
|
|
|
ss, err := srv.newSSHServer() |
|
|
|
|
if err != nil { |
|
|
|
|
t.Fatal(err) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
u, err := user.Current() |
|
|
|
|
if err != nil { |
|
|
|
|
t.Fatal(err) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
ci := &sshConnInfo{ |
|
|
|
|
sshUser: "test", |
|
|
|
|
srcIP: netaddr.MustParseIP("1.2.3.4"), |
|
|
|
|
node: &tailcfg.Node{}, |
|
|
|
|
uprof: &tailcfg.UserProfile{}, |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
ctx, cancel := context.WithCancel(context.Background()) |
|
|
|
|
defer cancel() |
|
|
|
|
ss.Handler = func(s ssh.Session) { |
|
|
|
|
srv.handleAcceptedSSH(ctx, s, ci, u) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
ln, err := net.Listen("tcp4", "127.0.0.1:0") |
|
|
|
|
if err != nil { |
|
|
|
|
t.Fatal(err) |
|
|
|
|
} |
|
|
|
|
defer ln.Close() |
|
|
|
|
port := ln.Addr().(*net.TCPAddr).Port |
|
|
|
|
|
|
|
|
|
go func() { |
|
|
|
|
for { |
|
|
|
|
c, err := ln.Accept() |
|
|
|
|
if err != nil { |
|
|
|
|
if !errors.Is(err, net.ErrClosed) { |
|
|
|
|
t.Errorf("Accept: %v", err) |
|
|
|
|
} |
|
|
|
|
return |
|
|
|
|
} |
|
|
|
|
go ss.HandleConn(c) |
|
|
|
|
} |
|
|
|
|
}() |
|
|
|
|
|
|
|
|
|
got, err := exec.Command("ssh", |
|
|
|
|
"-p", fmt.Sprint(port), |
|
|
|
|
"-o", "StrictHostKeyChecking=no", |
|
|
|
|
"user@127.0.0.1", "env").CombinedOutput() |
|
|
|
|
if err != nil { |
|
|
|
|
t.Fatal(err) |
|
|
|
|
} |
|
|
|
|
t.Logf("Got: %s", got) |
|
|
|
|
} |
|
|
|
|
|