Skip to content

Commit

Permalink
Merge pull request #24 from TogetherCrew/feature/async
Browse files Browse the repository at this point in the history
Add new tc-broker-lib's version
  • Loading branch information
cyri113 authored Jan 1, 2024
2 parents d3fad4a + b3f8212 commit 1714cc6
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 131 deletions.
2 changes: 1 addition & 1 deletion packages/ml/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ dependencies = [
"backoff<3.0.0,>=2.2.1",
"aiofiles>=23.2.1",
"aio-pika>=9.3.0",
"tc-messagebroker>=1.6.0",
"tc-messageBroker>=1.6.3",
"faiss-cpu>=1.7.4",
]
name = "QABot"
Expand Down
136 changes: 74 additions & 62 deletions packages/ml/src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,23 @@
import os
import traceback
from asyncio import Task
from typing import ClassVar

from logger.hivemind_logger import logger
from server.async_broker import AsyncBroker
from server.broker import EventBroker
from utils.constants import HIVEMIND_API_PORT
from utils.util import configure_logging

import aiormq
from tc_messageBroker.rabbit_mq.event import Event
from tc_messageBroker.rabbit_mq.queue import Queue

os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

if __name__ == "__main__":
import uvicorn

uvicorn.run("api:app", host="0.0.0.0",
port=int(HIVEMIND_API_PORT),
reload=False)
uvicorn.run("api:app", host="0.0.0.0", port=int(HIVEMIND_API_PORT), reload=False)
else:

import asyncio
from typing import AsyncGenerator
from asgi_correlation_id import CorrelationIdMiddleware
Expand All @@ -37,8 +36,10 @@

from asgi_correlation_id import correlation_id

ERROR_MESSAGE = f"An error occurred when trying to answer your question. An error report has been sent to " \
f"the developers."
ERROR_MESSAGE = (
f"An error occurred when trying to answer your question. An error report has been sent to "
f"the developers."
)

app = FastAPI()

Expand All @@ -51,21 +52,21 @@
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
expose_headers=['X-Request-ID'],
expose_headers=["X-Request-ID"],
)
#
app.add_middleware(CorrelationIdMiddleware)


@app.exception_handler(Exception)
async def unhandled_exception_handler(request: Request, exc: Exception) -> Response:
return await http_exception_handler(
request,
HTTPException(
status_code=500,
detail=f'Internal server error: {exc}',
headers={'X-Request-ID': correlation_id.get() or ""}
))
detail=f"Internal server error: {exc}",
headers={"X-Request-ID": correlation_id.get() or ""},
),
)

# create / route for health check
@app.get("/")
Expand All @@ -74,43 +75,49 @@ async def root():

class AsyncResponse(BaseModel):
# background_tasks: BackgroundTasks
callback_handler = AsyncChunkIteratorCallbackHandler()
total_tokens = 0
callback_handler: ClassVar = AsyncChunkIteratorCallbackHandler()
total_tokens: int = 0

class Config:
arbitrary_types_allowed = True

@staticmethod
def get_agent():
from main import load

return load()

async def generate_response(self, request: Request, question: str) -> AsyncGenerator[str, None]:
async def generate_response(
self, request: Request, question: str
) -> AsyncGenerator[str, None]:
run: asyncio.Task | None = None

session = ''
session = ""
try:
session = f"{request.headers['x-request-id']}"
logger.debug(f"session: {session}")

agent = AsyncResponse.get_agent()
run = asyncio.create_task(agent.run(question, self.callback_handler))

logger.info('Running...')
logger.info("Running...")

async for response in self.callback_handler.aiter():
# check token type
logger.info(response.__dict__)
if isinstance(response, TextChunk):
res_token = f'{response.token}\n\n'

# await a_publish(Queue.DISCORD_BOT, Event.DISCORD_BOT.SEND_MESSAGE,
# await eb.a_publish(Queue.DISCORD_BOT, "SEND_MESSAGE",
# {
# "uuid": f"s-{session}",
# "question": question,
# "streaming": res_token,
# })
res_token = f"{response.token}\n\n"

await broker.publish(
Queue.DISCORD_BOT,
Event.DISCORD_BOT.SEND_MESSAGE,
{
"uuid": f"s-{session}",
"question": question,
"streaming": res_token,
},
)

yield res_token
elif isinstance(response, InfoChunk):
self.total_tokens += response.count_tokens
Expand All @@ -123,15 +130,20 @@ async def generate_response(self, request: Request, question: str) -> AsyncGener
except BaseException as e: # asyncio.CancelledError
logger.error(f"session:{session} Caught BaseException: {str(e)}")
print(traceback.format_exc())
logger.exception('Something got wrong')
logger.exception("Something got wrong")
yield ERROR_MESSAGE
finally:
logger.info(f'Total tokens used: {self.total_tokens}')
# await eb.a_publish(Queue.DISCORD_BOT, "SEND_MESSAGE",
# {"uuid": f"s-{session}",
# "question": question,
# "user": user.json(),
# "total_tokens": self.total_tokens})
logger.info(f"Total tokens used: {self.total_tokens}")

await broker.publish(
Queue.DISCORD_BOT,
Event.DISCORD_BOT.SEND_MESSAGE,
{
"uuid": f"s-{session}",
"question": question,
"total_tokens": self.total_tokens,
},
)

# yield ERROR_MESSAGE

Expand All @@ -151,61 +163,61 @@ async def streamer(gen: AsyncGenerator[str, None]):
class Ask(BaseModel):
question: str


@app.post("/ask/")
async def ask(request: Request, body: Ask) -> StreamingResponse:

logger.info(f"Received question:{body.question}")
session = f"{request.headers['x-request-id']}"
logger.debug(f'session: {session}')
logger.debug(f"session: {session}")

# await eb.a_publish(Queue.DISCORD_BOT, "SEND_MESSAGE",
# {
# "uuid": f"s-{session}",
# "question": body.question,
# "user": current_user.json(),
# })
await broker.publish(
Queue.DISCORD_BOT,
Event.DISCORD_BOT.SEND_MESSAGE,
{
"uuid": f"s-{session}",
"question": body.question,
},
)

ar = AsyncResponse()
return StreamingResponse(AsyncResponse.streamer(ar.generate_response(request, body.question)))

return StreamingResponse(
AsyncResponse.streamer(ar.generate_response(request, body.question))
)

def log_event(msg: str, queue_name: str, event_name: str):
logger.info(f"{queue_name}->{event_name}::{msg}")


######################################
### STARTUP GLOBALS VARIABLES HERE ###
######################################

ab = AsyncBroker()
broker = EventBroker()
hivemind_task: Task | None = None


@app.on_event("startup")
async def startup():
configure_logging()

try:
await ab.connect()
loop = asyncio.get_event_loop()
loop.set_debug(True)
# logger.info(loop)

# hivemind_task = asyncio.get_event_loop().create_task(ab.listen(queue_name=Queue.HIVEMIND,
# event_name=Event.HIVEMIND.GUILD_MESSAGES_UPDATED,
# callback=log_event
# ))
except aiormq.exceptions.AMQPConnectionError as amqp:
logger.error(amqp)
hivemind_task = asyncio.get_event_loop().create_task(
broker.listen(
queue=Queue.HIVEMIND,
event=Event.HIVEMIND.GUILD_MESSAGES_UPDATED,
callback=log_event,
)
)

print("Server Startup!")
await hivemind_task

except Exception as error:
logger.error(error)

print("Server Startup!")

@app.on_event("shutdown")
async def shutdown():
if hivemind_task:
hivemind_task.cancel()
if ab and ab.connection:
await ab.connection.close()

if broker:
await broker.close()
print("Server Shutdown!")
Loading

0 comments on commit 1714cc6

Please sign in to comment.