Skip to content

Commit 2b85c08

Browse files
committed
feat(core,common,testing): support overriding middleware for testing
closes: nestjs#4073
1 parent 88c8cf8 commit 2b85c08

File tree

11 files changed

+480
-14
lines changed

11 files changed

+480
-14
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
import {
2+
Injectable,
3+
MiddlewareConsumer,
4+
Module,
5+
NestMiddleware,
6+
} from '@nestjs/common';
7+
import { Test } from '@nestjs/testing';
8+
import * as request from 'supertest';
9+
import { expect } from 'chai';
10+
11+
describe('Middleware overriding', () => {
12+
@Injectable()
13+
class MiddlewareA implements NestMiddleware {
14+
use(req, res, next) {
15+
middlewareAApplied = true;
16+
next();
17+
}
18+
}
19+
20+
function MiddlewareAOverride(req, res, next) {
21+
middlewareAOverrideApplied = true;
22+
next();
23+
}
24+
25+
function MiddlewareB(req, res, next) {
26+
middlewareBApplied = true;
27+
next();
28+
}
29+
30+
@Injectable()
31+
class MiddlewareBOverride implements NestMiddleware {
32+
use(req, res, next) {
33+
middlewareBOverrideApplied = true;
34+
next();
35+
}
36+
}
37+
38+
@Injectable()
39+
class MiddlewareC implements NestMiddleware {
40+
use(req, res, next) {
41+
middlewareCApplied = true;
42+
next();
43+
}
44+
}
45+
46+
@Injectable()
47+
class MiddlewareC1Override implements NestMiddleware {
48+
use(req, res, next) {
49+
middlewareC1OverrideApplied = true;
50+
next();
51+
}
52+
}
53+
54+
function MiddlewareC2Override(req, res, next) {
55+
middlewareC2OverrideApplied = true;
56+
next();
57+
}
58+
59+
@Module({})
60+
class AppModule {
61+
configure(consumer: MiddlewareConsumer) {
62+
return consumer
63+
.apply(MiddlewareA)
64+
.forRoutes('a')
65+
.apply(MiddlewareB)
66+
.forRoutes('b')
67+
.apply(MiddlewareC)
68+
.forRoutes('c');
69+
}
70+
}
71+
72+
let middlewareAApplied: boolean;
73+
let middlewareAOverrideApplied: boolean;
74+
75+
let middlewareBApplied: boolean;
76+
let middlewareBOverrideApplied: boolean;
77+
78+
let middlewareCApplied: boolean;
79+
let middlewareC1OverrideApplied: boolean;
80+
let middlewareC2OverrideApplied: boolean;
81+
82+
const resetMiddlewareApplicationFlags = () => {
83+
middlewareAApplied =
84+
middlewareAOverrideApplied =
85+
middlewareBApplied =
86+
middlewareBOverrideApplied =
87+
middlewareCApplied =
88+
middlewareC1OverrideApplied =
89+
middlewareC2OverrideApplied =
90+
false;
91+
};
92+
93+
beforeEach(() => {
94+
resetMiddlewareApplicationFlags();
95+
});
96+
it('should override class middleware', async () => {
97+
const testingModule = await Test.createTestingModule({
98+
imports: [AppModule],
99+
})
100+
.overrideMiddleware(MiddlewareA)
101+
.useMiddleware(MiddlewareAOverride)
102+
.overrideMiddleware(MiddlewareC)
103+
.useMiddleware(MiddlewareC1Override, MiddlewareC2Override)
104+
.compile();
105+
106+
const app = testingModule.createNestApplication();
107+
await app.init();
108+
109+
await request(app.getHttpServer()).get('/a');
110+
111+
expect(middlewareAApplied).to.be.false;
112+
expect(middlewareAOverrideApplied).to.be.true;
113+
expect(middlewareBApplied).to.be.false;
114+
expect(middlewareBOverrideApplied).to.be.false;
115+
expect(middlewareCApplied).to.be.false;
116+
expect(middlewareC1OverrideApplied).to.be.false;
117+
expect(middlewareC2OverrideApplied).to.be.false;
118+
resetMiddlewareApplicationFlags();
119+
120+
await request(app.getHttpServer()).get('/b');
121+
122+
expect(middlewareAApplied).to.be.false;
123+
expect(middlewareAOverrideApplied).to.be.false;
124+
expect(middlewareBApplied).to.be.true;
125+
expect(middlewareBOverrideApplied).to.be.false;
126+
expect(middlewareCApplied).to.be.false;
127+
expect(middlewareC1OverrideApplied).to.be.false;
128+
expect(middlewareC2OverrideApplied).to.be.false;
129+
resetMiddlewareApplicationFlags();
130+
131+
await request(app.getHttpServer()).get('/c');
132+
133+
expect(middlewareAApplied).to.be.false;
134+
expect(middlewareAOverrideApplied).to.be.false;
135+
expect(middlewareBApplied).to.be.false;
136+
expect(middlewareBOverrideApplied).to.be.false;
137+
expect(middlewareCApplied).to.be.false;
138+
expect(middlewareC1OverrideApplied).to.be.true;
139+
expect(middlewareC2OverrideApplied).to.be.true;
140+
resetMiddlewareApplicationFlags();
141+
142+
await app.close();
143+
});
144+
145+
it('should override functional middleware', async () => {
146+
const testingModule = await Test.createTestingModule({
147+
imports: [AppModule],
148+
})
149+
.overrideMiddleware(MiddlewareB)
150+
.useMiddleware(MiddlewareBOverride)
151+
.compile();
152+
153+
const app = testingModule.createNestApplication();
154+
await app.init();
155+
156+
await request(app.getHttpServer()).get('/a');
157+
158+
expect(middlewareAApplied).to.be.true;
159+
expect(middlewareAOverrideApplied).to.be.false;
160+
expect(middlewareBApplied).to.be.false;
161+
expect(middlewareBOverrideApplied).to.be.false;
162+
expect(middlewareCApplied).to.be.false;
163+
expect(middlewareC1OverrideApplied).to.be.false;
164+
expect(middlewareC2OverrideApplied).to.be.false;
165+
resetMiddlewareApplicationFlags();
166+
167+
await request(app.getHttpServer()).get('/b');
168+
169+
expect(middlewareAApplied).to.be.false;
170+
expect(middlewareAOverrideApplied).to.be.false;
171+
expect(middlewareBApplied).to.be.false;
172+
expect(middlewareBOverrideApplied).to.be.true;
173+
expect(middlewareCApplied).to.be.false;
174+
expect(middlewareC1OverrideApplied).to.be.false;
175+
expect(middlewareC2OverrideApplied).to.be.false;
176+
resetMiddlewareApplicationFlags();
177+
178+
await request(app.getHttpServer()).get('/c');
179+
180+
expect(middlewareAApplied).to.be.false;
181+
expect(middlewareAOverrideApplied).to.be.false;
182+
expect(middlewareBApplied).to.be.false;
183+
expect(middlewareBOverrideApplied).to.be.false;
184+
expect(middlewareCApplied).to.be.true;
185+
expect(middlewareC1OverrideApplied).to.be.false;
186+
expect(middlewareC2OverrideApplied).to.be.false;
187+
resetMiddlewareApplicationFlags();
188+
189+
await app.close();
190+
});
191+
});
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
{
2+
"compilerOptions": {
3+
"module": "commonjs",
4+
"declaration": true,
5+
"removeComments": true,
6+
"emitDecoratorMetadata": true,
7+
"experimentalDecorators": true,
8+
"allowSyntheticDefaultImports": true,
9+
"target": "ES2021",
10+
"sourceMap": true,
11+
"outDir": "./dist",
12+
"baseUrl": "./",
13+
"incremental": true,
14+
"skipLibCheck": true
15+
},
16+
"include": ["src/**/*"]
17+
}

packages/common/interfaces/middleware/middleware-consumer.interface.ts

+13
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,17 @@ export interface MiddlewareConsumer {
1616
* @returns {MiddlewareConfigProxy}
1717
*/
1818
apply(...middleware: (Type<any> | Function)[]): MiddlewareConfigProxy;
19+
20+
/**
21+
* Replaces the currently applied middleware with a new (set of) middleware.
22+
*
23+
* @param {Type | Function} middlewareToReplace middleware class/function to be replaced.
24+
* @param {(Type | Function)[]} middlewareReplacement middleware class/function(s) that serve as a replacement for {@link middlewareToReplace}.
25+
*
26+
* @returns {MiddlewareConsumer}
27+
*/
28+
replace(
29+
middlewareToReplace: Type<any> | Function,
30+
...middlewareReplacement: (Type<any> | Function)[]
31+
): MiddlewareConsumer;
1932
}

packages/core/middleware/builder.ts

+48-12
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,15 @@ import { RouteInfoPathExtractor } from './route-info-path-extractor';
1515
import { RoutesMapper } from './routes-mapper';
1616
import { filterMiddleware } from './utils';
1717

18+
type MiddlewareConfigurationContext = {
19+
middleware: (Type<any> | Function)[];
20+
routes: RouteInfo[];
21+
excludedRoutes: RouteInfo[];
22+
};
23+
1824
export class MiddlewareBuilder implements MiddlewareConsumer {
19-
private readonly middlewareCollection = new Set<MiddlewareConfiguration>();
25+
private readonly middlewareConfigurationContexts: MiddlewareConfigurationContext[] =
26+
[];
2027

2128
constructor(
2229
private readonly routesMapper: RoutesMapper,
@@ -34,8 +41,39 @@ export class MiddlewareBuilder implements MiddlewareConsumer {
3441
);
3542
}
3643

44+
public replace(
45+
middlewareToReplace: Type<any> | Function,
46+
...middlewareReplacements: Array<Type<any> | Function>
47+
): MiddlewareBuilder {
48+
for (const currentConfigurationContext of this
49+
.middlewareConfigurationContexts) {
50+
currentConfigurationContext.middleware = flatten(
51+
currentConfigurationContext.middleware.map(middleware =>
52+
middleware === middlewareToReplace
53+
? middlewareReplacements
54+
: middleware,
55+
),
56+
) as (Type<any> | Function)[];
57+
}
58+
59+
return this;
60+
}
61+
62+
public getMiddlewareConfigurationContexts(): MiddlewareConfigurationContext[] {
63+
return this.middlewareConfigurationContexts;
64+
}
65+
3766
public build(): MiddlewareConfiguration[] {
38-
return [...this.middlewareCollection];
67+
return this.middlewareConfigurationContexts.map(
68+
({ middleware, routes, excludedRoutes }) => ({
69+
middleware: filterMiddleware(
70+
middleware,
71+
excludedRoutes,
72+
this.getHttpAdapter(),
73+
),
74+
forRoutes: routes,
75+
}),
76+
);
3977
}
4078

4179
public getHttpAdapter(): HttpServer {
@@ -68,19 +106,17 @@ export class MiddlewareBuilder implements MiddlewareConsumer {
68106
public forRoutes(
69107
...routes: Array<string | Type<any> | RouteInfo>
70108
): MiddlewareConsumer {
71-
const { middlewareCollection } = this.builder;
109+
const { middlewareConfigurationContexts } = this.builder;
72110

73111
const flattedRoutes = this.getRoutesFlatList(routes);
74112
const forRoutes = this.removeOverlappedRoutes(flattedRoutes);
75-
const configuration = {
76-
middleware: filterMiddleware(
77-
this.middleware,
78-
this.excludedRoutes,
79-
this.builder.getHttpAdapter(),
80-
),
81-
forRoutes,
82-
};
83-
middlewareCollection.add(configuration);
113+
114+
middlewareConfigurationContexts.push({
115+
middleware: this.middleware,
116+
routes: forRoutes,
117+
excludedRoutes: this.excludedRoutes,
118+
});
119+
84120
return this.builder;
85121
}
86122

0 commit comments

Comments
 (0)