Updates #10263 Co-authored-by: Brad Fitzpatrick <bradfitz@tailscale.com> Change-Id: I240bc9b5ecf2df6f92c45929d105fde66c06a860 Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>main
parent
2dbd546766
commit
7d4221c295
@ -0,0 +1,703 @@ |
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
// The tsidp command is an OpenID Connect Identity Provider server.
|
||||
//
|
||||
// See https://github.com/tailscale/tailscale/issues/10263 for background.
|
||||
package main |
||||
|
||||
import ( |
||||
"context" |
||||
crand "crypto/rand" |
||||
"crypto/rsa" |
||||
"crypto/tls" |
||||
"crypto/x509" |
||||
"encoding/base64" |
||||
"encoding/binary" |
||||
"encoding/json" |
||||
"encoding/pem" |
||||
"flag" |
||||
"fmt" |
||||
"io" |
||||
"log" |
||||
"net" |
||||
"net/http" |
||||
"net/netip" |
||||
"net/url" |
||||
"os" |
||||
"strconv" |
||||
"strings" |
||||
"sync" |
||||
"time" |
||||
|
||||
"gopkg.in/square/go-jose.v2" |
||||
"gopkg.in/square/go-jose.v2/jwt" |
||||
"tailscale.com/client/tailscale" |
||||
"tailscale.com/client/tailscale/apitype" |
||||
"tailscale.com/envknob" |
||||
"tailscale.com/ipn/ipnstate" |
||||
"tailscale.com/tailcfg" |
||||
"tailscale.com/tsnet" |
||||
"tailscale.com/types/key" |
||||
"tailscale.com/types/lazy" |
||||
"tailscale.com/types/logger" |
||||
"tailscale.com/types/views" |
||||
"tailscale.com/util/mak" |
||||
"tailscale.com/util/must" |
||||
"tailscale.com/util/rands" |
||||
) |
||||
|
||||
var ( |
||||
flagVerbose = flag.Bool("verbose", false, "be verbose") |
||||
flagPort = flag.Int("port", 443, "port to listen on") |
||||
flagLocalPort = flag.Int("local-port", -1, "allow requests from localhost") |
||||
flagUseLocalTailscaled = flag.Bool("use-local-tailscaled", false, "use local tailscaled instead of tsnet") |
||||
) |
||||
|
||||
func main() { |
||||
flag.Parse() |
||||
ctx := context.Background() |
||||
if !envknob.UseWIPCode() { |
||||
log.Fatal("cmd/tsidp is a work in progress and has not been security reviewed;\nits use requires TAILSCALE_USE_WIP_CODE=1 be set in the environment for now.") |
||||
} |
||||
|
||||
var ( |
||||
lc *tailscale.LocalClient |
||||
st *ipnstate.Status |
||||
err error |
||||
|
||||
lns []net.Listener |
||||
) |
||||
if *flagUseLocalTailscaled { |
||||
lc = &tailscale.LocalClient{} |
||||
st, err = lc.StatusWithoutPeers(ctx) |
||||
if err != nil { |
||||
log.Fatalf("getting status: %v", err) |
||||
} |
||||
portStr := fmt.Sprint(*flagPort) |
||||
anySuccess := false |
||||
for _, ip := range st.TailscaleIPs { |
||||
ln, err := net.Listen("tcp", net.JoinHostPort(ip.String(), portStr)) |
||||
if err != nil { |
||||
log.Printf("failed to listen on %v: %v", ip, err) |
||||
continue |
||||
} |
||||
anySuccess = true |
||||
ln = tls.NewListener(ln, &tls.Config{ |
||||
GetCertificate: lc.GetCertificate, |
||||
}) |
||||
lns = append(lns, ln) |
||||
} |
||||
if !anySuccess { |
||||
log.Fatalf("failed to listen on any of %v", st.TailscaleIPs) |
||||
} |
||||
} else { |
||||
ts := &tsnet.Server{ |
||||
Hostname: "idp", |
||||
} |
||||
if !*flagVerbose { |
||||
ts.Logf = logger.Discard |
||||
} |
||||
st, err = ts.Up(ctx) |
||||
if err != nil { |
||||
log.Fatal(err) |
||||
} |
||||
lc, err = ts.LocalClient() |
||||
if err != nil { |
||||
log.Fatalf("getting local client: %v", err) |
||||
} |
||||
ln, err := ts.ListenTLS("tcp", fmt.Sprintf(":%d", *flagPort)) |
||||
if err != nil { |
||||
log.Fatal(err) |
||||
} |
||||
lns = append(lns, ln) |
||||
} |
||||
|
||||
srv := &idpServer{ |
||||
lc: lc, |
||||
} |
||||
if *flagPort != 443 { |
||||
srv.serverURL = fmt.Sprintf("https://%s:%d", strings.TrimSuffix(st.Self.DNSName, "."), *flagPort) |
||||
} else { |
||||
srv.serverURL = fmt.Sprintf("https://%s", strings.TrimSuffix(st.Self.DNSName, ".")) |
||||
} |
||||
|
||||
log.Printf("Running tsidp at %s ...", srv.serverURL) |
||||
|
||||
if *flagLocalPort != -1 { |
||||
log.Printf("Also running tsidp at %s ...", srv.loopbackURL) |
||||
srv.loopbackURL = fmt.Sprintf("http://localhost:%d", *flagLocalPort) |
||||
ln, err := net.Listen("tcp", fmt.Sprintf("localhost:%d", *flagLocalPort)) |
||||
if err != nil { |
||||
log.Fatal(err) |
||||
} |
||||
lns = append(lns, ln) |
||||
} |
||||
|
||||
for _, ln := range lns { |
||||
go http.Serve(ln, srv) |
||||
} |
||||
select {} |
||||
} |
||||
|
||||
type idpServer struct { |
||||
lc *tailscale.LocalClient |
||||
loopbackURL string |
||||
serverURL string // "https://foo.bar.ts.net"
|
||||
|
||||
lazyMux lazy.SyncValue[*http.ServeMux] |
||||
lazySigningKey lazy.SyncValue[*signingKey] |
||||
lazySigner lazy.SyncValue[jose.Signer] |
||||
|
||||
mu sync.Mutex // guards the fields below
|
||||
code map[string]*authRequest // keyed by random hex
|
||||
accessToken map[string]*authRequest // keyed by random hex
|
||||
} |
||||
|
||||
type authRequest struct { |
||||
// localRP is true if the request is from a relying party running on the
|
||||
// same machine as the idp server. It is mutually exclusive with rpNodeID.
|
||||
localRP bool |
||||
|
||||
// rpNodeID is the NodeID of the relying party (who requested the auth, such
|
||||
// as Proxmox or Synology), not the user node who is being authenticated. It
|
||||
// is mutually exclusive with localRP.
|
||||
rpNodeID tailcfg.NodeID |
||||
|
||||
// clientID is the "client_id" sent in the authorized request.
|
||||
clientID string |
||||
|
||||
// nonce presented in the request.
|
||||
nonce string |
||||
|
||||
// redirectURI is the redirect_uri presented in the request.
|
||||
redirectURI string |
||||
|
||||
// remoteUser is the user who is being authenticated.
|
||||
remoteUser *apitype.WhoIsResponse |
||||
|
||||
// validTill is the time until which the token is valid.
|
||||
// As of 2023-11-14, it is 5 minutes.
|
||||
// TODO: add routine to delete expired tokens.
|
||||
validTill time.Time |
||||
} |
||||
|
||||
func (ar *authRequest) allowRelyingParty(ctx context.Context, remoteAddr string, lc *tailscale.LocalClient) error { |
||||
if ar.localRP { |
||||
ra, err := netip.ParseAddrPort(remoteAddr) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
if !ra.Addr().IsLoopback() { |
||||
return fmt.Errorf("tsidp: request from non-loopback address") |
||||
} |
||||
return nil |
||||
} |
||||
who, err := lc.WhoIs(ctx, remoteAddr) |
||||
if err != nil { |
||||
return fmt.Errorf("tsidp: error getting WhoIs: %w", err) |
||||
} |
||||
if ar.rpNodeID != who.Node.ID { |
||||
return fmt.Errorf("tsidp: token for different node") |
||||
} |
||||
return nil |
||||
} |
||||
|
||||
func (s *idpServer) authorize(w http.ResponseWriter, r *http.Request) { |
||||
who, err := s.lc.WhoIs(r.Context(), r.RemoteAddr) |
||||
if err != nil { |
||||
log.Printf("Error getting WhoIs: %v", err) |
||||
http.Error(w, err.Error(), http.StatusInternalServerError) |
||||
return |
||||
} |
||||
|
||||
uq := r.URL.Query() |
||||
|
||||
code := rands.HexString(32) |
||||
ar := &authRequest{ |
||||
nonce: uq.Get("nonce"), |
||||
remoteUser: who, |
||||
redirectURI: uq.Get("redirect_uri"), |
||||
clientID: uq.Get("client_id"), |
||||
} |
||||
|
||||
if r.URL.Path == "/authorize/localhost" { |
||||
ar.localRP = true |
||||
} else { |
||||
var ok bool |
||||
ar.rpNodeID, ok = parseID[tailcfg.NodeID](strings.TrimPrefix(r.URL.Path, "/authorize/")) |
||||
if !ok { |
||||
http.Error(w, "tsidp: invalid node ID suffix after /authorize/", http.StatusBadRequest) |
||||
return |
||||
} |
||||
} |
||||
|
||||
s.mu.Lock() |
||||
mak.Set(&s.code, code, ar) |
||||
s.mu.Unlock() |
||||
|
||||
q := make(url.Values) |
||||
q.Set("code", code) |
||||
q.Set("state", uq.Get("state")) |
||||
u := uq.Get("redirect_uri") + "?" + q.Encode() |
||||
log.Printf("Redirecting to %q", u) |
||||
|
||||
http.Redirect(w, r, u, http.StatusFound) |
||||
} |
||||
|
||||
func (s *idpServer) newMux() *http.ServeMux { |
||||
mux := http.NewServeMux() |
||||
mux.HandleFunc(oidcJWKSPath, s.serveJWKS) |
||||
mux.HandleFunc(oidcConfigPath, s.serveOpenIDConfig) |
||||
mux.HandleFunc("/authorize/", s.authorize) |
||||
mux.HandleFunc("/userinfo", s.serveUserInfo) |
||||
mux.HandleFunc("/token", s.serveToken) |
||||
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { |
||||
if r.URL.Path == "/" { |
||||
io.WriteString(w, "<html><body><h1>Tailscale OIDC IdP</h1>") |
||||
return |
||||
} |
||||
http.Error(w, "tsidp: not found", http.StatusNotFound) |
||||
}) |
||||
return mux |
||||
} |
||||
|
||||
func (s *idpServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { |
||||
log.Printf("%v %v", r.Method, r.URL) |
||||
s.lazyMux.Get(s.newMux).ServeHTTP(w, r) |
||||
} |
||||
|
||||
func (s *idpServer) serveUserInfo(w http.ResponseWriter, r *http.Request) { |
||||
if r.Method != "GET" { |
||||
http.Error(w, "tsidp: method not allowed", http.StatusMethodNotAllowed) |
||||
return |
||||
} |
||||
tk, ok := strings.CutPrefix(r.Header.Get("Authorization"), "Bearer ") |
||||
if !ok { |
||||
http.Error(w, "tsidp: invalid Authorization header", http.StatusBadRequest) |
||||
return |
||||
} |
||||
|
||||
s.mu.Lock() |
||||
ar, ok := s.accessToken[tk] |
||||
s.mu.Unlock() |
||||
if !ok { |
||||
http.Error(w, "tsidp: invalid token", http.StatusBadRequest) |
||||
return |
||||
} |
||||
if err := ar.allowRelyingParty(r.Context(), r.RemoteAddr, s.lc); err != nil { |
||||
log.Printf("Error allowing relying party: %v", err) |
||||
http.Error(w, err.Error(), http.StatusForbidden) |
||||
return |
||||
} |
||||
|
||||
if ar.validTill.Before(time.Now()) { |
||||
http.Error(w, "tsidp: token expired", http.StatusBadRequest) |
||||
s.mu.Lock() |
||||
delete(s.accessToken, tk) |
||||
s.mu.Unlock() |
||||
} |
||||
|
||||
ui := userInfo{} |
||||
if ar.remoteUser.Node.IsTagged() { |
||||
http.Error(w, "tsidp: tagged nodes not supported", http.StatusBadRequest) |
||||
return |
||||
} |
||||
ui.Sub = ar.remoteUser.Node.User.String() |
||||
ui.Name = ar.remoteUser.UserProfile.DisplayName |
||||
ui.Email = ar.remoteUser.UserProfile.LoginName |
||||
ui.Picture = ar.remoteUser.UserProfile.ProfilePicURL |
||||
|
||||
// TODO(maisem): not sure if this is the right thing to do
|
||||
ui.UserName, _, _ = strings.Cut(ar.remoteUser.UserProfile.LoginName, "@") |
||||
|
||||
w.Header().Set("Content-Type", "application/json") |
||||
if err := json.NewEncoder(w).Encode(ui); err != nil { |
||||
http.Error(w, err.Error(), http.StatusInternalServerError) |
||||
} |
||||
} |
||||
|
||||
type userInfo struct { |
||||
Sub string `json:"sub"` |
||||
Name string `json:"name"` |
||||
Email string `json:"email"` |
||||
Picture string `json:"picture"` |
||||
UserName string `json:"username"` |
||||
} |
||||
|
||||
func (s *idpServer) serveToken(w http.ResponseWriter, r *http.Request) { |
||||
if r.Method != "POST" { |
||||
http.Error(w, "tsidp: method not allowed", http.StatusMethodNotAllowed) |
||||
return |
||||
} |
||||
if r.FormValue("grant_type") != "authorization_code" { |
||||
http.Error(w, "tsidp: grant_type not supported", http.StatusBadRequest) |
||||
return |
||||
} |
||||
code := r.FormValue("code") |
||||
if code == "" { |
||||
http.Error(w, "tsidp: code is required", http.StatusBadRequest) |
||||
return |
||||
} |
||||
s.mu.Lock() |
||||
ar, ok := s.code[code] |
||||
if ok { |
||||
delete(s.code, code) |
||||
} |
||||
s.mu.Unlock() |
||||
if !ok { |
||||
http.Error(w, "tsidp: code not found", http.StatusBadRequest) |
||||
return |
||||
} |
||||
if err := ar.allowRelyingParty(r.Context(), r.RemoteAddr, s.lc); err != nil { |
||||
log.Printf("Error allowing relying party: %v", err) |
||||
http.Error(w, err.Error(), http.StatusForbidden) |
||||
return |
||||
} |
||||
if ar.redirectURI != r.FormValue("redirect_uri") { |
||||
http.Error(w, "tsidp: redirect_uri mismatch", http.StatusBadRequest) |
||||
return |
||||
} |
||||
signer, err := s.oidcSigner() |
||||
if err != nil { |
||||
log.Printf("Error getting signer: %v", err) |
||||
http.Error(w, err.Error(), http.StatusInternalServerError) |
||||
return |
||||
} |
||||
jti := rands.HexString(32) |
||||
who := ar.remoteUser |
||||
|
||||
// TODO(maisem): not sure if this is the right thing to do
|
||||
userName, _, _ := strings.Cut(ar.remoteUser.UserProfile.LoginName, "@") |
||||
n := who.Node.View() |
||||
if n.IsTagged() { |
||||
http.Error(w, "tsidp: tagged nodes not supported", http.StatusBadRequest) |
||||
return |
||||
} |
||||
|
||||
now := time.Now() |
||||
_, tcd, _ := strings.Cut(n.Name(), ".") |
||||
tsClaims := tailscaleClaims{ |
||||
Claims: jwt.Claims{ |
||||
Audience: jwt.Audience{ar.clientID}, |
||||
Expiry: jwt.NewNumericDate(now.Add(5 * time.Minute)), |
||||
ID: jti, |
||||
IssuedAt: jwt.NewNumericDate(now), |
||||
Issuer: s.serverURL, |
||||
NotBefore: jwt.NewNumericDate(now), |
||||
Subject: n.User().String(), |
||||
}, |
||||
Nonce: ar.nonce, |
||||
Key: n.Key(), |
||||
Addresses: n.Addresses(), |
||||
NodeID: n.ID(), |
||||
NodeName: n.Name(), |
||||
Tailnet: tcd, |
||||
UserID: n.User(), |
||||
Email: who.UserProfile.LoginName, |
||||
UserName: userName, |
||||
} |
||||
if ar.localRP { |
||||
tsClaims.Issuer = s.loopbackURL |
||||
} |
||||
|
||||
// Create an OIDC token using this issuer's signer.
|
||||
token, err := jwt.Signed(signer).Claims(tsClaims).CompactSerialize() |
||||
if err != nil { |
||||
log.Printf("Error getting token: %v", err) |
||||
http.Error(w, err.Error(), http.StatusInternalServerError) |
||||
return |
||||
} |
||||
|
||||
at := rands.HexString(32) |
||||
s.mu.Lock() |
||||
ar.validTill = now.Add(5 * time.Minute) |
||||
mak.Set(&s.accessToken, at, ar) |
||||
s.mu.Unlock() |
||||
|
||||
w.Header().Set("Content-Type", "application/json") |
||||
if err := json.NewEncoder(w).Encode(oidcTokenResponse{ |
||||
AccessToken: at, |
||||
TokenType: "Bearer", |
||||
ExpiresIn: 5 * 60, |
||||
IDToken: token, |
||||
}); err != nil { |
||||
http.Error(w, err.Error(), http.StatusInternalServerError) |
||||
} |
||||
} |
||||
|
||||
type oidcTokenResponse struct { |
||||
IDToken string `json:"id_token"` |
||||
TokenType string `json:"token_type"` |
||||
AccessToken string `json:"access_token"` |
||||
RefreshToken string `json:"refresh_token"` |
||||
ExpiresIn int `json:"expires_in"` |
||||
} |
||||
|
||||
const ( |
||||
oidcJWKSPath = "/.well-known/jwks.json" |
||||
oidcConfigPath = "/.well-known/openid-configuration" |
||||
) |
||||
|
||||
func (s *idpServer) oidcSigner() (jose.Signer, error) { |
||||
return s.lazySigner.GetErr(func() (jose.Signer, error) { |
||||
sk, err := s.oidcPrivateKey() |
||||
if err != nil { |
||||
return nil, err |
||||
} |
||||
return jose.NewSigner(jose.SigningKey{ |
||||
Algorithm: jose.RS256, |
||||
Key: sk.k, |
||||
}, &jose.SignerOptions{EmbedJWK: false, ExtraHeaders: map[jose.HeaderKey]any{ |
||||
jose.HeaderType: "JWT", |
||||
"kid": fmt.Sprint(sk.kid), |
||||
}}) |
||||
}) |
||||
} |
||||
|
||||
func (s *idpServer) oidcPrivateKey() (*signingKey, error) { |
||||
return s.lazySigningKey.GetErr(func() (*signingKey, error) { |
||||
var sk signingKey |
||||
b, err := os.ReadFile("oidc-key.json") |
||||
if err == nil { |
||||
if err := sk.UnmarshalJSON(b); err == nil { |
||||
return &sk, nil |
||||
} else { |
||||
log.Printf("Error unmarshaling key: %v", err) |
||||
} |
||||
} |
||||
id, k := mustGenRSAKey(2048) |
||||
sk.k = k |
||||
sk.kid = id |
||||
b, err = sk.MarshalJSON() |
||||
if err != nil { |
||||
log.Fatalf("Error marshaling key: %v", err) |
||||
} |
||||
if err := os.WriteFile("oidc-key.json", b, 0600); err != nil { |
||||
log.Fatalf("Error writing key: %v", err) |
||||
} |
||||
return &sk, nil |
||||
}) |
||||
} |
||||
|
||||
func (s *idpServer) serveJWKS(w http.ResponseWriter, r *http.Request) { |
||||
if r.URL.Path != oidcJWKSPath { |
||||
http.Error(w, "tsidp: not found", http.StatusNotFound) |
||||
return |
||||
} |
||||
w.Header().Set("Content-Type", "application/json") |
||||
sk, err := s.oidcPrivateKey() |
||||
if err != nil { |
||||
http.Error(w, err.Error(), http.StatusInternalServerError) |
||||
return |
||||
} |
||||
// TODO(maisem): maybe only marshal this once and reuse?
|
||||
// TODO(maisem): implement key rotation.
|
||||
je := json.NewEncoder(w) |
||||
je.SetIndent("", " ") |
||||
if err := je.Encode(jose.JSONWebKeySet{ |
||||
Keys: []jose.JSONWebKey{ |
||||
{ |
||||
Key: sk.k.Public(), |
||||
Algorithm: string(jose.RS256), |
||||
Use: "sig", |
||||
KeyID: fmt.Sprint(sk.kid), |
||||
}, |
||||
}, |
||||
}); err != nil { |
||||
http.Error(w, err.Error(), http.StatusInternalServerError) |
||||
} |
||||
return |
||||
} |
||||
|
||||
// openIDProviderMetadata is a partial representation of
|
||||
// https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderMetadata.
|
||||
type openIDProviderMetadata struct { |
||||
Issuer string `json:"issuer"` |
||||
AuthorizationEndpoint string `json:"authorization_endpoint,omitempty"` |
||||
TokenEndpoint string `json:"token_endpoint,omitempty"` |
||||
UserInfoEndpoint string `json:"userinfo_endpoint,omitempty"` |
||||
JWKS_URI string `json:"jwks_uri"` |
||||
ScopesSupported views.Slice[string] `json:"scopes_supported"` |
||||
ResponseTypesSupported views.Slice[string] `json:"response_types_supported"` |
||||
SubjectTypesSupported views.Slice[string] `json:"subject_types_supported"` |
||||
ClaimsSupported views.Slice[string] `json:"claims_supported"` |
||||
IDTokenSigningAlgValuesSupported views.Slice[string] `json:"id_token_signing_alg_values_supported"` |
||||
// TODO(maisem): maybe add other fields?
|
||||
// Currently we fill out the REQUIRED fields, scopes_supported and claims_supported.
|
||||
} |
||||
|
||||
type tailscaleClaims struct { |
||||
jwt.Claims `json:",inline"` |
||||
Nonce string `json:"nonce,omitempty"` // the nonce from the request
|
||||
Key key.NodePublic `json:"key"` // the node public key
|
||||
Addresses views.Slice[netip.Prefix] `json:"addresses"` // the Tailscale IPs of the node
|
||||
NodeID tailcfg.NodeID `json:"nid"` // the stable node ID
|
||||
NodeName string `json:"node"` // name of the node
|
||||
Tailnet string `json:"tailnet"` // tailnet (like tail-scale.ts.net)
|
||||
|
||||
// Email is the "emailish" value with an '@' sign. It might not be a valid email.
|
||||
Email string `json:"email,omitempty"` // user emailish (like "alice@github" or "bob@example.com")
|
||||
UserID tailcfg.UserID `json:"uid,omitempty"` |
||||
|
||||
// UserName is the local part of Email (without '@' and domain).
|
||||
// It is a temporary (2023-11-15) hack during development.
|
||||
// We should probably let this be configured via grants.
|
||||
UserName string `json:"username,omitempty"` |
||||
} |
||||
|
||||
var ( |
||||
openIDSupportedClaims = views.SliceOf([]string{ |
||||
// Standard claims, these correspond to fields in jwt.Claims.
|
||||
"sub", "aud", "exp", "iat", "iss", "jti", "nbf", "username", "email", |
||||
|
||||
// Tailscale claims, these correspond to fields in tailscaleClaims.
|
||||
"key", "addresses", "nid", "node", "tailnet", "tags", "user", "uid", |
||||
}) |
||||
|
||||
// As defined in the OpenID spec this should be "openid".
|
||||
openIDSupportedScopes = views.SliceOf([]string{"openid", "email", "profile"}) |
||||
|
||||
// We only support getting the id_token.
|
||||
openIDSupportedReponseTypes = views.SliceOf([]string{"id_token", "code"}) |
||||
|
||||
// The type of the "sub" field in the JWT, which means it is globally unique identifier.
|
||||
// The other option is "pairwise", which means the identifier is different per receiving 3p.
|
||||
openIDSupportedSubjectTypes = views.SliceOf([]string{"public"}) |
||||
|
||||
// The algo used for signing. The OpenID spec says "The algorithm RS256 MUST be included."
|
||||
// https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderMetadata
|
||||
openIDSupportedSigningAlgos = views.SliceOf([]string{string(jose.RS256)}) |
||||
) |
||||
|
||||
func (s *idpServer) serveOpenIDConfig(w http.ResponseWriter, r *http.Request) { |
||||
if r.URL.Path != oidcConfigPath { |
||||
http.Error(w, "tsidp: not found", http.StatusNotFound) |
||||
return |
||||
} |
||||
ap, err := netip.ParseAddrPort(r.RemoteAddr) |
||||
if err != nil { |
||||
log.Printf("Error parsing remote addr: %v", err) |
||||
return |
||||
} |
||||
var authorizeEndpoint string |
||||
rpEndpoint := s.serverURL |
||||
if who, err := s.lc.WhoIs(r.Context(), r.RemoteAddr); err == nil { |
||||
authorizeEndpoint = fmt.Sprintf("%s/authorize/%d", s.serverURL, who.Node.ID) |
||||
} else if ap.Addr().IsLoopback() { |
||||
rpEndpoint = s.loopbackURL |
||||
authorizeEndpoint = fmt.Sprintf("%s/authorize/localhost", s.serverURL) |
||||
} else { |
||||
log.Printf("Error getting WhoIs: %v", err) |
||||
http.Error(w, err.Error(), http.StatusInternalServerError) |
||||
return |
||||
} |
||||
|
||||
w.Header().Set("Content-Type", "application/json") |
||||
je := json.NewEncoder(w) |
||||
je.SetIndent("", " ") |
||||
if err := je.Encode(openIDProviderMetadata{ |
||||
AuthorizationEndpoint: authorizeEndpoint, |
||||
Issuer: rpEndpoint, |
||||
JWKS_URI: rpEndpoint + oidcJWKSPath, |
||||
UserInfoEndpoint: rpEndpoint + "/userinfo", |
||||
TokenEndpoint: rpEndpoint + "/token", |
||||
ScopesSupported: openIDSupportedScopes, |
||||
ResponseTypesSupported: openIDSupportedReponseTypes, |
||||
SubjectTypesSupported: openIDSupportedSubjectTypes, |
||||
ClaimsSupported: openIDSupportedClaims, |
||||
IDTokenSigningAlgValuesSupported: openIDSupportedSigningAlgos, |
||||
}); err != nil { |
||||
http.Error(w, err.Error(), http.StatusInternalServerError) |
||||
} |
||||
} |
||||
|
||||
const ( |
||||
minimumRSAKeySize = 2048 |
||||
) |
||||
|
||||
// mustGenRSAKey generates a new RSA key with the provided number of bits. It
|
||||
// panics on failure. bits must be at least minimumRSAKeySizeBytes * 8.
|
||||
func mustGenRSAKey(bits int) (kid uint64, k *rsa.PrivateKey) { |
||||
if bits < minimumRSAKeySize { |
||||
panic("request to generate a too-small RSA key") |
||||
} |
||||
kid = must.Get(readUint64(crand.Reader)) |
||||
k = must.Get(rsa.GenerateKey(crand.Reader, bits)) |
||||
return |
||||
} |
||||
|
||||
// readUint64 reads from r until 8 bytes represent a non-zero uint64.
|
||||
func readUint64(r io.Reader) (uint64, error) { |
||||
for { |
||||
var b [8]byte |
||||
if _, err := io.ReadFull(r, b[:]); err != nil { |
||||
return 0, err |
||||
} |
||||
if v := binary.BigEndian.Uint64(b[:]); v != 0 { |
||||
return v, nil |
||||
} |
||||
} |
||||
} |
||||
|
||||
// rsaPrivateKeyJSONWrapper is the the JSON serialization
|
||||
// format used by RSAPrivateKey.
|
||||
type rsaPrivateKeyJSONWrapper struct { |
||||
Key string |
||||
ID uint64 |
||||
} |
||||
|
||||
type signingKey struct { |
||||
k *rsa.PrivateKey |
||||
kid uint64 |
||||
} |
||||
|
||||
func (sk *signingKey) MarshalJSON() ([]byte, error) { |
||||
b := pem.Block{ |
||||
Type: "RSA PRIVATE KEY", |
||||
Bytes: x509.MarshalPKCS1PrivateKey(sk.k), |
||||
} |
||||
bts := pem.EncodeToMemory(&b) |
||||
return json.Marshal(rsaPrivateKeyJSONWrapper{ |
||||
Key: base64.URLEncoding.EncodeToString(bts), |
||||
ID: sk.kid, |
||||
}) |
||||
} |
||||
|
||||
func (sk *signingKey) UnmarshalJSON(b []byte) error { |
||||
var wrapper rsaPrivateKeyJSONWrapper |
||||
if err := json.Unmarshal(b, &wrapper); err != nil { |
||||
return err |
||||
} |
||||
if len(wrapper.Key) == 0 { |
||||
return nil |
||||
} |
||||
b64dec, err := base64.URLEncoding.DecodeString(wrapper.Key) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
blk, _ := pem.Decode(b64dec) |
||||
k, err := x509.ParsePKCS1PrivateKey(blk.Bytes) |
||||
if err != nil { |
||||
return err |
||||
} |
||||
sk.k = k |
||||
sk.kid = wrapper.ID |
||||
return nil |
||||
} |
||||
|
||||
// parseID takes a string input and returns a typed IntID T and true, or a zero
|
||||
// value and false if the input is unhandled syntax or out of a valid range.
|
||||
func parseID[T ~int64](input string) (_ T, ok bool) { |
||||
if input == "" { |
||||
return 0, false |
||||
} |
||||
i, err := strconv.ParseInt(input, 10, 64) |
||||
if err != nil { |
||||
return 0, false |
||||
} |
||||
if i < 0 { |
||||
return 0, false |
||||
} |
||||
return T(i), true |
||||
} |
||||
Loading…
Reference in new issue