Skip to content

Commit

Permalink
Allow in-memory sessions.
Browse files Browse the repository at this point in the history
  • Loading branch information
EvieePy committed Apr 27, 2024
1 parent 5a97b3c commit 594185e
Showing 1 changed file with 22 additions and 8 deletions.
30 changes: 22 additions & 8 deletions starlette_plus/middleware/sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,27 +42,39 @@


class Storage:
__slots__ = "pool"
__slots__ = ("pool", "_keys")

def __init__(self, *, redis: Redis) -> None:
self.pool: redis.Redis = redis.pool
def __init__(self, *, redis: Redis | None = None) -> None:
self.pool: redis.Redis | None = redis.pool if redis else None
self._keys: dict[str, Any] = {}

async def get(self, data: dict[str, Any]) -> dict[str, Any]:
expiry: datetime.datetime = datetime.datetime.fromisoformat(data["expiry"])
key: str = data["_session_secret_key"]

if expiry <= datetime.datetime.now():
await self.pool.delete(key) # type: ignore
await self.delete(key)
return {}

session: Any = await self.pool.get(key) # type: ignore
if self.pool:
session: Any = await self.pool.get(key) # type: ignore
else:
session: Any = self._keys.get(key)

return json.loads(session) if session else {}

async def set(self, key: str, value: dict[str, Any], *, max_age: int) -> None:
await self.pool.set(key, json.dumps(value), ex=max_age) # type: ignore
if self.pool:
await self.pool.set(key, json.dumps(value), ex=max_age) # type: ignore
return

self._keys[key] = json.dumps(value)

async def delete(self, key: str) -> None:
await self.pool.delete(key) # type: ignore
if self.pool:
await self.pool.delete(key) # type: ignore
else:
self._keys.pop(key, None)


class SessionMiddleware:
Expand All @@ -75,7 +87,7 @@ def __init__(
max_age: int | None = None,
same_site: str = "lax",
secure: bool = True,
redis: Redis,
redis: Redis | None = None,
) -> None:
self.app: ASGIApp = app
self.name: str = name or "__session_cookie"
Expand Down Expand Up @@ -113,6 +125,8 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
scope["session"] = session

async def wrapper(message: Message) -> None:
nonlocal original, session, cookie

if message["type"] != "http.response.start":
await send(message)
return
Expand Down

0 comments on commit 594185e

Please sign in to comment.