Skip to content

Commit

Permalink
Revert partial usage and add chema bool
Browse files Browse the repository at this point in the history
  • Loading branch information
EvieePy committed May 5, 2024
1 parent dee0d94 commit 8ba004b
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 8 deletions.
34 changes: 26 additions & 8 deletions starlette_plus/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@
import inspect
import logging
from collections.abc import Callable, Coroutine, Iterator, Sequence
from functools import partial
from typing import TYPE_CHECKING, Any, ClassVar, Self, TypeAlias, TypedDict, Unpack

from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.requests import Request
from starlette.responses import Response
from starlette.routing import Mount, Route, WebSocketRoute
from starlette.types import Receive, Scope, Send

Expand Down Expand Up @@ -90,6 +91,16 @@ def __init__(self, **kwargs: Unpack[RouteOptions]) -> None:
self._limits: list[RateLimitData] = kwargs.get("limits", [])
self._is_websocket: bool = kwargs.get("websocket", False)
self._view: View | None = None
self._include_in_schema: bool = kwargs["include_in_schema"]

async def __call__(self, scope: Scope, receive: Receive, send: Send) -> Any:
request: Request = Request(scope, receive, send)
response: Response | None = await self._coro(self._view, request)

if not response:
return Response(status_code=204)

await response(scope, receive, send)


LimitDecorator: TypeAlias = Callable[..., RouteCoro] | _Route
Expand All @@ -103,6 +114,7 @@ def route(
methods: Methods = ["GET"],
prefix: bool = True,
websocket: bool = False,
include_in_schema: bool = True,
) -> Callable[..., _Route]:
def decorator(coro: Callable[..., RouteCoro]) -> _Route:
if not asyncio.iscoroutinefunction(coro):
Expand All @@ -113,7 +125,15 @@ def decorator(coro: Callable[..., RouteCoro]) -> _Route:
raise ValueError(f"Route callback function must not be named any: {', '.join(disallowed)}")

limits: list[RateLimitData] = getattr(coro, "__limits__", [])
return _Route(path=path, coro=coro, methods=methods, prefix=prefix, limits=limits, websocket=websocket)
return _Route(
path=path,
coro=coro,
methods=methods,
prefix=prefix,
limits=limits,
websocket=websocket,
include_in_schema=include_in_schema,
)

return decorator

Expand Down Expand Up @@ -186,14 +206,13 @@ def __new__(cls, *args: Any, **kwargs: Any) -> Self:
setattr(member, method, member._coro)

new: WebSocketRoute | Route
endpoint: partial[RouteCoro] = partial(member._coro, self)

if member._is_websocket:
new = WebSocketRoute(path=path, endpoint=endpoint, name=f"{name}.{member._coro.__name__}")
new = WebSocketRoute(path=path, endpoint=member, name=f"{name}.{member._coro.__name__}")
else:
new = Route(
path=path,
endpoint=endpoint,
endpoint=member,
methods=member._methods,
name=f"{name}.{member._coro.__name__}",
)
Expand Down Expand Up @@ -262,14 +281,13 @@ def __new__(cls, *args: Any, **kwargs: Any) -> Self:
setattr(member, method, member._coro)

new: WebSocketRoute | Route
endpoint: partial[RouteCoro] = partial(member._coro, self)

if member._is_websocket:
new = WebSocketRoute(path=path, endpoint=endpoint, name=f"{name}.{member._coro.__name__}")
new = WebSocketRoute(path=path, endpoint=member, name=f"{name}.{member._coro.__name__}")
else:
new = Route(
path=path,
endpoint=endpoint,
endpoint=member,
methods=member._methods,
name=f"{name}.{member._coro.__name__}",
)
Expand Down
1 change: 1 addition & 0 deletions starlette_plus/types_/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,4 @@ class RouteOptions(TypedDict):
prefix: bool
websocket: bool
limits: list[RateLimitData]
include_in_schema: bool

0 comments on commit 8ba004b

Please sign in to comment.