diff --git a/discovery/src/api/routes/node.rs b/discovery/src/api/routes/node.rs index 1c676e85..62cceb2c 100644 --- a/discovery/src/api/routes/node.rs +++ b/discovery/src/api/routes/node.rs @@ -33,6 +33,24 @@ pub async fn register_node( return HttpResponse::BadRequest() .json(ApiResponse::new(false, "Invalid x-address header")); } + if let Some(contracts) = data.contracts.clone() { + if (contracts + .compute_registry + .get_node( + node.provider_address.parse().unwrap(), + node.id.parse().unwrap(), + ) + .await) + .is_err() + { + return HttpResponse::BadRequest().json(ApiResponse::new( + false, + "Node not found in compute registry", + )); + } + } + + let node_store = data.node_store.clone(); let update_node = node.clone(); let existing_node = data.node_store.get_node(update_node.id.clone()); diff --git a/discovery/src/chainsync/sync.rs b/discovery/src/chainsync/sync.rs index 002a1c96..76d874e6 100644 --- a/discovery/src/chainsync/sync.rs +++ b/discovery/src/chainsync/sync.rs @@ -33,7 +33,6 @@ impl ChainSync { last_chain_sync, } } - async fn sync_single_node( node_store: Arc, contracts: Arc, @@ -41,95 +40,103 @@ impl ChainSync { ) -> Result<(), Error> { let mut n = node.clone(); - // Safely parse provider_address and node_address - let provider_address = Address::from_str(&node.provider_address).map_err(|e| { - eprintln!("Failed to parse provider address: {}", e); - anyhow::anyhow!("Invalid provider address") - })?; - - let node_address = Address::from_str(&node.id).map_err(|e| { - eprintln!("Failed to parse node address: {}", e); - anyhow::anyhow!("Invalid node address") - })?; + async fn sync_single_node( + node_store: Arc, + contracts: Arc, + node: DiscoveryNode, + ) -> Result<(), Error> { + let mut n = node.clone(); - let node_info = contracts - .compute_registry - .get_node(provider_address, node_address) - .await - .map_err(|e| { - eprintln!("Error retrieving node info: {}", e); - anyhow::anyhow!("Failed to retrieve node info") + // Safely parse provider_address and node_address + let provider_address = Address::from_str(&node.provider_address).map_err(|e| { + eprintln!("Failed to parse provider address: {}", e); + anyhow::anyhow!("Invalid provider address") })?; - let provider_info = contracts - .compute_registry - .get_provider(provider_address) - .await - .map_err(|e| { - eprintln!("Error retrieving provider info: {}", e); - anyhow::anyhow!("Failed to retrieve provider info") + let node_address = Address::from_str(&node.id).map_err(|e| { + eprintln!("Failed to parse node address: {}", e); + anyhow::anyhow!("Invalid node address") })?; - let (is_active, is_validated) = node_info; - n.is_active = is_active; - n.is_validated = is_validated; - n.is_provider_whitelisted = provider_info.is_whitelisted; + let node_info = contracts + .compute_registry + .get_node(provider_address, node_address) + .await + .map_err(|e| { + eprintln!("Error retrieving node info: {}", e); + anyhow::anyhow!("Failed to retrieve node info") + })?; - // Handle potential errors from async calls - let is_blacklisted = contracts - .compute_pool - .is_node_blacklisted(node.node.compute_pool_id, node_address) - .await - .map_err(|e| { - eprintln!("Error checking if node is blacklisted: {}", e); - anyhow::anyhow!("Failed to check blacklist status") - })?; - n.is_blacklisted = is_blacklisted; - match node_store.update_node(n) { - Ok(_) => (), - Err(e) => { - error!("Error updating node: {}", e); + let provider_info = contracts + .compute_registry + .get_provider(provider_address) + .await + .map_err(|e| { + eprintln!("Error retrieving provider info: {}", e); + anyhow::anyhow!("Failed to retrieve provider info") + })?; + + let (is_active, is_validated) = node_info; + n.is_active = is_active; + n.is_validated = is_validated; + n.is_provider_whitelisted = provider_info.is_whitelisted; + + // Handle potential errors from async calls + let is_blacklisted = contracts + .compute_pool + .is_node_blacklisted(node.node.compute_pool_id, node_address) + .await + .map_err(|e| { + eprintln!("Error checking if node is blacklisted: {}", e); + anyhow::anyhow!("Failed to check blacklist status") + })?; + n.is_blacklisted = is_blacklisted; + match node_store.update_node(n) { + Ok(_) => (), + Err(e) => { + error!("Error updating node: {}", e); + } } - } - Ok(()) - } + Ok(()) + } - pub async fn run(&self) -> Result<(), Error> { - let node_store_clone = self.node_store.clone(); - let contracts_clone = self.contracts.clone(); - let cancel_token = self.cancel_token.clone(); - let chain_sync_interval = self.chain_sync_interval; - let last_chain_sync = self.last_chain_sync.clone(); + pub async fn run(&self) -> Result<(), Error> { + let node_store_clone = self.node_store.clone(); + let contracts_clone = self.contracts.clone(); + let cancel_token = self.cancel_token.clone(); + let chain_sync_interval = self.chain_sync_interval; + let last_chain_sync = self.last_chain_sync.clone(); - tokio::spawn(async move { - let mut interval = tokio::time::interval(chain_sync_interval); - loop { - tokio::select! { - _ = interval.tick() => { - let nodes = node_store_clone.get_nodes(); - match nodes { - Ok(nodes) => { - for node in nodes { - if let Err(e) = ChainSync::sync_single_node(node_store_clone.clone(), contracts_clone.clone(), node).await { - error!("Error syncing node: {}", e); + tokio::spawn(async move { + let mut interval = tokio::time::interval(chain_sync_interval); + loop { + tokio::select! { + _ = interval.tick() => { + let nodes = node_store_clone.get_nodes(); + match nodes { + Ok(nodes) => { + for node in nodes { + if let Err(e) = ChainSync::sync_single_node(node_store_clone.clone(), contracts_clone.clone(), node).await { + error!("Error syncing node: {}", e); + } } + // Update the last chain sync time + let mut last_sync = last_chain_sync.lock().await; + *last_sync = Some(SystemTime::now()); + } + Err(e) => { + error!("Error getting nodes: {}", e); } - // Update the last chain sync time - let mut last_sync = last_chain_sync.lock().await; - *last_sync = Some(SystemTime::now()); - } - Err(e) => { - error!("Error getting nodes: {}", e); } } - } - _ = cancel_token.cancelled() => { - break; + _ = cancel_token.cancelled() => { + break; + } } } - } - }); - Ok(()) + }); + Ok(()) + } } } diff --git a/orchestrator/src/discovery/monitor.rs b/orchestrator/src/discovery/monitor.rs index 9518da3f..ea5b12b5 100644 --- a/orchestrator/src/discovery/monitor.rs +++ b/orchestrator/src/discovery/monitor.rs @@ -206,6 +206,21 @@ impl<'b> DiscoveryMonitor<'b> { Ok(()) } + async fn get_nodes(&self) -> Result, Error> { + let discovery_nodes = self.fetch_nodes_from_discovery().await?; + + for discovery_node in &discovery_nodes { + if let Err(e) = self.sync_single_node_with_discovery(discovery_node).await { + error!("Error syncing node with discovery: {}", e); + } + } + Ok(discovery_nodes + .into_iter() + .map(OrchestratorNode::from) + .collect()) + } +} + async fn get_nodes(&self) -> Result, Error> { let discovery_nodes = self.fetch_nodes_from_discovery().await?; diff --git a/shared/src/web3/contracts/implementations/compute_pool_contract.rs b/shared/src/web3/contracts/implementations/compute_pool_contract.rs index 1e737a34..f6dd0b8a 100644 --- a/shared/src/web3/contracts/implementations/compute_pool_contract.rs +++ b/shared/src/web3/contracts/implementations/compute_pool_contract.rs @@ -168,6 +168,24 @@ impl ComputePool { Ok(result) } + pub async fn submit_work( + &self, + pool_id: U256, + node: Address, + data: Vec, + ) -> Result, Box> { + let result = self + .instance + .instance() + .function("submitWork", &[pool_id.into(), node.into(), data.into()])? + .send() + .await? + .watch() + .await?; + println!("Result: {:?}", result); + Ok(result) + } + pub async fn blacklist_node( &self, pool_id: u32, diff --git a/validator/src/main.rs b/validator/src/main.rs index d99b5f0d..ea2def4c 100644 --- a/validator/src/main.rs +++ b/validator/src/main.rs @@ -489,3 +489,56 @@ mod tests { assert_eq!(resp.result, expected_response.result); } } + +#[cfg(test)] +mod tests { + + use actix_web::{test, App}; + use actix_web::{ + web::{self, post}, + HttpResponse, Scope, + }; + use shared::models::challenge::{calc_matrix, ChallengeRequest, ChallengeResponse, FixedF64}; + + pub async fn handle_challenge(challenge: web::Json) -> HttpResponse { + let result = calc_matrix(&challenge); + HttpResponse::Ok().json(result) + } + + pub fn challenge_routes() -> Scope { + web::scope("/challenge") + .route("", post().to(handle_challenge)) + .route("/", post().to(handle_challenge)) + } + + #[actix_web::test] + async fn test_challenge_route() { + let app = test::init_service(App::new().service(challenge_routes())).await; + + let vec_a = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]; + let vec_b = [9.0, 8.0, 7.0, 6.0, 5.0, 4.0, 3.0, 2.0, 1.0]; + + // convert vectors to FixedF64 + let data_a: Vec = vec_a.iter().map(|x| FixedF64(*x)).collect(); + let data_b: Vec = vec_b.iter().map(|x| FixedF64(*x)).collect(); + + let challenge_request = ChallengeRequest { + rows_a: 3, + cols_a: 3, + data_a, + rows_b: 3, + cols_b: 3, + data_b, + }; + + let req = test::TestRequest::post() + .uri("/challenge") + .set_json(&challenge_request) + .to_request(); + + let resp: ChallengeResponse = test::call_and_read_body_json(&app, req).await; + let expected_response = calc_matrix(&challenge_request); + + assert_eq!(resp.result, expected_response.result); + } +} diff --git a/worker/src/api/server.rs b/worker/src/api/server.rs index 86cf8684..9175c19b 100644 --- a/worker/src/api/server.rs +++ b/worker/src/api/server.rs @@ -1,6 +1,7 @@ use crate::api::routes::challenge::challenge_routes; use crate::api::routes::invite::invite_routes; use crate::api::routes::task::task_routes; +use crate::console::Console; use crate::docker::DockerService; use crate::operations::heartbeat::service::HeartbeatService; use crate::state::system_state::SystemState; diff --git a/worker/src/checks/hardware/hardware_check.rs b/worker/src/checks/hardware/hardware_check.rs index 3d35d35b..dac3f807 100644 --- a/worker/src/checks/hardware/hardware_check.rs +++ b/worker/src/checks/hardware/hardware_check.rs @@ -141,11 +141,16 @@ impl HardwareChecker { } fn collect_gpu_specs(&self) -> Result, Box> { - Ok(detect_gpu().map(|gpu| GpuSpecs { - count: Some(gpu.count.unwrap_or(0)), - model: gpu.model, - memory_mb: gpu.memory_mb, - })) + let gpu_specs = detect_gpu(); + if gpu_specs.is_empty() { + return Ok(None); + } + + let main_gpu = gpu_specs + .into_iter() + .max_by_key(|gpu| gpu.count.unwrap_or(0)); + + Ok(main_gpu) } fn collect_memory_specs(&self) -> Result<(u32, u32), Box> { diff --git a/worker/src/console/logger.rs b/worker/src/console/logger.rs new file mode 100644 index 00000000..7e4f5271 --- /dev/null +++ b/worker/src/console/logger.rs @@ -0,0 +1,81 @@ +use console::{style, Term}; +use std::cmp; +use unicode_width::UnicodeWidthStr; + +pub struct Console; + +impl Console { + /// Maximum content width for the box. + const MAX_WIDTH: usize = 80; + + /// Calculates the content width for boxes. + /// It uses the available terminal width (minus a margin) and caps it at MAX_WIDTH. + fn get_content_width() -> usize { + let term_width = Term::stdout().size().1 as usize; + // Leave a margin of 10 columns. + let available = if term_width > 10 { + term_width - 10 + } else { + term_width + }; + cmp::min(available, Self::MAX_WIDTH) + } + + /// Centers a given text within a given width based on its display width. + fn center_text(text: &str, width: usize) -> String { + let text_width = UnicodeWidthStr::width(text); + if width > text_width { + let total_padding = width - text_width; + let left = total_padding / 2; + let right = total_padding - left; + format!("{}{}{}", " ".repeat(left), text, " ".repeat(right)) + } else { + text.to_string() + } + } + + /// Prints a section header as an aligned box. + pub fn section(title: &str) { + let content_width = Self::get_content_width(); + let top_border = format!("╔{}╗", "═".repeat(content_width)); + let centered_title = Self::center_text(title, content_width); + let middle_line = format!("║{}║", centered_title); + let bottom_border = format!("╚{}╝", "═".repeat(content_width)); + + println!(); + println!("{}", style(top_border).white().bold()); + println!("{}", style(middle_line).white().bold()); + println!("{}", style(bottom_border).white().bold()); + } + + /// Prints a sub-title. + pub fn title(text: &str) { + println!(); + println!("{}", style(text).white().bold()); + } + + /// Prints an informational message. + pub fn info(label: &str, value: &str) { + println!("{}: {}", style(label).dim().white(), style(value).white()); + } + + /// Prints a success message. + pub fn success(text: &str) { + println!("{} {}", style("✓").green().bold(), style(text).green()); + } + + /// Prints a warning message. + pub fn warning(text: &str) { + println!("{} {}", style("⚠").yellow().bold(), style(text).yellow()); + } + + /// Prints an error message. + pub fn error(text: &str) { + println!("{} {}", style("✗").red().bold(), style(text).red()); + } + + /// Prints a progress message. + pub fn progress(text: &str) { + println!("{} {}", style("→").cyan().bold(), style(text).cyan()); + } +}