From 7b3474abff37d9ebbe11e48ed8aa38bd077bdf2e Mon Sep 17 00:00:00 2001 From: Alex Dima Date: Thu, 25 Nov 2021 16:44:30 +0100 Subject: [PATCH] Make sure websocket frames are processed in order --- src/vs/base/parts/ipc/node/ipc.net.ts | 290 ++++++++++-------- .../base/parts/ipc/test/node/ipc.net.test.ts | 9 + 2 files changed, 171 insertions(+), 128 deletions(-) diff --git a/src/vs/base/parts/ipc/node/ipc.net.ts b/src/vs/base/parts/ipc/node/ipc.net.ts index 4b3ae6bdef9..af596a0fed6 100644 --- a/src/vs/base/parts/ipc/node/ipc.net.ts +++ b/src/vs/base/parts/ipc/node/ipc.net.ts @@ -184,13 +184,7 @@ interface ISocketTracer { export class WebSocketNodeSocket extends Disposable implements ISocket, ISocketTracer { public readonly socket: NodeSocket; - public readonly permessageDeflate: boolean; - private _totalIncomingWireBytes: number = 0; - private _totalIncomingDataBytes: number = 0; - private _totalOutgoingWireBytes: number = 0; - private _totalOutgoingDataBytes: number = 0; - private readonly _zlibInflateStream: ZlibInflateStream | null; - private readonly _zlibDeflateStream: ZlibDeflateStream | null; + private readonly _flowManager: WebSocketFlowManager; private readonly _incomingData: ChunkStream; private readonly _onData = this._register(new Emitter()); private readonly _onClose = this._register(new Emitter()); @@ -205,27 +199,12 @@ export class WebSocketNodeSocket extends Disposable implements ISocket, ISocketT mask: 0 }; - public get totalIncomingWireBytes(): number { - return this._totalIncomingWireBytes; - } - - public get totalIncomingDataBytes(): number { - return this._totalIncomingDataBytes; - } - - public get totalOutgoingWireBytes(): number { - return this._totalOutgoingWireBytes; - } - - public get totalOutgoingDataBytes(): number { - return this._totalOutgoingDataBytes; + public get permessageDeflate(): boolean { + return this._flowManager.permessageDeflate; } public get recordedInflateBytes(): VSBuffer { - if (this._zlibInflateStream) { - return this._zlibInflateStream.recordedInflateBytes; - } - return VSBuffer.alloc(0); + return this._flowManager.recordedInflateBytes; } public traceSocketEvent(type: SocketDiagnosticsEventType, data?: VSBuffer | Uint8Array | ArrayBuffer | ArrayBufferView | any): void { @@ -248,55 +227,33 @@ export class WebSocketNodeSocket extends Disposable implements ISocket, ISocketT super(); this.socket = socket; this.traceSocketEvent(SocketDiagnosticsEventType.Created, { type: 'WebSocketNodeSocket', permessageDeflate, inflateBytesLength: inflateBytes?.byteLength || 0, recordInflateBytes }); - this.permessageDeflate = permessageDeflate; - if (permessageDeflate) { - // See https://tools.ietf.org/html/rfc7692#page-16 - // To simplify our logic, we don't negotiate the window size - // and simply dedicate (2^15) / 32kb per web socket - this._zlibInflateStream = this._register(new ZlibInflateStream(this, recordInflateBytes, inflateBytes, { - windowBits: 15 - })); - this._register(this._zlibInflateStream.onError((err) => { - // zlib errors are fatal, since we have no idea how to recover - console.error(err); - onUnexpectedError(err); - this._onClose.fire({ - type: SocketCloseEventType.NodeSocketCloseEvent, - hadError: true, - error: err - }); - })); - this._register(this._zlibInflateStream.onData((data) => { - this._totalIncomingDataBytes += data.byteLength; - this._onData.fire(data); - })); - - this._zlibDeflateStream = this._register(new ZlibDeflateStream(this, { - windowBits: 15 - })); - this._register(this._zlibDeflateStream.onError((err) => { - // zlib errors are fatal, since we have no idea how to recover - console.error(err); - onUnexpectedError(err); - this._onClose.fire({ - type: SocketCloseEventType.NodeSocketCloseEvent, - hadError: true, - error: err - }); - })); - } else { - this._zlibInflateStream = null; - this._zlibDeflateStream = null; - } + this._flowManager = this._register(new WebSocketFlowManager( + this, + permessageDeflate, + inflateBytes, + recordInflateBytes, + this._onData, + (data, compressed) => this._write(data, compressed) + )); + this._register(this._flowManager.onError((err) => { + // zlib errors are fatal, since we have no idea how to recover + console.error(err); + onUnexpectedError(err); + this._onClose.fire({ + type: SocketCloseEventType.NodeSocketCloseEvent, + hadError: true, + error: err + }); + })); this._incomingData = new ChunkStream(); this._register(this.socket.onData(data => this._acceptChunk(data))); this._register(this.socket.onClose((e) => this._onClose.fire(e))); } public override dispose(): void { - if (this._zlibDeflateStream && this._zlibDeflateStream.needsDraining()) { + if (this._flowManager.isProcessingWriteQueue()) { // Wait for any outstanding writes to finish before disposing - this._register(this._zlibDeflateStream.onDidDrain(() => { + this._register(this._flowManager.onDidFinishProcessingWriteQueue(() => { this.dispose(); })); } else { @@ -318,22 +275,15 @@ export class WebSocketNodeSocket extends Disposable implements ISocket, ISocketT } public write(buffer: VSBuffer): void { - this._totalOutgoingDataBytes += buffer.byteLength; - - if (this._zlibDeflateStream) { - this._zlibDeflateStream.write(buffer); - this._zlibDeflateStream.flush((data) => { - if (!this._isEnded) { - // Avoid ERR_STREAM_WRITE_AFTER_END - this._write(data, true); - } - }); - } else { - this._write(buffer, false); - } + this._flowManager.writeMessage(buffer); } private _write(buffer: VSBuffer, compressed: boolean): void { + if (this._isEnded) { + // Avoid ERR_STREAM_WRITE_AFTER_END + return; + } + this.traceSocketEvent(SocketDiagnosticsEventType.WebSocketNodeSocketWrite, buffer); let headerLen = Constants.MinHeaderByteSize; if (buffer.byteLength < 126) { @@ -371,7 +321,6 @@ export class WebSocketNodeSocket extends Disposable implements ISocket, ISocketT header.writeUInt8((buffer.byteLength >>> 0) & 0b11111111, ++offset); } - this._totalOutgoingWireBytes += header.byteLength + buffer.byteLength; this.socket.write(VSBuffer.concat([header, buffer])); } @@ -384,7 +333,6 @@ export class WebSocketNodeSocket extends Disposable implements ISocket, ISocketT if (data.byteLength === 0) { return; } - this._totalIncomingWireBytes += data.byteLength; this._incomingData.acceptChunk(data); @@ -467,45 +415,153 @@ export class WebSocketNodeSocket extends Disposable implements ISocket, ISocketT this._state.readLen = Constants.MinHeaderByteSize; this._state.mask = 0; - if (this._zlibInflateStream && this._state.compressed) { - // See https://datatracker.ietf.org/doc/html/rfc7692#section-9.2 - // Even if permessageDeflate is negotiated, it is possible - // that the other side might decide to send uncompressed messages - // So only decompress messages that have the RSV 1 bit set - // - // See https://tools.ietf.org/html/rfc7692#section-7.2.2 - - this._zlibInflateStream.write(body); - if (this._state.fin) { - this._zlibInflateStream.write(VSBuffer.fromByteArray([0x00, 0x00, 0xff, 0xff])); - } - this._zlibInflateStream.flush(); - } else { - this._totalIncomingDataBytes += body.byteLength; - this._onData.fire(body); - } + this._flowManager.acceptFrame(body, this._state.compressed, !!this._state.fin); } } } public async drain(): Promise { this.traceSocketEvent(SocketDiagnosticsEventType.WebSocketNodeSocketDrainBegin); - if (this._zlibDeflateStream) { - await this._zlibDeflateStream.drain(); + if (this._flowManager.isProcessingWriteQueue()) { + await Event.toPromise(this._flowManager.onDidFinishProcessingWriteQueue); } await this.socket.drain(); this.traceSocketEvent(SocketDiagnosticsEventType.WebSocketNodeSocketDrainEnd); } } +class WebSocketFlowManager extends Disposable { + + private readonly _onError = this._register(new Emitter()); + public readonly onError = this._onError.event; + + private readonly _zlibInflateStream: ZlibInflateStream | null; + private readonly _zlibDeflateStream: ZlibDeflateStream | null; + private readonly _writeQueue: VSBuffer[] = []; + private readonly _readQueue: { data: VSBuffer, isCompressed: boolean, isLastFrameOfMessage: boolean }[] = []; + + private readonly _onDidFinishProcessingWriteQueue = this._register(new Emitter()); + public readonly onDidFinishProcessingWriteQueue = this._onDidFinishProcessingWriteQueue.event; + + public get permessageDeflate(): boolean { + return Boolean(this._zlibInflateStream && this._zlibDeflateStream); + } + + public get recordedInflateBytes(): VSBuffer { + if (this._zlibInflateStream) { + return this._zlibInflateStream.recordedInflateBytes; + } + return VSBuffer.alloc(0); + } + + constructor( + private readonly _tracer: ISocketTracer, + permessageDeflate: boolean, + inflateBytes: VSBuffer | null, + recordInflateBytes: boolean, + private readonly _onData: Emitter, + private readonly _writeFn: (data: VSBuffer, compressed: boolean) => void + ) { + super(); + if (permessageDeflate) { + // See https://tools.ietf.org/html/rfc7692#page-16 + // To simplify our logic, we don't negotiate the window size + // and simply dedicate (2^15) / 32kb per web socket + this._zlibInflateStream = this._register(new ZlibInflateStream(this._tracer, recordInflateBytes, inflateBytes, { windowBits: 15 })); + this._zlibDeflateStream = this._register(new ZlibDeflateStream(this._tracer, { windowBits: 15 })); + this._register(this._zlibInflateStream.onError((err) => this._onError.fire(err))); + this._register(this._zlibDeflateStream.onError((err) => this._onError.fire(err))); + } else { + this._zlibInflateStream = null; + this._zlibDeflateStream = null; + } + } + + public writeMessage(message: VSBuffer): void { + this._writeQueue.push(message); + this._processWriteQueue(); + } + + private _isProcessingWriteQueue = false; + private async _processWriteQueue(): Promise { + if (this._isProcessingWriteQueue) { + return; + } + this._isProcessingWriteQueue = true; + while (this._writeQueue.length > 0) { + const message = this._writeQueue.shift()!; + if (this._zlibDeflateStream) { + const data = await this._deflateMessage(this._zlibDeflateStream, message); + this._writeFn(data, true); + } else { + this._writeFn(message, false); + } + } + this._isProcessingWriteQueue = false; + this._onDidFinishProcessingWriteQueue.fire(); + } + + public isProcessingWriteQueue(): boolean { + return (this._isProcessingWriteQueue); + } + + /** + * Subsequent calls should wait for the previous `_deflateBuffer` call to complete. + */ + private _deflateMessage(zlibDeflateStream: ZlibDeflateStream, buffer: VSBuffer): Promise { + return new Promise((resolve, reject) => { + zlibDeflateStream.write(buffer); + zlibDeflateStream.flush(data => resolve(data)); + }); + } + + public acceptFrame(data: VSBuffer, isCompressed: boolean, isLastFrameOfMessage: boolean): void { + this._readQueue.push({ data, isCompressed, isLastFrameOfMessage }); + this._processReadQueue(); + } + + private _isProcessingReadQueue = false; + private async _processReadQueue(): Promise { + if (this._isProcessingReadQueue) { + return; + } + this._isProcessingReadQueue = true; + while (this._readQueue.length > 0) { + const frameInfo = this._readQueue.shift()!; + if (this._zlibInflateStream && frameInfo.isCompressed) { + // See https://datatracker.ietf.org/doc/html/rfc7692#section-9.2 + // Even if permessageDeflate is negotiated, it is possible + // that the other side might decide to send uncompressed messages + // So only decompress messages that have the RSV 1 bit set + const data = await this._inflateFrame(this._zlibInflateStream, frameInfo.data, frameInfo.isLastFrameOfMessage); + this._onData.fire(data); + } else { + this._onData.fire(frameInfo.data); + } + } + this._isProcessingReadQueue = false; + } + + /** + * Subsequent calls should wait for the previous `transformRead` call to complete. + */ + private _inflateFrame(zlibInflateStream: ZlibInflateStream, buffer: VSBuffer, isLastFrameOfMessage: boolean): Promise { + return new Promise((resolve, reject) => { + // See https://tools.ietf.org/html/rfc7692#section-7.2.2 + zlibInflateStream.write(buffer); + if (isLastFrameOfMessage) { + zlibInflateStream.write(VSBuffer.fromByteArray([0x00, 0x00, 0xff, 0xff])); + } + zlibInflateStream.flush(data => resolve(data)); + }); + } +} + class ZlibInflateStream extends Disposable { private readonly _onError = this._register(new Emitter()); public readonly onError = this._onError.event; - private readonly _onData = this._register(new Emitter()); - public readonly onData = this._onData.event; - private readonly _zlibInflate: zlib.InflateRaw; private readonly _recordedInflateBytes: VSBuffer[] = []; private readonly _pendingInflateData: VSBuffer[] = []; @@ -551,12 +607,12 @@ class ZlibInflateStream extends Disposable { this._zlibInflate.write(buffer.buffer); } - public flush(): void { + public flush(callback: (data: VSBuffer) => void): void { this._zlibInflate.flush(() => { this._tracer.traceSocketEvent(SocketDiagnosticsEventType.zlibInflateFlushFired); const data = VSBuffer.concat(this._pendingInflateData); this._pendingInflateData.length = 0; - this._onData.fire(data); + callback(data); }); } } @@ -566,12 +622,8 @@ class ZlibDeflateStream extends Disposable { private readonly _onError = this._register(new Emitter()); public readonly onError = this._onError.event; - private readonly _onDidDrain = this._register(new Emitter()); - public readonly onDidDrain = this._onDidDrain.event; - private readonly _zlibDeflate: zlib.DeflateRaw; private readonly _pendingDeflateData: VSBuffer[] = []; - private _flushWaitingCount: number = 0; constructor( private readonly _tracer: ISocketTracer, @@ -598,12 +650,8 @@ class ZlibDeflateStream extends Disposable { } public flush(callback: (data: VSBuffer) => void): void { - this._flushWaitingCount++; - // See https://zlib.net/manual.html#Constants this._zlibDeflate.flush(/*Z_SYNC_FLUSH*/2, () => { - this._flushWaitingCount--; - this._tracer.traceSocketEvent(SocketDiagnosticsEventType.zlibDeflateFlushFired); let data = VSBuffer.concat(this._pendingDeflateData); @@ -613,22 +661,8 @@ class ZlibDeflateStream extends Disposable { data = data.slice(0, data.byteLength - 4); callback(data); - - if (this._flushWaitingCount === 0) { - this._onDidDrain.fire(); - } }); } - - public needsDraining(): boolean { - return (this._flushWaitingCount > 0); - } - - public async drain(): Promise { - if (this._flushWaitingCount > 0) { - await Event.toPromise(this.onDidDrain); - } - } } function unmask(buffer: VSBuffer, mask: number): void { diff --git a/src/vs/base/parts/ipc/test/node/ipc.net.test.ts b/src/vs/base/parts/ipc/test/node/ipc.net.test.ts index 4345fcccaca..042bfcbabfe 100644 --- a/src/vs/base/parts/ipc/test/node/ipc.net.test.ts +++ b/src/vs/base/parts/ipc/test/node/ipc.net.test.ts @@ -525,5 +525,14 @@ suite('WebSocketNodeSocket', () => { const actual = await testReading(frames, true); assert.deepStrictEqual(actual, 'Hello'); }); + + test('A single-frame compressed text message followed by a single-frame non-compressed text message', async () => { + const frames = [ + [0xc1, 0x07, 0xf2, 0x48, 0xcd, 0xc9, 0xc9, 0x07, 0x00], // contains "Hello" + [0x81, 0x05, 0x77, 0x6f, 0x72, 0x6c, 0x64] // contains "world" + ]; + const actual = await testReading(frames, true); + assert.deepStrictEqual(actual, 'Helloworld'); + }); }); });