Skip to content

fix(event_handler): split OpenAPI validation to respect middleware returns #7050

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jul 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 32 additions & 15 deletions aws_lambda_powertools/event_handler/api_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,7 @@ def __call__(

# Save CPU cycles by building middleware stack once
if not self._middleware_stack_built:
self._build_middleware_stack(router_middlewares=router_middlewares)
self._build_middleware_stack(router_middlewares=router_middlewares, app=app)

# If debug is turned on then output the middleware stack to the console
if app._debug:
Expand All @@ -487,7 +487,7 @@ def __call__(
# Call the Middleware Wrapped _call_stack function handler with the app
return self._middleware_stack(app)

def _build_middleware_stack(self, router_middlewares: list[Callable[..., Any]]) -> None:
def _build_middleware_stack(self, router_middlewares: list[Callable[..., Any]], app) -> None:
"""
Builds the middleware stack for the handler by wrapping each
handler in an instance of MiddlewareWrapper which is used to contain the state
Expand All @@ -505,7 +505,25 @@ def _build_middleware_stack(self, router_middlewares: list[Callable[..., Any]])
The Route Middleware stack is processed in reverse order. This is so the stack of
middleware handlers is applied in the order of being added to the handler.
"""
all_middlewares = router_middlewares + self.middlewares
# Build middleware stack in the correct order for validation:
# 1. Request validation middleware (first)
# 2. Router middlewares + user middlewares (middle)
# 3. Response validation middleware (before route handler)
# 4. Route handler adapter (last)

all_middlewares = []

# Add request validation middleware first if validation is enabled
if hasattr(app, "_request_validation_middleware"):
all_middlewares.append(app._request_validation_middleware)

# Add user middlewares in the middle
all_middlewares.extend(router_middlewares + self.middlewares)

# Add response validation middleware before the route handler if validation is enabled
if hasattr(app, "_response_validation_middleware"):
all_middlewares.append(app._response_validation_middleware)

logger.debug(f"Building middleware stack: {all_middlewares}")

# IMPORTANT:
Expand Down Expand Up @@ -1639,17 +1657,16 @@ def __init__(
self._json_body_deserializer = json_body_deserializer

if self._enable_validation:
from aws_lambda_powertools.event_handler.middlewares.openapi_validation import OpenAPIValidationMiddleware

# Note the serializer argument: only use custom serializer if provided by the caller
# Otherwise, fully rely on the internal Pydantic based mechanism to serialize responses for validation.
self.use(
[
OpenAPIValidationMiddleware(
validation_serializer=serializer,
has_response_validation_error=self._has_response_validation_error,
),
],
from aws_lambda_powertools.event_handler.middlewares.openapi_validation import (
OpenAPIRequestValidationMiddleware,
OpenAPIResponseValidationMiddleware,
)

# Store validation middlewares to be added in the correct order later
self._request_validation_middleware = OpenAPIRequestValidationMiddleware()
self._response_validation_middleware = OpenAPIResponseValidationMiddleware(
validation_serializer=serializer,
has_response_validation_error=self._has_response_validation_error,
)

def _validate_response_validation_error_http_code(
Expand Down Expand Up @@ -2668,7 +2685,7 @@ def _call_exception_handler(self, exp: Exception, route: Route) -> ResponseBuild
route=route,
)

# OpenAPIValidationMiddleware will only raise ResponseValidationError when
# OpenAPIResponseValidationMiddleware will only raise ResponseValidationError when
# 'self._response_validation_error_http_code' is not None or
# when route has custom_response_validation_http_code
if isinstance(exp, ResponseValidationError):
Expand Down
236 changes: 119 additions & 117 deletions aws_lambda_powertools/event_handler/middlewares/openapi_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,56 +37,20 @@
APPLICATION_FORM_CONTENT_TYPE = "application/x-www-form-urlencoded"


class OpenAPIValidationMiddleware(BaseMiddlewareHandler):
class OpenAPIRequestValidationMiddleware(BaseMiddlewareHandler):
"""
OpenAPIValidationMiddleware is a middleware that validates the request against the OpenAPI schema defined by the
Lambda handler. It also validates the response against the OpenAPI schema defined by the Lambda handler. It
should not be used directly, but rather through the `enable_validation` parameter of the `ApiGatewayResolver`.
OpenAPI request validation middleware - validates only incoming requests.

Example
--------

```python
from pydantic import BaseModel

from aws_lambda_powertools.event_handler.api_gateway import (
APIGatewayRestResolver,
)

class Todo(BaseModel):
name: str

app = APIGatewayRestResolver(enable_validation=True)

@app.get("/todos")
def get_todos(): list[Todo]:
return [Todo(name="hello world")]
```
This middleware should be used first in the middleware chain to validate
requests before they reach user middlewares.
"""

def __init__(
self,
validation_serializer: Callable[[Any], str] | None = None,
has_response_validation_error: bool = False,
):
"""
Initialize the OpenAPIValidationMiddleware.

Parameters
----------
validation_serializer : Callable, optional
Optional serializer to use when serializing the response for validation.
Use it when you have a custom type that cannot be serialized by the default jsonable_encoder.

has_response_validation_error: bool, optional
Optional flag used to distinguish between payload and validation errors.
By setting this flag to True, ResponseValidationError will be raised if response could not be validated.
"""
self._validation_serializer = validation_serializer
self._has_response_validation_error = has_response_validation_error
def __init__(self):
"""Initialize the request validation middleware."""
pass

def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) -> Response:
logger.debug("OpenAPIValidationMiddleware handler")
logger.debug("OpenAPIRequestValidationMiddleware handler")

route: Route = app.context["_route"]

Expand Down Expand Up @@ -140,15 +104,111 @@ def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) ->
if errors:
# Raise the validation errors
raise RequestValidationError(_normalize_errors(errors))

# Re-write the route_args with the validated values
app.context["_route_args"] = values

# Call the next middleware
return next_middleware(app)

def _get_body(self, app: EventHandlerInstance) -> dict[str, Any]:
"""
Get the request body from the event, and parse it according to content type.
"""
content_type = app.current_event.headers.get("content-type", "").strip()

# Handle JSON content
if not content_type or content_type.startswith(APPLICATION_JSON_CONTENT_TYPE):
return self._parse_json_data(app)

# Handle URL-encoded form data
elif content_type.startswith(APPLICATION_FORM_CONTENT_TYPE):
return self._parse_form_data(app)

else:
# Re-write the route_args with the validated values, and call the next middleware
app.context["_route_args"] = values
raise NotImplementedError("Only JSON body or Form() are supported")

# Call the handler by calling the next middleware
response = next_middleware(app)
def _parse_json_data(self, app: EventHandlerInstance) -> dict[str, Any]:
"""Parse JSON data from the request body."""
try:
return app.current_event.json_body
except json.JSONDecodeError as e:
raise RequestValidationError(
[
{
"type": "json_invalid",
"loc": ("body", e.pos),
"msg": "JSON decode error",
"input": {},
"ctx": {"error": e.msg},
},
],
body=e.doc,
) from e

# Process the response
return self._handle_response(route=route, response=response)
def _parse_form_data(self, app: EventHandlerInstance) -> dict[str, Any]:
"""Parse URL-encoded form data from the request body."""
try:
body = app.current_event.decoded_body or ""
# parse_qs returns dict[str, list[str]], but we want dict[str, str] for single values
parsed = parse_qs(body, keep_blank_values=True)

result: dict[str, Any] = {key: values[0] if len(values) == 1 else values for key, values in parsed.items()}
return result

except Exception as e: # pragma: no cover
raise RequestValidationError( # pragma: no cover
[
{
"type": "form_invalid",
"loc": ("body",),
"msg": "Form data parsing error",
"input": {},
"ctx": {"error": str(e)},
},
],
) from e


class OpenAPIResponseValidationMiddleware(BaseMiddlewareHandler):
"""
OpenAPI response validation middleware - validates only outgoing responses.

This middleware should be used last in the middleware chain to validate
responses only from route handlers, not from user middlewares.
"""

def __init__(
self,
validation_serializer: Callable[[Any], str] | None = None,
has_response_validation_error: bool = False,
):
"""
Initialize the response validation middleware.

Parameters
----------
validation_serializer : Callable, optional
Optional serializer to use when serializing the response for validation.
Use it when you have a custom type that cannot be serialized by the default jsonable_encoder.

has_response_validation_error: bool, optional
Optional flag used to distinguish between payload and validation errors.
By setting this flag to True, ResponseValidationError will be raised if response could not be validated.
"""
self._validation_serializer = validation_serializer
self._has_response_validation_error = has_response_validation_error

def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) -> Response:
logger.debug("OpenAPIResponseValidationMiddleware handler")

route: Route = app.context["_route"]

# Call the next middleware (should be the route handler)
response = next_middleware(app)

# Process the response
return self._handle_response(route=route, response=response)

def _handle_response(self, *, route: Route, response: Response):
# Process the response body if it exists
Expand Down Expand Up @@ -228,85 +288,27 @@ def _prepare_response_content(
"""
Prepares the response content for serialization.
"""
if isinstance(res, BaseModel):
return _model_dump(
if isinstance(res, BaseModel): # pragma: no cover
return _model_dump( # pragma: no cover
res,
by_alias=True,
exclude_unset=exclude_unset,
exclude_defaults=exclude_defaults,
exclude_none=exclude_none,
)
elif isinstance(res, list):
return [
elif isinstance(res, list): # pragma: no cover
return [ # pragma: no cover
self._prepare_response_content(item, exclude_unset=exclude_unset, exclude_defaults=exclude_defaults)
for item in res
]
elif isinstance(res, dict):
return {
elif isinstance(res, dict): # pragma: no cover
return { # pragma: no cover
k: self._prepare_response_content(v, exclude_unset=exclude_unset, exclude_defaults=exclude_defaults)
for k, v in res.items()
}
elif dataclasses.is_dataclass(res):
return dataclasses.asdict(res) # type: ignore[arg-type]
return res

def _get_body(self, app: EventHandlerInstance) -> dict[str, Any]:
"""
Get the request body from the event, and parse it according to content type.
"""
content_type = app.current_event.headers.get("content-type", "").strip()

# Handle JSON content
if not content_type or content_type.startswith(APPLICATION_JSON_CONTENT_TYPE):
return self._parse_json_data(app)

# Handle URL-encoded form data
elif content_type.startswith(APPLICATION_FORM_CONTENT_TYPE):
return self._parse_form_data(app)

else:
raise NotImplementedError("Only JSON body or Form() are supported")

def _parse_json_data(self, app: EventHandlerInstance) -> dict[str, Any]:
"""Parse JSON data from the request body."""
try:
return app.current_event.json_body
except json.JSONDecodeError as e:
raise RequestValidationError(
[
{
"type": "json_invalid",
"loc": ("body", e.pos),
"msg": "JSON decode error",
"input": {},
"ctx": {"error": e.msg},
},
],
body=e.doc,
) from e

def _parse_form_data(self, app: EventHandlerInstance) -> dict[str, Any]:
"""Parse URL-encoded form data from the request body."""
try:
body = app.current_event.decoded_body or ""
# parse_qs returns dict[str, list[str]], but we want dict[str, str] for single values
parsed = parse_qs(body, keep_blank_values=True)

result: dict[str, Any] = {key: values[0] if len(values) == 1 else values for key, values in parsed.items()}
return result

except Exception as e: # pragma: no cover
raise RequestValidationError( # pragma: no cover
[
{
"type": "form_invalid",
"loc": ("body",),
"msg": "Form data parsing error",
"input": {},
"ctx": {"error": str(e)},
},
],
) from e
elif dataclasses.is_dataclass(res): # pragma: no cover
return dataclasses.asdict(res) # type: ignore[arg-type] # pragma: no cover
return res # pragma: no cover


def _request_params_to_args(
Expand Down
14 changes: 14 additions & 0 deletions docs/core/event_handler/api_gateway.md
Original file line number Diff line number Diff line change
Expand Up @@ -960,6 +960,20 @@ As a practical example, let's refactor our correlation ID middleware so it accep
!!! note "Class-based **vs** function-based middlewares"
When registering a middleware, we expect a callable in both cases. For class-based middlewares, `BaseMiddlewareHandler` is doing the work of calling your `handler` method with the correct parameters, hence why we expect an instance of it.

#### Middleware and data validation

When you enable data validation with `enable_validation=True`, we split validation into two separate middlewares:

1. **Request validation** runs first to validate incoming data
2. **Your middlewares** run in the middle and can return early responses
3. **Response validation** runs last, only for route handler responses

This ensures your middlewares can return early responses (401, 403, 429, etc.) without triggering validation errors, while still validating actual route handler responses for data integrity.

```python hl_lines="5 11 23 36" title="Middleware early returns work seamlessly with validation"
--8<-- "examples/event_handler_rest/src/middleware_and_data_validation.py"
```

#### Native middlewares

These are native middlewares that may become native features depending on customer demand.
Expand Down
Loading
Loading