Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 15 additions & 15 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 10 additions & 2 deletions crates/api/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@ use crate::{
StateStore,
},
billing::{get_billing_costs, BillingRouteState},
completions::{chat_completions, image_edits, image_generations, models, rerank, score},
completions::{
audio_transcriptions, chat_completions, image_edits, image_generations, models, rerank,
score,
},
conversations,
health::health_check,
models::{get_model_by_name, list_models, ModelsAppState},
Expand Down Expand Up @@ -47,6 +50,9 @@ use std::sync::Arc;
use tower_http::cors::{AllowOrigin, Any, CorsLayer};
use utoipa::OpenApi;

// Audio transcription file size limit (25 MB for OpenAI Whisper API compatibility)
const AUDIO_TRANSCRIPTION_MAX_BODY_SIZE: usize = 25 * 1024 * 1024; // 25 MB

/// Service initialization components
#[derive(Clone)]
pub struct AuthComponents {
Expand Down Expand Up @@ -872,13 +878,15 @@ pub fn build_completion_routes(
) -> Router {
use crate::routes::files::MAX_FILE_SIZE;

// Text-based inference routes (chat/completions, image generation)
// Text-based inference routes (chat/completions, image generation, audio transcription, rerank, score)
// Use default body limit (~2 MB) since they only accept JSON
let text_inference_routes = Router::new()
.route("/chat/completions", post(chat_completions))
.route("/images/generations", post(image_generations))
.route("/audio/transcriptions", post(audio_transcriptions))
.route("/rerank", post(rerank))
.route("/score", post(score))
.layer(DefaultBodyLimit::max(AUDIO_TRANSCRIPTION_MAX_BODY_SIZE))
.with_state(app_state.clone())
.layer(from_fn_with_state(
usage_state.clone(),
Expand Down
191 changes: 191 additions & 0 deletions crates/api/src/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,144 @@ pub struct ImageData {
pub revised_prompt: Option<String>,
}

// ========================================
// Audio Transcription
// ========================================

/// Audio transcription request schema for OpenAPI documentation
/// This represents the multipart/form-data fields
#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
pub struct AudioTranscriptionRequestSchema {
/// Audio file (required) - binary audio data
pub file: String, // Placeholder for binary data in OpenAPI

/// Model identifier (required) - e.g. "openai/whisper-large-v3"
pub model: String,

/// Language code (optional) - e.g. "en", "es"
#[serde(skip_serializing_if = "Option::is_none")]
pub language: Option<String>,

/// Response format (optional) - one of: "json", "text", "srt", "verbose_json", "vtt"
#[serde(skip_serializing_if = "Option::is_none")]
pub response_format: Option<String>,
}

/// Audio transcription request (internal runtime struct)
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AudioTranscriptionRequest {
/// Model identifier (required via form field)
#[serde(skip)]
pub model: String,

/// Audio file bytes (required via form field)
#[serde(skip)]
pub file_bytes: Vec<u8>,

/// Original filename
#[serde(skip)]
pub filename: String,

/// Language code (optional, e.g. "en", "es")
#[serde(skip_serializing_if = "Option::is_none")]
pub language: Option<String>,

/// Response format: "json", "text", "srt", "verbose_json", "vtt"
#[serde(skip_serializing_if = "Option::is_none")]
pub response_format: Option<String>,

/// Sampling temperature (0-1)
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,

/// Timestamp granularities: "word", "segment"
#[serde(skip_serializing_if = "Option::is_none")]
pub timestamp_granularities: Option<Vec<String>>,
}

impl AudioTranscriptionRequest {
pub fn validate(&self) -> Result<(), String> {
// Validate model is not empty
if self.model.trim().is_empty() {
return Err("Model is required and cannot be empty".to_string());
}

// Validate file is not empty
if self.file_bytes.is_empty() {
return Err("Audio file is required and cannot be empty".to_string());
}

// Validate file size (25 MB limit per OpenAI spec)
const MAX_FILE_SIZE: usize = 25 * 1024 * 1024; // 25 MB
if self.file_bytes.len() > MAX_FILE_SIZE {
return Err(format!(
"Audio file size exceeds maximum of 25 MB (got {} MB)",
self.file_bytes.len() / (1024 * 1024)
));
}

// Validate filename by extracting just the base filename component
// This prevents path traversal attacks (including encoded variants)
use std::path::Path;

if self.filename.is_empty() {
return Err("Filename cannot be empty".to_string());
}

// Extract safe filename by stripping any path components
// This handles both Unix and Windows paths, and prevents traversal attacks
let safe_filename = Path::new(&self.filename)
.file_name()
.and_then(|n| n.to_str())
.ok_or_else(|| {
"Invalid filename: must be a valid UTF-8 filename without path components"
.to_string()
})?;

// Check if filename was stripped of path components (indicates traversal attempt)
if safe_filename != self.filename {
return Err(
"Filename cannot contain path components or traversal sequences".to_string(),
);
}

// Validate filename length (max 255 characters per common filesystem limit)
if safe_filename.len() > 255 {
return Err("Filename exceeds maximum length of 255 characters".to_string());
}

// Validate filename has extension
if !safe_filename.contains('.') {
return Err("Filename must have an extension (e.g., .mp3, .wav)".to_string());
}

// Reject null bytes which could truncate paths in C-based systems
if safe_filename.contains('\0') {
return Err("Filename cannot contain null bytes".to_string());
}

// Validate temperature if provided
if let Some(temp) = self.temperature {
if !(0.0..=1.0).contains(&temp) {
return Err("Temperature must be between 0 and 1".to_string());
}
}

// Validate response_format if provided
if let Some(format) = &self.response_format {
let valid_formats = ["json", "text", "srt", "verbose_json", "vtt"];
if !valid_formats.contains(&format.as_str()) {
return Err(format!(
"Invalid response_format. Must be one of: {}",
valid_formats.join(", ")
));
}
}

Ok(())
}
}

// ========== Rerank Models ==========

/// Request for document reranking
Expand Down Expand Up @@ -439,6 +577,59 @@ impl RerankRequest {
}
}

/// Audio transcription response (with OpenAPI schema)
#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
pub struct AudioTranscriptionResponse {
/// Transcribed text
pub text: String,

/// Total audio duration in seconds
#[serde(skip_serializing_if = "Option::is_none")]
pub duration: Option<f64>,

/// Detected language code
#[serde(skip_serializing_if = "Option::is_none")]
pub language: Option<String>,

/// Transcription segments with timing
#[serde(skip_serializing_if = "Option::is_none")]
pub segments: Option<Vec<TranscriptionSegment>>,

/// Word-level timing information
#[serde(skip_serializing_if = "Option::is_none")]
pub words: Option<Vec<TranscriptionWord>>,
}

/// Transcription segment with optional metadata fields
/// Matches the inference_providers version to ensure consistency with actual provider responses
#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
pub struct TranscriptionSegment {
pub id: i32,
pub seek: i32,
pub start: f64,
pub end: f64,
pub text: String,
pub tokens: Vec<i32>,
pub temperature: f64,
/// Optional: may be null in some provider responses
#[serde(skip_serializing_if = "Option::is_none")]
pub avg_logprob: Option<f64>,
/// Optional: may be null in some provider responses
#[serde(skip_serializing_if = "Option::is_none")]
pub compression_ratio: Option<f64>,
/// Optional: may be null in some provider responses
#[serde(skip_serializing_if = "Option::is_none")]
pub no_speech_prob: Option<f64>,
}

/// Word-level timing information
#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
pub struct TranscriptionWord {
pub word: String,
pub start: f64,
pub end: f64,
}

/// Response from document reranking
#[derive(Debug, Serialize, Deserialize, ToSchema)]
pub struct RerankResponse {
Expand Down
4 changes: 4 additions & 0 deletions crates/api/src/openapi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ use utoipa::{Modify, OpenApi};
tags(
(name = "Chat", description = "Chat completion endpoints for AI model inference"),
(name = "Images", description = "Image generation endpoints"),
(name = "Audio", description = "Audio transcription endpoints"),
(name = "Rerank", description = "Document reranking endpoints"),
(name = "Score", description = "Text similarity scoring endpoints"),
(name = "Models", description = "Public model catalog and information"),
Expand All @@ -41,6 +42,7 @@ use utoipa::{Modify, OpenApi};
// Chat completion endpoints (most important for users)
crate::routes::completions::chat_completions,
crate::routes::completions::image_generations,
crate::routes::completions::audio_transcriptions,
crate::routes::completions::image_edits,
crate::routes::completions::rerank,
crate::routes::completions::score,
Expand Down Expand Up @@ -146,6 +148,8 @@ use utoipa::{Modify, OpenApi};
CompletionRequest, ModelsResponse, ModelInfo, ModelPricing, ErrorResponse,
// Image generation models
ImageGenerationRequest, ImageGenerationResponse, ImageData,
// Audio transcription models
AudioTranscriptionRequestSchema, AudioTranscriptionResponse, TranscriptionSegment, TranscriptionWord,
// Image edit models
ImageEditRequestSchema,
// Rerank models
Expand Down
Loading