Skip to content

Commit

Permalink
fix get_session_remaining_seconds() for rolling sessions
Browse files Browse the repository at this point in the history
`get_session_remaining_seconds()` should return `last_update + lifetime - now` instead of `created + lifetime - now`.

Fix mypy: use FixtureRequest instead of SubRequest

`SubRequest` is a subclass of `FixtureRequest`, but is currently private
so pytest-asyncio uses `Any` instead. However, `FixtureRequest` typing
is sufficient for our needs, so can use that instead.
  • Loading branch information
deepcyrille committed Mar 27, 2024
1 parent af1c557 commit 65b13d1
Show file tree
Hide file tree
Showing 7 changed files with 42 additions and 15 deletions.
2 changes: 1 addition & 1 deletion starsessions/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:

connection = HTTPConnection(scope)
session_id = connection.cookies.get(self.cookie_name)
handler = SessionHandler(connection, session_id, self.store, self.serializer, self.lifetime)
handler = SessionHandler(connection, session_id, self.rolling, self.store, self.serializer, self.lifetime)

scope["session"] = LoadGuard()
scope["session_handler"] = handler
Expand Down
16 changes: 14 additions & 2 deletions starsessions/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,10 @@ def get_session_remaining_seconds(connection: HTTPConnection) -> int:
"""Get total seconds remaining before this session expires."""
now = time.time()
metadata = get_session_metadata(connection)
return int((metadata["created"] + metadata["lifetime"]) - now)
# use "last_update" if rolling session, otherwise use "created"
rolling = metadata.get("rolling", False)
last_update = metadata["last_update"] if rolling and "last_update" in metadata else metadata["created"]
return int((last_update + metadata["lifetime"]) - now)


class SessionHandler:
Expand All @@ -86,12 +89,14 @@ def __init__(
self,
connection: HTTPConnection,
session_id: typing.Optional[str],
rolling: bool,
store: SessionStore,
serializer: Serializer,
lifetime: int,
) -> None:
self.connection = connection
self.session_id = session_id
self.rolling = rolling
self.store = store
self.serializer = serializer
self.is_loaded = False
Expand All @@ -115,7 +120,12 @@ async def load(self) -> None:
)

# read and merge metadata
metadata = {"lifetime": self.lifetime, "created": time.time(), "last_access": time.time()}
metadata = {
"lifetime": self.lifetime,
"created": time.time(),
"last_access": time.time(),
"rolling": self.rolling,
}
metadata.update(data.pop("__metadata__", {}))
metadata.update({"last_access": time.time()}) # force update
self.metadata = metadata # type: ignore[assignment]
Expand All @@ -126,6 +136,8 @@ async def load(self) -> None:
self.initially_empty = len(self.connection.session) == 0

async def save(self, remaining_time: int) -> str:
if self.rolling and self.metadata:
self.metadata["last_update"] = time.time()
self.connection.session.update({"__metadata__": self.metadata})

self.session_id = await self.store.write(
Expand Down
1 change: 1 addition & 0 deletions starsessions/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ class SessionMetadata(typing.TypedDict):
lifetime: int
created: float # timestamp
last_access: float # timestamp
last_update: float # timestamp
4 changes: 2 additions & 2 deletions tests/backends/test_redis.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import pytest
import redis.asyncio
from pytest_asyncio.plugin import SubRequest
from pytest import FixtureRequest

from starsessions import ImproperlyConfigured
from starsessions.stores.base import SessionStore
Expand All @@ -13,7 +13,7 @@ def redis_key_callable(session_id: str) -> str:


@pytest.fixture(params=["prefix_", redis_key_callable], ids=["using string", "using redis_key_callable"])
def redis_store(request: SubRequest) -> SessionStore:
def redis_store(request: FixtureRequest) -> SessionStore:
redis_key = request.param
url = os.environ.get("REDIS_URL", "redis://localhost")
return RedisStore(url, prefix=redis_key)
Expand Down
4 changes: 4 additions & 0 deletions tests/test_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
"created": 1660556520,
"last_access": 1660556520,
"lifetime": 1209600,
"rolling": False,
}


Expand All @@ -59,6 +60,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
"created": 42,
"last_access": 1660556520,
"lifetime": 0,
"rolling": False,
}


Expand All @@ -79,11 +81,13 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
"created": 1660556520,
"last_access": 1660556520,
"lifetime": 1209600,
"rolling": False,
}

with mock.patch("time.time", lambda: 1660556000):
assert client.get("/", cookies={"session": "session_id"}).json() == {
"created": 1660556520,
"last_access": 1660556000,
"lifetime": 1209600,
"rolling": False,
}
14 changes: 8 additions & 6 deletions tests/test_rolling_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from unittest import mock

from starsessions import SessionMiddleware, SessionStore, load_session
from starsessions.session import get_session_remaining_seconds


@pytest.mark.asyncio
Expand All @@ -19,21 +20,22 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
await load_session(connection)
response = JSONResponse(connection.session)
await response(scope, receive, send)
# the session has been rolled again to last for another 10 seconds
assert get_session_remaining_seconds(connection) == session_lifetime

app = SessionMiddleware(app, store=store, lifetime=10, rolling=True)
client = TestClient(app)

current_time = time.time()

# it must set max-age = 10
with mock.patch("time.time", lambda: current_time):
response = client.get("/", cookies={"session": "session_id"})
first_max_age = next(cookie for cookie in response.cookies if cookie.name == "session").expires
first_expiration_date = next(cookie for cookie in response.cookies if cookie.name == "session").expires

# it must set max-age = 10 regardless of any previous value
# fast forward 2 seconds
with mock.patch("time.time", lambda: current_time + 2):
response = client.get("/", cookies={"session": "session_id"})
second_max_age = next(cookie for cookie in response.cookies if cookie.name == "session").expires
second_expiration_date = next(cookie for cookie in response.cookies if cookie.name == "session").expires

# the expiration date of the second response must be larger
assert second_max_age > first_max_age
# the expiration date of the cooke in the second response must be extended by 2 seconds
assert second_expiration_date - first_expiration_date == 2
16 changes: 12 additions & 4 deletions tests/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ def test_regenerate_session_id(store: SessionStore, serializer: Serializer) -> N
base_session_id = "some_id"
connection = HTTPConnection(scope)
connection.scope["session"] = {}
connection.scope["session_handler"] = SessionHandler(connection, base_session_id, store, serializer, lifetime=60)
connection.scope["session_handler"] = SessionHandler(
connection, base_session_id, False, store, serializer, lifetime=60
)

session_id = regenerate_session_id(connection)
assert session_id
Expand All @@ -34,7 +36,9 @@ def test_get_session_id(store: SessionStore, serializer: Serializer) -> None:
base_session_id = "some_id"
connection = HTTPConnection(scope)
connection.scope["session"] = {}
connection.scope["session_handler"] = SessionHandler(connection, base_session_id, store, serializer, lifetime=60)
connection.scope["session_handler"] = SessionHandler(
connection, base_session_id, False, store, serializer, lifetime=60
)

session_id = get_session_id(connection)
assert session_id == base_session_id
Expand All @@ -45,7 +49,9 @@ def test_get_session_handler(store: SessionStore, serializer: Serializer) -> Non
base_session_id = "some_id"
connection = HTTPConnection(scope)
connection.scope["session"] = {}
connection.scope["session_handler"] = SessionHandler(connection, base_session_id, store, serializer, lifetime=60)
connection.scope["session_handler"] = SessionHandler(
connection, base_session_id, False, store, serializer, lifetime=60
)

assert get_session_handler(connection) == connection.scope["session_handler"]

Expand All @@ -56,7 +62,9 @@ async def test_load_session(store: SessionStore, serializer: Serializer) -> None
base_session_id = "session_id"
connection = HTTPConnection(scope)
connection.scope["session"] = {}
connection.scope["session_handler"] = SessionHandler(connection, base_session_id, store, serializer, lifetime=60)
connection.scope["session_handler"] = SessionHandler(
connection, base_session_id, False, store, serializer, lifetime=60
)

await store.write("session_id", b'{"key": "value"}', lifetime=60, ttl=60)
await load_session(connection)
Expand Down

0 comments on commit 65b13d1

Please sign in to comment.