Skip to content

Commit

Permalink
Merge pull request #151 from arnoldknott/dev
Browse files Browse the repository at this point in the history
Dev: adds session handling and callbacks to socketio namespaces
  • Loading branch information
arnoldknott authored Dec 8, 2024
2 parents 242a71c + 62c54b0 commit 3659939
Show file tree
Hide file tree
Showing 11 changed files with 84 additions and 25 deletions.
2 changes: 1 addition & 1 deletion backendAPI/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
from routers.api.v1.public_resource import router as public_resource_router
from routers.api.v1.tag import router as tag_router
from routers.socketio.v1.base import presentation_interests_router, socketio_server
from routers.socketio.v1.public_namespace import public_namespace_router
from routers.socketio.v1.demo_namespace import demo_namespace_router
from routers.socketio.v1.public_namespace import public_namespace_router
from routers.ws.v1.websockets import router as websocket_router

# print("Current directory:", os.getcwd())
Expand Down
29 changes: 24 additions & 5 deletions backendAPI/src/routers/socketio/v1/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@

from core.config import config
from core.security import (
check_token_against_guards,
get_azure_token_payload,
get_token_from_cache,
check_token_against_guards,
)
from core.types import GuardTypes

Expand Down Expand Up @@ -130,13 +130,17 @@ def __init__(
room: str = None,
guards: GuardTypes = None,
crud=None,
callback_on_connect=None,
callback_on_disconnect=None,
):
super().__init__(namespace=namespace)
self.guards = guards
self.crud = crud
self.server = socketio_server
self.namespace = namespace
self.room = room
self.callback_on_connect = callback_on_connect
self.callback_on_disconnect = callback_on_disconnect

async def callback(self):
print("=== base - callback ===")
Expand All @@ -159,15 +163,26 @@ async def on_connect(
try:
# TBD: add get scopes from guards - potentially distinguish between MSGraph scopes and backendAPI scopes?!
# token = await get_token_from_cache(auth["session_id"], ["User.Read"])
# catch and handle an expired token gracefully and return something to the client on a different message channel,
# so it can initiate the authentication process and come back with a new session id
token = await get_token_from_cache(
auth["session_id"], [f"api://{config.API_SCOPE}/socketio"]
) # TBD: add get scopes from guards - potentially distinguish between MSGraph scopes and backendAPI scopes?!
token_payload = await get_azure_token_payload(token)
print("=== base - on_connect - token_payload ===")
print(token_payload, flush=True)
# print("=== base - on_connect - token_payload ===")
# print(token_payload, flush=True)
# print("=== base - on_connect - token_payload - name ===")
# print(token_payload["name"], flush=True)
current_user = await check_token_against_guards(token_payload, guards)
print("=== base - on_connect - current_user ===")
print(current_user, flush=True)
session_data = {
"user_name": token_payload["name"],
"current_user": current_user,
}
await self.server.save_session(
sid, session_data, namespace=self.namespace
)
# print("=== base - on_connect - current_user ===")
# print(current_user, flush=True)
logger.info(
f"Client authenticated to access protected namespace {self.namespace}."
)
Expand All @@ -179,6 +194,8 @@ async def on_connect(
else:
current_user = None
logger.info(f"Client authenticated to public namespace {self.namespace}.")
if self.callback_on_connect is not None:
await self.callback_on_connect(sid)

# current_user = await check_token_against_guards(token_payload, self.guards)
# print("=== base - on_connect - sid - current_user ===")
Expand All @@ -195,3 +212,5 @@ async def on_connect(
async def on_disconnect(self, sid):
"""Disconnect event for socket.io namespaces."""
logger.info(f"Client with session id {sid} disconnected.")
if self.callback_on_disconnect is not None:
await self.callback_on_disconnect(sid)
27 changes: 25 additions & 2 deletions backendAPI/src/routers/socketio/v1/demo_namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,31 @@ def __init__(self, namespace=None):
namespace=namespace,
guards=GuardTypes(scopes=["socketio", "api.write"], roles=["User"]),
crud=ProtectedResourceCRUD,
callback_on_connect=self.callback_on_connect,
)
self.namespace = namespace
# self.namespace = namespace

async def callback_on_connect(self, sid):
"""Callback on connect for socket.io namespaces."""
# print("=== demo_namespace - callback_on_connect - sid ===")
# print(sid)
# if self.server.get_session(sid):
session = await self.server.get_session(sid, namespace=self.namespace)
# print("=== demo_namespace - callback_on_connect - session ===")
# print(session, flush=True)
if session:
await self.server.emit(
"demo_message",
f"Welcome {session['user_name']} to {self.namespace}.",
namespace=self.namespace,
)
else:
await self.server.emit(
"demo_message",
f"Welcome ANONYMOUS to {self.namespace}.",
namespace=self.namespace,
)
return "callback_on_connect"

async def on_demo_message(self, sid, data):
"""Demo message event for socket.io namespaces with guards."""
Expand All @@ -36,5 +59,5 @@ async def on_demo_message(self, sid, data):
)


demo_namespace_router = DemoNamespace("/demo_namespace")
demo_namespace_router = DemoNamespace("/demo-namespace")
# socketio_server.register_namespace(ProtectedEvents())
10 changes: 10 additions & 0 deletions backendAPI/src/routers/socketio/v1/interactive_documentation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import logging

from .base import BaseNamespace

logger = logging.getLogger(__name__)


class InteractiveDocumentation(BaseNamespace):
def __init__(self):
super().__init__(namespace="/interactive-documentation")
5 changes: 3 additions & 2 deletions backendAPI/src/routers/socketio/v1/public_namespace.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import logging

from .base import BaseNamespace
from crud.public_resource import PublicResourceCRUD

from .base import BaseNamespace

logger = logging.getLogger(__name__)


Expand All @@ -26,4 +27,4 @@ async def on_public_message(self, sid, data):
)


public_namespace_router = PublicNamespace("/public_namespace")
public_namespace_router = PublicNamespace("/public-namespace")
14 changes: 7 additions & 7 deletions backendAPI/src/routers/socketio/v1/tests/test_demo_namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,14 @@ async def test_demo_message_with_test_server(

sio = socketio_test_server

sio.register_namespace(DemoNamespace("/demo_namespace"))
sio.register_namespace(DemoNamespace("/demo-namespace"))

async for client in socketio_test_client(["/demo_namespace"]):
await client.emit("demo_message", "Something", namespace="/demo_namespace")
async for client in socketio_test_client(["/demo-namespace"]):
await client.emit("demo_message", "Something", namespace="/demo-namespace")

response = ""

@client.event(namespace="/demo_namespace")
@client.event(namespace="/demo-namespace")
async def demo_message(data):

nonlocal response
Expand All @@ -67,17 +67,17 @@ async def test_demo_message_with_production_server_fails_without_token(

try:
async for client in socketio_test_client(
["/demo_namespace"], "http://127.0.0.1:80"
["/demo-namespace"], "http://127.0.0.1:80"
):
response = None

@client.on("demo_message", namespace="/demo_namespace")
@client.on("demo_message", namespace="/demo-namespace")
async def handler(data):
nonlocal response
response = data

await client.emit(
"demo_message", "Hello, world!", namespace="/demo_namespace"
"demo_message", "Hello, world!", namespace="/demo-namespace"
)

await client.sleep(1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,17 @@ async def test_public_message_event_in_public_namespace(socketio_test_client):
"""Test the public message event in socket.io's public namespace."""

async for client in socketio_test_client(
["/public_namespace"], "http://127.0.0.1:80"
["/public-namespace"], "http://127.0.0.1:80"
):
response = None

@client.on("public_message", namespace="/public_namespace")
@client.on("public_message", namespace="/public-namespace")
async def handler(data):
nonlocal response
response = data

await client.emit(
"public_message", "Hello, world!", namespace="/public_namespace"
"public_message", "Hello, world!", namespace="/public-namespace"
)

await client.sleep(1)
Expand Down
3 changes: 3 additions & 0 deletions backendAPI/src/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,15 @@
# Mocks payload to detect scope api.write:
token_payload_user_id = {
"oid": many_test_azure_users[0]["azure_user_id"],
"name": "Test User",
}
token_payload_another_user_id = {
"oid": many_test_azure_users[1]["azure_user_id"],
"name": "Another Test User",
}
token_payload_random_user_id = {
"oid": many_test_azure_users[2]["azure_user_id"],
"name": "Random Test User",
}
token_payload_tenant_id = {
"tid": many_test_azure_users[0]["azure_tenant_id"],
Expand Down
4 changes: 2 additions & 2 deletions frontend_svelte/src/components/Chat.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
$effect(() => {
socketio.client.on(connection.event, (data) => {
console.log(`Received: ${data}`);
old_messages.push(`Received: ${data}`);
console.log(`Received from socket.io server: ${data}`);
old_messages.push(`${data}`);
});
});
</script>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,8 @@ export const load: PageServerLoad = async ({ locals }) => {
throw new Error('No session id!');
}
// TBD: change scope to socketio!
await msalAuthProvider.getAccessToken(sessionId, [`${appConfig.api_scope}/socketio`, `${appConfig.api_scope}/api.write`]);
await msalAuthProvider.getAccessToken(sessionId, [
`${appConfig.api_scope}/socketio`,
`${appConfig.api_scope}/api.write`
]);
};
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@
const public_message_connection = {
event: 'public_message',
namespace: '/public_namespace',
namespace: '/public-namespace',
room: '',
cookie_session_id: $page.data.session.sessionId
};
const demo_message_connection = {
event: 'demo_message',
namespace: '/demo_namespace',
namespace: '/demo-namespace',
// namespace: '',
room: '',
cookie_session_id: $page.data.session.sessionId
Expand Down

0 comments on commit 3659939

Please sign in to comment.