net/nettest: make nettest.NewConn pass x/net/nettest.TestConn.

Signed-off-by: Maisem Ali <maisem@tailscale.com>
This commit is contained in:
Maisem Ali
2021-04-02 20:04:52 -07:00
committed by Maisem Ali
parent e0e677a8f6
commit 57756ef673
6 changed files with 284 additions and 154 deletions
+118 -134
View File
@@ -5,11 +5,13 @@
package nettest
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"log"
"net"
"os"
"sync"
"time"
)
@@ -20,13 +22,12 @@ const debugPipe = false
type Pipe struct {
name string
maxBuf int
rCh chan struct{}
wCh chan struct{}
mu sync.Mutex
cnd *sync.Cond
mu sync.Mutex
closed bool
blocked bool
buf []byte
closed bool
buf bytes.Buffer
readTimeout time.Time
writeTimeout time.Time
cancelReadTimer func()
@@ -35,21 +36,42 @@ type Pipe struct {
// NewPipe creates a Pipe with a buffer size fixed at maxBuf.
func NewPipe(name string, maxBuf int) *Pipe {
return &Pipe{
p := &Pipe{
name: name,
maxBuf: maxBuf,
rCh: make(chan struct{}, 1),
wCh: make(chan struct{}, 1),
}
p.cnd = sync.NewCond(&p.mu)
return p
}
var (
ErrTimeout = errors.New("timeout")
ErrReadTimeout = fmt.Errorf("read %w", ErrTimeout)
ErrWriteTimeout = fmt.Errorf("write %w", ErrTimeout)
)
// readOrBlock attempts to read from the buffer, if the buffer is empty and
// the connection hasn't been closed it will block until there is a change.
func (p *Pipe) readOrBlock(b []byte) (int, error) {
p.mu.Lock()
defer p.mu.Unlock()
if !p.readTimeout.IsZero() && !time.Now().Before(p.readTimeout) {
return 0, os.ErrDeadlineExceeded
}
if p.blocked {
p.cnd.Wait()
return 0, nil
}
n, err := p.buf.Read(b)
// err will either be nil or io.EOF.
if err == io.EOF {
if p.closed {
return n, err
}
// Wait for something to change.
p.cnd.Wait()
}
return n, nil
}
// Read implements io.Reader.
// Once the buffer is drained (i.e. after Close), subsequent calls will
// return io.EOF.
func (p *Pipe) Read(b []byte) (n int, err error) {
if debugPipe {
orig := b
@@ -57,35 +79,48 @@ func (p *Pipe) Read(b []byte) (n int, err error) {
log.Printf("Pipe(%q).Read( %q) n=%d, err=%v", p.name, string(orig[:n]), n, err)
}()
}
for {
p.mu.Lock()
closed := p.closed
timedout := !p.readTimeout.IsZero() && !time.Now().Before(p.readTimeout)
blocked := p.blocked
if !closed && !timedout && len(p.buf) > 0 {
n2 := copy(b, p.buf)
p.buf = p.buf[n2:]
b = b[n2:]
n += n2
for n == 0 {
n2, err := p.readOrBlock(b)
if err != nil {
return n2, err
}
p.mu.Unlock()
if closed {
return 0, fmt.Errorf("nettest.Pipe(%q): closed: %w", p.name, io.EOF)
}
if timedout {
return 0, fmt.Errorf("nettest.Pipe(%q): %w", p.name, ErrReadTimeout)
}
if blocked {
<-p.rCh
continue
}
if n > 0 {
p.signalWrite()
return n, nil
}
<-p.rCh
n += n2
}
p.cnd.Signal()
return n, nil
}
// writeOrBlock attempts to write to the buffer, if the buffer is full it will
// block until there is a change.
func (p *Pipe) writeOrBlock(b []byte) (int, error) {
p.mu.Lock()
defer p.mu.Unlock()
if p.closed {
return 0, net.ErrClosed
}
if !p.writeTimeout.IsZero() && !time.Now().Before(p.writeTimeout) {
return 0, os.ErrDeadlineExceeded
}
if p.blocked {
p.cnd.Wait()
return 0, nil
}
// Optimistically we want to write the entire slice.
n := len(b)
if limit := p.maxBuf - p.buf.Len(); limit < n {
// However, we don't have enough capacity to write everything.
n = limit
}
if n == 0 {
// Wait for something to change.
p.cnd.Wait()
return 0, nil
}
p.buf.Write(b[:n])
p.cnd.Signal()
return n, nil
}
// Write implements io.Writer.
@@ -96,47 +131,23 @@ func (p *Pipe) Write(b []byte) (n int, err error) {
log.Printf("Pipe(%q).Write(%q) n=%d, err=%v", p.name, string(orig), n, err)
}()
}
for {
p.mu.Lock()
closed := p.closed
timedout := !p.writeTimeout.IsZero() && !time.Now().Before(p.writeTimeout)
blocked := p.blocked
if !closed && !timedout {
n2 := len(b)
if limit := p.maxBuf - len(p.buf); limit < n2 {
n2 = limit
}
p.buf = append(p.buf, b[:n2]...)
b = b[n2:]
n += n2
for len(b) > 0 {
n2, err := p.writeOrBlock(b)
if err != nil {
return n + n2, err
}
p.mu.Unlock()
if closed {
return n, fmt.Errorf("nettest.Pipe(%q): closed: %w", p.name, io.EOF)
}
if timedout {
return n, fmt.Errorf("nettest.Pipe(%q): %w", p.name, ErrWriteTimeout)
}
if blocked {
<-p.wCh
continue
}
if n > 0 {
p.signalRead()
}
if len(b) == 0 {
return n, nil
}
<-p.wCh
n += n2
b = b[n2:]
}
return n, nil
}
// Close implements io.Closer.
// Close closes the pipe.
func (p *Pipe) Close() error {
p.mu.Lock()
closed := p.closed
defer p.mu.Unlock()
p.closed = true
p.blocked = false
if p.cancelWriteTimer != nil {
p.cancelWriteTimer()
p.cancelWriteTimer = nil
@@ -145,77 +156,65 @@ func (p *Pipe) Close() error {
p.cancelReadTimer()
p.cancelReadTimer = nil
}
p.mu.Unlock()
p.cnd.Broadcast()
if closed {
return fmt.Errorf("nettest.Pipe(%q).Close: already closed", p.name)
}
p.signalRead()
p.signalWrite()
return nil
}
func (p *Pipe) deadlineTimer(t time.Time) func() {
if t.IsZero() {
return nil
}
if t.Before(time.Now()) {
p.cnd.Broadcast()
return nil
}
ctx, cancel := context.WithDeadline(context.Background(), t)
go func() {
<-ctx.Done()
if ctx.Err() == context.DeadlineExceeded {
p.cnd.Broadcast()
}
}()
return cancel
}
// SetReadDeadline sets the deadline for future Read calls.
func (p *Pipe) SetReadDeadline(t time.Time) error {
p.mu.Lock()
defer p.mu.Unlock()
p.readTimeout = t
// If we already have a deadline, cancel it and create a new one.
if p.cancelReadTimer != nil {
p.cancelReadTimer()
p.cancelReadTimer = nil
}
if d := time.Until(t); !t.IsZero() && d > 0 {
ctx, cancel := context.WithCancel(context.Background())
p.cancelReadTimer = cancel
go func() {
t := time.NewTimer(d)
defer t.Stop()
select {
case <-t.C:
p.signalRead()
case <-ctx.Done():
}
}()
}
p.mu.Unlock()
p.signalRead()
p.cancelReadTimer = p.deadlineTimer(t)
return nil
}
// SetWriteDeadline sets the deadline for future Write calls.
func (p *Pipe) SetWriteDeadline(t time.Time) error {
p.mu.Lock()
defer p.mu.Unlock()
p.writeTimeout = t
// If we already have a deadline, cancel it and create a new one.
if p.cancelWriteTimer != nil {
p.cancelWriteTimer()
p.cancelWriteTimer = nil
}
if d := time.Until(t); !t.IsZero() && d > 0 {
ctx, cancel := context.WithCancel(context.Background())
p.cancelWriteTimer = cancel
go func() {
t := time.NewTimer(d)
defer t.Stop()
select {
case <-t.C:
p.signalWrite()
case <-ctx.Done():
}
}()
}
p.mu.Unlock()
p.signalWrite()
p.cancelWriteTimer = p.deadlineTimer(t)
return nil
}
// Block will cause all calls to Read and Write to block until they either
// timeout, are unblocked or the pipe is closed.
func (p *Pipe) Block() error {
p.mu.Lock()
defer p.mu.Unlock()
closed := p.closed
blocked := p.blocked
p.blocked = true
p.mu.Unlock()
if closed {
return fmt.Errorf("nettest.Pipe(%q).Block: closed", p.name)
@@ -223,17 +222,17 @@ func (p *Pipe) Block() error {
if blocked {
return fmt.Errorf("nettest.Pipe(%q).Block: already blocked", p.name)
}
p.signalRead()
p.signalWrite()
p.cnd.Broadcast()
return nil
}
// Unblock will cause all blocked Read/Write calls to continue execution.
func (p *Pipe) Unblock() error {
p.mu.Lock()
defer p.mu.Unlock()
closed := p.closed
blocked := p.blocked
p.blocked = false
p.mu.Unlock()
if closed {
return fmt.Errorf("nettest.Pipe(%q).Block: closed", p.name)
@@ -241,21 +240,6 @@ func (p *Pipe) Unblock() error {
if !blocked {
return fmt.Errorf("nettest.Pipe(%q).Block: already unblocked", p.name)
}
p.signalRead()
p.signalWrite()
p.cnd.Broadcast()
return nil
}
func (p *Pipe) signalRead() {
select {
case p.rCh <- struct{}{}:
default:
}
}
func (p *Pipe) signalWrite() {
select {
case p.wCh <- struct{}{}:
default:
}
}