Skip to content

Commit

Permalink
Merge branch 'runtime-check'
Browse files Browse the repository at this point in the history
  • Loading branch information
bluegenes committed May 8, 2024
2 parents 4fce226 + 74cad4d commit e826265
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 30 deletions.
6 changes: 0 additions & 6 deletions build.rs

This file was deleted.

26 changes: 11 additions & 15 deletions src/directsketch.rs
Original file line number Diff line number Diff line change
@@ -1,34 +1,30 @@
use anyhow::{anyhow, bail, Context, Error, Result};
use async_zip::base::write::{self, ZipFileWriter};
use async_zip::Compression;
use async_zip::{ZipDateTime, ZipEntryBuilder};
use async_zip::base::write::ZipFileWriter;
use async_zip::{Compression, ZipDateTime, ZipEntryBuilder};
use camino::Utf8PathBuf as PathBuf;
use chrono::Utc;
use md5;
use needletail::parse_fastx_reader;
use regex::Regex;
use reqwest::Client;
use std::collections::HashMap;
use std::fs::{self, create_dir_all};
use std::io::Cursor;
use std::path::Path;
use std::sync::Arc;
use tokio::fs::File;
use tokio::task;
use tokio::io::{AsyncWriteExt, BufWriter};
use tokio::sync::Semaphore;
use tokio::time::Duration;
use tokio_util::compat::Compat;

use pyo3::prelude::*;

use std::sync::Arc;
use tokio::io::{AsyncWriteExt, BufWriter};

use tokio::sync::Semaphore;
use tokio::time::{interval, Duration};

use sourmash::manifest::{Manifest, Record};
use sourmash::signature::Signature;

use crate::utils::{build_siginfo, load_accession_info, parse_params_str};

#[allow(dead_code)]
enum GenBankFileType {
Genomic,
Protein,
Expand Down Expand Up @@ -449,7 +445,7 @@ async fn write_sig(
pub fn sigwriter_handle(
mut recv_sigs: tokio::sync::mpsc::Receiver<Vec<Signature>>,
output_sigs: String,
mut error_sender: tokio::sync::mpsc::Sender<anyhow::Error>,
error_sender: tokio::sync::mpsc::Sender<anyhow::Error>,
) -> tokio::task::JoinHandle<()> {
tokio::spawn(async move {
let mut md5sum_occurrences = HashMap::new();
Expand Down Expand Up @@ -481,7 +477,7 @@ pub fn sigwriter_handle(
Ok(_) => wrote_sigs = true,
Err(e) => {
let error = e.context("Error processing signature");
if let Err(send_error) = error_sender.send(error).await {
if (error_sender.send(error).await).is_err() {
return; // Exit on failure to send error
}
}
Expand Down Expand Up @@ -530,7 +526,7 @@ pub fn sigwriter_handle(
pub fn failures_handle(
failed_csv: String,
mut recv_failed: tokio::sync::mpsc::Receiver<FailedDownload>,
mut error_sender: tokio::sync::mpsc::Sender<Error>, // Additional parameter for error channel
error_sender: tokio::sync::mpsc::Sender<Error>, // Additional parameter for error channel
) -> tokio::task::JoinHandle<()> {
tokio::spawn(async move {
match File::create(&failed_csv).await {
Expand Down Expand Up @@ -698,7 +694,7 @@ pub async fn download_and_sketch(
}
}
Err(e) => {
let _ = send_errors.send(e.into()).await;
let _ = send_errors.send(e).await;
}
}
drop(send_errors);
Expand Down
32 changes: 23 additions & 9 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,32 @@ lazy_static! {
#[pyfunction]
fn set_tokio_thread_pool(num_threads: usize) -> PyResult<usize> {
let mut rt_lock = GLOBAL_RUNTIME.lock().unwrap();
if rt_lock.is_none() {
// Only initialize the runtime if it has not been initialized already
let runtime = Builder::new_multi_thread()
.worker_threads(num_threads)
.enable_all()
.build()
.map_err(|e| PyErr::new::<PyRuntimeError, _>(e))?;

*rt_lock = Some(runtime);
// Check if pytest is running
let pytest_running = std::env::var("PYTEST_RUNNING").is_ok();

// Check if runtime is already initialized
if rt_lock.is_some() {
if pytest_running {
// If pytest is running, simply return the number of threads without error
return Ok(num_threads);
} else {
// If not under pytest, return an error on reinitialization attempts
return Err(PyErr::new::<PyRuntimeError, _>(
"Tokio runtime is already initialized.",
));
}
}

// Return the number of threads, which is now guaranteed to be set
// Initialize the runtime if not already initialized
let runtime = Builder::new_multi_thread()
.worker_threads(num_threads)
.enable_all()
.build()
.map_err(PyErr::new::<PyRuntimeError, _>)?;

*rt_lock = Some(runtime);

Ok(num_threads)
}

Expand Down
3 changes: 3 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,6 @@ def runtmp():
# Set environment variable PYTEST_RUNNING
def pytest_configure(config):
os.environ["PYTEST_RUNNING"] = "1"

def pytest_unconfigure(config):
del os.environ["PYTEST_RUNNING"]

0 comments on commit e826265

Please sign in to comment.