|
|
|
|
@ -26,6 +26,7 @@ import ( |
|
|
|
|
"reflect" |
|
|
|
|
"strings" |
|
|
|
|
"sync" |
|
|
|
|
"sync/atomic" |
|
|
|
|
"testing" |
|
|
|
|
"time" |
|
|
|
|
|
|
|
|
|
@ -652,23 +653,23 @@ func TestFallbackTCPHandler(t *testing.T) { |
|
|
|
|
} |
|
|
|
|
t.Logf("ping success: %#+v", res) |
|
|
|
|
|
|
|
|
|
s1TcpConnCount := 0 |
|
|
|
|
var s1TcpConnCount atomic.Int32 |
|
|
|
|
deregister := s1.RegisterFallbackTCPHandler(func(src, dst netip.AddrPort) (handler func(net.Conn), intercept bool) { |
|
|
|
|
s1TcpConnCount++ |
|
|
|
|
s1TcpConnCount.Add(1) |
|
|
|
|
return nil, false |
|
|
|
|
}) |
|
|
|
|
|
|
|
|
|
if _, err = s2.Dial(ctx, "tcp", fmt.Sprintf("%s:8081", s1ip)); err == nil { |
|
|
|
|
if _, err := s2.Dial(ctx, "tcp", fmt.Sprintf("%s:8081", s1ip)); err == nil { |
|
|
|
|
t.Fatal("Expected dial error because fallback handler did not intercept") |
|
|
|
|
} |
|
|
|
|
if s1TcpConnCount != 1 { |
|
|
|
|
t.Errorf("s1TcpConnCount = %d, want %d", s1TcpConnCount, 1) |
|
|
|
|
if got := s1TcpConnCount.Load(); got != 1 { |
|
|
|
|
t.Errorf("s1TcpConnCount = %d, want %d", got, 1) |
|
|
|
|
} |
|
|
|
|
deregister() |
|
|
|
|
if _, err = s2.Dial(ctx, "tcp", fmt.Sprintf("%s:8081", s1ip)); err == nil { |
|
|
|
|
if _, err := s2.Dial(ctx, "tcp", fmt.Sprintf("%s:8081", s1ip)); err == nil { |
|
|
|
|
t.Fatal("Expected dial error because nothing would intercept") |
|
|
|
|
} |
|
|
|
|
if s1TcpConnCount != 1 { |
|
|
|
|
t.Errorf("s1TcpConnCount = %d, want %d", s1TcpConnCount, 1) |
|
|
|
|
if got := s1TcpConnCount.Load(); got != 1 { |
|
|
|
|
t.Errorf("s1TcpConnCount = %d, want %d", got, 1) |
|
|
|
|
} |
|
|
|
|
} |
|
|
|
|
|