diff --git a/ai_spider/app.py b/ai_spider/app.py index 10ee639..767f8d8 100644 --- a/ai_spider/app.py +++ b/ai_spider/app.py @@ -8,20 +8,22 @@ from asyncio import Queue from json import JSONDecodeError from typing import Optional, cast, AsyncIterator +from tempfile import NamedTemporaryFile import fastapi import starlette.websockets import websockets from dotenv import load_dotenv -from fastapi import FastAPI, WebSocket, Request, HTTPException, Depends +from fastapi import FastAPI, WebSocket, Request, HTTPException, Depends, Form from fastapi.exceptions import RequestValidationError +from fastapi.responses import FileResponse from fastapi.responses import JSONResponse from starlette.exceptions import HTTPException as StarletteHTTPException from starlette.requests import HTTPConnection from .db import init_db_store from .openai_types import CompletionChunk, ChatCompletion, CreateChatCompletionRequest, EmbeddingRequest, Embedding, \ - ImageGenerationRequest, ImageGenerationResponse + ImageGenerationRequest, ImageGenerationResponse, AudioTranscriptionRequest from fastapi.middleware.cors import CORSMiddleware @@ -39,6 +41,7 @@ log = logging.getLogger(__name__) load_dotenv() +from fastapi import FastAPI, File, UploadFile SECRET_KEY = os.environ["SECRET_KEY"] APP_NAME = os.environ.get("APP_NAME", "GPUTopia QueenBee") @@ -168,6 +171,12 @@ async def worker_stats() -> dict: return mgr.worker_stats() +@app.get("/worker/extended", tags=["worker"]) +async def worker_detail_extended() -> dict: + """List of all workers, with detailed info""" + mgr = get_reg_mgr() + return mgr.worker_detail() + @app.get("/worker/detail", tags=["worker"]) async def worker_detail(query: Optional[str] = None, user_id: str = Depends(optional_bearer_token)) -> list: """List of all workers, with anonymized info""" @@ -178,6 +187,54 @@ async def worker_detail(query: Optional[str] = None, user_id: str = Depends(opti return [] return mgr.worker_anon_detail(query=query) +async def save_temp(file): + """Save file in a temp location and return the path + need to delete the file when is readed succesffully for a worker. + """ + with NamedTemporaryFile("wb", delete=False) as fp: + fp.write(await file.read()) + fp.close() + return fp.name + +@app.get("/storage/") +def get_temp_file(filename: str): + """Return a file from a temp location""" + # TODO: maybe we need pass only the filename or a key and not the full path + return FileResponse(filename, media_type="application/octet-stream", filename=os.path.basename(filename)) + +@app.post("/v1/audio/transcriptions") +async def post_audio_transcription( + request: Request, + model: str = Form(...), + file: UploadFile = File(...), +) -> dict: + await check_creds_and_funds(request) + tmp_file = await save_temp(file) + body = AudioTranscriptionRequest(model=model, file=tmp_file) + worker_type = worker_type_from_model_name(body.model) + msize = get_model_size(body.model) + mgr = get_reg_mgr() + gpu_filter = body.gpu_filter + gpu_filter["capabilities"] = ["whisper"] + try: + with mgr.get_socket_for_inference(msize, worker_type, gpu_filter) as ws: + js, job_time = await single_response_model_job("/v1/audio/transcriptions", body.model_dump(), ws) + schedule_task(check_bill_usage(request, msize, js, ws.info, job_time)) + return js + except HTTPException as ex: + log.error("transcription failed : %s", repr(ex)) + raise + except TimeoutError as ex: + log.error("transcription failed : %s", repr(ex)) + raise HTTPException(408, detail=repr(ex)) + except AssertionError as ex: + log.error("transcription failed : %s", repr(ex)) + raise HTTPException(400, detail=repr(ex)) + except Exception as ex: + log.exception("unknown error : %s", repr(ex)) + raise HTTPException(500, detail=repr(ex)) + finally: + os.unlink(tmp_file) @app.post("/v1/chat/completions") async def create_chat_completion( @@ -202,7 +259,9 @@ async def create_chat_completion( raise log.error("try again: %s: ", repr(ex)) await asyncio.sleep(0.25) - with mgr.get_socket_for_inference(msize, worker_type, gpu_filter) as ws: + if isinstance(ex, HTTPException) and "socked dropped" in ex.detail: + mgr.drop_worker(ws) + with mgr.get_socket_for_inference(msize, worker_type, gpu_filter, avoid=[ws]) as ws: return await do_inference(request, body, ws, final=True) except HTTPException as ex: log.error("inference failed : %s", repr(ex)) @@ -532,6 +591,7 @@ async def worker_connect(websocket: WebSocket): # get the source ip, for long-term punishment of bad actors req = HTTPConnection(websocket.scope) websocket.info["ip"] = get_ip(req) + js["ip"] = websocket.info["ip"] websocket.queue = Queue() websocket.results = Queue() diff --git a/ai_spider/openai_types.py b/ai_spider/openai_types.py index b9a518e..7adb8c3 100644 --- a/ai_spider/openai_types.py +++ b/ai_spider/openai_types.py @@ -193,6 +193,11 @@ class ChatCompletionRequestMessage(BaseModel): default=0.1, ge=0.001, le=1.0, description="Mirostat learning rate" ) +class AudioTranscriptionRequest(BaseModel): + model: str + file: str + gpu_filter: dict = {} + timeout: int = 60 * 5 class CreateChatCompletionRequest(BaseModel): messages: List[ChatCompletionRequestMessage] = Field( diff --git a/ai_spider/workers.py b/ai_spider/workers.py index 2fecd2c..ee19e86 100644 --- a/ai_spider/workers.py +++ b/ai_spider/workers.py @@ -29,6 +29,8 @@ def anon_info(ent, **fin): info = ent.info fin["worker_version"] = info.get("worker_version") fin["capabilities"] = info.get("capabilities", []) + fin["balena_device_name"] = info.get("balena_device_name", "") + fin["ip"] = info.get("ip") nv_gpu_cnt = sum([1 for _ in info.get("nv_gpus", [])]) cl_gpu_cnt = sum([1 for _ in info.get("cl_gpus", [])]) web_gpu_cnt = sum([1 for _ in info.get("web_gpus", [])]) @@ -56,7 +58,7 @@ def drop_worker(self, sock): self.busy.pop(sock, None) @contextlib.contextmanager - def get_socket_for_inference(self, msize: int, worker_type: WORKER_TYPES, gpu_filter={}) \ + def get_socket_for_inference(self, msize: int, worker_type: WORKER_TYPES, gpu_filter={}, avoid=[]) \ -> Generator[QueueSocket, None, None]: # msize is params adjusted by quant level with a heuristic @@ -76,11 +78,15 @@ def get_socket_for_inference(self, msize: int, worker_type: WORKER_TYPES, gpu_fi good = [] close = [] for sock, info in self.socks.items(): + if sock in avoid: + continue cpu_vram = info.get("vram", 0) disk_space = info.get("disk_space", 0) nv_gpu_ram = sum([el.get("memory", 0) for el in info.get("nv_gpus", [])]) cl_gpu_ram = sum([el.get("memory", 0) for el in info.get("cl_gpus", [])]) have_web_gpus = is_web_worker(info) + if not nv_gpu_ram and not cl_gpu_ram and not have_web_gpus: + continue if ver := gpu_filter.get("min_version"): try: @@ -118,6 +124,7 @@ def get_socket_for_inference(self, msize: int, worker_type: WORKER_TYPES, gpu_fi good.append(sock) if worker_type in ("any", "cli"): + good.append(sock) if gpu_needed < nv_gpu_ram and cpu_needed < cpu_vram and disk_needed < disk_space: good.append(sock) elif gpu_needed < cl_gpu_ram and cpu_needed < cpu_vram and disk_needed < disk_space: