diff --git a/starsessions/middleware.py b/starsessions/middleware.py index 602aa21..f911c70 100644 --- a/starsessions/middleware.py +++ b/starsessions/middleware.py @@ -66,7 +66,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: connection = HTTPConnection(scope) session_id = connection.cookies.get(self.cookie_name) - handler = SessionHandler(connection, session_id, self.store, self.serializer, self.lifetime) + handler = SessionHandler(connection, session_id, self.rolling, self.store, self.serializer, self.lifetime) scope["session"] = LoadGuard() scope["session_handler"] = handler diff --git a/starsessions/session.py b/starsessions/session.py index b832ea8..e66616b 100644 --- a/starsessions/session.py +++ b/starsessions/session.py @@ -72,7 +72,10 @@ def get_session_remaining_seconds(connection: HTTPConnection) -> int: """Get total seconds remaining before this session expires.""" now = time.time() metadata = get_session_metadata(connection) - return int((metadata["created"] + metadata["lifetime"]) - now) + # use "last_update" if rolling session, otherwise use "created" + rolling = metadata.get("rolling", False) + last_update = metadata["last_update"] if rolling and "last_update" in metadata else metadata["created"] + return int((last_update + metadata["lifetime"]) - now) class SessionHandler: @@ -86,12 +89,14 @@ def __init__( self, connection: HTTPConnection, session_id: typing.Optional[str], + rolling: bool, store: SessionStore, serializer: Serializer, lifetime: int, ) -> None: self.connection = connection self.session_id = session_id + self.rolling = rolling self.store = store self.serializer = serializer self.is_loaded = False @@ -115,7 +120,12 @@ async def load(self) -> None: ) # read and merge metadata - metadata = {"lifetime": self.lifetime, "created": time.time(), "last_access": time.time()} + metadata = { + "lifetime": self.lifetime, + "created": time.time(), + "last_access": time.time(), + "rolling": self.rolling, + } metadata.update(data.pop("__metadata__", {})) metadata.update({"last_access": time.time()}) # force update self.metadata = metadata # type: ignore[assignment] @@ -126,6 +136,8 @@ async def load(self) -> None: self.initially_empty = len(self.connection.session) == 0 async def save(self, remaining_time: int) -> str: + if self.rolling and self.metadata: + self.metadata["last_update"] = time.time() self.connection.session.update({"__metadata__": self.metadata}) self.session_id = await self.store.write( diff --git a/starsessions/types.py b/starsessions/types.py index b1df2a5..668b7b0 100644 --- a/starsessions/types.py +++ b/starsessions/types.py @@ -5,3 +5,4 @@ class SessionMetadata(typing.TypedDict): lifetime: int created: float # timestamp last_access: float # timestamp + last_update: float # timestamp diff --git a/tests/backends/test_redis.py b/tests/backends/test_redis.py index d5c1937..1bbb239 100644 --- a/tests/backends/test_redis.py +++ b/tests/backends/test_redis.py @@ -1,7 +1,7 @@ import os import pytest import redis.asyncio -from pytest_asyncio.plugin import SubRequest +from pytest import FixtureRequest from starsessions import ImproperlyConfigured from starsessions.stores.base import SessionStore @@ -13,7 +13,7 @@ def redis_key_callable(session_id: str) -> str: @pytest.fixture(params=["prefix_", redis_key_callable], ids=["using string", "using redis_key_callable"]) -def redis_store(request: SubRequest) -> SessionStore: +def redis_store(request: FixtureRequest) -> SessionStore: redis_key = request.param url = os.environ.get("REDIS_URL", "redis://localhost") return RedisStore(url, prefix=redis_key) diff --git a/tests/test_metadata.py b/tests/test_metadata.py index a764d83..232272e 100644 --- a/tests/test_metadata.py +++ b/tests/test_metadata.py @@ -37,6 +37,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: "created": 1660556520, "last_access": 1660556520, "lifetime": 1209600, + "rolling": False, } @@ -59,6 +60,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: "created": 42, "last_access": 1660556520, "lifetime": 0, + "rolling": False, } @@ -79,6 +81,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: "created": 1660556520, "last_access": 1660556520, "lifetime": 1209600, + "rolling": False, } with mock.patch("time.time", lambda: 1660556000): @@ -86,4 +89,5 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: "created": 1660556520, "last_access": 1660556000, "lifetime": 1209600, + "rolling": False, } diff --git a/tests/test_rolling_session.py b/tests/test_rolling_session.py index 11d8281..c2c4dba 100644 --- a/tests/test_rolling_session.py +++ b/tests/test_rolling_session.py @@ -7,6 +7,7 @@ from unittest import mock from starsessions import SessionMiddleware, SessionStore, load_session +from starsessions.session import get_session_remaining_seconds @pytest.mark.asyncio @@ -19,21 +20,22 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None: await load_session(connection) response = JSONResponse(connection.session) await response(scope, receive, send) + # the session has been rolled again to last for another 10 seconds + assert get_session_remaining_seconds(connection) == session_lifetime app = SessionMiddleware(app, store=store, lifetime=10, rolling=True) client = TestClient(app) current_time = time.time() - # it must set max-age = 10 with mock.patch("time.time", lambda: current_time): response = client.get("/", cookies={"session": "session_id"}) - first_max_age = next(cookie for cookie in response.cookies if cookie.name == "session").expires + first_expiration_date = next(cookie for cookie in response.cookies if cookie.name == "session").expires - # it must set max-age = 10 regardless of any previous value + # fast forward 2 seconds with mock.patch("time.time", lambda: current_time + 2): response = client.get("/", cookies={"session": "session_id"}) - second_max_age = next(cookie for cookie in response.cookies if cookie.name == "session").expires + second_expiration_date = next(cookie for cookie in response.cookies if cookie.name == "session").expires - # the expiration date of the second response must be larger - assert second_max_age > first_max_age + # the expiration date of the cooke in the second response must be extended by 2 seconds + assert second_expiration_date - first_expiration_date == 2 diff --git a/tests/test_session.py b/tests/test_session.py index d4eefa2..1b88edc 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -22,7 +22,9 @@ def test_regenerate_session_id(store: SessionStore, serializer: Serializer) -> N base_session_id = "some_id" connection = HTTPConnection(scope) connection.scope["session"] = {} - connection.scope["session_handler"] = SessionHandler(connection, base_session_id, store, serializer, lifetime=60) + connection.scope["session_handler"] = SessionHandler( + connection, base_session_id, False, store, serializer, lifetime=60 + ) session_id = regenerate_session_id(connection) assert session_id @@ -34,7 +36,9 @@ def test_get_session_id(store: SessionStore, serializer: Serializer) -> None: base_session_id = "some_id" connection = HTTPConnection(scope) connection.scope["session"] = {} - connection.scope["session_handler"] = SessionHandler(connection, base_session_id, store, serializer, lifetime=60) + connection.scope["session_handler"] = SessionHandler( + connection, base_session_id, False, store, serializer, lifetime=60 + ) session_id = get_session_id(connection) assert session_id == base_session_id @@ -45,7 +49,9 @@ def test_get_session_handler(store: SessionStore, serializer: Serializer) -> Non base_session_id = "some_id" connection = HTTPConnection(scope) connection.scope["session"] = {} - connection.scope["session_handler"] = SessionHandler(connection, base_session_id, store, serializer, lifetime=60) + connection.scope["session_handler"] = SessionHandler( + connection, base_session_id, False, store, serializer, lifetime=60 + ) assert get_session_handler(connection) == connection.scope["session_handler"] @@ -56,7 +62,9 @@ async def test_load_session(store: SessionStore, serializer: Serializer) -> None base_session_id = "session_id" connection = HTTPConnection(scope) connection.scope["session"] = {} - connection.scope["session_handler"] = SessionHandler(connection, base_session_id, store, serializer, lifetime=60) + connection.scope["session_handler"] = SessionHandler( + connection, base_session_id, False, store, serializer, lifetime=60 + ) await store.write("session_id", b'{"key": "value"}', lifetime=60, ttl=60) await load_session(connection)