diff --git a/packages/http/src/websocket/websocket.test.ts b/packages/http/src/websocket/websocket.test.ts index d5f6fbd..6c452cb 100644 --- a/packages/http/src/websocket/websocket.test.ts +++ b/packages/http/src/websocket/websocket.test.ts @@ -155,6 +155,41 @@ suite("WebSocket", { skip: skipIfNotIntegration }, () => { /Sec-WebSocket-Version/, ) }) + + test("sets Sec-WebSocket-Protocol response header from options.protocols[0]", async () => { + const [clientTransport, serverTransport] = loopbackTransportPair() + const serverConn = new ServerConnection(serverTransport) + const clientConn = new ClientConnection(clientTransport) + let receivedHeader = "" + + await Promise.all([ + serverConn.handle(async ({ req, res }) => { + const ws = await upgradeWebSocket(req, res, { protocols: ["chat"] }) + await ws.close() + for await (const _ of ws) { /* drain */ } + }), + (async () => { + const key = generateKey() + const req = new ClientRequestImpl({ + method: "GET", + target: "/", + headers: { + Host: "localhost", + Upgrade: "websocket", + Connection: "Upgrade", + "Sec-WebSocket-Key": key, + "Sec-WebSocket-Version": "13", + }, + }) + const res = await clientConn.request(req) + receivedHeader = (res.getHeader("sec-websocket-protocol") as string) ?? "" + const ws = await connectWebSocket(res, key) + for await (const _ of ws) { /* drain */ } + })(), + ]) + + assert.strictEqual(receivedHeader, "chat") + }) }) suite("connectWebSocket(dialer, url) — URL validation", () => { @@ -290,6 +325,52 @@ suite("WebSocket", { skip: skipIfNotIntegration }, () => { ), ]) }) + + test("sends Sec-WebSocket-Protocol header when protocols option is provided", async () => { + const [listener, dialer] = loopbackListener() + let receivedProtocol = "" + + await Promise.all([ + (async () => { + const t = await listener.accept() + listener.close() + const buf = new ReadBuffer(t) + await buf.readLine() // request line + let line: string + while ((line = await buf.readLine()) !== "") { + if (line.toLowerCase().startsWith("sec-websocket-protocol:")) + receivedProtocol = line.slice(line.indexOf(":") + 1).trim() + } + await t.close() + })(), + connectWebSocket(dialer, "ws://localhost/", { protocols: ["chat", "superchat"] }).catch(() => {}), + ]) + + assert.strictEqual(receivedProtocol, "chat, superchat") + }) + + test("sends Authorization header for URL with credentials", async () => { + const [listener, dialer] = loopbackListener() + let authHeader = "" + + await Promise.all([ + (async () => { + const t = await listener.accept() + listener.close() + const buf = new ReadBuffer(t) + await buf.readLine() // request line + let line: string + while ((line = await buf.readLine()) !== "") { + if (line.toLowerCase().startsWith("authorization:")) + authHeader = line.slice(line.indexOf(":") + 1).trim() + } + await t.close() + })(), + connectWebSocket(dialer, "ws://alice:secret@localhost/").catch(() => {}), + ]) + + assert.strictEqual(authHeader, `Basic ${btoa("alice:secret")}`) + }) }) suite("connectWebSocket(res, key) — response-based overload", () => { @@ -541,6 +622,41 @@ suite("WebSocket", { skip: skipIfNotIntegration }, () => { }, ) }) + + test("abrupt transport close without Close frame exits iterator cleanly", async () => { + // Exercises the catch { break } path in the async iterator (common.ts readFrame error). + const [listener, dialer] = loopbackListener() + let ended = false + + await Promise.all([ + (async () => { + const t = await listener.accept() + listener.close() + // Manually perform the WS handshake, then close abruptly without a Close frame. + const buf = new ReadBuffer(t) + let key = "" + let line: string + while ((line = await buf.readLine()) !== "") { + if (line.toLowerCase().startsWith("sec-websocket-key:")) + key = line.slice(line.indexOf(":") + 1).trim() + } + const accept = await computeAccept(key) + await t.write( + enc.encode( + `HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: ${accept}\r\n\r\n`, + ), + ) + await t.close() // abrupt close — no WS Close frame + })(), + (async () => { + const ws = await connectWebSocket(dialer, "ws://localhost/") + for await (const _ of ws) { /* drain */ } + ended = true + })(), + ]) + + assert.ok(ended) + }) }) suite("fragmented messages", () => {