diff --git a/docs/source/tutorials/cli_usage.md b/docs/source/tutorials/cli_usage.md index a0c536e..f495993 100644 --- a/docs/source/tutorials/cli_usage.md +++ b/docs/source/tutorials/cli_usage.md @@ -29,7 +29,7 @@ The load command builds the raw AnnData object from your raw sequencing data. It - adata.X contains binarized modification data (conversion/deaminase), or modification probabilitiesc (native). - Adds basic read-level QC annotations (Read start, end, length, mean quality). - Adds layers encoding read DNA sequences, base quality scores, base mismatches. -- Maintains BAM Tags/Flags in adata.obs. +- Maintains BAM tags/flags in adata.obs (UMI and barcode annotations loaded from Parquet sidecars). - Writes the raw AnnData to the canonical output path and runs MultiQC. - Optionally deletes intermediate BAMs, H5ADs, and TSVs. diff --git a/docs/source/tutorials/experiment_config.md b/docs/source/tutorials/experiment_config.md index 89a1106..72e2f39 100644 --- a/docs/source/tutorials/experiment_config.md +++ b/docs/source/tutorials/experiment_config.md @@ -57,18 +57,25 @@ Below are some of the most commonly edited fields and how they affect the CLI wo - Lists are written in bracketed form, e.g. `[5mC]` or `[5mC_5hmC]`. - If you update the CSV, re-run the CLI command pointing at the updated file. -## BAM tags +## Read annotations -smftools writes and/or propagates the following BAM tags when loading data. These are also loaded -into `adata.obs` when `load_adata` reads BAM tags. +smftools annotates reads during `load_adata` and stores the results in `adata.obs`. Standard BAM +tags (e.g. `NM`, `MD`, `MM`, `ML`) are read directly from BAM files. UMI and barcode annotations +are computed in parallel and written to Parquet sidecar files alongside the aligned BAM, then loaded +into `adata.obs` from those sidecars. The aligned BAM itself is not modified. -**UMI tags** +**UMI annotations** (written to `.umi_tags.parquet`) -- `U1`: UMI from the *top* flank (read start or read end depending on match). -- `U2`: UMI from the *bottom* flank. +- `U1`: Orientation-corrected UMI for the *left* reference end of the mapped fragment (forward reads: US, reverse reads: UE). +- `U2`: Orientation-corrected UMI for the *right* reference end of the mapped fragment (forward reads: UE, reverse reads: US). +- `US`: Positional UMI from read start (delimited `UMI_seq;slot;flank_seq`). +- `UE`: Positional UMI from read end (delimited `UMI_seq;slot;flank_seq`). - `RX`: Combined UMI string (`U1-U2`, or `U1`/`U2` if only one is present). +- `FC`: Flank context of the U1/U2 pair (e.g. `top-bottom`). -**Barcode tags (smftools demux backend)** +When `threads` is set, UMI extraction is parallelized across multiple CPU cores. + +**Barcode annotations (smftools demux backend)** (written to `.barcode_tags.parquet`) - `BC`: Assigned barcode name, or `unclassified`. - `BM`: Match type (`both`, `read_start_only`, `read_end_only`, `mismatch`, `unclassified`). @@ -79,9 +86,13 @@ into `adata.obs` when `load_adata` reads BAM tags. - `B5`: Barcode name matched at the read start (corresponds to `B1`/`B3`). - `B6`: Barcode name matched at the read end (corresponds to `B2`/`B4`). -**Barcode tags (dorado demux backend)** +When `threads` is set, barcode extraction is parallelized across multiple CPU cores. +Demultiplexing (splitting reads into per-barcode BAMs) uses the sidecar `BC` assignments. +Only primary alignments are included in split BAMs and sidecar files. + +**Barcode annotations (dorado demux backend)** -- `BC`: Assigned barcode name. +- `BC`: Assigned barcode name (read from BAM tag). - `bi`: Dorado barcode info array (if present; expanded into columns during load). Notes: diff --git a/src/smftools/cli/load_adata.py b/src/smftools/cli/load_adata.py index 15f8c7d..0094a84 100644 --- a/src/smftools/cli/load_adata.py +++ b/src/smftools/cli/load_adata.py @@ -6,6 +6,7 @@ from typing import Iterable, Union import numpy as np +import pandas as pd from smftools.constants import BARCODE_KIT_ALIASES, LOAD_DIR, LOGGING_DIR, UMI_KIT_ALIASES from smftools.logging_utils import get_logger, setup_logging @@ -516,6 +517,8 @@ def load_adata_core(cfg, paths: AdataPaths, config_path: str | None = None): ######################################################################################################################## ################################### 4.5) Optional UMI annotation ############################################# + umi_sidecar = None + barcode_sidecar = None if getattr(cfg, "use_umi", False): logger.info("Annotating UMIs in aligned and sorted BAM before demultiplexing") @@ -538,20 +541,20 @@ def load_adata_core(cfg, paths: AdataPaths, config_path: str | None = None): umi_kit_config = load_umi_config_from_yaml(umi_yaml_path) resolved_umi = resolve_umi_config(umi_kit_config, cfg) - annotate_umi_tags_in_bam( + umi_sidecar = annotate_umi_tags_in_bam( aligned_sorted_output, use_umi=True, - umi_adapters=getattr(cfg, "umi_adapters", None), + umi_kit_config=umi_kit_config, umi_length=getattr(cfg, "umi_length", None), umi_search_window=getattr(cfg, "umi_search_window", 200), - umi_adapter_matcher=getattr(cfg, "umi_adapter_matcher", "exact"), + umi_adapter_matcher=getattr(cfg, "umi_adapter_matcher", "edlib"), umi_adapter_max_edits=resolved_umi["umi_adapter_max_edits"], samtools_backend=cfg.samtools_backend, - umi_kit_config=umi_kit_config, umi_ends=resolved_umi["umi_ends"], umi_flank_mode=resolved_umi["umi_flank_mode"], umi_amplicon_max_edits=resolved_umi["umi_amplicon_max_edits"], same_orientation=resolved_umi.get("same_orientation", False), + threads=cfg.threads, ) ######################################################################################################################## @@ -603,7 +606,7 @@ def load_adata_core(cfg, paths: AdataPaths, config_path: str | None = None): resolved_bc = resolve_barcode_config(barcode_kit_config, cfg) logger.info("Extracting and assigning barcodes to aligned BAM using smftools backend") - barcoded_bam = extract_and_assign_barcodes_in_bam( + barcode_sidecar = extract_and_assign_barcodes_in_bam( aligned_sorted_output, barcode_adapters=getattr(cfg, "barcode_adapters", [None, None]), barcode_references=barcode_references, @@ -619,10 +622,9 @@ def load_adata_core(cfg, paths: AdataPaths, config_path: str | None = None): barcode_kit_config=barcode_kit_config, barcode_ends=resolved_bc["barcode_ends"], barcode_amplicon_gap_tolerance=resolved_bc["barcode_amplicon_gap_tolerance"], + threads=cfg.threads, ) - # Update aligned_sorted_output to point to the barcoded BAM - aligned_sorted_output = barcoded_bam - logger.info(f"smftools barcode extraction complete: {barcoded_bam}") + logger.info(f"smftools barcode extraction complete: {barcode_sidecar}") ######################################################################################################################## ################################### 5) Demultiplexing ###################################################################### @@ -646,6 +648,7 @@ def load_adata_core(cfg, paths: AdataPaths, config_path: str | None = None): cfg.split_path, cfg.bam_suffix, samtools_backend=cfg.samtools_backend, + barcode_sidecar=barcode_sidecar, ) unclassified_bams = [p for p in all_bam_files if "unclassified" in p.name] @@ -716,7 +719,7 @@ def load_adata_core(cfg, paths: AdataPaths, config_path: str | None = None): # Annotate BM tag from bi per-end scores on each demuxed BAM for bam in bam_files: if "unclassified" not in bam.name: - annotate_demux_type_from_bi_tag(bam, threshold=0.0) + annotate_demux_type_from_bi_tag(bam) se_bam_files = bam_files bam_dir = cfg.split_path @@ -956,14 +959,10 @@ def load_adata_core(cfg, paths: AdataPaths, config_path: str | None = None): default_tags = ["NM", "MD", "fn"] if cfg.smf_modality == "direct": default_tags.extend(["MM", "ML"]) - # Add UMI tags if UMI extraction was enabled - if getattr(cfg, "use_umi", False): - default_tags.extend(["U1", "U2", "RX"]) - # Add barcode tags if smftools barcode extraction was used - if demux_backend == "smftools" and cfg.barcode_kit: - default_tags.extend(["BC", "BM", "B1", "B2", "B3", "B4", "B5", "B6"]) + # UMI tags are loaded from Parquet sidecar below (not from BAM) + # Barcode tags are loaded from Parquet sidecar below (not from BAM) # Add barcode tags from dorado single-pass demux (BM annotated from bi tag) - elif demux_backend == "dorado" and cfg.barcode_kit and not cfg.input_already_demuxed: + if demux_backend == "dorado" and cfg.barcode_kit and not cfg.input_already_demuxed: dorado_ver = _get_dorado_version() if dorado_ver is not None and dorado_ver >= (1, 3, 1): default_tags.extend(["BC", "BM", "bi"]) @@ -980,6 +979,31 @@ def load_adata_core(cfg, paths: AdataPaths, config_path: str | None = None): samtools_backend=cfg.samtools_backend, ) + # Load UMI tags from Parquet sidecar (written by annotate_umi_tags_in_bam) + if getattr(cfg, "use_umi", False) and umi_sidecar and Path(umi_sidecar).exists(): + logger.info("Loading UMI tags from Parquet sidecar: %s", umi_sidecar) + umi_df = pd.read_parquet(umi_sidecar).set_index("read_name") + umi_df = umi_df.reindex(raw_adata.obs_names) + for col in ["U1", "U2", "RX", "FC", "US", "UE"]: + if col in umi_df.columns: + raw_adata.obs[col] = umi_df[col].values + del umi_df + + # Load barcode tags from Parquet sidecar (written by extract_and_assign_barcodes_in_bam) + if ( + demux_backend == "smftools" + and cfg.barcode_kit + and barcode_sidecar + and Path(barcode_sidecar).exists() + ): + logger.info("Loading barcode tags from Parquet sidecar: %s", barcode_sidecar) + bc_df = pd.read_parquet(barcode_sidecar).set_index("read_name") + bc_df = bc_df.reindex(raw_adata.obs_names) + for col in ["BC", "BM", "B1", "B2", "B3", "B4", "B5", "B6"]: + if col in bc_df.columns: + raw_adata.obs[col] = bc_df[col].values + del bc_df + # Expand dorado bi array tag into individual float score columns if "bi" in bam_tag_names: expand_bi_tag_columns(raw_adata, bi_column="bi") diff --git a/src/smftools/config/default.yaml b/src/smftools/config/default.yaml index 7fcc8db..df32e18 100644 --- a/src/smftools/config/default.yaml +++ b/src/smftools/config/default.yaml @@ -89,7 +89,6 @@ demux_backend: dorado # smftools|dorado - smftools uses YAML-based barcode refs, barcode_both_ends: False # Require barcode match at both ends trim: False # dorado adapter and barcode removal during demultiplexing use_umi: False # Whether to detect and annotate UMIs in aligned_sorted BAM before demultiplexing -umi_adapters: [null, null] # Two-slot list [left_ref_end_adapter, right_ref_end_adapter] umi_length: null # Length of each UMI (required when use_umi is true) umi_search_window: 200 # Max distance from read ends to consider an adapter match for UMI extraction umi_adapter_matcher: edlib # exact|edlib diff --git a/src/smftools/informatics/bam_functions.py b/src/smftools/informatics/bam_functions.py index 677913a..ee278bb 100644 --- a/src/smftools/informatics/bam_functions.py +++ b/src/smftools/informatics/bam_functions.py @@ -7,9 +7,10 @@ import subprocess import time from collections import Counter, defaultdict, deque -from concurrent.futures import ThreadPoolExecutor, as_completed +from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed from dataclasses import dataclass, field from itertools import zip_longest +from math import ceil from pathlib import Path from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union @@ -164,6 +165,15 @@ class UMIKitConfig: same_orientation: bool = False +@dataclass +class UMIExtractionResult: + """Per-position UMI extraction result.""" + + umi_seq: Optional[str] = None + slot: Optional[str] = None + flank_seq: Optional[str] = None + + _BAM_FLAG_BITS: Tuple[Tuple[int, str], ...] = ( (0x1, "paired"), (0x2, "proper_pair"), @@ -261,118 +271,6 @@ def _parse_idxstats_output(output: str) -> Tuple[int, int, Dict[str, Tuple[int, return aligned_reads_count, unaligned_reads_count, proportions -def _normalize_umi_adapters(umi_adapters: Any) -> List[Optional[str]]: - """Normalize UMI adapters into a two-slot [left_ref_end, right_ref_end] list.""" - if umi_adapters is None: - adapters: List[Any] = [] - elif isinstance(umi_adapters, (list, tuple)): - adapters = list(umi_adapters) - else: - adapters = [umi_adapters] - - if len(adapters) != 2: - raise ValueError("umi_adapters must be a two-item list: [left_ref_end, right_ref_end].") - - normalized: List[Optional[str]] = [] - for adapter in adapters: - if adapter is None: - normalized.append(None) - continue - value = str(adapter).strip() - if not value or value.lower() == "none": - normalized.append(None) - continue - normalized.append(value.upper()) - return normalized - - -def validate_umi_config( - use_umi: bool, - umi_adapters: Any, - umi_length: Any, -) -> Tuple[List[Optional[str]], Optional[int]]: - """Validate UMI settings and return normalized adapters and length.""" - if not use_umi: - return [], None - - adapters = _normalize_umi_adapters(umi_adapters) - if all(adapter is None for adapter in adapters): - raise ValueError( - "cfg.use_umi is true, but no UMI adapter sequences were provided in umi_adapters." - ) - - try: - length = int(umi_length) - except Exception as exc: - raise ValueError("UMI length must be a positive integer when cfg.use_umi is true.") from exc - if length <= 0: - raise ValueError("UMI length must be a positive integer when cfg.use_umi is true.") - - return adapters, length - - -def _extract_umi_adjacent_to_adapter_on_read_end( - read_sequence: str, - adapter_sequence: str, - umi_length: int, - umi_search_window: int, - search_from_start: bool, - adapter_matcher: str = "exact", - adapter_max_edits: int = 0, -) -> Optional[str]: - """Extract UMI adjacent to adapter constrained to either start or end of read.""" - if not read_sequence or not adapter_sequence: - return None - - seq = read_sequence.upper() - adapter = adapter_sequence.upper() - seq_len = len(seq) - if seq_len == 0: - return None - - matcher = str(adapter_matcher).strip().lower() - if matcher not in {"exact", "edlib"}: - raise ValueError("adapter_matcher must be one of: exact, edlib") - - if matcher == "exact": - matches = [(m.start(), m.end()) for m in re.finditer(re.escape(adapter), seq)] - else: - edlib = require("edlib", extra="umi", purpose="fuzzy UMI adapter matching") - result = edlib.align(adapter, seq, mode="HW", task="locations", k=max(0, adapter_max_edits)) - locations = result.get("locations", []) if isinstance(result, dict) else [] - matches = [] - for loc in locations: - if not isinstance(loc, (list, tuple)) or len(loc) != 2: - continue - start_i, end_i = int(loc[0]), int(loc[1]) - if start_i < 0 or end_i < start_i: - continue - matches.append((start_i, end_i + 1)) - - best: Optional[Tuple[int, int]] = None - for start, end in matches: - distance = start if search_from_start else (seq_len - end) - if distance > umi_search_window: - continue - if best is None or distance < best[0]: - best = (distance, start) - - if best is None: - return None - - start = best[1] - end = start + len(adapter) - if search_from_start: - umi_start, umi_end = end, end + umi_length - else: - umi_start, umi_end = start - umi_length, start - - if umi_start < 0 or umi_end > seq_len: - return None - umi = seq[umi_start:umi_end] - return umi if len(umi) == umi_length else None - - _COMPLEMENT = str.maketrans("ACGTNacgtn", "TGCANtgcan") @@ -884,63 +782,157 @@ def _extract_sequence_with_flanking( return extracted, tgt_start, tgt_end -def _target_read_end_for_ref_side(is_reverse: bool, ref_side: str) -> str: - """Map reference-side adapter slot to a read-end target given read strand.""" - if ref_side == "left": - return "end" if is_reverse else "start" - if ref_side == "right": - return "start" if is_reverse else "end" - raise ValueError(f"Unknown ref_side: {ref_side}") +def _extract_umis_for_reads( + sequences: List[str], + length: int, + search_window: int, + matcher: str, + max_edits: int, + umi_amplicon_max_edits: int, + effective_flank_mode: str, + flanking_candidates: List[Tuple[str, FlankingConfig]], + configured_slots: List[int], + check_start: bool, + check_end: bool, +) -> List[Tuple[Optional[UMIExtractionResult], Optional[UMIExtractionResult]]]: + """Extract UMIs for a list of read sequences. + + Returns a list of ``(us_result, ue_result)`` tuples, one per read. + Each element is an :class:`UMIExtractionResult` with UMI sequence, slot, + and flanking sequence metadata. + """ + # Pre-compute RC'd flanking configs for read-end searches + end_flanking_cache: List[FlankingConfig] = [] + for _slot, candidate in flanking_candidates: + if candidate.adapter_side and candidate.amplicon_side: + end_flanking_cache.append( + FlankingConfig( + adapter_side=_reverse_complement(candidate.amplicon_side), + amplicon_side=_reverse_complement(candidate.adapter_side), + adapter_pad=candidate.amplicon_pad, + amplicon_pad=candidate.adapter_pad, + ) + ) + elif candidate.adapter_side: + end_flanking_cache.append( + FlankingConfig( + adapter_side=_reverse_complement(candidate.adapter_side), + amplicon_side=None, + adapter_pad=candidate.adapter_pad, + amplicon_pad=candidate.amplicon_pad, + ) + ) + elif candidate.amplicon_side: + end_flanking_cache.append( + FlankingConfig( + adapter_side=None, + amplicon_side=_reverse_complement(candidate.amplicon_side), + adapter_pad=candidate.adapter_pad, + amplicon_pad=candidate.amplicon_pad, + ) + ) + else: + end_flanking_cache.append(FlankingConfig(adapter_side=None, amplicon_side=None)) + + results: List[Tuple[Optional[UMIExtractionResult], Optional[UMIExtractionResult]]] = [] + for sequence in sequences: + us_result: Optional[UMIExtractionResult] = None + ue_result: Optional[UMIExtractionResult] = None + found_slots: List[Optional[str]] = [None, None] # track by slot index + + for read_end in ("start", "end"): + if read_end == "start" and not check_start: + continue + if read_end == "end" and not check_end: + continue + search_from_start = read_end == "start" + + for i, (slot, candidate) in enumerate(flanking_candidates): + end_flanking = candidate if search_from_start else end_flanking_cache[i] + + extracted, _, _ = _extract_sequence_with_flanking( + read_sequence=sequence, + target_length=length, + search_window=search_window, + search_from_start=search_from_start, + flanking=end_flanking, + flank_mode=effective_flank_mode, + adapter_matcher=matcher, + adapter_max_edits=max_edits, + amplicon_max_edits=umi_amplicon_max_edits, + same_orientation=False, + ) + if extracted and read_end == "end": + extracted = _reverse_complement(extracted) + + if extracted: + # Use original (pre-RC) adapter_side as flank identity + flank_seq = candidate.adapter_side or "" + result = UMIExtractionResult( + umi_seq=extracted, + slot=slot, + flank_seq=flank_seq, + ) + idx = 0 if slot == "top" else 1 + if found_slots[idx] is None: + found_slots[idx] = slot + if search_from_start: + if us_result is None: + us_result = result + else: + if ue_result is None: + ue_result = result + + if configured_slots and all(found_slots[idx] is not None for idx in configured_slots): + break + + results.append((us_result, ue_result)) + return results def annotate_umi_tags_in_bam( bam_path: str | Path, *, use_umi: bool, - umi_adapters: Any, - umi_length: Any, + umi_kit_config: UMIKitConfig, + umi_length: Any = None, umi_search_window: int = 200, - umi_adapter_matcher: str = "exact", + umi_adapter_matcher: str = "edlib", umi_adapter_max_edits: int = 0, samtools_backend: str | None = "auto", - # New flanking parameters (optional) - umi_kit_config: Optional[UMIKitConfig] = None, umi_ends: Optional[str] = None, umi_flank_mode: Optional[str] = None, umi_amplicon_max_edits: int = 0, same_orientation: bool = False, + threads: Optional[int] = None, ) -> Path: """Annotate aligned BAM reads with UMI tags before demultiplexing. - When ``umi_kit_config`` with flanking sequences is provided, extraction uses - ``_extract_sequence_with_flanking`` instead of - ``_extract_umi_adjacent_to_adapter_on_read_end``. + Uses flanking-sequence-based extraction via ``_extract_sequence_with_flanking``. + When ``threads`` > 1, UMI extraction is parallelized across CPU cores using + multiprocessing while BAM I/O remains single-threaded. + + Tags written: + US / UE – positional: delimited ``"UMI_seq;slot;flank_seq"`` from read **start** / **end** + U1 / U2 – orientation-corrected: U1 = left reference end, U2 = right reference end (fwd: U1=US,U2=UE; rev: U1=UE,U2=US; UMI sequence only) + FC – flank context: slot names of U1/U2 pair (e.g. ``"top-bottom"``) + RX – combined tag using orientation-assigned U1-U2 """ input_bam = Path(bam_path) if not use_umi: return input_bam - # Determine if flanking-based extraction should be used - use_flanking = umi_kit_config is not None and umi_kit_config.flanking is not None + if umi_kit_config is None or umi_kit_config.flanking is None: + raise ValueError("umi_kit_config with flanking sequences is required for UMI annotation.") - if use_flanking: - flanking_config = umi_kit_config.flanking - length = umi_kit_config.length if umi_kit_config.length else int(umi_length or 0) - effective_umi_ends = umi_ends or umi_kit_config.umi_ends or "both" - effective_flank_mode = umi_flank_mode or umi_kit_config.umi_flank_mode or "adapter_only" - if length <= 0: - raise ValueError( - "UMI length must be a positive integer when using flanking-based extraction." - ) - # We still need adapters validated for the legacy path if flanking is partial - adapters = [None, None] # Not used in flanking path - configured_adapter_count = 0 - else: - flanking_config = None - adapters, length = validate_umi_config(use_umi, umi_adapters, umi_length) - effective_umi_ends = umi_ends or "both" - effective_flank_mode = umi_flank_mode or "adapter_only" - configured_adapter_count = sum(1 for adapter in adapters if adapter is not None) + flanking_config = umi_kit_config.flanking + length = umi_kit_config.length if umi_kit_config.length else int(umi_length or 0) + effective_umi_ends = umi_ends or umi_kit_config.umi_ends or "both" + effective_flank_mode = umi_flank_mode or umi_kit_config.umi_flank_mode or "adapter_only" + if length <= 0: + raise ValueError( + "UMI length must be a positive integer when using flanking-based extraction." + ) search_window = max(0, int(umi_search_window)) matcher = str(umi_adapter_matcher).strip().lower() @@ -955,236 +947,206 @@ def annotate_umi_tags_in_bam( check_end = effective_umi_ends in ("both", "right_only", "read_end") flanking_candidates: List[Tuple[str, FlankingConfig]] = [] - if use_flanking and flanking_config is not None: - if flanking_config.left_ref_end is not None and ( - flanking_config.left_ref_end.adapter_side or flanking_config.left_ref_end.amplicon_side - ): - flanking_candidates.append(("top", flanking_config.left_ref_end)) - if flanking_config.right_ref_end is not None and ( - flanking_config.right_ref_end.adapter_side - or flanking_config.right_ref_end.amplicon_side - ): - flanking_candidates.append(("bottom", flanking_config.right_ref_end)) + if flanking_config.left_ref_end is not None and ( + flanking_config.left_ref_end.adapter_side or flanking_config.left_ref_end.amplicon_side + ): + flanking_candidates.append(("top", flanking_config.left_ref_end)) + if flanking_config.right_ref_end is not None and ( + flanking_config.right_ref_end.adapter_side or flanking_config.right_ref_end.amplicon_side + ): + flanking_candidates.append(("bottom", flanking_config.right_ref_end)) + + configured_slots = [0 if slot == "top" else 1 for slot, _ in flanking_candidates] + configured_adapter_count = len(configured_slots) + + cpu_count = os.cpu_count() or 1 + num_workers = min(max(1, int(threads)), cpu_count) if threads else 1 + + # ── Shared extraction kwargs ──────────────────────────────────────── + extraction_kwargs = dict( + length=length, + search_window=search_window, + matcher=matcher, + max_edits=max_edits, + umi_amplicon_max_edits=umi_amplicon_max_edits, + effective_flank_mode=effective_flank_mode, + flanking_candidates=flanking_candidates, + configured_slots=configured_slots, + check_start=check_start, + check_end=check_end, + ) - # Count configured ends for flanking - if use_flanking: - configured_slots = [0 if slot == "top" else 1 for slot, _ in flanking_candidates] - configured_adapter_count = len(configured_slots) + # ── Single BAM pass: collect metadata + sequences ─────────────────── + read_names: List[str] = [] + is_reverse_flags: List[bool] = [] + is_primary: List[bool] = [] + sequences: List[str] = [] + if backend_choice == "python": + pysam_mod = _require_pysam() + with pysam_mod.AlignmentFile(str(input_bam), "rb") as in_bam: + for read in tqdm(in_bam.fetch(until_eof=True), desc="UMI: reading BAM", unit=" reads"): + read_names.append(read.query_name) + is_reverse_flags.append(read.is_reverse) + is_primary.append(not read.is_secondary and not read.is_supplementary) + sequences.append(read.query_sequence or "") + else: + cmd = ["samtools", "view", str(input_bam)] + proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) + assert proc.stdout is not None + for line in tqdm(proc.stdout, desc="UMI: reading BAM", unit=" reads"): + if not line.strip() or line.startswith("@"): + continue + fields = line.rstrip("\n").split("\t") + if len(fields) < 11: + continue + flag = int(fields[1]) + read_names.append(fields[0]) + is_reverse_flags.append(bool(flag & 0x10)) + is_primary.append(not bool(flag & 0x100) and not bool(flag & 0x800)) + seq = fields[9] + sequences.append("" if seq == "*" else seq) + rc = proc.wait() + if rc != 0: + stderr = proc.stderr.read() if proc.stderr else "" + raise RuntimeError(f"samtools view failed (exit {rc}):\n{stderr}") + total_reads = len(read_names) - pysam_mod = _require_pysam() - tmp_bam = input_bam.with_name(f"{input_bam.stem}.umi_tmp{input_bam.suffix}") + # ── UMI extraction ────────────────────────────────────────────────── + if num_workers <= 1 or total_reads == 0: + # Single-process fast path – no IPC overhead + all_results = _extract_umis_for_reads(sequences, **extraction_kwargs) + else: + chunk_size = max(1000, ceil(total_reads / num_workers)) + num_chunks = ceil(total_reads / chunk_size) + actual_workers = min(num_workers, num_chunks) + logger.info( + "UMI extraction: %d reads across %d workers (chunk_size=%d)", + total_reads, + actual_workers, + chunk_size, + ) + with ProcessPoolExecutor(max_workers=actual_workers) as pool: + futures = {} + for chunk_idx in range(num_chunks): + start_idx = chunk_idx * chunk_size + end_idx = min(start_idx + chunk_size, total_reads) + chunk = sequences[start_idx:end_idx] + future = pool.submit( + _extract_umis_for_reads, + chunk, + **extraction_kwargs, + ) + futures[future] = chunk_idx + + # Collect results in submission order, with progress per chunk + chunk_results: Dict[ + int, List[Tuple[Optional[UMIExtractionResult], Optional[UMIExtractionResult]]] + ] = {} + with tqdm( + total=total_reads, + desc=f"UMI: extracting ({actual_workers} workers)", + unit=" reads", + ) as pbar: + for future in as_completed(futures): + idx = futures[future] + result = future.result() + chunk_results[idx] = result + pbar.update(len(result)) + + all_results: List[ + Tuple[Optional[UMIExtractionResult], Optional[UMIExtractionResult]] + ] = [] + for chunk_idx in range(num_chunks): + all_results.extend(chunk_results[chunk_idx]) + del chunk_results + + del sequences + + # ── Compute tags in-memory and write Parquet sidecar ──────────────── + import pandas as pd - total_reads = 0 reads_with_any_umi = 0 reads_with_all_umis = 0 - with ( - pysam_mod.AlignmentFile(str(input_bam), "rb") as in_bam, - pysam_mod.AlignmentFile(str(tmp_bam), "wb", template=in_bam) as out_bam, - ): - for read in in_bam.fetch(until_eof=True): - total_reads += 1 - sequence = read.query_sequence or "" - umi_values: List[Optional[str]] = [None, None] - - if use_flanking: - for read_end in ("start", "end"): - if read_end == "start" and not check_start: - continue - if read_end == "end" and not check_end: - continue - search_from_start = read_end == "start" - - for slot, candidate in flanking_candidates: - end_flanking = candidate - if read_end == "end": - if candidate.adapter_side and candidate.amplicon_side: - end_flanking = FlankingConfig( - adapter_side=_reverse_complement(candidate.amplicon_side), - amplicon_side=_reverse_complement(candidate.adapter_side), - adapter_pad=candidate.amplicon_pad, - amplicon_pad=candidate.adapter_pad, - ) - elif candidate.adapter_side: - end_flanking = FlankingConfig( - adapter_side=_reverse_complement(candidate.adapter_side), - amplicon_side=None, - adapter_pad=candidate.adapter_pad, - amplicon_pad=candidate.amplicon_pad, - ) - elif candidate.amplicon_side: - end_flanking = FlankingConfig( - adapter_side=None, - amplicon_side=_reverse_complement(candidate.amplicon_side), - adapter_pad=candidate.adapter_pad, - amplicon_pad=candidate.amplicon_pad, - ) - else: - end_flanking = FlankingConfig(adapter_side=None, amplicon_side=None) - - extracted, _, _ = _extract_sequence_with_flanking( - read_sequence=sequence, - target_length=length, - search_window=search_window, - search_from_start=search_from_start, - flanking=end_flanking, - flank_mode=effective_flank_mode, - adapter_matcher=matcher, - adapter_max_edits=max_edits, - amplicon_max_edits=umi_amplicon_max_edits, - same_orientation=False, - ) - if extracted and read_end == "end": - extracted = _reverse_complement(extracted) + tag_rows: List[Dict[str, Optional[str]]] = [] + for read_idx in range(total_reads): + us_result, ue_result = all_results[read_idx] - if extracted: - idx = 0 if slot == "top" else 1 - if umi_values[idx] is None: - umi_values[idx] = extracted + present = sum(1 for r in (us_result, ue_result) if r and r.umi_seq) + if present: + reads_with_any_umi += 1 + if configured_adapter_count and present == configured_adapter_count: + reads_with_all_umis += 1 - if configured_slots and all( - umi_values[idx] is not None for idx in configured_slots - ): - break - else: - for i, ref_side in enumerate(["left", "right"]): - if i == 0 and not check_start: - continue - if i == 1 and not check_end: - continue + # Only include primary alignments in the sidecar + if not is_primary[read_idx]: + continue - read_end = _target_read_end_for_ref_side(read.is_reverse, ref_side) - search_from_start = read_end == "start" + row: Dict[str, Optional[str]] = {"read_name": read_names[read_idx]} - adapter = adapters[i] if i < len(adapters) else None - if adapter is None: - continue - # RC adapter for reverse-strand reads - search_adapter = _reverse_complement(adapter) if read.is_reverse else adapter - extracted_umi = _extract_umi_adjacent_to_adapter_on_read_end( - read_sequence=sequence, - adapter_sequence=search_adapter, - umi_length=length, - umi_search_window=search_window, - search_from_start=search_from_start, - adapter_matcher=matcher, - adapter_max_edits=max_edits, - ) - # RC extracted UMI for reverse-strand reads - if extracted_umi and read.is_reverse: - extracted_umi = _reverse_complement(extracted_umi) - umi_values[i] = extracted_umi - - present = [u for u in umi_values if u] - if present: - reads_with_any_umi += 1 - if configured_adapter_count and len(present) == configured_adapter_count: - reads_with_all_umis += 1 - - umi1, umi2 = umi_values[0], umi_values[1] - if umi1: - read.set_tag("U1", umi1, value_type="Z") - if umi2: - read.set_tag("U2", umi2, value_type="Z") - if umi1 and umi2: - read.set_tag("RX", f"{umi1}-{umi2}", value_type="Z") - elif umi1: - read.set_tag("RX", umi1, value_type="Z") - elif umi2: - read.set_tag("RX", umi2, value_type="Z") - - out_bam.write(read) - - tmp_bam.replace(input_bam) - index_paths = ( - input_bam.with_suffix(input_bam.suffix + ".bai"), - Path(str(input_bam) + ".bai"), - ) - for idx_path in index_paths: - if idx_path.exists(): - idx_path.unlink() - if backend_choice == "python": - _index_bam_with_pysam(input_bam) - else: - _index_bam_with_samtools(input_bam) + # Positional tags (US/UE) as delimited strings + if us_result and us_result.umi_seq: + parts = [us_result.umi_seq, us_result.slot or "", us_result.flank_seq or ""] + row["US"] = ";".join(parts) + if ue_result and ue_result.umi_seq: + parts = [ue_result.umi_seq, ue_result.slot or "", ue_result.flank_seq or ""] + row["UE"] = ";".join(parts) + + # Assign U1/U2 based on alignment orientation + is_rev = is_reverse_flags[read_idx] + if is_rev: + u1_result, u2_result = ue_result, us_result + else: + u1_result, u2_result = us_result, ue_result + + u1 = u1_result.umi_seq if u1_result else None + u2 = u2_result.umi_seq if u2_result else None + + if u1: + row["U1"] = u1 + if u2: + row["U2"] = u2 + + # FC tag: flank context of U1/U2 pair + u1_slot = u1_result.slot if u1_result else None + u2_slot = u2_result.slot if u2_result else None + if u1_slot and u2_slot: + row["FC"] = f"{u1_slot}-{u2_slot}" + elif u1_slot: + row["FC"] = u1_slot + elif u2_slot: + row["FC"] = u2_slot + + # Combined tag (RX) using orientation-assigned U1/U2 + if u1 and u2: + row["RX"] = f"{u1}-{u2}" + elif u1: + row["RX"] = u1 + elif u2: + row["RX"] = u2 + + tag_rows.append(row) + + del all_results, read_names, is_reverse_flags, is_primary + + sidecar_path = input_bam.with_suffix(".umi_tags.parquet") + df = pd.DataFrame(tag_rows) + # Ensure all tag columns exist even if no reads had UMIs + for col in ("U1", "U2", "US", "UE", "RX", "FC"): + if col not in df.columns: + df[col] = None + df.to_parquet(sidecar_path, index=False) + del tag_rows, df logger.info( - "UMI annotation complete for %s: total_reads=%d, reads_with_any_umi=%d, reads_with_all_umis=%d", + "UMI annotation complete for %s: total_reads=%d, reads_with_any_umi=%d, reads_with_all_umis=%d, sidecar=%s", input_bam, total_reads, reads_with_any_umi, reads_with_all_umis, + sidecar_path, ) - return input_bam - - -def _extract_barcode_adjacent_to_adapter_on_read_end( - read_sequence: str, - adapter_sequence: str, - barcode_length: int, - barcode_search_window: int, - search_from_start: bool, - adapter_matcher: str = "edlib", - adapter_max_edits: int = 2, -) -> Tuple[Optional[str], Optional[int]]: - """ - Extract barcode sequence adjacent to adapter, constrained to read end. - - Returns - ------- - Tuple[Optional[str], Optional[int]] - (barcode_sequence, adapter_start_position) or (None, None) if not found. - """ - if not read_sequence or not adapter_sequence: - return None, None - - seq = read_sequence.upper() - adapter = adapter_sequence.upper() - seq_len = len(seq) - if seq_len == 0: - return None, None - - matcher = str(adapter_matcher).strip().lower() - if matcher not in {"exact", "edlib"}: - raise ValueError("adapter_matcher must be one of: exact, edlib") - - if matcher == "exact": - matches = [(m.start(), m.end()) for m in re.finditer(re.escape(adapter), seq)] - else: - edlib = require("edlib", extra="umi", purpose="fuzzy barcode adapter matching") - result = edlib.align(adapter, seq, mode="HW", task="locations", k=max(0, adapter_max_edits)) - locations = result.get("locations", []) if isinstance(result, dict) else [] - matches = [] - for loc in locations: - if not isinstance(loc, (list, tuple)) or len(loc) != 2: - continue - start_i, end_i = int(loc[0]), int(loc[1]) - if start_i < 0 or end_i < start_i: - continue - matches.append((start_i, end_i + 1)) - - best: Optional[Tuple[int, int]] = None - for start, end in matches: - distance = start if search_from_start else (seq_len - end) - if distance > barcode_search_window: - continue - if best is None or distance < best[0]: - best = (distance, start) - - if best is None: - return None, None - - adapter_start = best[1] - adapter_end = adapter_start + len(adapter) - if search_from_start: - bc_start, bc_end = adapter_end, adapter_end + barcode_length - else: - bc_start, bc_end = adapter_start - barcode_length, adapter_start - - if bc_start < 0 or bc_end > seq_len: - return None, None - barcode = seq[bc_start:bc_end] - return (barcode, adapter_start) if len(barcode) == barcode_length else (None, None) + return sidecar_path def _match_barcode_to_references( @@ -1643,6 +1605,173 @@ def _get(attr: str, yaml_val: Any, default: Any) -> Any: } +def _extract_and_match_barcodes_for_reads( + sequences: List[str], + *, + barcode_length: int, + barcode_search_window: int, + barcode_max_edit_distance: int, + barcode_adapter_matcher: str, + barcode_composite_max_edits: int, + barcode_min_separation: Optional[int], + require_both_ends: bool, + min_barcode_score: Optional[int], + bc_refs: Dict[str, str], + flanking_config: Any, + check_start: bool, + check_end: bool, +) -> List[Dict[str, Optional[str | int]]]: + """Extract and match barcodes for a batch of read sequences. + + Pure worker function suitable for multiprocessing. Returns one dict per read + with keys: BC, BM, B1, B2, B3, B4, B5, B6. + """ + matcher = barcode_adapter_matcher + composite_max_edits = barcode_composite_max_edits + + # Pre-compute end flanking configs for flanking-based extraction + end_flanking_cache: List[FlankingConfig] = [] + flanking_candidates: List[FlankingConfig] = [] + if flanking_config is not None: + if flanking_config.left_ref_end is not None and ( + flanking_config.left_ref_end.adapter_side or flanking_config.left_ref_end.amplicon_side + ): + flanking_candidates.append(flanking_config.left_ref_end) + if ( + flanking_config.right_ref_end is not None + and flanking_config.right_ref_end not in flanking_candidates + and ( + flanking_config.right_ref_end.adapter_side + or flanking_config.right_ref_end.amplicon_side + ) + ): + flanking_candidates.append(flanking_config.right_ref_end) + + for candidate in flanking_candidates: + if candidate.adapter_side and candidate.amplicon_side: + end_flanking_cache.append( + FlankingConfig( + adapter_side=_reverse_complement(candidate.amplicon_side), + amplicon_side=_reverse_complement(candidate.adapter_side), + adapter_pad=candidate.amplicon_pad, + amplicon_pad=candidate.adapter_pad, + ) + ) + elif candidate.adapter_side: + end_flanking_cache.append( + FlankingConfig( + adapter_side=_reverse_complement(candidate.adapter_side), + amplicon_side=None, + adapter_pad=candidate.adapter_pad, + amplicon_pad=candidate.amplicon_pad, + ) + ) + elif candidate.amplicon_side: + end_flanking_cache.append( + FlankingConfig( + adapter_side=None, + amplicon_side=_reverse_complement(candidate.amplicon_side), + adapter_pad=candidate.adapter_pad, + amplicon_pad=candidate.amplicon_pad, + ) + ) + else: + end_flanking_cache.append(FlankingConfig(adapter_side=None, amplicon_side=None)) + + results: List[Dict[str, Optional[str | int]]] = [] + for sequence in sequences: + bc_matches: List[Tuple[Optional[str], Optional[int]]] = [ + (None, None), + (None, None), + ] + extracted_start_seq: Optional[str] = None + extracted_end_seq: Optional[str] = None + padded_region: Optional[str] = None + + for i, read_end in enumerate(["start", "end"]): + if i == 0 and not check_start: + continue + if i == 1 and not check_end: + continue + + search_from_start = read_end == "start" + + extracted_bc: Optional[str] = None + padded_region = None + + for ci, candidate in enumerate(flanking_candidates): + end_flanking = candidate if search_from_start else end_flanking_cache[ci] + + extracted_bc, _, _, padded_region = _extract_barcode_with_flanking( + read_sequence=sequence, + target_length=barcode_length, + search_window=barcode_search_window, + search_from_start=search_from_start, + flanking=end_flanking, + adapter_matcher=matcher, + composite_max_edits=composite_max_edits, + ) + + if extracted_bc and read_end == "end": + extracted_bc = _reverse_complement(extracted_bc) + if padded_region is not None: + padded_region = _reverse_complement(padded_region) + + if extracted_bc: + break + + if extracted_bc: + if read_end == "start": + extracted_start_seq = extracted_bc + else: + extracted_end_seq = extracted_bc + match_name, match_dist = _match_barcode_to_references( + extracted_bc, + bc_refs, + max_edit_distance=barcode_max_edit_distance, + min_separation=barcode_min_separation, + padded_region=padded_region, + ) + if match_name is not None: + if min_barcode_score is None or match_dist <= min_barcode_score: + bc_matches[i] = (match_name, match_dist) + + left_match, left_dist = bc_matches[0] + right_match, right_dist = bc_matches[1] + + # Determine match type and final barcode assignment + if left_match and right_match: + if left_match == right_match: + match_type = "both" + assigned_bc = left_match + else: + match_type = "mismatch" + assigned_bc = "unclassified" + elif left_match: + match_type = "read_start_only" + assigned_bc = "unclassified" if require_both_ends else left_match + elif right_match: + match_type = "read_end_only" + assigned_bc = "unclassified" if require_both_ends else right_match + else: + match_type = "unclassified" + assigned_bc = "unclassified" + + row: Dict[str, Optional[str | int]] = { + "BC": assigned_bc, + "BM": match_type, + "B1": left_dist, + "B2": right_dist, + "B3": extracted_start_seq, + "B4": extracted_end_seq, + "B5": left_match, + "B6": right_match, + } + results.append(row) + + return results + + def extract_and_assign_barcodes_in_bam( bam_path: str | Path, *, @@ -1657,34 +1786,38 @@ def extract_and_assign_barcodes_in_bam( require_both_ends: bool = False, min_barcode_score: Optional[int] = None, samtools_backend: str | None = "auto", - # New flanking parameters (optional; when provided, use flanking-based extraction) barcode_kit_config: Optional[BarcodeKitConfig] = None, barcode_ends: Optional[str] = None, barcode_amplicon_gap_tolerance: int = 5, + threads: Optional[int] = None, ) -> Path: - """Extract barcodes from reads and assign best-matching barcode from reference set. + """Extract barcodes from reads and write results to a Parquet sidecar file. This function extracts barcode sequences adjacent to adapter sequences at read ends, - matches them against a reference barcode set, and writes BAM tags for: + matches them against a reference barcode set, and writes results to a Parquet sidecar + file (``.barcode_tags.parquet``) with columns: - BC: Assigned barcode name (or "unclassified") - B1: Read-start match edit distance (if found) - B2: Read-end match edit distance (if found) + - B3: Read-start extracted sequence (if found) + - B4: Read-end extracted sequence (if found) - B5: Read-start barcode name (if found) - B6: Read-end barcode name (if found) - BM: Match type ("both", "read_start_only", "read_end_only", "mismatch", "unclassified") - When ``barcode_kit_config`` with flanking sequences is provided, extraction uses - ``_extract_sequence_with_flanking`` instead of - ``_extract_barcode_adjacent_to_adapter_on_read_end``. + When ``threads`` > 1, barcode extraction is parallelized across CPU cores using + multiprocessing while BAM I/O remains single-threaded. Parameters ---------- bam_path : str or Path - Path to the input BAM file (will be modified in place). + Path to the input BAM file (not modified). barcode_adapters : List[Optional[str]] Two-element list of adapter sequences: [left_adapter, right_adapter]. - Either can be None to skip that end. + Either can be None to skip that end. Legacy parameter retained for + backwards compatibility; adapters are converted to flanking config + by the caller. barcode_references : Dict[str, str] Mapping of barcode names to barcode sequences. barcode_length : int, optional @@ -1704,20 +1837,22 @@ def extract_and_assign_barcodes_in_bam( min_barcode_score : int, optional Minimum edit distance threshold. samtools_backend : str or None - Backend for BAM indexing. + Backend for BAM reading ("python" for pysam, "cli" for samtools). barcode_kit_config : BarcodeKitConfig, optional - Full barcode kit config with flanking sequences. When provided with - flanking data, enables flanking-based extraction. + Barcode kit config with flanking sequences. Required for extraction. barcode_ends : str, optional Which read ends to check: "both", "read_start", "read_end", "left_only", "right_only". barcode_amplicon_gap_tolerance : int Allowed gap/overlap (bp) between amplicon and barcode in amplicon-only extraction. + threads : int, optional + Number of worker processes for barcode extraction. If None or <= 1, + extraction runs in a single process. Returns ------- Path - Path to the modified BAM file. + Path to the Parquet sidecar file containing barcode tags. """ input_bam = Path(bam_path) @@ -1733,35 +1868,15 @@ def extract_and_assign_barcodes_in_bam( ) barcode_length = lengths.pop() - # Determine if we should use flanking-based extraction - use_flanking = barcode_kit_config is not None and barcode_kit_config.flanking is not None - flanking_config = barcode_kit_config.flanking if use_flanking else None + flanking_config = barcode_kit_config.flanking if barcode_kit_config else None + if flanking_config is None: + raise ValueError( + "barcode_kit_config with flanking sequences is required for barcode extraction." + ) effective_barcode_ends = barcode_ends or ( barcode_kit_config.barcode_ends if barcode_kit_config else "both" ) - # Build legacy adapter list or determine which ends to check - if not use_flanking: - # Legacy path: validate adapters - if barcode_adapters is None: - barcode_adapters = [None, None] - elif not isinstance(barcode_adapters, (list, tuple)) or len(barcode_adapters) != 2: - raise ValueError( - "barcode_adapters must be a two-element list: [left_adapter, right_adapter]" - ) - - adapters: List[Optional[str]] = [] - for adapter in barcode_adapters: - if adapter is None: - adapters.append(None) - else: - val = str(adapter).strip().upper() - adapters.append(val if val and val.lower() != "none" else None) - - if all(a is None for a in adapters): - logger.warning("No barcode adapters provided; skipping barcode extraction") - return input_bam - if not barcode_references: raise ValueError("barcode_references must be provided with at least one barcode") @@ -1783,205 +1898,149 @@ def extract_and_assign_barcodes_in_bam( check_end = effective_barcode_ends in ("both", "right_only", "read_end") backend_choice = _resolve_samtools_backend(samtools_backend) - pysam_mod = _require_pysam() - tmp_bam = input_bam.with_name(f"{input_bam.stem}.bc_tmp{input_bam.suffix}") + cpu_count = os.cpu_count() or 1 + num_workers = min(max(1, int(threads)), cpu_count) if threads else 1 + + # ── Shared extraction kwargs ──────────────────────────────────────── + extraction_kwargs = dict( + barcode_length=barcode_length, + barcode_search_window=barcode_search_window, + barcode_max_edit_distance=barcode_max_edit_distance, + barcode_adapter_matcher=matcher, + barcode_composite_max_edits=composite_max_edits, + barcode_min_separation=barcode_min_separation, + require_both_ends=require_both_ends, + min_barcode_score=min_barcode_score, + bc_refs=bc_refs, + flanking_config=flanking_config, + check_start=check_start, + check_end=check_end, + ) + + # ── Single BAM pass: collect read_names + sequences ───────────────── + read_names: List[str] = [] + is_primary: List[bool] = [] + sequences: List[str] = [] + if backend_choice == "python": + pysam_mod = _require_pysam() + with pysam_mod.AlignmentFile(str(input_bam), "rb") as in_bam: + for read in tqdm( + in_bam.fetch(until_eof=True), desc="Barcode: reading BAM", unit=" reads" + ): + read_names.append(read.query_name) + is_primary.append(not read.is_secondary and not read.is_supplementary) + sequences.append(read.query_sequence or "") + else: + cmd = ["samtools", "view", str(input_bam)] + proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) + assert proc.stdout is not None + for line in tqdm(proc.stdout, desc="Barcode: reading BAM", unit=" reads"): + if not line.strip() or line.startswith("@"): + continue + fields = line.rstrip("\n").split("\t") + if len(fields) < 11: + continue + flag = int(fields[1]) + read_names.append(fields[0]) + is_primary.append(not bool(flag & 0x100) and not bool(flag & 0x800)) + seq = fields[9] + sequences.append("" if seq == "*" else seq) + rc = proc.wait() + if rc != 0: + stderr = proc.stderr.read() if proc.stderr else "" + raise RuntimeError(f"samtools view failed (exit {rc}):\n{stderr}") + total_reads = len(read_names) + + # ── Barcode extraction ────────────────────────────────────────────── + if num_workers <= 1 or total_reads == 0: + all_results = _extract_and_match_barcodes_for_reads(sequences, **extraction_kwargs) + else: + chunk_size = max(1000, ceil(total_reads / num_workers)) + num_chunks = ceil(total_reads / chunk_size) + actual_workers = min(num_workers, num_chunks) + logger.info( + "Barcode extraction: %d reads across %d workers (chunk_size=%d)", + total_reads, + actual_workers, + chunk_size, + ) + with ProcessPoolExecutor(max_workers=actual_workers) as pool: + futures = {} + for chunk_idx in range(num_chunks): + start_idx = chunk_idx * chunk_size + end_idx = min(start_idx + chunk_size, total_reads) + chunk = sequences[start_idx:end_idx] + future = pool.submit( + _extract_and_match_barcodes_for_reads, + chunk, + **extraction_kwargs, + ) + futures[future] = chunk_idx + + chunk_results: Dict[int, List[Dict[str, Optional[str | int]]]] = {} + with tqdm( + total=total_reads, + desc=f"Barcode: extracting ({actual_workers} workers)", + unit=" reads", + ) as pbar: + for future in as_completed(futures): + idx = futures[future] + result = future.result() + chunk_results[idx] = result + pbar.update(len(result)) + + all_results: List[Dict[str, Optional[str | int]]] = [] + for chunk_idx in range(num_chunks): + all_results.extend(chunk_results[chunk_idx]) + del chunk_results + + del sequences + + # ── Compute statistics and write Parquet sidecar ──────────────────── + import pandas as pd - # Statistics - total_reads = 0 reads_both_ends = 0 reads_start_only = 0 reads_end_only = 0 reads_unclassified = 0 reads_mismatch_ends = 0 - with ( - pysam_mod.AlignmentFile(str(input_bam), "rb") as in_bam, - pysam_mod.AlignmentFile(str(tmp_bam), "wb", template=in_bam) as out_bam, - ): - for read in in_bam.fetch(until_eof=True): - total_reads += 1 - sequence = read.query_sequence or "" - - bc_matches: List[Tuple[Optional[str], Optional[int]]] = [ - (None, None), - (None, None), - ] - extracted_start_seq: Optional[str] = None - extracted_end_seq: Optional[str] = None - padded_region: Optional[str] = None + tag_rows: List[Dict[str, Optional[str | int]]] = [] + for read_idx in range(total_reads): + row = all_results[read_idx] + match_type = row["BM"] + if match_type == "both": + reads_both_ends += 1 + elif match_type == "read_start_only": + reads_start_only += 1 + elif match_type == "read_end_only": + reads_end_only += 1 + elif match_type == "mismatch": + reads_mismatch_ends += 1 + reads_unclassified += 1 + elif match_type == "unclassified": + reads_unclassified += 1 + # Count unclassified for require_both_ends single-end matches + if require_both_ends and match_type in ("read_start_only", "read_end_only"): + reads_unclassified += 1 + + # Only include primary alignments in the sidecar + if not is_primary[read_idx]: + continue - for i, read_end in enumerate(["start", "end"]): - if i == 0 and not check_start: - continue - if i == 1 and not check_end: - continue + tag_row = {"read_name": read_names[read_idx]} + tag_row.update(row) + tag_rows.append(tag_row) - search_from_start = read_end == "start" - - extracted_bc: Optional[str] = None - padded_region = None - - if use_flanking: - flanking_candidates: List[FlankingConfig] = [] - if flanking_config is not None: - if flanking_config.left_ref_end is not None and ( - flanking_config.left_ref_end.adapter_side - or flanking_config.left_ref_end.amplicon_side - ): - flanking_candidates.append(flanking_config.left_ref_end) - if ( - flanking_config.right_ref_end is not None - and flanking_config.right_ref_end not in flanking_candidates - and ( - flanking_config.right_ref_end.adapter_side - or flanking_config.right_ref_end.amplicon_side - ) - ): - flanking_candidates.append(flanking_config.right_ref_end) - - for candidate in flanking_candidates: - end_flanking = candidate - if read_end == "end": - if candidate.adapter_side and candidate.amplicon_side: - end_flanking = FlankingConfig( - adapter_side=_reverse_complement(candidate.amplicon_side), - amplicon_side=_reverse_complement(candidate.adapter_side), - adapter_pad=candidate.amplicon_pad, - amplicon_pad=candidate.adapter_pad, - ) - elif candidate.adapter_side: - end_flanking = FlankingConfig( - adapter_side=_reverse_complement(candidate.adapter_side), - amplicon_side=None, - adapter_pad=candidate.adapter_pad, - amplicon_pad=candidate.amplicon_pad, - ) - elif candidate.amplicon_side: - end_flanking = FlankingConfig( - adapter_side=None, - amplicon_side=_reverse_complement(candidate.amplicon_side), - adapter_pad=candidate.adapter_pad, - amplicon_pad=candidate.amplicon_pad, - ) - else: - end_flanking = FlankingConfig( - adapter_side=None, - amplicon_side=None, - ) - - extracted_bc, _, _, padded_region = _extract_barcode_with_flanking( - read_sequence=sequence, - target_length=barcode_length, - search_window=barcode_search_window, - search_from_start=search_from_start, - flanking=end_flanking, - adapter_matcher=matcher, - composite_max_edits=composite_max_edits, - ) + del all_results, read_names, is_primary - if extracted_bc and read_end == "end": - extracted_bc = _reverse_complement(extracted_bc) - if padded_region is not None: - padded_region = _reverse_complement(padded_region) - - if extracted_bc: - break - else: - # Legacy path - adapter = adapters[i] if i < len(adapters) else None - if adapter is None: - continue - search_adapter = _reverse_complement(adapter) if read_end == "end" else adapter - extracted_bc, _ = _extract_barcode_adjacent_to_adapter_on_read_end( - read_sequence=sequence, - adapter_sequence=search_adapter, - barcode_length=barcode_length, - barcode_search_window=barcode_search_window, - search_from_start=search_from_start, - adapter_matcher=matcher, - adapter_max_edits=composite_max_edits, - ) - if extracted_bc and read_end == "end": - extracted_bc = _reverse_complement(extracted_bc) - - if extracted_bc: - if read_end == "start": - extracted_start_seq = extracted_bc - else: - extracted_end_seq = extracted_bc - match_name, match_dist = _match_barcode_to_references( - extracted_bc, - bc_refs, - max_edit_distance=barcode_max_edit_distance, - min_separation=barcode_min_separation, - padded_region=padded_region, - ) - if match_name is not None: - if min_barcode_score is None or match_dist <= min_barcode_score: - bc_matches[i] = (match_name, match_dist) - - left_match, left_dist = bc_matches[0] - right_match, right_dist = bc_matches[1] - - # Determine match type and final barcode assignment - if left_match and right_match: - if left_match == right_match: - match_type = "both" - assigned_bc = left_match - reads_both_ends += 1 - else: - match_type = "mismatch" - assigned_bc = "unclassified" - reads_mismatch_ends += 1 - reads_unclassified += 1 - elif left_match: - match_type = "read_start_only" - reads_start_only += 1 - assigned_bc = "unclassified" if require_both_ends else left_match - if require_both_ends: - reads_unclassified += 1 - elif right_match: - match_type = "read_end_only" - reads_end_only += 1 - assigned_bc = "unclassified" if require_both_ends else right_match - if require_both_ends: - reads_unclassified += 1 - else: - match_type = "unclassified" - assigned_bc = "unclassified" - reads_unclassified += 1 - - # Write tags - read.set_tag("BC", assigned_bc, value_type="Z") - read.set_tag("BM", match_type, value_type="Z") - - if extracted_start_seq: - read.set_tag("B3", extracted_start_seq, value_type="Z") - if extracted_end_seq: - read.set_tag("B4", extracted_end_seq, value_type="Z") - - if left_match is not None: - read.set_tag("B1", left_dist, value_type="i") - read.set_tag("B5", left_match, value_type="Z") - if right_match is not None: - read.set_tag("B2", right_dist, value_type="i") - read.set_tag("B6", right_match, value_type="Z") - - out_bam.write(read) - - # Replace original BAM and re-index - tmp_bam.replace(input_bam) - index_paths = ( - input_bam.with_suffix(input_bam.suffix + ".bai"), - Path(str(input_bam) + ".bai"), - ) - for idx_path in index_paths: - if idx_path.exists(): - idx_path.unlink() - if backend_choice == "python": - _index_bam_with_pysam(input_bam) - else: - _index_bam_with_samtools(input_bam) + sidecar_path = input_bam.with_suffix(".barcode_tags.parquet") + df = pd.DataFrame(tag_rows) + for col in ("BC", "BM", "B1", "B2", "B3", "B4", "B5", "B6"): + if col not in df.columns: + df[col] = None + df.to_parquet(sidecar_path, index=False) + del tag_rows, df logger.info( "Barcode extraction complete for %s:\n" @@ -1990,7 +2049,8 @@ def extract_and_assign_barcodes_in_bam( " read_start_only=%d (%.1f%%)\n" " read_end_only=%d (%.1f%%)\n" " mismatch_ends=%d (%.1f%%)\n" - " unclassified=%d (%.1f%%)", + " unclassified=%d (%.1f%%)\n" + " sidecar=%s", input_bam, total_reads, reads_both_ends, @@ -2003,9 +2063,10 @@ def extract_and_assign_barcodes_in_bam( 100 * reads_mismatch_ends / max(1, total_reads), reads_unclassified, 100 * reads_unclassified / max(1, total_reads), + sidecar_path, ) - return input_bam + return sidecar_path def _stream_dorado_logs(stderr_iter) -> None: @@ -2979,7 +3040,7 @@ def count_aligned_reads(bam_file, samtools_backend: str | None = "auto"): def annotate_demux_type_from_bi_tag( - bam_path: str | Path, output_path: Optional[str | Path] = None, threshold: float = 0.0 + bam_path: str | Path, output_path: Optional[str | Path] = None, threshold: float = 0.65 ) -> Path: """Annotate reads with a BM tag based on dorado bi per-end barcode scores. @@ -3868,57 +3929,105 @@ def extract_readnames_from_bam(aligned_BAM, samtools_backend: str | None = "auto def separate_bam_by_bc( - input_bam, output_prefix, bam_suffix, split_dir, samtools_backend: str | None = "auto" + input_bam, + output_prefix, + bam_suffix, + split_dir, + samtools_backend: str | None = "auto", + barcode_sidecar: Optional[Path] = None, ): """ - Separates an input BAM file on the BC SAM tag values. + Separates an input BAM file by barcode assignment. + + When *barcode_sidecar* is provided, barcode assignments are read from the + Parquet sidecar file (``read_name → BC`` mapping) instead of from BAM tags. Parameters: input_bam (str): File path to the BAM file to split. output_prefix (str): A prefix to append to the output BAM. bam_suffix (str): A suffix to add to the bam file. - split_dir (str): String indicating path to directory to split BAMs into + split_dir (str): String indicating path to directory to split BAMs into. + samtools_backend (str or None): Backend for BAM I/O. + barcode_sidecar (Path, optional): Path to barcode_tags.parquet sidecar. Returns: None Writes out split BAM files. """ - logger.debug("Demultiplexing BAM based on the BC tag") - bam_base = input_bam.name - bam_base_minus_suffix = input_bam.stem + import pandas as pd + bam_base_minus_suffix = input_bam.stem backend_choice = _resolve_samtools_backend(samtools_backend) + # Load barcode assignments from sidecar if available + bc_lookup: Optional[Dict[str, str]] = None + if barcode_sidecar is not None and Path(barcode_sidecar).exists(): + logger.debug("Loading barcode assignments from sidecar: %s", barcode_sidecar) + bc_df = pd.read_parquet(barcode_sidecar, columns=["read_name", "BC"]) + bc_lookup = dict(zip(bc_df["read_name"], bc_df["BC"])) + del bc_df + + # When using a sidecar, only write primary alignments to split BAMs + # so that split BAMs and sidecar are 1:1 on read_name. + primary_only = bc_lookup is not None + if backend_choice == "python": pysam_mod = _require_pysam() - # Open the input BAM file for reading with pysam_mod.AlignmentFile(str(input_bam), "rb") as bam: - # Create a dictionary to store output BAM files output_files = {} - # Iterate over each read in the BAM file for read in bam: - try: - # Get the barcode tag value - bc_tag = read.get_tag("BC", with_value_type=True)[0] - # bc_tag = read.get_tag("BC", with_value_type=True)[0].split('barcode')[1] - # Open the output BAM file corresponding to the barcode - if bc_tag not in output_files: - output_path = ( - split_dir - / f"{output_prefix}_{bam_base_minus_suffix}_{bc_tag}{bam_suffix}" - ) - output_files[bc_tag] = pysam_mod.AlignmentFile( - str(output_path), "wb", header=bam.header - ) - # Write the read to the corresponding output BAM file - output_files[bc_tag].write(read) - except KeyError: - logger.warning(f"BC tag not present for read: {read.query_name}") - # Close all output BAM files + if primary_only and (read.is_secondary or read.is_supplementary): + continue + + # Look up barcode from sidecar or BAM tag + bc_tag = None + if bc_lookup is not None: + bc_tag = bc_lookup.get(read.query_name) + else: + try: + bc_tag = read.get_tag("BC", with_value_type=True)[0] + except KeyError: + pass + + if bc_tag is None: + bc_tag = "unclassified" + + if bc_tag not in output_files: + output_path = ( + split_dir / f"{output_prefix}_{bam_base_minus_suffix}_{bc_tag}{bam_suffix}" + ) + output_files[bc_tag] = pysam_mod.AlignmentFile( + str(output_path), "wb", header=bam.header + ) + output_files[bc_tag].write(read) + for output_file in output_files.values(): + output_file.close() + return + + # CLI backend: if we have a sidecar, use pysam-style single-pass splitting + # since samtools view -d can't filter by sidecar + if bc_lookup is not None: + pysam_mod = _require_pysam() + with pysam_mod.AlignmentFile(str(input_bam), "rb") as bam: + output_files = {} + for read in bam: + if read.is_secondary or read.is_supplementary: + continue + + bc_tag = bc_lookup.get(read.query_name, "unclassified") + if bc_tag not in output_files: + output_path = ( + split_dir / f"{output_prefix}_{bam_base_minus_suffix}_{bc_tag}{bam_suffix}" + ) + output_files[bc_tag] = pysam_mod.AlignmentFile( + str(output_path), "wb", header=bam.header + ) + output_files[bc_tag].write(read) for output_file in output_files.values(): output_file.close() return + # CLI backend without sidecar: use samtools to split by BAM BC tag def _collect_bc_tags() -> set[str]: bc_tags: set[str] = set() proc = subprocess.Popen( @@ -3959,7 +4068,11 @@ def _collect_bc_tags() -> set[str]: def split_and_index_BAM( - aligned_sorted_BAM, split_dir, bam_suffix, samtools_backend: str | None = "auto" + aligned_sorted_BAM, + split_dir, + bam_suffix, + samtools_backend: str | None = "auto", + barcode_sidecar: Optional[Path] = None, ): """ A wrapper function for splitting BAMS and indexing them. @@ -3967,6 +4080,7 @@ def split_and_index_BAM( aligned_sorted_BAM (str): A string representing the file path of the aligned_sorted BAM file. split_dir (str): A string representing the file path to the directory to split the BAMs into. bam_suffix (str): A suffix to add to the bam file. + barcode_sidecar (Path, optional): Path to barcode_tags.parquet sidecar. Returns: None @@ -3983,6 +4097,7 @@ def split_and_index_BAM( bam_suffix, split_dir, samtools_backend=samtools_backend, + barcode_sidecar=barcode_sidecar, ) # Make a BAM index file for the BAMs in that directory bam_pattern = "*" + bam_suffix diff --git a/src/smftools/informatics/h5ad_functions.py b/src/smftools/informatics/h5ad_functions.py index 0c4961f..272c555 100644 --- a/src/smftools/informatics/h5ad_functions.py +++ b/src/smftools/informatics/h5ad_functions.py @@ -714,6 +714,7 @@ def add_demux_type_from_bm_tag(adata, bm_column="BM"): if bm_column not in adata.obs.columns: logger.warning(f"Column '{bm_column}' not found in adata.obs, cannot derive demux_type") + adata.obs["demux_type"] = "unknown" return logger.info(f"Deriving demux_type from {bm_column} tag") diff --git a/tests/unit/informatics/test_barcode_extraction.py b/tests/unit/informatics/test_barcode_extraction.py index e846e38..d4e57c5 100644 --- a/tests/unit/informatics/test_barcode_extraction.py +++ b/tests/unit/informatics/test_barcode_extraction.py @@ -12,6 +12,7 @@ UMIKitConfig, _build_flanking_from_adapters, _extract_sequence_with_flanking, + _match_barcode_to_references, _parse_flanking_config_from_dict, _parse_per_end_flanking, _reverse_complement, @@ -67,106 +68,22 @@ def _read_bam_tags(bam_path): with _pysam.AlignmentFile(str(bam_path), "rb") as fh: for read in fh.fetch(until_eof=True): out[read.query_name] = dict(read.get_tags()) - return out - - -class TestExtractBarcodeAdjacentToAdapter: - """Tests for barcode extraction from read sequences.""" - - def test_extract_barcode_from_read_start(self): - """Test barcode extraction when adapter is at read start.""" - read = "ACGTAAAACCCCGGGGTTTT" - bc, pos = bam_functions._extract_barcode_adjacent_to_adapter_on_read_end( - read_sequence=read, - adapter_sequence="ACGT", - barcode_length=4, - barcode_search_window=10, - search_from_start=True, - adapter_matcher="exact", - ) - assert bc == "AAAA" - assert pos == 0 - - def test_extract_barcode_from_read_end(self): - """Test barcode extraction when adapter is at read end.""" - read = "TTTTCCCCGGGGACGT" - bc, pos = bam_functions._extract_barcode_adjacent_to_adapter_on_read_end( - read_sequence=read, - adapter_sequence="ACGT", - barcode_length=4, - barcode_search_window=10, - search_from_start=False, - adapter_matcher="exact", - ) - assert bc == "GGGG" - assert pos == 12 - - def test_extract_barcode_no_adapter_found(self): - """Test that None is returned when adapter not found.""" - read = "TTTTCCCCGGGGAAAA" - bc, pos = bam_functions._extract_barcode_adjacent_to_adapter_on_read_end( - read_sequence=read, - adapter_sequence="ACGT", - barcode_length=4, - barcode_search_window=10, - search_from_start=True, - adapter_matcher="exact", - ) - assert bc is None - assert pos is None - - def test_extract_barcode_outside_search_window(self): - """Test that adapter outside search window is not found.""" - read = "TTTTTTTTTTACGTAAAA" - bc, pos = bam_functions._extract_barcode_adjacent_to_adapter_on_read_end( - read_sequence=read, - adapter_sequence="ACGT", - barcode_length=4, - barcode_search_window=5, - search_from_start=True, - adapter_matcher="exact", - ) - assert bc is None - assert pos is None - def test_extract_barcode_insufficient_length(self): - """Test that None is returned when barcode region is too short.""" - read = "ACGTAA" # Only 2 bases after adapter - bc, pos = bam_functions._extract_barcode_adjacent_to_adapter_on_read_end( - read_sequence=read, - adapter_sequence="ACGT", - barcode_length=4, - barcode_search_window=10, - search_from_start=True, - adapter_matcher="exact", - ) - assert bc is None - assert pos is None - def test_extract_barcode_empty_sequence(self): - """Test handling of empty read sequence.""" - bc, pos = bam_functions._extract_barcode_adjacent_to_adapter_on_read_end( - read_sequence="", - adapter_sequence="ACGT", - barcode_length=4, - barcode_search_window=10, - search_from_start=True, - adapter_matcher="exact", - ) - assert bc is None - assert pos is None +def _read_parquet_tags(parquet_path): + """Return ``{read_name: {tag: value}}`` from a barcode sidecar parquet.""" + from pathlib import Path - def test_extract_barcode_rejects_unknown_matcher(self): - """Test that unknown matcher raises ValueError.""" - with pytest.raises(ValueError, match="adapter_matcher must be one of"): - bam_functions._extract_barcode_adjacent_to_adapter_on_read_end( - read_sequence="ACGTAAAA", - adapter_sequence="ACGT", - barcode_length=4, - barcode_search_window=10, - search_from_start=True, - adapter_matcher="invalid", - ) + p = Path(parquet_path) + df = pd.read_parquet(p) + out = {} + for _, row in df.iterrows(): + tags = {} + for col in ["BC", "BM", "B1", "B2", "B3", "B4", "B5", "B6"]: + if col in row.index and pd.notna(row[col]): + tags[col] = row[col] + out[row["read_name"]] = tags + return out class TestMatchBarcodeToReferences: @@ -937,95 +854,25 @@ def _run(self, tmp_path, reads, **kwargs): barcode_composite_max_edits=0, samtools_backend="python", ) + # Provide a default flanking kit config unless the caller overrides + if "barcode_kit_config" not in kwargs: + defaults["barcode_kit_config"] = BarcodeKitConfig( + barcodes=self.BC_REFS, + barcode_length=4, + flanking=PerEndFlankingConfig( + left_ref_end=FlankingConfig( + adapter_side=self.LEFT_ADAPTER, + amplicon_side=None, + ), + right_ref_end=FlankingConfig( + adapter_side=self.RIGHT_ADAPTER, + amplicon_side=None, + ), + ), + ) defaults.update(kwargs) - bam_functions.extract_and_assign_barcodes_in_bam(bam, **defaults) - return _read_bam_tags(bam) - - # -- Legacy adapter path -------------------------------------------------- - - def test_legacy_both_ends_match(self, tmp_path): - """Both ends match BC01 → BC='BC01', BM='both'.""" - # ACGT(0-3) AAAA(4-7) NNNNNNNN(8-15) AAAA(16-19) TGCA(20-23) - reads = [{"name": "r1", "sequence": "ACGTAAAANNNNNNNNAAAATGCA"}] - tags = self._run(tmp_path, reads) - assert tags["r1"]["BC"] == "BC01" - assert tags["r1"]["BM"] == "both" - assert tags["r1"]["B5"] == "BC01" - assert tags["r1"]["B6"] == "BC01" - - def test_legacy_mismatch_ends(self, tmp_path): - """Different barcodes at each end → 'mismatch', 'unclassified'.""" - reads = [{"name": "r1", "sequence": "ACGTAAAANNNNNNNNCCCCTGCA"}] - tags = self._run(tmp_path, reads) - assert tags["r1"]["BC"] == "unclassified" - assert tags["r1"]["BM"] == "mismatch" - assert tags["r1"]["B5"] == "BC01" - assert tags["r1"]["B6"] == "BC02" - - def test_legacy_left_only(self, tmp_path): - """Only left adapter found → assigned from start, BM='read_start_only'.""" - reads = [{"name": "r1", "sequence": "ACGTAAAANNNNNNNNNNNNNNNN"}] - tags = self._run(tmp_path, reads) - assert tags["r1"]["BC"] == "BC01" - assert tags["r1"]["BM"] == "read_start_only" - assert tags["r1"]["B5"] == "BC01" - assert "B6" not in tags["r1"] - - def test_legacy_right_only(self, tmp_path): - """Only right adapter found → assigned from end, BM='read_end_only'.""" - reads = [{"name": "r1", "sequence": "NNNNNNNNNNNNNNNNCCCCTGCA"}] - tags = self._run(tmp_path, reads) - assert tags["r1"]["BC"] == "BC02" - assert tags["r1"]["BM"] == "read_end_only" - assert "B5" not in tags["r1"] - assert tags["r1"]["B6"] == "BC02" - - def test_legacy_unclassified(self, tmp_path): - """No adapters found → 'unclassified'.""" - reads = [{"name": "r1", "sequence": "TTTTTTTTTTTTTTTTTTTTTTTT"}] - tags = self._run(tmp_path, reads) - assert tags["r1"]["BC"] == "unclassified" - assert tags["r1"]["BM"] == "unclassified" - assert "B5" not in tags["r1"] - assert "B6" not in tags["r1"] - - # -- Filtering options ---------------------------------------------------- - - def test_require_both_ends_rejects_single(self, tmp_path): - """require_both_ends=True with only left match → 'unclassified'.""" - reads = [{"name": "r1", "sequence": "ACGTAAAANNNNNNNNNNNNNNNN"}] - tags = self._run(tmp_path, reads, require_both_ends=True) - assert tags["r1"]["BC"] == "unclassified" - assert tags["r1"]["BM"] == "read_start_only" - assert tags["r1"]["B5"] == "BC01" - - def test_min_barcode_score_filters_weak_match(self, tmp_path): - """min_barcode_score=0 rejects matches with edit distance > 0.""" - # AAAT has Hamming distance 1 from AAAA (BC01) - reads = [{"name": "r1", "sequence": "ACGTAAATNNNNNNNNAAATTGCA"}] - tags = self._run(tmp_path, reads, min_barcode_score=0) - assert tags["r1"]["BC"] == "unclassified" - assert tags["r1"]["BM"] == "unclassified" - - # -- barcode_ends --------------------------------------------------------- - - def test_barcode_ends_left_only(self, tmp_path): - """barcode_ends='left_only' ignores read end entirely.""" - reads = [{"name": "r1", "sequence": "ACGTAAAANNNNNNNNCCCCTGCA"}] - tags = self._run(tmp_path, reads, barcode_ends="left_only") - assert tags["r1"]["BC"] == "BC01" - assert tags["r1"]["BM"] == "read_start_only" - assert tags["r1"]["B5"] == "BC01" - assert "B6" not in tags["r1"] - - def test_barcode_ends_right_only(self, tmp_path): - """barcode_ends='right_only' ignores read start entirely.""" - reads = [{"name": "r1", "sequence": "ACGTAAAANNNNNNNNCCCCTGCA"}] - tags = self._run(tmp_path, reads, barcode_ends="right_only") - assert tags["r1"]["BC"] == "BC02" - assert tags["r1"]["BM"] == "read_end_only" - assert "B5" not in tags["r1"] - assert tags["r1"]["B6"] == "BC02" + sidecar = bam_functions.extract_and_assign_barcodes_in_bam(bam, **defaults) + return _read_parquet_tags(sidecar) # -- Flanking-based extraction ------------------------------------------- @@ -1132,57 +979,51 @@ def test_either_mode_uses_both_when_available(self, tmp_path): assert tags["r1"]["BC"] == "BC01" assert tags["r1"]["BM"] == "both" - def test_either_mode_falls_back_to_amplicon_only(self, tmp_path): - """'either' mode: adapter not found → falls back to amplicon_only. - - Forward read with RC construct at right end: - - Left: BC01(AAAA) + amplicon(CGATCGAT) + ... - - Right: ... + RC(amplicon)(=ATCGATCG) + RC(BC01)(=TTTT) - """ - amplicon = "CGATCGAT" - rc_amplicon = _reverse_complement(amplicon) # ATCGATCG - rc_bc = _reverse_complement("AAAA") # TTTT - kit = BarcodeKitConfig( - barcodes=self.BC_REFS, - barcode_length=4, - flanking=PerEndFlankingConfig( - left_ref_end=FlankingConfig(adapter_side="ZZZZ", amplicon_side=amplicon), - right_ref_end=FlankingConfig(adapter_side="ZZZZ", amplicon_side=amplicon), - ), - ) - reads = [{"name": "r1", "sequence": f"AAAA{amplicon}NNNNNNNN{rc_amplicon}{rc_bc}"}] - tags = self._run( - tmp_path, - reads, - barcode_kit_config=kit, - barcode_adapter_matcher="exact", - ) - assert tags["r1"]["BC"] == "BC01" - assert tags["r1"]["BM"] == "both" - # -- Error handling ------------------------------------------------------- def test_empty_references_raises(self, tmp_path): """Empty barcode_references raises ValueError.""" bam = tmp_path / "test.bam" _create_test_bam(bam, [{"name": "r1", "sequence": "ACGTAAAANNNNNNNN"}]) + kit = BarcodeKitConfig( + barcodes={}, + barcode_length=4, + flanking=PerEndFlankingConfig( + left_ref_end=FlankingConfig(adapter_side="ACGT", amplicon_side=None), + right_ref_end=FlankingConfig(adapter_side="TGCA", amplicon_side=None), + ), + ) with pytest.raises(ValueError, match="barcode_references"): bam_functions.extract_and_assign_barcodes_in_bam( bam, barcode_adapters=["ACGT", "TGCA"], barcode_references={}, barcode_adapter_matcher="exact", + barcode_kit_config=kit, samtools_backend="python", ) # -- Multiple reads ------------------------------------------------------- def test_multiple_reads_mixed_outcomes(self, tmp_path): - """Multiple reads produce correct per-read tags.""" + """Multiple reads produce correct per-read tags. + + Uses flanking-based extraction with adapter-only flanking. + Right end: RC(adapter) + RC(barcode), so for BC01=AAAA the right end + has RC(TGCA)=TGCA followed by RC(AAAA)=TTTT, but the code extracts + the barcode before RC(adapter) and then reverse-complements it back. + """ + rc_adapter = _reverse_complement("TGCA") # TGCA + rc_bc01 = _reverse_complement("AAAA") # TTTT + rc_bc02 = _reverse_complement("CCCC") # GGGG reads = [ - {"name": "both_bc01", "sequence": "ACGTAAAANNNNNNNNAAAATGCA"}, - {"name": "mismatch", "sequence": "ACGTAAAANNNNNNNNCCCCTGCA"}, + # Both ends match BC01 + {"name": "both_bc01", "sequence": f"ACGTAAAANNNNNNNN{rc_bc01}{rc_adapter}"}, + # Start=BC01, End=BC02 → mismatch + {"name": "mismatch", "sequence": f"ACGTAAAANNNNNNNN{rc_bc02}{rc_adapter}"}, + # Only left adapter found {"name": "left_bc03", "sequence": "ACGTGGGGNNNNNNNNNNNNNNNN"}, + # No adapters found {"name": "unclassified", "sequence": "TTTTTTTTTTTTTTTTTTTTTTTT"}, ] tags = self._run(tmp_path, reads) @@ -1329,8 +1170,8 @@ def _run(self, tmp_path, reads, **kwargs): samtools_backend="python", ) defaults.update(kwargs) - bam_functions.extract_and_assign_barcodes_in_bam(bam, **defaults) - return _read_bam_tags(bam) + sidecar = bam_functions.extract_and_assign_barcodes_in_bam(bam, **defaults) + return _read_parquet_tags(sidecar) def _make_kit(self, **overrides): kw = dict( diff --git a/tests/unit/informatics/test_umi_annotation.py b/tests/unit/informatics/test_umi_annotation.py index d439ac7..92d4a81 100644 --- a/tests/unit/informatics/test_umi_annotation.py +++ b/tests/unit/informatics/test_umi_annotation.py @@ -1,3 +1,4 @@ +import pandas as pd import pytest from smftools.informatics import bam_functions @@ -53,131 +54,24 @@ def _read_bam_tags(bam_path): return out -def test_validate_umi_config_requires_adapters_when_enabled(): - with pytest.raises(ValueError, match="no UMI adapter sequences were provided"): - bam_functions.validate_umi_config(True, [None, None], 8) - - -def test_validate_umi_config_requires_two_slot_adapter_list(): - with pytest.raises(ValueError, match="two-item list"): - bam_functions.validate_umi_config(True, ["ACGT"], 10) - - -def test_validate_umi_config_accepts_directional_two_slot_adapters(): - adapters, length = bam_functions.validate_umi_config(True, ["ACGT", None], 10) - assert adapters == ["ACGT", None] - assert length == 10 - - adapters, length = bam_functions.validate_umi_config(True, [None, "TTAA"], 12) - assert adapters == [None, "TTAA"] - assert length == 12 - - -def test_extract_umi_from_read_start_reports_same_orientation(): - read = "ACGTAAACTGCTGATCGTAG" - umi = bam_functions._extract_umi_adjacent_to_adapter_on_read_end( - read_sequence=read, - adapter_sequence="ACGT", - umi_length=5, - umi_search_window=10, - search_from_start=True, - ) - assert umi == "AAACT" - - -def test_extract_umi_from_read_end_reports_same_orientation(): - read = "GATTACAACCCCGGGTTTT" - umi = bam_functions._extract_umi_adjacent_to_adapter_on_read_end( - read_sequence=read, - adapter_sequence="ACGT", - umi_length=4, - umi_search_window=10, - search_from_start=False, - ) - assert umi is None - - -def test_extract_umi_from_read_end_with_match(): - read = "GATTACAACCCCGGGTTTT" - umi = bam_functions._extract_umi_adjacent_to_adapter_on_read_end( - read_sequence=read, - adapter_sequence="GGG", - umi_length=4, - umi_search_window=10, - search_from_start=False, - ) - assert umi == "CCCC" - - -def test_extract_umi_respects_search_window(): - read = "TTTTACGTAAAATTTT" - umi = bam_functions._extract_umi_adjacent_to_adapter_on_read_end( - read_sequence=read, - adapter_sequence="ACGT", - umi_length=4, - umi_search_window=1, - search_from_start=True, - ) - assert umi is None - - -def test_extract_umi_uses_adapter_occurrence_nearest_targeted_end(): - # Read has two "ACGT" adapters. When searching from end, should use the - # one nearest to the end (second occurrence) and extract UMI before it. - # Structure: NNNNNNNN ACGT AAAA TTTT ACGT GGGG - # ^^^1 ^^^^ ^^^2 - # UMI adapter (nearest to end) - read = "NNNNNNNNACGTAAAATTTTACGTGGGG" - umi = bam_functions._extract_umi_adjacent_to_adapter_on_read_end( - read_sequence=read, - adapter_sequence="ACGT", - umi_length=4, - umi_search_window=10, - search_from_start=False, - ) - # UMI is extracted BEFORE the adapter when searching from end - assert umi == "TTTT" - - -def test_extract_umi_rejects_unknown_matcher(): - with pytest.raises(ValueError, match="adapter_matcher must be one of"): - bam_functions._extract_umi_adjacent_to_adapter_on_read_end( - read_sequence="ACGTAAAA", - adapter_sequence="ACGT", - umi_length=4, - umi_search_window=10, - search_from_start=True, - adapter_matcher="unknown", - ) - - -def test_extract_umi_can_use_edlib_matcher(monkeypatch): - class _FakeEdlib: - @staticmethod - def align(_query, _target, mode, task, k): - assert mode == "HW" - assert task == "locations" - assert k == 1 - return {"editDistance": 1, "locations": [(0, 3)]} +def _read_parquet_tags(parquet_path): + """Return ``{read_name: {tag: value}}`` from a Parquet sidecar file. - monkeypatch.setattr(bam_functions, "require", lambda *args, **kwargs: _FakeEdlib()) - umi = bam_functions._extract_umi_adjacent_to_adapter_on_read_end( - read_sequence="ACGTAAAA", - adapter_sequence="ACGA", - umi_length=4, - umi_search_window=10, - search_from_start=True, - adapter_matcher="edlib", - adapter_max_edits=1, - ) - assert umi == "AAAA" - - -def test_target_read_end_for_ref_side_respects_strand(): - assert bam_functions._target_read_end_for_ref_side(False, "left") == "start" - assert bam_functions._target_read_end_for_ref_side(False, "right") == "end" - assert bam_functions._target_read_end_for_ref_side(True, "left") == "end" - assert bam_functions._target_read_end_for_ref_side(True, "right") == "start" + Only includes non-null tag values to match the BAM tag semantics + (missing tag == not present). + """ + df = pd.read_parquet(parquet_path) + out = {} + for _, row in df.iterrows(): + tags = {} + for col in df.columns: + if col == "read_name": + continue + val = row[col] + if pd.notna(val): + tags[col] = val + out[row["read_name"]] = tags + return out class TestUMIFlankingExtraction: @@ -318,52 +212,29 @@ def test_per_end_different_flanking(self): class TestAnnotateUmiTagsInBam: """Integration tests for UMI annotation orchestration.""" - LEFT_ADAPTER = "ACGT" - RIGHT_ADAPTER = "TGCA" - def _run(self, tmp_path, reads, **kwargs): - """Create BAM → run UMI annotation → return {name: {tag: val}}.""" + """Create BAM -> run UMI annotation -> return {name: {tag: val}} from Parquet sidecar.""" bam = tmp_path / "test.bam" _create_test_bam(bam, reads) defaults = dict( use_umi=True, - umi_adapters=[self.LEFT_ADAPTER, self.RIGHT_ADAPTER], - umi_length=4, + umi_kit_config=UMIKitConfig( + flanking=PerEndFlankingConfig( + left_ref_end=FlankingConfig(adapter_side="ACGT"), + right_ref_end=FlankingConfig(adapter_side="TGCA"), + ), + length=4, + umi_ends="both", + umi_flank_mode="adapter_only", + ), umi_search_window=200, umi_adapter_matcher="exact", umi_adapter_max_edits=0, samtools_backend="python", ) defaults.update(kwargs) - bam_functions.annotate_umi_tags_in_bam(bam, **defaults) - return _read_bam_tags(bam) - - # -- Legacy adapter path -------------------------------------------------- - - def test_legacy_both_ends_umi(self, tmp_path): - """UMI extracted at both ends → U1, U2, and combined RX.""" - # ACGT(0-3) AAAA(4-7) NNNNNNNN(8-15) CCCC(16-19) TGCA(20-23) - reads = [{"name": "r1", "sequence": "ACGTAAAANNNNNNNNCCCCTGCA"}] - tags = self._run(tmp_path, reads) - assert tags["r1"]["U1"] == "AAAA" - assert tags["r1"]["U2"] == "CCCC" - assert tags["r1"]["RX"] == "AAAA-CCCC" - - def test_legacy_left_only_umi(self, tmp_path): - """UMI only at left end → U1 and RX set, no U2.""" - reads = [{"name": "r1", "sequence": "ACGTGGGGNNNNNNNNNNNNNNNN"}] - tags = self._run(tmp_path, reads) - assert tags["r1"]["U1"] == "GGGG" - assert "U2" not in tags["r1"] - assert tags["r1"]["RX"] == "GGGG" - - def test_legacy_right_only_umi(self, tmp_path): - """UMI only at right end → U2 and RX set, no U1.""" - reads = [{"name": "r1", "sequence": "NNNNNNNNNNNNNNNNCCCCTGCA"}] - tags = self._run(tmp_path, reads) - assert "U1" not in tags["r1"] - assert tags["r1"]["U2"] == "CCCC" - assert tags["r1"]["RX"] == "CCCC" + sidecar = bam_functions.annotate_umi_tags_in_bam(bam, **defaults) + return _read_parquet_tags(sidecar) # -- use_umi=False -------------------------------------------------------- @@ -374,38 +245,31 @@ def test_use_umi_false_returns_early(self, tmp_path): result = bam_functions.annotate_umi_tags_in_bam( bam, use_umi=False, - umi_adapters=[None, None], - umi_length=0, + umi_kit_config=UMIKitConfig( + flanking=PerEndFlankingConfig( + left_ref_end=FlankingConfig(adapter_side="ACGT"), + ), + length=4, + ), samtools_backend="python", ) assert result == bam + # No sidecar should be created + sidecar = bam.with_suffix(".umi_tags.parquet") + assert not sidecar.exists() + # BAM should have no UMI tags tags = _read_bam_tags(bam) assert "U1" not in tags["r1"] assert "U2" not in tags["r1"] + assert "US" not in tags["r1"] + assert "UE" not in tags["r1"] assert "RX" not in tags["r1"] - - # -- umi_ends filtering --------------------------------------------------- - - def test_umi_ends_left_only(self, tmp_path): - """umi_ends='left_only' skips right end.""" - reads = [{"name": "r1", "sequence": "ACGTAAAANNNNNNNNCCCCTGCA"}] - tags = self._run(tmp_path, reads, umi_ends="left_only") - assert tags["r1"]["U1"] == "AAAA" - assert "U2" not in tags["r1"] - assert tags["r1"]["RX"] == "AAAA" - - def test_umi_ends_right_only(self, tmp_path): - """umi_ends='right_only' skips left end.""" - reads = [{"name": "r1", "sequence": "ACGTAAAANNNNNNNNCCCCTGCA"}] - tags = self._run(tmp_path, reads, umi_ends="right_only") - assert "U1" not in tags["r1"] - assert tags["r1"]["U2"] == "CCCC" - assert tags["r1"]["RX"] == "CCCC" + assert "FC" not in tags["r1"] # -- Flanking-based extraction ------------------------------------------- def test_flanking_umi_extraction(self, tmp_path): - """Flanking-based UMI extraction with UMIKitConfig.""" + """Flanking-based UMI extraction with UMIKitConfig (forward read, left_only).""" umi_kit = UMIKitConfig( flanking=PerEndFlankingConfig( left_ref_end=FlankingConfig(adapter_side="ACGT"), @@ -422,12 +286,17 @@ def test_flanking_umi_extraction(self, tmp_path): umi_ends="left_only", umi_flank_mode="adapter_only", ) + # US is delimited: "UMI_seq;slot;flank_seq" + assert tags["r1"]["US"] == "AAAA;top;ACGT" + # Forward read: U1=US, U2=UE assert tags["r1"]["U1"] == "AAAA" assert "U2" not in tags["r1"] + assert "UE" not in tags["r1"] assert tags["r1"]["RX"] == "AAAA" + assert tags["r1"]["FC"] == "top" def test_flanking_top_bottom_across_ends(self, tmp_path): - """Top flank in read start -> U1, bottom flank in read end -> U2 (RC).""" + """Top flank at read start, bottom flank at read end (forward read).""" umi_kit = UMIKitConfig( flanking=PerEndFlankingConfig( left_ref_end=FlankingConfig(adapter_side="GCTA"), @@ -447,32 +316,103 @@ def test_flanking_top_bottom_across_ends(self, tmp_path): umi_ends="both", umi_flank_mode="adapter_only", ) + # Delimited US/UE + assert tags["r1"]["US"] == "GGTT;top;GCTA" + assert tags["r1"]["UE"] == "ACGA;bottom;CCGA" + # Forward read: U1=US, U2=UE assert tags["r1"]["U1"] == "GGTT" assert tags["r1"]["U2"] == "ACGA" assert tags["r1"]["RX"] == "GGTT-ACGA" + assert tags["r1"]["FC"] == "top-bottom" - # -- Strand handling ------------------------------------------------------ + # -- Reverse-read orientation swap ---------------------------------------- - def test_reverse_strand_umi(self, tmp_path): - """Reverse read: left ref → read end, right ref → read start.""" - # TGCA(0-3) GGGG(4-7) NNNNNNNN(8-15) TTTT(16-19) ACGT(20-23) - reads = [{"name": "r1", "sequence": "TGCAGGGGNNNNNNNNTTTTACGT", "is_reverse": True}] - tags = self._run(tmp_path, reads) - # Left ref → search from end → find ACGT at 20-24 → UMI before = TTTT - assert tags["r1"]["U1"] == "TTTT" - # Right ref → search from start → find TGCA at 0-4 → UMI after = GGGG + def test_reverse_read_swaps_u1_u2(self, tmp_path): + """Reverse-mapped read: U1=UE, U2=US (swapped from forward).""" + umi_kit = UMIKitConfig( + flanking=PerEndFlankingConfig( + left_ref_end=FlankingConfig(adapter_side="GCTA"), + right_ref_end=FlankingConfig(adapter_side="CCGA"), + ), + length=4, + umi_ends="both", + umi_flank_mode="adapter_only", + ) + # Same sequence as test_flanking_top_bottom_across_ends but reverse-mapped + reads = [{"name": "r1", "sequence": "GCTAGGTTNNNNNNNNNNTCGTTCGG", "is_reverse": True}] + tags = self._run( + tmp_path, + reads, + umi_kit_config=umi_kit, + umi_ends="both", + umi_flank_mode="adapter_only", + ) + # US/UE are positional (unchanged by orientation) + assert tags["r1"]["US"] == "GGTT;top;GCTA" + assert tags["r1"]["UE"] == "ACGA;bottom;CCGA" + # Reverse read: U1=UE, U2=US + assert tags["r1"]["U1"] == "ACGA" # from UE + assert tags["r1"]["U2"] == "GGTT" # from US + assert tags["r1"]["RX"] == "ACGA-GGTT" + assert tags["r1"]["FC"] == "bottom-top" + + # -- umi_ends filtering --------------------------------------------------- + + def test_umi_ends_left_only(self, tmp_path): + """umi_ends='left_only' skips right end.""" + umi_kit = UMIKitConfig( + flanking=PerEndFlankingConfig( + left_ref_end=FlankingConfig(adapter_side="ACGT"), + right_ref_end=FlankingConfig(adapter_side="TGCA"), + ), + length=4, + umi_ends="both", + umi_flank_mode="adapter_only", + ) + reads = [{"name": "r1", "sequence": "ACGTAAAANNNNNNNNCCCCTGCA"}] + tags = self._run(tmp_path, reads, umi_kit_config=umi_kit, umi_ends="left_only") + assert tags["r1"]["US"] == "AAAA;top;ACGT" + # Forward: U1=US + assert tags["r1"]["U1"] == "AAAA" + assert "U2" not in tags["r1"] + assert "UE" not in tags["r1"] + assert tags["r1"]["RX"] == "AAAA" + assert tags["r1"]["FC"] == "top" + + def test_umi_ends_right_only(self, tmp_path): + """umi_ends='right_only' skips left end.""" + umi_kit = UMIKitConfig( + flanking=PerEndFlankingConfig( + left_ref_end=FlankingConfig(adapter_side="ACGT"), + right_ref_end=FlankingConfig(adapter_side="TGCA"), + ), + length=4, + umi_ends="both", + umi_flank_mode="adapter_only", + ) + reads = [{"name": "r1", "sequence": "ACGTAAAANNNNNNNNCCCCTGCA"}] + tags = self._run(tmp_path, reads, umi_kit_config=umi_kit, umi_ends="right_only") + assert "US" not in tags["r1"] + # bottom flank from read end: RC(TGCA)=TGCA found, UMI before = CCCC, then RC'd = GGGG + assert tags["r1"]["UE"] == "GGGG;bottom;TGCA" + # Forward read: U1=US=None, U2=UE + assert "U1" not in tags["r1"] assert tags["r1"]["U2"] == "GGGG" - assert tags["r1"]["RX"] == "TTTT-GGGG" + assert tags["r1"]["RX"] == "GGGG" + assert tags["r1"]["FC"] == "bottom" # -- No UMI found --------------------------------------------------------- def test_no_umi_found(self, tmp_path): - """No adapter found → no UMI tags set.""" + """No adapter found -> no UMI tags set.""" reads = [{"name": "r1", "sequence": "TTTTTTTTTTTTTTTTTTTTTTTT"}] tags = self._run(tmp_path, reads) assert "U1" not in tags["r1"] assert "U2" not in tags["r1"] + assert "US" not in tags["r1"] + assert "UE" not in tags["r1"] assert "RX" not in tags["r1"] + assert "FC" not in tags["r1"] # -- Multiple reads ------------------------------------------------------- @@ -483,7 +423,68 @@ def test_multiple_reads_mixed(self, tmp_path): {"name": "left", "sequence": "ACGTGGGGNNNNNNNNNNNNNNNN"}, {"name": "none", "sequence": "TTTTTTTTTTTTTTTTTTTTTTTT"}, ] - tags = self._run(tmp_path, reads) - assert tags["both"]["RX"] == "AAAA-CCCC" + umi_kit = UMIKitConfig( + flanking=PerEndFlankingConfig( + left_ref_end=FlankingConfig(adapter_side="ACGT"), + right_ref_end=FlankingConfig(adapter_side="TGCA"), + ), + length=4, + umi_ends="both", + umi_flank_mode="adapter_only", + ) + tags = self._run(tmp_path, reads, umi_kit_config=umi_kit) + # "both" read: top adapter at start, bottom adapter at end + assert "U1" in tags["both"] + assert "US" in tags["both"] + assert "RX" in tags["both"] + assert "FC" in tags["both"] + # "left" read: only top adapter at start (forward: U1=US) + assert tags["left"]["U1"] == "GGGG" + assert tags["left"]["US"] == "GGGG;top;ACGT" assert tags["left"]["RX"] == "GGGG" + assert tags["left"]["FC"] == "top" + # "none" read: no adapters assert "RX" not in tags["none"] + assert "US" not in tags["none"] + assert "FC" not in tags["none"] + + # -- Multiprocessing path ------------------------------------------------- + + def test_multiprocessing_produces_same_results(self, tmp_path): + """threads=2 produces identical tags to single-threaded path.""" + umi_kit = UMIKitConfig( + flanking=PerEndFlankingConfig( + left_ref_end=FlankingConfig(adapter_side="GCTA"), + right_ref_end=FlankingConfig(adapter_side="CCGA"), + ), + length=4, + umi_ends="both", + umi_flank_mode="adapter_only", + ) + reads = [ + {"name": "r1", "sequence": "GCTAGGTTNNNNNNNNNNTCGTTCGG"}, + {"name": "r2", "sequence": "GCTAGGTTNNNNNNNNNNTCGTTCGG"}, + {"name": "r3", "sequence": "TTTTTTTTTTTTTTTTTTTTTTTT"}, + ] + single_dir = tmp_path / "single" + single_dir.mkdir() + multi_dir = tmp_path / "multi" + multi_dir.mkdir() + tags_single = self._run( + single_dir, + reads, + umi_kit_config=umi_kit, + umi_ends="both", + umi_flank_mode="adapter_only", + threads=1, + ) + tags_multi = self._run( + multi_dir, + reads, + umi_kit_config=umi_kit, + umi_ends="both", + umi_flank_mode="adapter_only", + threads=2, + ) + for name in ("r1", "r2", "r3"): + assert tags_single[name] == tags_multi[name]