WIP: rebase for 2026-05-18 #7
@@ -13,6 +13,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"maps"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
@@ -41,8 +42,10 @@ import (
|
||||
"tailscale.com/types/dnstype"
|
||||
"tailscale.com/types/logger"
|
||||
"tailscale.com/types/nettype"
|
||||
"tailscale.com/types/views"
|
||||
"tailscale.com/util/cloudenv"
|
||||
"tailscale.com/util/dnsname"
|
||||
"tailscale.com/util/mak"
|
||||
"tailscale.com/util/race"
|
||||
"tailscale.com/version"
|
||||
)
|
||||
@@ -324,6 +327,19 @@ type forwarder struct {
|
||||
// resolver lookup.
|
||||
cloudHostFallback []resolverAndDelay
|
||||
|
||||
// schemes are the collection of registered URI scheme names that
|
||||
// dynamically decide which resolver to use at the time of each query. The
|
||||
// key is the scheme (the portion before the first `:`) and the value is a
|
||||
// handler that determines where the current query should be sent.
|
||||
// Use schemeCacheLocked() to get the current contents that can continue to
|
||||
// be accessed once mu is released. This allows the (much more common)
|
||||
// resolver code path to avoid repeated locking and unlocking.
|
||||
// When modified, call invalidateSchemeCacheLocked() before unlocking mu.
|
||||
schemes map[string]CustomSchemeHandler
|
||||
// schemeCache is an immutable copy of schemes. Do not read directly,
|
||||
// use schemeCacheLocked() which will regenerate its contents as needed.
|
||||
schemeCache views.Map[string, CustomSchemeHandler]
|
||||
|
||||
// acceptDNS tracks the CorpDNS pref (--accept-dns)
|
||||
// This lets us skip health warnings if the forwarder receives inbound
|
||||
// queries directly - but we didn't configure it with any upstream resolvers.
|
||||
@@ -996,15 +1012,66 @@ func (f *forwarder) sendTCP(ctx context.Context, fq *forwardQuery, rr resolverAn
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// applySchemes resolves any custom-scheme entries in rrs using the provided
|
||||
// scheme handlers, returning the resulting slice. Entries whose handler returns
|
||||
// an error or empty string are dropped. Entries with no registered scheme pass
|
||||
// through unchanged. If schemes is nil, rrs is returned as-is.
|
||||
func applySchemes(logf logger.Logf, rrs []resolverAndDelay, schemes views.Map[string, CustomSchemeHandler]) []resolverAndDelay {
|
||||
if schemes.IsNil() {
|
||||
return rrs
|
||||
}
|
||||
var result []resolverAndDelay
|
||||
for i, rr := range rrs {
|
||||
scheme, _, hasColon := strings.Cut(rr.name.Addr, ":")
|
||||
handler, isCustom := schemes.GetOk(scheme)
|
||||
if !hasColon || !isCustom {
|
||||
if result != nil {
|
||||
result = append(result, rr)
|
||||
}
|
||||
continue
|
||||
}
|
||||
// Avoid making a results slice in the common case where there
|
||||
// are no custom scheme resolvers.
|
||||
if result == nil {
|
||||
result = make([]resolverAndDelay, i, len(rrs))
|
||||
copy(result, rrs)
|
||||
}
|
||||
newAddr, err := handler(rr.name.Addr)
|
||||
if err != nil {
|
||||
logf("error from custom scheme handler, skipping resolver : %v", err)
|
||||
}
|
||||
if err != nil || newAddr == "" {
|
||||
continue
|
||||
}
|
||||
newResolver := *rr.name
|
||||
newResolver.Addr = newAddr
|
||||
result = append(result, resolverAndDelay{name: &newResolver, startDelay: rr.startDelay})
|
||||
}
|
||||
// If we didn't have any custom schemes, return the original rrs.
|
||||
if result == nil {
|
||||
return rrs
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// resolvers returns the resolvers to use for domain.
|
||||
func (f *forwarder) resolvers(domain dnsname.FQDN) []resolverAndDelay {
|
||||
f.mu.Lock()
|
||||
routes := f.routes
|
||||
cloudHostFallback := f.cloudHostFallback
|
||||
schemes := f.schemeCacheLocked()
|
||||
f.mu.Unlock()
|
||||
|
||||
for _, route := range routes {
|
||||
if route.Suffix == "." || route.Suffix.Contains(domain) {
|
||||
return route.Resolvers
|
||||
if route.Suffix != "." && !route.Suffix.Contains(domain) {
|
||||
continue
|
||||
}
|
||||
resolved := applySchemes(f.logf, route.Resolvers, schemes)
|
||||
// If scheme resolution filtered out all resolvers from a non-empty
|
||||
// route, fall through to the next matching route. If the resolvers
|
||||
// were configured to be empty allow resolved to be empty.
|
||||
if len(resolved) > 0 || len(route.Resolvers) == 0 {
|
||||
return resolved
|
||||
}
|
||||
}
|
||||
return cloudHostFallback // or nil if no fallback
|
||||
@@ -1021,6 +1088,39 @@ func (f *forwarder) GetUpstreamResolvers(name dnsname.FQDN) []*dnstype.Resolver
|
||||
return upstreamResolvers
|
||||
}
|
||||
|
||||
// RegisterCustomScheme adds a [CustomSchemeHandler] that is called to provide
|
||||
// an updated address when a [dnstype.Resolver.Addr] uses that scheme.
|
||||
func (f *forwarder) RegisterCustomScheme(scheme string, h CustomSchemeHandler) error {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
if _, ok := f.schemes[scheme]; ok {
|
||||
return fmt.Errorf("scheme %q already registered", scheme)
|
||||
}
|
||||
f.invalidateSchemeCacheLocked()
|
||||
mak.Set(&f.schemes, scheme, h)
|
||||
return nil
|
||||
}
|
||||
|
||||
// invalidateSchemeCacheLocked clears f.schemeCache so that it will be rebuilt
|
||||
// on the next call to f.schemeCacheLocked().
|
||||
func (f *forwarder) invalidateSchemeCacheLocked() {
|
||||
f.schemeCache = views.Map[string, CustomSchemeHandler]{}
|
||||
}
|
||||
|
||||
// schemeCacheLocked returns an immutable copy of f.schemes that can be used
|
||||
// after mu is unlocked.
|
||||
func (f *forwarder) schemeCacheLocked() views.Map[string, CustomSchemeHandler] {
|
||||
if !f.schemeCache.IsNil() {
|
||||
return f.schemeCache
|
||||
}
|
||||
if f.schemes == nil {
|
||||
return f.schemeCache // returns a nil view
|
||||
}
|
||||
// Regenerate the cache
|
||||
f.schemeCache = views.MapOf(maps.Clone(f.schemes))
|
||||
return f.schemeCache
|
||||
}
|
||||
|
||||
// forwardQuery is information and state about a forwarded DNS query that's
|
||||
// being sent to 1 or more upstreams.
|
||||
//
|
||||
|
||||
@@ -27,6 +27,7 @@ import (
|
||||
"tailscale.com/net/tsdial"
|
||||
"tailscale.com/tstest"
|
||||
"tailscale.com/types/dnstype"
|
||||
"tailscale.com/util/dnsname"
|
||||
"tailscale.com/util/eventbus/eventbustest"
|
||||
)
|
||||
|
||||
@@ -1385,3 +1386,142 @@ func TestForwarderHealthOnContextExpiry(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolversCustomScheme(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := []struct {
|
||||
name string
|
||||
domain dnsname.FQDN
|
||||
schemes map[string]CustomSchemeHandler
|
||||
routes map[dnsname.FQDN][]*dnstype.Resolver
|
||||
wantAddrs []string
|
||||
}{
|
||||
{
|
||||
name: "no-custom-scheme",
|
||||
domain: "example.com.",
|
||||
schemes: map[string]CustomSchemeHandler{},
|
||||
routes: map[dnsname.FQDN][]*dnstype.Resolver{
|
||||
"example.com.": {
|
||||
{Addr: "192.168.1.1:53"},
|
||||
{Addr: "192.168.1.2:53"},
|
||||
},
|
||||
},
|
||||
wantAddrs: []string{"192.168.1.1:53", "192.168.1.2:53"},
|
||||
},
|
||||
{
|
||||
name: "single-custom-scheme",
|
||||
domain: "example.com.",
|
||||
schemes: map[string]CustomSchemeHandler{
|
||||
"myscheme": func(string) (string, error) { return "1.2.3.4:53", nil },
|
||||
},
|
||||
routes: map[dnsname.FQDN][]*dnstype.Resolver{
|
||||
"example.com.": {{Addr: "myscheme:customKey"}},
|
||||
},
|
||||
wantAddrs: []string{"1.2.3.4:53"},
|
||||
},
|
||||
{
|
||||
name: "with-other-resolvers",
|
||||
domain: "example.com.",
|
||||
schemes: map[string]CustomSchemeHandler{
|
||||
"myscheme": func(key string) (string, error) { return "1.2.3.4:53", nil },
|
||||
},
|
||||
routes: map[dnsname.FQDN][]*dnstype.Resolver{
|
||||
"example.com.": {
|
||||
{Addr: "192.168.1.1:53"},
|
||||
{Addr: "myscheme:customKey"},
|
||||
{Addr: "192.168.1.2:53"},
|
||||
},
|
||||
},
|
||||
wantAddrs: []string{"192.168.1.1:53", "1.2.3.4:53", "192.168.1.2:53"},
|
||||
},
|
||||
{
|
||||
name: "multiple-custom-schemes",
|
||||
domain: "example.com.",
|
||||
schemes: map[string]CustomSchemeHandler{
|
||||
"schemeOne": func(string) (string, error) { return "1.2.3.4:53", nil },
|
||||
"schemeTwo": func(string) (string, error) { return "5.6.7.8:53", nil },
|
||||
},
|
||||
routes: map[dnsname.FQDN][]*dnstype.Resolver{
|
||||
"example.com.": {
|
||||
{Addr: "schemeOne:customKey"},
|
||||
{Addr: "schemeTwo:customKey"},
|
||||
},
|
||||
},
|
||||
wantAddrs: []string{"1.2.3.4:53", "5.6.7.8:53"},
|
||||
},
|
||||
{
|
||||
name: "empty-string-means-no-resolver",
|
||||
domain: "example.com.",
|
||||
schemes: map[string]CustomSchemeHandler{
|
||||
"myscheme": func(string) (string, error) { return "", nil },
|
||||
},
|
||||
routes: map[dnsname.FQDN][]*dnstype.Resolver{
|
||||
"example.com.": {
|
||||
{Addr: "192.168.1.1:53"},
|
||||
{Addr: "myscheme:customKey"},
|
||||
},
|
||||
},
|
||||
wantAddrs: []string{"192.168.1.1:53"},
|
||||
},
|
||||
{
|
||||
name: "error-means-no-resolver",
|
||||
domain: "example.com.",
|
||||
schemes: map[string]CustomSchemeHandler{
|
||||
"myscheme": func(string) (string, error) { return "", fmt.Errorf("handler error") },
|
||||
},
|
||||
routes: map[dnsname.FQDN][]*dnstype.Resolver{
|
||||
"example.com.": {
|
||||
{Addr: "192.168.1.1:53"},
|
||||
{Addr: "myscheme:customKey"},
|
||||
},
|
||||
},
|
||||
wantAddrs: []string{"192.168.1.1:53"},
|
||||
},
|
||||
{
|
||||
// If the best-matching route yields no resolvers after scheme
|
||||
// resolution, fall through to the next matching route.
|
||||
name: "empty-scheme-result-falls-through-to-next-matching-route",
|
||||
domain: "example.com.",
|
||||
schemes: map[string]CustomSchemeHandler{
|
||||
"myscheme": func(string) (string, error) { return "", nil },
|
||||
},
|
||||
routes: map[dnsname.FQDN][]*dnstype.Resolver{
|
||||
"example.com.": {{Addr: "myscheme:customKey"}},
|
||||
".": {{Addr: "192.168.1.1:53"}},
|
||||
},
|
||||
wantAddrs: []string{"192.168.1.1:53"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
logf := tstest.WhileTestRunningLogger(t)
|
||||
bus := eventbustest.NewBus(t)
|
||||
netMon, err := netmon.New(bus, logf)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
var dialer tsdial.Dialer
|
||||
dialer.SetNetMon(netMon)
|
||||
dialer.SetBus(bus)
|
||||
|
||||
fwd := newForwarder(logf, netMon, nil, &dialer, health.NewTracker(bus), nil)
|
||||
for scheme, handler := range tt.schemes {
|
||||
if err := fwd.RegisterCustomScheme(scheme, handler); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
fwd.setRoutes(tt.routes, false)
|
||||
|
||||
got := fwd.resolvers(tt.domain)
|
||||
var gotAddrs []string
|
||||
for _, r := range got {
|
||||
gotAddrs = append(gotAddrs, r.name.Addr)
|
||||
}
|
||||
if !slices.Equal(gotAddrs, tt.wantAddrs) {
|
||||
t.Errorf("got %v, want %v", gotAddrs, tt.wantAddrs)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -293,6 +293,18 @@ func (r *Resolver) SetConfig(cfg Config) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// CustomSchemeHandler takes a URI (retrieved from [dnstype.Resolver.Addr]) and
|
||||
// returns an updated URI to use for the current query. The result is only valid
|
||||
// for right now and may change over time.
|
||||
type CustomSchemeHandler func(addr string) (newAddr string, err error)
|
||||
|
||||
// RegisterCustomScheme adds a [CustomSchemaHandler] that is called to provide
|
||||
// an updated address to the forwarder when a [dnstype.Resolver.Addr] uses that
|
||||
// scheme.
|
||||
func (r *Resolver) RegisterCustomScheme(scheme string, h CustomSchemeHandler) error {
|
||||
return r.forwarder.RegisterCustomScheme(scheme, h)
|
||||
}
|
||||
|
||||
// Close shuts down the resolver and ensures poll goroutines have exited.
|
||||
// The Resolver cannot be used again after Close is called.
|
||||
func (r *Resolver) Close() {
|
||||
|
||||
Reference in New Issue
Block a user