Skip to content

Commit

Permalink
fix: prevent instanceof handler check failures between different MS…
Browse files Browse the repository at this point in the history
…W versions (#2349)
  • Loading branch information
kettanaito authored Nov 6, 2024
1 parent 3135575 commit 28d26bd
Show file tree
Hide file tree
Showing 10 changed files with 105 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import { XMLHttpRequestInterceptor } from '@mswjs/interceptors/XMLHttpRequest'
import { SetupWorkerInternalContext, StartOptions } from '../glossary'
import type { RequiredDeep } from '~/core/typeUtils'
import { handleRequest } from '~/core/utils/handleRequest'
import { isHandlerKind } from '~/core/utils/internal/isHandlerKind'

export function createFallbackRequestListener(
context: SetupWorkerInternalContext,
Expand All @@ -24,7 +25,7 @@ export function createFallbackRequestListener(
const response = await handleRequest(
request,
requestId,
context.getRequestHandlers(),
context.getRequestHandlers().filter(isHandlerKind('RequestHandler')),
options,
context.emitter,
{
Expand Down
9 changes: 2 additions & 7 deletions src/browser/setupWorker/start/createRequestListener.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,11 @@ import {
} from './utils/createMessageChannel'
import { parseWorkerRequest } from '../../utils/parseWorkerRequest'
import { RequestHandler } from '~/core/handlers/RequestHandler'
import { HttpHandler } from '~/core/handlers/HttpHandler'
import { GraphQLHandler } from '~/core/handlers/GraphQLHandler'
import { handleRequest } from '~/core/utils/handleRequest'
import { RequiredDeep } from '~/core/typeUtils'
import { devUtils } from '~/core/utils/internal/devUtils'
import { toResponseInit } from '~/core/utils/toResponseInit'
import { isHandlerKind } from '~/core/utils/internal/isHandlerKind'

export const createRequestListener = (
context: SetupWorkerInternalContext,
Expand Down Expand Up @@ -45,11 +44,7 @@ export const createRequestListener = (
await handleRequest(
request,
requestId,
context.getRequestHandlers().filter((handler) => {
return (
handler instanceof HttpHandler || handler instanceof GraphQLHandler
)
}),
context.getRequestHandlers().filter(isHandlerKind('RequestHandler')),
options,
context.emitter,
{
Expand Down
4 changes: 4 additions & 0 deletions src/core/handlers/RequestHandler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import {
import type { ResponseResolutionContext } from '../utils/executeHandlers'
import type { MaybePromise } from '../typeUtils'
import { StrictRequest, StrictResponse } from '..//HttpResponse'
import type { HandlerKind } from './common'

export type DefaultRequestMultipartBody = Record<
string,
Expand Down Expand Up @@ -117,6 +118,8 @@ export abstract class RequestHandler<
StrictRequest<DefaultBodyType>
>()

private readonly __kind: HandlerKind

public info: HandlerInfo & RequestHandlerInternalInfo
/**
* Indicates whether this request handler has been used
Expand Down Expand Up @@ -151,6 +154,7 @@ export abstract class RequestHandler<
}

this.isUsed = false
this.__kind = 'RequestHandler'
}

/**
Expand Down
4 changes: 4 additions & 0 deletions src/core/handlers/WebSocketHandler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import {
matchRequestUrl,
} from '../utils/matching/matchRequestUrl'
import { getCallFrame } from '../utils/internal/getCallFrame'
import type { HandlerKind } from './common'

type WebSocketHandlerParsedResult = {
match: Match
Expand All @@ -28,6 +29,8 @@ const kStopPropagationPatched = Symbol('kStopPropagationPatched')
const KOnStopPropagation = Symbol('KOnStopPropagation')

export class WebSocketHandler {
private readonly __kind: HandlerKind

public id: string
public callFrame?: string

Expand All @@ -38,6 +41,7 @@ export class WebSocketHandler {

this[kEmitter] = new Emitter()
this.callFrame = getCallFrame(new Error())
this.__kind = 'EventHandler'
}

public parse(args: {
Expand Down
1 change: 1 addition & 0 deletions src/core/handlers/common.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
export type HandlerKind = 'RequestHandler' | 'EventHandler'
6 changes: 1 addition & 5 deletions src/core/utils/executeHandlers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ export interface ResponseResolutionContext {
* Returns the execution result object containing any matching request
* handler and any mocked response it returned.
*/
export const executeHandlers = async <Handlers extends Array<unknown>>({
export const executeHandlers = async <Handlers extends Array<RequestHandler>>({
request,
requestId,
handlers,
Expand All @@ -33,10 +33,6 @@ export const executeHandlers = async <Handlers extends Array<unknown>>({
let result: RequestHandlerExecutionResult<any> | null = null

for (const handler of handlers) {
if (!(handler instanceof RequestHandler)) {
continue
}

result = await handler.run({ request, requestId, resolutionContext })

// If the handler produces some result for this request,
Expand Down
64 changes: 64 additions & 0 deletions src/core/utils/internal/isHandlerKind.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import { GraphQLHandler } from '../../handlers/GraphQLHandler'
import { HttpHandler } from '../../handlers/HttpHandler'
import { RequestHandler } from '../../handlers/RequestHandler'
import { WebSocketHandler } from '../../handlers/WebSocketHandler'
import { isHandlerKind } from './isHandlerKind'

it('returns true if expected a request handler and given a request handler', () => {
expect(
isHandlerKind('RequestHandler')(new HttpHandler('*', '*', () => {})),
).toBe(true)

expect(
isHandlerKind('RequestHandler')(
new GraphQLHandler('all', '*', '*', () => {}),
),
).toBe(true)
})

it('returns true if expected a request handler and given a custom request handler', () => {
class MyHandler extends RequestHandler {
constructor() {
super({ info: { header: '*' }, resolver: () => {} })
}
predicate = () => false
log() {}
}

expect(isHandlerKind('RequestHandler')(new MyHandler())).toBe(true)
})

it('returns false if expected a request handler but given event handler', () => {
expect(isHandlerKind('RequestHandler')(new WebSocketHandler('*'))).toBe(false)
})

it('returns false if expected a request handler but given arbitrary object', () => {
expect(isHandlerKind('RequestHandler')(undefined)).toBe(false)
expect(isHandlerKind('RequestHandler')(null)).toBe(false)
expect(isHandlerKind('RequestHandler')({})).toBe(false)
expect(isHandlerKind('RequestHandler')([])).toBe(false)
expect(isHandlerKind('RequestHandler')(123)).toBe(false)
expect(isHandlerKind('RequestHandler')('hello')).toBe(false)
})

it('returns true if expected an event handler and given an event handler', () => {
expect(isHandlerKind('EventHandler')(new WebSocketHandler('*'))).toBe(true)
})

it('returns true if expected an event handler and given a custom event handler', () => {
class MyEventHandler extends WebSocketHandler {
constructor() {
super('*')
}
}
expect(isHandlerKind('EventHandler')(new MyEventHandler())).toBe(true)
})

it('returns false if expected an event handler but given arbitrary object', () => {
expect(isHandlerKind('EventHandler')(undefined)).toBe(false)
expect(isHandlerKind('EventHandler')(null)).toBe(false)
expect(isHandlerKind('EventHandler')({})).toBe(false)
expect(isHandlerKind('EventHandler')([])).toBe(false)
expect(isHandlerKind('EventHandler')(123)).toBe(false)
expect(isHandlerKind('EventHandler')('hello')).toBe(false)
})
21 changes: 21 additions & 0 deletions src/core/utils/internal/isHandlerKind.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import type { HandlerKind } from '../../handlers/common'
import type { RequestHandler } from '../../handlers/RequestHandler'
import type { WebSocketHandler } from '../../handlers/WebSocketHandler'

/**
* A filter function that ensures that the provided argument
* is a handler of the given kind. This helps differentiate
* between different kinds of handlers, e.g. request and event handlers.
*/
export function isHandlerKind<K extends HandlerKind>(kind: K) {
return (
input: unknown,
): input is K extends 'EventHandler' ? WebSocketHandler : RequestHandler => {
return (
input != null &&
typeof input === 'object' &&
'__kind' in input &&
input.__kind === kind
)
}
}
3 changes: 2 additions & 1 deletion src/core/ws/handleWebSocketEvent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import {
onUnhandledRequest,
UnhandledRequestStrategy,
} from '../utils/request/onUnhandledRequest'
import { isHandlerKind } from '../utils/internal/isHandlerKind'

interface HandleWebSocketEventOptions {
getUnhandledRequestStrategy: () => UnhandledRequestStrategy
Expand All @@ -30,7 +31,7 @@ export function handleWebSocketEvent(options: HandleWebSocketEventOptions) {

for (const handler of handlers) {
if (
handler instanceof WebSocketHandler &&
isHandlerKind('EventHandler')(handler) &&
handler.predicate({
event: connectionEvent,
parsedResult: handler.parse({
Expand Down
12 changes: 4 additions & 8 deletions src/node/SetupServerCommonApi.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,13 @@ import type { LifeCycleEventsMap, SharedOptions } from '~/core/sharedOptions'
import { SetupApi } from '~/core/SetupApi'
import { handleRequest } from '~/core/utils/handleRequest'
import type { RequestHandler } from '~/core/handlers/RequestHandler'
import { HttpHandler } from '~/core/handlers/HttpHandler'
import { GraphQLHandler } from '~/core/handlers/GraphQLHandler'
import type { WebSocketHandler } from '~/core/handlers/WebSocketHandler'
import { mergeRight } from '~/core/utils/internal/mergeRight'
import { InternalError, devUtils } from '~/core/utils/internal/devUtils'
import type { SetupServerCommon } from './glossary'
import { handleWebSocketEvent } from '~/core/ws/handleWebSocketEvent'
import { webSocketInterceptor } from '~/core/ws/webSocketInterceptor'
import { isHandlerKind } from '~/core/utils/internal/isHandlerKind'

export const DEFAULT_LISTEN_OPTIONS: RequiredDeep<SharedOptions> = {
onUnhandledRequest: 'warn',
Expand Down Expand Up @@ -63,12 +62,9 @@ export class SetupServerCommonApi
const response = await handleRequest(
request,
requestId,
this.handlersController.currentHandlers().filter((handler) => {
return (
handler instanceof HttpHandler ||
handler instanceof GraphQLHandler
)
}),
this.handlersController
.currentHandlers()
.filter(isHandlerKind('RequestHandler')),
this.resolvedOptions,
this.emitter,
)
Expand Down

0 comments on commit 28d26bd

Please sign in to comment.