diff --git a/cmd/tsconnect/wasm/wasm_js.go b/cmd/tsconnect/wasm/wasm_js.go index 71e8476a0..13806e271 100644 --- a/cmd/tsconnect/wasm/wasm_js.go +++ b/cmd/tsconnect/wasm/wasm_js.go @@ -20,11 +20,18 @@ import ( "net" "net/http" "net/netip" + "strconv" "strings" "syscall/js" "time" "golang.org/x/crypto/ssh" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" + "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" + "gvisor.dev/gvisor/pkg/waiter" "tailscale.com/control/controlclient" "tailscale.com/ipn" "tailscale.com/ipn/ipnauth" @@ -154,6 +161,7 @@ func newIPN(jsConfig js.Value) map[string]any { dialer: dialer, srv: srv, lb: lb, + ns: ns, controlURL: controlURL, authKey: authKey, hostname: hostname, @@ -208,6 +216,27 @@ func newIPN(jsConfig js.Value) map[string]any { url := args[0].String() return jsIPN.fetch(url) }), + "dial": js.FuncOf(func(this js.Value, args []js.Value) any { + if len(args) != 2 { + log.Printf("Usage: dial(network, addr)") + return nil + } + return jsIPN.dial(args[0].String(), args[1].String()) + }), + "listen": js.FuncOf(func(this js.Value, args []js.Value) any { + if len(args) != 2 { + log.Printf("Usage: listen(network, addr)") + return nil + } + return jsIPN.listen(args[0].String(), args[1].String()) + }), + "listenICMP": js.FuncOf(func(this js.Value, args []js.Value) any { + if len(args) != 1 { + log.Printf("Usage: listenICMP(network)") + return nil + } + return jsIPN.listenICMP(args[0].String()) + }), } } @@ -215,6 +244,7 @@ type jsIPN struct { dialer *tsdial.Dialer srv *ipnserver.Server lb *ipnlocal.LocalBackend + ns *netstack.Impl controlURL string authKey string hostname string @@ -531,6 +561,162 @@ func (i *jsIPN) fetch(url string) js.Value { }) } +func (i *jsIPN) dial(network, addr string) js.Value { + return makePromise(func() (any, error) { + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + conn, err := i.dialer.UserDial(ctx, network, addr) + if err != nil { + return nil, err + } + return wrapConn(conn), nil + }) +} + +func (i *jsIPN) listen(network, addr string) js.Value { + return makePromise(func() (any, error) { + pc, err := i.ns.ListenPacket(network, addr) + if err != nil { + return nil, err + } + return wrapPacketConn(pc), nil + }) +} + +func (i *jsIPN) listenICMP(network string) js.Value { + return makePromise(func() (any, error) { + var transportProto tcpip.TransportProtocolNumber + var networkProto tcpip.NetworkProtocolNumber + + switch network { + case "icmp4", "icmp": + transportProto = icmp.ProtocolNumber4 + networkProto = ipv4.ProtocolNumber + case "icmp6": + transportProto = icmp.ProtocolNumber6 + networkProto = ipv6.ProtocolNumber + default: + return nil, fmt.Errorf("unsupported network %q (use \"icmp4\" or \"icmp6\")", network) + } + + st := i.ns.Stack() + var wq waiter.Queue + ep, nserr := st.NewEndpoint(transportProto, networkProto, &wq) + if nserr != nil { + return nil, fmt.Errorf("creating ICMP endpoint: %v", nserr) + } + + pc := gonet.NewUDPConn(&wq, ep) + return wrapPacketConn(pc), nil + }) +} + +// wrapConn exposes a net.Conn to JavaScript with binary (Uint8Array) I/O. +func wrapConn(conn net.Conn) map[string]any { + return map[string]any{ + "read": js.FuncOf(func(this js.Value, args []js.Value) any { + return makePromise(func() (any, error) { + buf := make([]byte, 65536) + n, err := conn.Read(buf) + if err != nil { + return nil, err + } + arr := js.Global().Get("Uint8Array").New(n) + js.CopyBytesToJS(arr, buf[:n]) + return arr, nil + }) + }), + "write": js.FuncOf(func(this js.Value, args []js.Value) any { + return makePromise(func() (any, error) { + data := args[0] + buf := make([]byte, data.Get("length").Int()) + js.CopyBytesToGo(buf, data) + n, err := conn.Write(buf) + if err != nil { + return nil, err + } + return n, nil + }) + }), + "close": js.FuncOf(func(this js.Value, args []js.Value) any { + return conn.Close() != nil + }), + "localAddr": js.FuncOf(func(this js.Value, args []js.Value) any { + return conn.LocalAddr().String() + }), + "remoteAddr": js.FuncOf(func(this js.Value, args []js.Value) any { + return conn.RemoteAddr().String() + }), + } +} + +// wrapPacketConn exposes a net.PacketConn to JavaScript with binary (Uint8Array) I/O. +func wrapPacketConn(pc net.PacketConn) map[string]any { + return map[string]any{ + "readFrom": js.FuncOf(func(this js.Value, args []js.Value) any { + return makePromise(func() (any, error) { + buf := make([]byte, 65536) + n, addr, err := pc.ReadFrom(buf) + if err != nil { + return nil, err + } + arr := js.Global().Get("Uint8Array").New(n) + js.CopyBytesToJS(arr, buf[:n]) + return map[string]any{ + "data": arr, + "addr": addr.String(), + }, nil + }) + }), + "writeTo": js.FuncOf(func(this js.Value, args []js.Value) any { + return makePromise(func() (any, error) { + data := args[0] + addrStr := args[1].String() + buf := make([]byte, data.Get("length").Int()) + js.CopyBytesToGo(buf, data) + addr, err := resolveUDPAddr(addrStr) + if err != nil { + return nil, err + } + n, err := pc.WriteTo(buf, addr) + if err != nil { + return nil, err + } + return n, nil + }) + }), + "close": js.FuncOf(func(this js.Value, args []js.Value) any { + return pc.Close() != nil + }), + "localAddr": js.FuncOf(func(this js.Value, args []js.Value) any { + return pc.LocalAddr().String() + }), + } +} + +// resolveUDPAddr parses an address string that is either "host:port" or just +// an IP (for ICMP, where port defaults to 0). +func resolveUDPAddr(s string) (*net.UDPAddr, error) { + host, portStr, err := net.SplitHostPort(s) + if err != nil { + // Bare IP address without port (used for ICMP). + ip := net.ParseIP(s) + if ip == nil { + return nil, fmt.Errorf("invalid address: %s", s) + } + return &net.UDPAddr{IP: ip}, nil + } + ip := net.ParseIP(host) + if ip == nil { + return nil, fmt.Errorf("invalid IP: %s", host) + } + port, err := strconv.Atoi(portStr) + if err != nil { + return nil, fmt.Errorf("invalid port: %s", portStr) + } + return &net.UDPAddr{IP: ip, Port: port}, nil +} + type termWriter struct { f js.Value } diff --git a/wgengine/netstack/netstack.go b/wgengine/netstack/netstack.go index 4da89e364..1b3571a5b 100644 --- a/wgengine/netstack/netstack.go +++ b/wgengine/netstack/netstack.go @@ -280,6 +280,11 @@ type Impl struct { packetsInFlight map[stack.TransportEndpointID]struct{} } +// Stack returns the underlying gVisor network stack. +func (ns *Impl) Stack() *stack.Stack { + return ns.ipstack +} + const nicID = 1 // maxUDPPacketSize is the maximum size of a UDP packet we copy in