Make sure websocket frames are processed in order

This commit is contained in:
Alex Dima 2021-11-25 16:44:30 +01:00
parent a3dce400d6
commit 7b3474abff
No known key found for this signature in database
GPG key ID: 39563C1504FDD0C9
2 changed files with 171 additions and 128 deletions

View file

@ -184,13 +184,7 @@ interface ISocketTracer {
export class WebSocketNodeSocket extends Disposable implements ISocket, ISocketTracer {
public readonly socket: NodeSocket;
public readonly permessageDeflate: boolean;
private _totalIncomingWireBytes: number = 0;
private _totalIncomingDataBytes: number = 0;
private _totalOutgoingWireBytes: number = 0;
private _totalOutgoingDataBytes: number = 0;
private readonly _zlibInflateStream: ZlibInflateStream | null;
private readonly _zlibDeflateStream: ZlibDeflateStream | null;
private readonly _flowManager: WebSocketFlowManager;
private readonly _incomingData: ChunkStream;
private readonly _onData = this._register(new Emitter<VSBuffer>());
private readonly _onClose = this._register(new Emitter<SocketCloseEvent>());
@ -205,27 +199,12 @@ export class WebSocketNodeSocket extends Disposable implements ISocket, ISocketT
mask: 0
};
public get totalIncomingWireBytes(): number {
return this._totalIncomingWireBytes;
}
public get totalIncomingDataBytes(): number {
return this._totalIncomingDataBytes;
}
public get totalOutgoingWireBytes(): number {
return this._totalOutgoingWireBytes;
}
public get totalOutgoingDataBytes(): number {
return this._totalOutgoingDataBytes;
public get permessageDeflate(): boolean {
return this._flowManager.permessageDeflate;
}
public get recordedInflateBytes(): VSBuffer {
if (this._zlibInflateStream) {
return this._zlibInflateStream.recordedInflateBytes;
}
return VSBuffer.alloc(0);
return this._flowManager.recordedInflateBytes;
}
public traceSocketEvent(type: SocketDiagnosticsEventType, data?: VSBuffer | Uint8Array | ArrayBuffer | ArrayBufferView | any): void {
@ -248,55 +227,33 @@ export class WebSocketNodeSocket extends Disposable implements ISocket, ISocketT
super();
this.socket = socket;
this.traceSocketEvent(SocketDiagnosticsEventType.Created, { type: 'WebSocketNodeSocket', permessageDeflate, inflateBytesLength: inflateBytes?.byteLength || 0, recordInflateBytes });
this.permessageDeflate = permessageDeflate;
if (permessageDeflate) {
// See https://tools.ietf.org/html/rfc7692#page-16
// To simplify our logic, we don't negotiate the window size
// and simply dedicate (2^15) / 32kb per web socket
this._zlibInflateStream = this._register(new ZlibInflateStream(this, recordInflateBytes, inflateBytes, {
windowBits: 15
}));
this._register(this._zlibInflateStream.onError((err) => {
// zlib errors are fatal, since we have no idea how to recover
console.error(err);
onUnexpectedError(err);
this._onClose.fire({
type: SocketCloseEventType.NodeSocketCloseEvent,
hadError: true,
error: err
});
}));
this._register(this._zlibInflateStream.onData((data) => {
this._totalIncomingDataBytes += data.byteLength;
this._onData.fire(data);
}));
this._zlibDeflateStream = this._register(new ZlibDeflateStream(this, {
windowBits: 15
}));
this._register(this._zlibDeflateStream.onError((err) => {
// zlib errors are fatal, since we have no idea how to recover
console.error(err);
onUnexpectedError(err);
this._onClose.fire({
type: SocketCloseEventType.NodeSocketCloseEvent,
hadError: true,
error: err
});
}));
} else {
this._zlibInflateStream = null;
this._zlibDeflateStream = null;
}
this._flowManager = this._register(new WebSocketFlowManager(
this,
permessageDeflate,
inflateBytes,
recordInflateBytes,
this._onData,
(data, compressed) => this._write(data, compressed)
));
this._register(this._flowManager.onError((err) => {
// zlib errors are fatal, since we have no idea how to recover
console.error(err);
onUnexpectedError(err);
this._onClose.fire({
type: SocketCloseEventType.NodeSocketCloseEvent,
hadError: true,
error: err
});
}));
this._incomingData = new ChunkStream();
this._register(this.socket.onData(data => this._acceptChunk(data)));
this._register(this.socket.onClose((e) => this._onClose.fire(e)));
}
public override dispose(): void {
if (this._zlibDeflateStream && this._zlibDeflateStream.needsDraining()) {
if (this._flowManager.isProcessingWriteQueue()) {
// Wait for any outstanding writes to finish before disposing
this._register(this._zlibDeflateStream.onDidDrain(() => {
this._register(this._flowManager.onDidFinishProcessingWriteQueue(() => {
this.dispose();
}));
} else {
@ -318,22 +275,15 @@ export class WebSocketNodeSocket extends Disposable implements ISocket, ISocketT
}
public write(buffer: VSBuffer): void {
this._totalOutgoingDataBytes += buffer.byteLength;
if (this._zlibDeflateStream) {
this._zlibDeflateStream.write(buffer);
this._zlibDeflateStream.flush((data) => {
if (!this._isEnded) {
// Avoid ERR_STREAM_WRITE_AFTER_END
this._write(data, true);
}
});
} else {
this._write(buffer, false);
}
this._flowManager.writeMessage(buffer);
}
private _write(buffer: VSBuffer, compressed: boolean): void {
if (this._isEnded) {
// Avoid ERR_STREAM_WRITE_AFTER_END
return;
}
this.traceSocketEvent(SocketDiagnosticsEventType.WebSocketNodeSocketWrite, buffer);
let headerLen = Constants.MinHeaderByteSize;
if (buffer.byteLength < 126) {
@ -371,7 +321,6 @@ export class WebSocketNodeSocket extends Disposable implements ISocket, ISocketT
header.writeUInt8((buffer.byteLength >>> 0) & 0b11111111, ++offset);
}
this._totalOutgoingWireBytes += header.byteLength + buffer.byteLength;
this.socket.write(VSBuffer.concat([header, buffer]));
}
@ -384,7 +333,6 @@ export class WebSocketNodeSocket extends Disposable implements ISocket, ISocketT
if (data.byteLength === 0) {
return;
}
this._totalIncomingWireBytes += data.byteLength;
this._incomingData.acceptChunk(data);
@ -467,45 +415,153 @@ export class WebSocketNodeSocket extends Disposable implements ISocket, ISocketT
this._state.readLen = Constants.MinHeaderByteSize;
this._state.mask = 0;
if (this._zlibInflateStream && this._state.compressed) {
// See https://datatracker.ietf.org/doc/html/rfc7692#section-9.2
// Even if permessageDeflate is negotiated, it is possible
// that the other side might decide to send uncompressed messages
// So only decompress messages that have the RSV 1 bit set
//
// See https://tools.ietf.org/html/rfc7692#section-7.2.2
this._zlibInflateStream.write(body);
if (this._state.fin) {
this._zlibInflateStream.write(VSBuffer.fromByteArray([0x00, 0x00, 0xff, 0xff]));
}
this._zlibInflateStream.flush();
} else {
this._totalIncomingDataBytes += body.byteLength;
this._onData.fire(body);
}
this._flowManager.acceptFrame(body, this._state.compressed, !!this._state.fin);
}
}
}
public async drain(): Promise<void> {
this.traceSocketEvent(SocketDiagnosticsEventType.WebSocketNodeSocketDrainBegin);
if (this._zlibDeflateStream) {
await this._zlibDeflateStream.drain();
if (this._flowManager.isProcessingWriteQueue()) {
await Event.toPromise(this._flowManager.onDidFinishProcessingWriteQueue);
}
await this.socket.drain();
this.traceSocketEvent(SocketDiagnosticsEventType.WebSocketNodeSocketDrainEnd);
}
}
class WebSocketFlowManager extends Disposable {
private readonly _onError = this._register(new Emitter<Error>());
public readonly onError = this._onError.event;
private readonly _zlibInflateStream: ZlibInflateStream | null;
private readonly _zlibDeflateStream: ZlibDeflateStream | null;
private readonly _writeQueue: VSBuffer[] = [];
private readonly _readQueue: { data: VSBuffer, isCompressed: boolean, isLastFrameOfMessage: boolean }[] = [];
private readonly _onDidFinishProcessingWriteQueue = this._register(new Emitter<void>());
public readonly onDidFinishProcessingWriteQueue = this._onDidFinishProcessingWriteQueue.event;
public get permessageDeflate(): boolean {
return Boolean(this._zlibInflateStream && this._zlibDeflateStream);
}
public get recordedInflateBytes(): VSBuffer {
if (this._zlibInflateStream) {
return this._zlibInflateStream.recordedInflateBytes;
}
return VSBuffer.alloc(0);
}
constructor(
private readonly _tracer: ISocketTracer,
permessageDeflate: boolean,
inflateBytes: VSBuffer | null,
recordInflateBytes: boolean,
private readonly _onData: Emitter<VSBuffer>,
private readonly _writeFn: (data: VSBuffer, compressed: boolean) => void
) {
super();
if (permessageDeflate) {
// See https://tools.ietf.org/html/rfc7692#page-16
// To simplify our logic, we don't negotiate the window size
// and simply dedicate (2^15) / 32kb per web socket
this._zlibInflateStream = this._register(new ZlibInflateStream(this._tracer, recordInflateBytes, inflateBytes, { windowBits: 15 }));
this._zlibDeflateStream = this._register(new ZlibDeflateStream(this._tracer, { windowBits: 15 }));
this._register(this._zlibInflateStream.onError((err) => this._onError.fire(err)));
this._register(this._zlibDeflateStream.onError((err) => this._onError.fire(err)));
} else {
this._zlibInflateStream = null;
this._zlibDeflateStream = null;
}
}
public writeMessage(message: VSBuffer): void {
this._writeQueue.push(message);
this._processWriteQueue();
}
private _isProcessingWriteQueue = false;
private async _processWriteQueue(): Promise<void> {
if (this._isProcessingWriteQueue) {
return;
}
this._isProcessingWriteQueue = true;
while (this._writeQueue.length > 0) {
const message = this._writeQueue.shift()!;
if (this._zlibDeflateStream) {
const data = await this._deflateMessage(this._zlibDeflateStream, message);
this._writeFn(data, true);
} else {
this._writeFn(message, false);
}
}
this._isProcessingWriteQueue = false;
this._onDidFinishProcessingWriteQueue.fire();
}
public isProcessingWriteQueue(): boolean {
return (this._isProcessingWriteQueue);
}
/**
* Subsequent calls should wait for the previous `_deflateBuffer` call to complete.
*/
private _deflateMessage(zlibDeflateStream: ZlibDeflateStream, buffer: VSBuffer): Promise<VSBuffer> {
return new Promise<VSBuffer>((resolve, reject) => {
zlibDeflateStream.write(buffer);
zlibDeflateStream.flush(data => resolve(data));
});
}
public acceptFrame(data: VSBuffer, isCompressed: boolean, isLastFrameOfMessage: boolean): void {
this._readQueue.push({ data, isCompressed, isLastFrameOfMessage });
this._processReadQueue();
}
private _isProcessingReadQueue = false;
private async _processReadQueue(): Promise<void> {
if (this._isProcessingReadQueue) {
return;
}
this._isProcessingReadQueue = true;
while (this._readQueue.length > 0) {
const frameInfo = this._readQueue.shift()!;
if (this._zlibInflateStream && frameInfo.isCompressed) {
// See https://datatracker.ietf.org/doc/html/rfc7692#section-9.2
// Even if permessageDeflate is negotiated, it is possible
// that the other side might decide to send uncompressed messages
// So only decompress messages that have the RSV 1 bit set
const data = await this._inflateFrame(this._zlibInflateStream, frameInfo.data, frameInfo.isLastFrameOfMessage);
this._onData.fire(data);
} else {
this._onData.fire(frameInfo.data);
}
}
this._isProcessingReadQueue = false;
}
/**
* Subsequent calls should wait for the previous `transformRead` call to complete.
*/
private _inflateFrame(zlibInflateStream: ZlibInflateStream, buffer: VSBuffer, isLastFrameOfMessage: boolean): Promise<VSBuffer> {
return new Promise<VSBuffer>((resolve, reject) => {
// See https://tools.ietf.org/html/rfc7692#section-7.2.2
zlibInflateStream.write(buffer);
if (isLastFrameOfMessage) {
zlibInflateStream.write(VSBuffer.fromByteArray([0x00, 0x00, 0xff, 0xff]));
}
zlibInflateStream.flush(data => resolve(data));
});
}
}
class ZlibInflateStream extends Disposable {
private readonly _onError = this._register(new Emitter<Error>());
public readonly onError = this._onError.event;
private readonly _onData = this._register(new Emitter<VSBuffer>());
public readonly onData = this._onData.event;
private readonly _zlibInflate: zlib.InflateRaw;
private readonly _recordedInflateBytes: VSBuffer[] = [];
private readonly _pendingInflateData: VSBuffer[] = [];
@ -551,12 +607,12 @@ class ZlibInflateStream extends Disposable {
this._zlibInflate.write(buffer.buffer);
}
public flush(): void {
public flush(callback: (data: VSBuffer) => void): void {
this._zlibInflate.flush(() => {
this._tracer.traceSocketEvent(SocketDiagnosticsEventType.zlibInflateFlushFired);
const data = VSBuffer.concat(this._pendingInflateData);
this._pendingInflateData.length = 0;
this._onData.fire(data);
callback(data);
});
}
}
@ -566,12 +622,8 @@ class ZlibDeflateStream extends Disposable {
private readonly _onError = this._register(new Emitter<Error>());
public readonly onError = this._onError.event;
private readonly _onDidDrain = this._register(new Emitter<void>());
public readonly onDidDrain = this._onDidDrain.event;
private readonly _zlibDeflate: zlib.DeflateRaw;
private readonly _pendingDeflateData: VSBuffer[] = [];
private _flushWaitingCount: number = 0;
constructor(
private readonly _tracer: ISocketTracer,
@ -598,12 +650,8 @@ class ZlibDeflateStream extends Disposable {
}
public flush(callback: (data: VSBuffer) => void): void {
this._flushWaitingCount++;
// See https://zlib.net/manual.html#Constants
this._zlibDeflate.flush(/*Z_SYNC_FLUSH*/2, () => {
this._flushWaitingCount--;
this._tracer.traceSocketEvent(SocketDiagnosticsEventType.zlibDeflateFlushFired);
let data = VSBuffer.concat(this._pendingDeflateData);
@ -613,22 +661,8 @@ class ZlibDeflateStream extends Disposable {
data = data.slice(0, data.byteLength - 4);
callback(data);
if (this._flushWaitingCount === 0) {
this._onDidDrain.fire();
}
});
}
public needsDraining(): boolean {
return (this._flushWaitingCount > 0);
}
public async drain(): Promise<void> {
if (this._flushWaitingCount > 0) {
await Event.toPromise(this.onDidDrain);
}
}
}
function unmask(buffer: VSBuffer, mask: number): void {

View file

@ -525,5 +525,14 @@ suite('WebSocketNodeSocket', () => {
const actual = await testReading(frames, true);
assert.deepStrictEqual(actual, 'Hello');
});
test('A single-frame compressed text message followed by a single-frame non-compressed text message', async () => {
const frames = [
[0xc1, 0x07, 0xf2, 0x48, 0xcd, 0xc9, 0xc9, 0x07, 0x00], // contains "Hello"
[0x81, 0x05, 0x77, 0x6f, 0x72, 0x6c, 0x64] // contains "world"
];
const actual = await testReading(frames, true);
assert.deepStrictEqual(actual, 'Helloworld');
});
});
});