Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,4 @@ data/spider/
test/

CLAUDE.md
.claude/
14 changes: 14 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,17 @@ models:
auth_token: ${AUTH_TOKEN} # Mandatory
batch_size: 150 # Mandatory
chunk_size: 30 # Optional - Max audio length in seconds

- name: "Sonic-v3"
inference_type: "cartesia_tts"
model: "sonic-3"
auth_token: "${AUTH_TOKEN}"
delay: 100
retry_attempts: 8
timeout: 30
batch_size: 8
chunk_size: 30
voice_id: "6ccbfb76-1fc6-48f7-b71d-91ac6298247b" # For any caretsia_tts or ElevenLabs_tts inference type, voice_id is needed
```

**Note**: Batch-size proportional dataset sharding is implemented when multiple endpoints of the same model are provided. Be sure to have unique 'name' attributes for each unique endpoint, as shown above
Expand All @@ -289,6 +300,9 @@ models:
| "openai" | AsyncAzureOpenAI (Chat Completions) |
| "vllm" | AsyncOpenAI (Chat Completions) |
| "transcription" | AsyncOpenAI (Transcriptions) |
| "cartesia_tts" | AsyncCartesia (Text-to-Speech) |
| "deepgram_tts" | Deepgram (Text-to-Speech) |
| "elevenlabs_tts" | AsyncElevenLabs (Text-to-Speech) |

#### Judge Configuration
LLM-Judge setup is required to run any tasks requiring LLM-judge metrics. For specific task-metric pair compatibility, visit [Task Documentation](./tasks/README.md) and [Metric Documentation](./metrics/README.md).
Expand Down
11 changes: 10 additions & 1 deletion metrics/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ For more detailed documentation regarding which metrics can be used for which ta
| `sql_score` (↑) | SQL correctness and execution match | text2sql_score |
| `instruction_following` (↑) | LLM-judged instruction following capability | final |
| `gsm8k_exact_match` (↑) | Exact-match accuracy of the final numerical answer. | gsm8k_exact_match |
| `utmos` (↑) | UTMOSv2-based audio quality evaluation for TTS | utmos |

---

Expand Down Expand Up @@ -156,4 +157,12 @@ For more detailed documentation regarding which metrics can be used for which ta
- **Type**: Math correctness metric
- **Description**: Measure the exact-match accuracy of the final numerical answer (expected within `\boxed{}`) with the reference numerical answer.
- **Scoring (record-level)** Score between `0` and `100`, higher is better.
- **Used In**: Math (`gsm8k`)
- **Used In**: Math (`gsm8k`)

---

### `utmos`
- **Type**: TTS audio quality metric
- **Description**: Evaluates the quality of TTS-generated audio using the UTMOSv2 (Universal Text-to-Speech Mean Opinion Score v2) model, which predicts the naturalness and quality of synthesized speech.
- **Scoring (record-level)** Score between `0` and `5` (MOS scale), higher is better.
- **Used In**: Text-to-Speech (`tts`)
133 changes: 133 additions & 0 deletions metrics/utmos.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
"""UTMOSv2 metric for TTS audio quality evaluation."""

import logging
import torch
import utmosv2
from metrics.metrics import Metrics
from utils.custom_logging import write_record_log, append_final_score
from utils import util
import tempfile
import shutil
import os
import warnings
from tqdm import tqdm

logging.getLogger("timm").setLevel(logging.WARNING)
logging.getLogger("transformers").setLevel(logging.WARNING)
warnings.filterwarnings("ignore", category=FutureWarning, module="utmosv2")
warnings.filterwarnings("ignore", message=".*pin_memory.*")
warnings.filterwarnings("ignore", message=".*CUDA is not available.*")

logger = logging.getLogger(__name__)


class UTMOSMetric(Metrics):
"""UTMOSv2 metric for evaluating TTS audio quality."""

def __init__(self, batch_size=1):
"""Initialize UTMOSv2 metric.

Args:
batch_size: Number of audio files to process in parallel (default: 1)
"""
super().__init__()
self.name = "utmos"
self.record_level_scores = None
self.batch_size = batch_size

# Load model once and reuse for all evaluations
logger.info("[UTMOSMetric] Loading UTMOSv2 model...")
self.model = utmosv2.create_model(pretrained=True)

# Determine device
if torch.backends.mps.is_available():
self.device = 'mps'
elif torch.cuda.is_available():
self.device = 'cuda'
else:
self.device = 'cpu'

logger.info(f"[UTMOSMetric] Model loaded on {self.device} with batch_size={batch_size}")

def __call__(self, candidates, references, instructions=None, *,
task_name: str | None = None, model_name: str | None = None,
model_responses=None):
"""Compute UTMOSv2 scores for TTS-generated audio.

Args:
candidates: List of audio file paths (from TTS generation)
references: List of ground truth text (not used for UTMOS)
instructions: Optional instructions
task_name: Name of the task
model_name: Name of the model
model_responses: Raw model responses

Returns:
Dictionary with overall UTMOS score
"""
self.instructions = instructions

# Compute UTMOS scores for each audio file
self.record_level_scores = self.compute_record_level_scores(candidates)

# Calculate mean UTMOS
scores = self.record_level_scores.get(self.name, [])
valid_scores = [score for score in scores if score is not None]
mean_utmos = sum(valid_scores) / len(valid_scores) if valid_scores else 0.0
overall_score = {self.name: util.smart_round(mean_utmos, 3)}

if task_name and model_name:
write_record_log(self, references, candidates, scores,
task_name, model_name, instructions=self.instructions)
append_final_score(self, overall_score, task_name, model_name)

return overall_score

def compute_record_level_scores(self, audio_files: list) -> dict[str, list]:
"""Compute UTMOSv2 scores for each audio file in batches.

Args:
audio_files: List of paths to generated audio files

Returns:
Dictionary with UTMOS scores for each audio file
"""
num_batches = (len(audio_files) + self.batch_size - 1) // self.batch_size
all_scores = []

for batch_idx in tqdm(range(num_batches), desc="UTMOS", total=num_batches):
start_idx = batch_idx * self.batch_size
end_idx = min(start_idx + self.batch_size, len(audio_files))
batch_files = audio_files[start_idx:end_idx]

# Create temp directory for this batch
with tempfile.TemporaryDirectory() as temp_dir:
# Copy batch files to temp directory with indices
for i, audio_file in enumerate(batch_files):
temp_name = f"audio_{i:06d}.wav"
temp_path = os.path.join(temp_dir, temp_name)
shutil.copy(audio_file, temp_path)

# Batch prediction
results = self.model.predict(
input_dir=temp_dir,
device=self.device,
batch_size=self.batch_size,
num_workers=0,
verbose=False
)

# Extract scores for this batch
batch_scores = []
for i in range(len(batch_files)):
temp_name = f"audio_{i:06d}.wav"
score = None
for result in results:
if result.get('file_path', '').endswith(temp_name):
score = result.get('predicted_mos')
break
batch_scores.append(score)

all_scores.extend(batch_scores)

return {self.name: all_scores}
57 changes: 53 additions & 4 deletions models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ async def generate_text_with_retry(
"""
# Create a new error tracker instance for this specific call
call_errors = ErrorTracker()

result = None
try:
async for attempt in AsyncRetrying(
Expand All @@ -187,11 +188,15 @@ async def generate_text_with_retry(
):
with attempt:
try:
# All data prep is now in _generate_text
# Set attempt number for downstream logging
# All data prep is now in _generate_text or _generate_audio
self.req_resp_hndlr.current_attempt = attempt.retry_state.attempt_number
# Pass the error tracker to _generate_text
result: ModelResponse = await self._generate_text(message, run_params, call_errors)

# Route to appropriate generator based on inference type
if self.inference_type in (constants.CARTESIA_TTS, constants.ELEVENLABS_TTS, constants.DEEPGRAM_TTS):
result: ModelResponse = await self._generate_audio(message, run_params, call_errors)
else:
result: ModelResponse = await self._generate_text(message, run_params, call_errors)

# Ensure the result has our error tracker
if not result.error_tracker:
result.error_tracker = call_errors
Expand Down Expand Up @@ -254,6 +259,50 @@ async def generate_text_with_retry(
await self._mark_errors(result, call_errors)
return result

async def _generate_audio(self, message: dict, run_params: dict,
error_tracker: ErrorTracker) -> ModelResponse:
"""Generate audio from text using TTS provider.

Args:
message: Input message containing ground_truth_text
run_params: Runtime parameters for the inference request
error_tracker: Error tracker for this call

Returns:
ModelResponse: Response object containing audio file path and metadata
"""
text = message.get("ground_truth_text", "")

if not text:
logger.error("[Model.generate_audio] No ground_truth_text found in message")
return ModelResponse(
input_prompt="",
llm_response="",
raw_response="Missing ground_truth_text",
response_code=500,
performance=None,
wait_time=0,
error_tracker=error_tracker
)

try:
# Pass text to request handler
result = await self.req_resp_hndlr.request_server(
{"text": text},
error_tracker=error_tracker
)
return result
except Exception as e:
logger.error("[Model.generate_audio] TTS generation failed: %s", e)
return ModelResponse(
input_prompt=text,
llm_response="",
raw_response=str(e),
response_code=500,
performance=None,
wait_time=0,
error_tracker=error_tracker
)

async def _generate_text(self, message: dict, run_params: dict, error_tracker: ErrorTracker = None) -> ModelResponse:
"""
Expand Down
6 changes: 3 additions & 3 deletions models/model_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@ class Performance(BaseModel):
"""Python object for wrapping performance info from a model."""

latency: float
prompt_tokens: float
response_tokens: float
prompt_tokens: float | None
response_tokens: float | None
reasoning_tokens: float | None
time_per_token: float | None
relative_output_tokens: float
relative_output_tokens: float | None


class ErrorTracker(BaseModel):
Expand Down
Loading