Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix get_session_remaining_seconds() for rolling sessions #70

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion starsessions/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:

connection = HTTPConnection(scope)
session_id = connection.cookies.get(self.cookie_name)
handler = SessionHandler(connection, session_id, self.store, self.serializer, self.lifetime)
handler = SessionHandler(connection, session_id, self.rolling, self.store, self.serializer, self.lifetime)

scope["session"] = LoadGuard()
scope["session_handler"] = handler
Expand Down
16 changes: 14 additions & 2 deletions starsessions/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,10 @@ def get_session_remaining_seconds(connection: HTTPConnection) -> int:
"""Get total seconds remaining before this session expires."""
now = time.time()
metadata = get_session_metadata(connection)
return int((metadata["created"] + metadata["lifetime"]) - now)
# use "last_update" if rolling session, otherwise use "created"
rolling = metadata.get("rolling", False)
last_update = metadata["last_update"] if rolling and "last_update" in metadata else metadata["created"]
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if we will use last_access everywhere? This seems very natural to me. Rolling session is alive while it's being used. So last_access may work.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried that but last_access is updated before in the load function, whereas last_update is updated in the save function

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe we should use last_access here, as the purpose of rolling session is not to expire while being used.
last_access is the timestamp when the session was last loaded, read used.

I think that we should not use last_updated to count session remaining time because I think it is a good idea not to call handler.save at all if session was not modified to avoid network round trip, e.g. with Redis backend (feature out of scope).

Since we always update last_access on every load (read endpoint call), we do extend rolling session TTL by this action. So, my opinion is just to use last_access in place of created and this should solve your requirement.

return int((last_update + metadata["lifetime"]) - now)


class SessionHandler:
Expand All @@ -86,12 +89,14 @@ def __init__(
self,
connection: HTTPConnection,
session_id: typing.Optional[str],
rolling: bool,
store: SessionStore,
serializer: Serializer,
lifetime: int,
) -> None:
self.connection = connection
self.session_id = session_id
self.rolling = rolling
self.store = store
self.serializer = serializer
self.is_loaded = False
Expand All @@ -115,7 +120,12 @@ async def load(self) -> None:
)

# read and merge metadata
metadata = {"lifetime": self.lifetime, "created": time.time(), "last_access": time.time()}
metadata = {
"lifetime": self.lifetime,
"created": time.time(),
"last_access": time.time(),
"rolling": self.rolling,
}
metadata.update(data.pop("__metadata__", {}))
metadata.update({"last_access": time.time()}) # force update
self.metadata = metadata # type: ignore[assignment]
Expand All @@ -126,6 +136,8 @@ async def load(self) -> None:
self.initially_empty = len(self.connection.session) == 0

async def save(self, remaining_time: int) -> str:
if self.rolling and self.metadata:
self.metadata["last_update"] = time.time()
self.connection.session.update({"__metadata__": self.metadata})

self.session_id = await self.store.write(
Expand Down
1 change: 1 addition & 0 deletions starsessions/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ class SessionMetadata(typing.TypedDict):
lifetime: int
created: float # timestamp
last_access: float # timestamp
last_update: float # timestamp
4 changes: 2 additions & 2 deletions tests/backends/test_redis.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import pytest
import redis.asyncio
from pytest_asyncio.plugin import SubRequest
from pytest import FixtureRequest

from starsessions import ImproperlyConfigured
from starsessions.stores.base import SessionStore
Expand All @@ -13,7 +13,7 @@ def redis_key_callable(session_id: str) -> str:


@pytest.fixture(params=["prefix_", redis_key_callable], ids=["using string", "using redis_key_callable"])
def redis_store(request: SubRequest) -> SessionStore:
def redis_store(request: FixtureRequest) -> SessionStore:
redis_key = request.param
url = os.environ.get("REDIS_URL", "redis://localhost")
return RedisStore(url, prefix=redis_key)
Expand Down
4 changes: 4 additions & 0 deletions tests/test_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
"created": 1660556520,
"last_access": 1660556520,
"lifetime": 1209600,
"rolling": False,
}


Expand All @@ -59,6 +60,7 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
"created": 42,
"last_access": 1660556520,
"lifetime": 0,
"rolling": False,
}


Expand All @@ -79,11 +81,13 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
"created": 1660556520,
"last_access": 1660556520,
"lifetime": 1209600,
"rolling": False,
}

with mock.patch("time.time", lambda: 1660556000):
assert client.get("/", cookies={"session": "session_id"}).json() == {
"created": 1660556520,
"last_access": 1660556000,
"lifetime": 1209600,
"rolling": False,
}
14 changes: 8 additions & 6 deletions tests/test_rolling_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from unittest import mock

from starsessions import SessionMiddleware, SessionStore, load_session
from starsessions.session import get_session_remaining_seconds


@pytest.mark.asyncio
Expand All @@ -19,21 +20,22 @@ async def app(scope: Scope, receive: Receive, send: Send) -> None:
await load_session(connection)
response = JSONResponse(connection.session)
await response(scope, receive, send)
# the session has been rolled again to last for another 10 seconds
assert get_session_remaining_seconds(connection) == session_lifetime

app = SessionMiddleware(app, store=store, lifetime=10, rolling=True)
client = TestClient(app)

current_time = time.time()

# it must set max-age = 10
with mock.patch("time.time", lambda: current_time):
response = client.get("/", cookies={"session": "session_id"})
first_max_age = next(cookie for cookie in response.cookies if cookie.name == "session").expires
first_expiration_date = next(cookie for cookie in response.cookies if cookie.name == "session").expires

# it must set max-age = 10 regardless of any previous value
# fast forward 2 seconds
with mock.patch("time.time", lambda: current_time + 2):
response = client.get("/", cookies={"session": "session_id"})
second_max_age = next(cookie for cookie in response.cookies if cookie.name == "session").expires
second_expiration_date = next(cookie for cookie in response.cookies if cookie.name == "session").expires

# the expiration date of the second response must be larger
assert second_max_age > first_max_age
# the expiration date of the cooke in the second response must be extended by 2 seconds
assert second_expiration_date - first_expiration_date == 2
16 changes: 12 additions & 4 deletions tests/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ def test_regenerate_session_id(store: SessionStore, serializer: Serializer) -> N
base_session_id = "some_id"
connection = HTTPConnection(scope)
connection.scope["session"] = {}
connection.scope["session_handler"] = SessionHandler(connection, base_session_id, store, serializer, lifetime=60)
connection.scope["session_handler"] = SessionHandler(
connection, base_session_id, False, store, serializer, lifetime=60
)

session_id = regenerate_session_id(connection)
assert session_id
Expand All @@ -34,7 +36,9 @@ def test_get_session_id(store: SessionStore, serializer: Serializer) -> None:
base_session_id = "some_id"
connection = HTTPConnection(scope)
connection.scope["session"] = {}
connection.scope["session_handler"] = SessionHandler(connection, base_session_id, store, serializer, lifetime=60)
connection.scope["session_handler"] = SessionHandler(
connection, base_session_id, False, store, serializer, lifetime=60
)

session_id = get_session_id(connection)
assert session_id == base_session_id
Expand All @@ -45,7 +49,9 @@ def test_get_session_handler(store: SessionStore, serializer: Serializer) -> Non
base_session_id = "some_id"
connection = HTTPConnection(scope)
connection.scope["session"] = {}
connection.scope["session_handler"] = SessionHandler(connection, base_session_id, store, serializer, lifetime=60)
connection.scope["session_handler"] = SessionHandler(
connection, base_session_id, False, store, serializer, lifetime=60
)

assert get_session_handler(connection) == connection.scope["session_handler"]

Expand All @@ -56,7 +62,9 @@ async def test_load_session(store: SessionStore, serializer: Serializer) -> None
base_session_id = "session_id"
connection = HTTPConnection(scope)
connection.scope["session"] = {}
connection.scope["session_handler"] = SessionHandler(connection, base_session_id, store, serializer, lifetime=60)
connection.scope["session_handler"] = SessionHandler(
connection, base_session_id, False, store, serializer, lifetime=60
)

await store.write("session_id", b'{"key": "value"}', lifetime=60, ttl=60)
await load_session(connection)
Expand Down
Loading