diff --git a/examples/app.py b/examples/app.py index fe62c18..05ebd81 100644 --- a/examples/app.py +++ b/examples/app.py @@ -5,15 +5,15 @@ sio = SocketManager(app=app) -@app.sio.on('join') -async def handle_join(sid, *args, **kwargs): - await sio.emit('lobby', 'User joined') +@sio.event +async def connect(sid, *args, **kwargs): + print(f"[{sid}] Connected!") + await sio.emit('test', 'Hello world!') @sio.on('test') -async def test(sid, *args, **kwargs): - await sio.emit('hey', 'joe') - +async def test(sid, data, **kwargs): + print(f'[{sid}] Message Received! >> ', data) if __name__ == '__main__': @@ -25,4 +25,4 @@ async def test(sid, *args, **kwargs): import uvicorn - uvicorn.run("examples.app:app", host='0.0.0.0', port=8000, reload=True, debug=False) + uvicorn.run("app:app", host='0.0.0.0', port=8000) diff --git a/examples/client.py b/examples/client.py new file mode 100644 index 0000000..ceece6f --- /dev/null +++ b/examples/client.py @@ -0,0 +1,18 @@ +import socketio + +sio = socketio.Client() + + +@sio.event +def connect(): + print("Connected!") + + +@sio.on('test') +def on_message(data): + print('Message Received! >> ', data) + sio.emit('test', 'Hello world!') + + +sio.connect('http://127.0.0.1:8000') +sio.wait() diff --git a/examples/cors.py b/examples/cors.py new file mode 100644 index 0000000..b0e45af --- /dev/null +++ b/examples/cors.py @@ -0,0 +1,22 @@ +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware + +from fastapi_socketio import SocketManager + +app = FastAPI() +# Adding the CORS middleware will overwrite SocketManager's CORS settings +# Make sure to add the CORS middleware before SocketManager +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) +sio = SocketManager(app=app, cors_allowed_origins="*") + + +if __name__ == '__main__': + import uvicorn + + uvicorn.run("examples.cors:app", host='0.0.0.0', port=8000) diff --git a/fastapi_socketio/__init__.py b/fastapi_socketio/__init__.py index 4bf4510..34802f2 100644 --- a/fastapi_socketio/__init__.py +++ b/fastapi_socketio/__init__.py @@ -1 +1 @@ -from .socket_manager import SocketManager \ No newline at end of file +from .socket_manager import SocketManager diff --git a/fastapi_socketio/socket_manager.py b/fastapi_socketio/socket_manager.py index ff23726..ece4fb2 100644 --- a/fastapi_socketio/socket_manager.py +++ b/fastapi_socketio/socket_manager.py @@ -2,8 +2,10 @@ import socketio from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware -class SocketManager: + +class SocketManager(socketio.AsyncServer): """ Integrates SocketIO with FastAPI app. Adds `sio` property to FastAPI object (app). @@ -18,86 +20,26 @@ class SocketManager: """ def __init__( - self, - app: FastAPI, - mount_location: str = "/ws", - socketio_path: str = "socket.io", - cors_allowed_origins: Union[str, list] = '*', - async_mode: str = "asgi", - **kwargs + self, + app: FastAPI, + mount_location: str = "/ws", + socketio_path: str = "socket.io", + cors_allowed_origins: Union[str, list] = '*', + async_mode: str = "asgi", + **kwargs ) -> None: - # TODO: Change Cors policy based on fastapi cors Middleware - self._sio = socketio.AsyncServer(async_mode=async_mode, cors_allowed_origins=cors_allowed_origins, **kwargs) + middleware = next((x for x in app.user_middleware if issubclass(x.cls, CORSMiddleware)), None) + if middleware: + cors_allowed_origins = middleware.options.get("allow_origins", "*") + super().__init__(cors_allowed_origins=cors_allowed_origins, async_mode=async_mode, **kwargs) self._app = socketio.ASGIApp( - socketio_server=self._sio, socketio_path=socketio_path + socketio_server=self, socketio_path=socketio_path ) app.mount(mount_location, self._app) - app.sio = self._sio + app.add_route(f"/{socketio_path}/", route=self._app, methods=["GET", "POST"]) + app.add_websocket_route(f"/{socketio_path}/", self._app) + app.sio = self def is_asyncio_based(self) -> bool: return True - - @property - def on(self): - return self._sio.on - - @property - def attach(self): - return self._sio.attach - - @property - def emit(self): - return self._sio.emit - - @property - def send(self): - return self._sio.send - - @property - def call(self): - return self._sio.call - - @property - def close_room(self): - return self._sio.close_room - - @property - def get_session(self): - return self._sio.get_session - - @property - def save_session(self): - return self._sio.save_session - - @property - def session(self): - return self._sio.session - - @property - def disconnect(self): - return self._sio.disconnect - - @property - def handle_request(self): - return self._sio.handle_request - - @property - def start_background_task(self): - return self._sio.start_background_task - - @property - def sleep(self): - return self._sio.sleep - - @property - def enter_room(self): - return self._sio.enter_room - - @property - def leave_room(self): - return self._sio.leave_room - - @property - def register_namespace(self): - return self._sio.register_namespace