Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 63 additions & 3 deletions ai_spider/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,22 @@
from asyncio import Queue
from json import JSONDecodeError
from typing import Optional, cast, AsyncIterator
from tempfile import NamedTemporaryFile

import fastapi
import starlette.websockets
import websockets
from dotenv import load_dotenv
from fastapi import FastAPI, WebSocket, Request, HTTPException, Depends
from fastapi import FastAPI, WebSocket, Request, HTTPException, Depends, Form
from fastapi.exceptions import RequestValidationError
from fastapi.responses import FileResponse
from fastapi.responses import JSONResponse
from starlette.exceptions import HTTPException as StarletteHTTPException
from starlette.requests import HTTPConnection

from .db import init_db_store
from .openai_types import CompletionChunk, ChatCompletion, CreateChatCompletionRequest, EmbeddingRequest, Embedding, \
ImageGenerationRequest, ImageGenerationResponse
ImageGenerationRequest, ImageGenerationResponse, AudioTranscriptionRequest

from fastapi.middleware.cors import CORSMiddleware

Expand All @@ -39,6 +41,7 @@
log = logging.getLogger(__name__)

load_dotenv()
from fastapi import FastAPI, File, UploadFile

SECRET_KEY = os.environ["SECRET_KEY"]
APP_NAME = os.environ.get("APP_NAME", "GPUTopia QueenBee")
Expand Down Expand Up @@ -168,6 +171,12 @@ async def worker_stats() -> dict:
return mgr.worker_stats()


@app.get("/worker/extended", tags=["worker"])
async def worker_detail_extended() -> dict:
"""List of all workers, with detailed info"""
mgr = get_reg_mgr()
return mgr.worker_detail()

@app.get("/worker/detail", tags=["worker"])
async def worker_detail(query: Optional[str] = None, user_id: str = Depends(optional_bearer_token)) -> list:
"""List of all workers, with anonymized info"""
Expand All @@ -178,6 +187,54 @@ async def worker_detail(query: Optional[str] = None, user_id: str = Depends(opti
return []
return mgr.worker_anon_detail(query=query)

async def save_temp(file):
"""Save file in a temp location and return the path
need to delete the file when is readed succesffully for a worker.
"""
with NamedTemporaryFile("wb", delete=False) as fp:
fp.write(await file.read())
fp.close()
return fp.name

@app.get("/storage/")
def get_temp_file(filename: str):
"""Return a file from a temp location"""
# TODO: maybe we need pass only the filename or a key and not the full path
return FileResponse(filename, media_type="application/octet-stream", filename=os.path.basename(filename))

@app.post("/v1/audio/transcriptions")
async def post_audio_transcription(
request: Request,
model: str = Form(...),
file: UploadFile = File(...),
) -> dict:
await check_creds_and_funds(request)
tmp_file = await save_temp(file)
body = AudioTranscriptionRequest(model=model, file=tmp_file)
worker_type = worker_type_from_model_name(body.model)
msize = get_model_size(body.model)
mgr = get_reg_mgr()
gpu_filter = body.gpu_filter
gpu_filter["capabilities"] = ["whisper"]
try:
with mgr.get_socket_for_inference(msize, worker_type, gpu_filter) as ws:
js, job_time = await single_response_model_job("/v1/audio/transcriptions", body.model_dump(), ws)
schedule_task(check_bill_usage(request, msize, js, ws.info, job_time))
return js
except HTTPException as ex:
log.error("transcription failed : %s", repr(ex))
raise
except TimeoutError as ex:
log.error("transcription failed : %s", repr(ex))
raise HTTPException(408, detail=repr(ex))
except AssertionError as ex:
log.error("transcription failed : %s", repr(ex))
raise HTTPException(400, detail=repr(ex))
except Exception as ex:
log.exception("unknown error : %s", repr(ex))
raise HTTPException(500, detail=repr(ex))
finally:
os.unlink(tmp_file)

@app.post("/v1/chat/completions")
async def create_chat_completion(
Expand All @@ -202,7 +259,9 @@ async def create_chat_completion(
raise
log.error("try again: %s: ", repr(ex))
await asyncio.sleep(0.25)
with mgr.get_socket_for_inference(msize, worker_type, gpu_filter) as ws:
if isinstance(ex, HTTPException) and "socked dropped" in ex.detail:
mgr.drop_worker(ws)
with mgr.get_socket_for_inference(msize, worker_type, gpu_filter, avoid=[ws]) as ws:
return await do_inference(request, body, ws, final=True)
except HTTPException as ex:
log.error("inference failed : %s", repr(ex))
Expand Down Expand Up @@ -532,6 +591,7 @@ async def worker_connect(websocket: WebSocket):
# get the source ip, for long-term punishment of bad actors
req = HTTPConnection(websocket.scope)
websocket.info["ip"] = get_ip(req)
js["ip"] = websocket.info["ip"]

websocket.queue = Queue()
websocket.results = Queue()
Expand Down
5 changes: 5 additions & 0 deletions ai_spider/openai_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,11 @@ class ChatCompletionRequestMessage(BaseModel):
default=0.1, ge=0.001, le=1.0, description="Mirostat learning rate"
)

class AudioTranscriptionRequest(BaseModel):
model: str
file: str
gpu_filter: dict = {}
timeout: int = 60 * 5

class CreateChatCompletionRequest(BaseModel):
messages: List[ChatCompletionRequestMessage] = Field(
Expand Down
9 changes: 8 additions & 1 deletion ai_spider/workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ def anon_info(ent, **fin):
info = ent.info
fin["worker_version"] = info.get("worker_version")
fin["capabilities"] = info.get("capabilities", [])
fin["balena_device_name"] = info.get("balena_device_name", "")
fin["ip"] = info.get("ip")
nv_gpu_cnt = sum([1 for _ in info.get("nv_gpus", [])])
cl_gpu_cnt = sum([1 for _ in info.get("cl_gpus", [])])
web_gpu_cnt = sum([1 for _ in info.get("web_gpus", [])])
Expand Down Expand Up @@ -56,7 +58,7 @@ def drop_worker(self, sock):
self.busy.pop(sock, None)

@contextlib.contextmanager
def get_socket_for_inference(self, msize: int, worker_type: WORKER_TYPES, gpu_filter={}) \
def get_socket_for_inference(self, msize: int, worker_type: WORKER_TYPES, gpu_filter={}, avoid=[]) \
-> Generator[QueueSocket, None, None]:
# msize is params adjusted by quant level with a heuristic

Expand All @@ -76,11 +78,15 @@ def get_socket_for_inference(self, msize: int, worker_type: WORKER_TYPES, gpu_fi
good = []
close = []
for sock, info in self.socks.items():
if sock in avoid:
continue
cpu_vram = info.get("vram", 0)
disk_space = info.get("disk_space", 0)
nv_gpu_ram = sum([el.get("memory", 0) for el in info.get("nv_gpus", [])])
cl_gpu_ram = sum([el.get("memory", 0) for el in info.get("cl_gpus", [])])
have_web_gpus = is_web_worker(info)
if not nv_gpu_ram and not cl_gpu_ram and not have_web_gpus:
continue

if ver := gpu_filter.get("min_version"):
try:
Expand Down Expand Up @@ -118,6 +124,7 @@ def get_socket_for_inference(self, msize: int, worker_type: WORKER_TYPES, gpu_fi
good.append(sock)

if worker_type in ("any", "cli"):
good.append(sock)
if gpu_needed < nv_gpu_ram and cpu_needed < cpu_vram and disk_needed < disk_space:
good.append(sock)
elif gpu_needed < cl_gpu_ram and cpu_needed < cpu_vram and disk_needed < disk_space:
Expand Down