diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 17b8da641d3..407cd00781b 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -469,7 +469,7 @@ def __call__( # Save CPU cycles by building middleware stack once if not self._middleware_stack_built: - self._build_middleware_stack(router_middlewares=router_middlewares) + self._build_middleware_stack(router_middlewares=router_middlewares, app=app) # If debug is turned on then output the middleware stack to the console if app._debug: @@ -487,7 +487,7 @@ def __call__( # Call the Middleware Wrapped _call_stack function handler with the app return self._middleware_stack(app) - def _build_middleware_stack(self, router_middlewares: list[Callable[..., Any]]) -> None: + def _build_middleware_stack(self, router_middlewares: list[Callable[..., Any]], app) -> None: """ Builds the middleware stack for the handler by wrapping each handler in an instance of MiddlewareWrapper which is used to contain the state @@ -505,7 +505,25 @@ def _build_middleware_stack(self, router_middlewares: list[Callable[..., Any]]) The Route Middleware stack is processed in reverse order. This is so the stack of middleware handlers is applied in the order of being added to the handler. """ - all_middlewares = router_middlewares + self.middlewares + # Build middleware stack in the correct order for validation: + # 1. Request validation middleware (first) + # 2. Router middlewares + user middlewares (middle) + # 3. Response validation middleware (before route handler) + # 4. Route handler adapter (last) + + all_middlewares = [] + + # Add request validation middleware first if validation is enabled + if hasattr(app, "_request_validation_middleware"): + all_middlewares.append(app._request_validation_middleware) + + # Add user middlewares in the middle + all_middlewares.extend(router_middlewares + self.middlewares) + + # Add response validation middleware before the route handler if validation is enabled + if hasattr(app, "_response_validation_middleware"): + all_middlewares.append(app._response_validation_middleware) + logger.debug(f"Building middleware stack: {all_middlewares}") # IMPORTANT: @@ -1639,17 +1657,16 @@ def __init__( self._json_body_deserializer = json_body_deserializer if self._enable_validation: - from aws_lambda_powertools.event_handler.middlewares.openapi_validation import OpenAPIValidationMiddleware - - # Note the serializer argument: only use custom serializer if provided by the caller - # Otherwise, fully rely on the internal Pydantic based mechanism to serialize responses for validation. - self.use( - [ - OpenAPIValidationMiddleware( - validation_serializer=serializer, - has_response_validation_error=self._has_response_validation_error, - ), - ], + from aws_lambda_powertools.event_handler.middlewares.openapi_validation import ( + OpenAPIRequestValidationMiddleware, + OpenAPIResponseValidationMiddleware, + ) + + # Store validation middlewares to be added in the correct order later + self._request_validation_middleware = OpenAPIRequestValidationMiddleware() + self._response_validation_middleware = OpenAPIResponseValidationMiddleware( + validation_serializer=serializer, + has_response_validation_error=self._has_response_validation_error, ) def _validate_response_validation_error_http_code( @@ -2668,7 +2685,7 @@ def _call_exception_handler(self, exp: Exception, route: Route) -> ResponseBuild route=route, ) - # OpenAPIValidationMiddleware will only raise ResponseValidationError when + # OpenAPIResponseValidationMiddleware will only raise ResponseValidationError when # 'self._response_validation_error_http_code' is not None or # when route has custom_response_validation_http_code if isinstance(exp, ResponseValidationError): diff --git a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py index e5745ebddf3..6a276de20fb 100644 --- a/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py +++ b/aws_lambda_powertools/event_handler/middlewares/openapi_validation.py @@ -37,56 +37,20 @@ APPLICATION_FORM_CONTENT_TYPE = "application/x-www-form-urlencoded" -class OpenAPIValidationMiddleware(BaseMiddlewareHandler): +class OpenAPIRequestValidationMiddleware(BaseMiddlewareHandler): """ - OpenAPIValidationMiddleware is a middleware that validates the request against the OpenAPI schema defined by the - Lambda handler. It also validates the response against the OpenAPI schema defined by the Lambda handler. It - should not be used directly, but rather through the `enable_validation` parameter of the `ApiGatewayResolver`. + OpenAPI request validation middleware - validates only incoming requests. - Example - -------- - - ```python - from pydantic import BaseModel - - from aws_lambda_powertools.event_handler.api_gateway import ( - APIGatewayRestResolver, - ) - - class Todo(BaseModel): - name: str - - app = APIGatewayRestResolver(enable_validation=True) - - @app.get("/todos") - def get_todos(): list[Todo]: - return [Todo(name="hello world")] - ``` + This middleware should be used first in the middleware chain to validate + requests before they reach user middlewares. """ - def __init__( - self, - validation_serializer: Callable[[Any], str] | None = None, - has_response_validation_error: bool = False, - ): - """ - Initialize the OpenAPIValidationMiddleware. - - Parameters - ---------- - validation_serializer : Callable, optional - Optional serializer to use when serializing the response for validation. - Use it when you have a custom type that cannot be serialized by the default jsonable_encoder. - - has_response_validation_error: bool, optional - Optional flag used to distinguish between payload and validation errors. - By setting this flag to True, ResponseValidationError will be raised if response could not be validated. - """ - self._validation_serializer = validation_serializer - self._has_response_validation_error = has_response_validation_error + def __init__(self): + """Initialize the request validation middleware.""" + pass def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) -> Response: - logger.debug("OpenAPIValidationMiddleware handler") + logger.debug("OpenAPIRequestValidationMiddleware handler") route: Route = app.context["_route"] @@ -140,15 +104,111 @@ def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) -> if errors: # Raise the validation errors raise RequestValidationError(_normalize_errors(errors)) + + # Re-write the route_args with the validated values + app.context["_route_args"] = values + + # Call the next middleware + return next_middleware(app) + + def _get_body(self, app: EventHandlerInstance) -> dict[str, Any]: + """ + Get the request body from the event, and parse it according to content type. + """ + content_type = app.current_event.headers.get("content-type", "").strip() + + # Handle JSON content + if not content_type or content_type.startswith(APPLICATION_JSON_CONTENT_TYPE): + return self._parse_json_data(app) + + # Handle URL-encoded form data + elif content_type.startswith(APPLICATION_FORM_CONTENT_TYPE): + return self._parse_form_data(app) + else: - # Re-write the route_args with the validated values, and call the next middleware - app.context["_route_args"] = values + raise NotImplementedError("Only JSON body or Form() are supported") - # Call the handler by calling the next middleware - response = next_middleware(app) + def _parse_json_data(self, app: EventHandlerInstance) -> dict[str, Any]: + """Parse JSON data from the request body.""" + try: + return app.current_event.json_body + except json.JSONDecodeError as e: + raise RequestValidationError( + [ + { + "type": "json_invalid", + "loc": ("body", e.pos), + "msg": "JSON decode error", + "input": {}, + "ctx": {"error": e.msg}, + }, + ], + body=e.doc, + ) from e - # Process the response - return self._handle_response(route=route, response=response) + def _parse_form_data(self, app: EventHandlerInstance) -> dict[str, Any]: + """Parse URL-encoded form data from the request body.""" + try: + body = app.current_event.decoded_body or "" + # parse_qs returns dict[str, list[str]], but we want dict[str, str] for single values + parsed = parse_qs(body, keep_blank_values=True) + + result: dict[str, Any] = {key: values[0] if len(values) == 1 else values for key, values in parsed.items()} + return result + + except Exception as e: # pragma: no cover + raise RequestValidationError( # pragma: no cover + [ + { + "type": "form_invalid", + "loc": ("body",), + "msg": "Form data parsing error", + "input": {}, + "ctx": {"error": str(e)}, + }, + ], + ) from e + + +class OpenAPIResponseValidationMiddleware(BaseMiddlewareHandler): + """ + OpenAPI response validation middleware - validates only outgoing responses. + + This middleware should be used last in the middleware chain to validate + responses only from route handlers, not from user middlewares. + """ + + def __init__( + self, + validation_serializer: Callable[[Any], str] | None = None, + has_response_validation_error: bool = False, + ): + """ + Initialize the response validation middleware. + + Parameters + ---------- + validation_serializer : Callable, optional + Optional serializer to use when serializing the response for validation. + Use it when you have a custom type that cannot be serialized by the default jsonable_encoder. + + has_response_validation_error: bool, optional + Optional flag used to distinguish between payload and validation errors. + By setting this flag to True, ResponseValidationError will be raised if response could not be validated. + """ + self._validation_serializer = validation_serializer + self._has_response_validation_error = has_response_validation_error + + def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) -> Response: + logger.debug("OpenAPIResponseValidationMiddleware handler") + + route: Route = app.context["_route"] + + # Call the next middleware (should be the route handler) + response = next_middleware(app) + + # Process the response + return self._handle_response(route=route, response=response) def _handle_response(self, *, route: Route, response: Response): # Process the response body if it exists @@ -228,85 +288,27 @@ def _prepare_response_content( """ Prepares the response content for serialization. """ - if isinstance(res, BaseModel): - return _model_dump( + if isinstance(res, BaseModel): # pragma: no cover + return _model_dump( # pragma: no cover res, by_alias=True, exclude_unset=exclude_unset, exclude_defaults=exclude_defaults, exclude_none=exclude_none, ) - elif isinstance(res, list): - return [ + elif isinstance(res, list): # pragma: no cover + return [ # pragma: no cover self._prepare_response_content(item, exclude_unset=exclude_unset, exclude_defaults=exclude_defaults) for item in res ] - elif isinstance(res, dict): - return { + elif isinstance(res, dict): # pragma: no cover + return { # pragma: no cover k: self._prepare_response_content(v, exclude_unset=exclude_unset, exclude_defaults=exclude_defaults) for k, v in res.items() } - elif dataclasses.is_dataclass(res): - return dataclasses.asdict(res) # type: ignore[arg-type] - return res - - def _get_body(self, app: EventHandlerInstance) -> dict[str, Any]: - """ - Get the request body from the event, and parse it according to content type. - """ - content_type = app.current_event.headers.get("content-type", "").strip() - - # Handle JSON content - if not content_type or content_type.startswith(APPLICATION_JSON_CONTENT_TYPE): - return self._parse_json_data(app) - - # Handle URL-encoded form data - elif content_type.startswith(APPLICATION_FORM_CONTENT_TYPE): - return self._parse_form_data(app) - - else: - raise NotImplementedError("Only JSON body or Form() are supported") - - def _parse_json_data(self, app: EventHandlerInstance) -> dict[str, Any]: - """Parse JSON data from the request body.""" - try: - return app.current_event.json_body - except json.JSONDecodeError as e: - raise RequestValidationError( - [ - { - "type": "json_invalid", - "loc": ("body", e.pos), - "msg": "JSON decode error", - "input": {}, - "ctx": {"error": e.msg}, - }, - ], - body=e.doc, - ) from e - - def _parse_form_data(self, app: EventHandlerInstance) -> dict[str, Any]: - """Parse URL-encoded form data from the request body.""" - try: - body = app.current_event.decoded_body or "" - # parse_qs returns dict[str, list[str]], but we want dict[str, str] for single values - parsed = parse_qs(body, keep_blank_values=True) - - result: dict[str, Any] = {key: values[0] if len(values) == 1 else values for key, values in parsed.items()} - return result - - except Exception as e: # pragma: no cover - raise RequestValidationError( # pragma: no cover - [ - { - "type": "form_invalid", - "loc": ("body",), - "msg": "Form data parsing error", - "input": {}, - "ctx": {"error": str(e)}, - }, - ], - ) from e + elif dataclasses.is_dataclass(res): # pragma: no cover + return dataclasses.asdict(res) # type: ignore[arg-type] # pragma: no cover + return res # pragma: no cover def _request_params_to_args( diff --git a/docs/core/event_handler/api_gateway.md b/docs/core/event_handler/api_gateway.md index 618d055fcfa..76991182737 100644 --- a/docs/core/event_handler/api_gateway.md +++ b/docs/core/event_handler/api_gateway.md @@ -960,6 +960,20 @@ As a practical example, let's refactor our correlation ID middleware so it accep !!! note "Class-based **vs** function-based middlewares" When registering a middleware, we expect a callable in both cases. For class-based middlewares, `BaseMiddlewareHandler` is doing the work of calling your `handler` method with the correct parameters, hence why we expect an instance of it. +#### Middleware and data validation + +When you enable data validation with `enable_validation=True`, we split validation into two separate middlewares: + +1. **Request validation** runs first to validate incoming data +2. **Your middlewares** run in the middle and can return early responses +3. **Response validation** runs last, only for route handler responses + +This ensures your middlewares can return early responses (401, 403, 429, etc.) without triggering validation errors, while still validating actual route handler responses for data integrity. + +```python hl_lines="5 11 23 36" title="Middleware early returns work seamlessly with validation" +--8<-- "examples/event_handler_rest/src/middleware_and_data_validation.py" +``` + #### Native middlewares These are native middlewares that may become native features depending on customer demand. diff --git a/examples/event_handler_rest/src/middleware_and_data_validation.py b/examples/event_handler_rest/src/middleware_and_data_validation.py new file mode 100644 index 00000000000..69459daa0a2 --- /dev/null +++ b/examples/event_handler_rest/src/middleware_and_data_validation.py @@ -0,0 +1,20 @@ +from __future__ import annotations + +from aws_lambda_powertools.event_handler import APIGatewayRestResolver, Response +from aws_lambda_powertools.event_handler.middlewares import NextMiddleware + +app = APIGatewayRestResolver(enable_validation=True) + + +def auth_middleware(app: APIGatewayRestResolver, next_middleware: NextMiddleware) -> Response: + # This 401 response won't trigger validation errors + return Response(status_code=401, content_type="application/json", body="{}") + + +app.use(middlewares=[auth_middleware]) + + +@app.get("/protected") +def protected_route() -> list[str]: + # Only this response will be validated against OpenAPI schema + return ["protected", "route"] diff --git a/tests/e2e/event_handler/handlers/data_validation_and_middleware.py b/tests/e2e/event_handler/handlers/data_validation_and_middleware.py new file mode 100644 index 00000000000..63be10f7ac2 --- /dev/null +++ b/tests/e2e/event_handler/handlers/data_validation_and_middleware.py @@ -0,0 +1,27 @@ +from pydantic import BaseModel + +from aws_lambda_powertools.event_handler import APIGatewayRestResolver, Response +from aws_lambda_powertools.event_handler.middlewares import NextMiddleware +from aws_lambda_powertools.utilities.typing import LambdaContext + + +def middleware_auth(app: APIGatewayRestResolver, next_middleware: NextMiddleware): + # Return early response + return Response(status_code=202, content_type="application/json", body="{}") + + +app = APIGatewayRestResolver(enable_validation=True) +app.use(middlewares=[middleware_auth]) + + +class MyModel(BaseModel): + name: str + + +@app.get("/data_validation_middleware") +def get_data_validation_and_middleware() -> MyModel: + return MyModel(name="powertools") + + +def lambda_handler(event, context: LambdaContext): + return app.resolve(event, context) diff --git a/tests/e2e/event_handler/infrastructure.py b/tests/e2e/event_handler/infrastructure.py index 67d370d2340..46f7cfe2473 100644 --- a/tests/e2e/event_handler/infrastructure.py +++ b/tests/e2e/event_handler/infrastructure.py @@ -23,6 +23,7 @@ def create_resources(self): functions["ApiGatewayRestHandler"], functions["OpenapiHandler"], functions["OpenapiHandlerWithPep563"], + functions["DataValidationAndMiddleware"], ], ) self._create_api_gateway_http(function=functions["ApiGatewayHttpHandler"]) @@ -101,6 +102,9 @@ def _create_api_gateway_rest(self, function: list[Function]): openapi_schema = apigw.root.add_resource("openapi_schema_with_pep563") openapi_schema.add_method("GET", apigwv1.LambdaIntegration(function[2], proxy=True)) + openapi_schema = apigw.root.add_resource("data_validation_middleware") + openapi_schema.add_method("GET", apigwv1.LambdaIntegration(function[3], proxy=True)) + CfnOutput(self.stack, "APIGatewayRestUrl", value=apigw.url) def _create_lambda_function_url(self, function: Function): diff --git a/tests/e2e/event_handler/test_openapi.py b/tests/e2e/event_handler/test_openapi.py index 3a91f804d31..b5255e44661 100644 --- a/tests/e2e/event_handler/test_openapi.py +++ b/tests/e2e/event_handler/test_openapi.py @@ -44,3 +44,18 @@ def test_get_openapi_schema_with_pep563(apigw_rest_endpoint): assert "Powertools e2e API" in response.text assert "x-amazon-apigateway-gateway-responses" in response.text assert response.status_code == 200 + + +def test_get_openapi_validation_and_middleware(apigw_rest_endpoint): + # GIVEN + url = f"{apigw_rest_endpoint}data_validation_middleware" + + # WHEN + response = data_fetcher.get_http_response( + Request( + method="GET", + url=url, + ), + ) + + assert response.status_code == 202 diff --git a/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py b/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py index b41beda36bc..1fd919b7b71 100644 --- a/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py +++ b/tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py @@ -1632,3 +1632,354 @@ def handler() -> CustomModel: result = app({"httpMethod": "GET", "path": "/test"}, {}) assert result["statusCode"] == 200 + + +def test_middleware_early_return_without_validation_error(gw_event): + """Test that middleware can return early response without triggering validation error (Issue #5228)""" + # GIVEN an APIGatewayRestResolver with validation enabled + app = APIGatewayRestResolver(enable_validation=True) + execution_log = [] + + def auth_middleware(app, next_middleware): + execution_log.append("auth_middleware") + # Return 401 without calling next_middleware - should not trigger validation + return Response(status_code=401, content_type="application/json", body="{}") + + def logging_middleware(app, next_middleware): + execution_log.append("logging_middleware") # Should not be called + return next_middleware(app) + + app.use(middlewares=[auth_middleware, logging_middleware]) + + class UserModel(BaseModel): + name: str + age: int + email: str + + @app.get("/protected") + def protected_route() -> UserModel: + execution_log.append("route_handler") # Should not be called + return UserModel(name="John", age=30, email="john@example.com") + + # WHEN calling the protected route + gw_event["path"] = "/protected" + gw_event["httpMethod"] = "GET" + + # THEN it should return 401 without validation error + result = app(gw_event, {}) + + assert result["statusCode"] == 401 + assert result["body"] == "{}" + + # Check execution order - only auth_middleware should have run + assert execution_log == ["auth_middleware"] + + +def test_middleware_allows_validation_to_proceed(gw_event): + """Test that when middleware calls next_middleware, validation still works""" + # GIVEN an APIGatewayRestResolver with validation enabled + app = APIGatewayRestResolver(enable_validation=True) + execution_log = [] + + def logging_middleware(app, next_middleware): + execution_log.append("logging_middleware") + # Log and continue to next middleware + result = next_middleware(app) + execution_log.append("logging_middleware_after") + return result + + app.use(middlewares=[logging_middleware]) + + class UserModel(BaseModel): + name: str + age: int + email: str + + @app.get("/user") + def get_user() -> UserModel: + execution_log.append("route_handler") + return UserModel(name="Jane", age=25, email="jane@example.com") + + # WHEN calling the user route + gw_event["path"] = "/user" + gw_event["httpMethod"] = "GET" + + # THEN it should return 200 with validated response + result = app(gw_event, {}) + + assert result["statusCode"] == 200 + response_body = json.loads(result["body"]) + assert response_body["name"] == "Jane" + assert response_body["age"] == 25 + assert response_body["email"] == "jane@example.com" + + # Check execution order + expected_log = ["logging_middleware", "route_handler", "logging_middleware_after"] + assert execution_log == expected_log + + +def test_request_validation_fails_before_user_middlewares(gw_event): + """Test that request validation fails before user middlewares are executed""" + # GIVEN an APIGatewayRestResolver with validation enabled + app = APIGatewayRestResolver(enable_validation=True) + execution_log = [] + + def passthrough_middleware(app, next_middleware): + execution_log.append("passthrough_middleware") + return next_middleware(app) + + app.use(middlewares=[passthrough_middleware]) + + class UserModel(BaseModel): + name: str + age: int + email: str + + @app.post("/user") + def create_user(user: UserModel) -> UserModel: + execution_log.append("route_handler") # Should not be called due to validation error + return user + + # WHEN sending invalid request body (missing required fields) + gw_event["path"] = "/user" + gw_event["httpMethod"] = "POST" + gw_event["body"] = '{"name": "John"}' # Missing age and email + gw_event["headers"]["Content-Type"] = "application/json" + + # THEN it should return 422 for validation error + result = app(gw_event, {}) + + assert result["statusCode"] == 422 + response_body = json.loads(result["body"]) + assert "detail" in response_body + + # Request validation happens BEFORE user middlewares, so neither should run + assert "passthrough_middleware" not in execution_log + assert "route_handler" not in execution_log + + +def test_request_validation_passes_then_middlewares_execute(gw_event): + """Test that when request validation passes, user middlewares execute normally""" + # GIVEN an APIGatewayRestResolver with validation enabled + app = APIGatewayRestResolver(enable_validation=True) + execution_log = [] + + def passthrough_middleware(app, next_middleware): + execution_log.append("passthrough_middleware") + return next_middleware(app) + + app.use(middlewares=[passthrough_middleware]) + + class UserModel(BaseModel): + name: str + age: int + email: str + + @app.post("/user") + def create_user(user: UserModel) -> UserModel: + execution_log.append("route_handler") + return user + + # WHEN sending valid request body + gw_event["path"] = "/user" + gw_event["httpMethod"] = "POST" + gw_event["body"] = '{"name": "John", "age": 30, "email": "john@example.com"}' + gw_event["headers"]["Content-Type"] = "application/json" + + # THEN it should return 200 and middlewares should execute + result = app(gw_event, {}) + + assert result["statusCode"] == 200 + + # Both middleware and route handler should have executed + assert "passthrough_middleware" in execution_log + assert "route_handler" in execution_log + + +def test_multiple_middlewares_with_early_return(gw_event): + """Test multiple middlewares where one returns early (Issue #4656)""" + # GIVEN an APIGatewayRestResolver with validation enabled + app = APIGatewayRestResolver(enable_validation=True) + execution_log = [] + + def first_middleware(app, next_middleware): + execution_log.append("first_middleware") + return next_middleware(app) + + def auth_middleware(app, next_middleware): + execution_log.append("auth_middleware") + # Return early - should not trigger validation + return Response(status_code=403, content_type="application/json", body="{}") + + def third_middleware(app, next_middleware): + execution_log.append("third_middleware") # Should not be called + return next_middleware(app) + + app.use(middlewares=[first_middleware, auth_middleware, third_middleware]) + + class UserModel(BaseModel): + name: str + age: int + email: str + + @app.get("/protected") + def protected_route() -> UserModel: + execution_log.append("route_handler") # Should not be called + return UserModel(name="Secret", age=42, email="secret@example.com") + + # WHEN calling the protected route + gw_event["path"] = "/protected" + gw_event["httpMethod"] = "GET" + + # THEN it should return 403 without validation error + result = app(gw_event, {}) + + assert result["statusCode"] == 403 + assert result["body"] == "{}" + + # Check execution order - should stop at auth_middleware + expected_log = ["first_middleware", "auth_middleware"] + assert execution_log == expected_log + + +def test_middleware_execution_order_with_validation(gw_event): + """Test that middleware execution order is correct with validation enabled""" + # GIVEN an APIGatewayRestResolver with validation enabled + app = APIGatewayRestResolver(enable_validation=True) + execution_log = [] + + def first_middleware(app, next_middleware): + execution_log.append("first_middleware") + return next_middleware(app) + + def second_middleware(app, next_middleware): + execution_log.append("second_middleware") + return next_middleware(app) + + app.use(middlewares=[first_middleware, second_middleware]) + + class UserModel(BaseModel): + name: str + age: int + email: str + + @app.get("/test") + def test_route() -> UserModel: + execution_log.append("route_handler") + return UserModel(name="Test", age=30, email="test@example.com") + + # WHEN calling the test route + gw_event["path"] = "/test" + gw_event["httpMethod"] = "GET" + + # THEN it should return 200 with correct execution order + result = app(gw_event, {}) + + assert result["statusCode"] == 200 + + # Expected order: first -> second -> route + expected_order = ["first_middleware", "second_middleware", "route_handler"] + assert execution_log == expected_order + + +def test_rate_limiting_middleware_response_not_validated(gw_event): + """Test rate limiting middleware response (429) is not validated""" + # GIVEN an APIGatewayRestResolver with validation enabled + app = APIGatewayRestResolver(enable_validation=True) + + def rate_limit_middleware(app, next_middleware): + # Return 429 with simple body - should not be validated + return Response(status_code=429, content_type="application/json", body="{}") + + app.use(middlewares=[rate_limit_middleware]) + + class UserModel(BaseModel): + name: str + age: int + email: str + + @app.get("/api/data") + def get_data() -> UserModel: + return UserModel(name="Data", age=1, email="data@example.com") + + # WHEN calling the rate limited route + gw_event["path"] = "/api/data" + gw_event["httpMethod"] = "GET" + + # THEN it should return 429 without validation error + result = app(gw_event, {}) + + assert result["statusCode"] == 429 + assert result["body"] == "{}" + + +def test_middleware_with_complex_auth_response_gets_validated(gw_event): + """Test middleware with complex auth response that should be validated""" + # GIVEN an APIGatewayRestResolver with validation enabled + app = APIGatewayRestResolver(enable_validation=True) + + def auth_middleware(app, next_middleware): + # Return complex 401 response - should trigger validation + return Response( + status_code=401, + content_type="application/json", + body='{"error": "Unauthorized", "message": "Token expired", "code": 1001}', + ) + + app.use(middlewares=[auth_middleware]) + + class UserModel(BaseModel): + name: str + age: int + email: str + + @app.get("/protected") + def protected_route() -> UserModel: + return UserModel(name="Secret", age=42, email="secret@example.com") + + # WHEN calling the protected route + gw_event["path"] = "/protected" + gw_event["httpMethod"] = "GET" + + # THEN it should return 401 with complex body (validation should occur) + result = app(gw_event, {}) + + assert result["statusCode"] == 401 + response_body = json.loads(result["body"]) + assert response_body["error"] == "Unauthorized" + assert response_body["message"] == "Token expired" + assert response_body["code"] == 1001 + + +def test_normal_route_response_validation_still_works(gw_event): + """Test that normal route responses are still validated""" + # GIVEN an APIGatewayRestResolver with validation enabled + app = APIGatewayRestResolver(enable_validation=True) + + def logging_middleware(app, next_middleware): + result = next_middleware(app) + return result + + app.use(middlewares=[logging_middleware]) + + class UserModel(BaseModel): + name: str + age: int + email: str + + @app.get("/user/") + def get_user(user_id: int) -> UserModel: + return UserModel(name=f"User{user_id}", age=user_id + 20, email=f"user{user_id}@example.com") + + # WHEN calling the user route + gw_event["path"] = "/user/123" + gw_event["httpMethod"] = "GET" + + # THEN it should return 200 with validated response + result = app(gw_event, {}) + + assert result["statusCode"] == 200 + response_body = json.loads(result["body"]) + assert response_body["name"] == "User123" + assert response_body["age"] == 143 + assert response_body["email"] == "user123@example.com"