net/memnet: rename from net/nettest
This is just #cleanup to resolve a TODO Also add a package doc. Signed-off-by: Maisem Ali <maisem@tailscale.com>
This commit is contained in:
@@ -0,0 +1,110 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package memnet
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Conn is a net.Conn that can additionally have its reads and writes blocked and unblocked.
|
||||
type Conn interface {
|
||||
net.Conn
|
||||
|
||||
// SetReadBlock blocks or unblocks the Read method of this Conn.
|
||||
// It reports an error if the existing value matches the new value,
|
||||
// or if the Conn has been Closed.
|
||||
SetReadBlock(bool) error
|
||||
|
||||
// SetWriteBlock blocks or unblocks the Write method of this Conn.
|
||||
// It reports an error if the existing value matches the new value,
|
||||
// or if the Conn has been Closed.
|
||||
SetWriteBlock(bool) error
|
||||
}
|
||||
|
||||
// NewConn creates a pair of Conns that are wired together by pipes.
|
||||
func NewConn(name string, maxBuf int) (Conn, Conn) {
|
||||
r := NewPipe(name+"|0", maxBuf)
|
||||
w := NewPipe(name+"|1", maxBuf)
|
||||
|
||||
return &connHalf{r: r, w: w}, &connHalf{r: w, w: r}
|
||||
}
|
||||
|
||||
// NewTCPConn creates a pair of Conns that are wired together by pipes.
|
||||
func NewTCPConn(src, dst netip.AddrPort, maxBuf int) (local Conn, remote Conn) {
|
||||
r := NewPipe(src.String(), maxBuf)
|
||||
w := NewPipe(dst.String(), maxBuf)
|
||||
|
||||
lAddr := net.TCPAddrFromAddrPort(src)
|
||||
rAddr := net.TCPAddrFromAddrPort(dst)
|
||||
|
||||
return &connHalf{r: r, w: w, remote: rAddr, local: lAddr}, &connHalf{r: w, w: r, remote: lAddr, local: rAddr}
|
||||
}
|
||||
|
||||
type connAddr string
|
||||
|
||||
func (a connAddr) Network() string { return "mem" }
|
||||
func (a connAddr) String() string { return string(a) }
|
||||
|
||||
type connHalf struct {
|
||||
local, remote net.Addr
|
||||
r, w *Pipe
|
||||
}
|
||||
|
||||
func (c *connHalf) LocalAddr() net.Addr {
|
||||
if c.local != nil {
|
||||
return c.local
|
||||
}
|
||||
return connAddr(c.r.name)
|
||||
}
|
||||
|
||||
func (c *connHalf) RemoteAddr() net.Addr {
|
||||
if c.remote != nil {
|
||||
return c.remote
|
||||
}
|
||||
return connAddr(c.w.name)
|
||||
}
|
||||
|
||||
func (c *connHalf) Read(b []byte) (n int, err error) {
|
||||
return c.r.Read(b)
|
||||
}
|
||||
func (c *connHalf) Write(b []byte) (n int, err error) {
|
||||
return c.w.Write(b)
|
||||
}
|
||||
|
||||
func (c *connHalf) Close() error {
|
||||
if err := c.w.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
return c.r.Close()
|
||||
}
|
||||
|
||||
func (c *connHalf) SetDeadline(t time.Time) error {
|
||||
err1 := c.SetReadDeadline(t)
|
||||
err2 := c.SetWriteDeadline(t)
|
||||
if err1 != nil {
|
||||
return err1
|
||||
}
|
||||
return err2
|
||||
}
|
||||
func (c *connHalf) SetReadDeadline(t time.Time) error {
|
||||
return c.r.SetReadDeadline(t)
|
||||
}
|
||||
func (c *connHalf) SetWriteDeadline(t time.Time) error {
|
||||
return c.w.SetWriteDeadline(t)
|
||||
}
|
||||
|
||||
func (c *connHalf) SetReadBlock(b bool) error {
|
||||
if b {
|
||||
return c.r.Block()
|
||||
}
|
||||
return c.r.Unblock()
|
||||
}
|
||||
func (c *connHalf) SetWriteBlock(b bool) error {
|
||||
if b {
|
||||
return c.w.Block()
|
||||
}
|
||||
return c.w.Unblock()
|
||||
}
|
||||
@@ -0,0 +1,21 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package memnet
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"golang.org/x/net/nettest"
|
||||
)
|
||||
|
||||
func TestConn(t *testing.T) {
|
||||
nettest.TestConn(t, func() (c1 net.Conn, c2 net.Conn, stop func(), err error) {
|
||||
c1, c2 = NewConn("test", bufferSize)
|
||||
return c1, c2, func() {
|
||||
c1.Close()
|
||||
c2.Close()
|
||||
}, nil
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,89 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package memnet
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
const (
|
||||
bufferSize = 256 * 1024
|
||||
)
|
||||
|
||||
// Listener is a net.Listener using NewConn to create pairs of network
|
||||
// connections connected in memory using a buffered pipe. It also provides a
|
||||
// Dial method to establish new connections.
|
||||
type Listener struct {
|
||||
addr connAddr
|
||||
ch chan Conn
|
||||
closeOnce sync.Once
|
||||
closed chan struct{}
|
||||
}
|
||||
|
||||
// Listen returns a new Listener for the provided address.
|
||||
func Listen(addr string) *Listener {
|
||||
return &Listener{
|
||||
addr: connAddr(addr),
|
||||
ch: make(chan Conn),
|
||||
closed: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Addr implements net.Listener.Addr.
|
||||
func (l *Listener) Addr() net.Addr {
|
||||
return l.addr
|
||||
}
|
||||
|
||||
// Close closes the pipe listener.
|
||||
func (l *Listener) Close() error {
|
||||
l.closeOnce.Do(func() {
|
||||
close(l.closed)
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// Accept blocks until a new connection is available or the listener is closed.
|
||||
func (l *Listener) Accept() (net.Conn, error) {
|
||||
select {
|
||||
case c := <-l.ch:
|
||||
return c, nil
|
||||
case <-l.closed:
|
||||
return nil, net.ErrClosed
|
||||
}
|
||||
}
|
||||
|
||||
// Dial connects to the listener using the provided context.
|
||||
// The provided Context must be non-nil. If the context expires before the
|
||||
// connection is complete, an error is returned. Once successfully connected
|
||||
// any expiration of the context will not affect the connection.
|
||||
func (l *Listener) Dial(ctx context.Context, network, addr string) (_ net.Conn, err error) {
|
||||
if !strings.HasSuffix(network, "tcp") {
|
||||
return nil, net.UnknownNetworkError(network)
|
||||
}
|
||||
if connAddr(addr) != l.addr {
|
||||
return nil, &net.AddrError{
|
||||
Err: "invalid address",
|
||||
Addr: addr,
|
||||
}
|
||||
}
|
||||
c, s := NewConn(addr, bufferSize)
|
||||
defer func() {
|
||||
if err != nil {
|
||||
c.Close()
|
||||
s.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-l.closed:
|
||||
return nil, net.ErrClosed
|
||||
case l.ch <- s:
|
||||
return c, nil
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,33 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package memnet
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestListener(t *testing.T) {
|
||||
l := Listen("srv.local")
|
||||
defer l.Close()
|
||||
go func() {
|
||||
c, err := l.Accept()
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
defer c.Close()
|
||||
}()
|
||||
|
||||
if c, err := l.Dial(context.Background(), "tcp", "invalid"); err == nil {
|
||||
c.Close()
|
||||
t.Fatalf("dial to invalid address succeeded")
|
||||
}
|
||||
c, err := l.Dial(context.Background(), "tcp", "srv.local")
|
||||
if err != nil {
|
||||
t.Fatalf("dial failed: %v", err)
|
||||
return
|
||||
}
|
||||
c.Close()
|
||||
}
|
||||
@@ -0,0 +1,8 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
// Package memnet implements an in-memory network implementation.
|
||||
// It is useful for dialing and listening on in-memory addresses
|
||||
// in tests and other situations where you don't want to use the
|
||||
// network.
|
||||
package memnet
|
||||
@@ -0,0 +1,244 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package memnet
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const debugPipe = false
|
||||
|
||||
// Pipe implements an in-memory FIFO with timeouts.
|
||||
type Pipe struct {
|
||||
name string
|
||||
maxBuf int
|
||||
mu sync.Mutex
|
||||
cnd *sync.Cond
|
||||
|
||||
blocked bool
|
||||
closed bool
|
||||
buf bytes.Buffer
|
||||
readTimeout time.Time
|
||||
writeTimeout time.Time
|
||||
cancelReadTimer func()
|
||||
cancelWriteTimer func()
|
||||
}
|
||||
|
||||
// NewPipe creates a Pipe with a buffer size fixed at maxBuf.
|
||||
func NewPipe(name string, maxBuf int) *Pipe {
|
||||
p := &Pipe{
|
||||
name: name,
|
||||
maxBuf: maxBuf,
|
||||
}
|
||||
p.cnd = sync.NewCond(&p.mu)
|
||||
return p
|
||||
}
|
||||
|
||||
// readOrBlock attempts to read from the buffer, if the buffer is empty and
|
||||
// the connection hasn't been closed it will block until there is a change.
|
||||
func (p *Pipe) readOrBlock(b []byte) (int, error) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
if !p.readTimeout.IsZero() && !time.Now().Before(p.readTimeout) {
|
||||
return 0, os.ErrDeadlineExceeded
|
||||
}
|
||||
if p.blocked {
|
||||
p.cnd.Wait()
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
n, err := p.buf.Read(b)
|
||||
// err will either be nil or io.EOF.
|
||||
if err == io.EOF {
|
||||
if p.closed {
|
||||
return n, err
|
||||
}
|
||||
// Wait for something to change.
|
||||
p.cnd.Wait()
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// Read implements io.Reader.
|
||||
// Once the buffer is drained (i.e. after Close), subsequent calls will
|
||||
// return io.EOF.
|
||||
func (p *Pipe) Read(b []byte) (n int, err error) {
|
||||
if debugPipe {
|
||||
orig := b
|
||||
defer func() {
|
||||
log.Printf("Pipe(%q).Read(%q) n=%d, err=%v", p.name, string(orig[:n]), n, err)
|
||||
}()
|
||||
}
|
||||
for n == 0 {
|
||||
n2, err := p.readOrBlock(b)
|
||||
if err != nil {
|
||||
return n2, err
|
||||
}
|
||||
n += n2
|
||||
}
|
||||
p.cnd.Signal()
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// writeOrBlock attempts to write to the buffer, if the buffer is full it will
|
||||
// block until there is a change.
|
||||
func (p *Pipe) writeOrBlock(b []byte) (int, error) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
if p.closed {
|
||||
return 0, net.ErrClosed
|
||||
}
|
||||
if !p.writeTimeout.IsZero() && !time.Now().Before(p.writeTimeout) {
|
||||
return 0, os.ErrDeadlineExceeded
|
||||
}
|
||||
if p.blocked {
|
||||
p.cnd.Wait()
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// Optimistically we want to write the entire slice.
|
||||
n := len(b)
|
||||
if limit := p.maxBuf - p.buf.Len(); limit < n {
|
||||
// However, we don't have enough capacity to write everything.
|
||||
n = limit
|
||||
}
|
||||
if n == 0 {
|
||||
// Wait for something to change.
|
||||
p.cnd.Wait()
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
p.buf.Write(b[:n])
|
||||
p.cnd.Signal()
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// Write implements io.Writer.
|
||||
func (p *Pipe) Write(b []byte) (n int, err error) {
|
||||
if debugPipe {
|
||||
orig := b
|
||||
defer func() {
|
||||
log.Printf("Pipe(%q).Write(%q) n=%d, err=%v", p.name, string(orig), n, err)
|
||||
}()
|
||||
}
|
||||
for len(b) > 0 {
|
||||
n2, err := p.writeOrBlock(b)
|
||||
if err != nil {
|
||||
return n + n2, err
|
||||
}
|
||||
n += n2
|
||||
b = b[n2:]
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// Close closes the pipe.
|
||||
func (p *Pipe) Close() error {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
p.closed = true
|
||||
p.blocked = false
|
||||
if p.cancelWriteTimer != nil {
|
||||
p.cancelWriteTimer()
|
||||
p.cancelWriteTimer = nil
|
||||
}
|
||||
if p.cancelReadTimer != nil {
|
||||
p.cancelReadTimer()
|
||||
p.cancelReadTimer = nil
|
||||
}
|
||||
p.cnd.Broadcast()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Pipe) deadlineTimer(t time.Time) func() {
|
||||
if t.IsZero() {
|
||||
return nil
|
||||
}
|
||||
if t.Before(time.Now()) {
|
||||
p.cnd.Broadcast()
|
||||
return nil
|
||||
}
|
||||
ctx, cancel := context.WithDeadline(context.Background(), t)
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
if ctx.Err() == context.DeadlineExceeded {
|
||||
p.cnd.Broadcast()
|
||||
}
|
||||
}()
|
||||
return cancel
|
||||
}
|
||||
|
||||
// SetReadDeadline sets the deadline for future Read calls.
|
||||
func (p *Pipe) SetReadDeadline(t time.Time) error {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
p.readTimeout = t
|
||||
// If we already have a deadline, cancel it and create a new one.
|
||||
if p.cancelReadTimer != nil {
|
||||
p.cancelReadTimer()
|
||||
p.cancelReadTimer = nil
|
||||
}
|
||||
p.cancelReadTimer = p.deadlineTimer(t)
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetWriteDeadline sets the deadline for future Write calls.
|
||||
func (p *Pipe) SetWriteDeadline(t time.Time) error {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
p.writeTimeout = t
|
||||
// If we already have a deadline, cancel it and create a new one.
|
||||
if p.cancelWriteTimer != nil {
|
||||
p.cancelWriteTimer()
|
||||
p.cancelWriteTimer = nil
|
||||
}
|
||||
p.cancelWriteTimer = p.deadlineTimer(t)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Block will cause all calls to Read and Write to block until they either
|
||||
// timeout, are unblocked or the pipe is closed.
|
||||
func (p *Pipe) Block() error {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
closed := p.closed
|
||||
blocked := p.blocked
|
||||
p.blocked = true
|
||||
|
||||
if closed {
|
||||
return fmt.Errorf("memnet.Pipe(%q).Block: closed", p.name)
|
||||
}
|
||||
if blocked {
|
||||
return fmt.Errorf("memnet.Pipe(%q).Block: already blocked", p.name)
|
||||
}
|
||||
p.cnd.Broadcast()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Unblock will cause all blocked Read/Write calls to continue execution.
|
||||
func (p *Pipe) Unblock() error {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
closed := p.closed
|
||||
blocked := p.blocked
|
||||
p.blocked = false
|
||||
|
||||
if closed {
|
||||
return fmt.Errorf("memnet.Pipe(%q).Block: closed", p.name)
|
||||
}
|
||||
if !blocked {
|
||||
return fmt.Errorf("memnet.Pipe(%q).Block: already unblocked", p.name)
|
||||
}
|
||||
p.cnd.Broadcast()
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,117 @@
|
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
package memnet
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestPipeHello(t *testing.T) {
|
||||
p := NewPipe("p1", 1<<16)
|
||||
msg := "Hello, World!"
|
||||
if n, err := p.Write([]byte(msg)); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if n != len(msg) {
|
||||
t.Errorf("p.Write(%q) n=%d, want %d", msg, n, len(msg))
|
||||
}
|
||||
b := make([]byte, len(msg))
|
||||
if n, err := p.Read(b); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if n != len(b) {
|
||||
t.Errorf("p.Read(%q) n=%d, want %d", string(b[:n]), n, len(b))
|
||||
}
|
||||
if got := string(b); got != msg {
|
||||
t.Errorf("p.Read: %q, want %q", got, msg)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPipeTimeout(t *testing.T) {
|
||||
t.Run("write", func(t *testing.T) {
|
||||
p := NewPipe("p1", 1<<16)
|
||||
p.SetWriteDeadline(time.Now().Add(-1 * time.Second))
|
||||
n, err := p.Write([]byte{'h'})
|
||||
if !errors.Is(err, os.ErrDeadlineExceeded) {
|
||||
t.Errorf("missing write timeout got err: %v", err)
|
||||
}
|
||||
if n != 0 {
|
||||
t.Errorf("n=%d on timeout", n)
|
||||
}
|
||||
})
|
||||
t.Run("read", func(t *testing.T) {
|
||||
p := NewPipe("p1", 1<<16)
|
||||
p.Write([]byte{'h'})
|
||||
|
||||
p.SetReadDeadline(time.Now().Add(-1 * time.Second))
|
||||
b := make([]byte, 1)
|
||||
n, err := p.Read(b)
|
||||
if !errors.Is(err, os.ErrDeadlineExceeded) {
|
||||
t.Errorf("missing read timeout got err: %v", err)
|
||||
}
|
||||
if n != 0 {
|
||||
t.Errorf("n=%d on timeout", n)
|
||||
}
|
||||
})
|
||||
t.Run("block-write", func(t *testing.T) {
|
||||
p := NewPipe("p1", 1<<16)
|
||||
p.SetWriteDeadline(time.Now().Add(10 * time.Millisecond))
|
||||
if err := p.Block(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := p.Write([]byte{'h'}); !errors.Is(err, os.ErrDeadlineExceeded) {
|
||||
t.Fatalf("want write timeout got: %v", err)
|
||||
}
|
||||
})
|
||||
t.Run("block-read", func(t *testing.T) {
|
||||
p := NewPipe("p1", 1<<16)
|
||||
p.Write([]byte{'h', 'i'})
|
||||
p.SetReadDeadline(time.Now().Add(10 * time.Millisecond))
|
||||
b := make([]byte, 1)
|
||||
if err := p.Block(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := p.Read(b); !errors.Is(err, os.ErrDeadlineExceeded) {
|
||||
t.Fatalf("want read timeout got: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestLimit(t *testing.T) {
|
||||
p := NewPipe("p1", 1)
|
||||
errCh := make(chan error)
|
||||
go func() {
|
||||
n, err := p.Write([]byte{'a', 'b', 'c'})
|
||||
if err != nil {
|
||||
errCh <- err
|
||||
} else if n != 3 {
|
||||
errCh <- fmt.Errorf("p.Write n=%d, want 3", n)
|
||||
} else {
|
||||
errCh <- nil
|
||||
}
|
||||
}()
|
||||
b := make([]byte, 3)
|
||||
|
||||
if n, err := p.Read(b); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if n != 1 {
|
||||
t.Errorf("Read(%q): n=%d want 1", string(b), n)
|
||||
}
|
||||
if n, err := p.Read(b); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if n != 1 {
|
||||
t.Errorf("Read(%q): n=%d want 1", string(b), n)
|
||||
}
|
||||
if n, err := p.Read(b); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if n != 1 {
|
||||
t.Errorf("Read(%q): n=%d want 1", string(b), n)
|
||||
}
|
||||
|
||||
if err := <-errCh; err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user