diff --git a/integration/testing/middleware-override/e2e/middleware-override.spec.ts b/integration/testing/middleware-override/e2e/middleware-override.spec.ts new file mode 100644 index 00000000000..f0e853e30e6 --- /dev/null +++ b/integration/testing/middleware-override/e2e/middleware-override.spec.ts @@ -0,0 +1,192 @@ +import { + Injectable, + MiddlewareConsumer, + Module, + NestMiddleware, +} from '@nestjs/common'; +import { Test } from '@nestjs/testing'; +import * as request from 'supertest'; +import { expect } from 'chai'; + +describe('Middleware overriding', () => { + @Injectable() + class MiddlewareA implements NestMiddleware { + use(req, res, next) { + middlewareAApplied = true; + next(); + } + } + + function MiddlewareAOverride(req, res, next) { + middlewareAOverrideApplied = true; + next(); + } + + function MiddlewareB(req, res, next) { + middlewareBApplied = true; + next(); + } + + @Injectable() + class MiddlewareBOverride implements NestMiddleware { + use(req, res, next) { + middlewareBOverrideApplied = true; + next(); + } + } + + @Injectable() + class MiddlewareC implements NestMiddleware { + use(req, res, next) { + middlewareCApplied = true; + next(); + } + } + + @Injectable() + class MiddlewareC1Override implements NestMiddleware { + use(req, res, next) { + middlewareC1OverrideApplied = true; + next(); + } + } + + function MiddlewareC2Override(req, res, next) { + middlewareC2OverrideApplied = true; + next(); + } + + @Module({}) + class AppModule { + configure(consumer: MiddlewareConsumer) { + return consumer + .apply(MiddlewareA) + .forRoutes('a') + .apply(MiddlewareB) + .forRoutes('b') + .apply(MiddlewareC) + .forRoutes('c'); + } + } + + let middlewareAApplied: boolean; + let middlewareAOverrideApplied: boolean; + + let middlewareBApplied: boolean; + let middlewareBOverrideApplied: boolean; + + let middlewareCApplied: boolean; + let middlewareC1OverrideApplied: boolean; + let middlewareC2OverrideApplied: boolean; + + const resetMiddlewareApplicationFlags = () => { + middlewareAApplied = + middlewareAOverrideApplied = + middlewareBApplied = + middlewareBOverrideApplied = + middlewareCApplied = + middlewareC1OverrideApplied = + middlewareC2OverrideApplied = + false; + }; + + beforeEach(() => { + resetMiddlewareApplicationFlags(); + }); + + it('should override class middleware', async () => { + const testingModule = await Test.createTestingModule({ + imports: [AppModule], + }) + .overrideMiddleware(MiddlewareA) + .use(MiddlewareAOverride) + .overrideMiddleware(MiddlewareC) + .use(MiddlewareC1Override, MiddlewareC2Override) + .compile(); + + const app = testingModule.createNestApplication(); + await app.init(); + + await request(app.getHttpServer()).get('/a'); + + expect(middlewareAApplied).to.be.false; + expect(middlewareAOverrideApplied).to.be.true; + expect(middlewareBApplied).to.be.false; + expect(middlewareBOverrideApplied).to.be.false; + expect(middlewareCApplied).to.be.false; + expect(middlewareC1OverrideApplied).to.be.false; + expect(middlewareC2OverrideApplied).to.be.false; + resetMiddlewareApplicationFlags(); + + await request(app.getHttpServer()).get('/b'); + + expect(middlewareAApplied).to.be.false; + expect(middlewareAOverrideApplied).to.be.false; + expect(middlewareBApplied).to.be.true; + expect(middlewareBOverrideApplied).to.be.false; + expect(middlewareCApplied).to.be.false; + expect(middlewareC1OverrideApplied).to.be.false; + expect(middlewareC2OverrideApplied).to.be.false; + resetMiddlewareApplicationFlags(); + + await request(app.getHttpServer()).get('/c'); + + expect(middlewareAApplied).to.be.false; + expect(middlewareAOverrideApplied).to.be.false; + expect(middlewareBApplied).to.be.false; + expect(middlewareBOverrideApplied).to.be.false; + expect(middlewareCApplied).to.be.false; + expect(middlewareC1OverrideApplied).to.be.true; + expect(middlewareC2OverrideApplied).to.be.true; + resetMiddlewareApplicationFlags(); + + await app.close(); + }); + + it('should override functional middleware', async () => { + const testingModule = await Test.createTestingModule({ + imports: [AppModule], + }) + .overrideMiddleware(MiddlewareB) + .use(MiddlewareBOverride) + .compile(); + + const app = testingModule.createNestApplication(); + await app.init(); + + await request(app.getHttpServer()).get('/a'); + + expect(middlewareAApplied).to.be.true; + expect(middlewareAOverrideApplied).to.be.false; + expect(middlewareBApplied).to.be.false; + expect(middlewareBOverrideApplied).to.be.false; + expect(middlewareCApplied).to.be.false; + expect(middlewareC1OverrideApplied).to.be.false; + expect(middlewareC2OverrideApplied).to.be.false; + resetMiddlewareApplicationFlags(); + + await request(app.getHttpServer()).get('/b'); + + expect(middlewareAApplied).to.be.false; + expect(middlewareAOverrideApplied).to.be.false; + expect(middlewareBApplied).to.be.false; + expect(middlewareBOverrideApplied).to.be.true; + expect(middlewareCApplied).to.be.false; + expect(middlewareC1OverrideApplied).to.be.false; + expect(middlewareC2OverrideApplied).to.be.false; + resetMiddlewareApplicationFlags(); + + await request(app.getHttpServer()).get('/c'); + + expect(middlewareAApplied).to.be.false; + expect(middlewareAOverrideApplied).to.be.false; + expect(middlewareBApplied).to.be.false; + expect(middlewareBOverrideApplied).to.be.false; + expect(middlewareCApplied).to.be.true; + expect(middlewareC1OverrideApplied).to.be.false; + expect(middlewareC2OverrideApplied).to.be.false; + resetMiddlewareApplicationFlags(); + + await app.close(); + }); +}); diff --git a/integration/testing-module-override/tsconfig.json b/integration/testing/middleware-override/tsconfig.json similarity index 100% rename from integration/testing-module-override/tsconfig.json rename to integration/testing/middleware-override/tsconfig.json diff --git a/integration/testing-module-override/e2e/circular-dependency/a.module.ts b/integration/testing/module-override/e2e/circular-dependency/a.module.ts similarity index 100% rename from integration/testing-module-override/e2e/circular-dependency/a.module.ts rename to integration/testing/module-override/e2e/circular-dependency/a.module.ts diff --git a/integration/testing-module-override/e2e/circular-dependency/b.module.ts b/integration/testing/module-override/e2e/circular-dependency/b.module.ts similarity index 100% rename from integration/testing-module-override/e2e/circular-dependency/b.module.ts rename to integration/testing/module-override/e2e/circular-dependency/b.module.ts diff --git a/integration/testing-module-override/e2e/modules-override.spec.ts b/integration/testing/module-override/e2e/modules-override.spec.ts similarity index 100% rename from integration/testing-module-override/e2e/modules-override.spec.ts rename to integration/testing/module-override/e2e/modules-override.spec.ts diff --git a/integration/testing/module-override/tsconfig.json b/integration/testing/module-override/tsconfig.json new file mode 100644 index 00000000000..d9c82ca9758 --- /dev/null +++ b/integration/testing/module-override/tsconfig.json @@ -0,0 +1,17 @@ +{ + "compilerOptions": { + "module": "commonjs", + "declaration": true, + "removeComments": true, + "emitDecoratorMetadata": true, + "experimentalDecorators": true, + "allowSyntheticDefaultImports": true, + "target": "ES2021", + "sourceMap": true, + "outDir": "./dist", + "baseUrl": "./", + "incremental": true, + "skipLibCheck": true + }, + "include": ["src/**/*"] +} diff --git a/packages/common/interfaces/middleware/middleware-consumer.interface.ts b/packages/common/interfaces/middleware/middleware-consumer.interface.ts index 21bcbc5e9bc..c00742f96b6 100644 --- a/packages/common/interfaces/middleware/middleware-consumer.interface.ts +++ b/packages/common/interfaces/middleware/middleware-consumer.interface.ts @@ -16,4 +16,17 @@ export interface MiddlewareConsumer { * @returns {MiddlewareConfigProxy} */ apply(...middleware: (Type | Function)[]): MiddlewareConfigProxy; + + /** + * Replaces the currently applied middleware with a new (set of) middleware. + * + * @param {Type | Function} middlewareToReplace middleware class/function to be replaced. + * @param {(Type | Function)[]} middlewareReplacement middleware class/function(s) that serve as a replacement for {@link middlewareToReplace}. + * + * @returns {MiddlewareConsumer} + */ + replace( + middlewareToReplace: Type | Function, + ...middlewareReplacement: (Type | Function)[] + ): MiddlewareConsumer; } diff --git a/packages/core/middleware/builder.ts b/packages/core/middleware/builder.ts index aa1b3675a80..8ded50ebad4 100644 --- a/packages/core/middleware/builder.ts +++ b/packages/core/middleware/builder.ts @@ -15,8 +15,15 @@ import { RouteInfoPathExtractor } from './route-info-path-extractor'; import { RoutesMapper } from './routes-mapper'; import { filterMiddleware } from './utils'; +type MiddlewareConfigurationContext = { + middleware: (Type | Function)[]; + routes: RouteInfo[]; + excludedRoutes: RouteInfo[]; +}; + export class MiddlewareBuilder implements MiddlewareConsumer { - private readonly middlewareCollection = new Set(); + private readonly middlewareConfigurationContexts: MiddlewareConfigurationContext[] = + []; constructor( private readonly routesMapper: RoutesMapper, @@ -34,8 +41,39 @@ export class MiddlewareBuilder implements MiddlewareConsumer { ); } + public replace( + middlewareToReplace: Type | Function, + ...middlewareReplacements: Array | Function> + ): MiddlewareBuilder { + for (const currentConfigurationContext of this + .middlewareConfigurationContexts) { + currentConfigurationContext.middleware = flatten( + currentConfigurationContext.middleware.map(middleware => + middleware === middlewareToReplace + ? middlewareReplacements + : middleware, + ), + ) as (Type | Function)[]; + } + + return this; + } + + public getMiddlewareConfigurationContexts(): MiddlewareConfigurationContext[] { + return this.middlewareConfigurationContexts; + } + public build(): MiddlewareConfiguration[] { - return [...this.middlewareCollection]; + return this.middlewareConfigurationContexts.map( + ({ middleware, routes, excludedRoutes }) => ({ + middleware: filterMiddleware( + middleware, + excludedRoutes, + this.getHttpAdapter(), + ), + forRoutes: routes, + }), + ); } public getHttpAdapter(): HttpServer { @@ -68,19 +106,17 @@ export class MiddlewareBuilder implements MiddlewareConsumer { public forRoutes( ...routes: Array | RouteInfo> ): MiddlewareConsumer { - const { middlewareCollection } = this.builder; + const { middlewareConfigurationContexts } = this.builder; const flattedRoutes = this.getRoutesFlatList(routes); const forRoutes = this.removeOverlappedRoutes(flattedRoutes); - const configuration = { - middleware: filterMiddleware( - this.middleware, - this.excludedRoutes, - this.builder.getHttpAdapter(), - ), - forRoutes, - }; - middlewareCollection.add(configuration); + + middlewareConfigurationContexts.push({ + middleware: this.middleware, + routes: forRoutes, + excludedRoutes: this.excludedRoutes, + }); + return this.builder; } diff --git a/packages/core/test/middleware/builder.spec.ts b/packages/core/test/middleware/builder.spec.ts index fa73acae66a..ed9a63d0b40 100644 --- a/packages/core/test/middleware/builder.spec.ts +++ b/packages/core/test/middleware/builder.spec.ts @@ -4,6 +4,8 @@ import { Delete, Get, Head, + Injectable, + NestMiddleware, Options, Patch, Post, @@ -21,8 +23,50 @@ import { RoutesMapper } from '../../middleware/routes-mapper'; import { NoopHttpAdapter } from './../utils/noop-adapter.spec'; describe('MiddlewareBuilder', () => { + @Injectable() + class MiddlewareA implements NestMiddleware { + use(_req, _res, next) { + next(); + } + } + + function MiddlewareB(_req, _res, next) { + next(); + } + + @Injectable() + class MiddlewareC implements NestMiddleware { + use(_req, _res, next) { + next(); + } + } + let builder: MiddlewareBuilder; + const route = { path: '/test', method: RequestMethod.GET }; + const routesOfTestController = [ + { + method: RequestMethod.GET, + path: '/path/route', + }, + { + method: RequestMethod.GET, + path: '/path/versioned', + version: '1', + }, + ]; + const versionedRoutesOfTestController = [ + { + method: RequestMethod.GET, + path: '/path/route', + }, + { + method: RequestMethod.GET, + path: '/v1/path/versioned', + version: '1', + }, + ]; + beforeEach(() => { const container = new NestContainer(); const appConfig = new ApplicationConfig(); @@ -46,6 +90,7 @@ describe('MiddlewareBuilder', () => { beforeEach(() => { configProxy = builder.apply([]); }); + @Controller('path') class Test { @Get('route') @@ -55,7 +100,56 @@ describe('MiddlewareBuilder', () => { @Get('versioned') public getAllVersioned() {} } - const route = { path: '/test', method: RequestMethod.GET }; + + it('should generate the correct middleware configuration contexts', () => { + configProxy.forRoutes(route, Test); + + expect(builder.getMiddlewareConfigurationContexts()).to.be.eql([ + { + middleware: [], + routes: [route, ...routesOfTestController], + excludedRoutes: [], + }, + ]); + + builder + .apply(MiddlewareA, MiddlewareB, MiddlewareC) + .forRoutes(route) + .apply(MiddlewareA, MiddlewareB) + .exclude(route) + .forRoutes(Test) + .apply(MiddlewareC) + .exclude(route, ...routesOfTestController) + .forRoutes('*'); + + expect(builder.getMiddlewareConfigurationContexts()).to.be.eql([ + { + middleware: [], + routes: [route, ...routesOfTestController], + excludedRoutes: [], + }, + { + middleware: [MiddlewareA, MiddlewareB, MiddlewareC], + routes: [route], + excludedRoutes: [], + }, + { + middleware: [MiddlewareA, MiddlewareB], + routes: routesOfTestController, + excludedRoutes: [route], + }, + { + middleware: [MiddlewareC], + routes: [ + { + method: -1, + path: '/*', + }, + ], + excludedRoutes: [route, ...versionedRoutesOfTestController], + }, + ]); + }); it('should store configuration passed as argument', () => { configProxy.forRoutes(route, Test); @@ -198,4 +292,66 @@ describe('MiddlewareBuilder', () => { ]); }); }); + + describe('replace', () => { + function MiddlewareAOverride(_req, _res, next) { + next(); + } + + @Injectable() + class MiddlewareBOverride implements NestMiddleware { + use(_req, _res, next) { + next(); + } + } + + @Injectable() + class MiddlewareC1Override implements NestMiddleware { + use(_req, _res, next) { + next(); + } + } + + function MiddlewareC2Override(_req, _res, next) { + next(); + } + + it('should replace class middleware', () => { + builder + .apply(MiddlewareA, MiddlewareB, MiddlewareC) + .exclude(route) + .forRoutes(...routesOfTestController) + .replace(MiddlewareA, MiddlewareAOverride) + .replace(MiddlewareC, MiddlewareC1Override, MiddlewareC2Override); + + expect(builder.getMiddlewareConfigurationContexts()).to.be.eql([ + { + middleware: [ + MiddlewareAOverride, + MiddlewareB, + MiddlewareC1Override, + MiddlewareC2Override, + ], + routes: [...routesOfTestController], + excludedRoutes: [route], + }, + ]); + }); + + it('should replace functional middleware', () => { + builder + .apply(MiddlewareA, MiddlewareB, MiddlewareC) + .exclude(route) + .forRoutes(route, ...routesOfTestController) + .replace(MiddlewareB, MiddlewareBOverride); + + expect(builder.getMiddlewareConfigurationContexts()).to.be.eql([ + { + middleware: [MiddlewareA, MiddlewareBOverride, MiddlewareC], + routes: [route, ...routesOfTestController], + excludedRoutes: [route], + }, + ]); + }); + }); }); diff --git a/packages/testing/interfaces/override-middleware.interface.ts b/packages/testing/interfaces/override-middleware.interface.ts new file mode 100644 index 00000000000..2035ff7c2f6 --- /dev/null +++ b/packages/testing/interfaces/override-middleware.interface.ts @@ -0,0 +1,10 @@ +import { ModuleDefinition } from '@nestjs/core/interfaces/module-definition.interface'; +import { TestingModuleBuilder } from '../testing-module.builder'; +import { Type } from '@nestjs/common'; + +/** + * @publicApi + */ +export interface OverrideMiddleware { + use: (...newMiddleware: (Type | Function)[]) => TestingModuleBuilder; +} diff --git a/packages/testing/testing-module.builder.ts b/packages/testing/testing-module.builder.ts index b6670cbce51..545f8fd19c3 100644 --- a/packages/testing/testing-module.builder.ts +++ b/packages/testing/testing-module.builder.ts @@ -1,4 +1,11 @@ -import { Logger, LoggerService, Module, ModuleMetadata } from '@nestjs/common'; +import { + Logger, + LoggerService, + MiddlewareConsumer, + Module, + ModuleMetadata, + Type, +} from '@nestjs/common'; import { NestApplicationContextOptions } from '@nestjs/common/interfaces/nest-application-context-options.interface'; import { ApplicationConfig } from '@nestjs/core/application-config'; import { NestContainer } from '@nestjs/core/injector/container'; @@ -22,6 +29,7 @@ import { TestingLogger } from './services/testing-logger.service'; import { TestingInjector } from './testing-injector'; import { TestingInstanceLoader } from './testing-instance-loader'; import { TestingModule } from './testing-module'; +import { OverrideMiddleware } from './interfaces/override-middleware.interface'; /** * @publicApi @@ -34,6 +42,11 @@ export class TestingModuleBuilder { ModuleDefinition, ModuleDefinition >(); + private readonly middlewareOverloadsMap = new Map< + Type | Function, + (Type | Function)[] + >(); + private readonly module: any; private testingLogger: LoggerService; private mocker?: MockFactory; @@ -84,6 +97,17 @@ export class TestingModuleBuilder { }; } + public overrideMiddleware( + middlewareToOverride: Type | Function, + ): OverrideMiddleware { + return { + use: (...newMiddleware: (Type | Function)[]) => { + this.middlewareOverloadsMap.set(middlewareToOverride, newMiddleware); + return this; + }, + }; + } + public async compile( options: Pick = {}, ): Promise { @@ -111,6 +135,7 @@ export class TestingModuleBuilder { this.applyOverloadsMap(); await this.createInstancesOfDependencies(graphInspector, options); scanner.applyApplicationProviders(); + this.applyMiddlewareOverrides(); const root = this.getRootModule(); return new TestingModule( @@ -150,6 +175,23 @@ export class TestingModuleBuilder { }); } + private applyMiddlewareOverrides() { + for (const { instance } of this.container.getModules().values()) { + const originalConfigurationMethod = instance.configure; + instance.configure = (middlewareConsumer: MiddlewareConsumer) => { + if (!originalConfigurationMethod) { + return []; + } + originalConfigurationMethod(middlewareConsumer); + for (const [middlewareToOverride, newMiddleware] of this + .middlewareOverloadsMap) { + middlewareConsumer.replace(middlewareToOverride, ...newMiddleware); + } + return middlewareConsumer; + }; + } + } + private getModuleOverloads(): ModuleOverride[] { const overloads = [...this.moduleOverloadsMap.entries()]; return overloads.map(([moduleToReplace, newModule]) => ({ @@ -183,6 +225,7 @@ export class TestingModuleBuilder { private createModule(metadata: ModuleMetadata) { class RootTestModule {} + Module(metadata)(RootTestModule); return RootTestModule; }