diff --git a/discovery/src/api/routes/node.rs b/discovery/src/api/routes/node.rs index 1c676e85..54ab4edb 100644 --- a/discovery/src/api/routes/node.rs +++ b/discovery/src/api/routes/node.rs @@ -220,6 +220,7 @@ mod tests { count: Some(4), model: Some("A100".to_string()), memory_mb: Some(40000), + indices: Some(vec![0, 1, 2, 3]), }), cpu: Some(CpuSpecs { cores: Some(16), diff --git a/shared/src/models/node.rs b/shared/src/models/node.rs index 92b0a1f7..3c74e063 100644 --- a/shared/src/models/node.rs +++ b/shared/src/models/node.rs @@ -66,6 +66,7 @@ pub struct GpuSpecs { pub count: Option, pub model: Option, pub memory_mb: Option, + pub indices: Option>, } impl fmt::Display for GpuSpecs { @@ -445,6 +446,7 @@ mod tests { count: gpu_count, model: gpu_model.map(String::from), memory_mb: gpu_mem, + indices: None, }) } else { None @@ -744,6 +746,7 @@ mod tests { count: Some(4), model: Some("A100".to_string()), memory_mb: None, + indices: None, }), cpu: Some(CpuSpecs { cores: Some(16), diff --git a/worker/src/checks/hardware/gpu.rs b/worker/src/checks/hardware/gpu.rs index 541ffa59..9b58a3d8 100644 --- a/worker/src/checks/hardware/gpu.rs +++ b/worker/src/checks/hardware/gpu.rs @@ -5,52 +5,44 @@ use shared::models::node::GpuSpecs; use std::sync::Mutex; #[allow(dead_code)] -const BYTES_TO_GB: f64 = 1024.0 * 1024.0 * 1024.0; +const BYTES_TO_MB: u64 = 1024 * 1024; // Use lazy_static to initialize NVML once and reuse it lazy_static! { static ref NVML: Mutex> = Mutex::new(None); } +#[derive(Debug)] #[allow(dead_code)] -enum GpuDevice { - Available { - name: String, - memory: u64, - driver_version: String, - device_count: usize, - }, - NotAvailable(String), +struct GpuDevice { + name: String, + memory: u64, + driver_version: String, + count: u32, + indices: Vec, } -pub fn detect_gpu() -> Option { +pub fn detect_gpu() -> Vec { Console::title("GPU Detection"); - // Changed return type to GpuSpecs - match get_gpu_status() { - GpuDevice::Available { - name, - memory, - driver_version: _, - device_count, - } => Some(GpuSpecs { - // Create GpuSpecs directly - count: Some(device_count as u32), - model: Some( - name.to_lowercase() - .split_whitespace() - .collect::>() - .join("_"), - ), - memory_mb: Some((memory / 1024 / 1024) as u32), // Convert bytes to MB - }), - GpuDevice::NotAvailable(_) => { - Console::user_error("GPU not available"); - None - } + + let gpu_devices = get_gpu_status(); + if gpu_devices.is_empty() { + Console::user_error("No GPU devices detected"); + return vec![]; } + + gpu_devices + .into_iter() + .map(|device| GpuSpecs { + count: Some(device.count), + model: Some(device.name.to_lowercase()), + memory_mb: Some((device.memory / BYTES_TO_MB) as u32), + indices: Some(device.indices), + }) + .collect() } -fn get_gpu_status() -> GpuDevice { +fn get_gpu_status() -> Vec { let mut nvml_guard = NVML.lock().unwrap(); // Initialize NVML if not already initialized @@ -62,7 +54,10 @@ fn get_gpu_status() -> GpuDevice { .init() { Ok(nvml) => *nvml_guard = Some(nvml), - Err(e) => return GpuDevice::NotAvailable(format!("Failed to initialize NVML: {}", e)), + Err(e) => { + Console::user_error(&format!("Failed to initialize NVML: {}", e)); + return vec![]; + } } } @@ -71,30 +66,50 @@ fn get_gpu_status() -> GpuDevice { // Get device count let device_count = match nvml.device_count() { Ok(count) => count as usize, - Err(e) => return GpuDevice::NotAvailable(format!("Failed to get device count: {}", e)), + Err(e) => { + Console::user_error(&format!("Failed to get device count: {}", e)); + return vec![]; + } }; if device_count == 0 { - return GpuDevice::NotAvailable("No GPU devices detected".to_string()); + Console::user_error("No GPU devices detected"); + return vec![]; } - // Get first device info - // TODO: Get all devices - match nvml.device_by_index(0) { - Ok(device) => { - let name = device.name().unwrap_or_else(|_| "Unknown".to_string()); - let memory = device.memory_info().map(|m| m.total).unwrap_or(0); - let driver_version = nvml - .sys_driver_version() - .unwrap_or_else(|_| "Unknown".to_string()); + let mut device_map: std::collections::HashMap = + std::collections::HashMap::new(); + + for i in 0..device_count { + match nvml.device_by_index(i as u32) { + Ok(device) => { + let name = device.name().unwrap_or_else(|_| "Unknown".to_string()); + let memory = device.memory_info().map(|m| m.total).unwrap_or(0); + let driver_version = nvml + .sys_driver_version() + .unwrap_or_else(|_| "Unknown".to_string()); - GpuDevice::Available { - name, - memory, - driver_version, - device_count, + if let Some(existing_device) = device_map.get_mut(&name) { + existing_device.count += 1; + existing_device.indices.push(i as u32); + } else { + device_map.insert( + name.clone(), + GpuDevice { + name, + memory, + driver_version, + count: 1, + indices: vec![i as u32], + }, + ); + } + } + Err(e) => { + Console::user_error(&format!("Failed to get device {}: {}", i, e)); } } - Err(e) => GpuDevice::NotAvailable(format!("Failed to get device: {}", e)), } + + device_map.into_values().collect() } diff --git a/worker/src/checks/hardware/hardware_check.rs b/worker/src/checks/hardware/hardware_check.rs index 3d35d35b..dac3f807 100644 --- a/worker/src/checks/hardware/hardware_check.rs +++ b/worker/src/checks/hardware/hardware_check.rs @@ -141,11 +141,16 @@ impl HardwareChecker { } fn collect_gpu_specs(&self) -> Result, Box> { - Ok(detect_gpu().map(|gpu| GpuSpecs { - count: Some(gpu.count.unwrap_or(0)), - model: gpu.model, - memory_mb: gpu.memory_mb, - })) + let gpu_specs = detect_gpu(); + if gpu_specs.is_empty() { + return Ok(None); + } + + let main_gpu = gpu_specs + .into_iter() + .max_by_key(|gpu| gpu.count.unwrap_or(0)); + + Ok(main_gpu) } fn collect_memory_specs(&self) -> Result<(u32, u32), Box> { diff --git a/worker/src/cli/command.rs b/worker/src/cli/command.rs index a1e56283..f4af1111 100644 --- a/worker/src/cli/command.rs +++ b/worker/src/cli/command.rs @@ -347,14 +347,6 @@ pub async fn execute_command( } } - let has_gpu = match node_config.compute_specs { - Some(ref specs) => specs.gpu.is_some(), - None => { - Console::warning("Compute specs are not available, assuming no GPU."); - false - } - }; - let metrics_store = Arc::new(MetricsStore::new()); let heartbeat_metrics_clone = metrics_store.clone(); let bridge_contracts = contracts.clone(); @@ -380,9 +372,13 @@ pub async fn execute_command( .as_ref() .map(|specs| specs.ram_mb.unwrap_or(0)); + let gpu = node_config + .compute_specs + .clone() + .and_then(|specs| specs.gpu.clone()); let docker_service = Arc::new(DockerService::new( cancellation_token.clone(), - has_gpu, + gpu, system_memory, task_bridge.socket_path.clone(), docker_storage_path, diff --git a/worker/src/docker/docker_manager.rs b/worker/src/docker/docker_manager.rs index 96469d43..d5aefe51 100644 --- a/worker/src/docker/docker_manager.rs +++ b/worker/src/docker/docker_manager.rs @@ -11,6 +11,7 @@ use bollard::volume::CreateVolumeOptions; use bollard::Docker; use futures_util::StreamExt; use log::{debug, error, info}; +use shared::models::node::GpuSpecs; use std::collections::HashMap; use std::time::Duration; use strip_ansi_escapes::strip; @@ -104,7 +105,7 @@ impl DockerManager { name: &str, env_vars: Option>, command: Option>, - gpu_enabled: bool, + gpu: Option, // Simple Vec of (host_path, container_path, read_only) volumes: Option>, shm_size: Option, @@ -189,13 +190,25 @@ impl DockerManager { Some(binds) }; - let host_config = if gpu_enabled { + let host_config = if gpu.is_some() { + let gpu = gpu.unwrap(); + let device_ids = match &gpu.indices { + Some(indices) if !indices.is_empty() => { + // Use specific GPU indices if available + indices.iter().map(|i| i.to_string()).collect() + } + _ => { + // Request all available GPUs if no specific indices + vec!["all".to_string()] + } + }; + Some(HostConfig { extra_hosts: Some(vec!["host.docker.internal:host-gateway".into()]), device_requests: Some(vec![DeviceRequest { - driver: Some("".into()), - count: Some(-1), - device_ids: None, + driver: Some("nvidia".into()), + count: None, + device_ids: Some(device_ids), capabilities: Some(vec![vec!["gpu".into()]]), options: Some(HashMap::new()), }]), diff --git a/worker/src/docker/service.rs b/worker/src/docker/service.rs index 69f349de..7ca8cfaf 100644 --- a/worker/src/docker/service.rs +++ b/worker/src/docker/service.rs @@ -4,6 +4,7 @@ use super::DockerState; use crate::console::Console; use bollard::models::ContainerStateStatusEnum; use chrono::{DateTime, Utc}; +use shared::models::node::GpuSpecs; use shared::models::task::Task; use shared::models::task::TaskState; use std::collections::HashMap; @@ -17,7 +18,7 @@ pub struct DockerService { docker_manager: Arc, cancellation_token: CancellationToken, pub state: Arc, - has_gpu: bool, + gpu: Option, system_memory_mb: Option, task_bridge_socket_path: String, node_address: String, @@ -28,7 +29,7 @@ const TASK_PREFIX: &str = "prime-task"; impl DockerService { pub fn new( cancellation_token: CancellationToken, - has_gpu: bool, + gpu: Option, system_memory_mb: Option, task_bridge_socket_path: String, storage_path: Option, @@ -39,7 +40,7 @@ impl DockerService { docker_manager, cancellation_token, state: Arc::new(DockerState::new()), - has_gpu, + gpu, system_memory_mb, task_bridge_socket_path, node_address, @@ -151,7 +152,7 @@ impl DockerService { Console::info("DockerService", "Starting new container ..."); let manager_clone = manager_clone.clone(); let state_clone = task_state_clone.clone(); - let has_gpu = self.has_gpu; + let gpu = self.gpu.clone(); let system_memory_mb = self.system_memory_mb; let task_bridge_socket_path = self.task_bridge_socket_path.clone(); let node_address = self.node_address.clone(); @@ -195,7 +196,7 @@ impl DockerService { 67108864 // Default to 64MB in bytes } }; - match manager_clone.start_container(&payload.image, &container_task_id, Some(env_vars), Some(cmd), has_gpu, Some(volumes), Some(shm_size)).await { + match manager_clone.start_container(&payload.image, &container_task_id, Some(env_vars), Some(cmd), gpu, Some(volumes), Some(shm_size)).await { Ok(container_id) => { Console::info("DockerService", &format!("Container started with id: {}", container_id)); }, @@ -320,7 +321,7 @@ mod tests { let cancellation_token = CancellationToken::new(); let docker_service = DockerService::new( cancellation_token.clone(), - false, + None, Some(1024), "/tmp/com.prime.worker/metrics.sock".to_string(), None, @@ -365,7 +366,7 @@ mod tests { let cancellation_token = CancellationToken::new(); let docker_service = DockerService::new( cancellation_token.clone(), - false, + None, Some(1024), "/tmp/com.prime.worker/metrics.sock".to_string(), None,