|
|
|
|
@ -5,6 +5,7 @@ |
|
|
|
|
package dns |
|
|
|
|
|
|
|
|
|
import ( |
|
|
|
|
"errors" |
|
|
|
|
"fmt" |
|
|
|
|
"strings" |
|
|
|
|
|
|
|
|
|
@ -13,6 +14,12 @@ import ( |
|
|
|
|
"tailscale.com/types/logger" |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
const ( |
|
|
|
|
ipv4RegBase = `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters` |
|
|
|
|
ipv6RegBase = `SYSTEM\CurrentControlSet\Services\Tcpip6\Parameters` |
|
|
|
|
tsRegBase = `SOFTWARE\Tailscale IPN` |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
type windowsManager struct { |
|
|
|
|
logf logger.Logf |
|
|
|
|
guid string |
|
|
|
|
@ -25,29 +32,88 @@ func newManager(mconfig ManagerConfig) managerImpl { |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func setRegistry(path, nameservers, domains string) error { |
|
|
|
|
key, err := registry.OpenKey(registry.LOCAL_MACHINE, path, registry.READ|registry.SET_VALUE) |
|
|
|
|
func setRegistryString(path, name, value string) error { |
|
|
|
|
key, err := registry.OpenKey(registry.LOCAL_MACHINE, path, registry.SET_VALUE) |
|
|
|
|
if err != nil { |
|
|
|
|
return fmt.Errorf("opening %s: %w", path, err) |
|
|
|
|
} |
|
|
|
|
defer key.Close() |
|
|
|
|
|
|
|
|
|
err = key.SetStringValue("NameServer", nameservers) |
|
|
|
|
err = key.SetStringValue(name, value) |
|
|
|
|
if err != nil { |
|
|
|
|
return fmt.Errorf("setting %s[%s]: %w", path, name, err) |
|
|
|
|
} |
|
|
|
|
return nil |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func getRegistryString(path, name string) (string, error) { |
|
|
|
|
key, err := registry.OpenKey(registry.LOCAL_MACHINE, path, registry.READ) |
|
|
|
|
if err != nil { |
|
|
|
|
return "", fmt.Errorf("opening %s: %w", path, err) |
|
|
|
|
} |
|
|
|
|
defer key.Close() |
|
|
|
|
|
|
|
|
|
value, _, err := key.GetStringValue(name) |
|
|
|
|
if err != nil { |
|
|
|
|
return fmt.Errorf("setting %s/NameServer: %w", path, err) |
|
|
|
|
return "", fmt.Errorf("getting %s[%s]: %w", path, name, err) |
|
|
|
|
} |
|
|
|
|
return value, nil |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func (m windowsManager) setNameservers(basePath string, nameservers []string) error { |
|
|
|
|
path := fmt.Sprintf(`%s\Interfaces\%s`, basePath, m.guid) |
|
|
|
|
value := strings.Join(nameservers, ",") |
|
|
|
|
return setRegistryString(path, "NameServer", value) |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
err = key.SetStringValue("Domain", domains) |
|
|
|
|
func (m windowsManager) setDomains(path string, oldDomains, newDomains []string) error { |
|
|
|
|
// We reimplement setRegistryString to ensure that we hold the key for the whole operation.
|
|
|
|
|
key, err := registry.OpenKey(registry.LOCAL_MACHINE, path, registry.READ|registry.SET_VALUE) |
|
|
|
|
if err != nil { |
|
|
|
|
return fmt.Errorf("setting %s/Domain: %w", path, err) |
|
|
|
|
return fmt.Errorf("opening %s: %w", path, err) |
|
|
|
|
} |
|
|
|
|
defer key.Close() |
|
|
|
|
|
|
|
|
|
searchList, _, err := key.GetStringValue("SearchList") |
|
|
|
|
if err != nil && err != registry.ErrNotExist { |
|
|
|
|
return fmt.Errorf("getting %s[SearchList]: %w", path, err) |
|
|
|
|
} |
|
|
|
|
currentDomains := strings.Split(searchList, ",") |
|
|
|
|
|
|
|
|
|
var domainsToSet []string |
|
|
|
|
for _, domain := range currentDomains { |
|
|
|
|
inOld, inNew := false, false |
|
|
|
|
|
|
|
|
|
// The number of domains should be small,
|
|
|
|
|
// so this is probaly faster than constructing a map.
|
|
|
|
|
for _, oldDomain := range oldDomains { |
|
|
|
|
if domain == oldDomain { |
|
|
|
|
inOld = true |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
for _, newDomain := range newDomains { |
|
|
|
|
if domain == newDomain { |
|
|
|
|
inNew = true |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
if !inNew && !inOld { |
|
|
|
|
domainsToSet = append(domainsToSet, domain) |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
domainsToSet = append(domainsToSet, newDomains...) |
|
|
|
|
|
|
|
|
|
searchList = strings.Join(domainsToSet, ",") |
|
|
|
|
if err := key.SetStringValue("SearchList", searchList); err != nil { |
|
|
|
|
return fmt.Errorf("setting %s[SearchList]: %w", path, err) |
|
|
|
|
} |
|
|
|
|
return nil |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
func (m windowsManager) Up(config Config) error { |
|
|
|
|
var ipsv4 []string |
|
|
|
|
var ipsv6 []string |
|
|
|
|
|
|
|
|
|
for _, ip := range config.Nameservers { |
|
|
|
|
if ip.Is4() { |
|
|
|
|
ipsv4 = append(ipsv4, ip.String()) |
|
|
|
|
@ -55,23 +121,29 @@ func (m windowsManager) Up(config Config) error { |
|
|
|
|
ipsv6 = append(ipsv6, ip.String()) |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
nsv4 := strings.Join(ipsv4, ",") |
|
|
|
|
nsv6 := strings.Join(ipsv6, ",") |
|
|
|
|
|
|
|
|
|
var domains string |
|
|
|
|
if len(config.Domains) > 0 { |
|
|
|
|
if len(config.Domains) > 1 { |
|
|
|
|
m.logf("only a single search domain is supported") |
|
|
|
|
} |
|
|
|
|
domains = config.Domains[0] |
|
|
|
|
lastSearchList, err := getRegistryString(tsRegBase, "SearchList") |
|
|
|
|
if err != nil && !errors.Is(err, registry.ErrNotExist) { |
|
|
|
|
return err |
|
|
|
|
} |
|
|
|
|
lastDomains := strings.Split(lastSearchList, ",") |
|
|
|
|
|
|
|
|
|
v4Path := `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces\` + m.guid |
|
|
|
|
if err := setRegistry(v4Path, nsv4, domains); err != nil { |
|
|
|
|
if err := m.setNameservers(ipv4RegBase, ipsv4); err != nil { |
|
|
|
|
return err |
|
|
|
|
} |
|
|
|
|
if err := m.setDomains(ipv4RegBase, lastDomains, config.Domains); err != nil { |
|
|
|
|
return err |
|
|
|
|
} |
|
|
|
|
v6Path := `SYSTEM\CurrentControlSet\Services\Tcpip6\Parameters\Interfaces\` + m.guid |
|
|
|
|
if err := setRegistry(v6Path, nsv6, domains); err != nil { |
|
|
|
|
|
|
|
|
|
if err := m.setNameservers(ipv6RegBase, ipsv6); err != nil { |
|
|
|
|
return err |
|
|
|
|
} |
|
|
|
|
if err := m.setDomains(ipv6RegBase, lastDomains, config.Domains); err != nil { |
|
|
|
|
return err |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
newSearchList := strings.Join(config.Domains, ",") |
|
|
|
|
if err := setRegistryString(tsRegBase, "SearchList", newSearchList); err != nil { |
|
|
|
|
return err |
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|