diff --git a/.github/workflows/checks.yml b/.github/workflows/checks.yml index 42ab51d3..f81163f4 100644 --- a/.github/workflows/checks.yml +++ b/.github/workflows/checks.yml @@ -18,7 +18,6 @@ env: LANG: C.UTF-8 LC_ALL: C.UTF-8 - jobs: check: name: Format & Lint @@ -57,4 +56,4 @@ jobs: if: success() || failure() run: | redis-server --version - RUST_BACKTRACE=1 cargo test -- --nocapture \ No newline at end of file + RUST_BACKTRACE=1 cargo test -- --nocapture diff --git a/Cargo.lock b/Cargo.lock index 8e7d544c..6e8a731e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8605,6 +8605,7 @@ dependencies = [ "thiserror 2.0.12", "time", "tokio", + "tokio-stream", "tokio-util", "toml", "tracing", diff --git a/crates/worker/Cargo.toml b/crates/worker/Cargo.toml index 506afc4c..91f48dd1 100644 --- a/crates/worker/Cargo.toml +++ b/crates/worker/Cargo.toml @@ -34,7 +34,7 @@ anyhow = { workspace = true } thiserror = "2.0.11" toml = { workspace = true } ctrlc = "3.4.5" -tokio-util = { workspace = true } +tokio-util = { workspace = true, features = ["rt"] } futures = { workspace = true } chrono = { workspace = true } serial_test = "0.5.1" @@ -55,4 +55,5 @@ iroh = { workspace = true } rand_v8 = { workspace = true } rand_core_v6 = { workspace = true } dashmap = "6.1.0" +tokio-stream = { version = "0.1.17", features = ["net"] } homedir = "0.3" diff --git a/crates/worker/src/cli/command.rs b/crates/worker/src/cli/command.rs index e54ce3f4..6a75330d 100644 --- a/crates/worker/src/cli/command.rs +++ b/crates/worker/src/cli/command.rs @@ -453,7 +453,7 @@ pub async fn execute_command( gpu, system_memory, task_bridge - .socket_path + .get_socket_path() .to_str() .expect("path is valid utf-8 string") .to_string(), @@ -469,11 +469,10 @@ pub async fn execute_command( let bridge_cancellation_token = cancellation_token.clone(); tokio::spawn(async move { - let bridge_clone = task_bridge.clone(); tokio::select! { _ = bridge_cancellation_token.cancelled() => { } - _ = bridge_clone.run() => { + _ = task_bridge.run() => { } } }); diff --git a/crates/worker/src/docker/taskbridge/bridge.rs b/crates/worker/src/docker/taskbridge/bridge.rs index 6faa5a83..9a461b43 100644 --- a/crates/worker/src/docker/taskbridge/bridge.rs +++ b/crates/worker/src/docker/taskbridge/bridge.rs @@ -2,13 +2,19 @@ use crate::docker::taskbridge::file_handler; use crate::docker::taskbridge::json_helper; use crate::metrics::store::MetricsStore; use crate::state::system_state::SystemState; +use anyhow::bail; use anyhow::Result; +use futures::future::BoxFuture; +use futures::stream::FuturesUnordered; +use futures::FutureExt; +use futures::StreamExt as _; use log::{debug, error, info, warn}; use serde::{Deserialize, Serialize}; use shared::models::node::Node; use shared::web3::contracts::core::builder::Contracts; use shared::web3::wallet::Wallet; use shared::web3::wallet::WalletProvider; +use std::collections::HashSet; #[cfg(unix)] use std::os::unix::fs::PermissionsExt; use std::sync::Arc; @@ -19,13 +25,21 @@ use tokio::{io::BufReader, net::UnixListener}; const DEFAULT_SOCKET_FILE: &str = "prime-worker/com.prime.worker/metrics.sock"; pub struct TaskBridge { - pub socket_path: std::path::PathBuf, - pub metrics_store: Arc, - pub contracts: Option>, - pub node_config: Option, - pub node_wallet: Option, - pub docker_storage_path: String, - pub state: Arc, + socket_path: std::path::PathBuf, + config: TaskBridgeConfig, +} + +#[derive(Clone)] +struct TaskBridgeConfig { + metrics_store: Arc, + + // TODO: the optional values are only used for testing; refactor + // the tests such that these aren't optional + contracts: Option>, + node_config: Option, + node_wallet: Option, + docker_storage_path: String, + state: Arc, } #[derive(Deserialize, Serialize, Debug)] @@ -37,7 +51,7 @@ struct MetricInput { impl TaskBridge { #[allow(clippy::too_many_arguments)] - pub fn new( + pub(crate) fn new( socket_path: Option<&str>, metrics_store: Arc, contracts: Option>, @@ -45,7 +59,7 @@ impl TaskBridge { node_wallet: Option, docker_storage_path: String, state: Arc, - ) -> Result> { + ) -> Result { let path = match socket_path { Some(path) => std::path::PathBuf::from(path), None => { @@ -55,144 +69,30 @@ impl TaskBridge { } }; - Ok(Arc::new(Self { + Ok(Self { socket_path: path, - metrics_store, - contracts, - node_config, - node_wallet, - docker_storage_path, - state, - })) - } - - async fn handle_metric(self: Arc, input: &MetricInput) -> Result<()> { - debug!("Processing metric message"); - for (key, value) in input.metrics.iter() { - debug!("Metric - Key: {key}, Value: {value}"); - let _ = self - .metrics_store - .update_metric( - input.task_id.clone(), - key.to_string(), - value.as_f64().unwrap_or(0.0), - ) - .await; - } - Ok(()) + config: TaskBridgeConfig { + metrics_store, + contracts, + node_config, + node_wallet, + docker_storage_path, + state, + }, + }) } - fn handle_file_upload(self: Arc, json_str: &str) -> Result<()> { - debug!("Handling file upload"); - if let Ok(file_info) = serde_json::from_str::(json_str) { - let task_id = file_info["task_id"].as_str().unwrap_or("unknown"); - - // Handle file upload if save_path is present - if let Some(file_name) = file_info["output/save_path"].as_str() { - info!("Handling file upload for task_id: {task_id}, file: {file_name}"); - - let storage_path_inner = self.docker_storage_path.clone(); - let task_id_inner = task_id.to_string(); - let file_name_inner = file_name.to_string(); - let wallet_inner = self.node_wallet.as_ref().unwrap().clone(); - let state_inner = self.state.clone(); - - tokio::spawn(async move { - if let Err(e) = file_handler::handle_file_upload( - &storage_path_inner, - &task_id_inner, - &file_name_inner, - &wallet_inner, - &state_inner, - ) - .await - { - error!("Failed to handle file upload: {e}"); - } else { - info!("File upload handled successfully"); - } - }); - } - - // Handle file validation if sha256 is present - if let Some(file_sha) = file_info["output/sha256"].as_str() { - debug!("Processing file validation message"); - let output_flops: f64 = file_info["output/output_flops"].as_f64().unwrap_or(0.0); - let input_flops: f64 = file_info["output/input_flops"].as_f64().unwrap_or(0.0); - - info!( - "Handling file validation for task_id: {task_id}, sha: {file_sha}, output_flops: {output_flops}, input_flops: {input_flops}" - ); - if let (Some(contracts_ref), Some(node_ref)) = - (self.contracts.clone(), self.node_config.clone()) - { - let file_sha_inner = file_sha.to_string(); - let contracts_inner = contracts_ref.clone(); - let node_inner = node_ref.clone(); - let provider = match self.node_wallet.as_ref() { - Some(wallet) => wallet.provider(), - None => { - error!("No wallet provider found"); - return Err(anyhow::anyhow!("No wallet provider found")); - } - }; - - if output_flops <= 0.0 { - error!("Invalid work units calculation: output_flops ({output_flops}) must be greater than 0.0. Blocking file validation submission."); - return Err(anyhow::anyhow!( - "Invalid work units: output_flops must be greater than 0.0" - )); - } - let work_units = output_flops; - - tokio::spawn(async move { - if let Err(e) = file_handler::handle_file_validation( - &file_sha_inner, - &contracts_inner, - &node_inner, - &provider, - work_units, - ) - .await - { - error!("Failed to handle file validation: {e}"); - } - }); - } else { - error!("Missing contracts or node configuration for file validation"); - } - } - } else { - error!("Failed to parse JSON: {json_str}"); - } - Ok(()) + pub(crate) fn get_socket_path(&self) -> &std::path::Path { + &self.socket_path } - async fn handle_message(self: Arc, json_str: &str) -> Result<()> { - debug!("Extracted JSON object: {json_str}"); - if json_str.contains("output/save_path") { - if let Err(e) = self.handle_file_upload(json_str) { - error!("Failed to handle file upload: {e}"); - } - } else { - debug!("Processing metric message"); - match serde_json::from_str::(json_str) { - Ok(input) => { - if let Err(e) = self.handle_metric(&input).await { - error!("Failed to handle metric: {e}"); - } - } - Err(e) => { - error!("Failed to parse metric input: {json_str} {e}"); - } - } - } - - Ok(()) - } + pub(crate) async fn run(self) -> Result<()> { + let Self { + socket_path, + config, + } = self; - pub async fn run(self: Arc) -> Result<()> { - let socket_path = Path::new(&self.socket_path); + let socket_path = Path::new(&socket_path); debug!("Setting up TaskBridge socket at: {}", socket_path.display()); if let Some(parent) = socket_path.parent() { @@ -240,91 +140,315 @@ impl TaskBridge { } } info!("TaskBridge socket created at: {}", socket_path.display()); + + let mut handle_stream_futures = FuturesUnordered::new(); + let mut listener_stream = tokio_stream::wrappers::UnixListenerStream::new(listener); + let (file_validation_futures_tx, mut file_validation_futures_rx) = + tokio::sync::mpsc::channel::<(String, BoxFuture>)>(100); + let mut file_validation_futures_set = HashSet::new(); + let mut file_validation_futures = FuturesUnordered::new(); + let (file_upload_futures_tx, mut file_upload_futures_rx) = + tokio::sync::mpsc::channel::>>(100); + let mut file_upload_futures_set = FuturesUnordered::new(); + loop { - let bridge = self.clone(); - match listener.accept().await { - Ok((stream, _addr)) => { - tokio::spawn(async move { - debug!("Received connection from {_addr:?}"); - let mut reader = BufReader::new(stream); - let mut buffer = vec![0; 1024]; - let mut data = Vec::new(); - - loop { - let n = match reader.read(&mut buffer).await { - Ok(0) => { - debug!("Connection closed by client"); - 0 - } - Ok(n) => { - debug!("Read {n} bytes from socket"); - n - } - Err(e) => { - error!("Error reading from stream: {e}"); - break; - } - }; - - data.extend_from_slice(&buffer[..n]); - debug!("Current data buffer size: {} bytes", data.len()); - - if let Ok(data_str) = std::str::from_utf8(&data) { - debug!("Raw data received: {data_str}"); - } else { - debug!("Raw data received (non-UTF8): {} bytes", data.len()); - } - - let mut current_pos = 0; - while current_pos < data.len() { - // Try to find a complete JSON object - if let Some((json_str, byte_length)) = - json_helper::extract_next_json(&data[current_pos..]) - { - let json_str = json_str.to_string(); - let bridge_clone = bridge.clone(); - if let Err(e) = bridge_clone.handle_message(&json_str).await { - error!("Error handling message: {e}"); - } - - current_pos += byte_length; - debug!( - "Advanced position to {current_pos} after processing JSON" - ); - } else { - debug!("No complete JSON object found, waiting for more data"); - break; - } - } - - data = data.split_off(current_pos); - debug!( - "Remaining data buffer size after processing: {} bytes", - data.len() - ); - if n == 0 { - if data.is_empty() { - // No data left to process, we can break - break; - } else { - // We have data but couldn't parse it as complete JSON objects - // and the connection is closed - log and discard - if let Ok(unparsed) = std::str::from_utf8(&data) { - warn!("Discarding unparseable data after connection close: {unparsed}"); - } else { - warn!("Discarding unparseable binary data after connection close ({} bytes)", data.len()); - } - // Break out of the loop - break; - } - } + tokio::select! { + Some(res) = listener_stream.next() => { + match res { + Ok(stream) => { + let handle_future = handle_stream(config.clone(), stream, file_upload_futures_tx.clone(), file_validation_futures_tx.clone()).fuse(); + handle_stream_futures.push(tokio::task::spawn(handle_future)); + } + Err(e) => { + error!("Accept failed on Unix socket: {e}"); + } + } + } + Some(res) = handle_stream_futures.next() => { + match res { + Ok(Ok(())) => { + debug!("Stream handler completed successfully"); + } + Ok(Err(e)) => { + error!("Stream handler failed: {e}"); + } + Err(e) => { + error!("Error joining stream handler task: {e}"); + } + } + } + Some((hash, fut)) = file_validation_futures_rx.recv() => { + if file_validation_futures_set.contains(&hash) { + debug!("duplicate file validation task for hash: {hash}, skipping"); + continue; + } + // we never remove hashes from this set, as we should never + // submit the same file for validation twice. + file_validation_futures_set.insert(hash.clone()); + file_validation_futures.push(async move {(hash, tokio::task::spawn(fut).await)}); + } + Some((hash, res)) = file_validation_futures.next() => { + match res { + Ok(Ok(())) => { + debug!("File validation task for hash {hash} completed successfully"); + } + Ok(Err(e)) => { + error!("File validation task for hash {hash} failed: {e}"); + } + Err(e) => { + error!("Error joining file validation task for hash {hash}: {e}"); + } + } + } + Some(fut) = file_upload_futures_rx.recv() => { + file_upload_futures_set.push(tokio::task::spawn(fut)); + } + Some(res) = file_upload_futures_set.next() => { + match res { + Ok(Ok(())) => { + debug!("File upload task completed successfully"); } - }); + Ok(Err(e)) => { + error!("File upload task failed: {e}"); + } + Err(e) => { + error!("Error joining file upload task: {e}"); + } + } + } + } + } + } +} + +async fn handle_stream( + config: TaskBridgeConfig, + stream: tokio::net::UnixStream, + file_upload_futures_tx: tokio::sync::mpsc::Sender>>, + file_validation_futures_tx: tokio::sync::mpsc::Sender<( + String, + BoxFuture<'_, anyhow::Result<()>>, + )>, +) -> Result<()> { + let addr = stream.peer_addr()?; + debug!("Received connection from {addr:?}"); + let mut reader = BufReader::new(stream); + let mut buffer = vec![0; 1024]; + let mut data = Vec::new(); + + loop { + let n = match reader.read(&mut buffer).await { + Ok(0) => { + debug!("Connection closed by client"); + 0 + } + Ok(n) => { + debug!("Read {n} bytes from socket"); + n + } + Err(e) => { + bail!("Error reading from stream: {e}"); + } + }; + + data.extend_from_slice(&buffer[..n]); + debug!("Current data buffer size: {} bytes", data.len()); + + if let Ok(data_str) = std::str::from_utf8(&data) { + debug!("Raw data received: {data_str}"); + } else { + debug!("Raw data received (non-UTF8): {} bytes", data.len()); + } + + let mut current_pos = 0; + while current_pos < data.len() { + // Try to find a complete JSON object + if let Some((json_str, byte_length)) = + json_helper::extract_next_json(&data[current_pos..]) + { + let json_str = json_str.to_string(); + if let Err(e) = handle_message( + config.clone(), + &json_str, + file_upload_futures_tx.clone(), + file_validation_futures_tx.clone(), + ) + .await + { + error!("Error handling message: {e}"); + } + + current_pos += byte_length; + debug!("Advanced position to {current_pos} after processing JSON"); + } else { + debug!("No complete JSON object found, waiting for more data"); + break; + } + } + + data = data.split_off(current_pos); + debug!( + "Remaining data buffer size after processing: {} bytes", + data.len() + ); + if n == 0 { + if data.is_empty() { + // No data left to process, we can break + break; + } else { + // We have data but couldn't parse it as complete JSON objects + // and the connection is closed - log and discard + if let Ok(unparsed) = std::str::from_utf8(&data) { + warn!("Discarding unparseable data after connection close: {unparsed}"); + } else { + warn!( + "Discarding unparseable binary data after connection close ({} bytes)", + data.len() + ); + } + // Break out of the loop + break; + } + } + } + Ok(()) +} + +async fn handle_metric(config: TaskBridgeConfig, input: &MetricInput) -> Result<()> { + debug!("Processing metric message"); + for (key, value) in input.metrics.iter() { + debug!("Metric - Key: {key}, Value: {value}"); + let _ = config + .metrics_store + .update_metric( + input.task_id.clone(), + key.to_string(), + value.as_f64().unwrap_or(0.0), + ) + .await; + } + Ok(()) +} + +async fn handle_file_upload( + config: TaskBridgeConfig, + json_str: &str, + file_upload_futures_tx: tokio::sync::mpsc::Sender>>, + file_validation_futures_tx: tokio::sync::mpsc::Sender<( + String, + BoxFuture<'_, anyhow::Result<()>>, + )>, +) -> Result<()> { + debug!("Handling file upload"); + if let Ok(file_info) = serde_json::from_str::(json_str) { + let task_id = file_info["task_id"].as_str().unwrap_or("unknown"); + + // Handle file upload if save_path is present + if let Some(file_name) = file_info["output/save_path"].as_str() { + info!("Handling file upload for task_id: {task_id}, file: {file_name}"); + + let Some(wallet) = config.node_wallet.as_ref() else { + bail!("no wallet found; must be set to upload files"); + }; + + let _ = file_upload_futures_tx + .send(Box::pin(file_handler::handle_file_upload( + config.docker_storage_path.clone(), + task_id.to_string(), + file_name.to_string(), + wallet.clone(), + config.state.clone(), + ))) + .await; + } + + // Handle file validation if sha256 is present + if let Some(file_sha) = file_info["output/sha256"].as_str() { + debug!("Processing file validation message"); + let output_flops: f64 = file_info["output/output_flops"].as_f64().unwrap_or(0.0); + let input_flops: f64 = file_info["output/input_flops"].as_f64().unwrap_or(0.0); + + info!( + "Handling file validation for task_id: {task_id}, sha: {file_sha}, output_flops: {output_flops}, input_flops: {input_flops}" + ); + + if let (Some(contracts), Some(node)) = + (config.contracts.clone(), config.node_config.clone()) + { + let provider = match config.node_wallet.as_ref() { + Some(wallet) => wallet.provider(), + None => { + error!("No wallet provider found"); + return Err(anyhow::anyhow!("No wallet provider found")); + } + }; + + if output_flops <= 0.0 { + error!("Invalid work units calculation: output_flops ({output_flops}) must be greater than 0.0. Blocking file validation submission."); + return Err(anyhow::anyhow!( + "Invalid work units: output_flops must be greater than 0.0" + )); + } + let work_units = output_flops; + + let _ = file_validation_futures_tx + .send(( + file_sha.to_string(), + Box::pin(file_handler::handle_file_validation( + file_sha.to_string(), + contracts.clone(), + node.clone(), + provider, + work_units, + )), + )) + .await; + } else { + error!("Missing contracts or node configuration for file validation"); + } + } + } else { + error!("Failed to parse JSON: {json_str}"); + } + Ok(()) +} + +async fn handle_message( + config: TaskBridgeConfig, + json_str: &str, + file_upload_futures_tx: tokio::sync::mpsc::Sender>>, + file_validation_futures_tx: tokio::sync::mpsc::Sender<( + String, + BoxFuture<'_, anyhow::Result<()>>, + )>, +) -> Result<()> { + debug!("Extracted JSON object: {json_str}"); + if json_str.contains("output/save_path") { + if let Err(e) = handle_file_upload( + config, + json_str, + file_upload_futures_tx, + file_validation_futures_tx, + ) + .await + { + error!("Failed to handle file upload: {e}"); + } + } else { + debug!("Processing metric message"); + match serde_json::from_str::(json_str) { + Ok(input) => { + if let Err(e) = handle_metric(config, &input).await { + error!("Failed to handle metric: {e}"); } - Err(e) => error!("Accept failed on Unix socket: {e}"), + } + Err(e) => { + error!("Failed to parse metric input: {json_str} {e}"); } } } + + Ok(()) } #[cfg(test)] diff --git a/crates/worker/src/docker/taskbridge/file_handler.rs b/crates/worker/src/docker/taskbridge/file_handler.rs index 857c04ae..82d856b6 100644 --- a/crates/worker/src/docker/taskbridge/file_handler.rs +++ b/crates/worker/src/docker/taskbridge/file_handler.rs @@ -17,11 +17,11 @@ use std::time::Duration; /// Handles a file upload request pub async fn handle_file_upload( - storage_path: &str, - task_id: &str, - file_name: &str, - wallet: &Wallet, - state: &Arc, + storage_path: String, + task_id: String, + file_name: String, + wallet: Wallet, + state: Arc, ) -> Result<()> { info!("📄 Received file upload request: {file_name}"); info!("Task ID: {task_id}, Storage path: {storage_path}"); @@ -44,7 +44,7 @@ pub async fn handle_file_upload( info!("Clean file name: {clean_file_name}"); let task_dir = format!("prime-task-{task_id}"); - let file_path = Path::new(storage_path) + let file_path = Path::new(&storage_path) .join(&task_dir) .join("data") .join(clean_file_name); @@ -122,7 +122,7 @@ pub async fn handle_file_upload( }; let signature = - match sign_request_with_nonce("/storage/request-upload", wallet, Some(&request_value)) + match sign_request_with_nonce("/storage/request-upload", &wallet, Some(&request_value)) .await { Ok(sig) => { @@ -313,10 +313,10 @@ pub async fn handle_file_upload( /// Handles a file validation request pub async fn handle_file_validation( - file_sha: &str, - contracts: &Contracts, - node: &Node, - provider: &WalletProvider, + file_sha: String, + contracts: Contracts, + node: Node, + provider: WalletProvider, work_units: f64, ) -> Result<()> { info!("📄 Received file SHA for validation: {file_sha}");