client/web: add Sec-Fetch-Site CSRF protection (#16046)

RELNOTE=Fix CSRF errors in the client Web UI

Replace gorilla/csrf with a Sec-Fetch-Site based CSRF protection
middleware that falls back to comparing the Host & Origin headers if no
SFS value is passed by the client.

Add an -origin override to the web CLI that allows callers to specify
the origin at which the web UI will be available if it is hosted behind
a reverse proxy or within another application via CGI.

Updates #14872
Updates #15065

Signed-off-by: Patrick O'Doherty <patrick@tailscale.com>
This commit is contained in:
Patrick O'Doherty
2025-05-22 12:26:02 -07:00
committed by GitHub
parent 3ee4c60ff0
commit a05924a9e5
8 changed files with 184 additions and 169 deletions
+89 -74
View File
@@ -11,7 +11,6 @@ import (
"fmt"
"io"
"net/http"
"net/http/cookiejar"
"net/http/httptest"
"net/netip"
"net/url"
@@ -21,14 +20,12 @@ import (
"time"
"github.com/google/go-cmp/cmp"
"github.com/gorilla/csrf"
"tailscale.com/client/local"
"tailscale.com/client/tailscale/apitype"
"tailscale.com/ipn"
"tailscale.com/ipn/ipnstate"
"tailscale.com/net/memnet"
"tailscale.com/tailcfg"
"tailscale.com/tstest/nettest"
"tailscale.com/types/views"
"tailscale.com/util/httpm"
)
@@ -1492,81 +1489,99 @@ func mockWaitAuthURL(_ context.Context, id string, src tailcfg.NodeID) (*tailcfg
}
func TestCSRFProtect(t *testing.T) {
s := &Server{}
mux := http.NewServeMux()
mux.HandleFunc("GET /test/csrf-token", func(w http.ResponseWriter, r *http.Request) {
token := csrf.Token(r)
_, err := io.WriteString(w, token)
if err != nil {
t.Fatal(err)
}
})
mux.HandleFunc("POST /test/csrf-protected", func(w http.ResponseWriter, r *http.Request) {
_, err := io.WriteString(w, "ok")
if err != nil {
t.Fatal(err)
}
})
h := s.withCSRF(mux)
ser := nettest.NewHTTPServer(nettest.GetNetwork(t), h)
defer ser.Close()
jar, err := cookiejar.New(nil)
if err != nil {
t.Fatalf("unable to construct cookie jar: %v", err)
tests := []struct {
name string
method string
secFetchSite string
host string
origin string
originOverride string
wantError bool
}{
{
name: "GET requests with no header are allowed",
method: "GET",
},
{
name: "POST requests with same-origin are allowed",
method: "POST",
secFetchSite: "same-origin",
},
{
name: "POST requests with cross-site are not allowed",
method: "POST",
secFetchSite: "cross-site",
wantError: true,
},
{
name: "POST requests with unknown sec-fetch-site values are not allowed",
method: "POST",
secFetchSite: "new-unknown-value",
wantError: true,
},
{
name: "POST requests with none are not allowed",
method: "POST",
secFetchSite: "none",
wantError: true,
},
{
name: "POST requests with no sec-fetch-site header but matching host and origin are allowed",
method: "POST",
host: "example.com",
origin: "https://example.com",
},
{
name: "POST requests with no sec-fetch-site and non-matching host and origin are not allowed",
method: "POST",
host: "example.com",
origin: "https://example.net",
wantError: true,
},
{
name: "POST requests with no sec-fetch-site and and origin that matches the override are allowed",
method: "POST",
originOverride: "example.net",
host: "internal.example.foo", // Host can be changed by reverse proxies
origin: "http://example.net",
},
}
client := ser.Client()
client.Jar = jar
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "OK")
})
// make GET request to populate cookie jar
resp, err := client.Get(ser.URL + "/test/csrf-token")
if err != nil {
t.Fatalf("unable to make request: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("unexpected status: %v", resp.Status)
}
tokenBytes, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("unable to read body: %v", err)
}
s := &Server{
originOverride: tt.originOverride,
}
withCSRF := s.csrfProtect(handler)
csrfToken := strings.TrimSpace(string(tokenBytes))
if csrfToken == "" {
t.Fatal("empty csrf token")
}
r := httptest.NewRequest(tt.method, "http://example.com/", nil)
if tt.secFetchSite != "" {
r.Header.Set("Sec-Fetch-Site", tt.secFetchSite)
}
if tt.host != "" {
r.Host = tt.host
}
if tt.origin != "" {
r.Header.Set("Origin", tt.origin)
}
// make a POST request without the CSRF header; ensure it fails
resp, err = client.Post(ser.URL+"/test/csrf-protected", "text/plain", nil)
if err != nil {
t.Fatalf("unable to make request: %v", err)
}
if resp.StatusCode != http.StatusForbidden {
t.Fatalf("unexpected status: %v", resp.Status)
}
// make a POST request with the CSRF header; ensure it succeeds
req, err := http.NewRequest("POST", ser.URL+"/test/csrf-protected", nil)
if err != nil {
t.Fatalf("error building request: %v", err)
}
req.Header.Set("X-CSRF-Token", csrfToken)
resp, err = client.Do(req)
if err != nil {
t.Fatalf("unable to make request: %v", err)
}
if resp.StatusCode != http.StatusOK {
t.Fatalf("unexpected status: %v", resp.Status)
}
defer resp.Body.Close()
out, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("unable to read body: %v", err)
}
if string(out) != "ok" {
t.Fatalf("unexpected body: %q", out)
w := httptest.NewRecorder()
withCSRF.ServeHTTP(w, r)
res := w.Result()
defer res.Body.Close()
if tt.wantError {
if res.StatusCode != http.StatusForbidden {
t.Errorf("expected status forbidden, got %v", res.StatusCode)
}
return
}
if res.StatusCode != http.StatusOK {
t.Errorf("expected status ok, got %v", res.StatusCode)
}
})
}
}