Skip to content
This repository was archived by the owner on Jan 27, 2026. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions discovery/src/api/routes/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ mod tests {
count: Some(4),
model: Some("A100".to_string()),
memory_mb: Some(40000),
indices: Some(vec![0, 1, 2, 3]),
}),
cpu: Some(CpuSpecs {
cores: Some(16),
Expand Down
3 changes: 3 additions & 0 deletions shared/src/models/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ pub struct GpuSpecs {
pub count: Option<u32>,
pub model: Option<String>,
pub memory_mb: Option<u32>,
pub indices: Option<Vec<u32>>,
}

impl fmt::Display for GpuSpecs {
Expand Down Expand Up @@ -445,6 +446,7 @@ mod tests {
count: gpu_count,
model: gpu_model.map(String::from),
memory_mb: gpu_mem,
indices: None,
})
} else {
None
Expand Down Expand Up @@ -744,6 +746,7 @@ mod tests {
count: Some(4),
model: Some("A100".to_string()),
memory_mb: None,
indices: None,
}),
cpu: Some(CpuSpecs {
cores: Some(16),
Expand Down
117 changes: 66 additions & 51 deletions worker/src/checks/hardware/gpu.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,52 +5,44 @@ use shared::models::node::GpuSpecs;
use std::sync::Mutex;

#[allow(dead_code)]
const BYTES_TO_GB: f64 = 1024.0 * 1024.0 * 1024.0;
const BYTES_TO_MB: u64 = 1024 * 1024;

// Use lazy_static to initialize NVML once and reuse it
lazy_static! {
static ref NVML: Mutex<Option<Nvml>> = Mutex::new(None);
}

#[derive(Debug)]
#[allow(dead_code)]
enum GpuDevice {
Available {
name: String,
memory: u64,
driver_version: String,
device_count: usize,
},
NotAvailable(String),
struct GpuDevice {
name: String,
memory: u64,
driver_version: String,
count: u32,
indices: Vec<u32>,
}

pub fn detect_gpu() -> Option<GpuSpecs> {
pub fn detect_gpu() -> Vec<GpuSpecs> {
Console::title("GPU Detection");
// Changed return type to GpuSpecs
match get_gpu_status() {
GpuDevice::Available {
name,
memory,
driver_version: _,
device_count,
} => Some(GpuSpecs {
// Create GpuSpecs directly
count: Some(device_count as u32),
model: Some(
name.to_lowercase()
.split_whitespace()
.collect::<Vec<&str>>()
.join("_"),
),
memory_mb: Some((memory / 1024 / 1024) as u32), // Convert bytes to MB
}),
GpuDevice::NotAvailable(_) => {
Console::user_error("GPU not available");
None
}

let gpu_devices = get_gpu_status();
if gpu_devices.is_empty() {
Console::user_error("No GPU devices detected");
return vec![];
}

gpu_devices
.into_iter()
.map(|device| GpuSpecs {
count: Some(device.count),
model: Some(device.name.to_lowercase()),
memory_mb: Some((device.memory / BYTES_TO_MB) as u32),
indices: Some(device.indices),
})
.collect()
}

fn get_gpu_status() -> GpuDevice {
fn get_gpu_status() -> Vec<GpuDevice> {
let mut nvml_guard = NVML.lock().unwrap();

// Initialize NVML if not already initialized
Expand All @@ -62,7 +54,10 @@ fn get_gpu_status() -> GpuDevice {
.init()
{
Ok(nvml) => *nvml_guard = Some(nvml),
Err(e) => return GpuDevice::NotAvailable(format!("Failed to initialize NVML: {}", e)),
Err(e) => {
Console::user_error(&format!("Failed to initialize NVML: {}", e));
return vec![];
}
}
}

Expand All @@ -71,30 +66,50 @@ fn get_gpu_status() -> GpuDevice {
// Get device count
let device_count = match nvml.device_count() {
Ok(count) => count as usize,
Err(e) => return GpuDevice::NotAvailable(format!("Failed to get device count: {}", e)),
Err(e) => {
Console::user_error(&format!("Failed to get device count: {}", e));
return vec![];
}
};

if device_count == 0 {
return GpuDevice::NotAvailable("No GPU devices detected".to_string());
Console::user_error("No GPU devices detected");
return vec![];
}

// Get first device info
// TODO: Get all devices
match nvml.device_by_index(0) {
Ok(device) => {
let name = device.name().unwrap_or_else(|_| "Unknown".to_string());
let memory = device.memory_info().map(|m| m.total).unwrap_or(0);
let driver_version = nvml
.sys_driver_version()
.unwrap_or_else(|_| "Unknown".to_string());
let mut device_map: std::collections::HashMap<String, GpuDevice> =
std::collections::HashMap::new();

for i in 0..device_count {
match nvml.device_by_index(i as u32) {
Ok(device) => {
let name = device.name().unwrap_or_else(|_| "Unknown".to_string());
let memory = device.memory_info().map(|m| m.total).unwrap_or(0);
let driver_version = nvml
.sys_driver_version()
.unwrap_or_else(|_| "Unknown".to_string());

GpuDevice::Available {
name,
memory,
driver_version,
device_count,
if let Some(existing_device) = device_map.get_mut(&name) {
existing_device.count += 1;
existing_device.indices.push(i as u32);
} else {
device_map.insert(
name.clone(),
GpuDevice {
name,
memory,
driver_version,
count: 1,
indices: vec![i as u32],
},
);
}
Comment on lines +86 to +106
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I looked up a few SMI examples and it seems like the "name" field is usually unique to a product, but also this doesn't seem to be guaranteed, e.g. with VRAM modded cards (that might get more popular in the future) it could cause errors not to check the VRAM available. It might be worthwhile appending the memory as a string to the name and using that as a key.

}
Err(e) => {
Console::user_error(&format!("Failed to get device {}: {}", i, e));
}
}
Err(e) => GpuDevice::NotAvailable(format!("Failed to get device: {}", e)),
}

device_map.into_values().collect()
}
15 changes: 10 additions & 5 deletions worker/src/checks/hardware/hardware_check.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,11 +141,16 @@ impl HardwareChecker {
}

fn collect_gpu_specs(&self) -> Result<Option<GpuSpecs>, Box<dyn std::error::Error>> {
Ok(detect_gpu().map(|gpu| GpuSpecs {
count: Some(gpu.count.unwrap_or(0)),
model: gpu.model,
memory_mb: gpu.memory_mb,
}))
let gpu_specs = detect_gpu();
if gpu_specs.is_empty() {
return Ok(None);
}

let main_gpu = gpu_specs
.into_iter()
.max_by_key(|gpu| gpu.count.unwrap_or(0));

Ok(main_gpu)
}

fn collect_memory_specs(&self) -> Result<(u32, u32), Box<dyn std::error::Error>> {
Expand Down
14 changes: 5 additions & 9 deletions worker/src/cli/command.rs
Original file line number Diff line number Diff line change
Expand Up @@ -347,14 +347,6 @@ pub async fn execute_command(
}
}

let has_gpu = match node_config.compute_specs {
Some(ref specs) => specs.gpu.is_some(),
None => {
Console::warning("Compute specs are not available, assuming no GPU.");
false
}
};

let metrics_store = Arc::new(MetricsStore::new());
let heartbeat_metrics_clone = metrics_store.clone();
let bridge_contracts = contracts.clone();
Expand All @@ -380,9 +372,13 @@ pub async fn execute_command(
.as_ref()
.map(|specs| specs.ram_mb.unwrap_or(0));

let gpu = node_config
.compute_specs
.clone()
.and_then(|specs| specs.gpu.clone());
let docker_service = Arc::new(DockerService::new(
cancellation_token.clone(),
has_gpu,
gpu,
system_memory,
task_bridge.socket_path.clone(),
docker_storage_path,
Expand Down
23 changes: 18 additions & 5 deletions worker/src/docker/docker_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use bollard::volume::CreateVolumeOptions;
use bollard::Docker;
use futures_util::StreamExt;
use log::{debug, error, info};
use shared::models::node::GpuSpecs;
use std::collections::HashMap;
use std::time::Duration;
use strip_ansi_escapes::strip;
Expand Down Expand Up @@ -104,7 +105,7 @@ impl DockerManager {
name: &str,
env_vars: Option<HashMap<String, String>>,
command: Option<Vec<String>>,
gpu_enabled: bool,
gpu: Option<GpuSpecs>,
// Simple Vec of (host_path, container_path, read_only)
volumes: Option<Vec<(String, String, bool)>>,
shm_size: Option<u64>,
Expand Down Expand Up @@ -189,13 +190,25 @@ impl DockerManager {
Some(binds)
};

let host_config = if gpu_enabled {
let host_config = if gpu.is_some() {
let gpu = gpu.unwrap();
let device_ids = match &gpu.indices {
Some(indices) if !indices.is_empty() => {
// Use specific GPU indices if available
indices.iter().map(|i| i.to_string()).collect()
}
_ => {
// Request all available GPUs if no specific indices
vec!["all".to_string()]
}
};

Some(HostConfig {
extra_hosts: Some(vec!["host.docker.internal:host-gateway".into()]),
device_requests: Some(vec![DeviceRequest {
driver: Some("".into()),
count: Some(-1),
device_ids: None,
driver: Some("nvidia".into()),
count: None,
device_ids: Some(device_ids),
capabilities: Some(vec![vec!["gpu".into()]]),
options: Some(HashMap::new()),
}]),
Expand Down
15 changes: 8 additions & 7 deletions worker/src/docker/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use super::DockerState;
use crate::console::Console;
use bollard::models::ContainerStateStatusEnum;
use chrono::{DateTime, Utc};
use shared::models::node::GpuSpecs;
use shared::models::task::Task;
use shared::models::task::TaskState;
use std::collections::HashMap;
Expand All @@ -17,7 +18,7 @@ pub struct DockerService {
docker_manager: Arc<DockerManager>,
cancellation_token: CancellationToken,
pub state: Arc<DockerState>,
has_gpu: bool,
gpu: Option<GpuSpecs>,
system_memory_mb: Option<u32>,
task_bridge_socket_path: String,
node_address: String,
Expand All @@ -28,7 +29,7 @@ const TASK_PREFIX: &str = "prime-task";
impl DockerService {
pub fn new(
cancellation_token: CancellationToken,
has_gpu: bool,
gpu: Option<GpuSpecs>,
system_memory_mb: Option<u32>,
task_bridge_socket_path: String,
storage_path: Option<String>,
Expand All @@ -39,7 +40,7 @@ impl DockerService {
docker_manager,
cancellation_token,
state: Arc::new(DockerState::new()),
has_gpu,
gpu,
system_memory_mb,
task_bridge_socket_path,
node_address,
Expand Down Expand Up @@ -151,7 +152,7 @@ impl DockerService {
Console::info("DockerService", "Starting new container ...");
let manager_clone = manager_clone.clone();
let state_clone = task_state_clone.clone();
let has_gpu = self.has_gpu;
let gpu = self.gpu.clone();
let system_memory_mb = self.system_memory_mb;
let task_bridge_socket_path = self.task_bridge_socket_path.clone();
let node_address = self.node_address.clone();
Expand Down Expand Up @@ -195,7 +196,7 @@ impl DockerService {
67108864 // Default to 64MB in bytes
}
};
match manager_clone.start_container(&payload.image, &container_task_id, Some(env_vars), Some(cmd), has_gpu, Some(volumes), Some(shm_size)).await {
match manager_clone.start_container(&payload.image, &container_task_id, Some(env_vars), Some(cmd), gpu, Some(volumes), Some(shm_size)).await {
Ok(container_id) => {
Console::info("DockerService", &format!("Container started with id: {}", container_id));
},
Expand Down Expand Up @@ -320,7 +321,7 @@ mod tests {
let cancellation_token = CancellationToken::new();
let docker_service = DockerService::new(
cancellation_token.clone(),
false,
None,
Some(1024),
"/tmp/com.prime.worker/metrics.sock".to_string(),
None,
Expand Down Expand Up @@ -365,7 +366,7 @@ mod tests {
let cancellation_token = CancellationToken::new();
let docker_service = DockerService::new(
cancellation_token.clone(),
false,
None,
Some(1024),
"/tmp/com.prime.worker/metrics.sock".to_string(),
None,
Expand Down