Imported type and parsing, with minor modifications. Updates tailscale/corp#15043 Signed-off-by: James Tucker <james@tailscale.com>main
parent
d62af8e643
commit
96f01a73b1
@ -0,0 +1,160 @@ |
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package tailcfg |
||||
|
||||
import ( |
||||
"errors" |
||||
"fmt" |
||||
"strconv" |
||||
"strings" |
||||
|
||||
"tailscale.com/types/ipproto" |
||||
"tailscale.com/util/vizerror" |
||||
) |
||||
|
||||
// ProtoPortRange is used to encode "proto:port" format.
|
||||
// The following formats are supported:
|
||||
//
|
||||
// "*" allows all TCP, UDP and ICMP traffic on all ports.
|
||||
// "<ports>" allows all TCP, UDP and ICMP traffic on the specified ports.
|
||||
// "proto:*" allows traffic of the specified proto on all ports.
|
||||
// "proto:<port>" allows traffic of the specified proto on the specified port.
|
||||
//
|
||||
// Ports are either a single port number or a range of ports (e.g. "80-90").
|
||||
// String named protocols support names that ipproto.Proto accepts.
|
||||
type ProtoPortRange struct { |
||||
// Proto is the IP protocol number.
|
||||
// If Proto is 0, it means TCP+UDP+ICMP(4+6).
|
||||
Proto int |
||||
Ports PortRange |
||||
} |
||||
|
||||
func (ppr ProtoPortRange) String() string { |
||||
if ppr.Proto == 0 { |
||||
if ppr.Ports == PortRangeAny { |
||||
return "*" |
||||
} |
||||
} |
||||
var buf strings.Builder |
||||
if ppr.Proto != 0 { |
||||
// Proto.MarshalText is infallible.
|
||||
text, _ := ipproto.Proto(ppr.Proto).MarshalText() |
||||
buf.Write(text) |
||||
buf.Write([]byte(":")) |
||||
} |
||||
pr := ppr.Ports |
||||
if pr.First == pr.Last { |
||||
fmt.Fprintf(&buf, "%d", pr.First) |
||||
} else if pr == PortRangeAny { |
||||
buf.WriteByte('*') |
||||
} else { |
||||
fmt.Fprintf(&buf, "%d-%d", pr.First, pr.Last) |
||||
} |
||||
return buf.String() |
||||
} |
||||
|
||||
// ParseProtoPortRanges parses a slice of IP port range fields.
|
||||
func ParseProtoPortRanges(ips []string) ([]ProtoPortRange, error) { |
||||
var out []ProtoPortRange |
||||
for _, p := range ips { |
||||
ppr, err := parseProtoPortRange(p) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
out = append(out, *ppr) |
||||
} |
||||
return out, nil |
||||
} |
||||
|
||||
func parseProtoPortRange(ipProtoPort string) (*ProtoPortRange, error) { |
||||
if ipProtoPort == "" { |
||||
return nil, errors.New("empty string") |
||||
} |
||||
if ipProtoPort == "*" { |
||||
return &ProtoPortRange{Ports: PortRangeAny}, nil |
||||
} |
||||
if !strings.Contains(ipProtoPort, ":") { |
||||
ipProtoPort = "*:" + ipProtoPort |
||||
} |
||||
protoStr, portRange, err := parseHostPortRange(ipProtoPort) |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
if protoStr == "" { |
||||
return nil, errors.New("empty protocol") |
||||
} |
||||
|
||||
ppr := &ProtoPortRange{ |
||||
Ports: portRange, |
||||
} |
||||
if protoStr == "*" { |
||||
return ppr, nil |
||||
} |
||||
var ipProto ipproto.Proto |
||||
if err := ipProto.UnmarshalText([]byte(protoStr)); err != nil { |
||||
return nil, err |
||||
} |
||||
ppr.Proto = int(ipProto) |
||||
return ppr, nil |
||||
} |
||||
|
||||
// parseHostPortRange parses hostport as HOST:PORTS where HOST is
|
||||
// returned unchanged and PORTS is is either "*" or PORTLOW-PORTHIGH ranges.
|
||||
func parseHostPortRange(hostport string) (host string, ports PortRange, err error) { |
||||
hostport = strings.ToLower(hostport) |
||||
colon := strings.LastIndexByte(hostport, ':') |
||||
if colon < 0 { |
||||
return "", ports, vizerror.New("hostport must contain a colon (\":\")") |
||||
} |
||||
host = hostport[:colon] |
||||
portlist := hostport[colon+1:] |
||||
|
||||
if strings.Contains(host, ",") { |
||||
return "", ports, vizerror.New("host cannot contain a comma (\",\")") |
||||
} |
||||
|
||||
if portlist == "*" { |
||||
// Special case: permit hostname:* as a port wildcard.
|
||||
return host, PortRangeAny, nil |
||||
} |
||||
|
||||
if len(portlist) == 0 { |
||||
return "", ports, vizerror.Errorf("invalid port list: %#v", portlist) |
||||
} |
||||
|
||||
if strings.Count(portlist, "-") > 1 { |
||||
return "", ports, vizerror.Errorf("port range %#v: too many dashes(-)", portlist) |
||||
} |
||||
|
||||
firstStr, lastStr, isRange := strings.Cut(portlist, "-") |
||||
|
||||
var first, last uint64 |
||||
first, err = strconv.ParseUint(firstStr, 10, 16) |
||||
if err != nil { |
||||
return "", ports, vizerror.Errorf("port range %#v: invalid first integer", portlist) |
||||
} |
||||
|
||||
if isRange { |
||||
last, err = strconv.ParseUint(lastStr, 10, 16) |
||||
if err != nil { |
||||
return "", ports, vizerror.Errorf("port range %#v: invalid last integer", portlist) |
||||
} |
||||
} else { |
||||
last = first |
||||
} |
||||
|
||||
if first == 0 { |
||||
return "", ports, vizerror.Errorf("port range %#v: first port must be >0, or use '*' for wildcard", portlist) |
||||
} |
||||
|
||||
if first > last { |
||||
return "", ports, vizerror.Errorf("port range %#v: first port must be >= last port", portlist) |
||||
} |
||||
|
||||
return host, newPortRange(uint16(first), uint16(last)), nil |
||||
} |
||||
|
||||
func newPortRange(first, last uint16) PortRange { |
||||
return PortRange{First: first, Last: last} |
||||
} |
||||
@ -0,0 +1,90 @@ |
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package tailcfg |
||||
|
||||
import ( |
||||
"errors" |
||||
"testing" |
||||
|
||||
"tailscale.com/types/ipproto" |
||||
) |
||||
|
||||
func TestProtoPortRangeParsing(t *testing.T) { |
||||
pr := func(s, e uint16) PortRange { |
||||
return PortRange{First: s, Last: e} |
||||
} |
||||
tests := []struct { |
||||
in string |
||||
out ProtoPortRange |
||||
err error |
||||
}{ |
||||
{in: "tcp:80", out: ProtoPortRange{Proto: int(ipproto.TCP), Ports: pr(80, 80)}}, |
||||
{in: "80", out: ProtoPortRange{Ports: pr(80, 80)}}, |
||||
{in: "*", out: ProtoPortRange{Ports: PortRangeAny}}, |
||||
{in: "*:*", out: ProtoPortRange{Ports: PortRangeAny}}, |
||||
{in: "tcp:*", out: ProtoPortRange{Proto: int(ipproto.TCP), Ports: PortRangeAny}}, |
||||
{ |
||||
in: "tcp:", |
||||
err: errors.New(`invalid port list: ""`), |
||||
}, |
||||
{ |
||||
in: ":80", |
||||
err: errors.New(`empty protocol`), |
||||
}, |
||||
{ |
||||
in: "", |
||||
err: errors.New(`empty string`), |
||||
}, |
||||
} |
||||
|
||||
for _, tc := range tests { |
||||
t.Run(tc.in, func(t *testing.T) { |
||||
ppr, err := parseProtoPortRange(tc.in) |
||||
if gotErr, wantErr := err != nil, tc.err != nil; gotErr != wantErr { |
||||
t.Fatalf("got err %v; want %v", err, tc.err) |
||||
} else if gotErr { |
||||
if err.Error() != tc.err.Error() { |
||||
t.Fatalf("got err %q; want %q", err, tc.err) |
||||
} |
||||
return |
||||
} |
||||
if *ppr != tc.out { |
||||
t.Fatalf("got %v; want %v", ppr, tc.out) |
||||
} |
||||
}) |
||||
} |
||||
} |
||||
|
||||
func TestProtoPortRangeString(t *testing.T) { |
||||
tests := []struct { |
||||
input ProtoPortRange |
||||
want string |
||||
}{ |
||||
{ProtoPortRange{}, "0"}, |
||||
|
||||
// Zero protocol.
|
||||
{ProtoPortRange{Ports: PortRangeAny}, "*"}, |
||||
{ProtoPortRange{Ports: PortRange{23, 23}}, "23"}, |
||||
{ProtoPortRange{Ports: PortRange{80, 120}}, "80-120"}, |
||||
|
||||
// Non-zero unnamed protocol.
|
||||
{ProtoPortRange{Proto: 100, Ports: PortRange{80, 80}}, "100:80"}, |
||||
{ProtoPortRange{Proto: 200, Ports: PortRange{101, 105}}, "200:101-105"}, |
||||
|
||||
// Non-zero named protocol.
|
||||
{ProtoPortRange{Proto: 1, Ports: PortRangeAny}, "icmp:*"}, |
||||
{ProtoPortRange{Proto: 2, Ports: PortRangeAny}, "igmp:*"}, |
||||
{ProtoPortRange{Proto: 6, Ports: PortRange{10, 13}}, "tcp:10-13"}, |
||||
{ProtoPortRange{Proto: 17, Ports: PortRangeAny}, "udp:*"}, |
||||
{ProtoPortRange{Proto: 0x84, Ports: PortRange{999, 999}}, "sctp:999"}, |
||||
{ProtoPortRange{Proto: 0x3a, Ports: PortRangeAny}, "ipv6-icmp:*"}, |
||||
{ProtoPortRange{Proto: 0x21, Ports: PortRangeAny}, "dccp:*"}, |
||||
{ProtoPortRange{Proto: 0x2f, Ports: PortRangeAny}, "gre:*"}, |
||||
} |
||||
for _, tc := range tests { |
||||
if got := tc.input.String(); got != tc.want { |
||||
t.Errorf("String for %v: got %q, want %q", tc.input, got, tc.want) |
||||
} |
||||
} |
||||
} |
||||
Loading…
Reference in new issue