diff --git a/packages/ml/pyproject.toml b/packages/ml/pyproject.toml index 3428a4d..551a2d6 100644 --- a/packages/ml/pyproject.toml +++ b/packages/ml/pyproject.toml @@ -47,7 +47,7 @@ dependencies = [ "backoff<3.0.0,>=2.2.1", "aiofiles>=23.2.1", "aio-pika>=9.3.0", - "tc-messagebroker>=1.6.0", + "tc-messageBroker>=1.6.3", "faiss-cpu>=1.7.4", ] name = "QABot" diff --git a/packages/ml/src/api.py b/packages/ml/src/api.py index eb460b7..225c8d5 100644 --- a/packages/ml/src/api.py +++ b/packages/ml/src/api.py @@ -2,24 +2,23 @@ import os import traceback from asyncio import Task +from typing import ClassVar from logger.hivemind_logger import logger -from server.async_broker import AsyncBroker +from server.broker import EventBroker from utils.constants import HIVEMIND_API_PORT from utils.util import configure_logging -import aiormq +from tc_messageBroker.rabbit_mq.event import Event +from tc_messageBroker.rabbit_mq.queue import Queue os.environ["CUDA_VISIBLE_DEVICES"] = "-1" if __name__ == "__main__": import uvicorn - uvicorn.run("api:app", host="0.0.0.0", - port=int(HIVEMIND_API_PORT), - reload=False) + uvicorn.run("api:app", host="0.0.0.0", port=int(HIVEMIND_API_PORT), reload=False) else: - import asyncio from typing import AsyncGenerator from asgi_correlation_id import CorrelationIdMiddleware @@ -37,8 +36,10 @@ from asgi_correlation_id import correlation_id - ERROR_MESSAGE = f"An error occurred when trying to answer your question. An error report has been sent to " \ - f"the developers." + ERROR_MESSAGE = ( + f"An error occurred when trying to answer your question. An error report has been sent to " + f"the developers." + ) app = FastAPI() @@ -51,21 +52,21 @@ allow_credentials=True, allow_methods=["*"], allow_headers=["*"], - expose_headers=['X-Request-ID'], + expose_headers=["X-Request-ID"], ) # app.add_middleware(CorrelationIdMiddleware) - @app.exception_handler(Exception) async def unhandled_exception_handler(request: Request, exc: Exception) -> Response: return await http_exception_handler( request, HTTPException( status_code=500, - detail=f'Internal server error: {exc}', - headers={'X-Request-ID': correlation_id.get() or ""} - )) + detail=f"Internal server error: {exc}", + headers={"X-Request-ID": correlation_id.get() or ""}, + ), + ) # create / route for health check @app.get("/") @@ -74,8 +75,8 @@ async def root(): class AsyncResponse(BaseModel): # background_tasks: BackgroundTasks - callback_handler = AsyncChunkIteratorCallbackHandler() - total_tokens = 0 + callback_handler: ClassVar = AsyncChunkIteratorCallbackHandler() + total_tokens: int = 0 class Config: arbitrary_types_allowed = True @@ -83,12 +84,15 @@ class Config: @staticmethod def get_agent(): from main import load + return load() - async def generate_response(self, request: Request, question: str) -> AsyncGenerator[str, None]: + async def generate_response( + self, request: Request, question: str + ) -> AsyncGenerator[str, None]: run: asyncio.Task | None = None - session = '' + session = "" try: session = f"{request.headers['x-request-id']}" logger.debug(f"session: {session}") @@ -96,21 +100,24 @@ async def generate_response(self, request: Request, question: str) -> AsyncGener agent = AsyncResponse.get_agent() run = asyncio.create_task(agent.run(question, self.callback_handler)) - logger.info('Running...') + logger.info("Running...") async for response in self.callback_handler.aiter(): # check token type logger.info(response.__dict__) if isinstance(response, TextChunk): - res_token = f'{response.token}\n\n' - - # await a_publish(Queue.DISCORD_BOT, Event.DISCORD_BOT.SEND_MESSAGE, - # await eb.a_publish(Queue.DISCORD_BOT, "SEND_MESSAGE", - # { - # "uuid": f"s-{session}", - # "question": question, - # "streaming": res_token, - # }) + res_token = f"{response.token}\n\n" + + await broker.publish( + Queue.DISCORD_BOT, + Event.DISCORD_BOT.SEND_MESSAGE, + { + "uuid": f"s-{session}", + "question": question, + "streaming": res_token, + }, + ) + yield res_token elif isinstance(response, InfoChunk): self.total_tokens += response.count_tokens @@ -123,15 +130,20 @@ async def generate_response(self, request: Request, question: str) -> AsyncGener except BaseException as e: # asyncio.CancelledError logger.error(f"session:{session} Caught BaseException: {str(e)}") print(traceback.format_exc()) - logger.exception('Something got wrong') + logger.exception("Something got wrong") yield ERROR_MESSAGE finally: - logger.info(f'Total tokens used: {self.total_tokens}') - # await eb.a_publish(Queue.DISCORD_BOT, "SEND_MESSAGE", - # {"uuid": f"s-{session}", - # "question": question, - # "user": user.json(), - # "total_tokens": self.total_tokens}) + logger.info(f"Total tokens used: {self.total_tokens}") + + await broker.publish( + Queue.DISCORD_BOT, + Event.DISCORD_BOT.SEND_MESSAGE, + { + "uuid": f"s-{session}", + "question": question, + "total_tokens": self.total_tokens, + }, + ) # yield ERROR_MESSAGE @@ -151,61 +163,61 @@ async def streamer(gen: AsyncGenerator[str, None]): class Ask(BaseModel): question: str - @app.post("/ask/") async def ask(request: Request, body: Ask) -> StreamingResponse: - logger.info(f"Received question:{body.question}") session = f"{request.headers['x-request-id']}" - logger.debug(f'session: {session}') + logger.debug(f"session: {session}") - # await eb.a_publish(Queue.DISCORD_BOT, "SEND_MESSAGE", - # { - # "uuid": f"s-{session}", - # "question": body.question, - # "user": current_user.json(), - # }) + await broker.publish( + Queue.DISCORD_BOT, + Event.DISCORD_BOT.SEND_MESSAGE, + { + "uuid": f"s-{session}", + "question": body.question, + }, + ) ar = AsyncResponse() - return StreamingResponse(AsyncResponse.streamer(ar.generate_response(request, body.question))) - + return StreamingResponse( + AsyncResponse.streamer(ar.generate_response(request, body.question)) + ) def log_event(msg: str, queue_name: str, event_name: str): logger.info(f"{queue_name}->{event_name}::{msg}") - ###################################### ### STARTUP GLOBALS VARIABLES HERE ### ###################################### - ab = AsyncBroker() + broker = EventBroker() hivemind_task: Task | None = None - @app.on_event("startup") async def startup(): configure_logging() try: - await ab.connect() - loop = asyncio.get_event_loop() - loop.set_debug(True) - # logger.info(loop) - - # hivemind_task = asyncio.get_event_loop().create_task(ab.listen(queue_name=Queue.HIVEMIND, - # event_name=Event.HIVEMIND.GUILD_MESSAGES_UPDATED, - # callback=log_event - # )) - except aiormq.exceptions.AMQPConnectionError as amqp: - logger.error(amqp) + hivemind_task = asyncio.get_event_loop().create_task( + broker.listen( + queue=Queue.HIVEMIND, + event=Event.HIVEMIND.GUILD_MESSAGES_UPDATED, + callback=log_event, + ) + ) - print("Server Startup!") + await hivemind_task + except Exception as error: + logger.error(error) + + print("Server Startup!") @app.on_event("shutdown") async def shutdown(): if hivemind_task: hivemind_task.cancel() - if ab and ab.connection: - await ab.connection.close() + + if broker: + await broker.close() print("Server Shutdown!") diff --git a/packages/ml/src/server/broker.py b/packages/ml/src/server/broker.py index 38f4530..223b8a0 100644 --- a/packages/ml/src/server/broker.py +++ b/packages/ml/src/server/broker.py @@ -1,12 +1,11 @@ import asyncio import logging -import os import threading -from enum import Enum -from typing import Callable, Any, Type +from typing import Callable, Any, Type, Union from tc_messageBroker import RabbitMQ from tc_messageBroker.rabbit_mq.event import Event +from tc_messageBroker.rabbit_mq.queue import Queue import tc_messageBroker.rabbit_mq.event.events_microservice as Events from logger.hivemind_logger import logger @@ -37,22 +36,38 @@ class EventBroker: rabbit_mq: RabbitMQ def __init__(self): - self.rabbit_mq = None self.broker_url = constants.RABBITMQ_HOST self.port = constants.RABBITMQ_PORT self.username = constants.RABBITMQ_USER self.password = constants.RABBITMQ_PASS logger.info(f"__init__ broker_url: {self.broker_url}:{self.port}") - self.connect() + self.rmq_consume = RabbitMQ( + broker_url=self.broker_url, + port=self.port, + username=self.username, + password=self.password, + ) + + self.rmq_publish = RabbitMQ( + broker_url=self.broker_url, + port=self.port, + username=self.username, + password=self.password, + ) @staticmethod def get_queue_by_event(string_to_find: str): member_mapping = {} - for cls in (Event, Events.BotBaseEvent, Events.AnalyzerBaseEvent, - Events.ServerEvent, Events.DiscordBotEvent, - # Events.DiscordAnalyzerEvent, Events.HivemindEvent): # if there not UNIQUE events across all classes that will get the last one - Events.DiscordAnalyzerEvent): # if there not UNIQUE events across all classes that will get the last one + for cls in ( + Event, + Events.BotBaseEvent, + Events.AnalyzerBaseEvent, + Events.ServerEvent, + Events.DiscordBotEvent, + # Events.DiscordAnalyzerEvent, Events.HivemindEvent): # if there not UNIQUE events across all classes that will get the last one + Events.DiscordAnalyzerEvent, + ): # if there not UNIQUE events across all classes that will get the last one for name, value in cls.__dict__.items(): member_mapping[name] = cls # print(member_mapping) @@ -71,86 +86,53 @@ def get_queue_by_event(string_to_find: str): return _queue_found - def connect(self) -> RabbitMQ: - if self.rabbit_mq is None: - logger.info(f"broker_url: {self.broker_url}:{self.port}") - - self.rabbit_mq = RabbitMQ( - broker_url=self.broker_url, port=self.port, username=self.username, password=self.password - ) - - return self.rabbit_mq - - async def a_listen(self, queue: str, event: str, callback: Callable): - asyncio.get_event_loop().run_in_executor(None, self.listen, queue, event, callback) - - def listen(self, queue: str, event: str, callback: Callable): - logger.debug("listening %s", queue) - - self.rabbit_mq.on_event(event, callback) - - # print("Waiting for messages...") - self.rabbit_mq.connect(queue) - print(f"Connected to {queue} queue!") - - self.rabbit_mq.consume(queue) - print("consume messages...") - if self.rabbit_mq.channel is not None: - print("listening messages...") - try: - self.rabbit_mq.channel.start_consuming() - print("Never reach here!") - except KeyboardInterrupt: - self.rabbit_mq.channel.stop_consuming() - print("Disconnected from broker successfully!") - else: - print("Connection to broker was not successful!") - - def add_event(self, event: str, callback: Callable): - self.rabbit_mq.on_event(event, callback) - - def t_listen(self, queue: str): + async def listen(self, queue: str, event: str, callback: Callable): logger.debug("listening %s", queue) - # print("Waiting for messages...") + async def job(queue: str, event: str, callback: Callable): + await self.rmq_consume.on_event_async(event, callback) - def consume_messages(): - self.rabbit_mq.connect(queue) + # print("Waiting for messages...") + await self.rmq_consume.connect_async(queue) print(f"Connected to {queue} queue!") - self.rabbit_mq.consume(queue) + + await self.rmq_consume.consume_async(queue) print("consume messages...") - if self.rabbit_mq.channel is not None: + if self.rmq_consume.channel is not None: print("listening messages...") try: - self.rabbit_mq.channel.start_consuming() + self.rmq_consume.channel.start_consuming() print("Never reach here!") except KeyboardInterrupt: - self.rabbit_mq.channel.stop_consuming() + self.rmq_consume.channel.stop_consuming() print("Disconnected from broker successfully!") else: print("Connection to broker was not successful!") - # Create a separate thread to run the consume_messages function - consume_thread = threading.Thread(target=consume_messages) - consume_thread.start() + threading.Thread( + target=asyncio.run, args=(job(queue, event, callback),) + ).start() - async def a_publish(self, queue: str, event: str, content: dict[str, Any] | None): - asyncio.get_event_loop().run_in_executor(None, self.publish, queue, event, content) + async def add_event(self, event: str, callback: Callable): + await self.rmq_consume.on_event_async(event, callback) - def publish(self, queue: str, event: str, content: dict[str, Any] | None): + async def publish(self, queue: str, event: str, content: dict[str, Any] | None): logger.debug("publishing %s", content) - self.rabbit_mq = RabbitMQ( - broker_url=self.broker_url, port=self.port, username=self.username, password=self.password - ) - - self.rabbit_mq.connect(queue) + await self.rmq_publish.connect_async(queue) if content is None: - content = {"uuid": "d99a1490-fba6-11ed-b9a9-0d29e7612dp8", "data": "some results"} + content = { + "uuid": "d99a1490-fba6-11ed-b9a9-0d29e7612dp8", + "data": "some results", + } - self.rabbit_mq.publish( + await self.rmq_publish.publish_async( queue, event=event, content=content, ) + + async def close(self): + self.rmq_consume.connection.close() + self.rmq_publish.connection.close()