diff --git a/k8s-operator/sessionrecording/ws/conn.go b/k8s-operator/sessionrecording/ws/conn.go index 4762630ca..ed0ecc7ac 100644 --- a/k8s-operator/sessionrecording/ws/conn.go +++ b/k8s-operator/sessionrecording/ws/conn.go @@ -147,88 +147,12 @@ func (c *conn) Read(b []byte) (int, error) { return 0, nil } - // TODO(tomhjp): If we get multiple frames in a single Read with different - // types, we may parse the second frame with the wrong type. - typ := messageType(opcode(b)) - if (typ == noOpcode && c.readMsgIsIncomplete()) || c.readBufHasIncompleteFragment() { // subsequent fragment - if typ, err = c.curReadMsgType(); err != nil { - return 0, err - } - } - - // A control message can not be fragmented and we are not interested in - // these messages. Just return. - // TODO(tomhjp): If we get multiple frames in a single Read, we may skip - // some non-control messages. - if isControlMessage(typ) { - return n, nil - } - - // The only data message type that Kubernetes supports is binary message. - // If we received another message type, return and let the API server close the connection. - // https://github.com/kubernetes/client-go/blob/release-1.30/tools/remotecommand/websocket.go#L281 - if typ != binaryMessage { - c.log.Infof("[unexpected] received a data message with a type that is not binary message type %v", typ) - return n, nil - } - if _, err := c.readBuf.Write(b[:n]); err != nil { return 0, fmt.Errorf("[unexpected] error writing message contents to read buffer: %w", err) } - for c.readBuf.Len() != 0 { - readMsg := &message{typ: typ} // start a new message... - // ... or pick up an already started one if the previous fragment was not final. - if c.readMsgIsIncomplete() { - readMsg = c.currentReadMsg - } - - ok, err := readMsg.Parse(c.readBuf.Bytes(), c.log) - if err != nil { - return 0, fmt.Errorf("error parsing message: %v", err) - } - if !ok { // incomplete fragment - return n, nil - } - c.readBuf.Next(len(readMsg.raw)) - - if readMsg.isFinalized && !c.readMsgIsIncomplete() { - // we want to send stream resize messages for terminal sessions - // Stream IDs for websocket streams are static. - // https://github.com/kubernetes/client-go/blob/v0.30.0-rc.1/tools/remotecommand/websocket.go#L218 - if readMsg.streamID.Load() == remotecommand.StreamResize && c.hasTerm { - var msg tsrecorder.ResizeMsg - if err = json.Unmarshal(readMsg.payload, &msg); err != nil { - return 0, fmt.Errorf("error umarshalling resize message: %w", err) - } - - c.ch.Width = msg.Width - c.ch.Height = msg.Height - - var isInitialResize bool - c.writeCastHeaderOnce.Do(func() { - isInitialResize = true - // If this is a session with a terminal attached, - // we must wait for the terminal width and - // height to be parsed from a resize message - // before sending CastHeader, else tsrecorder - // will not be able to play this recording. - err = c.rec.WriteCastHeader(c.ch) - close(c.initialCastHeaderSent) - }) - if err != nil { - return 0, fmt.Errorf("error writing CastHeader: %w", err) - } - - if !isInitialResize { - if err := c.rec.WriteResize(msg.Height, msg.Width); err != nil { - return 0, fmt.Errorf("error writing resize message: %w", err) - } - } - } - } - - c.currentReadMsg = readMsg + if _, err := c.processFrames(&c.readBuf, &c.currentReadMsg); err != nil { + return 0, err } return n, nil @@ -245,64 +169,21 @@ func (c *conn) Write(b []byte) (int, error) { return 0, nil } - typ := messageType(opcode(b)) - // If we are in process of parsing a message fragment, the received - // bytes are not structured as a message fragment and can not be used to - // determine a message fragment. - if c.writeBufHasIncompleteFragment() { // buffer contains previous incomplete fragment - var err error - if typ, err = c.curWriteMsgType(); err != nil { - return 0, err - } - } - - if isControlMessage(typ) { - return c.Conn.Write(b) - } - - writeMsg := &message{typ: typ} // start a new message... - // ... or continue the existing one if it has not been finalized. - if c.writeMsgIsIncomplete() || c.writeBufHasIncompleteFragment() { - writeMsg = c.currentWriteMsg - } - if _, err := c.writeBuf.Write(b); err != nil { c.log.Errorf("write: error writing to write buf: %v", err) return 0, fmt.Errorf("[unexpected] error writing to internal write buffer: %w", err) } - ok, err := writeMsg.Parse(c.writeBuf.Bytes(), c.log) + raw, err := c.processFrames(&c.writeBuf, &c.currentWriteMsg) if err != nil { - c.log.Errorf("write: parsing a message errored: %v", err) - return 0, fmt.Errorf("write: error parsing message: %v", err) - } - - c.currentWriteMsg = writeMsg - if !ok { // incomplete fragment - return len(b), nil + return 0, err } - - c.writeBuf.Next(len(writeMsg.raw)) // advance frame - - if len(writeMsg.payload) != 0 && writeMsg.isFinalized { - if writeMsg.streamID.Load() == remotecommand.StreamStdOut || writeMsg.streamID.Load() == remotecommand.StreamStdErr { - // we must wait for confirmation that the initial cast header was sent before proceeding with any more writes - select { - case <-c.ctx.Done(): - return 0, c.ctx.Err() - case <-c.initialCastHeaderSent: - if err := c.rec.WriteOutput(writeMsg.payload); err != nil { - return 0, fmt.Errorf("error writing message to recorder: %w", err) - } - } + if len(raw) > 0 { + if _, err := c.Conn.Write(raw); err != nil { + return 0, err } } - _, err = c.Conn.Write(c.currentWriteMsg.raw) - if err != nil { - c.log.Errorf("write: error writing to conn: %v", err) - } - return len(b), nil } @@ -318,48 +199,125 @@ func (c *conn) Close() error { return errors.Join(connCloseErr, recCloseErr) } -// writeBufHasIncompleteFragment returns true if the latest data message -// fragment written to the connection was incomplete and the following write -// must be the remaining payload bytes of that fragment. -func (c *conn) writeBufHasIncompleteFragment() bool { - return c.writeBuf.Len() != 0 +// handleData records a finalized data message to the session recorder. +// It handles resize messages (updating terminal dimensions and writing the +// CastHeader on the first one) and stdout/stderr messages (recording output). +// Other stream IDs (stdin, error) are ignored. +func (c *conn) handleData(msg *message) error { + switch msg.streamID.Load() { + case remotecommand.StreamResize: + if !c.hasTerm { + return nil + } + var rm tsrecorder.ResizeMsg + if err := json.Unmarshal(msg.payload, &rm); err != nil { + return fmt.Errorf("error unmarshalling resize message: %w", err) + } + c.ch.Width = rm.Width + c.ch.Height = rm.Height + + // The first resize writes the CastHeader and unblocks output recording. + var headerErr error + var isInitialResize bool + c.writeCastHeaderOnce.Do(func() { + isInitialResize = true + headerErr = c.rec.WriteCastHeader(c.ch) + close(c.initialCastHeaderSent) + }) + if headerErr != nil { + return fmt.Errorf("error writing CastHeader: %w", headerErr) + } + if !isInitialResize { + if err := c.rec.WriteResize(rm.Height, rm.Width); err != nil { + return fmt.Errorf("error writing resize message: %w", err) + } + } + case remotecommand.StreamStdOut, remotecommand.StreamStdErr: + // Wait for the CastHeader before recording any output. + select { + case <-c.ctx.Done(): + return c.ctx.Err() + case <-c.initialCastHeaderSent: + if err := c.rec.WriteOutput(msg.payload); err != nil { + return fmt.Errorf("error writing message to recorder: %w", err) + } + } + } + return nil } -// readBufHasIncompleteFragment returns true if the latest data message -// fragment read from the connection was incomplete and the following read -// must be the remaining payload bytes of that fragment. -func (c *conn) readBufHasIncompleteFragment() bool { - return c.readBuf.Len() != 0 -} +// processFrames drains complete WebSocket frames from buf, recording session +// data via handleData for finalized binary messages. It returns the raw bytes +// of every consumed frame so the Write path can forward them to the underlying +// connection. Incomplete frames are left in buf for the next call. +// +// Control frames are consumed whole without inspection. Non-binary data frames +// are unexpected (k8s only uses binary) and cause the buffer to be discarded. +func (c *conn) processFrames( + buf *bytes.Buffer, + curMsg **message, +) ([]byte, error) { + var raw []byte + for buf.Len() != 0 { + b := buf.Bytes() + if len(b) < 2 { + return raw, nil + } -// writeMsgIsIncomplete returns true if the latest WebSocket message written to -// the connection was fragmented and the next data message fragment written to -// the connection must be a fragment of that message. -// https://www.rfc-editor.org/rfc/rfc6455#section-5.4 -func (c *conn) writeMsgIsIncomplete() bool { - return c.currentWriteMsg != nil && !c.currentWriteMsg.isFinalized -} + // Continuation frames (opcode 0) inherit the type of the in-progress message. + typ := messageType(opcode(b)) + if typ == noOpcode && *curMsg != nil { + typ = (*curMsg).typ + } -// readMsgIsIncomplete returns true if the latest WebSocket message written to -// the connection was fragmented and the next data message fragment written to -// the connection must be a fragment of that message. -// https://www.rfc-editor.org/rfc/rfc6455#section-5.4 -func (c *conn) readMsgIsIncomplete() bool { - return c.currentReadMsg != nil && !c.currentReadMsg.isFinalized -} + // Control frames: pass through without inspection. + if isControlMessage(typ) { + maskSet := isMasked(b) + payloadLen, payloadOffset, _, err := fragmentDimensions(b, maskSet) + if err != nil { + return nil, fmt.Errorf("error parsing control frame: %w", err) + } + frameLen := int(payloadOffset + payloadLen) + if len(b) < frameLen { + return raw, nil // incomplete control frame + } + raw = append(raw, b[:frameLen]...) + buf.Next(frameLen) + continue + } -func (c *conn) curReadMsgType() (messageType, error) { - if c.currentReadMsg != nil { - return c.currentReadMsg.typ, nil - } - return 0, errors.New("[unexpected] attempted to determine type for nil message") -} + // k8s remotecommand only uses binary data messages. + if typ != binaryMessage { + c.log.Infof("[unexpected] received a data message with a type that is not binary message type %v", typ) + buf.Reset() + return raw, nil + } -func (c *conn) curWriteMsgType() (messageType, error) { - if c.currentWriteMsg != nil { - return c.currentWriteMsg.typ, nil + // Continue a fragmented message or start a new one. + msg := &message{typ: typ} + if *curMsg != nil && !(*curMsg).isFinalized { + msg = *curMsg + } + + ok, err := msg.Parse(b, c.log) + if err != nil { + return nil, fmt.Errorf("error parsing message: %w", err) + } + if !ok { + *curMsg = msg + return raw, nil // incomplete fragment, wait for more bytes + } + buf.Next(len(msg.raw)) + *curMsg = msg + + raw = append(raw, msg.raw...) + if msg.isFinalized && len(msg.payload) > 0 { + if err := c.handleData(msg); err != nil { + return nil, err + } + } } - return 0, errors.New("[unexpected] attempted to determine type for nil message") + return raw, nil } // opcode reads the websocket message opcode that denotes the message type. diff --git a/k8s-operator/sessionrecording/ws/conn_test.go b/k8s-operator/sessionrecording/ws/conn_test.go index 0b4353698..ea9aca192 100644 --- a/k8s-operator/sessionrecording/ws/conn_test.go +++ b/k8s-operator/sessionrecording/ws/conn_test.go @@ -37,6 +37,17 @@ func Test_conn_Read(t *testing.T) { wantCastHeaderHeight int wantRecorded []byte }{ + // Empty final continuation frame after a resize frame. + { + name: "continuation_frame_with_empty_payload", + inputs: [][]byte{ + append([]byte{0x02, lenResizeMsgPayload}, testResizeMsg...), + {0x80, 0x00}, + }, + wantRecorded: fakes.AsciinemaCastHeaderMsg(t, 10, 20), + wantCastHeaderWidth: 10, + wantCastHeaderHeight: 20, + }, { name: "single_read_control_message", inputs: [][]byte{{0x88, 0x0}}, @@ -58,6 +69,19 @@ func Test_conn_Read(t *testing.T) { wantCastHeaderWidth: 10, wantCastHeaderHeight: 20, }, + { + // A control frame (close) followed by a resize data frame in + // a single Read. Without the frame loop, the close frame + // would cause the data frame to be skipped. + name: "control_then_data_in_one_read", + inputs: [][]byte{ + // close frame (0x88, len 0), then resize data frame + append([]byte{0x88, 0x00, 0x82, lenResizeMsgPayload}, testResizeMsg...), + }, + wantRecorded: fakes.AsciinemaCastHeaderMsg(t, 10, 20), + wantCastHeaderWidth: 10, + wantCastHeaderHeight: 20, + }, { name: "resize_data_frame_two_in_one_read", inputs: [][]byte{ @@ -156,6 +180,26 @@ func Test_conn_Write(t *testing.T) { wantRecorded []byte hasTerm bool }{ + // Empty final continuation frame; stream ID already set by + // the initial fragment. + { + name: "continuation_frame_with_empty_payload", + inputs: [][]byte{ + {0x02, 0x03, 0x01, 0x07, 0x08}, + {0x80, 0x00}, + }, + wantForwarded: []byte{0x02, 0x03, 0x01, 0x07, 0x08, 0x80, 0x00}, + wantRecorded: fakes.CastLine(t, []byte{0x07, 0x08}, cl), + }, + // Same as above but both fragments land in one Write call. + { + name: "continuation_frame_with_empty_payload_single_write", + inputs: [][]byte{ + {0x02, 0x03, 0x01, 0x07, 0x08, 0x80, 0x00}, + }, + wantForwarded: []byte{0x02, 0x03, 0x01, 0x07, 0x08, 0x80, 0x00}, + wantRecorded: fakes.CastLine(t, []byte{0x07, 0x08}, cl), + }, { name: "single_write_control_frame", inputs: [][]byte{{0x88, 0x0}}, @@ -203,6 +247,38 @@ func Test_conn_Write(t *testing.T) { wantRecorded: fakes.CastLine(t, []byte{0x7, 0x8, 0x1, 0x2, 0x3, 0x4, 0x5}, cl), hasTerm: true, }, + { + // Two complete WebSocket frames coalesced into a single + // Write() call: a stdout binary frame followed by a close + // frame. Without a loop in the Write path, the close frame + // gets stranded in writeBuf and misinterpreted on the next + // Write. + name: "two_frames_in_one_write_data_then_close", + inputs: [][]byte{ + // binary frame (opcode 0x2, FIN set = 0x82), payload len 3, + // stream ID 1 (stdout), two data bytes, + // then close frame (opcode 0x8, FIN set = 0x88), payload len 0 + {0x82, 0x03, 0x01, 0x07, 0x08, 0x88, 0x00}, + }, + wantForwarded: []byte{0x82, 0x03, 0x01, 0x07, 0x08, 0x88, 0x00}, + wantRecorded: fakes.CastLine(t, []byte{0x07, 0x08}, cl), + }, + { + // Two complete stdout data frames in one Write() call. + // Mirrors the "resize_data_frame_two_in_one_read" test + // for the Read path. + name: "two_data_frames_in_one_write", + inputs: [][]byte{ + // first: binary frame, payload len 3, stdout stream, two data bytes + // second: binary frame, payload len 3, stdout stream, two different data bytes + {0x82, 0x03, 0x01, 0x07, 0x08, 0x82, 0x03, 0x01, 0x09, 0x0a}, + }, + wantForwarded: []byte{0x82, 0x03, 0x01, 0x07, 0x08, 0x82, 0x03, 0x01, 0x09, 0x0a}, + wantRecorded: append( + fakes.CastLine(t, []byte{0x07, 0x08}, cl), + fakes.CastLine(t, []byte{0x09, 0x0a}, cl)..., + ), + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -254,12 +330,20 @@ func Test_conn_ReadRand(t *testing.T) { if err != nil { t.Fatalf("error creating a test logger: %v", err) } + cl := tstest.NewClock(tstest.ClockOpts{}) + sr := &fakes.TestSessionRecorder{} + rec := tsrecorder.New(sr, cl, cl.Now(), true, zl.Sugar()) for i := range 100 { tc := &fakes.TestConn{} tc.ResetReadBuf() + headerSent := make(chan struct{}) + close(headerSent) // pre-close so handleData doesn't block c := &conn{ - Conn: tc, - log: zl.Sugar(), + Conn: tc, + log: zl.Sugar(), + ctx: context.Background(), + rec: rec, + initialCastHeaderSent: headerSent, } bb := fakes.RandomBytes(t) for j, input := range bb { diff --git a/k8s-operator/sessionrecording/ws/message.go b/k8s-operator/sessionrecording/ws/message.go index 36359996a..47177ef19 100644 --- a/k8s-operator/sessionrecording/ws/message.go +++ b/k8s-operator/sessionrecording/ws/message.go @@ -99,19 +99,19 @@ func (msg *message) Parse(b []byte, log *zap.SugaredLogger) (bool, error) { } isInitialFragment := len(msg.raw) == 0 - msg.isFinalized = isFinalFragment(b) - + finalized := isFinalFragment(b) maskSet := isMasked(b) payloadLength, payloadOffset, maskOffset, err := fragmentDimensions(b, maskSet) if err != nil { return false, fmt.Errorf("error determining payload length: %w", err) } - log.Debugf("parse: parsing a message fragment with payload length: %d payload offset: %d maskOffset: %d mask set: %t, is finalized: %t, is initial fragment: %t", payloadLength, payloadOffset, maskOffset, maskSet, msg.isFinalized, isInitialFragment) + log.Debugf("parse: parsing a message fragment with payload length: %d payload offset: %d maskOffset: %d mask set: %t, is finalized: %t, is initial fragment: %t", payloadLength, payloadOffset, maskOffset, maskSet, finalized, isInitialFragment) if len(b) < int(payloadOffset+payloadLength) { // incomplete fragment return false, nil } + msg.isFinalized = finalized // TODO (irbekrm): perhaps only do this extra allocation if we know we // will need to unmask? msg.raw = make([]byte, int(payloadOffset)+int(payloadLength)) @@ -136,6 +136,13 @@ func (msg *message) Parse(b []byte, log *zap.SugaredLogger) (bool, error) { // message payload. // https://github.com/kubernetes/apimachinery/commit/73d12d09c5be8703587b5127416eb83dc3b7e182#diff-291f96e8632d04d2d20f5fb00f6b323492670570d65434e8eac90c7a442d13bdR23-R36 if len(msgPayload) == 0 { + if !isInitialFragment { + // Continuation frame with zero payload. The stream ID is + // already known from the initial fragment, so this is not + // fatal, just unusual. + log.Infof("[unexpected] received a continuation fragment with no payload") + return true, nil + } return false, errors.New("[unexpected] received a message fragment with no stream ID") }