diff --git a/backends/src/lib.rs b/backends/src/lib.rs index 245715b3..f4d1b300 100644 --- a/backends/src/lib.rs +++ b/backends/src/lib.rs @@ -364,6 +364,7 @@ async fn init_backend( otlp_service_name: String, ) -> Result, BackendError> { let mut backend_start_failed = false; + let api_repo = api_repo.map(Arc::new); if cfg!(feature = "ort") { #[cfg(feature = "ort")] @@ -409,7 +410,7 @@ async fn init_backend( if let Some(api_repo) = api_repo.as_ref() { if cfg!(feature = "python") || cfg!(feature = "candle") { let start = std::time::Instant::now(); - if download_safetensors(api_repo).await.is_err() { + if download_safetensors(api_repo.clone()).await.is_err() { tracing::warn!("safetensors weights not found. Using `pytorch_model.bin` instead. Model loading will be significantly slower."); tracing::info!("Downloading `pytorch_model.bin`"); api_repo @@ -581,7 +582,7 @@ enum BackendCommand { ), } -async fn download_safetensors(api: &ApiRepo) -> Result, ApiError> { +async fn download_safetensors(api: Arc) -> Result, ApiError> { // Single file tracing::info!("Downloading `model.safetensors`"); match api.get("model.safetensors").await { @@ -611,10 +612,20 @@ async fn download_safetensors(api: &ApiRepo) -> Result, ApiError> { } // Download weight files - let mut safetensors_files = Vec::new(); + let mut handles = Vec::with_capacity(safetensors_filenames.len()); for n in safetensors_filenames { - tracing::info!("Downloading `{}`", n); - safetensors_files.push(api.get(&n).await?); + let api = Arc::clone(&api); + handles.push(tokio::spawn(async move { + tracing::info!("Downloading `{}`", n); + api.get(&n).await + })); + } + + let mut safetensors_files = Vec::with_capacity(handles.len()); + for handle in handles { + // Await the JoinHandle to get the result of the task, + // then unpack the inner result from api.get() + safetensors_files.push(handle.await??); } Ok(safetensors_files)