diff --git a/README.md b/README.md index 39d11cc..aeda292 100644 --- a/README.md +++ b/README.md @@ -11,63 +11,108 @@ pip3 install realtime==1.0.2 ## Installation from source ```bash -pip3 install -r requirements.txt +poetry install python3 usage.py - ``` ## Quick Start ```python from realtime.connection import Socket +import asyncio -def callback1(payload): - print("Callback 1: ", payload) +async def callback1(payload): + print(f"Got message: {payload}") -def callback2(payload): - print("Callback 2: ", payload) +async def main(): -if __name__ == "__main__": - URL = "ws://localhost:4000/socket/websocket" - s = Socket(URL) - s.connect() + # your phoenix server token + TOKEN = "" + # your phoenix server URL + URL = f"ws://127.0.0.1:4000/socket/websocket?token={TOKEN}&vsn=2.0.0" + + client = Socket(URL) + + # connect to the server + await client.connect() + + # fire and forget the listening routine + listen_task = asyncio.ensure_future(client.listen()) + + # join the channel + channel = client.set_channel("this:is:my:topic") + await channel.join() + + # by using a partial function + channel.on("your_event_name", None, callback1) + + # we give it some time to complete + await asyncio.sleep(10) - channel_1 = s.set_channel("realtime:public:todos") - channel_1.join().on("UPDATE", callback1) + # proper shut down + listen_task.cancel() - channel_2 = s.set_channel("realtime:public:users") - channel_2.join().on("*", callback2) +if __name__ == '__main__': + loop = asyncio.get_event_loop() + try: + loop.run_until_complete(main()) - s.listen() + except KeyboardInterrupt: + loop.stop() + exit(0) ``` +## Sending and Receiving data +Sending data to phoenix channels using `send`: +```python +await channel.send("your_handler", "this is my payload", None) +``` +One can also use references for queries/answers: +```python + +ref = 1 +channel.on(None, ref, callback1) +await channel.send("your_handler", "this is my payload", ref) +# remove the callback when your are done +# |-> exercise left to the reader ;) +channel.off(None, ref, callback1) +``` +## Examples +see `usage.py`, `sending-receiving-usage.py`, and `fd-usage.py`. ## Sample usage with Supabase -Here's how you could connect to your realtime endpoint using Supabase endpoint. Correct as of 5th June 2021. Please replace `SUPABASE_ID` and `API_KEY` with your own `SUPABASE_ID` and `API_KEY`. The variables shown below are fake and they will not work if you try to run the snippet. +Here's how you could connect to your realtime endpoint using Supabase endpoint. Should be correct as of 13th Feb 2024. Please replace `SUPABASE_ID` and `API_KEY` with your own `SUPABASE_ID` and `API_KEY`. The variables shown below are fake and they will not work if you try to run the snippet. ```python from realtime.connection import Socket +import asyncio SUPABASE_ID = "dlzlllxhaakqdmaapvji" API_KEY = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJyb2xlIjoiYW5vbiIsImlhdCI6MT" -def callback1(payload): +async def callback1(payload): print("Callback 1: ", payload) -if __name__ == "__main__": +async def main(): URL = f"wss://{SUPABASE_ID}.supabase.co/realtime/v1/websocket?apikey={API_KEY}&vsn=1.0.0" s = Socket(URL) - s.connect() + await s.connect() + listen_task = asyncio.ensure_future(s.listen()) channel_1 = s.set_channel("realtime:*") - channel_1.join().on("UPDATE", callback1) - s.listen() - -``` - -Then, go to the Supabase interface and toggle a row in a table. You should see a corresponding payload show up in your console/terminal. + await channel_1.join() + channel_1.on("UPDATE", callback1) +if __name__ == '__main__': + loop = asyncio.get_event_loop() + try: + loop.run_until_complete(main()) + except KeyboardInterrupt: + loop.stop() + exit(0) +``` +Then, go to the Supabase interface and toggle a row in a table. You should see a corresponding payload show up in your console/terminal. \ No newline at end of file diff --git a/fd-usage.py b/fd-usage.py new file mode 100644 index 0000000..3e9e82b --- /dev/null +++ b/fd-usage.py @@ -0,0 +1,65 @@ +from realtime.connection import Socket +import asyncio +import uuid +import json +# We will use a partial function to pass the file descriptor to the callback +from functools import partial + +# notice that the callback has two arguments +# and that it is not an async function +# it will be executed in a different thread +def callback(fd, payload): + fd.write(json.dumps(payload)) + print(f"Callback with reference c2: {payload}") + +async def main(): + + # your phoenix server token + TOKEN = "" + # your phoenix server URL + URL = f"ws://127.0.0.1:4000/socket/websocket?token={TOKEN}&vsn=2.0.0" + + # We create a file descriptor to write the received messages + fd = create_file_and_return_fd() + + client = Socket(URL) + + # connect to the server + await client.connect() + + # fire and forget the listening routine + listen_task = asyncio.ensure_future(client.listen()) + + # join the channel + channel = client.set_channel("this:is:my:topic") + await channel.join() + + # we can also use reference for the callback + # with a proper reply elixir handler: + #def handle_in("ping", payload, socket) do + # {:reply, {:ok, payload}, socket} + # Here we use uuid, use whatever you want + ref = str(uuid.uuid4()) + # Pass the file descriptor to the callback through a partial function + channel.on(None, ref, partial(callback, fd)) + await channel.send("ping", "this is the ping payload that shall appear in myfile.txt", ref) + + # we give it some time to complete + await asyncio.sleep(10) + + # proper shut down + listen_task.cancel() + fd.close() + +def create_file_and_return_fd(): + fd = open("myfile.txt", "w") + return fd + +if __name__ == '__main__': + loop = asyncio.get_event_loop() + try: + loop.run_until_complete(main()) + + except KeyboardInterrupt: + loop.stop() + exit(0) \ No newline at end of file diff --git a/poetry.lock b/poetry.lock index cc13076..a168c32 100644 --- a/poetry.lock +++ b/poetry.lock @@ -5,7 +5,7 @@ name = "annotated-types" version = "0.6.0" description = "Reusable constraint types to use with typing.Annotated" optional = false -python-versions = ">=3.8" +python-versions = ">=3.9" files = [ {file = "annotated_types-0.6.0-py3-none-any.whl", hash = "sha256:0641064de18ba7a25dee8f96403ebc39113d0cb953a01429249d5c7564666a43"}, {file = "annotated_types-0.6.0.tar.gz", hash = "sha256:563339e807e53ffd9c267e99fc6d9ea23eb8443c08f112651963e24e22f84a5d"}, diff --git a/pyproject.toml b/pyproject.toml index 2831bfe..0b09188 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,6 +15,7 @@ python = "^3.8" websockets = "^11.0" python-dateutil = "^2.8.1" typing-extensions = "^4.2.0" +uuid = "^1.30" [tool.poetry.dev-dependencies] pytest = "^7.2.0" diff --git a/realtime/channel.py b/realtime/channel.py index e4ac908..8aaf7d3 100644 --- a/realtime/channel.py +++ b/realtime/channel.py @@ -1,9 +1,11 @@ from __future__ import annotations -import asyncio import json -from typing import Any, List, Dict, TYPE_CHECKING, NamedTuple +import logging +import uuid +from typing import List, TYPE_CHECKING, NamedTuple, Dict, Any +from realtime.message import ChannelEvents from realtime.types import Callback if TYPE_CHECKING: @@ -13,6 +15,7 @@ class CallbackListener(NamedTuple): """A tuple with `event` and `callback` """ event: str + ref: str callback: Callback @@ -20,6 +23,7 @@ class Channel: """ `Channel` is an abstraction for a topic listener for an existing socket connection. Each Channel has its own topic and a list of event-callbacks that responds to messages. + A client can also send messages to a channel and register callback when expecting replies. Should only be instantiated through `connection.Socket().set_channel(topic)` Topic-Channel has a 1-many relationship. """ @@ -35,45 +39,81 @@ def __init__(self, socket: Socket, topic: str, params: Dict[str, Any] = {}) -> N self.topic = topic self.listeners: List[CallbackListener] = [] self.joined = False + self.join_ref = str(uuid.uuid4()) + self.control_msg_ref = "" - def join(self) -> Channel: + async def join(self) -> None: """ - Wrapper for async def _join() to expose a non-async interface - Essentially gets the only event loop and attempt joining a topic - :return: Channel + Coroutine that attempts to join Phoenix Realtime server via a certain topic + :return: None """ - loop = asyncio.get_event_loop() # TODO: replace with get_running_loop - loop.run_until_complete(self._join()) - return self + if self.socket.version == 1: + join_req = dict(topic=self.topic, event=ChannelEvents.join, + payload={}, ref=None) + elif self.socket.version == 2: + # [join_reference, message_reference, topic_name, event_name, payload] + self.control_msg_ref = str(uuid.uuid4()) + join_req = [self.join_ref, self.control_msg_ref, self.topic, ChannelEvents.join, self.params] + + try: + await self.socket.ws_connection.send(json.dumps(join_req)) + except Exception as e: + logging.error(f"Error while joining channel: {str(e)}", exc_info=True) + return - async def _join(self) -> None: + async def leave(self) -> None: """ - Coroutine that attempts to join Phoenix Realtime server via a certain topic + Coroutine that attempts to leave Phoenix Realtime server via a certain topic :return: None """ - join_req = dict(topic=self.topic, event="phx_join", - payload={}, ref=None) + if self.socket.version == 1: + leave_req = dict(topic=self.topic, event=ChannelEvents.leave, + payload={}, ref=None) + elif self.socket.version == 2: + leave_req = [self.join_ref, None, self.topic, ChannelEvents.leave, {}] try: - await self.socket.ws_connection.send(json.dumps(join_req)) + await self.socket.ws_connection.send(json.dumps(leave_req)) except Exception as e: - print(str(e)) # TODO: better error propagation + logging.error(f"Error while leaving channel: {str(e)}", exc_info=True) return - def on(self, event: str, callback: Callback) -> Channel: + def on(self, event: str, ref: str, callback: Callback) -> Channel: """ :param event: A specific event will have a specific callback + :param ref: A specific reference that will have a specific callback :param callback: Callback that takes msg payload as its first argument :return: Channel """ - cl = CallbackListener(event=event, callback=callback) + cl = CallbackListener(event=event, ref=ref, callback=callback) self.listeners.append(cl) return self - def off(self, event: str) -> None: + def off(self, event: str, ref: str) -> None: """ :param event: Stop responding to a certain event + :param event: Stop responding to a certain reference :return: None """ self.listeners = [ - callback for callback in self.listeners if callback.event != event] + callback for callback in self.listeners if (callback.event != event and callback.ref != ref)] + + async def send(self, event_name: str, payload: str, ref: str) -> None: + """ + Coroutine that attempts to join Phoenix Realtime server via a certain topic + :param event_name: The event_name: it must match the first argument of a handle_in function on the server channel module. + :param payload: The payload to be sent to the phoenix server + :param ref: The message reference that the server will use for replying + :return: None + """ + if self.socket.version == 1: + msg = dict(topic=self.topic, event=event_name, + payload=payload, ref=None) + elif self.socket.version == 2: + msg = [None, ref, self.topic, event_name, payload] + + try: + await self.socket.ws_connection.send(json.dumps(msg)) + except Exception as e: + logging.error(f"Error while sending message: {str(e)}", exc_info=True) + raise diff --git a/realtime/connection.py b/realtime/connection.py index 1c2e7eb..ed347f1 100644 --- a/realtime/connection.py +++ b/realtime/connection.py @@ -7,6 +7,12 @@ import websockets from typing_extensions import ParamSpec +from websockets.exceptions import ( + ConnectionClosed, + InvalidHandshake, + InvalidMessage, + ConnectionClosedOK, +) from realtime.channel import Channel from realtime.exceptions import NotConnectedError @@ -29,6 +35,10 @@ def wrapper(*args: T_ParamSpec.args, **kwargs: T_ParamSpec.kwargs) -> T_Retval: return wrapper +class CallbackError(Exception): + pass + + class Socket: def __init__( self, @@ -36,6 +46,8 @@ def __init__( auto_reconnect: bool = False, params: Dict[str, Any] = {}, hb_interval: int = 5, + version: int = 2, + ping_timeout: int = 20, ) -> None: """ `Socket` is the abstraction for an actual socket connection that receives and 'reroutes' `Message` according to its `topic` and `event`. @@ -43,7 +55,8 @@ def __init__( Socket-Topic has a 1-many relationship. :param url: Websocket URL of the Realtime server. starts with `ws://` or `wss://` :param params: Optional parameters for connection. - :param hb_interval: WS connection is kept alive by sending a heartbeat message. Optional, defaults to 5. + :param hb_interval: WS connection is kept alive by sending a heartbeat message. Optional, defaults to 30. + :param version: phoenix JSON serializer version. """ self.url = url self.channels = defaultdict(list) @@ -51,93 +64,181 @@ def __init__( self.params = params self.hb_interval = hb_interval self.ws_connection: websockets.client.WebSocketClientProtocol - self.kept_alive = False + self.kept_alive = set() self.auto_reconnect = auto_reconnect + self.version = version + self.ping_timeout = ping_timeout self.channels: DefaultDict[str, List[Channel]] = defaultdict(list) - @ensure_connection - def listen(self) -> None: - """ - Wrapper for async def _listen() to expose a non-async interface - In most cases, this should be the last method executed as it starts an infinite listening loop. - :return: None - """ - loop = asyncio.get_event_loop() # TODO: replace with get_running_loop - loop.run_until_complete(asyncio.gather(self._listen(), self._keep_alive())) + async def _run_callback_safe(self, callback: Callback, payload: Dict) -> None: + try: + if asyncio.iscoroutinefunction(callback): + asyncio.create_task(callback(payload)) + else: + asyncio.create_task(asyncio.to_thread(callback, payload)) + except Exception as e: + raise CallbackError("Error in callback") from e - async def _listen(self) -> None: + @ensure_connection + async def listen(self) -> None: """ An infinite loop that keeps listening. :return: None """ + if self.hb_interval >= 0: + self.kept_alive.add(asyncio.create_task(self.keep_alive())) + while True: try: + await asyncio.sleep(0) + msg = await self.ws_connection.recv() - msg = Message(**json.loads(msg)) + if self.version == 1: + msg = Message(**json.loads(msg)) + elif self.version == 2: + msg_array = json.loads(msg) + msg = Message( + join_ref=msg_array[0], + ref=msg_array[1], + topic=msg_array[2], + event=msg_array[3], + payload=msg_array[4], + ) if msg.event == ChannelEvents.reply: - continue + for channel in self.channels.get(msg.topic, []): + if msg.ref == channel.control_msg_ref: + if msg.payload["status"] == "error": + logging.info( + f"Error joining channel: {msg.topic} - {msg.payload['response']['reason']}" + ) + break + elif msg.payload["status"] == "ok": + logging.info(f"Successfully joined {msg.topic}") + continue + else: + for cl in channel.listeners: + if cl.ref in ["*", msg.ref]: + await self._run_callback_safe(cl.callback, msg.payload) + + if msg.event == ChannelEvents.close: + for channel in self.channels.get(msg.topic, []): + if msg.join_ref == channel.join_ref: + logging.info(f"Successfully left {msg.topic}") + continue for channel in self.channels.get(msg.topic, []): for cl in channel.listeners: if cl.event in ["*", msg.event]: - cl.callback(msg.payload) - except websockets.exceptions.ConnectionClosed: - if self.auto_reconnect: - logging.info( - "Connection with server closed, trying to reconnect..." - ) - await self._connect() - for topic, channels in self.channels.items(): - for channel in channels: - await channel._join() - else: - logging.exception("Connection with the server closed.") - break - - def connect(self) -> None: - """ - Wrapper for async def _connect() to expose a non-async interface - """ - loop = asyncio.get_event_loop() # TODO: replace with get_running - loop.run_until_complete(self._connect()) - self.connected = True + await self._run_callback_safe(cl.callback, msg.payload) + + except ConnectionClosedOK: + logging.info("Connection was closed normally.") + await self.leave_all() + break + + except InvalidMessage: + logging.error( + "Received an invalid message. Check message format and content." + ) + + except ConnectionClosed as e: + logging.error(f"Connection closed unexpectedly: {e}") + await self._handle_reconnection() + + except InvalidHandshake: + logging.error( + "Invalid handshake while connecting. Ensure your client and server configurations match." + ) - async def _connect(self) -> None: - ws_connection = await websockets.connect(self.url) + except asyncio.CancelledError: + logging.info("Listen task was cancelled.") + await self.leave_all() + break - if ws_connection.open: - logging.info("Connection was successful") - self.ws_connection = ws_connection - self.connected = True + except CallbackError: + logging.info("Error in callback") + + except ( + Exception + ) as e: # A general exception handler should be the last resort + logging.error(f"Unexpected error in listen: {e}") + await self._handle_reconnection() + + async def connect(self) -> None: + while True: + try: + ws_connection = await websockets.connect(self.url, ping_timeout=self.ping_timeout) + + self.ws_connection = ws_connection + self.connected = True + logging.info("Connection was successful") + break + except OSError: + logging.error( + "Connection failed. Retrying in 3 seconds. Ensure the server is alive and responsive." + ) + await asyncio.sleep(3) + except asyncio.CancelledError: + logging.info("Connect task was cancelled.") + break + + async def _handle_reconnection(self) -> None: + if self.auto_reconnect: + logging.info("Connection with server closed, trying to reconnect...") + await self.connect() + for topic, channels in self.channels.items(): + for channel in channels: + await channel.join() else: - raise Exception("Connection Failed") + logging.exception("Connection with the server closed.") + + async def leave_all(self) -> None: + for channel in self.channels: + for chan in self.channels.get(channel, []): + await chan.leave() - async def _keep_alive(self) -> None: + async def keep_alive(self) -> None: """ Sending heartbeat to server every 5 seconds Ping - pong messages to verify connection is alive """ while True: try: - data = dict( - topic=PHOENIX_CHANNEL, - event=ChannelEvents.heartbeat, - payload=HEARTBEAT_PAYLOAD, - ref=None, - ) + if self.version == 1: + data = dict( + topic=PHOENIX_CHANNEL, + event=ChannelEvents.heartbeat, + payload=HEARTBEAT_PAYLOAD, + ref=None, + ) + elif self.version == 2: + # [null,"4","phoenix","heartbeat",{}] + data = [ + None, + None, + PHOENIX_CHANNEL, + ChannelEvents.heartbeat, + HEARTBEAT_PAYLOAD, + ] + await self.ws_connection.send(json.dumps(data)) await asyncio.sleep(self.hb_interval) - except websockets.exceptions.ConnectionClosed: - if self.auto_reconnect: - logging.info( - "Connection with server closed, trying to reconnect..." - ) - await self._connect() - else: - logging.exception("Connection with the server closed.") - break + + except asyncio.CancelledError: + logging.info("Keep alive task was cancelled.") + break + except ConnectionClosed: + logging.error( + "Connection closed unexpectedly during heartbeat. Ensure the server is alive and responsive." + ) + await self._handle_reconnection() + + except ( + Exception + ) as e: # A general exception handler should be the last resort + logging.error(f"Unexpected error in keep_alive: {e}") @ensure_connection def set_channel(self, topic: str) -> Channel: @@ -150,11 +251,20 @@ def set_channel(self, topic: str) -> Channel: return chan + def remove_channel(self, topic: str) -> None: + """ + :param topic: Removes a channel from the socket + :return: None + """ + self.channels.pop(topic, None) + def summary(self) -> None: """ - Prints a list of topics and event the socket is listening to + Prints a list of topics and event, and reference that the socket is listening to :return: None """ for topic, chans in self.channels.items(): for chan in chans: - print(f"Topic: {topic} | Events: {[e for e, _ in chan.callbacks]}]") + print( + f"Topic: {topic} | Events: {[e for e, _, _ in chan.listeners]} | References: {[r for _, r, _ in chan.listeners]}]" + ) diff --git a/realtime/message.py b/realtime/message.py index 9909d4d..532a145 100644 --- a/realtime/message.py +++ b/realtime/message.py @@ -13,6 +13,8 @@ class Message: ref: Any topic: str + join_ref: Any = None # V2 + def __hash__(self): return hash((self.event, tuple(list(self.payload.values())), self.ref, self.topic)) @@ -32,4 +34,4 @@ class ChannelEvents(str, Enum): PHOENIX_CHANNEL = "phoenix" -HEARTBEAT_PAYLOAD = {"msg": "ping"} +HEARTBEAT_PAYLOAD = {} diff --git a/sending-receiving-usage.py b/sending-receiving-usage.py new file mode 100644 index 0000000..597e24b --- /dev/null +++ b/sending-receiving-usage.py @@ -0,0 +1,66 @@ +from realtime.connection import Socket +import asyncio +import uuid + +async def callback1(payload): + print(f"c1: {payload}") + +def callback2(payload): + print(f"c2: {payload}") + +async def main(): + + # your phoenix server token + TOKEN = "" + # your phoenix server URL + URL = f"ws://127.0.0.1:4000/socket/websocket?token={TOKEN}&vsn=2.0.0" + + client = Socket(URL) + + # connect to the server + await client.connect() + + # fire and forget the listening routine + listen_task = asyncio.ensure_future(client.listen()) + + # join the channel + channel = client.set_channel("this:is:my:topic") + await channel.join() + + channel.on("test_event", None, callback1) + + # here is an example corresponding elixir handler for the sake of the example: + #def handle_in("request_ping", payload, socket) do + # push(socket, "test_event", %{body: payload}) + # {:noreply, socket} + #end + + await channel.send("request_ping", "this is my payload 1", None) + await channel.send("request_ping", "this is my payload 2", None) + await channel.send("request_ping", "this is my payload 3", None) + + # we can also use reference for the callback + # with a proper reply elixir handler: + #def handle_in("ping", payload, socket) do + # {:reply, {:ok, payload}, socket} + #end + + # Here we use uuid, use whatever you want + ref = str(uuid.uuid4()) + channel.on(None, ref, callback2) + await channel.send("ping", "this is my ping payload", ref) + + # we give it some time to complete + await asyncio.sleep(10) + + # proper shut down + listen_task.cancel() + +if __name__ == '__main__': + loop = asyncio.get_event_loop() + try: + loop.run_until_complete(main()) + + except KeyboardInterrupt: + loop.stop() + exit(0) \ No newline at end of file diff --git a/usage.py b/usage.py index 56339ee..5085e95 100644 --- a/usage.py +++ b/usage.py @@ -1,23 +1,42 @@ from realtime.connection import Socket +import asyncio +async def callback1(payload): + print(f"Got message: {payload}") -def callback1(payload): - print("Callback 1: ", payload) +async def main(): + # your phoenix server token + TOKEN = "" + # your phoenix server URL + URL = f"ws://127.0.0.1:4000/socket/websocket?token={TOKEN}&vsn=2.0.0" -def callback2(payload): - print("Callback 2: ", payload) + client = Socket(URL) + # connect to the server + await client.connect() -if __name__ == "__main__": - URL = "ws://localhost:4000/socket/websocket" - s = Socket(URL) - s.connect() + # fire and forget the listening routine + listen_task = asyncio.ensure_future(client.listen()) - channel_1 = s.set_channel("realtime:public:todos") - channel_1.join().on("UPDATE", callback1) + # join the channel + channel = client.set_channel("this:is:my:topic") + await channel.join() - channel_2 = s.set_channel("realtime:public:users") - channel_2.join().on("*", callback2) + # by using a partial function + channel.on("your_event_name", None, callback1) - s.listen() + # we give it some time to complete + await asyncio.sleep(10) + + # proper shut down + listen_task.cancel() + +if __name__ == '__main__': + loop = asyncio.get_event_loop() + try: + loop.run_until_complete(main()) + + except KeyboardInterrupt: + loop.stop() + exit(0) \ No newline at end of file