cmd/k8s-operator,k8s-operator/sessionrecording: support recording kubectl exec sessions over WebSockets (#12947)
cmd/k8s-operator,k8s-operator/sessionrecording: support recording WebSocket sessions Kubernetes currently supports two streaming protocols, SPDY and WebSockets. WebSockets are replacing SPDY, see https://github.com/kubernetes/enhancements/issues/4006. We were currently only supporting SPDY, erroring out if session was not SPDY and relying on the kube's built-in SPDY fallback. This PR: - adds support for parsing contents of 'kubectl exec' sessions streamed over WebSockets - adds logic to distinguish 'kubectl exec' requests for a SPDY/WebSockets sessions and call the relevant handler Updates tailscale/corp#19821 Signed-off-by: Irbe Krumina <irbe@tailscale.com> Co-authored-by: Tom Proctor <tomhjp@users.noreply.github.com>main
parent
4c2e978f1e
commit
a15ff1bade
@ -1,20 +0,0 @@ |
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
//go:build !plan9
|
||||
|
||||
// Package conn contains shared interface for the hijacked
|
||||
// connection of a 'kubectl exec' session that is being recorded.
|
||||
package conn |
||||
|
||||
import "net" |
||||
|
||||
type Conn interface { |
||||
net.Conn |
||||
// Fail can be called to set connection state to failed. By default any
|
||||
// bytes left over in write buffer are forwarded to the intended
|
||||
// destination when the connection is being closed except for when the
|
||||
// connection state is failed- so set the state to failed when erroring
|
||||
// out and failure policy is to fail closed.
|
||||
Fail() |
||||
} |
||||
@ -0,0 +1,301 @@ |
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
//go:build !plan9
|
||||
|
||||
// package ws has functionality to parse 'kubectl exec' sessions streamed using
|
||||
// WebSocket protocol.
|
||||
package ws |
||||
|
||||
import ( |
||||
"bytes" |
||||
"encoding/json" |
||||
"errors" |
||||
"fmt" |
||||
"io" |
||||
"net" |
||||
"sync" |
||||
|
||||
"go.uber.org/zap" |
||||
"k8s.io/apimachinery/pkg/util/remotecommand" |
||||
"tailscale.com/k8s-operator/sessionrecording/tsrecorder" |
||||
"tailscale.com/sessionrecording" |
||||
"tailscale.com/util/multierr" |
||||
) |
||||
|
||||
// New wraps the provided network connection and returns a connection whose reads and writes will get triggered as data is received on the hijacked connection.
|
||||
// The connection must be a hijacked connection for a 'kubectl exec' session using WebSocket protocol and a *.channel.k8s.io subprotocol.
|
||||
// The hijacked connection is used to transmit *.channel.k8s.io streams between Kubernetes client ('kubectl') and the destination proxy controlled by Kubernetes.
|
||||
// Data read from the underlying network connection is data sent via one of the streams from the client to the container.
|
||||
// Data written to the underlying connection is data sent from the container to the client.
|
||||
// We parse the data and send everything for the STDOUT/STDERR streams to the configured tsrecorder as an asciinema recording with the provided header.
|
||||
// https://github.com/kubernetes/enhancements/tree/master/keps/sig-api-machinery/4006-transition-spdy-to-websockets#proposal-new-remotecommand-sub-protocol-version---v5channelk8sio
|
||||
func New(c net.Conn, rec *tsrecorder.Client, ch sessionrecording.CastHeader, log *zap.SugaredLogger) net.Conn { |
||||
return &conn{ |
||||
Conn: c, |
||||
rec: rec, |
||||
ch: ch, |
||||
log: log, |
||||
} |
||||
} |
||||
|
||||
// conn is a wrapper around net.Conn. It reads the bytestream
|
||||
// for a 'kubectl exec' session, sends session recording data to the configured
|
||||
// recorder and forwards the raw bytes to the original destination.
|
||||
// A new conn is created per session.
|
||||
// conn only knows to how to read a 'kubectl exec' session that is streamed using WebSocket protocol.
|
||||
// https://www.rfc-editor.org/rfc/rfc6455
|
||||
type conn struct { |
||||
net.Conn |
||||
// rec knows how to send data to a tsrecorder instance.
|
||||
rec *tsrecorder.Client |
||||
// ch is the asiinema CastHeader for a session.
|
||||
ch sessionrecording.CastHeader |
||||
log *zap.SugaredLogger |
||||
|
||||
rmu sync.Mutex // sequences reads
|
||||
// currentReadMsg contains parsed contents of a websocket binary data message that
|
||||
// is currently being read from the underlying net.Conn.
|
||||
currentReadMsg *message |
||||
// readBuf contains bytes for a currently parsed binary data message
|
||||
// read from the underlying conn. If the message is masked, it is
|
||||
// unmasked in place, so having this buffer allows us to avoid modifying
|
||||
// the original byte array.
|
||||
readBuf bytes.Buffer |
||||
|
||||
wmu sync.Mutex // sequences writes
|
||||
writeCastHeaderOnce sync.Once |
||||
closed bool // connection is closed
|
||||
// writeBuf contains bytes for a currently parsed binary data message
|
||||
// being written to the underlying conn. If the message is masked, it is
|
||||
// unmasked in place, so having this buffer allows us to avoid modifying
|
||||
// the original byte array.
|
||||
writeBuf bytes.Buffer |
||||
// currentWriteMsg contains parsed contents of a websocket binary data message that
|
||||
// is currently being written to the underlying net.Conn.
|
||||
currentWriteMsg *message |
||||
} |
||||
|
||||
// Read reads bytes from the original connection and parses them as websocket
|
||||
// message fragments.
|
||||
// Bytes read from the original connection are the bytes sent from the Kubernetes client (kubectl) to the destination container via kubelet.
|
||||
|
||||
// If the message is for the resize stream, sets the width
|
||||
// and height of the CastHeader for this connection.
|
||||
// The fragment can be incomplete.
|
||||
func (c *conn) Read(b []byte) (int, error) { |
||||
c.rmu.Lock() |
||||
defer c.rmu.Unlock() |
||||
n, err := c.Conn.Read(b) |
||||
if err != nil { |
||||
// It seems that we sometimes get a wrapped io.EOF, but the
|
||||
// caller checks for io.EOF with ==.
|
||||
if errors.Is(err, io.EOF) { |
||||
err = io.EOF |
||||
} |
||||
return 0, err |
||||
} |
||||
if n == 0 { |
||||
c.log.Debug("[unexpected] Read called for 0 length bytes") |
||||
return 0, nil |
||||
} |
||||
|
||||
typ := messageType(opcode(b)) |
||||
if (typ == noOpcode && c.readMsgIsIncomplete()) || c.readBufHasIncompleteFragment() { // subsequent fragment
|
||||
if typ, err = c.curReadMsgType(); err != nil { |
||||
return 0, err |
||||
} |
||||
} |
||||
|
||||
// A control message can not be fragmented and we are not interested in
|
||||
// these messages. Just return.
|
||||
if isControlMessage(typ) { |
||||
return n, nil |
||||
} |
||||
|
||||
// The only data message type that Kubernetes supports is binary message.
|
||||
// If we received another message type, return and let the API server close the connection.
|
||||
// https://github.com/kubernetes/client-go/blob/release-1.30/tools/remotecommand/websocket.go#L281
|
||||
if typ != binaryMessage { |
||||
c.log.Infof("[unexpected] received a data message with a type that is not binary message type %v", typ) |
||||
return n, nil |
||||
} |
||||
|
||||
readMsg := &message{typ: typ} // start a new message...
|
||||
// ... or pick up an already started one if the previous fragment was not final.
|
||||
if c.readMsgIsIncomplete() || c.readBufHasIncompleteFragment() { |
||||
readMsg = c.currentReadMsg |
||||
} |
||||
|
||||
if _, err := c.readBuf.Write(b[:n]); err != nil { |
||||
return 0, fmt.Errorf("[unexpected] error writing message contents to read buffer: %w", err) |
||||
} |
||||
|
||||
ok, err := readMsg.Parse(c.readBuf.Bytes(), c.log) |
||||
if err != nil { |
||||
return 0, fmt.Errorf("error parsing message: %v", err) |
||||
} |
||||
if !ok { // incomplete fragment
|
||||
return n, nil |
||||
} |
||||
c.readBuf.Next(len(readMsg.raw)) |
||||
|
||||
if readMsg.isFinalized { |
||||
// Stream IDs for websocket streams are static.
|
||||
// https://github.com/kubernetes/client-go/blob/v0.30.0-rc.1/tools/remotecommand/websocket.go#L218
|
||||
if readMsg.streamID.Load() == remotecommand.StreamResize { |
||||
var err error |
||||
var msg tsrecorder.ResizeMsg |
||||
if err = json.Unmarshal(readMsg.payload, &msg); err != nil { |
||||
return 0, fmt.Errorf("error umarshalling resize message: %w", err) |
||||
} |
||||
c.ch.Width = msg.Width |
||||
c.ch.Height = msg.Height |
||||
} |
||||
} |
||||
c.currentReadMsg = readMsg |
||||
return n, nil |
||||
} |
||||
|
||||
// Write parses the written bytes as WebSocket message fragment. If the message
|
||||
// is for stdout or stderr streams, it is written to the configured tsrecorder.
|
||||
// A message fragment can be incomplete.
|
||||
func (c *conn) Write(b []byte) (int, error) { |
||||
c.wmu.Lock() |
||||
defer c.wmu.Unlock() |
||||
if len(b) == 0 { |
||||
c.log.Debug("[unexpected] Write called with 0 bytes") |
||||
return 0, nil |
||||
} |
||||
|
||||
typ := messageType(opcode(b)) |
||||
// If we are in process of parsing a message fragment, the received
|
||||
// bytes are not structured as a message fragment and can not be used to
|
||||
// determine a message fragment.
|
||||
if c.writeBufHasIncompleteFragment() { // buffer contains previous incomplete fragment
|
||||
var err error |
||||
if typ, err = c.curWriteMsgType(); err != nil { |
||||
return 0, err |
||||
} |
||||
} |
||||
|
||||
if isControlMessage(typ) { |
||||
return c.Conn.Write(b) |
||||
} |
||||
|
||||
writeMsg := &message{typ: typ} // start a new message...
|
||||
// ... or continue the existing one if it has not been finalized.
|
||||
if c.writeMsgIsIncomplete() || c.writeBufHasIncompleteFragment() { |
||||
writeMsg = c.currentWriteMsg |
||||
} |
||||
|
||||
if _, err := c.writeBuf.Write(b); err != nil { |
||||
c.log.Errorf("write: error writing to write buf: %v", err) |
||||
return 0, fmt.Errorf("[unexpected] error writing to internal write buffer: %w", err) |
||||
} |
||||
|
||||
ok, err := writeMsg.Parse(c.writeBuf.Bytes(), c.log) |
||||
if err != nil { |
||||
c.log.Errorf("write: parsing a message errored: %v", err) |
||||
return 0, fmt.Errorf("write: error parsing message: %v", err) |
||||
} |
||||
c.currentWriteMsg = writeMsg |
||||
if !ok { // incomplete fragment
|
||||
return len(b), nil |
||||
} |
||||
c.writeBuf.Next(len(writeMsg.raw)) // advance frame
|
||||
|
||||
if len(writeMsg.payload) != 0 && writeMsg.isFinalized { |
||||
if writeMsg.streamID.Load() == remotecommand.StreamStdOut || writeMsg.streamID.Load() == remotecommand.StreamStdErr { |
||||
var err error |
||||
c.writeCastHeaderOnce.Do(func() { |
||||
var j []byte |
||||
j, err = json.Marshal(c.ch) |
||||
if err != nil { |
||||
c.log.Errorf("error marhsalling conn: %v", err) |
||||
return |
||||
} |
||||
j = append(j, '\n') |
||||
err = c.rec.WriteCastLine(j) |
||||
if err != nil { |
||||
c.log.Errorf("received error from recorder: %v", err) |
||||
} |
||||
}) |
||||
if err != nil { |
||||
return 0, fmt.Errorf("error writing CastHeader: %w", err) |
||||
} |
||||
if err := c.rec.Write(writeMsg.payload); err != nil { |
||||
return 0, fmt.Errorf("error writing message to recorder: %v", err) |
||||
} |
||||
} |
||||
} |
||||
_, err = c.Conn.Write(c.currentWriteMsg.raw) |
||||
if err != nil { |
||||
c.log.Errorf("write: error writing to conn: %v", err) |
||||
} |
||||
return len(b), nil |
||||
} |
||||
|
||||
func (c *conn) Close() error { |
||||
c.wmu.Lock() |
||||
defer c.wmu.Unlock() |
||||
if c.closed { |
||||
return nil |
||||
} |
||||
c.closed = true |
||||
connCloseErr := c.Conn.Close() |
||||
recCloseErr := c.rec.Close() |
||||
return multierr.New(connCloseErr, recCloseErr) |
||||
} |
||||
|
||||
// writeBufHasIncompleteFragment returns true if the latest data message
|
||||
// fragment written to the connection was incomplete and the following write
|
||||
// must be the remaining payload bytes of that fragment.
|
||||
func (c *conn) writeBufHasIncompleteFragment() bool { |
||||
return c.writeBuf.Len() != 0 |
||||
} |
||||
|
||||
// readBufHasIncompleteFragment returns true if the latest data message
|
||||
// fragment read from the connection was incomplete and the following read
|
||||
// must be the remaining payload bytes of that fragment.
|
||||
func (c *conn) readBufHasIncompleteFragment() bool { |
||||
return c.readBuf.Len() != 0 |
||||
} |
||||
|
||||
// writeMsgIsIncomplete returns true if the latest WebSocket message written to
|
||||
// the connection was fragmented and the next data message fragment written to
|
||||
// the connection must be a fragment of that message.
|
||||
// https://www.rfc-editor.org/rfc/rfc6455#section-5.4
|
||||
func (c *conn) writeMsgIsIncomplete() bool { |
||||
return c.currentWriteMsg != nil && !c.currentWriteMsg.isFinalized |
||||
} |
||||
|
||||
// readMsgIsIncomplete returns true if the latest WebSocket message written to
|
||||
// the connection was fragmented and the next data message fragment written to
|
||||
// the connection must be a fragment of that message.
|
||||
// https://www.rfc-editor.org/rfc/rfc6455#section-5.4
|
||||
func (c *conn) readMsgIsIncomplete() bool { |
||||
return c.currentReadMsg != nil && !c.currentReadMsg.isFinalized |
||||
} |
||||
func (c *conn) curReadMsgType() (messageType, error) { |
||||
if c.currentReadMsg != nil { |
||||
return c.currentReadMsg.typ, nil |
||||
} |
||||
return 0, errors.New("[unexpected] attempted to determine type for nil message") |
||||
} |
||||
|
||||
func (c *conn) curWriteMsgType() (messageType, error) { |
||||
if c.currentWriteMsg != nil { |
||||
return c.currentWriteMsg.typ, nil |
||||
} |
||||
return 0, errors.New("[unexpected] attempted to determine type for nil message") |
||||
} |
||||
|
||||
// opcode reads the websocket message opcode that denotes the message type.
|
||||
// opcode is contained in bits [4-8] of the message.
|
||||
// https://www.rfc-editor.org/rfc/rfc6455#section-5.2
|
||||
func opcode(b []byte) int { |
||||
// 0xf = 00001111; b & 00001111 zeroes out bits [0 - 3] of b
|
||||
var mask byte = 0xf |
||||
return int(b[0] & mask) |
||||
} |
||||
@ -0,0 +1,257 @@ |
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
//go:build !plan9
|
||||
|
||||
package ws |
||||
|
||||
import ( |
||||
"fmt" |
||||
"reflect" |
||||
"testing" |
||||
|
||||
"go.uber.org/zap" |
||||
"k8s.io/apimachinery/pkg/util/remotecommand" |
||||
"tailscale.com/k8s-operator/sessionrecording/fakes" |
||||
"tailscale.com/k8s-operator/sessionrecording/tsrecorder" |
||||
"tailscale.com/sessionrecording" |
||||
"tailscale.com/tstest" |
||||
) |
||||
|
||||
func Test_conn_Read(t *testing.T) { |
||||
zl, err := zap.NewDevelopment() |
||||
if err != nil { |
||||
t.Fatal(err) |
||||
} |
||||
// Resize stream ID + {"width": 10, "height": 20}
|
||||
testResizeMsg := []byte{byte(remotecommand.StreamResize), 0x7b, 0x22, 0x77, 0x69, 0x64, 0x74, 0x68, 0x22, 0x3a, 0x31, 0x30, 0x2c, 0x22, 0x68, 0x65, 0x69, 0x67, 0x68, 0x74, 0x22, 0x3a, 0x32, 0x30, 0x7d} |
||||
lenResizeMsgPayload := byte(len(testResizeMsg)) |
||||
|
||||
tests := []struct { |
||||
name string |
||||
inputs [][]byte |
||||
wantWidth int |
||||
wantHeight int |
||||
}{ |
||||
{ |
||||
name: "single_read_control_message", |
||||
inputs: [][]byte{{0x88, 0x0}}, |
||||
}, |
||||
{ |
||||
name: "single_read_resize_message", |
||||
inputs: [][]byte{append([]byte{0x82, lenResizeMsgPayload}, testResizeMsg...)}, |
||||
wantWidth: 10, |
||||
wantHeight: 20, |
||||
}, |
||||
{ |
||||
name: "two_reads_resize_message", |
||||
inputs: [][]byte{{0x2, 0x9, 0x4, 0x7b, 0x22, 0x77, 0x69, 0x64, 0x74, 0x68, 0x22}, {0x80, 0x11, 0x4, 0x3a, 0x31, 0x30, 0x2c, 0x22, 0x68, 0x65, 0x69, 0x67, 0x68, 0x74, 0x22, 0x3a, 0x32, 0x30, 0x7d}}, |
||||
wantWidth: 10, |
||||
wantHeight: 20, |
||||
}, |
||||
{ |
||||
name: "three_reads_resize_message_with_split_fragment", |
||||
inputs: [][]byte{{0x2, 0x9, 0x4, 0x7b, 0x22, 0x77, 0x69, 0x64, 0x74, 0x68, 0x22}, {0x80, 0x11, 0x4, 0x3a, 0x31, 0x30, 0x2c, 0x22, 0x68, 0x65, 0x69, 0x67, 0x68, 0x74}, {0x22, 0x3a, 0x32, 0x30, 0x7d}}, |
||||
wantWidth: 10, |
||||
wantHeight: 20, |
||||
}, |
||||
} |
||||
for _, tt := range tests { |
||||
t.Run(tt.name, func(t *testing.T) { |
||||
tc := &fakes.TestConn{} |
||||
tc.ResetReadBuf() |
||||
c := &conn{ |
||||
Conn: tc, |
||||
log: zl.Sugar(), |
||||
} |
||||
for i, input := range tt.inputs { |
||||
if err := tc.WriteReadBufBytes(input); err != nil { |
||||
t.Fatalf("writing bytes to test conn: %v", err) |
||||
} |
||||
_, err := c.Read(make([]byte, len(input))) |
||||
if err != nil { |
||||
t.Errorf("[%d] conn.Read() errored %v", i, err) |
||||
return |
||||
} |
||||
} |
||||
if tt.wantHeight != 0 || tt.wantWidth != 0 { |
||||
if tt.wantWidth != c.ch.Width { |
||||
t.Errorf("wants width: %v, got %v", tt.wantWidth, c.ch.Width) |
||||
} |
||||
if tt.wantHeight != c.ch.Height { |
||||
t.Errorf("want height: %v, got %v", tt.wantHeight, c.ch.Height) |
||||
} |
||||
} |
||||
}) |
||||
} |
||||
} |
||||
|
||||
func Test_conn_Write(t *testing.T) { |
||||
zl, err := zap.NewDevelopment() |
||||
if err != nil { |
||||
t.Fatal(err) |
||||
} |
||||
cl := tstest.NewClock(tstest.ClockOpts{}) |
||||
tests := []struct { |
||||
name string |
||||
inputs [][]byte |
||||
wantForwarded []byte |
||||
wantRecorded []byte |
||||
firstWrite bool |
||||
width int |
||||
height int |
||||
}{ |
||||
{ |
||||
name: "single_write_control_frame", |
||||
inputs: [][]byte{{0x88, 0x0}}, |
||||
wantForwarded: []byte{0x88, 0x0}, |
||||
}, |
||||
{ |
||||
name: "single_write_stdout_data_message", |
||||
inputs: [][]byte{{0x82, 0x3, 0x1, 0x7, 0x8}}, |
||||
wantForwarded: []byte{0x82, 0x3, 0x1, 0x7, 0x8}, |
||||
wantRecorded: fakes.CastLine(t, []byte{0x7, 0x8}, cl), |
||||
}, |
||||
{ |
||||
name: "single_write_stderr_data_message", |
||||
inputs: [][]byte{{0x82, 0x3, 0x2, 0x7, 0x8}}, |
||||
wantForwarded: []byte{0x82, 0x3, 0x2, 0x7, 0x8}, |
||||
wantRecorded: fakes.CastLine(t, []byte{0x7, 0x8}, cl), |
||||
}, |
||||
{ |
||||
name: "single_write_stdin_data_message", |
||||
inputs: [][]byte{{0x82, 0x3, 0x0, 0x7, 0x8}}, |
||||
wantForwarded: []byte{0x82, 0x3, 0x0, 0x7, 0x8}, |
||||
}, |
||||
{ |
||||
name: "single_write_stdout_data_message_with_cast_header", |
||||
inputs: [][]byte{{0x82, 0x3, 0x1, 0x7, 0x8}}, |
||||
wantForwarded: []byte{0x82, 0x3, 0x1, 0x7, 0x8}, |
||||
wantRecorded: append(fakes.AsciinemaResizeMsg(t, 10, 20), fakes.CastLine(t, []byte{0x7, 0x8}, cl)...), |
||||
width: 10, |
||||
height: 20, |
||||
firstWrite: true, |
||||
}, |
||||
{ |
||||
name: "two_writes_stdout_data_message", |
||||
inputs: [][]byte{{0x2, 0x3, 0x1, 0x7, 0x8}, {0x80, 0x6, 0x1, 0x1, 0x2, 0x3, 0x4, 0x5}}, |
||||
wantForwarded: []byte{0x2, 0x3, 0x1, 0x7, 0x8, 0x80, 0x6, 0x1, 0x1, 0x2, 0x3, 0x4, 0x5}, |
||||
wantRecorded: fakes.CastLine(t, []byte{0x7, 0x8, 0x1, 0x2, 0x3, 0x4, 0x5}, cl), |
||||
}, |
||||
{ |
||||
name: "three_writes_stdout_data_message_with_split_fragment", |
||||
inputs: [][]byte{{0x2, 0x3, 0x1, 0x7, 0x8}, {0x80, 0x6, 0x1, 0x1, 0x2, 0x3}, {0x4, 0x5}}, |
||||
wantForwarded: []byte{0x2, 0x3, 0x1, 0x7, 0x8, 0x80, 0x6, 0x1, 0x1, 0x2, 0x3, 0x4, 0x5}, |
||||
wantRecorded: fakes.CastLine(t, []byte{0x7, 0x8, 0x1, 0x2, 0x3, 0x4, 0x5}, cl), |
||||
}, |
||||
} |
||||
for _, tt := range tests { |
||||
t.Run(tt.name, func(t *testing.T) { |
||||
tc := &fakes.TestConn{} |
||||
sr := &fakes.TestSessionRecorder{} |
||||
rec := tsrecorder.New(sr, cl, cl.Now(), true) |
||||
c := &conn{ |
||||
Conn: tc, |
||||
log: zl.Sugar(), |
||||
ch: sessionrecording.CastHeader{ |
||||
Width: tt.width, |
||||
Height: tt.height, |
||||
}, |
||||
rec: rec, |
||||
} |
||||
if !tt.firstWrite { |
||||
// This test case does not intend to test that cast header gets written once.
|
||||
c.writeCastHeaderOnce.Do(func() {}) |
||||
} |
||||
for i, input := range tt.inputs { |
||||
_, err := c.Write(input) |
||||
if err != nil { |
||||
t.Fatalf("[%d] conn.Write() errored: %v", i, err) |
||||
} |
||||
} |
||||
// Assert that the expected bytes have been forwarded to the original destination.
|
||||
gotForwarded := tc.WriteBufBytes() |
||||
if !reflect.DeepEqual(gotForwarded, tt.wantForwarded) { |
||||
t.Errorf("expected bytes not forwarded, wants\n%x\ngot\n%x", tt.wantForwarded, gotForwarded) |
||||
} |
||||
|
||||
// Assert that the expected bytes have been forwarded to the session recorder.
|
||||
gotRecorded := sr.Bytes() |
||||
if !reflect.DeepEqual(gotRecorded, tt.wantRecorded) { |
||||
t.Errorf("expected bytes not recorded, wants\n%b\ngot\n%b", tt.wantRecorded, gotRecorded) |
||||
} |
||||
}) |
||||
} |
||||
} |
||||
|
||||
// Test_conn_ReadRand tests reading arbitrarily generated byte slices from conn to
|
||||
// test that we don't panic when parsing input from a broken or malicious
|
||||
// client.
|
||||
func Test_conn_ReadRand(t *testing.T) { |
||||
zl, err := zap.NewDevelopment() |
||||
if err != nil { |
||||
t.Fatalf("error creating a test logger: %v", err) |
||||
} |
||||
for i := range 100 { |
||||
tc := &fakes.TestConn{} |
||||
tc.ResetReadBuf() |
||||
c := &conn{ |
||||
Conn: tc, |
||||
log: zl.Sugar(), |
||||
} |
||||
bb := fakes.RandomBytes(t) |
||||
for j, input := range bb { |
||||
if err := tc.WriteReadBufBytes(input); err != nil { |
||||
t.Fatalf("[%d] writing bytes to test conn: %v", i, err) |
||||
} |
||||
f := func() { |
||||
c.Read(make([]byte, len(input))) |
||||
} |
||||
testPanic(t, f, fmt.Sprintf("[%d %d] Read panic parsing input of length %d first bytes: %v, current read message: %+#v", i, j, len(input), firstBytes(input), c.currentReadMsg)) |
||||
} |
||||
} |
||||
} |
||||
|
||||
// Test_conn_WriteRand calls conn.Write with an arbitrary input to validate that it does not
|
||||
// panic.
|
||||
func Test_conn_WriteRand(t *testing.T) { |
||||
zl, err := zap.NewDevelopment() |
||||
if err != nil { |
||||
t.Fatalf("error creating a test logger: %v", err) |
||||
} |
||||
cl := tstest.NewClock(tstest.ClockOpts{}) |
||||
sr := &fakes.TestSessionRecorder{} |
||||
rec := tsrecorder.New(sr, cl, cl.Now(), true) |
||||
for i := range 100 { |
||||
tc := &fakes.TestConn{} |
||||
c := &conn{ |
||||
Conn: tc, |
||||
log: zl.Sugar(), |
||||
rec: rec, |
||||
} |
||||
bb := fakes.RandomBytes(t) |
||||
for j, input := range bb { |
||||
f := func() { |
||||
c.Write(input) |
||||
} |
||||
testPanic(t, f, fmt.Sprintf("[%d %d] Write: panic parsing input of length %d first bytes %b current write message %+#v", i, j, len(input), firstBytes(input), c.currentWriteMsg)) |
||||
} |
||||
} |
||||
} |
||||
|
||||
func testPanic(t *testing.T, f func(), msg string) { |
||||
t.Helper() |
||||
defer func() { |
||||
if r := recover(); r != nil { |
||||
t.Fatal(msg, r) |
||||
} |
||||
}() |
||||
f() |
||||
} |
||||
|
||||
func firstBytes(b []byte) []byte { |
||||
if len(b) < 10 { |
||||
return b |
||||
} |
||||
return b[:10] |
||||
} |
||||
@ -0,0 +1,267 @@ |
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
//go:build !plan9
|
||||
|
||||
package ws |
||||
|
||||
import ( |
||||
"encoding/binary" |
||||
"fmt" |
||||
"sync/atomic" |
||||
|
||||
"github.com/pkg/errors" |
||||
"go.uber.org/zap" |
||||
|
||||
"golang.org/x/net/websocket" |
||||
) |
||||
|
||||
const ( |
||||
noOpcode messageType = 0 // continuation frame for fragmented messages
|
||||
binaryMessage messageType = 2 |
||||
) |
||||
|
||||
// messageType is the type of a websocket data or control message as defined by opcode.
|
||||
// https://www.rfc-editor.org/rfc/rfc6455#section-5.2
|
||||
// Known types of control messages are close, ping and pong.
|
||||
// https://www.rfc-editor.org/rfc/rfc6455#section-5.5
|
||||
// The only data message type supported by Kubernetes is binary message
|
||||
// https://github.com/kubernetes/client-go/blob/v0.30.0-rc.1/tools/remotecommand/websocket.go#L281
|
||||
type messageType int |
||||
|
||||
// message is a parsed Websocket Message.
|
||||
type message struct { |
||||
// payload is the contents of the so far parsed Websocket
|
||||
// data Message payload, potentially from multiple fragments written by
|
||||
// multiple invocations of Parse. As per RFC 6455 We can assume that the
|
||||
// fragments will always arrive in order and data messages will not be
|
||||
// interleaved.
|
||||
payload []byte |
||||
|
||||
// isFinalized is set to true if msgPayload contains full contents of
|
||||
// the message (the final fragment has been received).
|
||||
isFinalized bool |
||||
|
||||
// streamID is the stream to which the message belongs, i.e stdin, stout
|
||||
// etc. It is one of the stream IDs defined in
|
||||
// https://github.com/kubernetes/apimachinery/blob/73d12d09c5be8703587b5127416eb83dc3b7e182/pkg/util/httpstream/wsstream/doc.go#L23-L36
|
||||
streamID atomic.Uint32 |
||||
|
||||
// typ is the type of a WebsocketMessage as defined by its opcode
|
||||
// https://www.rfc-editor.org/rfc/rfc6455#section-5.2
|
||||
typ messageType |
||||
raw []byte |
||||
} |
||||
|
||||
// Parse accepts a websocket message fragment as a byte slice and parses its contents.
|
||||
// It returns true if the fragment is complete, false if the fragment is incomplete.
|
||||
// If the fragment is incomplete, Parse will be called again with the same fragment + more bytes when those are received.
|
||||
// If the fragment is complete, it will be parsed into msg.
|
||||
// A complete fragment can be:
|
||||
// - a fragment that consists of a whole message
|
||||
// - an initial fragment for a message for which we expect more fragments
|
||||
// - a subsequent fragment for a message that we are currently parsing and whose so-far parsed contents are stored in msg.
|
||||
// Parse must not be called with bytes that don't contain fragment header (so, no less than 2 bytes).
|
||||
// 0 1 2 3
|
||||
// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
|
||||
// +-+-+-+-+-------+-+-------------+-------------------------------+
|
||||
// |F|R|R|R| opcode|M| Payload len | Extended payload length |
|
||||
// |I|S|S|S| (4) |A| (7) | (16/64) |
|
||||
// |N|V|V|V| |S| | (if payload len==126/127) |
|
||||
// | |1|2|3| |K| | |
|
||||
// +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - +
|
||||
// | Extended payload length continued, if payload len == 127 |
|
||||
// + - - - - - - - - - - - - - - - +-------------------------------+
|
||||
// | |Masking-key, if MASK set to 1 |
|
||||
// +-------------------------------+-------------------------------+
|
||||
// | Masking-key (continued) | Payload Data |
|
||||
// +-------------------------------- - - - - - - - - - - - - - - - +
|
||||
// : Payload Data continued ... :
|
||||
// + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +
|
||||
// | Payload Data continued ... |
|
||||
// +---------------------------------------------------------------+
|
||||
// https://www.rfc-editor.org/rfc/rfc6455#section-5.2
|
||||
//
|
||||
// Fragmentation rules:
|
||||
// An unfragmented message consists of a single frame with the FIN
|
||||
// bit set (Section 5.2) and an opcode other than 0.
|
||||
// A fragmented message consists of a single frame with the FIN bit
|
||||
// clear and an opcode other than 0, followed by zero or more frames
|
||||
// with the FIN bit clear and the opcode set to 0, and terminated by
|
||||
// a single frame with the FIN bit set and an opcode of 0.
|
||||
// https://www.rfc-editor.org/rfc/rfc6455#section-5.4
|
||||
func (msg *message) Parse(b []byte, log *zap.SugaredLogger) (bool, error) { |
||||
if len(b) < 2 { |
||||
return false, fmt.Errorf("[unexpected] Parse should not be called with less than 2 bytes, got %d bytes", len(b)) |
||||
} |
||||
if msg.typ != binaryMessage { |
||||
return false, fmt.Errorf("[unexpected] internal error: attempted to parse a message with type %d", msg.typ) |
||||
} |
||||
isInitialFragment := len(msg.raw) == 0 |
||||
|
||||
msg.isFinalized = isFinalFragment(b) |
||||
|
||||
maskSet := isMasked(b) |
||||
|
||||
payloadLength, payloadOffset, maskOffset, err := fragmentDimensions(b, maskSet) |
||||
if err != nil { |
||||
return false, fmt.Errorf("error determining payload length: %w", err) |
||||
} |
||||
log.Debugf("parse: parsing a message fragment with payload length: %d payload offset: %d maskOffset: %d mask set: %t, is finalized: %t, is initial fragment: %t", payloadLength, payloadOffset, maskOffset, maskSet, msg.isFinalized, isInitialFragment) |
||||
|
||||
if len(b) < int(payloadOffset+payloadLength) { // incomplete fragment
|
||||
return false, nil |
||||
} |
||||
// TODO (irbekrm): perhaps only do this extra allocation if we know we
|
||||
// will need to unmask?
|
||||
msg.raw = make([]byte, int(payloadOffset)+int(payloadLength)) |
||||
copy(msg.raw, b[:payloadOffset+payloadLength]) |
||||
|
||||
// Extract the payload.
|
||||
msgPayload := b[payloadOffset : payloadOffset+payloadLength] |
||||
|
||||
// Unmask the payload if needed.
|
||||
// TODO (irbekrm): instead of unmasking all of the payload each time,
|
||||
// determine if the payload is for a resize message early and skip
|
||||
// unmasking the remaining bytes if not.
|
||||
if maskSet { |
||||
m := b[maskOffset:payloadOffset] |
||||
var mask [4]byte |
||||
copy(mask[:], m) |
||||
maskBytes(mask, msgPayload) |
||||
} |
||||
|
||||
// Determine what stream the message is for. Stream ID of a Kubernetes
|
||||
// streaming session is a 32bit integer, stored in the first byte of the
|
||||
// message payload.
|
||||
// https://github.com/kubernetes/apimachinery/commit/73d12d09c5be8703587b5127416eb83dc3b7e182#diff-291f96e8632d04d2d20f5fb00f6b323492670570d65434e8eac90c7a442d13bdR23-R36
|
||||
if len(msgPayload) == 0 { |
||||
return false, errors.New("[unexpected] received a message fragment with no stream ID") |
||||
} |
||||
|
||||
streamID := uint32(msgPayload[0]) |
||||
if !isInitialFragment && msg.streamID.Load() != streamID { |
||||
return false, fmt.Errorf("[unexpected] received message fragments with mismatched streamIDs %d and %d", msg.streamID.Load(), streamID) |
||||
} |
||||
msg.streamID.Store(streamID) |
||||
|
||||
// This is normal, Kubernetes seem to send a couple data messages with
|
||||
// no payloads at the start.
|
||||
if len(msgPayload) < 2 { |
||||
return true, nil |
||||
} |
||||
msgPayload = msgPayload[1:] // remove the stream ID byte
|
||||
msg.payload = append(msg.payload, msgPayload...) |
||||
return true, nil |
||||
} |
||||
|
||||
// maskBytes applies mask to bytes in place.
|
||||
// https://www.rfc-editor.org/rfc/rfc6455#section-5.3
|
||||
func maskBytes(key [4]byte, b []byte) { |
||||
for i := range b { |
||||
b[i] = b[i] ^ key[i%4] |
||||
} |
||||
} |
||||
|
||||
// isControlMessage returns true if the message type is one of the known control
|
||||
// frame message types.
|
||||
// https://www.rfc-editor.org/rfc/rfc6455#section-5.5
|
||||
func isControlMessage(t messageType) bool { |
||||
const ( |
||||
closeMessage messageType = 8 |
||||
pingMessage messageType = 9 |
||||
pongMessage messageType = 10 |
||||
) |
||||
return t == closeMessage || t == pingMessage || t == pongMessage |
||||
} |
||||
|
||||
// isFinalFragment can be called with websocket message fragment and returns true if
|
||||
// the fragment is the final fragment of a websocket message.
|
||||
func isFinalFragment(b []byte) bool { |
||||
return extractFirstBit(b[0]) != 0 |
||||
} |
||||
|
||||
// isMasked can be called with a websocket message fragment and returns true if
|
||||
// the payload of the message is masked. It uses the mask bit to determine if
|
||||
// the payload is masked.
|
||||
// https://www.rfc-editor.org/rfc/rfc6455#section-5.3
|
||||
func isMasked(b []byte) bool { |
||||
return extractFirstBit(b[1]) != 0 |
||||
} |
||||
|
||||
// extractFirstBit extracts first bit of a byte by zeroing out all the other
|
||||
// bits.
|
||||
func extractFirstBit(b byte) byte { |
||||
return b & 0x80 |
||||
} |
||||
|
||||
// zeroFirstBit returns the provided byte with the first bit set to 0.
|
||||
func zeroFirstBit(b byte) byte { |
||||
return b & 0x7f |
||||
} |
||||
|
||||
// fragmentDimensions returns payload length as well as payload offset and mask offset.
|
||||
func fragmentDimensions(b []byte, maskSet bool) (payloadLength, payloadOffset, maskOffset uint64, _ error) { |
||||
|
||||
// payload length can be stored either in bits [9-15] or in bytes 2, 3
|
||||
// or in bytes 2, 3, 4, 5, 6, 7.
|
||||
// https://www.rfc-editor.org/rfc/rfc6455#section-5.2
|
||||
// 0 1 2 3
|
||||
// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
|
||||
// +-+-+-+-+-------+-+-------------+-------------------------------+
|
||||
// |F|R|R|R| opcode|M| Payload len | Extended payload length |
|
||||
// |I|S|S|S| (4) |A| (7) | (16/64) |
|
||||
// |N|V|V|V| |S| | (if payload len==126/127) |
|
||||
// | |1|2|3| |K| | |
|
||||
// +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - +
|
||||
// | Extended payload length continued, if payload len == 127 |
|
||||
// + - - - - - - - - - - - - - - - +-------------------------------+
|
||||
// | |Masking-key, if MASK set to 1 |
|
||||
// +-------------------------------+-------------------------------+
|
||||
payloadLengthIndicator := zeroFirstBit(b[1]) |
||||
switch { |
||||
case payloadLengthIndicator < 126: |
||||
maskOffset = 2 |
||||
payloadLength = uint64(payloadLengthIndicator) |
||||
case payloadLengthIndicator == 126: |
||||
maskOffset = 4 |
||||
if len(b) < int(maskOffset) { |
||||
return 0, 0, 0, fmt.Errorf("invalid message fragment- length indicator suggests that length is stored in bytes 2:4, but message length is only %d", len(b)) |
||||
} |
||||
payloadLength = uint64(binary.BigEndian.Uint16(b[2:4])) |
||||
case payloadLengthIndicator == 127: |
||||
maskOffset = 10 |
||||
if len(b) < int(maskOffset) { |
||||
return 0, 0, 0, fmt.Errorf("invalid message fragment- length indicator suggests that length is stored in bytes 2:10, but message length is only %d", len(b)) |
||||
} |
||||
payloadLength = binary.BigEndian.Uint64(b[2:10]) |
||||
default: |
||||
return 0, 0, 0, fmt.Errorf("unexpected payload length indicator value: %v", payloadLengthIndicator) |
||||
} |
||||
|
||||
// Ensure that a rogue or broken client doesn't cause us attempt to
|
||||
// allocate a huge array by setting a high payload size.
|
||||
// websocket.DefaultMaxPayloadBytes is the maximum payload size accepted
|
||||
// by server side of this connection, so we can safely reject messages
|
||||
// with larger payload size.
|
||||
if payloadLength > websocket.DefaultMaxPayloadBytes { |
||||
return 0, 0, 0, fmt.Errorf("[unexpected]: too large payload size: %v", payloadLength) |
||||
} |
||||
|
||||
// Masking key can take up 0 or 4 bytes- we need to take that into
|
||||
// account when determining payload offset.
|
||||
// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
|
||||
// ....
|
||||
// + - - - - - - - - - - - - - - - +-------------------------------+
|
||||
// | |Masking-key, if MASK set to 1 |
|
||||
// +-------------------------------+-------------------------------+
|
||||
// | Masking-key (continued) | Payload Data |
|
||||
// + - - - - - - - - - - - - - - - +-------------------------------+
|
||||
// ...
|
||||
if maskSet { |
||||
payloadOffset = maskOffset + 4 |
||||
} else { |
||||
payloadOffset = maskOffset |
||||
} |
||||
return |
||||
} |
||||
@ -0,0 +1,215 @@ |
||||
// Copyright (c) Tailscale Inc & AUTHORS
|
||||
// SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
//go:build !plan9
|
||||
|
||||
package ws |
||||
|
||||
import ( |
||||
"encoding/binary" |
||||
"fmt" |
||||
"reflect" |
||||
"testing" |
||||
"time" |
||||
|
||||
"math/rand" |
||||
|
||||
"go.uber.org/zap" |
||||
"golang.org/x/net/websocket" |
||||
) |
||||
|
||||
func Test_msg_Parse(t *testing.T) { |
||||
zl, err := zap.NewDevelopment() |
||||
if err != nil { |
||||
t.Fatalf("error creating a test logger: %v", err) |
||||
} |
||||
testMask := [4]byte{1, 2, 3, 4} |
||||
bs126, bs126Len := bytesSlice2ByteLen(t) |
||||
bs127, bs127Len := byteSlice8ByteLen(t) |
||||
tests := []struct { |
||||
name string |
||||
b []byte |
||||
initialPayload []byte |
||||
wantPayload []byte |
||||
wantIsFinalized bool |
||||
wantStreamID uint32 |
||||
wantErr bool |
||||
}{ |
||||
{ |
||||
name: "single_fragment_stdout_stream_no_payload_no_mask", |
||||
b: []byte{0x82, 0x1, 0x1}, |
||||
wantPayload: nil, |
||||
wantIsFinalized: true, |
||||
wantStreamID: 1, |
||||
}, |
||||
{ |
||||
name: "single_fragment_stderr_steam_no_payload_has_mask", |
||||
b: append([]byte{0x82, 0x81, 0x1, 0x2, 0x3, 0x4}, maskedBytes(testMask, []byte{0x2})...), |
||||
wantPayload: nil, |
||||
wantIsFinalized: true, |
||||
wantStreamID: 2, |
||||
}, |
||||
{ |
||||
name: "single_fragment_stdout_stream_no_mask_has_payload", |
||||
b: []byte{0x82, 0x3, 0x1, 0x7, 0x8}, |
||||
wantPayload: []byte{0x7, 0x8}, |
||||
wantIsFinalized: true, |
||||
wantStreamID: 1, |
||||
}, |
||||
{ |
||||
name: "single_fragment_stdout_stream_has_mask_has_payload", |
||||
b: append([]byte{0x82, 0x83, 0x1, 0x2, 0x3, 0x4}, maskedBytes(testMask, []byte{0x1, 0x7, 0x8})...), |
||||
wantPayload: []byte{0x7, 0x8}, |
||||
wantIsFinalized: true, |
||||
wantStreamID: 1, |
||||
}, |
||||
{ |
||||
name: "initial_fragment_stdout_stream_no_mask_has_payload", |
||||
b: []byte{0x2, 0x3, 0x1, 0x7, 0x8}, |
||||
wantPayload: []byte{0x7, 0x8}, |
||||
wantStreamID: 1, |
||||
}, |
||||
{ |
||||
name: "initial_fragment_stdout_stream_has_mask_has_payload", |
||||
b: append([]byte{0x2, 0x83, 0x1, 0x2, 0x3, 0x4}, maskedBytes(testMask, []byte{0x1, 0x7, 0x8})...), |
||||
wantPayload: []byte{0x7, 0x8}, |
||||
wantStreamID: 1, |
||||
}, |
||||
{ |
||||
name: "subsequent_fragment_stdout_stream_no_mask_has_payload", |
||||
b: []byte{0x0, 0x3, 0x1, 0x7, 0x8}, |
||||
initialPayload: []byte{0x1, 0x2, 0x3}, |
||||
wantPayload: []byte{0x1, 0x2, 0x3, 0x7, 0x8}, |
||||
wantStreamID: 1, |
||||
}, |
||||
{ |
||||
name: "subsequent_fragment_stdout_stream_has_mask_has_payload", |
||||
b: append([]byte{0x0, 0x83, 0x1, 0x2, 0x3, 0x4}, maskedBytes(testMask, []byte{0x1, 0x7, 0x8})...), |
||||
initialPayload: []byte{0x1, 0x2, 0x3}, |
||||
wantPayload: []byte{0x1, 0x2, 0x3, 0x7, 0x8}, |
||||
wantStreamID: 1, |
||||
}, |
||||
{ |
||||
name: "final_fragment_stdout_stream_no_mask_has_payload", |
||||
b: []byte{0x80, 0x3, 0x1, 0x7, 0x8}, |
||||
initialPayload: []byte{0x1, 0x2, 0x3}, |
||||
wantIsFinalized: true, |
||||
wantPayload: []byte{0x1, 0x2, 0x3, 0x7, 0x8}, |
||||
wantStreamID: 1, |
||||
}, |
||||
{ |
||||
name: "final_fragment_stdout_stream_has_mask_has_payload", |
||||
b: append([]byte{0x80, 0x83, 0x1, 0x2, 0x3, 0x4}, maskedBytes(testMask, []byte{0x1, 0x7, 0x8})...), |
||||
initialPayload: []byte{0x1, 0x2, 0x3}, |
||||
wantIsFinalized: true, |
||||
wantPayload: []byte{0x1, 0x2, 0x3, 0x7, 0x8}, |
||||
wantStreamID: 1, |
||||
}, |
||||
{ |
||||
name: "single_large_fragment_no_mask_length_hint_126", |
||||
b: append(append([]byte{0x80, 0x7e}, bs126Len...), append([]byte{0x1}, bs126...)...), |
||||
wantIsFinalized: true, |
||||
wantPayload: bs126, |
||||
wantStreamID: 1, |
||||
}, |
||||
{ |
||||
name: "single_large_fragment_no_mask_length_hint_127", |
||||
b: append(append([]byte{0x80, 0x7f}, bs127Len...), append([]byte{0x1}, bs127...)...), |
||||
wantIsFinalized: true, |
||||
wantPayload: bs127, |
||||
wantStreamID: 1, |
||||
}, |
||||
{ |
||||
name: "zero_length_bytes", |
||||
b: []byte{}, |
||||
wantErr: true, |
||||
}, |
||||
} |
||||
for _, tt := range tests { |
||||
t.Run(tt.name, func(t *testing.T) { |
||||
msg := &message{ |
||||
typ: binaryMessage, |
||||
payload: tt.initialPayload, |
||||
} |
||||
if _, err := msg.Parse(tt.b, zl.Sugar()); (err != nil) != tt.wantErr { |
||||
t.Errorf("msg.Parse() = %v, wantsErr: %t", err, tt.wantErr) |
||||
} |
||||
if msg.isFinalized != tt.wantIsFinalized { |
||||
t.Errorf("wants message to be finalized: %t, got: %t", tt.wantIsFinalized, msg.isFinalized) |
||||
} |
||||
if msg.streamID.Load() != tt.wantStreamID { |
||||
t.Errorf("wants stream ID: %d, got: %d", tt.wantStreamID, msg.streamID.Load()) |
||||
} |
||||
if !reflect.DeepEqual(msg.payload, tt.wantPayload) { |
||||
t.Errorf("unexpected message payload after Parse, wants %b got %b", tt.wantPayload, msg.payload) |
||||
} |
||||
}) |
||||
} |
||||
} |
||||
|
||||
// Test_msg_Parse_Rand calls Parse with a randomly generated input to verify
|
||||
// that it doesn't panic.
|
||||
func Test_msg_Parse_Rand(t *testing.T) { |
||||
zl, err := zap.NewDevelopment() |
||||
if err != nil { |
||||
t.Fatalf("error creating a test logger: %v", err) |
||||
} |
||||
r := rand.New(rand.NewSource(time.Now().UnixNano())) |
||||
for i := range 100 { |
||||
n := r.Intn(4096) |
||||
b := make([]byte, n) |
||||
_, err := r.Read(b) |
||||
if err != nil { |
||||
t.Fatalf("error generating random byte slice: %v", err) |
||||
} |
||||
msg := message{typ: binaryMessage} |
||||
f := func() { |
||||
msg.Parse(b, zl.Sugar()) |
||||
} |
||||
testPanic(t, f, fmt.Sprintf("[%d] Parse panicked running with byte slice of length %d: %v", i, n, r)) |
||||
} |
||||
} |
||||
|
||||
// byteSlice2ByteLen generates a number that represents websocket message fragment length and is stored in an 8 byte slice.
|
||||
// Returns the byte slice with the length as well as a slice of arbitrary bytes of the given length.
|
||||
// This is used to generate test input representing websocket message with payload length hint 126.
|
||||
// https://www.rfc-editor.org/rfc/rfc6455#section-5.2
|
||||
func bytesSlice2ByteLen(t *testing.T) ([]byte, []byte) { |
||||
r := rand.New(rand.NewSource(time.Now().UnixNano())) |
||||
var n uint16 |
||||
n = uint16(rand.Intn(65535 - 1)) // space for and additional 1 byte stream ID
|
||||
b := make([]byte, n) |
||||
_, err := r.Read(b) |
||||
if err != nil { |
||||
t.Fatalf("error generating random byte slice: %v ", err) |
||||
} |
||||
bb := make([]byte, 2) |
||||
binary.BigEndian.PutUint16(bb, n+1) // + stream ID
|
||||
return b, bb |
||||
} |
||||
|
||||
// byteSlice8ByteLen generates a number that represents websocket message fragment length and is stored in an 8 byte slice.
|
||||
// Returns the byte slice with the length as well as a slice of arbitrary bytes of the given length.
|
||||
// This is used to generate test input representing websocket message with payload length hint 127.
|
||||
// https://www.rfc-editor.org/rfc/rfc6455#section-5.2
|
||||
func byteSlice8ByteLen(t *testing.T) ([]byte, []byte) { |
||||
nanos := time.Now().UnixNano() |
||||
t.Logf("Creating random source with seed %v", nanos) |
||||
r := rand.New(rand.NewSource(nanos)) |
||||
var n uint64 |
||||
n = uint64(rand.Intn(websocket.DefaultMaxPayloadBytes - 1)) // space for and additional 1 byte stream ID
|
||||
t.Logf("byteSlice8ByteLen: generating message payload of length %d", n) |
||||
b := make([]byte, n) |
||||
_, err := r.Read(b) |
||||
if err != nil { |
||||
t.Fatalf("error generating random byte slice: %v ", err) |
||||
} |
||||
bb := make([]byte, 8) |
||||
binary.BigEndian.PutUint64(bb, n+1) // + stream ID
|
||||
return b, bb |
||||
} |
||||
|
||||
func maskedBytes(mask [4]byte, b []byte) []byte { |
||||
maskBytes(mask, b) |
||||
return b |
||||
} |
||||
Loading…
Reference in new issue