control/controlclient,ipn/ipnlocal: replace State enum with boolean flags
Remove the State enum (StateNew, StateNotAuthenticated, etc.) from controlclient and replace it with two explicit boolean fields: - LoginFinished: indicates successful authentication - Synced: indicates we've received at least one netmap This makes the state more composable and easier to reason about, as multiple conditions can be true independently rather than being encoded in a single enum value. The State enum was originally intended as the state machine for the whole client, but that abstraction moved to ipn.Backend long ago. This change continues moving away from the legacy state machine by representing state as a combination of independent facts. Also adds test helpers in ipnlocal that check independent, observable facts (hasValidNetMap, needsLogin, etc.) rather than relying on derived state enums, making tests more robust. Updates #12639 Signed-off-by: James Tucker <james@tailscale.com>
This commit is contained in:
committed by
James Tucker
parent
c5919b4ed1
commit
a96ef432cf
@@ -138,7 +138,6 @@ type Auto struct {
|
||||
loggedIn bool // true if currently logged in
|
||||
loginGoal *LoginGoal // non-nil if some login activity is desired
|
||||
inMapPoll bool // true once we get the first MapResponse in a stream; false when HTTP response ends
|
||||
state State // TODO(bradfitz): delete this, make it computed by method from other state
|
||||
|
||||
authCtx context.Context // context used for auth requests
|
||||
mapCtx context.Context // context used for netmap and update requests
|
||||
@@ -296,10 +295,11 @@ func (c *Auto) authRoutine() {
|
||||
c.mu.Lock()
|
||||
goal := c.loginGoal
|
||||
ctx := c.authCtx
|
||||
loggedIn := c.loggedIn
|
||||
if goal != nil {
|
||||
c.logf("[v1] authRoutine: %s; wantLoggedIn=%v", c.state, true)
|
||||
c.logf("[v1] authRoutine: loggedIn=%v; wantLoggedIn=%v", loggedIn, true)
|
||||
} else {
|
||||
c.logf("[v1] authRoutine: %s; goal=nil paused=%v", c.state, c.paused)
|
||||
c.logf("[v1] authRoutine: loggedIn=%v; goal=nil paused=%v", loggedIn, c.paused)
|
||||
}
|
||||
c.mu.Unlock()
|
||||
|
||||
@@ -322,11 +322,6 @@ func (c *Auto) authRoutine() {
|
||||
|
||||
c.mu.Lock()
|
||||
c.urlToVisit = goal.url
|
||||
if goal.url != "" {
|
||||
c.state = StateURLVisitRequired
|
||||
} else {
|
||||
c.state = StateAuthenticating
|
||||
}
|
||||
c.mu.Unlock()
|
||||
|
||||
var url string
|
||||
@@ -360,7 +355,6 @@ func (c *Auto) authRoutine() {
|
||||
flags: LoginDefault,
|
||||
url: url,
|
||||
}
|
||||
c.state = StateURLVisitRequired
|
||||
c.mu.Unlock()
|
||||
|
||||
c.sendStatus("authRoutine-url", err, url, nil)
|
||||
@@ -380,7 +374,6 @@ func (c *Auto) authRoutine() {
|
||||
c.urlToVisit = ""
|
||||
c.loggedIn = true
|
||||
c.loginGoal = nil
|
||||
c.state = StateAuthenticated
|
||||
c.mu.Unlock()
|
||||
|
||||
c.sendStatus("authRoutine-success", nil, "", nil)
|
||||
@@ -431,12 +424,9 @@ func (mrs mapRoutineState) UpdateFullNetmap(nm *netmap.NetworkMap) {
|
||||
|
||||
c.mu.Lock()
|
||||
c.inMapPoll = true
|
||||
if c.loggedIn {
|
||||
c.state = StateSynchronized
|
||||
}
|
||||
c.expiry = nm.Expiry
|
||||
stillAuthed := c.loggedIn
|
||||
c.logf("[v1] mapRoutine: netmap received: %s", c.state)
|
||||
c.logf("[v1] mapRoutine: netmap received: loggedIn=%v inMapPoll=true", stillAuthed)
|
||||
c.mu.Unlock()
|
||||
|
||||
if stillAuthed {
|
||||
@@ -484,8 +474,8 @@ func (c *Auto) mapRoutine() {
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
c.logf("[v1] mapRoutine: %s", c.state)
|
||||
loggedIn := c.loggedIn
|
||||
c.logf("[v1] mapRoutine: loggedIn=%v", loggedIn)
|
||||
ctx := c.mapCtx
|
||||
c.mu.Unlock()
|
||||
|
||||
@@ -516,9 +506,6 @@ func (c *Auto) mapRoutine() {
|
||||
c.direct.health.SetOutOfPollNetMap()
|
||||
c.mu.Lock()
|
||||
c.inMapPoll = false
|
||||
if c.state == StateSynchronized {
|
||||
c.state = StateAuthenticated
|
||||
}
|
||||
paused := c.paused
|
||||
c.mu.Unlock()
|
||||
|
||||
@@ -584,12 +571,12 @@ func (c *Auto) sendStatus(who string, err error, url string, nm *netmap.NetworkM
|
||||
c.mu.Unlock()
|
||||
return
|
||||
}
|
||||
state := c.state
|
||||
loggedIn := c.loggedIn
|
||||
inMapPoll := c.inMapPoll
|
||||
loginGoal := c.loginGoal
|
||||
c.mu.Unlock()
|
||||
|
||||
c.logf("[v1] sendStatus: %s: %v", who, state)
|
||||
c.logf("[v1] sendStatus: %s: loggedIn=%v inMapPoll=%v", who, loggedIn, inMapPoll)
|
||||
|
||||
var p persist.PersistView
|
||||
if nm != nil && loggedIn && inMapPoll {
|
||||
@@ -600,11 +587,12 @@ func (c *Auto) sendStatus(who string, err error, url string, nm *netmap.NetworkM
|
||||
nm = nil
|
||||
}
|
||||
newSt := &Status{
|
||||
URL: url,
|
||||
Persist: p,
|
||||
NetMap: nm,
|
||||
Err: err,
|
||||
state: state,
|
||||
URL: url,
|
||||
Persist: p,
|
||||
NetMap: nm,
|
||||
Err: err,
|
||||
LoggedIn: loggedIn && loginGoal == nil,
|
||||
InMapPoll: inMapPoll,
|
||||
}
|
||||
|
||||
if c.observer == nil {
|
||||
@@ -667,14 +655,15 @@ func canSkipStatus(s1, s2 *Status) bool {
|
||||
// we can't skip it.
|
||||
return false
|
||||
}
|
||||
if s1.Err != nil || s1.URL != "" {
|
||||
// If s1 has an error or a URL, we shouldn't skip it, lest the error go
|
||||
// away in s2 or in-between. We want to make sure all the subsystems see
|
||||
// it. Plus there aren't many of these, so not worth skipping.
|
||||
if s1.Err != nil || s1.URL != "" || s1.LoggedIn {
|
||||
// If s1 has an error, a URL, or LoginFinished set, we shouldn't skip it,
|
||||
// lest the error go away in s2 or in-between. We want to make sure all
|
||||
// the subsystems see it. Plus there aren't many of these, so not worth
|
||||
// skipping.
|
||||
return false
|
||||
}
|
||||
if !s1.Persist.Equals(s2.Persist) || s1.state != s2.state {
|
||||
// If s1 has a different Persist or state than s2,
|
||||
if !s1.Persist.Equals(s2.Persist) || s1.LoggedIn != s2.LoggedIn || s1.InMapPoll != s2.InMapPoll || s1.URL != s2.URL {
|
||||
// If s1 has a different Persist, LoginFinished, Synced, or URL than s2,
|
||||
// don't skip it. We only care about skipping the typical
|
||||
// entries where the only difference is the NetMap.
|
||||
return false
|
||||
@@ -736,7 +725,6 @@ func (c *Auto) Logout(ctx context.Context) error {
|
||||
}
|
||||
c.mu.Lock()
|
||||
c.loggedIn = false
|
||||
c.state = StateNotAuthenticated
|
||||
c.cancelAuthCtxLocked()
|
||||
c.cancelMapCtxLocked()
|
||||
c.mu.Unlock()
|
||||
|
||||
@@ -15,7 +15,6 @@ import (
|
||||
"net/netip"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"slices"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -49,7 +48,7 @@ func fieldsOf(t reflect.Type) (fields []string) {
|
||||
|
||||
func TestStatusEqual(t *testing.T) {
|
||||
// Verify that the Equal method stays in sync with reality
|
||||
equalHandles := []string{"Err", "URL", "NetMap", "Persist", "state"}
|
||||
equalHandles := []string{"Err", "URL", "LoggedIn", "InMapPoll", "NetMap", "Persist"}
|
||||
if have := fieldsOf(reflect.TypeFor[Status]()); !reflect.DeepEqual(have, equalHandles) {
|
||||
t.Errorf("Status.Equal check might be out of sync\nfields: %q\nhandled: %q\n",
|
||||
have, equalHandles)
|
||||
@@ -81,7 +80,7 @@ func TestStatusEqual(t *testing.T) {
|
||||
},
|
||||
{
|
||||
&Status{},
|
||||
&Status{state: StateAuthenticated},
|
||||
&Status{LoggedIn: true, Persist: new(persist.Persist).View()},
|
||||
false,
|
||||
},
|
||||
}
|
||||
@@ -135,8 +134,20 @@ func TestCanSkipStatus(t *testing.T) {
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "s1-state-diff",
|
||||
s1: &Status{state: 123, NetMap: nm1},
|
||||
name: "s1-login-finished-diff",
|
||||
s1: &Status{LoggedIn: true, Persist: new(persist.Persist).View(), NetMap: nm1},
|
||||
s2: &Status{NetMap: nm2},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "s1-login-finished",
|
||||
s1: &Status{LoggedIn: true, Persist: new(persist.Persist).View(), NetMap: nm1},
|
||||
s2: &Status{NetMap: nm2},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "s1-synced-diff",
|
||||
s1: &Status{InMapPoll: true, LoggedIn: true, Persist: new(persist.Persist).View(), NetMap: nm1},
|
||||
s2: &Status{NetMap: nm2},
|
||||
want: false,
|
||||
},
|
||||
@@ -167,10 +178,11 @@ func TestCanSkipStatus(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
want := []string{"Err", "URL", "NetMap", "Persist", "state"}
|
||||
if f := fieldsOf(reflect.TypeFor[Status]()); !slices.Equal(f, want) {
|
||||
t.Errorf("Status fields = %q; this code was only written to handle fields %q", f, want)
|
||||
coveredFields := []string{"Err", "URL", "LoggedIn", "InMapPoll", "NetMap", "Persist"}
|
||||
if have := fieldsOf(reflect.TypeFor[Status]()); !reflect.DeepEqual(have, coveredFields) {
|
||||
t.Errorf("Status fields = %q; this code was only written to handle fields %q", have, coveredFields)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestRetryableErrors(t *testing.T) {
|
||||
|
||||
@@ -4,8 +4,6 @@
|
||||
package controlclient
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"reflect"
|
||||
|
||||
"tailscale.com/types/netmap"
|
||||
@@ -13,57 +11,6 @@ import (
|
||||
"tailscale.com/types/structs"
|
||||
)
|
||||
|
||||
// State is the high-level state of the client. It is used only in
|
||||
// unit tests for proper sequencing, don't depend on it anywhere else.
|
||||
//
|
||||
// TODO(apenwarr): eliminate the state, as it's now obsolete.
|
||||
//
|
||||
// apenwarr: Historical note: controlclient.Auto was originally
|
||||
// intended to be the state machine for the whole tailscale client, but that
|
||||
// turned out to not be the right abstraction layer, and it moved to
|
||||
// ipn.Backend. Since ipn.Backend now has a state machine, it would be
|
||||
// much better if controlclient could be a simple stateless API. But the
|
||||
// current server-side API (two interlocking polling https calls) makes that
|
||||
// very hard to implement. A server side API change could untangle this and
|
||||
// remove all the statefulness.
|
||||
type State int
|
||||
|
||||
const (
|
||||
StateNew = State(iota)
|
||||
StateNotAuthenticated
|
||||
StateAuthenticating
|
||||
StateURLVisitRequired
|
||||
StateAuthenticated
|
||||
StateSynchronized // connected and received map update
|
||||
)
|
||||
|
||||
func (s State) AppendText(b []byte) ([]byte, error) {
|
||||
return append(b, s.String()...), nil
|
||||
}
|
||||
|
||||
func (s State) MarshalText() ([]byte, error) {
|
||||
return []byte(s.String()), nil
|
||||
}
|
||||
|
||||
func (s State) String() string {
|
||||
switch s {
|
||||
case StateNew:
|
||||
return "state:new"
|
||||
case StateNotAuthenticated:
|
||||
return "state:not-authenticated"
|
||||
case StateAuthenticating:
|
||||
return "state:authenticating"
|
||||
case StateURLVisitRequired:
|
||||
return "state:url-visit-required"
|
||||
case StateAuthenticated:
|
||||
return "state:authenticated"
|
||||
case StateSynchronized:
|
||||
return "state:synchronized"
|
||||
default:
|
||||
return fmt.Sprintf("state:unknown:%d", int(s))
|
||||
}
|
||||
}
|
||||
|
||||
type Status struct {
|
||||
_ structs.Incomparable
|
||||
|
||||
@@ -76,6 +23,14 @@ type Status struct {
|
||||
// URL, if non-empty, is the interactive URL to visit to finish logging in.
|
||||
URL string
|
||||
|
||||
// LoggedIn, if true, indicates that serveRegister has completed and no
|
||||
// other login change is in progress.
|
||||
LoggedIn bool
|
||||
|
||||
// InMapPoll, if true, indicates that we've received at least one netmap
|
||||
// and are connected to receive updates.
|
||||
InMapPoll bool
|
||||
|
||||
// NetMap is the latest server-pushed state of the tailnet network.
|
||||
NetMap *netmap.NetworkMap
|
||||
|
||||
@@ -83,26 +38,8 @@ type Status struct {
|
||||
//
|
||||
// TODO(bradfitz,maisem): clarify this.
|
||||
Persist persist.PersistView
|
||||
|
||||
// state is the internal state. It should not be exposed outside this
|
||||
// package, but we have some automated tests elsewhere that need to
|
||||
// use it via the StateForTest accessor.
|
||||
// TODO(apenwarr): Unexport or remove these.
|
||||
state State
|
||||
}
|
||||
|
||||
// LoginFinished reports whether the controlclient is in its "StateAuthenticated"
|
||||
// state where it's in a happy register state but not yet in a map poll.
|
||||
//
|
||||
// TODO(bradfitz): delete this and everything around Status.state.
|
||||
func (s *Status) LoginFinished() bool { return s.state == StateAuthenticated }
|
||||
|
||||
// StateForTest returns the internal state of s for tests only.
|
||||
func (s *Status) StateForTest() State { return s.state }
|
||||
|
||||
// SetStateForTest sets the internal state of s for tests only.
|
||||
func (s *Status) SetStateForTest(state State) { s.state = state }
|
||||
|
||||
// Equal reports whether s and s2 are equal.
|
||||
func (s *Status) Equal(s2 *Status) bool {
|
||||
if s == nil && s2 == nil {
|
||||
@@ -111,15 +48,8 @@ func (s *Status) Equal(s2 *Status) bool {
|
||||
return s != nil && s2 != nil &&
|
||||
s.Err == s2.Err &&
|
||||
s.URL == s2.URL &&
|
||||
s.state == s2.state &&
|
||||
s.LoggedIn == s2.LoggedIn &&
|
||||
s.InMapPoll == s2.InMapPoll &&
|
||||
reflect.DeepEqual(s.Persist, s2.Persist) &&
|
||||
reflect.DeepEqual(s.NetMap, s2.NetMap)
|
||||
}
|
||||
|
||||
func (s Status) String() string {
|
||||
b, err := json.MarshalIndent(s, "", "\t")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return s.state.String() + " " + string(b)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user