Skip to content

Commit 4da5ad3

Browse files
committed
fix(request-audio): loop through model_names
Signed-off-by: Max Wittig <[email protected]>
1 parent b6dd717 commit 4da5ad3

File tree

3 files changed

+39
-102
lines changed

3 files changed

+39
-102
lines changed

src/vllm_router/service_discovery.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,8 @@ class EndpointInfo:
9494
# Model label
9595
model_label: str
9696

97+
model_type: str
98+
9799
# Endpoint's sleep status
98100
sleep: bool
99101

@@ -306,13 +308,15 @@ def get_endpoint_info(self) -> List[EndpointInfo]:
306308
):
307309
continue
308310
model_label = self.model_labels[i] if self.model_labels else "default"
311+
model_type = self.model_types[i] if self.model_types else "default"
309312
endpoint_info = EndpointInfo(
310313
url=url,
311314
model_names=[model], # Convert single model to list
312315
Id=self.engines_id[i],
313316
sleep=False,
314317
added_timestamp=self.added_timestamp,
315318
model_label=model_label,
319+
model_type=model_type,
316320
model_info=self._get_model_info(model),
317321
)
318322
endpoint_infos.append(endpoint_info)

src/vllm_router/services/request_service/request.py

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -539,7 +539,6 @@ async def route_general_transcriptions(
539539
content={"error": f"Invalid request: missing '{e.args[0]}' in form data."},
540540
)
541541

542-
logger.debug("==== Enter audio_transcriptions ====")
543542
logger.debug("Received upload: %s (%s)", file.filename, file.content_type)
544543
logger.debug(
545544
"Params: model=%s prompt=%r response_format=%r temperature=%r language=%s",
@@ -565,18 +564,12 @@ async def route_general_transcriptions(
565564

566565
endpoints = service_discovery.get_endpoint_info()
567566

568-
logger.debug("==== Total endpoints ====")
569-
logger.debug(endpoints)
570-
logger.debug("==== Total endpoints ====")
571-
572-
# filter the endpoints url by model name and label for transcriptions
573-
transcription_endpoints = [
574-
ep
575-
for ep in endpoints
576-
if model == ep.model_name
577-
and ep.model_label == "transcription"
578-
and not ep.sleep # Added ep.sleep == False
579-
]
567+
# filter the endpoints url by model name and model_type for transcriptions
568+
transcription_endpoints = []
569+
for ep in endpoints:
570+
for model_name in ep.model_names:
571+
if model == model_name and ep.model_type == "transcription" and not ep.sleep:
572+
transcription_endpoints.append(ep)
580573

581574
logger.debug("====List of transcription endpoints====")
582575
logger.debug(transcription_endpoints)
@@ -620,10 +613,6 @@ async def route_general_transcriptions(
620613

621614
logger.info("Proxying transcription request for model %s to %s", model, chosen_url)
622615

623-
logger.debug("==== data payload keys ====")
624-
logger.debug(list(data.keys()))
625-
logger.debug("==== data payload keys ====")
626-
627616
try:
628617
client = request.app.state.aiohttp_client_wrapper()
629618

@@ -687,3 +676,9 @@ async def route_general_transcriptions(
687676
status_code=503,
688677
content={"error": f"Failed to connect to backend: {str(client_error)}"},
689678
)
679+
except Exception as e:
680+
logger.error(e)
681+
return JSONResponse(
682+
status_code=500,
683+
content={"error": f"Internal server error"},
684+
)

0 commit comments

Comments
 (0)