diff --git a/cmd/tsconnect/wasm/wasm_js.go b/cmd/tsconnect/wasm/wasm_js.go index 423fb8024..9b080c1e2 100644 --- a/cmd/tsconnect/wasm/wasm_js.go +++ b/cmd/tsconnect/wasm/wasm_js.go @@ -30,6 +30,7 @@ import ( "time" "golang.org/x/crypto/ssh" + "golang.org/x/net/dns/dnsmessage" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" @@ -315,6 +316,46 @@ func newIPN(jsConfig js.Value) map[string]any { } return jsIPN.setFunnel(args[0].String(), uint16(args[1].Int()), args[2].Bool()) }), + "whoIs": js.FuncOf(func(this js.Value, args []js.Value) any { + if len(args) < 1 { + log.Printf("Usage: whoIs(addrPort[, proto])") + return nil + } + proto := "" + if len(args) >= 2 { + proto = args[1].String() + } + return jsIPN.whoIs(args[0].String(), proto) + }), + "queryDNS": js.FuncOf(func(this js.Value, args []js.Value) any { + if len(args) < 1 { + log.Printf("Usage: queryDNS(name[, type])") + return nil + } + qtype := 1 // TypeA + if len(args) >= 2 { + qtype = args[1].Int() + } + return jsIPN.queryDNS(args[0].String(), qtype) + }), + "ping": js.FuncOf(func(this js.Value, args []js.Value) any { + if len(args) < 1 { + log.Printf("Usage: ping(ip[, type[, size]])") + return nil + } + pingType := "TSMP" + if len(args) >= 2 { + pingType = args[1].String() + } + size := 0 + if len(args) >= 3 { + size = args[2].Int() + } + return jsIPN.ping(args[0].String(), pingType, size) + }), + "suggestExitNode": js.FuncOf(func(this js.Value, args []js.Value) any { + return jsIPN.suggestExitNode() + }), } } @@ -1001,6 +1042,156 @@ func (i *jsIPN) setFunnel(hostname string, port uint16, enabled bool) js.Value { }) } +func (i *jsIPN) whoIs(addrPort string, proto string) js.Value { + return makePromise(func() (any, error) { + ipp, err := netip.ParseAddrPort(addrPort) + if err != nil { + return nil, fmt.Errorf("whoIs: invalid addr:port %q: %w", addrPort, err) + } + n, u, ok := i.lb.WhoIs(proto, ipp) + if !ok { + return nil, nil + } + addrs := make([]any, n.Addresses().Len()) + for idx, ap := range n.Addresses().All() { + addrs[idx] = ap.Addr().String() + } + return map[string]any{ + "node": map[string]any{ + "id": string(n.StableID()), + "name": n.Name(), + "addresses": addrs, + }, + "user": map[string]any{ + "id": int64(u.ID), + "loginName": u.LoginName, + "displayName": u.DisplayName, + "profilePicURL": u.ProfilePicURL, + }, + }, nil + }) +} + +func (i *jsIPN) queryDNS(name string, queryType int) js.Value { + return makePromise(func() (any, error) { + res, resolvers, err := i.lb.QueryDNS(name, dnsmessage.Type(queryType)) + if err != nil { + return nil, err + } + var p dnsmessage.Parser + if _, err := p.Start(res); err != nil { + return nil, fmt.Errorf("queryDNS: parsing response: %w", err) + } + if err := p.SkipAllQuestions(); err != nil { + return nil, fmt.Errorf("queryDNS: skipping questions: %w", err) + } + var answers []any + for { + h, err := p.AnswerHeader() + if err == dnsmessage.ErrSectionDone { + break + } + if err != nil { + return nil, fmt.Errorf("queryDNS: reading answer: %w", err) + } + switch h.Type { + case dnsmessage.TypeA: + r, err := p.AResource() + if err != nil { + return nil, fmt.Errorf("queryDNS: reading A record: %w", err) + } + answers = append(answers, netip.AddrFrom4(r.A).String()) + case dnsmessage.TypeAAAA: + r, err := p.AAAAResource() + if err != nil { + return nil, fmt.Errorf("queryDNS: reading AAAA record: %w", err) + } + answers = append(answers, netip.AddrFrom16(r.AAAA).String()) + case dnsmessage.TypeCNAME: + r, err := p.CNAMEResource() + if err != nil { + return nil, fmt.Errorf("queryDNS: reading CNAME record: %w", err) + } + answers = append(answers, r.CNAME.String()) + case dnsmessage.TypeTXT: + r, err := p.TXTResource() + if err != nil { + return nil, fmt.Errorf("queryDNS: reading TXT record: %w", err) + } + for _, s := range r.TXT { + answers = append(answers, s) + } + default: + if err := p.SkipAnswer(); err != nil { + return nil, fmt.Errorf("queryDNS: skipping unknown answer: %w", err) + } + } + } + resolverAddrs := make([]any, len(resolvers)) + for idx, r := range resolvers { + resolverAddrs[idx] = r.Addr + } + return map[string]any{ + "answers": answers, + "resolvers": resolverAddrs, + }, nil + }) +} + +func (i *jsIPN) ping(ip string, pingType string, size int) js.Value { + return makePromise(func() (any, error) { + addr, err := netip.ParseAddr(ip) + if err != nil { + return nil, fmt.Errorf("ping: invalid IP %q: %w", ip, err) + } + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + pr, err := i.lb.Ping(ctx, addr, tailcfg.PingType(pingType), size) + if err != nil { + return nil, err + } + result := map[string]any{ + "ip": pr.IP, + "nodeIP": pr.NodeIP, + "nodeName": pr.NodeName, + "latencySeconds": pr.LatencySeconds, + "endpoint": pr.Endpoint, + "derpRegionID": pr.DERPRegionID, + "derpRegionCode": pr.DERPRegionCode, + "peerAPIURL": pr.PeerAPIURL, + "isLocalIP": pr.IsLocalIP, + } + if pr.Err != "" { + result["err"] = pr.Err + } + return result, nil + }) +} + +func (i *jsIPN) suggestExitNode() js.Value { + return makePromise(func() (any, error) { + resp, err := i.lb.SuggestExitNode() + if err != nil { + return nil, err + } + result := map[string]any{ + "id": string(resp.ID), + "name": resp.Name, + } + if l := resp.Location; l.Valid() { + result["location"] = map[string]any{ + "country": l.Country(), + "countryCode": l.CountryCode(), + "city": l.City(), + "cityCode": l.CityCode(), + "latitude": l.Latitude(), + "longitude": l.Longitude(), + } + } + return result, nil + }) +} + // wrapConn exposes a net.Conn to JavaScript with binary (Uint8Array) I/O. func wrapConn(conn net.Conn) map[string]any { return map[string]any{