cmd/tailscale/cli: prevent all dup flags, not just strings

The earlier #15534 prevent some dup string flags. This does it for all
flag types.

Updates #6813

Change-Id: Iec2871448394ea9a5b604310bdbf7b499434bf01
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:
Brad Fitzpatrick
2025-04-08 12:34:18 -07:00
committed by Brad Fitzpatrick
parent 8597b25840
commit 79ff067db3
5 changed files with 101 additions and 100 deletions
+36
View File
@@ -165,6 +165,41 @@ func Run(args []string) (err error) {
return err return err
} }
type onceFlagValue struct {
flag.Value
set bool
}
func (v *onceFlagValue) Set(s string) error {
if v.set {
return fmt.Errorf("flag provided multiple times")
}
v.set = true
return v.Value.Set(s)
}
func (v *onceFlagValue) IsBoolFlag() bool {
type boolFlag interface {
IsBoolFlag() bool
}
bf, ok := v.Value.(boolFlag)
return ok && bf.IsBoolFlag()
}
// noDupFlagify modifies c recursively to make all the
// flag values be wrappers that permit setting the value
// at most once.
func noDupFlagify(c *ffcli.Command) {
if c.FlagSet != nil {
c.FlagSet.VisitAll(func(f *flag.Flag) {
f.Value = &onceFlagValue{Value: f.Value}
})
}
for _, sub := range c.Subcommands {
noDupFlagify(sub)
}
}
func newRootCmd() *ffcli.Command { func newRootCmd() *ffcli.Command {
rootfs := newFlagSet("tailscale") rootfs := newFlagSet("tailscale")
rootfs.Func("socket", "path to tailscaled socket", func(s string) error { rootfs.Func("socket", "path to tailscaled socket", func(s string) error {
@@ -236,6 +271,7 @@ change in the future.
}) })
ffcomplete.Inject(rootCmd, func(c *ffcli.Command) { c.LongHelp = hidden + c.LongHelp }, usageFunc) ffcomplete.Inject(rootCmd, func(c *ffcli.Command) { c.LongHelp = hidden + c.LongHelp }, usageFunc)
noDupFlagify(rootCmd)
return rootCmd return rootCmd
} }
+52 -68
View File
@@ -657,13 +657,6 @@ func upArgsFromOSArgs(goos string, flagArgs ...string) (args upArgsT) {
return return
} }
func newSingleUseStringForTest(v string) singleUseStringFlag {
return singleUseStringFlag{
set: true,
value: v,
}
}
func TestPrefsFromUpArgs(t *testing.T) { func TestPrefsFromUpArgs(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
@@ -728,14 +721,14 @@ func TestPrefsFromUpArgs(t *testing.T) {
{ {
name: "error_advertise_route_invalid_ip", name: "error_advertise_route_invalid_ip",
args: upArgsT{ args: upArgsT{
advertiseRoutes: newSingleUseStringForTest("foo"), advertiseRoutes: "foo",
}, },
wantErr: `"foo" is not a valid IP address or CIDR prefix`, wantErr: `"foo" is not a valid IP address or CIDR prefix`,
}, },
{ {
name: "error_advertise_route_unmasked_bits", name: "error_advertise_route_unmasked_bits",
args: upArgsT{ args: upArgsT{
advertiseRoutes: newSingleUseStringForTest("1.2.3.4/16"), advertiseRoutes: "1.2.3.4/16",
}, },
wantErr: `1.2.3.4/16 has non-address bits set; expected 1.2.0.0/16`, wantErr: `1.2.3.4/16 has non-address bits set; expected 1.2.0.0/16`,
}, },
@@ -756,7 +749,7 @@ func TestPrefsFromUpArgs(t *testing.T) {
{ {
name: "error_tag_prefix", name: "error_tag_prefix",
args: upArgsT{ args: upArgsT{
advertiseTags: newSingleUseStringForTest("foo"), advertiseTags: "foo",
}, },
wantErr: `tag: "foo": tags must start with 'tag:'`, wantErr: `tag: "foo": tags must start with 'tag:'`,
}, },
@@ -836,7 +829,7 @@ func TestPrefsFromUpArgs(t *testing.T) {
name: "via_route_good", name: "via_route_good",
goos: "linux", goos: "linux",
args: upArgsT{ args: upArgsT{
advertiseRoutes: newSingleUseStringForTest("fd7a:115c:a1e0:b1a::bb:10.0.0.0/112"), advertiseRoutes: "fd7a:115c:a1e0:b1a::bb:10.0.0.0/112",
netfilterMode: "off", netfilterMode: "off",
}, },
want: &ipn.Prefs{ want: &ipn.Prefs{
@@ -855,7 +848,7 @@ func TestPrefsFromUpArgs(t *testing.T) {
name: "via_route_good_16_bit", name: "via_route_good_16_bit",
goos: "linux", goos: "linux",
args: upArgsT{ args: upArgsT{
advertiseRoutes: newSingleUseStringForTest("fd7a:115c:a1e0:b1a::aabb:10.0.0.0/112"), advertiseRoutes: "fd7a:115c:a1e0:b1a::aabb:10.0.0.0/112",
netfilterMode: "off", netfilterMode: "off",
}, },
want: &ipn.Prefs{ want: &ipn.Prefs{
@@ -874,7 +867,7 @@ func TestPrefsFromUpArgs(t *testing.T) {
name: "via_route_short_prefix", name: "via_route_short_prefix",
goos: "linux", goos: "linux",
args: upArgsT{ args: upArgsT{
advertiseRoutes: newSingleUseStringForTest("fd7a:115c:a1e0:b1a::/64"), advertiseRoutes: "fd7a:115c:a1e0:b1a::/64",
netfilterMode: "off", netfilterMode: "off",
}, },
wantErr: "fd7a:115c:a1e0:b1a::/64 4-in-6 prefix must be at least a /96", wantErr: "fd7a:115c:a1e0:b1a::/64 4-in-6 prefix must be at least a /96",
@@ -883,7 +876,7 @@ func TestPrefsFromUpArgs(t *testing.T) {
name: "via_route_short_reserved_siteid", name: "via_route_short_reserved_siteid",
goos: "linux", goos: "linux",
args: upArgsT{ args: upArgsT{
advertiseRoutes: newSingleUseStringForTest("fd7a:115c:a1e0:b1a:1234:5678::/112"), advertiseRoutes: "fd7a:115c:a1e0:b1a:1234:5678::/112",
netfilterMode: "off", netfilterMode: "off",
}, },
wantErr: "route fd7a:115c:a1e0:b1a:1234:5678::/112 contains invalid site ID 12345678; must be 0xffff or less", wantErr: "route fd7a:115c:a1e0:b1a:1234:5678::/112 contains invalid site ID 12345678; must be 0xffff or less",
@@ -1113,7 +1106,6 @@ func TestUpdatePrefs(t *testing.T) {
}, },
env: upCheckEnv{backendState: "Running"}, env: upCheckEnv{backendState: "Running"},
}, },
{ {
// Issue 3808: explicitly empty --operator= should clear value. // Issue 3808: explicitly empty --operator= should clear value.
name: "explicit_empty_operator", name: "explicit_empty_operator",
@@ -1507,6 +1499,51 @@ func TestParseNLArgs(t *testing.T) {
} }
} }
// makeQuietContinueOnError modifies c recursively to make all the
// flagsets have error mode flag.ContinueOnError and not
// spew all over stderr.
func makeQuietContinueOnError(c *ffcli.Command) {
if c.FlagSet != nil {
c.FlagSet.Init(c.Name, flag.ContinueOnError)
c.FlagSet.Usage = func() {}
c.FlagSet.SetOutput(io.Discard)
}
c.UsageFunc = func(*ffcli.Command) string { return "" }
for _, sub := range c.Subcommands {
makeQuietContinueOnError(sub)
}
}
// see tailscale/tailscale#6813
func TestNoDups(t *testing.T) {
tests := []struct {
name string
args []string
want string
}{
{
name: "dup-boolean",
args: []string{"up", "--json", "--json"},
want: "error parsing commandline arguments: invalid boolean flag json: flag provided multiple times",
},
{
name: "dup-string",
args: []string{"up", "--hostname=foo", "--hostname=bar"},
want: "error parsing commandline arguments: invalid value \"bar\" for flag -hostname: flag provided multiple times",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
cmd := newRootCmd()
makeQuietContinueOnError(cmd)
err := cmd.Parse(tt.args)
if got := fmt.Sprint(err); got != tt.want {
t.Errorf("got %q, want %q", got, tt.want)
}
})
}
}
func TestHelpAlias(t *testing.T) { func TestHelpAlias(t *testing.T) {
var stdout, stderr bytes.Buffer var stdout, stderr bytes.Buffer
tstest.Replace[io.Writer](t, &Stdout, &stdout) tstest.Replace[io.Writer](t, &Stdout, &stdout)
@@ -1606,56 +1643,3 @@ func TestDepsNoCapture(t *testing.T) {
}.Check(t) }.Check(t)
} }
func TestSingleUseStringFlag(t *testing.T) {
tests := []struct {
name string
setValues []string
wantValue string
wantErr bool
}{
{
name: "set once",
setValues: []string{"foo"},
wantValue: "foo",
wantErr: false,
},
{
name: "set twice",
setValues: []string{"foo", "bar"},
wantValue: "foo",
wantErr: true,
},
{
name: "set nothing",
setValues: []string{},
wantValue: "",
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var flag singleUseStringFlag
var lastErr error
for _, val := range tt.setValues {
lastErr = flag.Set(val)
}
if tt.wantErr {
if lastErr == nil {
t.Errorf("expected error on final Set, got nil")
}
} else {
if lastErr != nil {
t.Errorf("unexpected error on final Set: %v", lastErr)
}
}
if got := flag.String(); got != tt.wantValue {
t.Errorf("String() = %q, want %q", got, tt.wantValue)
}
})
}
}
+4 -4
View File
@@ -49,7 +49,7 @@ type setArgsT struct {
runSSH bool runSSH bool
runWebClient bool runWebClient bool
hostname string hostname string
advertiseRoutes singleUseStringFlag advertiseRoutes string
advertiseDefaultRoute bool advertiseDefaultRoute bool
advertiseConnector bool advertiseConnector bool
opUser string opUser string
@@ -75,7 +75,7 @@ func newSetFlagSet(goos string, setArgs *setArgsT) *flag.FlagSet {
setf.BoolVar(&setArgs.shieldsUp, "shields-up", false, "don't allow incoming connections") setf.BoolVar(&setArgs.shieldsUp, "shields-up", false, "don't allow incoming connections")
setf.BoolVar(&setArgs.runSSH, "ssh", false, "run an SSH server, permitting access per tailnet admin's declared policy") setf.BoolVar(&setArgs.runSSH, "ssh", false, "run an SSH server, permitting access per tailnet admin's declared policy")
setf.StringVar(&setArgs.hostname, "hostname", "", "hostname to use instead of the one provided by the OS") setf.StringVar(&setArgs.hostname, "hostname", "", "hostname to use instead of the one provided by the OS")
setf.Var(&setArgs.advertiseRoutes, "advertise-routes", "routes to advertise to other nodes (comma-separated, e.g. \"10.0.0.0/8,192.168.0.0/24\") or empty string to not advertise routes") setf.StringVar(&setArgs.advertiseRoutes, "advertise-routes", "", "routes to advertise to other nodes (comma-separated, e.g. \"10.0.0.0/8,192.168.0.0/24\") or empty string to not advertise routes")
setf.BoolVar(&setArgs.advertiseDefaultRoute, "advertise-exit-node", false, "offer to be an exit node for internet traffic for the tailnet") setf.BoolVar(&setArgs.advertiseDefaultRoute, "advertise-exit-node", false, "offer to be an exit node for internet traffic for the tailnet")
setf.BoolVar(&setArgs.advertiseConnector, "advertise-connector", false, "offer to be an app connector for domain specific internet traffic for the tailnet") setf.BoolVar(&setArgs.advertiseConnector, "advertise-connector", false, "offer to be an app connector for domain specific internet traffic for the tailnet")
setf.BoolVar(&setArgs.updateCheck, "update-check", true, "notify about available Tailscale updates") setf.BoolVar(&setArgs.updateCheck, "update-check", true, "notify about available Tailscale updates")
@@ -259,11 +259,11 @@ func runSet(ctx context.Context, args []string) (retErr error) {
// setArgs is the parsed command-line arguments. // setArgs is the parsed command-line arguments.
func calcAdvertiseRoutesForSet(advertiseExitNodeSet, advertiseRoutesSet bool, curPrefs *ipn.Prefs, setArgs setArgsT) (routes []netip.Prefix, err error) { func calcAdvertiseRoutesForSet(advertiseExitNodeSet, advertiseRoutesSet bool, curPrefs *ipn.Prefs, setArgs setArgsT) (routes []netip.Prefix, err error) {
if advertiseExitNodeSet && advertiseRoutesSet { if advertiseExitNodeSet && advertiseRoutesSet {
return netutil.CalcAdvertiseRoutes(setArgs.advertiseRoutes.String(), setArgs.advertiseDefaultRoute) return netutil.CalcAdvertiseRoutes(setArgs.advertiseRoutes, setArgs.advertiseDefaultRoute)
} }
if advertiseRoutesSet { if advertiseRoutesSet {
return netutil.CalcAdvertiseRoutes(setArgs.advertiseRoutes.String(), curPrefs.AdvertisesExitNode()) return netutil.CalcAdvertiseRoutes(setArgs.advertiseRoutes, curPrefs.AdvertisesExitNode())
} }
if advertiseExitNodeSet { if advertiseExitNodeSet {
alreadyAdvertisesExitNode := curPrefs.AdvertisesExitNode() alreadyAdvertisesExitNode := curPrefs.AdvertisesExitNode()
+1 -1
View File
@@ -116,7 +116,7 @@ func TestCalcAdvertiseRoutesForSet(t *testing.T) {
sa.advertiseDefaultRoute = *tc.setExit sa.advertiseDefaultRoute = *tc.setExit
} }
if tc.setRoutes != nil { if tc.setRoutes != nil {
sa.advertiseRoutes = newSingleUseStringForTest(*tc.setRoutes) sa.advertiseRoutes = *tc.setRoutes
} }
got, err := calcAdvertiseRoutesForSet(tc.setExit != nil, tc.setRoutes != nil, curPrefs, sa) got, err := calcAdvertiseRoutesForSet(tc.setExit != nil, tc.setRoutes != nil, curPrefs, sa)
if err != nil { if err != nil {
+8 -27
View File
@@ -82,25 +82,6 @@ func acceptRouteDefault(goos string) bool {
return p.DefaultRouteAll(goos) return p.DefaultRouteAll(goos)
} }
// singleUseStringFlag will throw an error if the flag is specified more than once.
type singleUseStringFlag struct {
set bool
value string
}
func (s singleUseStringFlag) String() string {
return s.value
}
func (s *singleUseStringFlag) Set(v string) error {
if s.set {
return errors.New("flag can only be specified once")
}
s.set = true
s.value = v
return nil
}
var upFlagSet = newUpFlagSet(effectiveGOOS(), &upArgsGlobal, "up") var upFlagSet = newUpFlagSet(effectiveGOOS(), &upArgsGlobal, "up")
// newUpFlagSet returns a new flag set for the "up" and "login" commands. // newUpFlagSet returns a new flag set for the "up" and "login" commands.
@@ -123,9 +104,9 @@ func newUpFlagSet(goos string, upArgs *upArgsT, cmd string) *flag.FlagSet {
upf.BoolVar(&upArgs.exitNodeAllowLANAccess, "exit-node-allow-lan-access", false, "Allow direct access to the local network when routing traffic via an exit node") upf.BoolVar(&upArgs.exitNodeAllowLANAccess, "exit-node-allow-lan-access", false, "Allow direct access to the local network when routing traffic via an exit node")
upf.BoolVar(&upArgs.shieldsUp, "shields-up", false, "don't allow incoming connections") upf.BoolVar(&upArgs.shieldsUp, "shields-up", false, "don't allow incoming connections")
upf.BoolVar(&upArgs.runSSH, "ssh", false, "run an SSH server, permitting access per tailnet admin's declared policy") upf.BoolVar(&upArgs.runSSH, "ssh", false, "run an SSH server, permitting access per tailnet admin's declared policy")
upf.Var(&upArgs.advertiseTags, "advertise-tags", "comma-separated ACL tags to request; each must start with \"tag:\" (e.g. \"tag:eng,tag:montreal,tag:ssh\")") upf.StringVar(&upArgs.advertiseTags, "advertise-tags", "", "comma-separated ACL tags to request; each must start with \"tag:\" (e.g. \"tag:eng,tag:montreal,tag:ssh\")")
upf.StringVar(&upArgs.hostname, "hostname", "", "hostname to use instead of the one provided by the OS") upf.StringVar(&upArgs.hostname, "hostname", "", "hostname to use instead of the one provided by the OS")
upf.Var(&upArgs.advertiseRoutes, "advertise-routes", "routes to advertise to other nodes (comma-separated, e.g. \"10.0.0.0/8,192.168.0.0/24\") or empty string to not advertise routes") upf.StringVar(&upArgs.advertiseRoutes, "advertise-routes", "", "routes to advertise to other nodes (comma-separated, e.g. \"10.0.0.0/8,192.168.0.0/24\") or empty string to not advertise routes")
upf.BoolVar(&upArgs.advertiseConnector, "advertise-connector", false, "advertise this node as an app connector") upf.BoolVar(&upArgs.advertiseConnector, "advertise-connector", false, "advertise this node as an app connector")
upf.BoolVar(&upArgs.advertiseDefaultRoute, "advertise-exit-node", false, "offer to be an exit node for internet traffic for the tailnet") upf.BoolVar(&upArgs.advertiseDefaultRoute, "advertise-exit-node", false, "offer to be an exit node for internet traffic for the tailnet")
upf.BoolVar(&upArgs.postureChecking, "posture-checking", false, hidden+"allow management plane to gather device posture information") upf.BoolVar(&upArgs.postureChecking, "posture-checking", false, hidden+"allow management plane to gather device posture information")
@@ -193,9 +174,9 @@ type upArgsT struct {
runWebClient bool runWebClient bool
forceReauth bool forceReauth bool
forceDaemon bool forceDaemon bool
advertiseRoutes singleUseStringFlag advertiseRoutes string
advertiseDefaultRoute bool advertiseDefaultRoute bool
advertiseTags singleUseStringFlag advertiseTags string
advertiseConnector bool advertiseConnector bool
snat bool snat bool
statefulFiltering bool statefulFiltering bool
@@ -263,7 +244,7 @@ func warnf(format string, args ...any) {
// function exists for testing and should have no side effects or // function exists for testing and should have no side effects or
// outside interactions (e.g. no making Tailscale LocalAPI calls). // outside interactions (e.g. no making Tailscale LocalAPI calls).
func prefsFromUpArgs(upArgs upArgsT, warnf logger.Logf, st *ipnstate.Status, goos string) (*ipn.Prefs, error) { func prefsFromUpArgs(upArgs upArgsT, warnf logger.Logf, st *ipnstate.Status, goos string) (*ipn.Prefs, error) {
routes, err := netutil.CalcAdvertiseRoutes(upArgs.advertiseRoutes.String(), upArgs.advertiseDefaultRoute) routes, err := netutil.CalcAdvertiseRoutes(upArgs.advertiseRoutes, upArgs.advertiseDefaultRoute)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -273,8 +254,8 @@ func prefsFromUpArgs(upArgs upArgsT, warnf logger.Logf, st *ipnstate.Status, goo
} }
var tags []string var tags []string
if upArgs.advertiseTags.String() != "" { if upArgs.advertiseTags != "" {
tags = strings.Split(upArgs.advertiseTags.String(), ",") tags = strings.Split(upArgs.advertiseTags, ",")
for _, tag := range tags { for _, tag := range tags {
err := tailcfg.CheckTag(tag) err := tailcfg.CheckTag(tag)
if err != nil { if err != nil {
@@ -574,7 +555,7 @@ func runUp(ctx context.Context, cmd string, args []string, upArgs upArgsT) (retE
if err != nil { if err != nil {
return err return err
} }
authKey, err = resolveAuthKey(ctx, authKey, upArgs.advertiseTags.String()) authKey, err = resolveAuthKey(ctx, authKey, upArgs.advertiseTags)
if err != nil { if err != nil {
return err return err
} }