Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 48 additions & 2 deletions crates/mofa-foundation/src/inference/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,18 @@ pub struct InferenceRequest {
pub priority: RequestPriority,
/// Preferred precision level (orchestrator may downgrade under pressure)
pub preferred_precision: Precision,
/// Optional cap for completion length.
///
/// This is forwarded from API-level request parameters so downstream
/// providers can enforce generation limits.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<u32>,
/// Optional sampling temperature.
///
/// This is forwarded from API-level request parameters so downstream
/// providers can control generation stochasticity.
#[serde(default, skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
}

impl InferenceRequest {
Expand All @@ -150,6 +162,8 @@ impl InferenceRequest {
required_memory_mb: memory_mb,
priority: RequestPriority::default(),
preferred_precision: Precision::F16,
max_tokens: None,
temperature: None,
}
}

Expand All @@ -164,6 +178,18 @@ impl InferenceRequest {
self.preferred_precision = precision;
self
}

/// Set maximum generated tokens.
pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
self.max_tokens = Some(max_tokens);
self
}

/// Set sampling temperature.
pub fn with_temperature(mut self, temperature: f32) -> Self {
self.temperature = Some(temperature);
self
}
}

/// The result of an inference request after orchestration.
Expand Down Expand Up @@ -206,12 +232,16 @@ mod tests {
fn test_request_builder() {
let req = InferenceRequest::new("llama-3-13b", "Hello world", 13312)
.with_priority(RequestPriority::High)
.with_precision(Precision::Q8);
.with_precision(Precision::Q8)
.with_max_tokens(256)
.with_temperature(0.7);

assert_eq!(req.model_id, "llama-3-13b");
assert_eq!(req.required_memory_mb, 13312);
assert_eq!(req.priority, RequestPriority::High);
assert_eq!(req.preferred_precision, Precision::Q8);
assert_eq!(req.max_tokens, Some(256));
assert_eq!(req.temperature, Some(0.7));
}

#[test]
Expand Down Expand Up @@ -284,12 +314,28 @@ mod tests {
fn test_inference_request_serde_roundtrip() {
let req = InferenceRequest::new("llama-3-13b", "Hello world", 13312)
.with_priority(RequestPriority::High)
.with_precision(Precision::Q8);
.with_precision(Precision::Q8)
.with_max_tokens(128)
.with_temperature(0.4);
let json = serde_json::to_string(&req).unwrap();
let back: InferenceRequest = serde_json::from_str(&json).unwrap();
assert_eq!(back, req);
}

#[test]
fn test_inference_request_deserialize_without_generation_fields() {
let json = r#"{
"model_id":"m",
"prompt":"p",
"required_memory_mb":1024,
"priority":"Normal",
"preferred_precision":"F16"
}"#;
let req: InferenceRequest = serde_json::from_str(json).unwrap();
assert_eq!(req.max_tokens, None);
assert_eq!(req.temperature, None);
}

#[test]
fn test_inference_result_serde_roundtrip() {
let result = InferenceResult {
Expand Down
11 changes: 11 additions & 0 deletions crates/mofa-gateway/src/inference_bridge.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,17 @@ impl InferenceBridge {
.with_priority(RequestPriority::Normal)
.with_precision(Precision::F16);

let inference_request = if let Some(max_tokens) = request.max_tokens {
inference_request.with_max_tokens(max_tokens)
} else {
inference_request
};
let inference_request = if let Some(temperature) = request.temperature {
inference_request.with_temperature(temperature)
} else {
inference_request
};

// Call the orchestrator
let result = {
let mut orch = self.orchestrator.lock();
Expand Down
3 changes: 1 addition & 2 deletions crates/mofa-gateway/src/openai_compat/handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,7 @@ pub async fn chat_completions(

// ── Build InferenceRequest ────────────────────────────────────────────────
let prompt = req.to_prompt();
let inference_req =
InferenceRequest::new(&req.model, &prompt, 7168).with_priority(req.priority());
let inference_req = req.to_inference_request(7168);

// ── Invoke orchestrator ───────────────────────────────────────────────────
let start = Instant::now();
Expand Down
43 changes: 42 additions & 1 deletion crates/mofa-gateway/src/openai_compat/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

use serde::{Deserialize, Serialize};

use mofa_foundation::inference::types::RequestPriority;
use mofa_foundation::inference::types::{InferenceRequest, RequestPriority};
use mofa_foundation::inference::{OrchestratorConfig, RoutingPolicy};

// ──────────────────────────────────────────────────────────────────────────────
Expand Down Expand Up @@ -80,6 +80,25 @@ impl ChatCompletionRequest {
_ => RequestPriority::Normal,
}
}

/// Convert API request into an internal inference request.
///
/// Carries generation controls (`max_tokens`, `temperature`) forward so
/// downstream routing/providers can consume them instead of silently
/// dropping user-specified parameters.
pub fn to_inference_request(&self, required_memory_mb: usize) -> InferenceRequest {
let mut req = InferenceRequest::new(&self.model, self.to_prompt(), required_memory_mb)
.with_priority(self.priority());

if let Some(max_tokens) = self.max_tokens {
req = req.with_max_tokens(max_tokens);
}
if let Some(temperature) = self.temperature {
req = req.with_temperature(temperature);
}

req
}
}

/// Serializable counterpart to [`RequestPriority`] for JSON deserialization.
Expand Down Expand Up @@ -396,6 +415,28 @@ mod tests {
assert!(prompt.contains("user: Hi"));
}

#[test]
fn test_to_inference_request_propagates_generation_params() {
let req = ChatCompletionRequest {
model: "mofa-local".into(),
messages: vec![ChatMessage {
role: "user".into(),
content: "Generate".into(),
}],
stream: false,
max_tokens: Some(128),
temperature: Some(0.2),
priority: RequestPriorityParam::High,
};

let internal = req.to_inference_request(7168);
assert_eq!(internal.model_id, "mofa-local");
assert_eq!(internal.required_memory_mb, 7168);
assert_eq!(internal.max_tokens, Some(128));
assert_eq!(internal.temperature, Some(0.2));
assert_eq!(internal.priority, RequestPriority::High);
}

#[test]
fn test_gateway_error_body_rate_limited() {
let err = GatewayErrorBody::rate_limited();
Expand Down
5 changes: 1 addition & 4 deletions crates/mofa-gateway/src/streaming/ws.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,10 +133,7 @@ async fn handle_ws_session(mut socket: axum::extract::ws::WebSocket, state: AppS
}

// ── 3. Run the orchestrator ───────────────────────────────────────────
use mofa_foundation::inference::types::InferenceRequest;
let prompt = req.to_prompt();
let inference_req =
InferenceRequest::new(&req.model, &prompt, 7168).with_priority(req.priority());
let inference_req = req.to_inference_request(7168);

let (_result, token_stream) = {
let mut orch = state.orchestrator.write().await;
Expand Down
Loading