diff --git a/models/model.py b/models/model.py index 9db6b2a..5b92476 100644 --- a/models/model.py +++ b/models/model.py @@ -301,6 +301,7 @@ async def _generate_text(self, message: dict, run_params: dict, error_tracker: E if self.inference_type in ( constants.INFERENCE_SERVER_VLLM_CHAT_COMPLETION, constants.OPENAI_CHAT_COMPLETION, + constants.GEMINI_CHAT_COMPLETION, ): for i in range(num_chunks): start = i * max_samples @@ -381,6 +382,7 @@ async def _generate_text(self, message: dict, run_params: dict, error_tracker: E if self.inference_type in ( constants.INFERENCE_SERVER_VLLM_CHAT_COMPLETION, constants.OPENAI_CHAT_COMPLETION, + constants.GEMINI_CHAT_COMPLETION, ): # Cut to first 30s, then process as chat completion if audio_array is not None and len(audio_array) > 0: @@ -472,6 +474,7 @@ async def _handle_multi_turn(self, message: dict, run_params: dict, error_tracke if self.inference_type not in ( constants.INFERENCE_SERVER_VLLM_CHAT_COMPLETION, constants.OPENAI_CHAT_COMPLETION, + constants.GEMINI_CHAT_COMPLETION, ): raise ValueError("Multi-turn conversations only supported for chat completion inference types") diff --git a/models/request_resp_handler.py b/models/request_resp_handler.py index 580ddcb..29a0ab7 100644 --- a/models/request_resp_handler.py +++ b/models/request_resp_handler.py @@ -5,6 +5,8 @@ import inspect import httpx +from google.auth import default +from google.auth.transport.requests import Request from openai import AsyncAzureOpenAI, AsyncOpenAI from models.model_response import ModelResponse, ErrorTracker from utils import constants @@ -20,6 +22,8 @@ def __init__(self, inference_type: str, model_info: dict, generation_params: dic self.model_info = model_info self.api = model_info.get("url") self.auth = model_info.get("auth_token", "") + self.location = model_info.get("location", "") + self.project_id = model_info.get("project_id", "") self.api_version = model_info.get("api_version", "") self.client = None self.timeout = timeout @@ -153,6 +157,25 @@ def set_client(self, verify_ssl: bool, timeout: int): http_client=httpx.AsyncClient(verify=verify_ssl), ) ) + elif self.inference_type == constants.GEMINI_CHAT_COMPLETION: + # Gemini endpoints + + # Set an API host for Gemini on Vertex AI + api_host = "aiplatform.googleapis.com" + if self.location != "global": + api_host = f"{self.location}-aiplatform.googleapis.com" + + credentials, _ = default(scopes=["https://www.googleapis.com/auth/cloud-platform"]) + credentials.refresh(Request()) + + self.client = AsyncOpenAI( + base_url=f"https://{api_host}/v1/projects/{self.project_id}/locations/{self.location}/endpoints/openapi", + api_key=credentials.token, + timeout=timeout, + max_retries=0, + default_headers={"Connection": "close"}, + http_client=httpx.AsyncClient(verify=verify_ssl), + ) def validated_safe_generation_params(self, generation_params): """Validate and sanitize generation parameters for the OpenAI API client. @@ -187,6 +210,7 @@ async def request_server(self, msg_body, tools=None, error_tracker: ErrorTracker 2. Any exception is wrapped in a `ModelResponse` with ``response_code = 500``. """ model_name: str | None = self.model_info.get("model") + reasoning_effort = self.model_info.get("reasoning_effort", None) if tools: tools = self.convert_to_tool(tools) @@ -194,12 +218,19 @@ async def request_server(self, msg_body, tools=None, error_tracker: ErrorTracker # Re-create a fresh client for this request to avoid closed-loop issues self.set_client(verify_ssl=True, timeout=self.timeout) try: - if self.inference_type == constants.OPENAI_CHAT_COMPLETION or self.inference_type == constants.INFERENCE_SERVER_VLLM_CHAT_COMPLETION: + if self.inference_type in (constants.OPENAI_CHAT_COMPLETION, constants.INFERENCE_SERVER_VLLM_CHAT_COMPLETION, constants.GEMINI_CHAT_COMPLETION): # openai chat completions, vllm chat completions self.generation_params = self.validated_safe_generation_params(self.generation_params) - prediction = await self.client.chat.completions.create( - model=model_name, messages=msg_body, tools=tools, **self.generation_params - ) + + if reasoning_effort: + prediction = await self.client.chat.completions.create( + model=model_name, messages=msg_body, tools=tools, reasoning_effort=reasoning_effort, **self.generation_params + ) + else: + prediction = await self.client.chat.completions.create( + model=model_name, messages=msg_body, tools=tools, **self.generation_params + ) + raw_response: str = self._extract_response_data(prediction) llm_response: str = raw_response['choices'][0]['message']['content'] or " " diff --git a/requirements.txt b/requirements.txt index ba34d86..47459d0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -51,4 +51,8 @@ nest_asyncio==1.6.0 immutabledict==4.2.1 # Debugging -debugpy==1.8.0 # VSCode debugging support \ No newline at end of file +debugpy==1.8.0 # VSCode debugging support + +# Gemimini libraries +google-auth==2.40.3 +google-genai==1.38.0 \ No newline at end of file diff --git a/sample_config.yaml b/sample_config.yaml index e4c2cfe..a28d69c 100644 --- a/sample_config.yaml +++ b/sample_config.yaml @@ -25,64 +25,75 @@ filter: num_samples: 100 # number of samples to run(remove for all) length_filter: [0.0, 30.0] #optional - filters for only audio samples in this length(seconds) - only supported for general and callhome preprocessors -judge_properties: +judge_settings: judge_concurrency: 8 #judge call(optional) - judge_model: "gpt-4o-mini" #optional - judge_type: "openai" # mandatory (vllm or openai) - judge_api_version: "${API_VERSION}" # optional(needed for openai) - judge_api_endpoint: "${ENDPOINT_URL}" # mandatory - judge_api_key: "${AUTH_TOKEN}" # mandatory + judge_model: gpt-4o-mini #optional + judge_type: openai # mandatory (vllm or openai) + judge_api_version: ${API_VERSION} # optional(needed for openai) + judge_api_endpoint: ${ENDPOINT_URL} # mandatory + judge_api_key: ${AUTH_TOKEN} # mandatory judge_temperature: 0.1 # optional - judge_prompt_model_override: "gpt-4o-mini-enhanced" # optional + judge_prompt_model_override: gpt-4o-mini-enhanced # optional logging: log_file: "audiobench.log" # Path to the main log file - models: - - name: "gpt-4o-mini-audio-preview-1" # mandatory - must be unique - inference_type: "openai" # mandatory - you can use vllm(vllm), openai(openai), (chat completion) or audio transcription endpoint(transcription) - url: "${ENDPOINT_URL}" # mandatory - endpoint url + - name: gpt-4o-mini-audio-preview-1 # must be unique + inference_type: openai # you can use vllm, openai, gemini or transcription + url: ${ENDPOINT_URL} # endpoint url delay: 100 retry_attempts: 8 timeout: 30 - model: "gpt-4o-mini-audio-preview" # mandatory - only needed for vllm - auth_token: "${AUTH_TOKEN}" - api_version: "${API_VERSION}" + model: gpt-4o-mini-audio-preview + auth_token: ${AUTH_TOKEN} + api_version: ${API_VERSION} batch_size: 300 # Optional - batch eval size chunk_size: 30 # Optional - max audio length in seconds fed to model - - name: "gpt-4o-mini-audio-preview-2" # mandatory - must be unique - inference_type: "openai" # mandatory - you can use vllm(vllm), openai(openai), (chat completion) or audio transcription endpoint(transcription) - url: "${ENDPOINT_URL}" # mandatory - endpoint url + - name: gpt-4o-mini-audio-preview-2 # must be unique + inference_type: openai # you can use vllm, openai, gemini or transcription + url: ${ENDPOINT_URL} # endpoint url delay: 100 retry_attempts: 8 timeout: 30 - model: "gpt-4o-mini-audio-preview" # mandatory - only needed for vllm - auth_token: "${AUTH_TOKEN}" - api_version: "${API_VERSION}" - batch_size: 100 # Optional - batch eval size + model: gpt-4o-mini-audio-preview + auth_token: ${AUTH_TOKEN} + api_version: ${API_VERSION} + batch_size: 300 # Optional - batch eval size chunk_size: 30 # Optional - max audio length in seconds fed to model - - name: "qwen-2.5-omni" - inference_type: "vllm" # mandatory - you can use vllm(vllm), openai(openai), (chat completion) or audio transcription endpoint(transcription) - url: "${ENDPOINT_URL}" # mandatory - endpoint url + - name: gemini-2.5-flash # must be unique + inference_type: gemini # you can use vllm, openai, gemini or transcription + location: ${GOOGLE_CLOUD_LOCATION} # GCP Vertex AI configureation + project_id: ${GOOGLE_CLOUD_PROJECT} # GCP Vertex AI configureation + reasoning_effort: medium # Optional - Reasoning effort for supported reasoning models like gemini-2.5-flash, gpt-5,... + delay: 100 + retry_attempts: 5 + timeout: 300 + model: google/gemini-2.5-flash + batch_size: 100 # Optional - batch eval size + chunk_size: 30240 # Optional - max audio length in seconds fed to model + + - name: qwen-2.5-omni # must be unique + inference_type: vllm # you can use vllm, openai, gemini or transcription + url: ${ENDPOINT_URL} # endpoint url delay: 100 retry_attempts: 8 timeout: 30 - model: "qwen-2.5-omni" # mandatory - only needed for vllm - auth_token: "${AUTH_TOKEN}" + model: qwen-2.5-omni + auth_token: ${AUTH_TOKEN} batch_size: 200 # Optional - batch eval size chunk_size: 40 # Optional - max audio length in seconds fed to model - - name: "whisper-large-3" - inference_type: "vllm" # mandatory - you can use vllm(vllm), openai(openai), (chat completion) or audio transcription endpoint(transcription) - url: "${ENDPOINT_URL}" # mandatory - endpoint url + - name: whisper-large-3 # must be unique + inference_type: transcription # you can use vllm, openai, gemini or transcription + url: ${ENDPOINT_URL} # endpoint url delay: 100 retry_attempts: 8 timeout: 30 - model: "whisper-large-3" # mandatory - only needed for vllm - auth_token: "${AUTH_TOKEN}" + model: whisper-large-3 + auth_token: ${AUTH_TOKEN} batch_size: 100 # Optional - batch eval size chunk_size: 30 # Optional - max audio length in seconds fed to model diff --git a/utils/constants.py b/utils/constants.py index 6703ae5..2a89a4f 100644 --- a/utils/constants.py +++ b/utils/constants.py @@ -21,6 +21,7 @@ # Inference server types INFERENCE_SERVER_VLLM_CHAT_COMPLETION = 'vllm' OPENAI_CHAT_COMPLETION = 'openai' +GEMINI_CHAT_COMPLETION = 'gemini' TRANSCRIPTION = 'transcription' # WER/CER metrics constants diff --git a/utils/util.py b/utils/util.py index 4dbbe0d..fa1db41 100644 --- a/utils/util.py +++ b/utils/util.py @@ -1,6 +1,6 @@ import importlib -import json +import re import logging import os import statistics @@ -219,7 +219,7 @@ def _validate_models(config: Dict) -> None: ValueError: If the models section is invalid """ def validate_required_fields(info: Dict, index: int) -> None: - required_fields = ['name', 'model', 'inference_type', 'url'] + required_fields = ['name', 'model', 'inference_type'] for field in required_fields: if not info.get(field) or not isinstance(info[field], str) or not info[field].strip(): raise ValueError(f"Model {index}: '{field}' must be a non-empty string") @@ -411,9 +411,68 @@ def setup_logging(log_file: str): # Set httpx logger to WARNING level to reduce noise logging.getLogger("httpx").setLevel(logging.WARNING) +def _replace_env_vars(value): + """ + Replace environment variables in strings. + Supports ${ENV_VAR} and $ENV_VAR syntax. + + Args: + value: String value that may contain environment variables + + Returns: + String with environment variables substituted + """ + if not isinstance(value, str): + return value + + # Replace ${VAR} format + pattern1 = re.compile(r'\${([^}^{]+)}') + matches = pattern1.findall(value) + if matches: + for match in matches: + env_var = os.environ.get(match) + if env_var is not None: + value = value.replace(f"${{{match}}}", env_var) + else: + logger.warning(f"Environment variable '{match}' not found when processing config") + + # Replace $VAR format + pattern2 = re.compile(r'(?