diff --git a/build.rs b/build.rs deleted file mode 100644 index 8f19e94..0000000 --- a/build.rs +++ /dev/null @@ -1,6 +0,0 @@ -use pyo3_build_config; - -fn main() { - // If you have an existing build.rs file, just add this line to it. - pyo3_build_config::use_pyo3_cfgs(); -} diff --git a/src/directsketch.rs b/src/directsketch.rs index 788974d..202762a 100644 --- a/src/directsketch.rs +++ b/src/directsketch.rs @@ -1,10 +1,8 @@ use anyhow::{anyhow, bail, Context, Error, Result}; -use async_zip::base::write::{self, ZipFileWriter}; -use async_zip::Compression; -use async_zip::{ZipDateTime, ZipEntryBuilder}; +use async_zip::base::write::ZipFileWriter; +use async_zip::{Compression, ZipDateTime, ZipEntryBuilder}; use camino::Utf8PathBuf as PathBuf; use chrono::Utc; -use md5; use needletail::parse_fastx_reader; use regex::Regex; use reqwest::Client; @@ -12,23 +10,21 @@ use std::collections::HashMap; use std::fs::{self, create_dir_all}; use std::io::Cursor; use std::path::Path; +use std::sync::Arc; use tokio::fs::File; -use tokio::task; +use tokio::io::{AsyncWriteExt, BufWriter}; +use tokio::sync::Semaphore; +use tokio::time::Duration; use tokio_util::compat::Compat; use pyo3::prelude::*; -use std::sync::Arc; -use tokio::io::{AsyncWriteExt, BufWriter}; - -use tokio::sync::Semaphore; -use tokio::time::{interval, Duration}; - use sourmash::manifest::{Manifest, Record}; use sourmash::signature::Signature; use crate::utils::{build_siginfo, load_accession_info, parse_params_str}; +#[allow(dead_code)] enum GenBankFileType { Genomic, Protein, @@ -449,7 +445,7 @@ async fn write_sig( pub fn sigwriter_handle( mut recv_sigs: tokio::sync::mpsc::Receiver>, output_sigs: String, - mut error_sender: tokio::sync::mpsc::Sender, + error_sender: tokio::sync::mpsc::Sender, ) -> tokio::task::JoinHandle<()> { tokio::spawn(async move { let mut md5sum_occurrences = HashMap::new(); @@ -481,7 +477,7 @@ pub fn sigwriter_handle( Ok(_) => wrote_sigs = true, Err(e) => { let error = e.context("Error processing signature"); - if let Err(send_error) = error_sender.send(error).await { + if (error_sender.send(error).await).is_err() { return; // Exit on failure to send error } } @@ -530,7 +526,7 @@ pub fn sigwriter_handle( pub fn failures_handle( failed_csv: String, mut recv_failed: tokio::sync::mpsc::Receiver, - mut error_sender: tokio::sync::mpsc::Sender, // Additional parameter for error channel + error_sender: tokio::sync::mpsc::Sender, // Additional parameter for error channel ) -> tokio::task::JoinHandle<()> { tokio::spawn(async move { match File::create(&failed_csv).await { @@ -698,7 +694,7 @@ pub async fn download_and_sketch( } } Err(e) => { - let _ = send_errors.send(e.into()).await; + let _ = send_errors.send(e).await; } } drop(send_errors); diff --git a/src/lib.rs b/src/lib.rs index 9cb329d..6977c4e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -18,18 +18,32 @@ lazy_static! { #[pyfunction] fn set_tokio_thread_pool(num_threads: usize) -> PyResult { let mut rt_lock = GLOBAL_RUNTIME.lock().unwrap(); - if rt_lock.is_none() { - // Only initialize the runtime if it has not been initialized already - let runtime = Builder::new_multi_thread() - .worker_threads(num_threads) - .enable_all() - .build() - .map_err(|e| PyErr::new::(e))?; - *rt_lock = Some(runtime); + // Check if pytest is running + let pytest_running = std::env::var("PYTEST_RUNNING").is_ok(); + + // Check if runtime is already initialized + if rt_lock.is_some() { + if pytest_running { + // If pytest is running, simply return the number of threads without error + return Ok(num_threads); + } else { + // If not under pytest, return an error on reinitialization attempts + return Err(PyErr::new::( + "Tokio runtime is already initialized.", + )); + } } - // Return the number of threads, which is now guaranteed to be set + // Initialize the runtime if not already initialized + let runtime = Builder::new_multi_thread() + .worker_threads(num_threads) + .enable_all() + .build() + .map_err(PyErr::new::)?; + + *rt_lock = Some(runtime); + Ok(num_threads) } diff --git a/tests/conftest.py b/tests/conftest.py index 49cb49f..3f90580 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -14,3 +14,6 @@ def runtmp(): # Set environment variable PYTEST_RUNNING def pytest_configure(config): os.environ["PYTEST_RUNNING"] = "1" + +def pytest_unconfigure(config): + del os.environ["PYTEST_RUNNING"]