Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions backends/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,7 @@ async fn init_backend(
otlp_service_name: String,
) -> Result<Box<dyn CoreBackend + Send>, BackendError> {
let mut backend_start_failed = false;
let api_repo = api_repo.map(Arc::new);

if cfg!(feature = "ort") {
#[cfg(feature = "ort")]
Expand Down Expand Up @@ -409,7 +410,7 @@ async fn init_backend(
if let Some(api_repo) = api_repo.as_ref() {
if cfg!(feature = "python") || cfg!(feature = "candle") {
let start = std::time::Instant::now();
if download_safetensors(api_repo).await.is_err() {
if download_safetensors(api_repo.clone()).await.is_err() {
tracing::warn!("safetensors weights not found. Using `pytorch_model.bin` instead. Model loading will be significantly slower.");
tracing::info!("Downloading `pytorch_model.bin`");
api_repo
Expand Down Expand Up @@ -581,7 +582,7 @@ enum BackendCommand {
),
}

async fn download_safetensors(api: &ApiRepo) -> Result<Vec<PathBuf>, ApiError> {
async fn download_safetensors(api: Arc<ApiRepo>) -> Result<Vec<PathBuf>, ApiError> {
// Single file
tracing::info!("Downloading `model.safetensors`");
match api.get("model.safetensors").await {
Expand Down Expand Up @@ -611,10 +612,20 @@ async fn download_safetensors(api: &ApiRepo) -> Result<Vec<PathBuf>, ApiError> {
}

// Download weight files
let mut safetensors_files = Vec::new();
let mut handles = Vec::with_capacity(safetensors_filenames.len());
for n in safetensors_filenames {
tracing::info!("Downloading `{}`", n);
safetensors_files.push(api.get(&n).await?);
let api = Arc::clone(&api);
handles.push(tokio::spawn(async move {
tracing::info!("Downloading `{}`", n);
api.get(&n).await
}));
}

let mut safetensors_files = Vec::with_capacity(handles.len());
for handle in handles {
// Await the JoinHandle to get the result of the task,
// then unpack the inner result from api.get()
safetensors_files.push(handle.await??);
}

Ok(safetensors_files)
Expand Down