Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,7 @@ async def _generate_text(self, message: dict, run_params: dict, error_tracker: E
if self.inference_type in (
constants.INFERENCE_SERVER_VLLM_CHAT_COMPLETION,
constants.OPENAI_CHAT_COMPLETION,
constants.GEMINI_CHAT_COMPLETION,
):
for i in range(num_chunks):
start = i * max_samples
Expand Down Expand Up @@ -381,6 +382,7 @@ async def _generate_text(self, message: dict, run_params: dict, error_tracker: E
if self.inference_type in (
constants.INFERENCE_SERVER_VLLM_CHAT_COMPLETION,
constants.OPENAI_CHAT_COMPLETION,
constants.GEMINI_CHAT_COMPLETION,
):
# Cut to first 30s, then process as chat completion
if audio_array is not None and len(audio_array) > 0:
Expand Down Expand Up @@ -472,6 +474,7 @@ async def _handle_multi_turn(self, message: dict, run_params: dict, error_tracke
if self.inference_type not in (
constants.INFERENCE_SERVER_VLLM_CHAT_COMPLETION,
constants.OPENAI_CHAT_COMPLETION,
constants.GEMINI_CHAT_COMPLETION,
):
raise ValueError("Multi-turn conversations only supported for chat completion inference types")

Expand Down
39 changes: 35 additions & 4 deletions models/request_resp_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import inspect

import httpx
from google.auth import default
from google.auth.transport.requests import Request
from openai import AsyncAzureOpenAI, AsyncOpenAI
from models.model_response import ModelResponse, ErrorTracker
from utils import constants
Expand All @@ -20,6 +22,8 @@ def __init__(self, inference_type: str, model_info: dict, generation_params: dic
self.model_info = model_info
self.api = model_info.get("url")
self.auth = model_info.get("auth_token", "")
self.location = model_info.get("location", "")
self.project_id = model_info.get("project_id", "")
self.api_version = model_info.get("api_version", "")
self.client = None
self.timeout = timeout
Expand Down Expand Up @@ -153,6 +157,25 @@ def set_client(self, verify_ssl: bool, timeout: int):
http_client=httpx.AsyncClient(verify=verify_ssl),
)
)
elif self.inference_type == constants.GEMINI_CHAT_COMPLETION:
# Gemini endpoints

# Set an API host for Gemini on Vertex AI
api_host = "aiplatform.googleapis.com"
if self.location != "global":
api_host = f"{self.location}-aiplatform.googleapis.com"

credentials, _ = default(scopes=["https://www.googleapis.com/auth/cloud-platform"])
credentials.refresh(Request())

self.client = AsyncOpenAI(
base_url=f"https://{api_host}/v1/projects/{self.project_id}/locations/{self.location}/endpoints/openapi",
api_key=credentials.token,
timeout=timeout,
max_retries=0,
default_headers={"Connection": "close"},
http_client=httpx.AsyncClient(verify=verify_ssl),
)

def validated_safe_generation_params(self, generation_params):
"""Validate and sanitize generation parameters for the OpenAI API client.
Expand Down Expand Up @@ -187,19 +210,27 @@ async def request_server(self, msg_body, tools=None, error_tracker: ErrorTracker
2. Any exception is wrapped in a `ModelResponse` with ``response_code = 500``.
"""
model_name: str | None = self.model_info.get("model")
reasoning_effort = self.model_info.get("reasoning_effort", None)
if tools:
tools = self.convert_to_tool(tools)

start_time = time.time()
# Re-create a fresh client for this request to avoid closed-loop issues
self.set_client(verify_ssl=True, timeout=self.timeout)
try:
if self.inference_type == constants.OPENAI_CHAT_COMPLETION or self.inference_type == constants.INFERENCE_SERVER_VLLM_CHAT_COMPLETION:
if self.inference_type in (constants.OPENAI_CHAT_COMPLETION, constants.INFERENCE_SERVER_VLLM_CHAT_COMPLETION, constants.GEMINI_CHAT_COMPLETION):
# openai chat completions, vllm chat completions
self.generation_params = self.validated_safe_generation_params(self.generation_params)
prediction = await self.client.chat.completions.create(
model=model_name, messages=msg_body, tools=tools, **self.generation_params
)

if reasoning_effort:
prediction = await self.client.chat.completions.create(
model=model_name, messages=msg_body, tools=tools, reasoning_effort=reasoning_effort, **self.generation_params
)
else:
prediction = await self.client.chat.completions.create(
model=model_name, messages=msg_body, tools=tools, **self.generation_params
)

raw_response: str = self._extract_response_data(prediction)
llm_response: str = raw_response['choices'][0]['message']['content'] or " "

Expand Down
6 changes: 5 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -51,4 +51,8 @@ nest_asyncio==1.6.0
immutabledict==4.2.1

# Debugging
debugpy==1.8.0 # VSCode debugging support
debugpy==1.8.0 # VSCode debugging support

# Gemimini libraries
google-auth==2.40.3
google-genai==1.38.0
73 changes: 42 additions & 31 deletions sample_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,64 +25,75 @@ filter:
num_samples: 100 # number of samples to run(remove for all)
length_filter: [0.0, 30.0] #optional - filters for only audio samples in this length(seconds) - only supported for general and callhome preprocessors

judge_properties:
judge_settings:
judge_concurrency: 8 #judge call(optional)
judge_model: "gpt-4o-mini" #optional
judge_type: "openai" # mandatory (vllm or openai)
judge_api_version: "${API_VERSION}" # optional(needed for openai)
judge_api_endpoint: "${ENDPOINT_URL}" # mandatory
judge_api_key: "${AUTH_TOKEN}" # mandatory
judge_model: gpt-4o-mini #optional
judge_type: openai # mandatory (vllm or openai)
judge_api_version: ${API_VERSION} # optional(needed for openai)
judge_api_endpoint: ${ENDPOINT_URL} # mandatory
judge_api_key: ${AUTH_TOKEN} # mandatory
judge_temperature: 0.1 # optional
judge_prompt_model_override: "gpt-4o-mini-enhanced" # optional
judge_prompt_model_override: gpt-4o-mini-enhanced # optional

logging:
log_file: "audiobench.log" # Path to the main log file


models:
- name: "gpt-4o-mini-audio-preview-1" # mandatory - must be unique
inference_type: "openai" # mandatory - you can use vllm(vllm), openai(openai), (chat completion) or audio transcription endpoint(transcription)
url: "${ENDPOINT_URL}" # mandatory - endpoint url
- name: gpt-4o-mini-audio-preview-1 # must be unique
inference_type: openai # you can use vllm, openai, gemini or transcription
url: ${ENDPOINT_URL} # endpoint url
delay: 100
retry_attempts: 8
timeout: 30
model: "gpt-4o-mini-audio-preview" # mandatory - only needed for vllm
auth_token: "${AUTH_TOKEN}"
api_version: "${API_VERSION}"
model: gpt-4o-mini-audio-preview
auth_token: ${AUTH_TOKEN}
api_version: ${API_VERSION}
batch_size: 300 # Optional - batch eval size
chunk_size: 30 # Optional - max audio length in seconds fed to model

- name: "gpt-4o-mini-audio-preview-2" # mandatory - must be unique
inference_type: "openai" # mandatory - you can use vllm(vllm), openai(openai), (chat completion) or audio transcription endpoint(transcription)
url: "${ENDPOINT_URL}" # mandatory - endpoint url
- name: gpt-4o-mini-audio-preview-2 # must be unique
inference_type: openai # you can use vllm, openai, gemini or transcription
url: ${ENDPOINT_URL} # endpoint url
delay: 100
retry_attempts: 8
timeout: 30
model: "gpt-4o-mini-audio-preview" # mandatory - only needed for vllm
auth_token: "${AUTH_TOKEN}"
api_version: "${API_VERSION}"
batch_size: 100 # Optional - batch eval size
model: gpt-4o-mini-audio-preview
auth_token: ${AUTH_TOKEN}
api_version: ${API_VERSION}
batch_size: 300 # Optional - batch eval size
chunk_size: 30 # Optional - max audio length in seconds fed to model

- name: "qwen-2.5-omni"
inference_type: "vllm" # mandatory - you can use vllm(vllm), openai(openai), (chat completion) or audio transcription endpoint(transcription)
url: "${ENDPOINT_URL}" # mandatory - endpoint url
- name: gemini-2.5-flash # must be unique
inference_type: gemini # you can use vllm, openai, gemini or transcription
location: ${GOOGLE_CLOUD_LOCATION} # GCP Vertex AI configureation
project_id: ${GOOGLE_CLOUD_PROJECT} # GCP Vertex AI configureation
reasoning_effort: medium # Optional - Reasoning effort for supported reasoning models like gemini-2.5-flash, gpt-5,...
delay: 100
retry_attempts: 5
timeout: 300
model: google/gemini-2.5-flash
batch_size: 100 # Optional - batch eval size
chunk_size: 30240 # Optional - max audio length in seconds fed to model

- name: qwen-2.5-omni # must be unique
inference_type: vllm # you can use vllm, openai, gemini or transcription
url: ${ENDPOINT_URL} # endpoint url
delay: 100
retry_attempts: 8
timeout: 30
model: "qwen-2.5-omni" # mandatory - only needed for vllm
auth_token: "${AUTH_TOKEN}"
model: qwen-2.5-omni
auth_token: ${AUTH_TOKEN}
batch_size: 200 # Optional - batch eval size
chunk_size: 40 # Optional - max audio length in seconds fed to model

- name: "whisper-large-3"
inference_type: "vllm" # mandatory - you can use vllm(vllm), openai(openai), (chat completion) or audio transcription endpoint(transcription)
url: "${ENDPOINT_URL}" # mandatory - endpoint url
- name: whisper-large-3 # must be unique
inference_type: transcription # you can use vllm, openai, gemini or transcription
url: ${ENDPOINT_URL} # endpoint url
delay: 100
retry_attempts: 8
timeout: 30
model: "whisper-large-3" # mandatory - only needed for vllm
auth_token: "${AUTH_TOKEN}"
model: whisper-large-3
auth_token: ${AUTH_TOKEN}
batch_size: 100 # Optional - batch eval size
chunk_size: 30 # Optional - max audio length in seconds fed to model

Expand Down
1 change: 1 addition & 0 deletions utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
# Inference server types
INFERENCE_SERVER_VLLM_CHAT_COMPLETION = 'vllm'
OPENAI_CHAT_COMPLETION = 'openai'
GEMINI_CHAT_COMPLETION = 'gemini'
TRANSCRIPTION = 'transcription'

# WER/CER metrics constants
Expand Down
67 changes: 65 additions & 2 deletions utils/util.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@

import importlib
import json
import re
import logging
import os
import statistics
Expand Down Expand Up @@ -219,7 +219,7 @@ def _validate_models(config: Dict) -> None:
ValueError: If the models section is invalid
"""
def validate_required_fields(info: Dict, index: int) -> None:
required_fields = ['name', 'model', 'inference_type', 'url']
required_fields = ['name', 'model', 'inference_type']
for field in required_fields:
if not info.get(field) or not isinstance(info[field], str) or not info[field].strip():
raise ValueError(f"Model {index}: '{field}' must be a non-empty string")
Expand Down Expand Up @@ -411,9 +411,68 @@ def setup_logging(log_file: str):
# Set httpx logger to WARNING level to reduce noise
logging.getLogger("httpx").setLevel(logging.WARNING)

def _replace_env_vars(value):
"""
Replace environment variables in strings.
Supports ${ENV_VAR} and $ENV_VAR syntax.

Args:
value: String value that may contain environment variables

Returns:
String with environment variables substituted
"""
if not isinstance(value, str):
return value

# Replace ${VAR} format
pattern1 = re.compile(r'\${([^}^{]+)}')
matches = pattern1.findall(value)
if matches:
for match in matches:
env_var = os.environ.get(match)
if env_var is not None:
value = value.replace(f"${{{match}}}", env_var)
else:
logger.warning(f"Environment variable '{match}' not found when processing config")

# Replace $VAR format
pattern2 = re.compile(r'(?<!\\)\$([a-zA-Z0-9_]+)')
matches = pattern2.findall(value)
if matches:
for match in matches:
env_var = os.environ.get(match)
if env_var is not None:
value = value.replace(f"${match}", env_var)
else:
logger.warning(f"Environment variable '{match}' not found when processing config")

return value

def _process_nested_env_vars(data):
"""
Process all values in a nested dictionary/list structure,
replacing environment variables in string values.

Args:
data: Dict, list, or scalar value

Returns:
Data with environment variables substituted in string values
"""
if isinstance(data, dict):
return {k: _process_nested_env_vars(v) for k, v in data.items()}
elif isinstance(data, list):
return [_process_nested_env_vars(item) for item in data]
elif isinstance(data, str):
return _replace_env_vars(data)
else:
return data

def read_config(cfg_path: str):
"""
Read configuration file and set up logging.
Supports environment variable substitution in the format ${ENV_VAR} or $ENV_VAR.

Args:
cfg_path: Path to configuration file
Expand All @@ -424,6 +483,10 @@ def read_config(cfg_path: str):
# Set up logging
with open(cfg_path, encoding='utf-8') as f:
raw_cfg = yaml.safe_load(f)

# Process environment variables in the config
raw_cfg = _process_nested_env_vars(raw_cfg)

log_file = raw_cfg.get("logging", {}).get("log_file", "default.log")
setup_logging(log_file)

Expand Down