k8s-operator/sessionrecording/ws: unify Read/Write frame parsing (#19227)

Consolidate the duplicated WebSocket frame-parsing logic from Read
and Write into a shared processFrames loop, fixing several bugs in
the process:

- Mixed control and data frames in a single Read/Write call buffer
  were not handled: a control frame would cause merged data frames
  to be skipped.
- Multiple data frames into one Write call weren't being correctly
  parsed: only the first frame was processed, ignoring the rest in
  the buffer.
- msg.isFinalized was being set before confirming the fragment was
  complete, so an incomplete msg fragment, could've been sometimes
  marked as finalized.
- Continuation frames without any payload were being treated as if
  they didn't have stream ID, even thought the id is already known
  from the initial fragment.

Fixes tailscale/corp#39583

Signed-off-by: Fernando Serboncini <fserb@tailscale.com>
Signed-off-by: chaosinthecrd <tom@tmlabs.co.uk>
Co-authored-by: chaosinthecrd <tom@tmlabs.co.uk>
main
Fernando Serboncini 1 week ago committed by GitHub
parent 8a7e160a6e
commit 07399275f1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 280
      k8s-operator/sessionrecording/ws/conn.go
  2. 88
      k8s-operator/sessionrecording/ws/conn_test.go
  3. 13
      k8s-operator/sessionrecording/ws/message.go

@ -147,88 +147,12 @@ func (c *conn) Read(b []byte) (int, error) {
return 0, nil
}
// TODO(tomhjp): If we get multiple frames in a single Read with different
// types, we may parse the second frame with the wrong type.
typ := messageType(opcode(b))
if (typ == noOpcode && c.readMsgIsIncomplete()) || c.readBufHasIncompleteFragment() { // subsequent fragment
if typ, err = c.curReadMsgType(); err != nil {
return 0, err
}
}
// A control message can not be fragmented and we are not interested in
// these messages. Just return.
// TODO(tomhjp): If we get multiple frames in a single Read, we may skip
// some non-control messages.
if isControlMessage(typ) {
return n, nil
}
// The only data message type that Kubernetes supports is binary message.
// If we received another message type, return and let the API server close the connection.
// https://github.com/kubernetes/client-go/blob/release-1.30/tools/remotecommand/websocket.go#L281
if typ != binaryMessage {
c.log.Infof("[unexpected] received a data message with a type that is not binary message type %v", typ)
return n, nil
}
if _, err := c.readBuf.Write(b[:n]); err != nil {
return 0, fmt.Errorf("[unexpected] error writing message contents to read buffer: %w", err)
}
for c.readBuf.Len() != 0 {
readMsg := &message{typ: typ} // start a new message...
// ... or pick up an already started one if the previous fragment was not final.
if c.readMsgIsIncomplete() {
readMsg = c.currentReadMsg
}
ok, err := readMsg.Parse(c.readBuf.Bytes(), c.log)
if err != nil {
return 0, fmt.Errorf("error parsing message: %v", err)
}
if !ok { // incomplete fragment
return n, nil
}
c.readBuf.Next(len(readMsg.raw))
if readMsg.isFinalized && !c.readMsgIsIncomplete() {
// we want to send stream resize messages for terminal sessions
// Stream IDs for websocket streams are static.
// https://github.com/kubernetes/client-go/blob/v0.30.0-rc.1/tools/remotecommand/websocket.go#L218
if readMsg.streamID.Load() == remotecommand.StreamResize && c.hasTerm {
var msg tsrecorder.ResizeMsg
if err = json.Unmarshal(readMsg.payload, &msg); err != nil {
return 0, fmt.Errorf("error umarshalling resize message: %w", err)
}
c.ch.Width = msg.Width
c.ch.Height = msg.Height
var isInitialResize bool
c.writeCastHeaderOnce.Do(func() {
isInitialResize = true
// If this is a session with a terminal attached,
// we must wait for the terminal width and
// height to be parsed from a resize message
// before sending CastHeader, else tsrecorder
// will not be able to play this recording.
err = c.rec.WriteCastHeader(c.ch)
close(c.initialCastHeaderSent)
})
if err != nil {
return 0, fmt.Errorf("error writing CastHeader: %w", err)
}
if !isInitialResize {
if err := c.rec.WriteResize(msg.Height, msg.Width); err != nil {
return 0, fmt.Errorf("error writing resize message: %w", err)
}
}
}
}
c.currentReadMsg = readMsg
if _, err := c.processFrames(&c.readBuf, &c.currentReadMsg); err != nil {
return 0, err
}
return n, nil
@ -245,64 +169,21 @@ func (c *conn) Write(b []byte) (int, error) {
return 0, nil
}
typ := messageType(opcode(b))
// If we are in process of parsing a message fragment, the received
// bytes are not structured as a message fragment and can not be used to
// determine a message fragment.
if c.writeBufHasIncompleteFragment() { // buffer contains previous incomplete fragment
var err error
if typ, err = c.curWriteMsgType(); err != nil {
return 0, err
}
}
if isControlMessage(typ) {
return c.Conn.Write(b)
}
writeMsg := &message{typ: typ} // start a new message...
// ... or continue the existing one if it has not been finalized.
if c.writeMsgIsIncomplete() || c.writeBufHasIncompleteFragment() {
writeMsg = c.currentWriteMsg
}
if _, err := c.writeBuf.Write(b); err != nil {
c.log.Errorf("write: error writing to write buf: %v", err)
return 0, fmt.Errorf("[unexpected] error writing to internal write buffer: %w", err)
}
ok, err := writeMsg.Parse(c.writeBuf.Bytes(), c.log)
raw, err := c.processFrames(&c.writeBuf, &c.currentWriteMsg)
if err != nil {
c.log.Errorf("write: parsing a message errored: %v", err)
return 0, fmt.Errorf("write: error parsing message: %v", err)
}
c.currentWriteMsg = writeMsg
if !ok { // incomplete fragment
return len(b), nil
return 0, err
}
c.writeBuf.Next(len(writeMsg.raw)) // advance frame
if len(writeMsg.payload) != 0 && writeMsg.isFinalized {
if writeMsg.streamID.Load() == remotecommand.StreamStdOut || writeMsg.streamID.Load() == remotecommand.StreamStdErr {
// we must wait for confirmation that the initial cast header was sent before proceeding with any more writes
select {
case <-c.ctx.Done():
return 0, c.ctx.Err()
case <-c.initialCastHeaderSent:
if err := c.rec.WriteOutput(writeMsg.payload); err != nil {
return 0, fmt.Errorf("error writing message to recorder: %w", err)
}
}
if len(raw) > 0 {
if _, err := c.Conn.Write(raw); err != nil {
return 0, err
}
}
_, err = c.Conn.Write(c.currentWriteMsg.raw)
if err != nil {
c.log.Errorf("write: error writing to conn: %v", err)
}
return len(b), nil
}
@ -318,48 +199,125 @@ func (c *conn) Close() error {
return errors.Join(connCloseErr, recCloseErr)
}
// writeBufHasIncompleteFragment returns true if the latest data message
// fragment written to the connection was incomplete and the following write
// must be the remaining payload bytes of that fragment.
func (c *conn) writeBufHasIncompleteFragment() bool {
return c.writeBuf.Len() != 0
// handleData records a finalized data message to the session recorder.
// It handles resize messages (updating terminal dimensions and writing the
// CastHeader on the first one) and stdout/stderr messages (recording output).
// Other stream IDs (stdin, error) are ignored.
func (c *conn) handleData(msg *message) error {
switch msg.streamID.Load() {
case remotecommand.StreamResize:
if !c.hasTerm {
return nil
}
var rm tsrecorder.ResizeMsg
if err := json.Unmarshal(msg.payload, &rm); err != nil {
return fmt.Errorf("error unmarshalling resize message: %w", err)
}
c.ch.Width = rm.Width
c.ch.Height = rm.Height
// The first resize writes the CastHeader and unblocks output recording.
var headerErr error
var isInitialResize bool
c.writeCastHeaderOnce.Do(func() {
isInitialResize = true
headerErr = c.rec.WriteCastHeader(c.ch)
close(c.initialCastHeaderSent)
})
if headerErr != nil {
return fmt.Errorf("error writing CastHeader: %w", headerErr)
}
if !isInitialResize {
if err := c.rec.WriteResize(rm.Height, rm.Width); err != nil {
return fmt.Errorf("error writing resize message: %w", err)
}
}
case remotecommand.StreamStdOut, remotecommand.StreamStdErr:
// Wait for the CastHeader before recording any output.
select {
case <-c.ctx.Done():
return c.ctx.Err()
case <-c.initialCastHeaderSent:
if err := c.rec.WriteOutput(msg.payload); err != nil {
return fmt.Errorf("error writing message to recorder: %w", err)
}
}
}
return nil
}
// readBufHasIncompleteFragment returns true if the latest data message
// fragment read from the connection was incomplete and the following read
// must be the remaining payload bytes of that fragment.
func (c *conn) readBufHasIncompleteFragment() bool {
return c.readBuf.Len() != 0
}
// processFrames drains complete WebSocket frames from buf, recording session
// data via handleData for finalized binary messages. It returns the raw bytes
// of every consumed frame so the Write path can forward them to the underlying
// connection. Incomplete frames are left in buf for the next call.
//
// Control frames are consumed whole without inspection. Non-binary data frames
// are unexpected (k8s only uses binary) and cause the buffer to be discarded.
func (c *conn) processFrames(
buf *bytes.Buffer,
curMsg **message,
) ([]byte, error) {
var raw []byte
for buf.Len() != 0 {
b := buf.Bytes()
if len(b) < 2 {
return raw, nil
}
// writeMsgIsIncomplete returns true if the latest WebSocket message written to
// the connection was fragmented and the next data message fragment written to
// the connection must be a fragment of that message.
// https://www.rfc-editor.org/rfc/rfc6455#section-5.4
func (c *conn) writeMsgIsIncomplete() bool {
return c.currentWriteMsg != nil && !c.currentWriteMsg.isFinalized
}
// Continuation frames (opcode 0) inherit the type of the in-progress message.
typ := messageType(opcode(b))
if typ == noOpcode && *curMsg != nil {
typ = (*curMsg).typ
}
// readMsgIsIncomplete returns true if the latest WebSocket message written to
// the connection was fragmented and the next data message fragment written to
// the connection must be a fragment of that message.
// https://www.rfc-editor.org/rfc/rfc6455#section-5.4
func (c *conn) readMsgIsIncomplete() bool {
return c.currentReadMsg != nil && !c.currentReadMsg.isFinalized
}
// Control frames: pass through without inspection.
if isControlMessage(typ) {
maskSet := isMasked(b)
payloadLen, payloadOffset, _, err := fragmentDimensions(b, maskSet)
if err != nil {
return nil, fmt.Errorf("error parsing control frame: %w", err)
}
frameLen := int(payloadOffset + payloadLen)
if len(b) < frameLen {
return raw, nil // incomplete control frame
}
raw = append(raw, b[:frameLen]...)
buf.Next(frameLen)
continue
}
func (c *conn) curReadMsgType() (messageType, error) {
if c.currentReadMsg != nil {
return c.currentReadMsg.typ, nil
}
return 0, errors.New("[unexpected] attempted to determine type for nil message")
}
// k8s remotecommand only uses binary data messages.
if typ != binaryMessage {
c.log.Infof("[unexpected] received a data message with a type that is not binary message type %v", typ)
buf.Reset()
return raw, nil
}
func (c *conn) curWriteMsgType() (messageType, error) {
if c.currentWriteMsg != nil {
return c.currentWriteMsg.typ, nil
// Continue a fragmented message or start a new one.
msg := &message{typ: typ}
if *curMsg != nil && !(*curMsg).isFinalized {
msg = *curMsg
}
ok, err := msg.Parse(b, c.log)
if err != nil {
return nil, fmt.Errorf("error parsing message: %w", err)
}
if !ok {
*curMsg = msg
return raw, nil // incomplete fragment, wait for more bytes
}
buf.Next(len(msg.raw))
*curMsg = msg
raw = append(raw, msg.raw...)
if msg.isFinalized && len(msg.payload) > 0 {
if err := c.handleData(msg); err != nil {
return nil, err
}
}
}
return 0, errors.New("[unexpected] attempted to determine type for nil message")
return raw, nil
}
// opcode reads the websocket message opcode that denotes the message type.

@ -37,6 +37,17 @@ func Test_conn_Read(t *testing.T) {
wantCastHeaderHeight int
wantRecorded []byte
}{
// Empty final continuation frame after a resize frame.
{
name: "continuation_frame_with_empty_payload",
inputs: [][]byte{
append([]byte{0x02, lenResizeMsgPayload}, testResizeMsg...),
{0x80, 0x00},
},
wantRecorded: fakes.AsciinemaCastHeaderMsg(t, 10, 20),
wantCastHeaderWidth: 10,
wantCastHeaderHeight: 20,
},
{
name: "single_read_control_message",
inputs: [][]byte{{0x88, 0x0}},
@ -58,6 +69,19 @@ func Test_conn_Read(t *testing.T) {
wantCastHeaderWidth: 10,
wantCastHeaderHeight: 20,
},
{
// A control frame (close) followed by a resize data frame in
// a single Read. Without the frame loop, the close frame
// would cause the data frame to be skipped.
name: "control_then_data_in_one_read",
inputs: [][]byte{
// close frame (0x88, len 0), then resize data frame
append([]byte{0x88, 0x00, 0x82, lenResizeMsgPayload}, testResizeMsg...),
},
wantRecorded: fakes.AsciinemaCastHeaderMsg(t, 10, 20),
wantCastHeaderWidth: 10,
wantCastHeaderHeight: 20,
},
{
name: "resize_data_frame_two_in_one_read",
inputs: [][]byte{
@ -156,6 +180,26 @@ func Test_conn_Write(t *testing.T) {
wantRecorded []byte
hasTerm bool
}{
// Empty final continuation frame; stream ID already set by
// the initial fragment.
{
name: "continuation_frame_with_empty_payload",
inputs: [][]byte{
{0x02, 0x03, 0x01, 0x07, 0x08},
{0x80, 0x00},
},
wantForwarded: []byte{0x02, 0x03, 0x01, 0x07, 0x08, 0x80, 0x00},
wantRecorded: fakes.CastLine(t, []byte{0x07, 0x08}, cl),
},
// Same as above but both fragments land in one Write call.
{
name: "continuation_frame_with_empty_payload_single_write",
inputs: [][]byte{
{0x02, 0x03, 0x01, 0x07, 0x08, 0x80, 0x00},
},
wantForwarded: []byte{0x02, 0x03, 0x01, 0x07, 0x08, 0x80, 0x00},
wantRecorded: fakes.CastLine(t, []byte{0x07, 0x08}, cl),
},
{
name: "single_write_control_frame",
inputs: [][]byte{{0x88, 0x0}},
@ -203,6 +247,38 @@ func Test_conn_Write(t *testing.T) {
wantRecorded: fakes.CastLine(t, []byte{0x7, 0x8, 0x1, 0x2, 0x3, 0x4, 0x5}, cl),
hasTerm: true,
},
{
// Two complete WebSocket frames coalesced into a single
// Write() call: a stdout binary frame followed by a close
// frame. Without a loop in the Write path, the close frame
// gets stranded in writeBuf and misinterpreted on the next
// Write.
name: "two_frames_in_one_write_data_then_close",
inputs: [][]byte{
// binary frame (opcode 0x2, FIN set = 0x82), payload len 3,
// stream ID 1 (stdout), two data bytes,
// then close frame (opcode 0x8, FIN set = 0x88), payload len 0
{0x82, 0x03, 0x01, 0x07, 0x08, 0x88, 0x00},
},
wantForwarded: []byte{0x82, 0x03, 0x01, 0x07, 0x08, 0x88, 0x00},
wantRecorded: fakes.CastLine(t, []byte{0x07, 0x08}, cl),
},
{
// Two complete stdout data frames in one Write() call.
// Mirrors the "resize_data_frame_two_in_one_read" test
// for the Read path.
name: "two_data_frames_in_one_write",
inputs: [][]byte{
// first: binary frame, payload len 3, stdout stream, two data bytes
// second: binary frame, payload len 3, stdout stream, two different data bytes
{0x82, 0x03, 0x01, 0x07, 0x08, 0x82, 0x03, 0x01, 0x09, 0x0a},
},
wantForwarded: []byte{0x82, 0x03, 0x01, 0x07, 0x08, 0x82, 0x03, 0x01, 0x09, 0x0a},
wantRecorded: append(
fakes.CastLine(t, []byte{0x07, 0x08}, cl),
fakes.CastLine(t, []byte{0x09, 0x0a}, cl)...,
),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@ -254,12 +330,20 @@ func Test_conn_ReadRand(t *testing.T) {
if err != nil {
t.Fatalf("error creating a test logger: %v", err)
}
cl := tstest.NewClock(tstest.ClockOpts{})
sr := &fakes.TestSessionRecorder{}
rec := tsrecorder.New(sr, cl, cl.Now(), true, zl.Sugar())
for i := range 100 {
tc := &fakes.TestConn{}
tc.ResetReadBuf()
headerSent := make(chan struct{})
close(headerSent) // pre-close so handleData doesn't block
c := &conn{
Conn: tc,
log: zl.Sugar(),
Conn: tc,
log: zl.Sugar(),
ctx: context.Background(),
rec: rec,
initialCastHeaderSent: headerSent,
}
bb := fakes.RandomBytes(t)
for j, input := range bb {

@ -99,19 +99,19 @@ func (msg *message) Parse(b []byte, log *zap.SugaredLogger) (bool, error) {
}
isInitialFragment := len(msg.raw) == 0
msg.isFinalized = isFinalFragment(b)
finalized := isFinalFragment(b)
maskSet := isMasked(b)
payloadLength, payloadOffset, maskOffset, err := fragmentDimensions(b, maskSet)
if err != nil {
return false, fmt.Errorf("error determining payload length: %w", err)
}
log.Debugf("parse: parsing a message fragment with payload length: %d payload offset: %d maskOffset: %d mask set: %t, is finalized: %t, is initial fragment: %t", payloadLength, payloadOffset, maskOffset, maskSet, msg.isFinalized, isInitialFragment)
log.Debugf("parse: parsing a message fragment with payload length: %d payload offset: %d maskOffset: %d mask set: %t, is finalized: %t, is initial fragment: %t", payloadLength, payloadOffset, maskOffset, maskSet, finalized, isInitialFragment)
if len(b) < int(payloadOffset+payloadLength) { // incomplete fragment
return false, nil
}
msg.isFinalized = finalized
// TODO (irbekrm): perhaps only do this extra allocation if we know we
// will need to unmask?
msg.raw = make([]byte, int(payloadOffset)+int(payloadLength))
@ -136,6 +136,13 @@ func (msg *message) Parse(b []byte, log *zap.SugaredLogger) (bool, error) {
// message payload.
// https://github.com/kubernetes/apimachinery/commit/73d12d09c5be8703587b5127416eb83dc3b7e182#diff-291f96e8632d04d2d20f5fb00f6b323492670570d65434e8eac90c7a442d13bdR23-R36
if len(msgPayload) == 0 {
if !isInitialFragment {
// Continuation frame with zero payload. The stream ID is
// already known from the initial fragment, so this is not
// fatal, just unusual.
log.Infof("[unexpected] received a continuation fragment with no payload")
return true, nil
}
return false, errors.New("[unexpected] received a message fragment with no stream ID")
}

Loading…
Cancel
Save