diff --git a/starlette_plus/core.py b/starlette_plus/core.py index e5ec4af..6212f9f 100644 --- a/starlette_plus/core.py +++ b/starlette_plus/core.py @@ -19,11 +19,12 @@ import inspect import logging from collections.abc import Callable, Coroutine, Iterator, Sequence -from functools import partial from typing import TYPE_CHECKING, Any, ClassVar, Self, TypeAlias, TypedDict, Unpack from starlette.applications import Starlette from starlette.middleware import Middleware +from starlette.requests import Request +from starlette.responses import Response from starlette.routing import Mount, Route, WebSocketRoute from starlette.types import Receive, Scope, Send @@ -90,6 +91,16 @@ def __init__(self, **kwargs: Unpack[RouteOptions]) -> None: self._limits: list[RateLimitData] = kwargs.get("limits", []) self._is_websocket: bool = kwargs.get("websocket", False) self._view: View | None = None + self._include_in_schema: bool = kwargs["include_in_schema"] + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> Any: + request: Request = Request(scope, receive, send) + response: Response | None = await self._coro(self._view, request) + + if not response: + return Response(status_code=204) + + await response(scope, receive, send) LimitDecorator: TypeAlias = Callable[..., RouteCoro] | _Route @@ -103,6 +114,7 @@ def route( methods: Methods = ["GET"], prefix: bool = True, websocket: bool = False, + include_in_schema: bool = True, ) -> Callable[..., _Route]: def decorator(coro: Callable[..., RouteCoro]) -> _Route: if not asyncio.iscoroutinefunction(coro): @@ -113,7 +125,15 @@ def decorator(coro: Callable[..., RouteCoro]) -> _Route: raise ValueError(f"Route callback function must not be named any: {', '.join(disallowed)}") limits: list[RateLimitData] = getattr(coro, "__limits__", []) - return _Route(path=path, coro=coro, methods=methods, prefix=prefix, limits=limits, websocket=websocket) + return _Route( + path=path, + coro=coro, + methods=methods, + prefix=prefix, + limits=limits, + websocket=websocket, + include_in_schema=include_in_schema, + ) return decorator @@ -186,14 +206,13 @@ def __new__(cls, *args: Any, **kwargs: Any) -> Self: setattr(member, method, member._coro) new: WebSocketRoute | Route - endpoint: partial[RouteCoro] = partial(member._coro, self) if member._is_websocket: - new = WebSocketRoute(path=path, endpoint=endpoint, name=f"{name}.{member._coro.__name__}") + new = WebSocketRoute(path=path, endpoint=member, name=f"{name}.{member._coro.__name__}") else: new = Route( path=path, - endpoint=endpoint, + endpoint=member, methods=member._methods, name=f"{name}.{member._coro.__name__}", ) @@ -262,14 +281,13 @@ def __new__(cls, *args: Any, **kwargs: Any) -> Self: setattr(member, method, member._coro) new: WebSocketRoute | Route - endpoint: partial[RouteCoro] = partial(member._coro, self) if member._is_websocket: - new = WebSocketRoute(path=path, endpoint=endpoint, name=f"{name}.{member._coro.__name__}") + new = WebSocketRoute(path=path, endpoint=member, name=f"{name}.{member._coro.__name__}") else: new = Route( path=path, - endpoint=endpoint, + endpoint=member, methods=member._methods, name=f"{name}.{member._coro.__name__}", ) diff --git a/starlette_plus/types_/core.py b/starlette_plus/types_/core.py index 8378525..4f14913 100644 --- a/starlette_plus/types_/core.py +++ b/starlette_plus/types_/core.py @@ -32,3 +32,4 @@ class RouteOptions(TypedDict): prefix: bool websocket: bool limits: list[RateLimitData] + include_in_schema: bool