tsnet,wgengine/netstack: add ListenPacket and tests
This adds a new ListenPacket function on tsnet.Server which acts mostly like `net.ListenPacket`. Unlike `Server.Listen`, this requires listening on a specific IP and does not automatically listen on both V4 and V6 addresses of the Server when the IP is unspecified. To test this, it also adds UDP support to tsdial.Dialer.UserDial and plumbs it through the localapi. Then an associated test to make sure the UDP functionality works from both sides. Updates #12182 Signed-off-by: Maisem Ali <maisem@tailscale.com>
This commit is contained in:
+81
-27
@@ -562,14 +562,25 @@ func (s *Server) start() (reterr error) {
|
||||
return ok
|
||||
}
|
||||
s.dialer.NetstackDialTCP = func(ctx context.Context, dst netip.AddrPort) (net.Conn, error) {
|
||||
// Note: don't just return ns.DialContextTCP or we'll
|
||||
// return an interface containing a nil pointer.
|
||||
// Note: don't just return ns.DialContextTCP or we'll return
|
||||
// *gonet.TCPConn(nil) instead of a nil interface which trips up
|
||||
// callers.
|
||||
tcpConn, err := ns.DialContextTCP(ctx, dst)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return tcpConn, nil
|
||||
}
|
||||
s.dialer.NetstackDialUDP = func(ctx context.Context, dst netip.AddrPort) (net.Conn, error) {
|
||||
// Note: don't just return ns.DialContextUDP or we'll return
|
||||
// *gonet.UDPConn(nil) instead of a nil interface which trips up
|
||||
// callers.
|
||||
udpConn, err := ns.DialContextUDP(ctx, dst)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return udpConn, nil
|
||||
}
|
||||
|
||||
if s.Store == nil {
|
||||
stateFile := filepath.Join(s.rootPath, "tailscaled.state")
|
||||
@@ -908,6 +919,34 @@ func (s *Server) Listen(network, addr string) (net.Listener, error) {
|
||||
return s.listen(network, addr, listenOnTailnet)
|
||||
}
|
||||
|
||||
// ListenPacket announces on the Tailscale network.
|
||||
//
|
||||
// The network must be "udp", "udp4" or "udp6". The addr must be of the form
|
||||
// "ip:port" (or "[ip]:port") where ip is a valid IPv4 or IPv6 address
|
||||
// corresponding to "udp4" or "udp6" respectively. IP must be specified.
|
||||
//
|
||||
// If s has not been started yet, it will be started.
|
||||
func (s *Server) ListenPacket(network, addr string) (net.PacketConn, error) {
|
||||
ap, err := resolveListenAddr(network, addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !ap.Addr().IsValid() {
|
||||
return nil, fmt.Errorf("tsnet.ListenPacket(%q, %q): address must be a valid IP", network, addr)
|
||||
}
|
||||
if network == "udp" {
|
||||
if ap.Addr().Is4() {
|
||||
network = "udp4"
|
||||
} else {
|
||||
network = "udp6"
|
||||
}
|
||||
}
|
||||
if err := s.Start(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s.netstack.ListenPacket(network, ap.String())
|
||||
}
|
||||
|
||||
// ListenTLS announces only on the Tailscale network.
|
||||
// It returns a TLS listener wrapping the tsnet listener.
|
||||
// It will start the server if it has not been started yet.
|
||||
@@ -1070,50 +1109,65 @@ const (
|
||||
listenOnBoth = listenOn("listen-on-both")
|
||||
)
|
||||
|
||||
func (s *Server) listen(network, addr string, lnOn listenOn) (net.Listener, error) {
|
||||
switch network {
|
||||
case "", "tcp", "tcp4", "tcp6", "udp", "udp4", "udp6":
|
||||
default:
|
||||
return nil, errors.New("unsupported network type")
|
||||
}
|
||||
// resolveListenAddr resolves a network and address into a netip.AddrPort. The
|
||||
// returned netip.AddrPort.Addr will be the zero value if the address is empty.
|
||||
// The port must be a valid port number. The caller is responsible for checking
|
||||
// the network and address are valid.
|
||||
//
|
||||
// It resolves well-known port names and validates the address is a valid IP
|
||||
// literal for the network.
|
||||
func resolveListenAddr(network, addr string) (netip.AddrPort, error) {
|
||||
var zero netip.AddrPort
|
||||
host, portStr, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("tsnet: %w", err)
|
||||
return zero, fmt.Errorf("tsnet: %w", err)
|
||||
}
|
||||
port, err := net.LookupPort(network, portStr)
|
||||
if err != nil || port < 0 || port > math.MaxUint16 {
|
||||
// LookupPort returns an error on out of range values so the bounds
|
||||
// checks on port should be unnecessary, but harmless. If they do
|
||||
// match, worst case this error message says "invalid port: <nil>".
|
||||
return nil, fmt.Errorf("invalid port: %w", err)
|
||||
return zero, fmt.Errorf("invalid port: %w", err)
|
||||
}
|
||||
var bindHostOrZero netip.Addr
|
||||
if host != "" {
|
||||
bindHostOrZero, err = netip.ParseAddr(host)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid Listen addr %q; host part must be empty or IP literal", host)
|
||||
}
|
||||
if strings.HasSuffix(network, "4") && !bindHostOrZero.Is4() {
|
||||
return nil, fmt.Errorf("invalid non-IPv4 addr %v for network %q", host, network)
|
||||
}
|
||||
if strings.HasSuffix(network, "6") && !bindHostOrZero.Is6() {
|
||||
return nil, fmt.Errorf("invalid non-IPv6 addr %v for network %q", host, network)
|
||||
}
|
||||
if host == "" {
|
||||
return netip.AddrPortFrom(netip.Addr{}, uint16(port)), nil
|
||||
}
|
||||
|
||||
bindHostOrZero, err := netip.ParseAddr(host)
|
||||
if err != nil {
|
||||
return zero, fmt.Errorf("invalid Listen addr %q; host part must be empty or IP literal", host)
|
||||
}
|
||||
if strings.HasSuffix(network, "4") && !bindHostOrZero.Is4() {
|
||||
return zero, fmt.Errorf("invalid non-IPv4 addr %v for network %q", host, network)
|
||||
}
|
||||
if strings.HasSuffix(network, "6") && !bindHostOrZero.Is6() {
|
||||
return zero, fmt.Errorf("invalid non-IPv6 addr %v for network %q", host, network)
|
||||
}
|
||||
return netip.AddrPortFrom(bindHostOrZero, uint16(port)), nil
|
||||
}
|
||||
|
||||
func (s *Server) listen(network, addr string, lnOn listenOn) (*listener, error) {
|
||||
switch network {
|
||||
case "", "tcp", "tcp4", "tcp6", "udp", "udp4", "udp6":
|
||||
default:
|
||||
return nil, errors.New("unsupported network type")
|
||||
}
|
||||
host, err := resolveListenAddr(network, addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := s.Start(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var keys []listenKey
|
||||
switch lnOn {
|
||||
case listenOnTailnet:
|
||||
keys = append(keys, listenKey{network, bindHostOrZero, uint16(port), false})
|
||||
keys = append(keys, listenKey{network, host.Addr(), host.Port(), false})
|
||||
case listenOnFunnel:
|
||||
keys = append(keys, listenKey{network, bindHostOrZero, uint16(port), true})
|
||||
keys = append(keys, listenKey{network, host.Addr(), host.Port(), true})
|
||||
case listenOnBoth:
|
||||
keys = append(keys, listenKey{network, bindHostOrZero, uint16(port), false})
|
||||
keys = append(keys, listenKey{network, bindHostOrZero, uint16(port), true})
|
||||
keys = append(keys, listenKey{network, host.Addr(), host.Port(), false})
|
||||
keys = append(keys, listenKey{network, host.Addr(), host.Port(), true})
|
||||
}
|
||||
|
||||
ln := &listener{
|
||||
|
||||
@@ -745,3 +745,73 @@ func TestCapturePcap(t *testing.T) {
|
||||
t.Errorf("s2 pcap file size = %d, want > pcapHeaderSize(%d)", got, pcapHeaderSize)
|
||||
}
|
||||
}
|
||||
|
||||
func TestUDPConn(t *testing.T) {
|
||||
tstest.ResourceCheck(t)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
controlURL, _ := startControl(t)
|
||||
s1, s1ip, _ := startServer(t, ctx, controlURL, "s1")
|
||||
s2, s2ip, _ := startServer(t, ctx, controlURL, "s2")
|
||||
|
||||
lc2, err := s2.LocalClient()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// ping to make sure the connection is up.
|
||||
res, err := lc2.Ping(ctx, s1ip, tailcfg.PingICMP)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Logf("ping success: %#+v", res)
|
||||
|
||||
pc := must.Get(s1.ListenPacket("udp", fmt.Sprintf("%s:8081", s1ip)))
|
||||
defer pc.Close()
|
||||
|
||||
// Dial to s1 from s2
|
||||
w, err := s2.Dial(ctx, "udp", fmt.Sprintf("%s:8081", s1ip))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer w.Close()
|
||||
|
||||
// Send a packet from s2 to s1
|
||||
want := "hello"
|
||||
if _, err := io.WriteString(w, want); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Receive the packet on s1
|
||||
got := make([]byte, 1024)
|
||||
n, from, err := pc.ReadFrom(got)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
got = got[:n]
|
||||
t.Logf("got: %q", got)
|
||||
if string(got) != want {
|
||||
t.Errorf("got %q, want %q", got, want)
|
||||
}
|
||||
if from.(*net.UDPAddr).AddrPort().Addr() != s2ip {
|
||||
t.Errorf("got from %v, want %v", from, s2ip)
|
||||
}
|
||||
|
||||
// Write a response back to s2
|
||||
if _, err := pc.WriteTo([]byte("world"), from); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Receive the response on s2
|
||||
got = make([]byte, 1024)
|
||||
n, err = w.Read(got)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
got = got[:n]
|
||||
t.Logf("got: %q", got)
|
||||
if string(got) != "world" {
|
||||
t.Errorf("got %q, want world", got)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user