diff --git a/client/web/web.go b/client/web/web.go index 3e5fa4b54..cebea60fc 100644 --- a/client/web/web.go +++ b/client/web/web.go @@ -1140,6 +1140,9 @@ type postRoutesRequest struct { } func (s *Server) servePostRoutes(ctx context.Context, data postRoutesRequest) error { + if !data.SetExitNode && !data.SetRoutes { + return errors.New("must specify SetExitNode or SetRoutes") + } prefs, err := s.lc.GetPrefs(ctx) if err != nil { return err @@ -1153,13 +1156,14 @@ func (s *Server) servePostRoutes(ctx context.Context, data postRoutesRequest) er } currNonExitRoutes = append(currNonExitRoutes, r.String()) } - // Set non-edited fields to their current values. - if data.SetExitNode { - data.AdvertiseRoutes = currNonExitRoutes - } else if data.SetRoutes { + // For each group of fields not being set, preserve the current prefs. + if !data.SetExitNode { data.AdvertiseExitNode = currAdvertisingExitNode data.UseExitNode = prefs.ExitNodeID } + if !data.SetRoutes { + data.AdvertiseRoutes = currNonExitRoutes + } // Calculate routes. routesStr := strings.Join(data.AdvertiseRoutes, ",") diff --git a/client/web/web_test.go b/client/web/web_test.go index 032cd5222..fbf459545 100644 --- a/client/web/web_test.go +++ b/client/web/web_test.go @@ -1604,3 +1604,125 @@ func TestCSRFProtect(t *testing.T) { }) } } + +func TestServePostRoutes(t *testing.T) { + existingExitNodeID := tailcfg.StableNodeID("existing-exit-node") + existingRoute := netip.MustParsePrefix("192.168.1.0/24") + + existingPrefs := &ipn.Prefs{ + ExitNodeID: existingExitNodeID, + AdvertiseRoutes: []netip.Prefix{existingRoute}, + } + + tests := []struct { + name string + data postRoutesRequest + wantErr bool + wantEditPrefs bool // whether EditPrefs (PATCH /prefs) should be called + wantExitNodeID tailcfg.StableNodeID + wantRoutes []netip.Prefix + }{ + { + name: "empty-request", + data: postRoutesRequest{}, + wantErr: true, + wantEditPrefs: false, + }, + { + name: "SetExitNode-only", + data: postRoutesRequest{ + SetExitNode: true, + UseExitNode: "new-exit-node", + }, + wantEditPrefs: true, + wantExitNodeID: "new-exit-node", + wantRoutes: []netip.Prefix{existingRoute}, + }, + { + name: "SetRoutes-only", + data: postRoutesRequest{ + SetRoutes: true, + AdvertiseRoutes: []string{"10.0.0.0/8"}, + }, + wantEditPrefs: true, + wantExitNodeID: existingExitNodeID, + wantRoutes: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")}, + }, + { + name: "SetExitNode-and-SetRoutes", + data: postRoutesRequest{ + SetExitNode: true, + SetRoutes: true, + UseExitNode: "new-exit-node", + AdvertiseRoutes: []string{"10.0.0.0/8"}, + }, + wantEditPrefs: true, + wantExitNodeID: "new-exit-node", + wantRoutes: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var gotPrefs *ipn.MaskedPrefs + + lal := memnet.Listen("local-tailscaled.sock:80") + defer lal.Close() + + localapi := &http.Server{Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/localapi/v0/prefs" { + t.Errorf("unexpected localapi call to %q", r.URL.Path) + http.Error(w, "unexpected localapi call", http.StatusInternalServerError) + return + } + switch r.Method { + case httpm.GET: + writeJSON(w, existingPrefs) + case httpm.PATCH: + var mp ipn.MaskedPrefs + if err := json.NewDecoder(r.Body).Decode(&mp); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + gotPrefs = &mp + writeJSON(w, gotPrefs.Prefs) + default: + t.Errorf("unexpected method %q on /prefs", r.Method) + http.Error(w, "unexpected method", http.StatusMethodNotAllowed) + } + })} + defer localapi.Close() + go localapi.Serve(lal) + + s := &Server{ + mode: ManageServerMode, + lc: &local.Client{Dial: lal.Dial}, + } + + err := s.servePostRoutes(context.Background(), tt.data) + + if tt.wantErr { + if err == nil { + t.Error("wanted error, got nil") + } + if gotPrefs != nil { + t.Error("EditPrefs should not have been called on error") + } + return + } + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if gotPrefs == nil { + t.Fatal("expected EditPrefs to be called") + } + if diff := cmp.Diff(tt.wantExitNodeID, gotPrefs.ExitNodeID); diff != "" { + t.Errorf("ExitNodeID mismatch (-want +got):\n%s", diff) + } + if diff := cmp.Diff(tt.wantRoutes, gotPrefs.AdvertiseRoutes, cmp.Comparer(func(a, b netip.Prefix) bool { return a.Compare(b) == 0 })); diff != "" { + t.Errorf("AdvertiseRoutes mismatch (-want +got):\n%s", diff) + } + }) + } +}