net/tlsdial: fix TLS cert validation of HTTPS proxies

If you had HTTPS_PROXY=https://some-valid-cert.example.com running a
CONNECT proxy, we should've been able to do a TLS CONNECT request to
e.g. controlplane.tailscale.com:443 through that, and I'm pretty sure
it used to work, but refactorings and lack of integration tests made
it regress.

It probably regressed when we added the baked-in LetsEncrypt root cert
validation fallback code, which was testing against the wrong hostname
(the ultimate one, not the one which we were being asked to validate)

Fixes #16222

Change-Id: If014e395f830e2f87f056f588edacad5c15e91bc
Signed-off-by: Brad Fitzpatrick <bradfitz@tailscale.com>
This commit is contained in:
Brad Fitzpatrick
2025-06-08 18:51:41 -07:00
committed by Brad Fitzpatrick
parent 4979ce7a94
commit e92eb6b17b
17 changed files with 672 additions and 49 deletions
+3 -2
View File
@@ -7,6 +7,7 @@ package bakedroots
import (
"crypto/x509"
"fmt"
"sync"
"tailscale.com/util/testenv"
@@ -14,7 +15,7 @@ import (
// Get returns the baked-in roots.
//
// As of 2025-01-21, this includes only the LetsEncrypt ISRG Root X1 root.
// As of 2025-01-21, this includes only the LetsEncrypt ISRG Root X1 & X2 roots.
func Get() *x509.CertPool {
roots.once.Do(func() {
roots.parsePEM(append(
@@ -56,7 +57,7 @@ type rootsOnce struct {
func (r *rootsOnce) parsePEM(caPEM []byte) {
p := x509.NewCertPool()
if !p.AppendCertsFromPEM(caPEM) {
panic("bogus PEM")
panic(fmt.Sprintf("bogus PEM: %q", caPEM))
}
r.p = p
}
+93
View File
@@ -0,0 +1,93 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
// Package connectproxy contains some CONNECT proxy code.
package connectproxy
import (
"context"
"io"
"log"
"net"
"net/http"
"time"
"tailscale.com/net/netx"
"tailscale.com/types/logger"
)
// Handler is an HTTP CONNECT proxy handler.
type Handler struct {
// Dial, if non-nil, is an alternate dialer to use
// instead of the default dialer.
Dial netx.DialFunc
// Logf, if non-nil, is an alterate logger to
// use instead of log.Printf.
Logf logger.Logf
// Check, if non-nil, validates the CONNECT target.
Check func(hostPort string) error
}
func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if r.Method != "CONNECT" {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
dial := h.Dial
if dial == nil {
var d net.Dialer
dial = d.DialContext
}
logf := h.Logf
if logf == nil {
logf = log.Printf
}
hostPort := r.RequestURI
if h.Check != nil {
if err := h.Check(hostPort); err != nil {
logf("CONNECT target %q not allowed: %v", hostPort, err)
http.Error(w, "Invalid CONNECT target", http.StatusForbidden)
return
}
}
ctx, cancel := context.WithTimeout(ctx, 15*time.Second)
defer cancel()
back, err := dial(ctx, "tcp", hostPort)
if err != nil {
logf("error CONNECT dialing %v: %v", hostPort, err)
http.Error(w, "Connect failure", http.StatusBadGateway)
return
}
defer back.Close()
hj, ok := w.(http.Hijacker)
if !ok {
http.Error(w, "CONNECT hijack unavailable", http.StatusInternalServerError)
return
}
c, br, err := hj.Hijack()
if err != nil {
logf("CONNECT hijack: %v", err)
return
}
defer c.Close()
io.WriteString(c, "HTTP/1.1 200 OK\r\n\r\n")
errc := make(chan error, 2)
go func() {
_, err := io.Copy(c, back)
errc <- err
}()
go func() {
_, err := io.Copy(back, br)
errc <- err
}()
<-errc
}
+12 -1
View File
@@ -24,6 +24,7 @@ import (
"tailscale.com/util/cloudenv"
"tailscale.com/util/singleflight"
"tailscale.com/util/slicesx"
"tailscale.com/util/testenv"
)
var zaddr netip.Addr
@@ -63,6 +64,10 @@ type Resolver struct {
// If nil, net.DefaultResolver is used.
Forward *net.Resolver
// LookupIPForTest, if non-nil and in tests, handles requests instead
// of the usual mechanisms.
LookupIPForTest func(ctx context.Context, host string) ([]netip.Addr, error)
// LookupIPFallback optionally provides a backup DNS mechanism
// to use if Forward returns an error or no results.
LookupIPFallback func(ctx context.Context, host string) ([]netip.Addr, error)
@@ -284,7 +289,13 @@ func (r *Resolver) lookupIP(ctx context.Context, host string) (ip, ip6 netip.Add
lookupCtx, lookupCancel := context.WithTimeout(ctx, r.lookupTimeoutForHost(host))
defer lookupCancel()
ips, err := r.fwd().LookupNetIP(lookupCtx, "ip", host)
var ips []netip.Addr
if r.LookupIPForTest != nil && testenv.InTest() {
ips, err = r.LookupIPForTest(ctx, host)
} else {
ips, err = r.fwd().LookupNetIP(lookupCtx, "ip", host)
}
if err != nil || len(ips) == 0 {
if resolver, ok := r.cloudHostResolver(); ok {
r.dlogf("resolving %q via cloud resolver", host)
+1 -1
View File
@@ -286,7 +286,7 @@ func bootstrapDNSMap(ctx context.Context, serverName string, serverIP netip.Addr
tr.DialContext = func(ctx context.Context, netw, addr string) (net.Conn, error) {
return dialer.DialContext(ctx, "tcp", net.JoinHostPort(serverIP.String(), "443"))
}
tr.TLSClientConfig = tlsdial.Config(serverName, ht, tr.TLSClientConfig)
tr.TLSClientConfig = tlsdial.Config(ht, tr.TLSClientConfig)
c := &http.Client{Transport: tr}
req, err := http.NewRequestWithContext(ctx, "GET", "https://"+serverName+"/bootstrap-dns?q="+url.QueryEscape(queryName), nil)
if err != nil {
+35 -33
View File
@@ -59,18 +59,26 @@ var mitmBlockWarnable = health.Register(&health.Warnable{
ImpactsConnectivity: true,
})
// Config returns a tls.Config for connecting to a server.
// Config returns a tls.Config for connecting to a server that
// uses system roots for validation but, if those fail, also tries
// the baked-in LetsEncrypt roots as a fallback validation method.
//
// If base is non-nil, it's cloned as the base config before
// being configured and returned.
// If ht is non-nil, it's used to report health errors.
func Config(host string, ht *health.Tracker, base *tls.Config) *tls.Config {
func Config(ht *health.Tracker, base *tls.Config) *tls.Config {
var conf *tls.Config
if base == nil {
conf = new(tls.Config)
} else {
conf = base.Clone()
}
conf.ServerName = host
// Note: we do NOT set conf.ServerName here (as we accidentally did
// previously), as this path is also used when dialing an HTTPS proxy server
// (through which we'll send a CONNECT request to get a TCP connection to do
// the real TCP connection) because host is the ultimate hostname, but this
// tls.Config is used for both the proxy and the ultimate target.
if n := sslKeyLogFile; n != "" {
f, err := os.OpenFile(n, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0600)
@@ -93,7 +101,9 @@ func Config(host string, ht *health.Tracker, base *tls.Config) *tls.Config {
// (with the baked-in fallback root) in the VerifyConnection hook.
conf.InsecureSkipVerify = true
conf.VerifyConnection = func(cs tls.ConnectionState) (retErr error) {
if host == "log.tailscale.com" && hostinfo.IsNATLabGuestVM() {
dialedHost := cs.ServerName
if dialedHost == "log.tailscale.com" && hostinfo.IsNATLabGuestVM() {
// Allow log.tailscale.com TLS MITM for integration tests when
// the client's running within a NATLab VM.
return nil
@@ -116,7 +126,7 @@ func Config(host string, ht *health.Tracker, base *tls.Config) *tls.Config {
// Show a dedicated warning.
m, ok := blockblame.VerifyCertificate(cert)
if ok {
log.Printf("tlsdial: server cert for %q looks like %q equipment (could be blocking Tailscale)", host, m.Name)
log.Printf("tlsdial: server cert seen while dialing %q looks like %q equipment (could be blocking Tailscale)", dialedHost, m.Name)
ht.SetUnhealthy(mitmBlockWarnable, health.Args{"manufacturer": m.Name})
} else {
ht.SetHealthy(mitmBlockWarnable)
@@ -135,7 +145,7 @@ func Config(host string, ht *health.Tracker, base *tls.Config) *tls.Config {
ht.SetTLSConnectionError(cs.ServerName, nil)
if selfSignedIssuer != "" {
// Log the self-signed issuer, but don't treat it as an error.
log.Printf("tlsdial: warning: server cert for %q passed x509 validation but is self-signed by %q", host, selfSignedIssuer)
log.Printf("tlsdial: warning: server cert for %q passed x509 validation but is self-signed by %q", dialedHost, selfSignedIssuer)
}
}
}()
@@ -144,7 +154,7 @@ func Config(host string, ht *health.Tracker, base *tls.Config) *tls.Config {
// First try doing x509 verification with the system's
// root CA pool.
opts := x509.VerifyOptions{
DNSName: cs.ServerName,
DNSName: dialedHost,
Intermediates: x509.NewCertPool(),
}
for _, cert := range cs.PeerCertificates[1:] {
@@ -152,7 +162,7 @@ func Config(host string, ht *health.Tracker, base *tls.Config) *tls.Config {
}
_, errSys := cs.PeerCertificates[0].Verify(opts)
if debug() {
log.Printf("tlsdial(sys %q): %v", host, errSys)
log.Printf("tlsdial(sys %q): %v", dialedHost, errSys)
}
// Always verify with our baked-in Let's Encrypt certificate,
@@ -161,13 +171,11 @@ func Config(host string, ht *health.Tracker, base *tls.Config) *tls.Config {
opts.Roots = bakedroots.Get()
_, bakedErr := cs.PeerCertificates[0].Verify(opts)
if debug() {
log.Printf("tlsdial(bake %q): %v", host, bakedErr)
log.Printf("tlsdial(bake %q): %v", dialedHost, bakedErr)
} else if bakedErr != nil {
if _, loaded := tlsdialWarningPrinted.LoadOrStore(host, true); !loaded {
if errSys == nil {
log.Printf("tlsdial: warning: server cert for %q is not a Let's Encrypt cert", host)
} else {
log.Printf("tlsdial: error: server cert for %q failed to verify and is not a Let's Encrypt cert", host)
if _, loaded := tlsdialWarningPrinted.LoadOrStore(dialedHost, true); !loaded {
if errSys != nil {
log.Printf("tlsdial: error: server cert for %q failed both system roots & Let's Encrypt root validation", dialedHost)
}
}
}
@@ -202,9 +210,6 @@ func SetConfigExpectedCert(c *tls.Config, certDNSName string) {
c.ServerName = certDNSName
return
}
if c.VerifyPeerCertificate != nil {
panic("refusing to override tls.Config.VerifyPeerCertificate")
}
// Set InsecureSkipVerify to prevent crypto/tls from doing its
// own cert verification, but do the same work that it'd do
// (but using certDNSName) in the VerifyPeerCertificate hook.
@@ -257,29 +262,30 @@ func SetConfigExpectedCertHash(c *tls.Config, wantFullCertSHA256Hex string) {
if c.VerifyPeerCertificate != nil {
panic("refusing to override tls.Config.VerifyPeerCertificate")
}
// Set InsecureSkipVerify to prevent crypto/tls from doing its
// own cert verification, but do the same work that it'd do
// (but using certDNSName) in the VerifyPeerCertificate hook.
// (but using certDNSName) in the VerifyConnection hook.
c.InsecureSkipVerify = true
c.VerifyConnection = nil
c.VerifyPeerCertificate = func(rawCerts [][]byte, _ [][]*x509.Certificate) error {
c.VerifyConnection = func(cs tls.ConnectionState) error {
dialedHost := cs.ServerName
var sawGoodCert bool
for _, rawCert := range rawCerts {
cert, err := x509.ParseCertificate(rawCert)
if err != nil {
return fmt.Errorf("ParseCertificate: %w", err)
}
for _, cert := range cs.PeerCertificates {
if strings.HasPrefix(cert.Subject.CommonName, derpconst.MetaCertCommonNamePrefix) {
continue
}
if sawGoodCert {
return errors.New("unexpected multiple certs presented")
}
if fmt.Sprintf("%02x", sha256.Sum256(rawCert)) != wantFullCertSHA256Hex {
if fmt.Sprintf("%02x", sha256.Sum256(cert.Raw)) != wantFullCertSHA256Hex {
return fmt.Errorf("cert hash does not match expected cert hash")
}
if err := cert.VerifyHostname(c.ServerName); err != nil {
return fmt.Errorf("cert does not match server name %q: %w", c.ServerName, err)
if dialedHost != "" { // it's empty when dialing a derper by IP with no hostname
if err := cert.VerifyHostname(dialedHost); err != nil {
return fmt.Errorf("cert does not match server name %q: %w", dialedHost, err)
}
}
now := time.Now()
if now.After(cert.NotAfter) {
@@ -302,12 +308,8 @@ func SetConfigExpectedCertHash(c *tls.Config, wantFullCertSHA256Hex string) {
func NewTransport() *http.Transport {
return &http.Transport{
DialTLSContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
host, _, err := net.SplitHostPort(addr)
if err != nil {
return nil, err
}
var d tls.Dialer
d.Config = Config(host, nil, nil)
d.Config = Config(nil, nil)
return d.DialContext(ctx, network, addr)
},
}
+1 -1
View File
@@ -86,7 +86,7 @@ func TestFallbackRootWorks(t *testing.T) {
DisableKeepAlives: true, // for test cleanup ease
}
ht := new(health.Tracker)
tr.TLSClientConfig = Config("tlsdial.test", ht, tr.TLSClientConfig)
tr.TLSClientConfig = Config(ht, tr.TLSClientConfig)
c := &http.Client{Transport: tr}
ctr0 := atomic.LoadInt32(&counterFallbackOK)