client/web: add Sec-Fetch-Site CSRF protection (#16046)
RELNOTE=Fix CSRF errors in the client Web UI Replace gorilla/csrf with a Sec-Fetch-Site based CSRF protection middleware that falls back to comparing the Host & Origin headers if no SFS value is passed by the client. Add an -origin override to the web CLI that allows callers to specify the origin at which the web UI will be available if it is hosted behind a reverse proxy or within another application via CGI. Updates #14872 Updates #15065 Signed-off-by: Patrick O'Doherty <patrick@tailscale.com>
This commit is contained in:
committed by
GitHub
parent
3ee4c60ff0
commit
a05924a9e5
+89
-74
@@ -11,7 +11,6 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/cookiejar"
|
||||
"net/http/httptest"
|
||||
"net/netip"
|
||||
"net/url"
|
||||
@@ -21,14 +20,12 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/gorilla/csrf"
|
||||
"tailscale.com/client/local"
|
||||
"tailscale.com/client/tailscale/apitype"
|
||||
"tailscale.com/ipn"
|
||||
"tailscale.com/ipn/ipnstate"
|
||||
"tailscale.com/net/memnet"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/tstest/nettest"
|
||||
"tailscale.com/types/views"
|
||||
"tailscale.com/util/httpm"
|
||||
)
|
||||
@@ -1492,81 +1489,99 @@ func mockWaitAuthURL(_ context.Context, id string, src tailcfg.NodeID) (*tailcfg
|
||||
}
|
||||
|
||||
func TestCSRFProtect(t *testing.T) {
|
||||
s := &Server{}
|
||||
|
||||
mux := http.NewServeMux()
|
||||
mux.HandleFunc("GET /test/csrf-token", func(w http.ResponseWriter, r *http.Request) {
|
||||
token := csrf.Token(r)
|
||||
_, err := io.WriteString(w, token)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
})
|
||||
mux.HandleFunc("POST /test/csrf-protected", func(w http.ResponseWriter, r *http.Request) {
|
||||
_, err := io.WriteString(w, "ok")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
})
|
||||
h := s.withCSRF(mux)
|
||||
ser := nettest.NewHTTPServer(nettest.GetNetwork(t), h)
|
||||
defer ser.Close()
|
||||
|
||||
jar, err := cookiejar.New(nil)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to construct cookie jar: %v", err)
|
||||
tests := []struct {
|
||||
name string
|
||||
method string
|
||||
secFetchSite string
|
||||
host string
|
||||
origin string
|
||||
originOverride string
|
||||
wantError bool
|
||||
}{
|
||||
{
|
||||
name: "GET requests with no header are allowed",
|
||||
method: "GET",
|
||||
},
|
||||
{
|
||||
name: "POST requests with same-origin are allowed",
|
||||
method: "POST",
|
||||
secFetchSite: "same-origin",
|
||||
},
|
||||
{
|
||||
name: "POST requests with cross-site are not allowed",
|
||||
method: "POST",
|
||||
secFetchSite: "cross-site",
|
||||
wantError: true,
|
||||
},
|
||||
{
|
||||
name: "POST requests with unknown sec-fetch-site values are not allowed",
|
||||
method: "POST",
|
||||
secFetchSite: "new-unknown-value",
|
||||
wantError: true,
|
||||
},
|
||||
{
|
||||
name: "POST requests with none are not allowed",
|
||||
method: "POST",
|
||||
secFetchSite: "none",
|
||||
wantError: true,
|
||||
},
|
||||
{
|
||||
name: "POST requests with no sec-fetch-site header but matching host and origin are allowed",
|
||||
method: "POST",
|
||||
host: "example.com",
|
||||
origin: "https://example.com",
|
||||
},
|
||||
{
|
||||
name: "POST requests with no sec-fetch-site and non-matching host and origin are not allowed",
|
||||
method: "POST",
|
||||
host: "example.com",
|
||||
origin: "https://example.net",
|
||||
wantError: true,
|
||||
},
|
||||
{
|
||||
name: "POST requests with no sec-fetch-site and and origin that matches the override are allowed",
|
||||
method: "POST",
|
||||
originOverride: "example.net",
|
||||
host: "internal.example.foo", // Host can be changed by reverse proxies
|
||||
origin: "http://example.net",
|
||||
},
|
||||
}
|
||||
|
||||
client := ser.Client()
|
||||
client.Jar = jar
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintf(w, "OK")
|
||||
})
|
||||
|
||||
// make GET request to populate cookie jar
|
||||
resp, err := client.Get(ser.URL + "/test/csrf-token")
|
||||
if err != nil {
|
||||
t.Fatalf("unable to make request: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("unexpected status: %v", resp.Status)
|
||||
}
|
||||
tokenBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to read body: %v", err)
|
||||
}
|
||||
s := &Server{
|
||||
originOverride: tt.originOverride,
|
||||
}
|
||||
withCSRF := s.csrfProtect(handler)
|
||||
|
||||
csrfToken := strings.TrimSpace(string(tokenBytes))
|
||||
if csrfToken == "" {
|
||||
t.Fatal("empty csrf token")
|
||||
}
|
||||
r := httptest.NewRequest(tt.method, "http://example.com/", nil)
|
||||
if tt.secFetchSite != "" {
|
||||
r.Header.Set("Sec-Fetch-Site", tt.secFetchSite)
|
||||
}
|
||||
if tt.host != "" {
|
||||
r.Host = tt.host
|
||||
}
|
||||
if tt.origin != "" {
|
||||
r.Header.Set("Origin", tt.origin)
|
||||
}
|
||||
|
||||
// make a POST request without the CSRF header; ensure it fails
|
||||
resp, err = client.Post(ser.URL+"/test/csrf-protected", "text/plain", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to make request: %v", err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusForbidden {
|
||||
t.Fatalf("unexpected status: %v", resp.Status)
|
||||
}
|
||||
|
||||
// make a POST request with the CSRF header; ensure it succeeds
|
||||
req, err := http.NewRequest("POST", ser.URL+"/test/csrf-protected", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("error building request: %v", err)
|
||||
}
|
||||
req.Header.Set("X-CSRF-Token", csrfToken)
|
||||
resp, err = client.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to make request: %v", err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("unexpected status: %v", resp.Status)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
out, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
t.Fatalf("unable to read body: %v", err)
|
||||
}
|
||||
if string(out) != "ok" {
|
||||
t.Fatalf("unexpected body: %q", out)
|
||||
w := httptest.NewRecorder()
|
||||
withCSRF.ServeHTTP(w, r)
|
||||
res := w.Result()
|
||||
defer res.Body.Close()
|
||||
if tt.wantError {
|
||||
if res.StatusCode != http.StatusForbidden {
|
||||
t.Errorf("expected status forbidden, got %v", res.StatusCode)
|
||||
}
|
||||
return
|
||||
}
|
||||
if res.StatusCode != http.StatusOK {
|
||||
t.Errorf("expected status ok, got %v", res.StatusCode)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user