diff --git a/.gitignore b/.gitignore index c7eaf9a..7667076 100644 --- a/.gitignore +++ b/.gitignore @@ -35,3 +35,4 @@ data/spider/ test/ CLAUDE.md +.claude/ diff --git a/README.md b/README.md index fe1d7fd..75a5235 100644 --- a/README.md +++ b/README.md @@ -278,6 +278,17 @@ models: auth_token: ${AUTH_TOKEN} # Mandatory batch_size: 150 # Mandatory chunk_size: 30 # Optional - Max audio length in seconds + + - name: "Sonic-v3" + inference_type: "cartesia_tts" + model: "sonic-3" + auth_token: "${AUTH_TOKEN}" + delay: 100 + retry_attempts: 8 + timeout: 30 + batch_size: 8 + chunk_size: 30 + voice_id: "6ccbfb76-1fc6-48f7-b71d-91ac6298247b" # For any caretsia_tts or ElevenLabs_tts inference type, voice_id is needed ``` **Note**: Batch-size proportional dataset sharding is implemented when multiple endpoints of the same model are provided. Be sure to have unique 'name' attributes for each unique endpoint, as shown above @@ -289,6 +300,9 @@ models: | "openai" | AsyncAzureOpenAI (Chat Completions) | | "vllm" | AsyncOpenAI (Chat Completions) | | "transcription" | AsyncOpenAI (Transcriptions) | +| "cartesia_tts" | AsyncCartesia (Text-to-Speech) | +| "deepgram_tts" | Deepgram (Text-to-Speech) | +| "elevenlabs_tts" | AsyncElevenLabs (Text-to-Speech) | #### Judge Configuration LLM-Judge setup is required to run any tasks requiring LLM-judge metrics. For specific task-metric pair compatibility, visit [Task Documentation](./tasks/README.md) and [Metric Documentation](./metrics/README.md). diff --git a/metrics/README.md b/metrics/README.md index a3a5796..2d27088 100644 --- a/metrics/README.md +++ b/metrics/README.md @@ -26,6 +26,7 @@ For more detailed documentation regarding which metrics can be used for which ta | `sql_score` (↑) | SQL correctness and execution match | text2sql_score | | `instruction_following` (↑) | LLM-judged instruction following capability | final | | `gsm8k_exact_match` (↑) | Exact-match accuracy of the final numerical answer. | gsm8k_exact_match | +| `utmos` (↑) | UTMOSv2-based audio quality evaluation for TTS | utmos | --- @@ -156,4 +157,12 @@ For more detailed documentation regarding which metrics can be used for which ta - **Type**: Math correctness metric - **Description**: Measure the exact-match accuracy of the final numerical answer (expected within `\boxed{}`) with the reference numerical answer. - **Scoring (record-level)** Score between `0` and `100`, higher is better. -- **Used In**: Math (`gsm8k`) \ No newline at end of file +- **Used In**: Math (`gsm8k`) + +--- + +### `utmos` +- **Type**: TTS audio quality metric +- **Description**: Evaluates the quality of TTS-generated audio using the UTMOSv2 (Universal Text-to-Speech Mean Opinion Score v2) model, which predicts the naturalness and quality of synthesized speech. +- **Scoring (record-level)** Score between `0` and `5` (MOS scale), higher is better. +- **Used In**: Text-to-Speech (`tts`) \ No newline at end of file diff --git a/metrics/utmos.py b/metrics/utmos.py new file mode 100644 index 0000000..e54d6d9 --- /dev/null +++ b/metrics/utmos.py @@ -0,0 +1,133 @@ +"""UTMOSv2 metric for TTS audio quality evaluation.""" + +import logging +import torch +import utmosv2 +from metrics.metrics import Metrics +from utils.custom_logging import write_record_log, append_final_score +from utils import util +import tempfile +import shutil +import os +import warnings +from tqdm import tqdm + +logging.getLogger("timm").setLevel(logging.WARNING) +logging.getLogger("transformers").setLevel(logging.WARNING) +warnings.filterwarnings("ignore", category=FutureWarning, module="utmosv2") +warnings.filterwarnings("ignore", message=".*pin_memory.*") +warnings.filterwarnings("ignore", message=".*CUDA is not available.*") + +logger = logging.getLogger(__name__) + + +class UTMOSMetric(Metrics): + """UTMOSv2 metric for evaluating TTS audio quality.""" + + def __init__(self, batch_size=1): + """Initialize UTMOSv2 metric. + + Args: + batch_size: Number of audio files to process in parallel (default: 1) + """ + super().__init__() + self.name = "utmos" + self.record_level_scores = None + self.batch_size = batch_size + + # Load model once and reuse for all evaluations + logger.info("[UTMOSMetric] Loading UTMOSv2 model...") + self.model = utmosv2.create_model(pretrained=True) + + # Determine device + if torch.backends.mps.is_available(): + self.device = 'mps' + elif torch.cuda.is_available(): + self.device = 'cuda' + else: + self.device = 'cpu' + + logger.info(f"[UTMOSMetric] Model loaded on {self.device} with batch_size={batch_size}") + + def __call__(self, candidates, references, instructions=None, *, + task_name: str | None = None, model_name: str | None = None, + model_responses=None): + """Compute UTMOSv2 scores for TTS-generated audio. + + Args: + candidates: List of audio file paths (from TTS generation) + references: List of ground truth text (not used for UTMOS) + instructions: Optional instructions + task_name: Name of the task + model_name: Name of the model + model_responses: Raw model responses + + Returns: + Dictionary with overall UTMOS score + """ + self.instructions = instructions + + # Compute UTMOS scores for each audio file + self.record_level_scores = self.compute_record_level_scores(candidates) + + # Calculate mean UTMOS + scores = self.record_level_scores.get(self.name, []) + valid_scores = [score for score in scores if score is not None] + mean_utmos = sum(valid_scores) / len(valid_scores) if valid_scores else 0.0 + overall_score = {self.name: util.smart_round(mean_utmos, 3)} + + if task_name and model_name: + write_record_log(self, references, candidates, scores, + task_name, model_name, instructions=self.instructions) + append_final_score(self, overall_score, task_name, model_name) + + return overall_score + + def compute_record_level_scores(self, audio_files: list) -> dict[str, list]: + """Compute UTMOSv2 scores for each audio file in batches. + + Args: + audio_files: List of paths to generated audio files + + Returns: + Dictionary with UTMOS scores for each audio file + """ + num_batches = (len(audio_files) + self.batch_size - 1) // self.batch_size + all_scores = [] + + for batch_idx in tqdm(range(num_batches), desc="UTMOS", total=num_batches): + start_idx = batch_idx * self.batch_size + end_idx = min(start_idx + self.batch_size, len(audio_files)) + batch_files = audio_files[start_idx:end_idx] + + # Create temp directory for this batch + with tempfile.TemporaryDirectory() as temp_dir: + # Copy batch files to temp directory with indices + for i, audio_file in enumerate(batch_files): + temp_name = f"audio_{i:06d}.wav" + temp_path = os.path.join(temp_dir, temp_name) + shutil.copy(audio_file, temp_path) + + # Batch prediction + results = self.model.predict( + input_dir=temp_dir, + device=self.device, + batch_size=self.batch_size, + num_workers=0, + verbose=False + ) + + # Extract scores for this batch + batch_scores = [] + for i in range(len(batch_files)): + temp_name = f"audio_{i:06d}.wav" + score = None + for result in results: + if result.get('file_path', '').endswith(temp_name): + score = result.get('predicted_mos') + break + batch_scores.append(score) + + all_scores.extend(batch_scores) + + return {self.name: all_scores} diff --git a/models/model.py b/models/model.py index 9db6b2a..786fcc4 100644 --- a/models/model.py +++ b/models/model.py @@ -177,6 +177,7 @@ async def generate_text_with_retry( """ # Create a new error tracker instance for this specific call call_errors = ErrorTracker() + result = None try: async for attempt in AsyncRetrying( @@ -187,11 +188,15 @@ async def generate_text_with_retry( ): with attempt: try: - # All data prep is now in _generate_text - # Set attempt number for downstream logging + # All data prep is now in _generate_text or _generate_audio self.req_resp_hndlr.current_attempt = attempt.retry_state.attempt_number - # Pass the error tracker to _generate_text - result: ModelResponse = await self._generate_text(message, run_params, call_errors) + + # Route to appropriate generator based on inference type + if self.inference_type in (constants.CARTESIA_TTS, constants.ELEVENLABS_TTS, constants.DEEPGRAM_TTS): + result: ModelResponse = await self._generate_audio(message, run_params, call_errors) + else: + result: ModelResponse = await self._generate_text(message, run_params, call_errors) + # Ensure the result has our error tracker if not result.error_tracker: result.error_tracker = call_errors @@ -254,6 +259,50 @@ async def generate_text_with_retry( await self._mark_errors(result, call_errors) return result + async def _generate_audio(self, message: dict, run_params: dict, + error_tracker: ErrorTracker) -> ModelResponse: + """Generate audio from text using TTS provider. + + Args: + message: Input message containing ground_truth_text + run_params: Runtime parameters for the inference request + error_tracker: Error tracker for this call + + Returns: + ModelResponse: Response object containing audio file path and metadata + """ + text = message.get("ground_truth_text", "") + + if not text: + logger.error("[Model.generate_audio] No ground_truth_text found in message") + return ModelResponse( + input_prompt="", + llm_response="", + raw_response="Missing ground_truth_text", + response_code=500, + performance=None, + wait_time=0, + error_tracker=error_tracker + ) + + try: + # Pass text to request handler + result = await self.req_resp_hndlr.request_server( + {"text": text}, + error_tracker=error_tracker + ) + return result + except Exception as e: + logger.error("[Model.generate_audio] TTS generation failed: %s", e) + return ModelResponse( + input_prompt=text, + llm_response="", + raw_response=str(e), + response_code=500, + performance=None, + wait_time=0, + error_tracker=error_tracker + ) async def _generate_text(self, message: dict, run_params: dict, error_tracker: ErrorTracker = None) -> ModelResponse: """ diff --git a/models/model_response.py b/models/model_response.py index c460a37..fb1d445 100644 --- a/models/model_response.py +++ b/models/model_response.py @@ -18,11 +18,11 @@ class Performance(BaseModel): """Python object for wrapping performance info from a model.""" latency: float - prompt_tokens: float - response_tokens: float + prompt_tokens: float | None + response_tokens: float | None reasoning_tokens: float | None time_per_token: float | None - relative_output_tokens: float + relative_output_tokens: float | None class ErrorTracker(BaseModel): diff --git a/models/request_resp_handler.py b/models/request_resp_handler.py index 580ddcb..c1cde03 100644 --- a/models/request_resp_handler.py +++ b/models/request_resp_handler.py @@ -3,9 +3,15 @@ import re import time import inspect - +import tempfile +import numpy as np +import soundfile as sf +import os import httpx from openai import AsyncAzureOpenAI, AsyncOpenAI +from cartesia import AsyncCartesia +from elevenlabs.client import AsyncElevenLabs +from deepgram import AsyncDeepgramClient from models.model_response import ModelResponse, ErrorTracker from utils import constants @@ -153,6 +159,17 @@ def set_client(self, verify_ssl: bool, timeout: int): http_client=httpx.AsyncClient(verify=verify_ssl), ) ) + elif self.inference_type == constants.CARTESIA_TTS: + # Cartesia TTS client + self.client = AsyncCartesia(api_key=self.auth) + + elif self.inference_type == constants.ELEVENLABS_TTS: + # ElevenLabs TTS client + self.client = AsyncElevenLabs(api_key=self.auth) + + elif self.inference_type == constants.DEEPGRAM_TTS: + # Deepgram TTS async client + self.client = AsyncDeepgramClient(api_key=self.auth) def validated_safe_generation_params(self, generation_params): """Validate and sanitize generation parameters for the OpenAI API client. @@ -178,7 +195,122 @@ def validated_safe_generation_params(self, generation_params): safe_params['max_completion_tokens'] = safe_params.get('max_completion_tokens', constants.DEFAULT_MAX_COMPLETION_TOKENS) return safe_params - + + async def request_tts_server(self, text: str, model_name: str, voice_id: str, + start_time: float, error_tracker: ErrorTracker) -> ModelResponse: + """Helper function for TTS request processing. + + Args: + text: Text to convert to speech + model_name: TTS model name + voice_id: Voice ID for TTS + start_time: Request start time for performance tracking + error_tracker: Error tracker for this call + + Returns: + ModelResponse: Response object with audio file path + """ + if self.inference_type == constants.CARTESIA_TTS: + # Cartesia TTS generation + bytes_iter = self.client.tts.bytes( + model_id=model_name, + transcript=text, + voice={"mode": "id", "id": voice_id}, + output_format={ + "container": "wav", + "sample_rate": 16000, + "encoding": "pcm_s16le", + } + ) + + # Save to temp file + with tempfile.NamedTemporaryFile(delete=False, suffix=".wav", mode='wb') as f: + async for chunk in bytes_iter: + f.write(chunk) + audio_path = f.name + + elapsed_time = time.time() - start_time + + # Get file size for response + audio_bytes_len = os.path.getsize(audio_path) + + return ModelResponse( + input_prompt=text, + llm_response=audio_path, + raw_response={"audio_bytes": audio_bytes_len, "audio_path": audio_path}, + response_code=200, + performance=None, + wait_time=elapsed_time, + error_tracker=error_tracker, + ) + + elif self.inference_type == constants.ELEVENLABS_TTS: + # ElevenLabs TTS generation + audio = self.client.text_to_speech.convert( + text=text, + voice_id=voice_id, + model_id=model_name, + output_format="pcm_16000" + ) + + # Collect PCM chunks + pcm_chunks = [] + async for chunk in audio: + if isinstance(chunk, bytes): + pcm_chunks.append(chunk) + + pcm_data = b''.join(pcm_chunks) + + # Convert PCM to WAV with headers using soundfile + audio_array = np.frombuffer(pcm_data, dtype=np.int16) + + with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f: + sf.write(f.name, audio_array, 16000, format='WAV') + audio_path = f.name + + elapsed_time = time.time() - start_time + return ModelResponse( + input_prompt=text, + llm_response=audio_path, + raw_response={"audio_bytes": len(pcm_data), "audio_path": audio_path}, + response_code=200, + performance=None, + wait_time=elapsed_time, + error_tracker=error_tracker, + ) + + elif self.inference_type == constants.DEEPGRAM_TTS: + # Deepgram TTS generation + response = self.client.speak.v1.audio.generate( + text=text, + model=model_name, + ) + + # response is an async generator, collect chunks + audio_chunks = [] + async for chunk in response: + if isinstance(chunk, bytes): + audio_chunks.append(chunk) + + audio_data = b''.join(audio_chunks) + + # Write to temporary WAV file + with tempfile.NamedTemporaryFile(delete=False, suffix=".wav", mode='wb') as f: + f.write(audio_data) + audio_path = f.name + + elapsed_time = time.time() - start_time + + return ModelResponse( + input_prompt=text, + llm_response=audio_path, + raw_response={"audio_bytes": len(audio_data), "audio_path": audio_path}, + response_code=200, + performance=None, + wait_time=elapsed_time, + error_tracker=error_tracker, + ) + async def request_server(self, msg_body, tools=None, error_tracker: ErrorTracker = None) -> ModelResponse: """Send a request to the inference server and return a `Model Response`. @@ -192,9 +324,17 @@ async def request_server(self, msg_body, tools=None, error_tracker: ErrorTracker start_time = time.time() # Re-create a fresh client for this request to avoid closed-loop issues - self.set_client(verify_ssl=True, timeout=self.timeout) + # Skip for TTS clients - they should be reused to avoid file descriptor leaks + if self.inference_type not in (constants.CARTESIA_TTS, constants.ELEVENLABS_TTS, constants.DEEPGRAM_TTS): + self.set_client(verify_ssl=True, timeout=self.timeout) + try: - if self.inference_type == constants.OPENAI_CHAT_COMPLETION or self.inference_type == constants.INFERENCE_SERVER_VLLM_CHAT_COMPLETION: + # Handle TTS requests + if self.inference_type in (constants.CARTESIA_TTS, constants.ELEVENLABS_TTS, constants.DEEPGRAM_TTS): + text = msg_body.get("text") + voice_id = self.model_info.get("voice_id") + return await self.request_tts_server(text, model_name, voice_id, start_time, error_tracker) + elif self.inference_type == constants.OPENAI_CHAT_COMPLETION or self.inference_type == constants.INFERENCE_SERVER_VLLM_CHAT_COMPLETION: # openai chat completions, vllm chat completions self.generation_params = self.validated_safe_generation_params(self.generation_params) prediction = await self.client.chat.completions.create( diff --git a/postprocessors/tts_postprocessor.py b/postprocessors/tts_postprocessor.py new file mode 100644 index 0000000..19946bc --- /dev/null +++ b/postprocessors/tts_postprocessor.py @@ -0,0 +1,45 @@ +"""TTS postprocessor module for AU-Harness framework. + +This module provides a postprocessor for Text-to-Speech tasks that extracts +generated audio paths and ground truth text for evaluation. +""" + +import logging +from postprocessors.base import Postprocessor + +logger = logging.getLogger(__name__) + + +class TtsPostprocessor(Postprocessor): + """Postprocessor for TTS - extracts audio paths and ground truth text.""" + + def process(self, dataset: list[dict], predictions, metric) -> dict: + """Process TTS predictions (audio file paths). + + Args: + dataset: List of preprocessed input samples + predictions: Dictionary mapping model names to lists of ModelResponse objects + metric: Evaluation metric + + Returns: + dict: Dictionary containing processed data for evaluation + """ + + # Extract ground truth text as targets + targets = [sample.get("ground_truth_text", "") for sample in dataset] + + # Extract generated audio paths from predictions + processed_predictions = {} + for model_name, model_predictions in predictions.items(): + # Each prediction is a ModelResponse with audio path in llm_response + audio_paths = [ + pred.llm_response if pred else "" + for pred in model_predictions + ] + processed_predictions[model_name] = audio_paths + + # Create standardized output + return self.create_output( + model_targets=targets, + processed_predictions=processed_predictions + ) diff --git a/preprocessors/tts_preprocessor.py b/preprocessors/tts_preprocessor.py new file mode 100644 index 0000000..8b8b972 --- /dev/null +++ b/preprocessors/tts_preprocessor.py @@ -0,0 +1,78 @@ +"""TTS preprocessor module for AU-Harness framework. + +This module provides a preprocessor for Text-to-Speech tasks that extracts +text from datasets for audio generation. +""" + +import logging +from typing import Dict, List, Any +import numpy as np +from tqdm import tqdm +from datasets import Dataset +from preprocessors.base import Preprocessor + +logger = logging.getLogger(__name__) + + +class TtsPreprocessor(Preprocessor): + """Preprocessor for TTS tasks - extracts text for generation.""" + + def process(self, dataset: Dataset, task_config: Dict[str, Any], + run_config: Dict[str, Any]) -> List[Dict[str, Any]]: + """Extract text from dataset for TTS generation. + + Args: + dataset: The task dataset to pre-process + task_config: Dictionary containing task configuration parameters + run_config: Dictionary containing run configuration parameters + + Returns: + List of dictionaries where each dictionary represents a pre-processed sample + """ + + # Get dataset info + dataset_keys = list(dataset.features.keys()) + dataset_size = len(dataset) + self.log_dataset_info(dataset_keys, dataset_size) + + # Get dataset filters + length_filter, num_samples_filter = self.get_dataset_filters( + run_config.get('filter', None), dataset_size + ) + + processed_data = [] + sample_count = 0 + + for i, row in enumerate(tqdm(dataset, desc="Processing TTS samples")): + record = {k: row[k] for k in dataset_keys} + + # Find text column (priority order: normalized_text -> text_normalized -> text) + text = None + for key in ['normalized_text', 'text_normalized', 'text']: + if key in record: + text = record[key] + break + + if not text: + logger.warning("No text column found in sample %d, skipping", i) + continue + + # Store text in standardized column + record["ground_truth_text"] = text + + # Placeholder audio (not used for TTS generation) + record["array"] = np.array([]) + record["sampling_rate"] = 16000 + record["instruction"] = "" # Not used for TTS + record["model_target"] = text # For compatibility with postprocessor + + # Apply sample count filter + if num_samples_filter and sample_count >= num_samples_filter: + break + + processed_data.append(record) + sample_count += 1 + + self.log_dataset_info(dataset_keys, dataset_size, sample_count) + + return processed_data diff --git a/requirements.txt b/requirements.txt index ba34d86..f84f58d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -18,6 +18,7 @@ sacrebleu==2.5.1 nltk==3.9.1 bert-score unbabel-comet +torch # Misc utilities PyYAML==6.0.2 @@ -29,7 +30,7 @@ setuptools==80.9.0 pillow==11.1.0 logger==1.4 Jinja2==3.1.5 -librosa==0.10.0 +librosa==0.10.2 jiwer==4.0.0 num2words==0.5.14 jaconv==0.4.0 @@ -51,4 +52,10 @@ nest_asyncio==1.6.0 immutabledict==4.2.1 # Debugging -debugpy==1.8.0 # VSCode debugging support \ No newline at end of file +debugpy==1.8.0 # VSCode debugging support + +# TTS support +cartesia==2.0.15 +elevenlabs==2.22.0 +deepgram-sdk==5.3.0 +utmosv2 @ git+https://github.com/sarulab-speech/UTMOSv2.git \ No newline at end of file diff --git a/tasks/speech_recognition/asr/gigaspeech2/gigaspeech2_th_test.yaml b/tasks/speech_recognition/asr/gigaspeech2/gigaspeech2_th_test.yaml index ad090c9..488bfee 100644 --- a/tasks/speech_recognition/asr/gigaspeech2/gigaspeech2_th_test.yaml +++ b/tasks/speech_recognition/asr/gigaspeech2/gigaspeech2_th_test.yaml @@ -1,4 +1,4 @@ task_name: gigaspeech2_th_test -extends: ["../base.yaml#"] +extends: ["./base.yaml#"] subset: th-test language: th \ No newline at end of file diff --git a/tasks/speech_recognition/asr/gigaspeech2/gigaspeech2_vi_test.yaml b/tasks/speech_recognition/asr/gigaspeech2/gigaspeech2_vi_test.yaml index 6a1e74c..57dd3c9 100644 --- a/tasks/speech_recognition/asr/gigaspeech2/gigaspeech2_vi_test.yaml +++ b/tasks/speech_recognition/asr/gigaspeech2/gigaspeech2_vi_test.yaml @@ -1,4 +1,4 @@ task_name: gigaspeech2_vi_test -extends: ["../base.yaml#"] +extends: ["./base.yaml#"] subset: vi-test language: vi \ No newline at end of file diff --git a/tasks/speech_recognition/asr/mnsc/base.yaml b/tasks/speech_recognition/asr/mnsc/base.yaml index 4f15b8d..a24dcc3 100644 --- a/tasks/speech_recognition/asr/mnsc/base.yaml +++ b/tasks/speech_recognition/asr/mnsc/base.yaml @@ -2,7 +2,7 @@ extends: ["../base.yaml#"] dataset_path: AudioLLMs/Multitask-National-Speech-Corpus-v1-extend language: en -split: test +split: train accented: true audio_column: context target_column: answer diff --git a/tasks/tts/speech_synthesis/base.yaml b/tasks/tts/speech_synthesis/base.yaml new file mode 100644 index 0000000..967d952 --- /dev/null +++ b/tasks/tts/speech_synthesis/base.yaml @@ -0,0 +1,11 @@ +# Base configuration for TTS (Text-to-Speech) tasks +preprocessor: TtsPreprocessor +postprocessor: TtsPostprocessor +split: test +long_audio_processing_logic: truncate + +generation_kwargs: + temperature: 0.2 + +metrics: + - metric: utmos diff --git a/tasks/tts/speech_synthesis/libritts/libritts_test.yaml b/tasks/tts/speech_synthesis/libritts/libritts_test.yaml new file mode 100644 index 0000000..4465344 --- /dev/null +++ b/tasks/tts/speech_synthesis/libritts/libritts_test.yaml @@ -0,0 +1,6 @@ +task_name: libritts_test +extends: ["../base.yaml#"] +dataset_path: "mythicinfinity/libritts" +subset: "clean" +split: "test.clean" +language: "en" diff --git a/tasks/tts/speech_synthesis/libritts/libritts_test_other.yaml b/tasks/tts/speech_synthesis/libritts/libritts_test_other.yaml new file mode 100644 index 0000000..92160a9 --- /dev/null +++ b/tasks/tts/speech_synthesis/libritts/libritts_test_other.yaml @@ -0,0 +1,11 @@ +task_name: libritts_test_other +dataset_path: "mythicinfinity/libritts" +subset: "other" +split: "test.other" +preprocessor: TtsPreprocessor +postprocessor: TtsPostprocessor +language: "en" +generation_kwargs: + temperature: 0.3 +metrics: + - metric: nisqa_mos diff --git a/utils/constants.py b/utils/constants.py index 6703ae5..849b7a2 100644 --- a/utils/constants.py +++ b/utils/constants.py @@ -22,6 +22,9 @@ INFERENCE_SERVER_VLLM_CHAT_COMPLETION = 'vllm' OPENAI_CHAT_COMPLETION = 'openai' TRANSCRIPTION = 'transcription' +CARTESIA_TTS = 'cartesia_tts' +ELEVENLABS_TTS = 'elevenlabs_tts' +DEEPGRAM_TTS = 'deepgram_tts' # WER/CER metrics constants # Define WER/CER related constants @@ -70,6 +73,7 @@ "word_error_rate": ("metrics.word_error_rate_metrics", "WERMetrics"), "comet": ("metrics.comet_score", "CometScore"), "mt_bench_llm_judge": ("metrics.llm_judge", "MtbenchLLMJudgeMetric"), + 'utmos': ("metrics.utmos", "UTMOSMetric"), } @@ -140,6 +144,9 @@ 'safety': ['llm_judge_redteaming'], 'spoofing': ['llm_judge_detailed', 'llm_judge_binary'], 'mtbench': ['mt_bench_llm_judge'], + + # Text-to-Speech + 'tts': ['utmos'], } metric_output = { @@ -157,7 +164,8 @@ "sql_score": ["sql_score"], # need to find real metric "instruction_following": ["strict_instruction", "loose_instruction", "final"], "diarization_metrics": ["average_sample_wder", "overall_wder", "average_sample_cpwer", "overall_cpwer", "speaker_count_absolute_error"], - "comet": ["comet"] + "comet": ["comet"], + "utmos": ["utmos"] } # Dictionary mapping language names to their standard codes diff --git a/utils/util.py b/utils/util.py index 4dbbe0d..2fee1fe 100644 --- a/utils/util.py +++ b/utils/util.py @@ -219,10 +219,22 @@ def _validate_models(config: Dict) -> None: ValueError: If the models section is invalid """ def validate_required_fields(info: Dict, index: int) -> None: - required_fields = ['name', 'model', 'inference_type', 'url'] + # Base required fields for all models + required_fields = ['name', 'model', 'inference_type'] + + # URL is only required for non-TTS inference types + inference_type = info.get('inference_type') + if inference_type not in ['cartesia_tts', 'elevenlabs_tts', 'deepgram_tts']: + required_fields.append('url') + for field in required_fields: if not info.get(field) or not isinstance(info[field], str) or not info[field].strip(): raise ValueError(f"Model {index}: '{field}' must be a non-empty string") + + # Require voice_id for TTS inference types + if inference_type in ['cartesia_tts', 'elevenlabs_tts']: + if not info.get('voice_id') or not isinstance(info['voice_id'], str): + raise ValueError(f"Model {index}: 'voice_id' is required for TTS inference types") def validate_optional_fields(info: Dict, index: int) -> None: optional_fields = { 'delay': int, 'retry_attempts': int, 'timeout': int,