From 28d26bd7fa585d76f9ca69a67dfa70a234450ed9 Mon Sep 17 00:00:00 2001 From: Artem Zakharchenko Date: Wed, 6 Nov 2024 21:31:04 +0100 Subject: [PATCH] fix: prevent `instanceof` handler check failures between different MSW versions (#2349) --- .../start/createFallbackRequestListener.ts | 3 +- .../start/createRequestListener.ts | 9 +-- src/core/handlers/RequestHandler.ts | 4 ++ src/core/handlers/WebSocketHandler.ts | 4 ++ src/core/handlers/common.ts | 1 + src/core/utils/executeHandlers.ts | 6 +- src/core/utils/internal/isHandlerKind.test.ts | 64 +++++++++++++++++++ src/core/utils/internal/isHandlerKind.ts | 21 ++++++ src/core/ws/handleWebSocketEvent.ts | 3 +- src/node/SetupServerCommonApi.ts | 12 ++-- 10 files changed, 105 insertions(+), 22 deletions(-) create mode 100644 src/core/handlers/common.ts create mode 100644 src/core/utils/internal/isHandlerKind.test.ts create mode 100644 src/core/utils/internal/isHandlerKind.ts diff --git a/src/browser/setupWorker/start/createFallbackRequestListener.ts b/src/browser/setupWorker/start/createFallbackRequestListener.ts index a87e5da81..1afee6f8d 100644 --- a/src/browser/setupWorker/start/createFallbackRequestListener.ts +++ b/src/browser/setupWorker/start/createFallbackRequestListener.ts @@ -8,6 +8,7 @@ import { XMLHttpRequestInterceptor } from '@mswjs/interceptors/XMLHttpRequest' import { SetupWorkerInternalContext, StartOptions } from '../glossary' import type { RequiredDeep } from '~/core/typeUtils' import { handleRequest } from '~/core/utils/handleRequest' +import { isHandlerKind } from '~/core/utils/internal/isHandlerKind' export function createFallbackRequestListener( context: SetupWorkerInternalContext, @@ -24,7 +25,7 @@ export function createFallbackRequestListener( const response = await handleRequest( request, requestId, - context.getRequestHandlers(), + context.getRequestHandlers().filter(isHandlerKind('RequestHandler')), options, context.emitter, { diff --git a/src/browser/setupWorker/start/createRequestListener.ts b/src/browser/setupWorker/start/createRequestListener.ts index e9e1cf904..ec96603ae 100644 --- a/src/browser/setupWorker/start/createRequestListener.ts +++ b/src/browser/setupWorker/start/createRequestListener.ts @@ -9,12 +9,11 @@ import { } from './utils/createMessageChannel' import { parseWorkerRequest } from '../../utils/parseWorkerRequest' import { RequestHandler } from '~/core/handlers/RequestHandler' -import { HttpHandler } from '~/core/handlers/HttpHandler' -import { GraphQLHandler } from '~/core/handlers/GraphQLHandler' import { handleRequest } from '~/core/utils/handleRequest' import { RequiredDeep } from '~/core/typeUtils' import { devUtils } from '~/core/utils/internal/devUtils' import { toResponseInit } from '~/core/utils/toResponseInit' +import { isHandlerKind } from '~/core/utils/internal/isHandlerKind' export const createRequestListener = ( context: SetupWorkerInternalContext, @@ -45,11 +44,7 @@ export const createRequestListener = ( await handleRequest( request, requestId, - context.getRequestHandlers().filter((handler) => { - return ( - handler instanceof HttpHandler || handler instanceof GraphQLHandler - ) - }), + context.getRequestHandlers().filter(isHandlerKind('RequestHandler')), options, context.emitter, { diff --git a/src/core/handlers/RequestHandler.ts b/src/core/handlers/RequestHandler.ts index f9d34f384..0a5e6f83d 100644 --- a/src/core/handlers/RequestHandler.ts +++ b/src/core/handlers/RequestHandler.ts @@ -7,6 +7,7 @@ import { import type { ResponseResolutionContext } from '../utils/executeHandlers' import type { MaybePromise } from '../typeUtils' import { StrictRequest, StrictResponse } from '..//HttpResponse' +import type { HandlerKind } from './common' export type DefaultRequestMultipartBody = Record< string, @@ -117,6 +118,8 @@ export abstract class RequestHandler< StrictRequest >() + private readonly __kind: HandlerKind + public info: HandlerInfo & RequestHandlerInternalInfo /** * Indicates whether this request handler has been used @@ -151,6 +154,7 @@ export abstract class RequestHandler< } this.isUsed = false + this.__kind = 'RequestHandler' } /** diff --git a/src/core/handlers/WebSocketHandler.ts b/src/core/handlers/WebSocketHandler.ts index 26d443872..f37f1bd6f 100644 --- a/src/core/handlers/WebSocketHandler.ts +++ b/src/core/handlers/WebSocketHandler.ts @@ -8,6 +8,7 @@ import { matchRequestUrl, } from '../utils/matching/matchRequestUrl' import { getCallFrame } from '../utils/internal/getCallFrame' +import type { HandlerKind } from './common' type WebSocketHandlerParsedResult = { match: Match @@ -28,6 +29,8 @@ const kStopPropagationPatched = Symbol('kStopPropagationPatched') const KOnStopPropagation = Symbol('KOnStopPropagation') export class WebSocketHandler { + private readonly __kind: HandlerKind + public id: string public callFrame?: string @@ -38,6 +41,7 @@ export class WebSocketHandler { this[kEmitter] = new Emitter() this.callFrame = getCallFrame(new Error()) + this.__kind = 'EventHandler' } public parse(args: { diff --git a/src/core/handlers/common.ts b/src/core/handlers/common.ts new file mode 100644 index 000000000..ef0d1018a --- /dev/null +++ b/src/core/handlers/common.ts @@ -0,0 +1 @@ +export type HandlerKind = 'RequestHandler' | 'EventHandler' diff --git a/src/core/utils/executeHandlers.ts b/src/core/utils/executeHandlers.ts index 3df00901e..a1c450aeb 100644 --- a/src/core/utils/executeHandlers.ts +++ b/src/core/utils/executeHandlers.ts @@ -18,7 +18,7 @@ export interface ResponseResolutionContext { * Returns the execution result object containing any matching request * handler and any mocked response it returned. */ -export const executeHandlers = async >({ +export const executeHandlers = async >({ request, requestId, handlers, @@ -33,10 +33,6 @@ export const executeHandlers = async >({ let result: RequestHandlerExecutionResult | null = null for (const handler of handlers) { - if (!(handler instanceof RequestHandler)) { - continue - } - result = await handler.run({ request, requestId, resolutionContext }) // If the handler produces some result for this request, diff --git a/src/core/utils/internal/isHandlerKind.test.ts b/src/core/utils/internal/isHandlerKind.test.ts new file mode 100644 index 000000000..84486fbe9 --- /dev/null +++ b/src/core/utils/internal/isHandlerKind.test.ts @@ -0,0 +1,64 @@ +import { GraphQLHandler } from '../../handlers/GraphQLHandler' +import { HttpHandler } from '../../handlers/HttpHandler' +import { RequestHandler } from '../../handlers/RequestHandler' +import { WebSocketHandler } from '../../handlers/WebSocketHandler' +import { isHandlerKind } from './isHandlerKind' + +it('returns true if expected a request handler and given a request handler', () => { + expect( + isHandlerKind('RequestHandler')(new HttpHandler('*', '*', () => {})), + ).toBe(true) + + expect( + isHandlerKind('RequestHandler')( + new GraphQLHandler('all', '*', '*', () => {}), + ), + ).toBe(true) +}) + +it('returns true if expected a request handler and given a custom request handler', () => { + class MyHandler extends RequestHandler { + constructor() { + super({ info: { header: '*' }, resolver: () => {} }) + } + predicate = () => false + log() {} + } + + expect(isHandlerKind('RequestHandler')(new MyHandler())).toBe(true) +}) + +it('returns false if expected a request handler but given event handler', () => { + expect(isHandlerKind('RequestHandler')(new WebSocketHandler('*'))).toBe(false) +}) + +it('returns false if expected a request handler but given arbitrary object', () => { + expect(isHandlerKind('RequestHandler')(undefined)).toBe(false) + expect(isHandlerKind('RequestHandler')(null)).toBe(false) + expect(isHandlerKind('RequestHandler')({})).toBe(false) + expect(isHandlerKind('RequestHandler')([])).toBe(false) + expect(isHandlerKind('RequestHandler')(123)).toBe(false) + expect(isHandlerKind('RequestHandler')('hello')).toBe(false) +}) + +it('returns true if expected an event handler and given an event handler', () => { + expect(isHandlerKind('EventHandler')(new WebSocketHandler('*'))).toBe(true) +}) + +it('returns true if expected an event handler and given a custom event handler', () => { + class MyEventHandler extends WebSocketHandler { + constructor() { + super('*') + } + } + expect(isHandlerKind('EventHandler')(new MyEventHandler())).toBe(true) +}) + +it('returns false if expected an event handler but given arbitrary object', () => { + expect(isHandlerKind('EventHandler')(undefined)).toBe(false) + expect(isHandlerKind('EventHandler')(null)).toBe(false) + expect(isHandlerKind('EventHandler')({})).toBe(false) + expect(isHandlerKind('EventHandler')([])).toBe(false) + expect(isHandlerKind('EventHandler')(123)).toBe(false) + expect(isHandlerKind('EventHandler')('hello')).toBe(false) +}) diff --git a/src/core/utils/internal/isHandlerKind.ts b/src/core/utils/internal/isHandlerKind.ts new file mode 100644 index 000000000..d877bc847 --- /dev/null +++ b/src/core/utils/internal/isHandlerKind.ts @@ -0,0 +1,21 @@ +import type { HandlerKind } from '../../handlers/common' +import type { RequestHandler } from '../../handlers/RequestHandler' +import type { WebSocketHandler } from '../../handlers/WebSocketHandler' + +/** + * A filter function that ensures that the provided argument + * is a handler of the given kind. This helps differentiate + * between different kinds of handlers, e.g. request and event handlers. + */ +export function isHandlerKind(kind: K) { + return ( + input: unknown, + ): input is K extends 'EventHandler' ? WebSocketHandler : RequestHandler => { + return ( + input != null && + typeof input === 'object' && + '__kind' in input && + input.__kind === kind + ) + } +} diff --git a/src/core/ws/handleWebSocketEvent.ts b/src/core/ws/handleWebSocketEvent.ts index a20bd6ec4..919ae0e89 100644 --- a/src/core/ws/handleWebSocketEvent.ts +++ b/src/core/ws/handleWebSocketEvent.ts @@ -6,6 +6,7 @@ import { onUnhandledRequest, UnhandledRequestStrategy, } from '../utils/request/onUnhandledRequest' +import { isHandlerKind } from '../utils/internal/isHandlerKind' interface HandleWebSocketEventOptions { getUnhandledRequestStrategy: () => UnhandledRequestStrategy @@ -30,7 +31,7 @@ export function handleWebSocketEvent(options: HandleWebSocketEventOptions) { for (const handler of handlers) { if ( - handler instanceof WebSocketHandler && + isHandlerKind('EventHandler')(handler) && handler.predicate({ event: connectionEvent, parsedResult: handler.parse({ diff --git a/src/node/SetupServerCommonApi.ts b/src/node/SetupServerCommonApi.ts index dca9bd85a..0d2104119 100644 --- a/src/node/SetupServerCommonApi.ts +++ b/src/node/SetupServerCommonApi.ts @@ -14,14 +14,13 @@ import type { LifeCycleEventsMap, SharedOptions } from '~/core/sharedOptions' import { SetupApi } from '~/core/SetupApi' import { handleRequest } from '~/core/utils/handleRequest' import type { RequestHandler } from '~/core/handlers/RequestHandler' -import { HttpHandler } from '~/core/handlers/HttpHandler' -import { GraphQLHandler } from '~/core/handlers/GraphQLHandler' import type { WebSocketHandler } from '~/core/handlers/WebSocketHandler' import { mergeRight } from '~/core/utils/internal/mergeRight' import { InternalError, devUtils } from '~/core/utils/internal/devUtils' import type { SetupServerCommon } from './glossary' import { handleWebSocketEvent } from '~/core/ws/handleWebSocketEvent' import { webSocketInterceptor } from '~/core/ws/webSocketInterceptor' +import { isHandlerKind } from '~/core/utils/internal/isHandlerKind' export const DEFAULT_LISTEN_OPTIONS: RequiredDeep = { onUnhandledRequest: 'warn', @@ -63,12 +62,9 @@ export class SetupServerCommonApi const response = await handleRequest( request, requestId, - this.handlersController.currentHandlers().filter((handler) => { - return ( - handler instanceof HttpHandler || - handler instanceof GraphQLHandler - ) - }), + this.handlersController + .currentHandlers() + .filter(isHandlerKind('RequestHandler')), this.resolvedOptions, this.emitter, )