net/tstun: refactor natConfig to not be per-family
This was a holdover from the older, pre-BART days and is no longer necessary. Updates #cleanup Signed-off-by: Andrew Dunham <andrew@du.nham.ca> Change-Id: I71b892bab1898077767b9ff51cef33d59c08faf8
This commit is contained in:
+77
-98
@@ -550,61 +550,14 @@ func findV6(addrs []netip.Prefix) netip.Addr {
|
|||||||
//
|
//
|
||||||
// The nil value is a valid configuration.
|
// The nil value is a valid configuration.
|
||||||
type natConfig struct {
|
type natConfig struct {
|
||||||
v4, v6 *natFamilyConfig
|
// nativeAddr4 and nativeAddr6 are the IPv4/IPv6 Tailscale Addresses of
|
||||||
}
|
// the current node.
|
||||||
|
//
|
||||||
func (c *natConfig) String() string {
|
// These are implicitly used as the address to rewrite to in the DNAT
|
||||||
if c == nil {
|
// path (as configured by listenAddrs, below). The IPv4 address will be
|
||||||
return "<nil>"
|
// used if the inbound packet is IPv4, and the IPv6 address if the
|
||||||
}
|
// inbound packet is IPv6.
|
||||||
|
nativeAddr4, nativeAddr6 netip.Addr
|
||||||
var b strings.Builder
|
|
||||||
b.WriteString("natConfig{")
|
|
||||||
fmt.Fprintf(&b, "v4: %v, ", c.v4)
|
|
||||||
fmt.Fprintf(&b, "v6: %v", c.v6)
|
|
||||||
b.WriteString("}")
|
|
||||||
return b.String()
|
|
||||||
}
|
|
||||||
|
|
||||||
// mapDstIP returns the destination IP to use for a packet to dst.
|
|
||||||
// If dst is not one of the listen addresses, it is returned as-is,
|
|
||||||
// otherwise the native address is returned.
|
|
||||||
func (c *natConfig) mapDstIP(oldDst netip.Addr) netip.Addr {
|
|
||||||
if c == nil {
|
|
||||||
return oldDst
|
|
||||||
}
|
|
||||||
if oldDst.Is4() {
|
|
||||||
return c.v4.mapDstIP(oldDst)
|
|
||||||
}
|
|
||||||
if oldDst.Is6() {
|
|
||||||
return c.v6.mapDstIP(oldDst)
|
|
||||||
}
|
|
||||||
return oldDst
|
|
||||||
}
|
|
||||||
|
|
||||||
// selectSrcIP returns the source IP to use for a packet to dst.
|
|
||||||
// If the packet is not from the native address, it is returned as-is.
|
|
||||||
func (c *natConfig) selectSrcIP(oldSrc, dst netip.Addr) netip.Addr {
|
|
||||||
if c == nil {
|
|
||||||
return oldSrc
|
|
||||||
}
|
|
||||||
if oldSrc.Is4() {
|
|
||||||
return c.v4.selectSrcIP(oldSrc, dst)
|
|
||||||
}
|
|
||||||
if oldSrc.Is6() {
|
|
||||||
return c.v6.selectSrcIP(oldSrc, dst)
|
|
||||||
}
|
|
||||||
return oldSrc
|
|
||||||
}
|
|
||||||
|
|
||||||
// natFamilyConfig is the NAT configuration for a particular
|
|
||||||
// address family.
|
|
||||||
// It should be treated as immutable.
|
|
||||||
//
|
|
||||||
// The nil value is a valid configuration.
|
|
||||||
type natFamilyConfig struct {
|
|
||||||
// nativeAddr is the Tailscale Address of the current node.
|
|
||||||
nativeAddr netip.Addr
|
|
||||||
|
|
||||||
// listenAddrs is the set of addresses that should be
|
// listenAddrs is the set of addresses that should be
|
||||||
// mapped to the native address. These are the addresses that
|
// mapped to the native address. These are the addresses that
|
||||||
@@ -620,13 +573,14 @@ type natFamilyConfig struct {
|
|||||||
masqAddrCounts map[netip.Addr]int
|
masqAddrCounts map[netip.Addr]int
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *natFamilyConfig) String() string {
|
func (c *natConfig) String() string {
|
||||||
if c == nil {
|
if c == nil {
|
||||||
return "natFamilyConfig(nil)"
|
return "natConfig(nil)"
|
||||||
}
|
}
|
||||||
var b strings.Builder
|
var b strings.Builder
|
||||||
b.WriteString("natFamilyConfig{")
|
b.WriteString("natConfig{")
|
||||||
fmt.Fprintf(&b, "nativeAddr: %v, ", c.nativeAddr)
|
fmt.Fprintf(&b, "nativeAddr4: %v, ", c.nativeAddr4)
|
||||||
|
fmt.Fprintf(&b, "nativeAddr6: %v, ", c.nativeAddr6)
|
||||||
fmt.Fprint(&b, "listenAddrs: [")
|
fmt.Fprint(&b, "listenAddrs: [")
|
||||||
|
|
||||||
i := 0
|
i := 0
|
||||||
@@ -656,23 +610,31 @@ func (c *natFamilyConfig) String() string {
|
|||||||
// mapDstIP returns the destination IP to use for a packet to dst.
|
// mapDstIP returns the destination IP to use for a packet to dst.
|
||||||
// If dst is not one of the listen addresses, it is returned as-is,
|
// If dst is not one of the listen addresses, it is returned as-is,
|
||||||
// otherwise the native address is returned.
|
// otherwise the native address is returned.
|
||||||
func (c *natFamilyConfig) mapDstIP(oldDst netip.Addr) netip.Addr {
|
func (c *natConfig) mapDstIP(oldDst netip.Addr) netip.Addr {
|
||||||
if c == nil {
|
if c == nil {
|
||||||
return oldDst
|
return oldDst
|
||||||
}
|
}
|
||||||
if _, ok := c.listenAddrs.GetOk(oldDst); ok {
|
if _, ok := c.listenAddrs.GetOk(oldDst); ok {
|
||||||
return c.nativeAddr
|
if oldDst.Is4() && c.nativeAddr4.IsValid() {
|
||||||
|
return c.nativeAddr4
|
||||||
|
}
|
||||||
|
if oldDst.Is6() && c.nativeAddr6.IsValid() {
|
||||||
|
return c.nativeAddr6
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return oldDst
|
return oldDst
|
||||||
}
|
}
|
||||||
|
|
||||||
// selectSrcIP returns the source IP to use for a packet to dst.
|
// selectSrcIP returns the source IP to use for a packet to dst.
|
||||||
// If the packet is not from the native address, it is returned as-is.
|
// If the packet is not from the native address, it is returned as-is.
|
||||||
func (c *natFamilyConfig) selectSrcIP(oldSrc, dst netip.Addr) netip.Addr {
|
func (c *natConfig) selectSrcIP(oldSrc, dst netip.Addr) netip.Addr {
|
||||||
if c == nil {
|
if c == nil {
|
||||||
return oldSrc
|
return oldSrc
|
||||||
}
|
}
|
||||||
if oldSrc != c.nativeAddr {
|
if oldSrc.Is4() && oldSrc != c.nativeAddr4 {
|
||||||
|
return oldSrc
|
||||||
|
}
|
||||||
|
if oldSrc.Is6() && oldSrc != c.nativeAddr6 {
|
||||||
return oldSrc
|
return oldSrc
|
||||||
}
|
}
|
||||||
eip, ok := c.dstMasqAddrs.Get(dst)
|
eip, ok := c.dstMasqAddrs.Get(dst)
|
||||||
@@ -682,22 +644,16 @@ func (c *natFamilyConfig) selectSrcIP(oldSrc, dst netip.Addr) netip.Addr {
|
|||||||
return eip
|
return eip
|
||||||
}
|
}
|
||||||
|
|
||||||
// natConfigFromWGConfig generates a natFamilyConfig from nm,
|
// natConfigFromWGConfig generates a natConfig from nm. If NAT is not required,
|
||||||
// for the indicated address family.
|
// it returns nil.
|
||||||
// If NAT is not required for that address family, it returns nil.
|
func natConfigFromWGConfig(wcfg *wgcfg.Config) *natConfig {
|
||||||
func natConfigFromWGConfig(wcfg *wgcfg.Config, addrFam ipproto.Version) *natFamilyConfig {
|
|
||||||
if wcfg == nil {
|
if wcfg == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var nativeAddr netip.Addr
|
nativeAddr4 := findV4(wcfg.Addresses)
|
||||||
switch addrFam {
|
nativeAddr6 := findV6(wcfg.Addresses)
|
||||||
case ipproto.Version4:
|
if !nativeAddr4.IsValid() && !nativeAddr6.IsValid() {
|
||||||
nativeAddr = findV4(wcfg.Addresses)
|
|
||||||
case ipproto.Version6:
|
|
||||||
nativeAddr = findV6(wcfg.Addresses)
|
|
||||||
}
|
|
||||||
if !nativeAddr.IsValid() {
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -714,10 +670,10 @@ func natConfigFromWGConfig(wcfg *wgcfg.Config, addrFam ipproto.Version) *natFami
|
|||||||
for _, p := range wcfg.Peers {
|
for _, p := range wcfg.Peers {
|
||||||
isExitNode := slices.Contains(p.AllowedIPs, tsaddr.AllIPv4()) || slices.Contains(p.AllowedIPs, tsaddr.AllIPv6())
|
isExitNode := slices.Contains(p.AllowedIPs, tsaddr.AllIPv4()) || slices.Contains(p.AllowedIPs, tsaddr.AllIPv6())
|
||||||
if isExitNode {
|
if isExitNode {
|
||||||
hasMasqAddrsForFamily := false ||
|
hasMasqAddr := false ||
|
||||||
(addrFam == ipproto.Version4 && p.V4MasqAddr != nil && p.V4MasqAddr.IsValid()) ||
|
(p.V4MasqAddr != nil && p.V4MasqAddr.IsValid()) ||
|
||||||
(addrFam == ipproto.Version6 && p.V6MasqAddr != nil && p.V6MasqAddr.IsValid())
|
(p.V6MasqAddr != nil && p.V6MasqAddr.IsValid())
|
||||||
if hasMasqAddrsForFamily {
|
if hasMasqAddr {
|
||||||
exitNodeRequiresMasq = true
|
exitNodeRequiresMasq = true
|
||||||
}
|
}
|
||||||
break
|
break
|
||||||
@@ -725,29 +681,56 @@ func natConfigFromWGConfig(wcfg *wgcfg.Config, addrFam ipproto.Version) *natFami
|
|||||||
}
|
}
|
||||||
for i := range wcfg.Peers {
|
for i := range wcfg.Peers {
|
||||||
p := &wcfg.Peers[i]
|
p := &wcfg.Peers[i]
|
||||||
var addrToUse netip.Addr
|
|
||||||
if addrFam == ipproto.Version4 && p.V4MasqAddr != nil && p.V4MasqAddr.IsValid() {
|
// Build a routing table that configures DNAT (i.e. changing
|
||||||
addrToUse = *p.V4MasqAddr
|
// the V4MasqAddr/V6MasqAddr for a given peer to the current
|
||||||
mak.Set(&listenAddrs, addrToUse, struct{}{})
|
// peer's v4/v6 IP).
|
||||||
} else if addrFam == ipproto.Version6 && p.V6MasqAddr != nil && p.V6MasqAddr.IsValid() {
|
var addrToUse4, addrToUse6 netip.Addr
|
||||||
addrToUse = *p.V6MasqAddr
|
if p.V4MasqAddr != nil && p.V4MasqAddr.IsValid() {
|
||||||
mak.Set(&listenAddrs, addrToUse, struct{}{})
|
addrToUse4 = *p.V4MasqAddr
|
||||||
} else if exitNodeRequiresMasq {
|
mak.Set(&listenAddrs, addrToUse4, struct{}{})
|
||||||
addrToUse = nativeAddr
|
masqAddrCounts[addrToUse4]++
|
||||||
} else {
|
}
|
||||||
|
if p.V6MasqAddr != nil && p.V6MasqAddr.IsValid() {
|
||||||
|
addrToUse6 = *p.V6MasqAddr
|
||||||
|
mak.Set(&listenAddrs, addrToUse6, struct{}{})
|
||||||
|
masqAddrCounts[addrToUse6]++
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the exit node requires masquerading, set the masquerade
|
||||||
|
// addresses to our native addresses.
|
||||||
|
if exitNodeRequiresMasq {
|
||||||
|
if !addrToUse4.IsValid() && nativeAddr4.IsValid() {
|
||||||
|
addrToUse4 = nativeAddr4
|
||||||
|
}
|
||||||
|
if !addrToUse6.IsValid() && nativeAddr6.IsValid() {
|
||||||
|
addrToUse6 = nativeAddr6
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !addrToUse4.IsValid() && !addrToUse6.IsValid() {
|
||||||
|
// NAT not required for this peer.
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
masqAddrCounts[addrToUse]++
|
// Build the SNAT table that maps each AllowedIP to the
|
||||||
|
// masquerade address.
|
||||||
for _, ip := range p.AllowedIPs {
|
for _, ip := range p.AllowedIPs {
|
||||||
rt.Insert(ip, addrToUse)
|
is4 := ip.Addr().Is4()
|
||||||
|
if is4 && addrToUse4.IsValid() {
|
||||||
|
rt.Insert(ip, addrToUse4)
|
||||||
|
}
|
||||||
|
if !is4 && addrToUse6.IsValid() {
|
||||||
|
rt.Insert(ip, addrToUse6)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if len(listenAddrs) == 0 && len(masqAddrCounts) == 0 {
|
if len(listenAddrs) == 0 && len(masqAddrCounts) == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
return &natFamilyConfig{
|
return &natConfig{
|
||||||
nativeAddr: nativeAddr,
|
nativeAddr4: nativeAddr4,
|
||||||
|
nativeAddr6: nativeAddr6,
|
||||||
listenAddrs: views.MapOf(listenAddrs),
|
listenAddrs: views.MapOf(listenAddrs),
|
||||||
dstMasqAddrs: &rt,
|
dstMasqAddrs: &rt,
|
||||||
masqAddrCounts: masqAddrCounts,
|
masqAddrCounts: masqAddrCounts,
|
||||||
@@ -756,11 +739,7 @@ func natConfigFromWGConfig(wcfg *wgcfg.Config, addrFam ipproto.Version) *natFami
|
|||||||
|
|
||||||
// SetNetMap is called when a new NetworkMap is received.
|
// SetNetMap is called when a new NetworkMap is received.
|
||||||
func (t *Wrapper) SetWGConfig(wcfg *wgcfg.Config) {
|
func (t *Wrapper) SetWGConfig(wcfg *wgcfg.Config) {
|
||||||
v4, v6 := natConfigFromWGConfig(wcfg, ipproto.Version4), natConfigFromWGConfig(wcfg, ipproto.Version6)
|
cfg := natConfigFromWGConfig(wcfg)
|
||||||
var cfg *natConfig
|
|
||||||
if v4 != nil || v6 != nil {
|
|
||||||
cfg = &natConfig{v4: v4, v6: v6}
|
|
||||||
}
|
|
||||||
|
|
||||||
old := t.natConfig.Swap(cfg)
|
old := t.natConfig.Swap(cfg)
|
||||||
if !reflect.DeepEqual(old, cfg) {
|
if !reflect.DeepEqual(old, cfg) {
|
||||||
|
|||||||
@@ -601,6 +601,8 @@ func TestFilterDiscoLoop(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO(andrew-d): refactor this test to no longer use addrFam, after #11945
|
||||||
|
// removed it in natConfigFromWGConfig
|
||||||
func TestNATCfg(t *testing.T) {
|
func TestNATCfg(t *testing.T) {
|
||||||
node := func(ip, masqIP netip.Addr, otherAllowedIPs ...netip.Prefix) wgcfg.Peer {
|
node := func(ip, masqIP netip.Addr, otherAllowedIPs ...netip.Prefix) wgcfg.Peer {
|
||||||
p := wgcfg.Peer{
|
p := wgcfg.Peer{
|
||||||
@@ -800,7 +802,7 @@ func TestNATCfg(t *testing.T) {
|
|||||||
|
|
||||||
for _, tc := range tests {
|
for _, tc := range tests {
|
||||||
t.Run(fmt.Sprintf("%v/%v", addrFam, tc.name), func(t *testing.T) {
|
t.Run(fmt.Sprintf("%v/%v", addrFam, tc.name), func(t *testing.T) {
|
||||||
ncfg := natConfigFromWGConfig(tc.wcfg, addrFam)
|
ncfg := natConfigFromWGConfig(tc.wcfg)
|
||||||
for peer, want := range tc.snatMap {
|
for peer, want := range tc.snatMap {
|
||||||
if got := ncfg.selectSrcIP(selfNativeIP, peer); got != want {
|
if got := ncfg.selectSrcIP(selfNativeIP, peer); got != want {
|
||||||
t.Errorf("selectSrcIP[%v]: got %v; want %v", peer, got, want)
|
t.Errorf("selectSrcIP[%v]: got %v; want %v", peer, got, want)
|
||||||
|
|||||||
Reference in New Issue
Block a user