From 4af61f0408dc0c463ad272c7dba71c4e02cd6d4d Mon Sep 17 00:00:00 2001 From: Jannik Straube Date: Thu, 26 Jun 2025 16:19:05 +1000 Subject: [PATCH 1/3] poc to allow to partition gpus --- crates/shared/src/models/task.rs | 6 + crates/worker/src/docker/service.rs | 486 +++++++++++++++------ crates/worker/src/docker/task_container.rs | 52 ++- 3 files changed, 402 insertions(+), 142 deletions(-) diff --git a/crates/shared/src/models/task.rs b/crates/shared/src/models/task.rs index 0a6ffac3..c5a96493 100644 --- a/crates/shared/src/models/task.rs +++ b/crates/shared/src/models/task.rs @@ -153,6 +153,8 @@ pub struct TaskRequest { pub storage_config: Option, pub metadata: Option, pub volume_mounts: Option>, + #[serde(default)] + pub partition_by_gpu: bool, } #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, ToSchema)] @@ -182,6 +184,8 @@ pub struct Task { pub metadata: Option, #[serde(default)] pub volume_mounts: Option>, + #[serde(default)] + pub partition_by_gpu: bool, } impl Task { @@ -238,6 +242,7 @@ impl Default for Task { storage_config: None, metadata: None, volume_mounts: None, + partition_by_gpu: false, } } } @@ -301,6 +306,7 @@ impl TryFrom for Task { storage_config: request.storage_config, metadata: request.metadata, volume_mounts: request.volume_mounts, + partition_by_gpu: request.partition_by_gpu, }) } } diff --git a/crates/worker/src/docker/service.rs b/crates/worker/src/docker/service.rs index a5dd921f..4e005508 100644 --- a/crates/worker/src/docker/service.rs +++ b/crates/worker/src/docker/service.rs @@ -11,6 +11,7 @@ use shared::models::task::Task; use shared::models::task::TaskState; use std::collections::HashMap; use std::path::Path; +use std::str::FromStr; use std::sync::Arc; use tokio::sync::Mutex; use tokio::time::{interval, Duration}; @@ -56,6 +57,34 @@ impl DockerService { } } + /// Helper function to get expected container names for a task + /// Returns a vector of (container_name, gpu_index) tuples + fn get_container_names_for_task(&self, task: &Task) -> Vec<(String, Option)> { + let config_hash = task.generate_config_hash(); + let base_name = format!("{}-{}-{:x}", TASK_PREFIX, task.id, config_hash); + + if task.partition_by_gpu { + // For GPU partitioned tasks, create one container per GPU + if let Some(ref gpu_specs) = self.gpu { + let gpu_count = gpu_specs + .indices + .as_ref() + .map(|indices| indices.len()) + .unwrap_or(gpu_specs.count.unwrap_or(1) as usize); + + (0..gpu_count) + .map(|i| (format!("{}-gpu{}", base_name, i), Some(i as u32))) + .collect() + } else { + // No GPUs available, fall back to single container + vec![(base_name, None)] + } + } else { + // Non-partitioned task: single container with access to all GPUs + vec![(base_name, None)] + } + } + pub async fn run(&self) -> Result<(), Box> { let mut interval = interval(Duration::from_secs(5)); let manager = self.docker_manager.clone(); @@ -143,10 +172,29 @@ impl DockerService { } } - if current_task.is_some() && task_id.is_some() { - let container_task_id = task_id.as_ref().unwrap().clone(); - let container_match = all_containers.iter().find(|c| c.names.contains(&format!("/{container_task_id}"))); - if container_match.is_none() { + if let Some(ref current_task_ref) = current_task { + // Get expected container names for this task + let expected_containers = self.get_container_names_for_task(current_task_ref); + let task_containers: Vec = all_containers + .iter() + .filter(|c| { + expected_containers.iter().any(|(name, _)| { + c.names.contains(&format!("/{}", name)) + }) + }) + .cloned() + .collect(); + + let missing_containers: Vec<(String, Option)> = expected_containers + .into_iter() + .filter(|(name, _)| { + !task_containers.iter().any(|c| { + c.names.contains(&format!("/{}", name)) + }) + }) + .collect(); + + if !missing_containers.is_empty() { let running_tasks = starting_container_tasks.lock().await; let has_running_tasks = running_tasks.iter().any(|h| !h.is_finished()); drop(running_tasks); @@ -171,143 +219,193 @@ impl DockerService { } else { Console::info("DockerService", "Starting new container..."); } - let manager_clone = manager_clone.clone(); - let state_clone = task_state_clone.clone(); - let gpu = self.gpu.clone(); - let system_memory_mb = self.system_memory_mb; - let task_bridge_socket_path = self.task_bridge_socket_path.clone(); - let node_address = self.node_address.clone(); - let p2p_seed = self.p2p_seed; - let handle = tokio::spawn(async move { - let payload = match state_clone.get_current_task().await { - Some(payload) => payload, - None => { - return; - } - }; - let cmd = match payload.cmd { - Some(cmd_vec) => { - cmd_vec.into_iter().map(|arg| { - let mut processed_arg = arg.replace("${SOCKET_PATH}", &task_bridge_socket_path); + + // Start containers for all missing GPU indices + for (container_name, gpu_index) in missing_containers { + let manager_clone = manager_clone.clone(); + let state_clone = task_state_clone.clone(); + let gpu = self.gpu.clone(); + let system_memory_mb = self.system_memory_mb; + let task_bridge_socket_path = self.task_bridge_socket_path.clone(); + let node_address = self.node_address.clone(); + let p2p_seed = self.p2p_seed; + let handle = tokio::spawn(async move { + let payload = match state_clone.get_current_task().await { + Some(payload) => payload, + None => { + return; + } + }; + let cmd = match payload.cmd { + Some(cmd_vec) => { + cmd_vec.into_iter().map(|arg| { + let mut processed_arg = arg.replace("${SOCKET_PATH}", &task_bridge_socket_path); + if let Some(seed) = p2p_seed { + processed_arg = processed_arg.replace("${WORKER_P2P_SEED}", &seed.to_string()); + } + processed_arg + }).collect() + } + None => vec!["sleep".to_string(), "infinity".to_string()], + }; + + let mut env_vars: HashMap = HashMap::new(); + if let Some(env) = &payload.env_vars { + // Clone env vars and replace ${SOCKET_PATH} in values + for (key, value) in env.iter() { + let mut processed_value = value.replace("${SOCKET_PATH}", &task_bridge_socket_path); if let Some(seed) = p2p_seed { - processed_arg = processed_arg.replace("${WORKER_P2P_SEED}", &seed.to_string()); + processed_value = processed_value.replace("${WORKER_P2P_SEED}", &seed.to_string()); } - processed_arg - }).collect() - } - None => vec!["sleep".to_string(), "infinity".to_string()], - }; - - let mut env_vars: HashMap = HashMap::new(); - if let Some(env) = &payload.env_vars { - // Clone env vars and replace ${SOCKET_PATH} in values - for (key, value) in env.iter() { - let mut processed_value = value.replace("${SOCKET_PATH}", &task_bridge_socket_path); - if let Some(seed) = p2p_seed { - processed_value = processed_value.replace("${WORKER_P2P_SEED}", &seed.to_string()); + env_vars.insert(key.clone(), processed_value); } - env_vars.insert(key.clone(), processed_value); } - } - - env_vars.insert("NODE_ADDRESS".to_string(), node_address); - env_vars.insert("PRIME_MONITOR__SOCKET__PATH".to_string(), task_bridge_socket_path.to_string()); - env_vars.insert("PRIME_TASK_ID".to_string(), payload.id.to_string()); - - let mut volumes = vec![ - ( - Path::new(&task_bridge_socket_path).parent().unwrap().to_path_buf().to_string_lossy().to_string(), - Path::new(&task_bridge_socket_path).parent().unwrap().to_path_buf().to_string_lossy().to_string(), - false, - false, - ) - ]; - - if let Some(volume_mounts) = &payload.volume_mounts { - for volume_mount in volume_mounts { - volumes.push(( - volume_mount.host_path.clone(), - volume_mount.container_path.clone(), - false, - true - )); + + env_vars.insert("NODE_ADDRESS".to_string(), node_address); + env_vars.insert("PRIME_MONITOR__SOCKET__PATH".to_string(), task_bridge_socket_path.to_string()); + env_vars.insert("PRIME_TASK_ID".to_string(), payload.id.to_string()); + + // Add GPU_INDEX environment variable for partitioned tasks + if let Some(idx) = gpu_index { + env_vars.insert("GPU_INDEX".to_string(), idx.to_string()); } - } - let shm_size = match system_memory_mb { - Some(mem_mb) => (mem_mb as u64) * 1024 * 1024 / 2, // Convert MB to bytes and divide by 2 - None => { - Console::warning("System memory not available, using default shm size"); - 67108864 // Default to 64MB in bytes + + let mut volumes = vec![ + ( + Path::new(&task_bridge_socket_path).parent().unwrap().to_path_buf().to_string_lossy().to_string(), + Path::new(&task_bridge_socket_path).parent().unwrap().to_path_buf().to_string_lossy().to_string(), + false, + false, + ) + ]; + + if let Some(volume_mounts) = &payload.volume_mounts { + for volume_mount in volume_mounts { + volumes.push(( + volume_mount.host_path.clone(), + volume_mount.container_path.clone(), + false, + true + )); + } } - }; - match manager_clone.start_container(&payload.image, &container_task_id, Some(env_vars), Some(cmd), gpu, Some(volumes), Some(shm_size), payload.entrypoint, None).await { - Ok(container_id) => { - Console::info("DockerService", &format!("Container started with id: {container_id}")); - }, - Err(e) => { - log::error!("Error starting container: {e}"); - state_clone.update_task_state(payload.id, TaskState::FAILED).await; + let shm_size = match system_memory_mb { + Some(mem_mb) => (mem_mb as u64) * 1024 * 1024 / 2, // Convert MB to bytes and divide by 2 + None => { + Console::warning("System memory not available, using default shm size"); + 67108864 // Default to 64MB in bytes + } + }; + + // Prepare GPU specs for this specific container + let container_gpu = if payload.partition_by_gpu && gpu_index.is_some() { + // For partitioned tasks, assign specific GPU + gpu.map(|mut g| { + let gpu_idx = gpu_index.unwrap(); + // If indices are specified, use the actual GPU index + if let Some(indices) = &g.indices { + if (gpu_idx as usize) < indices.len() { + g.indices = Some(vec![indices[gpu_idx as usize]]); + } + } else { + // Otherwise, use the sequential index + g.indices = Some(vec![gpu_idx]); + } + g + }) + } else { + // For non-partitioned tasks, use all GPUs + gpu + }; + + match manager_clone.start_container(&payload.image, &container_name, Some(env_vars), Some(cmd), container_gpu, Some(volumes), Some(shm_size), payload.entrypoint.clone(), None).await { + Ok(container_id) => { + Console::info("DockerService", &format!("Container started with id: {container_id}")); + }, + Err(e) => { + log::error!("Error starting container: {e}"); + state_clone.update_task_state(payload.id, TaskState::FAILED).await; + } } - } - state_clone.set_last_started(Utc::now()).await; - }); - starting_container_tasks.lock().await.push(handle); - + state_clone.set_last_started(Utc::now()).await; + }); + starting_container_tasks.lock().await.push(handle); + } } - } } else { - let container_status = container_match.unwrap().clone(); - let status = match manager.get_container_details(&container_status.id).await { - Ok(status) => status, - Err(e) => { - log::error!("Error getting container details: {e}"); - continue; - } - }; - - let task_state_current = match task_state_clone.get_current_task().await { - Some(task) => task.state, - None => { - log::error!("No task found in state"); - continue; - } - }; - // handle edge case where container instantly dies due to invalid command - if status.status == Some(ContainerStateStatusEnum::CREATED) && task_state_current == TaskState::FAILED { - Console::info("DockerService", "Task failed, waiting for new command from manager ..."); - } else { - debug!("docker container status: {:?}, status_code: {:?}", status.status, status.status_code); - let task_state_live = match (status.status, status.status_code) { - (Some(ContainerStateStatusEnum::RUNNING), _) => TaskState::RUNNING, - (Some(ContainerStateStatusEnum::CREATED), _) => TaskState::PENDING, - (Some(ContainerStateStatusEnum::EXITED), Some(0)) => TaskState::COMPLETED, - (Some(ContainerStateStatusEnum::EXITED), Some(code)) if code != 0 => TaskState::FAILED, - (Some(ContainerStateStatusEnum::DEAD), _) => TaskState::FAILED, - (Some(ContainerStateStatusEnum::PAUSED), _) => TaskState::PAUSED, - (Some(ContainerStateStatusEnum::RESTARTING), _) => TaskState::RESTARTING, - (Some(ContainerStateStatusEnum::REMOVING), _) => TaskState::UNKNOWN, - _ => TaskState::UNKNOWN, + // All expected containers exist, now aggregate their states + let mut any_failed = false; + let mut all_completed = true; + let mut any_running = false; + let mut all_created = true; + + for container in &task_containers { + let status = match manager.get_container_details(&container.id).await { + Ok(status) => status, + Err(e) => { + log::error!("Error getting container details: {e}"); + continue; + } }; - // Only log if state changed - if task_state_live != task_state_current { - Console::info("DockerService", &format!("Task state changed from {task_state_current:?} to {task_state_live:?}")); + match (status.status, status.status_code) { + (Some(ContainerStateStatusEnum::RUNNING), _) => { + any_running = true; + all_completed = false; + all_created = false; + }, + (Some(ContainerStateStatusEnum::CREATED), _) => { + all_completed = false; + }, + (Some(ContainerStateStatusEnum::EXITED), Some(0)) => { + all_created = false; + }, + (Some(ContainerStateStatusEnum::EXITED), Some(code)) if code != 0 => { + any_failed = true; + all_created = false; + }, + (Some(ContainerStateStatusEnum::DEAD), _) => { + any_failed = true; + all_created = false; + }, + _ => { + all_completed = false; + all_created = false; + } + } + } - if task_state_live == TaskState::FAILED { + let task_state_current = current_task_ref.state.clone(); + + // Determine aggregated task state + let task_state_live = if any_failed { + TaskState::FAILED + } else if all_completed { + TaskState::COMPLETED + } else if any_running { + TaskState::RUNNING + } else if all_created { + TaskState::PENDING + } else { + TaskState::UNKNOWN + }; - consecutive_failures += 1; - Console::info("DockerService", &format!("Task failed (attempt {consecutive_failures}), waiting with exponential backoff before restart")); + // Only log if state changed + if task_state_live != task_state_current { + Console::info("DockerService", &format!("Task state changed from {task_state_current:?} to {task_state_live:?}")); - } else if task_state_live == TaskState::RUNNING { - // Reset failure counter when container runs successfully - consecutive_failures = 0; - } + if task_state_live == TaskState::FAILED { + consecutive_failures += 1; + Console::info("DockerService", &format!("Task failed (attempt {consecutive_failures}), waiting with exponential backoff before restart")); + } else if task_state_live == TaskState::RUNNING { + // Reset failure counter when container runs successfully + consecutive_failures = 0; } + } - if let Some(task) = task_state_clone.get_current_task().await { - task_state_clone.update_task_state(task.id, task_state_live).await; - } + if let Some(task) = task_state_clone.get_current_task().await { + task_state_clone.update_task_state(task.id, task_state_live).await; } } } @@ -322,17 +420,38 @@ impl DockerService { let current_task = self.state.get_current_task().await; match current_task { Some(task) => { - let config_hash = task.generate_config_hash(); - let container_id = format!("{}-{}-{:x}", TASK_PREFIX, task.id, config_hash); - - let logs = self - .docker_manager - .get_container_logs(&container_id, None) - .await?; - if logs.is_empty() { - Ok("No logs found in docker container".to_string()) + let container_names = self.get_container_names_for_task(&task); + let mut all_logs = Vec::new(); + + for (container_name, gpu_index) in container_names { + let logs = match self + .docker_manager + .get_container_logs(&container_name, None) + .await + { + Ok(logs) if !logs.is_empty() => logs, + Ok(_) => continue, // Empty logs, skip + Err(e) => { + log::debug!( + "Failed to get logs for container {}: {}", + container_name, + e + ); + continue; + } + }; + + // Add header for partitioned tasks + if task.partition_by_gpu && gpu_index.is_some() { + all_logs.push(format!("=== GPU {} ===", gpu_index.unwrap())); + } + all_logs.push(logs); + } + + if all_logs.is_empty() { + Ok("No logs found in docker containers".to_string()) } else { - Ok(logs) + Ok(all_logs.join("\n\n")) } } None => Ok("No task running".to_string()), @@ -343,9 +462,24 @@ impl DockerService { let current_task = self.state.get_current_task().await; match current_task { Some(task) => { - let config_hash = task.generate_config_hash(); - let container_id = format!("{}-{}-{:x}", TASK_PREFIX, task.id, config_hash); - self.docker_manager.restart_container(&container_id).await?; + let container_names = self.get_container_names_for_task(&task); + let mut restart_errors = Vec::new(); + + for (container_name, _) in container_names { + if let Err(e) = self.docker_manager.restart_container(&container_name).await { + log::error!("Failed to restart container {}: {}", container_name, e); + restart_errors.push(format!("{}: {}", container_name, e)); + } + } + + if !restart_errors.is_empty() { + return Err(format!( + "Failed to restart some containers: {}", + restart_errors.join(", ") + ) + .into()); + } + Ok(()) } None => Ok(()), @@ -353,8 +487,10 @@ impl DockerService { } pub async fn get_task_details(&self, task: &Task) -> Option { - let config_hash = task.generate_config_hash(); - let container_name = format!("{}-{}-{:x}", TASK_PREFIX, task.id, config_hash); + let container_names = self.get_container_names_for_task(task); + + // For partitioned tasks, use the first container (gpu0) as representative + let (container_name, _) = container_names.first()?; match self.docker_manager.list_containers(true).await { Ok(containers) => { @@ -540,4 +676,76 @@ mod tests { // without mocking DockerManager, but we've verified the logic visually cancellation_token.cancel(); } + + #[test] + fn test_get_container_names_for_task() { + let cancellation_token = CancellationToken::new(); + + // Test with GPU partitioning enabled and multiple GPUs + let gpu_specs = Some(GpuSpecs { + count: Some(4), + indices: Some(vec![0, 1, 2, 3]), + model: None, + memory_mb: None, + }); + + let docker_service = DockerService::new( + cancellation_token.clone(), + gpu_specs, + Some(1024), + "/tmp/test.sock".to_string(), + "/tmp/test-storage".to_string(), + Address::ZERO.to_string(), + None, + false, + ); + + // Test partitioned task + let partitioned_task = Task { + image: "test:latest".to_string(), + name: "partitioned".to_string(), + id: Uuid::from_str("123e4567-e89b-12d3-a456-426614174000").unwrap(), + partition_by_gpu: true, + ..Default::default() + }; + + let container_names = docker_service.get_container_names_for_task(&partitioned_task); + assert_eq!(container_names.len(), 4); + assert_eq!(container_names[0].1, Some(0)); + assert_eq!(container_names[1].1, Some(1)); + assert_eq!(container_names[2].1, Some(2)); + assert_eq!(container_names[3].1, Some(3)); + assert!(container_names[0].0.ends_with("-gpu0")); + assert!(container_names[1].0.ends_with("-gpu1")); + + // Test non-partitioned task + let non_partitioned_task = Task { + image: "test:latest".to_string(), + name: "non-partitioned".to_string(), + id: Uuid::from_str("123e4567-e89b-12d3-a456-426614174000").unwrap(), + partition_by_gpu: false, + ..Default::default() + }; + + let container_names = docker_service.get_container_names_for_task(&non_partitioned_task); + assert_eq!(container_names.len(), 1); + assert_eq!(container_names[0].1, None); + assert!(!container_names[0].0.contains("-gpu")); + + // Test partitioned task with no GPUs available + let docker_service_no_gpu = DockerService::new( + cancellation_token.clone(), + None, + Some(1024), + "/tmp/test.sock".to_string(), + "/tmp/test-storage".to_string(), + Address::ZERO.to_string(), + None, + false, + ); + + let container_names = docker_service_no_gpu.get_container_names_for_task(&partitioned_task); + assert_eq!(container_names.len(), 1); + assert_eq!(container_names[0].1, None); + } } diff --git a/crates/worker/src/docker/task_container.rs b/crates/worker/src/docker/task_container.rs index 5a6f903f..39d360ae 100644 --- a/crates/worker/src/docker/task_container.rs +++ b/crates/worker/src/docker/task_container.rs @@ -2,6 +2,7 @@ pub struct TaskContainer { pub task_id: String, pub config_hash: String, + pub gpu_index: Option, } impl TaskContainer { @@ -18,10 +19,25 @@ impl std::str::FromStr for TaskContainer { if parts.len() >= 8 && parts[0] == "prime" && parts[1] == "task" { let task_id = parts[2..7].join("-"); - let config_hash = parts[7..].join("-"); + + // Check if the container name has a GPU suffix + let (config_hash, gpu_index) = if parts.len() >= 9 && parts[parts.len() - 2] == "gpu" { + // Has GPU suffix: extract GPU index + let gpu_idx = parts[parts.len() - 1] + .parse::() + .map_err(|_| "Invalid GPU index")?; + let config_hash = parts[7..parts.len() - 2].join("-"); + (config_hash, Some(gpu_idx)) + } else { + // No GPU suffix: normal container + let config_hash = parts[7..].join("-"); + (config_hash, None) + }; + Ok(Self { task_id, config_hash, + gpu_index, }) } else { Err("Invalid container name format") @@ -40,17 +56,19 @@ mod tests { let container_name = "prime-task-123e4567-e89b-12d3-a456-426614174000-a1b2c3d4"; let result = TaskContainer::from_str(container_name); assert_eq!( - result.map(|c| c.task_id), + result.as_ref().map(|c| c.task_id.clone()), Ok("123e4567-e89b-12d3-a456-426614174000".to_string()) ); + assert_eq!(result.as_ref().map(|c| c.gpu_index), Ok(None)); // Test with leading slash let container_name = "/prime-task-123e4567-e89b-12d3-a456-426614174000-a1b2c3d4"; let result = TaskContainer::from_str(container_name); assert_eq!( - result.map(|c| c.task_id), + result.as_ref().map(|c| c.task_id.clone()), Ok("123e4567-e89b-12d3-a456-426614174000".to_string()) ); + assert_eq!(result.as_ref().map(|c| c.gpu_index), Ok(None)); // Test with invalid format let container_name = "not-a-prime-task"; @@ -67,4 +85,32 @@ mod tests { let result = TaskContainer::from_str(container_name); assert!(result.is_err()); } + + #[test] + fn test_gpu_partitioned_container_names() { + // Test with GPU suffix + let container_name = "prime-task-123e4567-e89b-12d3-a456-426614174000-a1b2c3d4-gpu0"; + let result = TaskContainer::from_str(container_name).unwrap(); + assert_eq!(result.task_id, "123e4567-e89b-12d3-a456-426614174000"); + assert_eq!(result.config_hash, "a1b2c3d4"); + assert_eq!(result.gpu_index, Some(0)); + + // Test with GPU suffix (gpu1) + let container_name = "prime-task-123e4567-e89b-12d3-a456-426614174000-a1b2c3d4-gpu1"; + let result = TaskContainer::from_str(container_name).unwrap(); + assert_eq!(result.task_id, "123e4567-e89b-12d3-a456-426614174000"); + assert_eq!(result.config_hash, "a1b2c3d4"); + assert_eq!(result.gpu_index, Some(1)); + + // Test data_dir_name doesn't include GPU suffix + assert_eq!( + result.data_dir_name(), + "prime-task-123e4567-e89b-12d3-a456-426614174000" + ); + + // Test with invalid GPU index + let container_name = "prime-task-123e4567-e89b-12d3-a456-426614174000-a1b2c3d4-gpu-invalid"; + let result = TaskContainer::from_str(container_name); + assert!(result.is_err()); + } } From 6b7a812d1db44cc2b37bbf0f8c49dc94b3fed007 Mon Sep 17 00:00:00 2001 From: Jannik Straube Date: Thu, 26 Jun 2025 18:08:13 +1000 Subject: [PATCH 2/3] add more log details --- crates/worker/src/docker/service.rs | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/crates/worker/src/docker/service.rs b/crates/worker/src/docker/service.rs index 4e005508..6b7f33ff 100644 --- a/crates/worker/src/docker/service.rs +++ b/crates/worker/src/docker/service.rs @@ -441,9 +441,12 @@ impl DockerService { } }; - // Add header for partitioned tasks + // Add prominent header for partitioned tasks if task.partition_by_gpu && gpu_index.is_some() { - all_logs.push(format!("=== GPU {} ===", gpu_index.unwrap())); + let gpu_num = gpu_index.unwrap(); + all_logs.push(format!("\n{}", "=".repeat(60))); + all_logs.push(format!(" GPU {} LOGS", gpu_num)); + all_logs.push(format!("{}", "=".repeat(60))); } all_logs.push(logs); } From 77454e28dcdc18b30403045d5c62bf9f3dfe8a03 Mon Sep 17 00:00:00 2001 From: Jannik Straube Date: Thu, 26 Jun 2025 18:33:31 +1000 Subject: [PATCH 3/3] fix tests --- crates/worker/src/docker/service.rs | 31 ++++++----- crates/worker/src/docker/task_container.rs | 61 +++++++++++++--------- 2 files changed, 53 insertions(+), 39 deletions(-) diff --git a/crates/worker/src/docker/service.rs b/crates/worker/src/docker/service.rs index 6b7f33ff..1b58284a 100644 --- a/crates/worker/src/docker/service.rs +++ b/crates/worker/src/docker/service.rs @@ -11,7 +11,6 @@ use shared::models::task::Task; use shared::models::task::TaskState; use std::collections::HashMap; use std::path::Path; -use std::str::FromStr; use std::sync::Arc; use tokio::sync::Mutex; use tokio::time::{interval, Duration}; @@ -300,19 +299,22 @@ impl DockerService { // Prepare GPU specs for this specific container let container_gpu = if payload.partition_by_gpu && gpu_index.is_some() { // For partitioned tasks, assign specific GPU - gpu.map(|mut g| { - let gpu_idx = gpu_index.unwrap(); - // If indices are specified, use the actual GPU index - if let Some(indices) = &g.indices { - if (gpu_idx as usize) < indices.len() { - g.indices = Some(vec![indices[gpu_idx as usize]]); + if let Some(idx) = gpu_index { + gpu.map(|mut g| { + // If indices are specified, use the actual GPU index + if let Some(indices) = &g.indices { + if (idx as usize) < indices.len() { + g.indices = Some(vec![indices[idx as usize]]); + } + } else { + // Otherwise, use the sequential index + g.indices = Some(vec![idx]); } - } else { - // Otherwise, use the sequential index - g.indices = Some(vec![gpu_idx]); - } - g - }) + g + }) + } else { + None + } } else { // For non-partitioned tasks, use all GPUs gpu @@ -446,7 +448,7 @@ impl DockerService { let gpu_num = gpu_index.unwrap(); all_logs.push(format!("\n{}", "=".repeat(60))); all_logs.push(format!(" GPU {} LOGS", gpu_num)); - all_logs.push(format!("{}", "=".repeat(60))); + all_logs.push("=".repeat(60).to_string()); } all_logs.push(logs); } @@ -557,6 +559,7 @@ mod tests { use alloy::primitives::Address; use shared::models::task::Task; use shared::models::task::TaskState; + use std::str::FromStr; use uuid::Uuid; #[tokio::test] diff --git a/crates/worker/src/docker/task_container.rs b/crates/worker/src/docker/task_container.rs index 39d360ae..6fbb1d2d 100644 --- a/crates/worker/src/docker/task_container.rs +++ b/crates/worker/src/docker/task_container.rs @@ -15,24 +15,31 @@ impl std::str::FromStr for TaskContainer { type Err = &'static str; fn from_str(container_name: &str) -> Result { - let parts: Vec<&str> = container_name.trim_start_matches('/').split('-').collect(); - + let parts: Vec<&str> = container_name + .trim() + .trim_start_matches('/') + .split('-') + .collect(); if parts.len() >= 8 && parts[0] == "prime" && parts[1] == "task" { let task_id = parts[2..7].join("-"); // Check if the container name has a GPU suffix - let (config_hash, gpu_index) = if parts.len() >= 9 && parts[parts.len() - 2] == "gpu" { - // Has GPU suffix: extract GPU index - let gpu_idx = parts[parts.len() - 1] - .parse::() - .map_err(|_| "Invalid GPU index")?; - let config_hash = parts[7..parts.len() - 2].join("-"); - (config_hash, Some(gpu_idx)) - } else { - // No GPU suffix: normal container - let config_hash = parts[7..].join("-"); - (config_hash, None) - }; + let (config_hash, gpu_index) = + if parts.len() >= 9 && parts[parts.len() - 1].starts_with("gpu") { + // Has GPU suffix: extract GPU index from "gpu0", "gpu1", etc. + let gpu_part = parts[parts.len() - 1]; + let gpu_idx = gpu_part + .strip_prefix("gpu") + .ok_or("Invalid GPU suffix format")? + .parse::() + .map_err(|_| "Invalid GPU index")?; + let config_hash = parts[7..parts.len() - 1].join("-"); + (config_hash, Some(gpu_idx)) + } else { + // No GPU suffix: normal container + let config_hash = parts[7..].join("-"); + (config_hash, None) + }; Ok(Self { task_id, @@ -89,27 +96,31 @@ mod tests { #[test] fn test_gpu_partitioned_container_names() { // Test with GPU suffix - let container_name = "prime-task-123e4567-e89b-12d3-a456-426614174000-a1b2c3d4-gpu0"; + let container_name = + " prime-task-c45f0b5b-683b-400a-9452-132d0c1bd00e-73adfcfcfbf417c1-gpu1"; let result = TaskContainer::from_str(container_name).unwrap(); - assert_eq!(result.task_id, "123e4567-e89b-12d3-a456-426614174000"); - assert_eq!(result.config_hash, "a1b2c3d4"); - assert_eq!(result.gpu_index, Some(0)); + println!("result: {:?}", result); + assert_eq!(result.task_id, "c45f0b5b-683b-400a-9452-132d0c1bd00e"); + assert_eq!(result.config_hash, "73adfcfcfbf417c1"); + assert_eq!(result.gpu_index, Some(1)); // Test with GPU suffix (gpu1) - let container_name = "prime-task-123e4567-e89b-12d3-a456-426614174000-a1b2c3d4-gpu1"; + let container_name = + " prime-task-c45f0b5b-683b-400a-9452-132d0c1bd00e-73adfcfcfbf417c1-gpu2"; let result = TaskContainer::from_str(container_name).unwrap(); - assert_eq!(result.task_id, "123e4567-e89b-12d3-a456-426614174000"); - assert_eq!(result.config_hash, "a1b2c3d4"); - assert_eq!(result.gpu_index, Some(1)); + assert_eq!(result.task_id, "c45f0b5b-683b-400a-9452-132d0c1bd00e"); + assert_eq!(result.config_hash, "73adfcfcfbf417c1"); + assert_eq!(result.gpu_index, Some(2)); // Test data_dir_name doesn't include GPU suffix assert_eq!( result.data_dir_name(), - "prime-task-123e4567-e89b-12d3-a456-426614174000" + "prime-task-c45f0b5b-683b-400a-9452-132d0c1bd00e" ); - // Test with invalid GPU index - let container_name = "prime-task-123e4567-e89b-12d3-a456-426614174000-a1b2c3d4-gpu-invalid"; + // Test with invalid GPU index (non-numeric) + let container_name = + "prime-task-c45f0b5b-683b-400a-9452-132d0c1bd00e-73adfcfcfbf417c1-gpuinvalid"; let result = TaskContainer::from_str(container_name); assert!(result.is_err()); }