diff --git a/README.md b/README.md index 3c85bcc..e1c3c20 100644 --- a/README.md +++ b/README.md @@ -123,7 +123,7 @@ long read data should still work. ```bash strkit call \ - path/to/read/file.bam \ # [REQUIRED] At least one indexed read file (BAM/CRAM) + path/to/read/file.bam \ # [REQUIRED] One indexed read file (BAM/CRAM) --hq \ # If using PacBio HiFi reads, enable this to get better genotyping & more robust expansion detection --realign \ # If using PacBio HiFi reads, enable this to enable local realignment / read recovery. Good for detecting expansions, but slows down calling. --ref path/to/reference.fa.gz \ # [REQUIRED] Indexed FASTA-formatted reference genome @@ -146,10 +146,6 @@ If you're using HiFi reads as input, **use the `--hq` and `--realign` options** genotype calculation and a greater proportion of reads incorporated into the computed genotypes, respectively. These should not add much performance overhead. -If more than one read file is specified, the reads will be pooled. This can come in handy if you -have e.g. multiple flow cells of the same sample split into different BAM files, or the reads are -split by chromosome. - If you want to **incorporate haplotagging from an alignment file (`HP` tags)** into the process, which should speed up runtime and potentially improve calling results, you must pass the `--use-hp` flag. **This flag is experimental, and has not been tested extensively.** diff --git a/requirements.txt b/requirements.txt index 4b59635..475b1ec 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,33 +3,33 @@ click==8.1.7 coverage==7.4.1 Cython==3.0.8 exceptiongroup==1.2.0 -Flask==3.0.2 -importlib-metadata==7.0.1 +Flask==3.0.3 +importlib_metadata==7.1.0 iniconfig==2.0.0 -itsdangerous==2.1.2 +itsdangerous==2.2.0 Jinja2==3.1.3 joblib==1.3.2 MarkupSafe==2.1.5 numpy==1.26.4 -orjson==3.9.13 -packaging==23.2 +orjson==3.10.3 +packaging==24.0 pandas==2.2.0 parasail==1.3.4 patsy==0.5.6 pluggy==1.4.0 -pyparsing==3.1.1 -pysam==0.22.0 +pyparsing==3.1.2 +pysam==0.22.1 pytest==7.4.4 pytest-cov==4.1.0 python-dateutil==2.8.2 pytz==2024.1 -scikit-learn==1.4.0 -scipy==1.12.0 +scikit-learn==1.4.2 +scipy==1.13.0 six==1.16.0 -statsmodels==0.14.1 -strkit_rust_ext==0.10.0 -threadpoolctl==3.2.0 +statsmodels==0.14.2 +strkit_rust_ext==0.11.0 +threadpoolctl==3.4.0 tomli==2.0.1 tzdata==2023.4 -Werkzeug==3.0.1 +Werkzeug==3.0.2 zipp==3.17.0 diff --git a/setup.py b/setup.py index dd1b919..c0f2c79 100644 --- a/setup.py +++ b/setup.py @@ -22,7 +22,7 @@ "scikit-learn>=1.2.1,<1.5", "scipy>=1.10,<1.14", "statsmodels>=0.14.0,<0.15", - "strkit_rust_ext==0.10.0", + "strkit_rust_ext==0.11.0", ], description="A toolkit for analyzing variation in short(ish) tandem repeats.", diff --git a/strkit/call/call_locus.py b/strkit/call/call_locus.py index cb24d38..4ceadca 100644 --- a/strkit/call/call_locus.py +++ b/strkit/call/call_locus.py @@ -1,12 +1,10 @@ from __future__ import annotations import functools -import itertools import logging import multiprocessing as mp import multiprocessing.managers as mmg import numpy as np -import pysam import operator import queue import time @@ -15,31 +13,33 @@ from collections import Counter from collections.abc import Sequence from datetime import datetime -from pysam import AlignmentFile, FastaFile +from pysam import FastaFile from sklearn.cluster import AgglomerativeClustering from numpy.typing import NDArray from typing import Iterable, Literal, Optional, Union -from strkit_rust_ext import get_pairs_and_tr_read_coords +from strkit_rust_ext import ( + get_pairs_and_tr_read_coords, STRkitBAMReader, STRkitAlignedSegment, STRkitVCFReader, CandidateSNVs +) from strkit.call.allele import CallDict, call_alleles from strkit.utils import apply_or_none from .align_matrix import match_score -from .consensus import best_representative -# from .consensus import consensus_seq +from .cigar import decode_cigar_np +# from .consensus import best_representative +from .consensus import consensus_seq from .params import CallParams from .realign import realign_read from .repeats import get_repeat_count, get_ref_repeat_count from .snvs import ( SNV_OUT_OF_RANGE_CHAR, - get_candidate_snvs, call_and_filter_useful_snvs, process_read_snvs_for_locus_and_calculate_useful_snvs, ) -from .types import ReadDict, ReadDictExtra, CandidateSNV, CalledSNV -from .utils import cat_strs, idx_0_getter, find_pair_by_ref_pos, normalize_contig, round_to_base_pos, get_new_seed +from .types import ReadDict, ReadDictExtra, CalledSNV +from .utils import cat_strs, idx_0_getter, find_pair_by_ref_pos, normalize_contig, get_new_seed __all__ = [ @@ -75,7 +75,7 @@ def _mask_low_q_base(base_and_qual: tuple[str, int]) -> str: return base_and_qual[0] if base_and_qual[1] > base_wildcard_threshold else "X" -def calculate_seq_with_wildcards(qs: str, quals: Optional[list[int]]) -> str: +def calculate_seq_with_wildcards(qs: str, quals: Optional[NDArray[np.uint8]]) -> str: if quals is None: return qs # No quality information, so don't do anything return cat_strs(map(_mask_low_q_base, zip(qs, quals))) @@ -146,77 +146,6 @@ def get_read_coords_from_matched_pairs( return left_flank_start, left_flank_end, right_flank_start, right_flank_end -def get_overlapping_segments_and_related_data( - bfs: tuple[pysam.AlignmentFile, ...], - read_contig: str, - left_flank_coord: int, - right_flank_coord: int, - max_reads: int, - logger_: logging.Logger, - locus_log_str: str, -) -> tuple[list[pysam.AlignedSegment], list[int], dict[str, int], int, int]: - - left_most_coord: int = 999999999999999 - right_most_coord: int = 0 - - overlapping_segments: list[pysam.AlignedSegment] = [] - seen_reads: set[str] = set() - read_lengths: list[int] = [] - n_seen: int = 0 - - chimeric_read_status: dict[str, int] = {} - - segment: pysam.AlignedSegment - - for segment in itertools.chain.from_iterable( - map(lambda bf: bf.fetch(read_contig, left_flank_coord, right_flank_coord), bfs) - ): - rn = segment.query_name - - if rn is None: # Skip reads with no name - logger_.debug(f"{locus_log_str} - skipping entry for read with no name") - continue - - supp = segment.flag & 2048 - - # If we have two overlapping alignments for the same read, we have a chimeric read within the TR - # (so probably a large expansion...) - - chimeric_read_status[rn] = chimeric_read_status.get(rn, 0) | (2 if supp else 1) - - if supp: # Skip supplemental alignments - logger_.debug(f"{locus_log_str} - skipping entry for read {rn} (supplemental)") - continue - - if rn in seen_reads: - logger_.debug(f"{locus_log_str} - skipping entry for read {rn} (already seen)") - continue - - qal: int = segment.query_alignment_length - - if segment.query_length == 0 or qal == 0: - # No aligned segment, skip entry (used to pull segment.query_sequence, but that's extra work) - continue - - if segment.reference_end is None: - logger_.debug(f"{locus_log_str} - skipping entry for read {rn} (reference_end is None, unmapped?)") - continue - - n_seen += 1 - - seen_reads.add(rn) - overlapping_segments.append(segment) - read_lengths.append(qal) - - left_most_coord = min(left_most_coord, segment.reference_start) - right_most_coord = max(right_most_coord, segment.reference_end) - - if n_seen > max_reads: - break - - return overlapping_segments, read_lengths, chimeric_read_status, left_most_coord, right_most_coord - - def calculate_read_distance( n_reads: int, read_dict_items: Sequence[tuple[str, ReadDict]], @@ -284,7 +213,7 @@ def calculate_read_distance( def call_alleles_with_haplotags( params: CallParams, - haplotags: list[str], + haplotags: list[int], ps_id: int, read_dict_items: tuple[tuple[str, ReadDict], ...], # We could derive this again, but we already have before... rng: np.random.Generator, @@ -387,12 +316,8 @@ def _determine_snv_call_phase_set( snv_pss_with_should_flip: list[tuple[int, bool]] = [] l1 = snv_genotype_update_lock.acquire(timeout=300) - l2 = phase_set_lock.acquire(timeout=300) - - if not (l1 and l2): - logger_.error( - f"Failed to acquire one of the following locks: {'snv_genotype_update_lock' if not l1 else ''} " - f"{'phase_set_lock' if not l2 else ''}") + if not l1: + logger_.error("Failed to acquire snv_genotype_update_lock") return None try: @@ -414,10 +339,21 @@ def _determine_snv_call_phase_set( logger_.debug(f"{locus_log_str} - assigned new phase set {call_phase_set} to SNVs {snv_id_list}") - else: + return call_phase_set + + finally: + snv_genotype_update_lock.release() + + l2 = phase_set_lock.acquire(timeout=300) + if not l2: + logger_.error("Failed to acquire phase_set_lock") + return None + + try: + if snv_pss_with_should_flip: # else from above, but we want to release snv_genotype_update_lock first # Have found SNVs, should flip/not flip and assign existing phase set - phase_set_consensus_set = tuple(sorted(set(snv_pss_with_should_flip), key=lambda x: x[0])) + phase_set_consensus_set = tuple(sorted(set(snv_pss_with_should_flip), key=idx_0_getter)) call_phase_set, should_flip = phase_set_consensus_set[0] # Use the phase set synonymous graph to get back to the smallest-count phase set to use for these SNVs @@ -467,11 +403,10 @@ def _determine_snv_call_phase_set( # We're good as-is, so assign the phase set r["ps"] = call_phase_set + return call_phase_set + finally: phase_set_lock.release() - snv_genotype_update_lock.release() - - return call_phase_set def call_alleles_with_incorporated_snvs( @@ -483,7 +418,7 @@ def call_alleles_with_incorporated_snvs( read_dict_extra: dict[str, dict], n_reads_in_dict: int, # We could derive this again, but we already have before... useful_snvs: list[tuple[int, int]], - candidate_snvs_dict: dict[int, CandidateSNV], + candidate_snvs_dict: CandidateSNVs, # --- phase_set_lock: threading.Lock, phase_set_counter: mmg.ValueProxy, @@ -737,7 +672,7 @@ def call_locus( t_idx: int, t: tuple, n_alleles: int, - bfs: tuple[AlignmentFile, ...], + bf: STRkitBAMReader, ref: FastaFile, params: CallParams, # --- @@ -751,7 +686,7 @@ def call_locus( seed: int, logger_: logging.Logger, # --- - snv_vcf_file: Optional[pysam.VariantFile] = None, + snv_vcf_file: Optional[STRkitVCFReader] = None, snv_vcf_contigs: tuple[str, ...] = (), snv_vcf_file_format: Literal["chr", "num", "acc", ""] = "", # --- @@ -868,17 +803,21 @@ def call_locus( # If SNV-based peak calling is enabled, we can use this to pre-fetch reference data for all reads to reduce the # fairly significant overhead involved in reading from the reference genome for each read to identifify SNVs. - overlapping_segments: list[pysam.AlignedSegment] - read_lengths: list[int] + overlapping_segments: NDArray[STRkitAlignedSegment] + read_lengths: NDArray[np.uint] chimeric_read_status: dict[str, int] max_reads: int = params.max_reads - overlapping_segments, read_lengths, chimeric_read_status, left_most_coord, right_most_coord = \ - get_overlapping_segments_and_related_data( - bfs, read_contig, left_flank_coord, right_flank_coord, max_reads, logger_, locus_log_str) - - n_overlapping_reads = len(overlapping_segments) + ( + overlapping_segments, + n_overlapping_reads, + read_lengths, + chimeric_read_status, + left_most_coord, + right_most_coord, + ) = bf.get_overlapping_segments_and_related_data( + read_contig, left_flank_coord, right_flank_coord, max_reads, logger_, locus_log_str) if n_overlapping_reads > params.max_reads: logger_.warning(f"{locus_log_str} - skipping locus; too many overlapping reads") @@ -888,11 +827,12 @@ def call_locus( # Find candidate SNVs, if we're using SNV data - candidate_snvs_dict: dict[int, CandidateSNV] = {} # Lookup dictionary for candidate SNVs by position + candidate_snvs_dict: Optional[CandidateSNVs] = None # Lookup dictionary for candidate SNVs by position if n_overlapping_reads and should_incorporate_snvs and snv_vcf_file: # ^^ n_overlapping_reads check since otherwise we will have invalid left/right_most_coord - candidate_snvs_dict = get_candidate_snvs( - snv_vcf_file, snv_vcf_contigs, snv_vcf_file_format, contig, left_most_coord, right_most_coord) + candidate_snvs_dict = snv_vcf_file.get_candidate_snvs( + snv_vcf_contigs, snv_vcf_file_format, contig, left_most_coord, right_most_coord + ) # Build the read dictionary with segment information, copy number, weight, & more. --------------------------------- @@ -902,7 +842,7 @@ def call_locus( # Various aggregators for if we have a phased alignment file: haplotagged_reads_count: int = 0 # Number of reads with HP tags - haplotags: set[str] = set() + haplotags: set[int] = set() phase_sets: Counter[int] = Counter() # Aggregations for additional read-level data @@ -912,18 +852,16 @@ def call_locus( n_extremely_poor_scoring_reads = 0 - segment: pysam.AlignedSegment + segment: STRkitAlignedSegment for segment, read_len in zip(overlapping_segments, read_lengths): - rn: str = segment.query_name # Know this is not None from overlapping_segments calculation - segment_start: int = segment.reference_start - segment_end: int = segment.reference_end # Optional[int], but if it's here we know it isn't None + rn: str = segment.name # Know this is not None from overlapping_segments calculation + segment_start: int = segment.start + segment_end: int = segment.end - # While .query_sequence is Optional[str], we know (because we skipped all segments with query_sequence is None - # above) that this is guaranteed to be, in fact, not None. qs: str = segment.query_sequence - fqqs: Optional[list[int]] = segment.query_qualities - cigar_tuples: list[tuple[int, int]] = segment.cigartuples + fqqs: NDArray[np.uint8] = segment.query_qualities + cigar_tuples: list[tuple[int, int]] = list(decode_cigar_np(segment.raw_cigar)) realigned: bool = False @@ -1031,7 +969,7 @@ def call_locus( continue # we can fit PHRED scores in uint8 - qqs = np.fromiter(fqqs[left_flank_end:right_flank_start], dtype=np.uint8) + qqs = fqqs[left_flank_end:right_flank_start] if qqs.shape[0] and (m_qqs := np.mean(qqs)) < (min_avg_phred := params.min_avg_phred): # TODO: check flank? logger_.debug( f"{locus_log_str} - skipping read {rn} due to low average base quality ({m_qqs:.2f} < {min_avg_phred})") @@ -1054,7 +992,7 @@ def call_locus( tr_len_w_flank: int = tr_len + flank_len tr_read_seq = qs[left_flank_end:right_flank_start] - tr_read_seq_wc = calculate_seq_with_wildcards(qs[left_flank_end:right_flank_start], qqs) + tr_read_seq_wc = calculate_seq_with_wildcards(tr_read_seq, qqs) if count_kmers != "none": read_kmers.clear() @@ -1122,8 +1060,7 @@ def call_locus( # Reads can show up more than once - TODO - cache this information across loci if params.use_hp: - tags = dict(segment.get_tags()) - if (hp := tags.get("HP")) is not None and (ps := tags.get("PS")) is not None: + if (hp := segment.hp) is not None and (ps := segment.ps) is not None: orig_ps = int(ps) phase_set_lock.acquire(timeout=600) @@ -1389,7 +1326,7 @@ def call_locus( if call_data and consensus: call_seqs = list( - map(lambda a: best_representative(map(lambda rr: read_dict_extra[rr]["_tr_seq"], a)), allele_reads) + map(lambda a: consensus_seq(map(lambda rr: read_dict_extra[rr]["_tr_seq"], a)), allele_reads) ) peak_data = { @@ -1416,14 +1353,10 @@ def call_locus( # Compile the call into a dictionary with all information to return ------------------------------------------------ - if fractional: - def _ndarray_serialize(x: Iterable) -> list[Union[float, np.float_]]: - return [round_to_base_pos(y, motif_size) for y in x] - else: - def _ndarray_serialize(x: Iterable) -> list[Union[int, float, np.int_, np.float_]]: - return list(map(round, x)) + def _ndarray_serialize(x: Iterable) -> list[Union[int, np.int_]]: + return list(map(round, x)) - def _nested_ndarray_serialize(x: Iterable) -> list[list[Union[int, float, np.int_, np.float_]]]: + def _nested_ndarray_serialize(x: Iterable) -> list[list[Union[int, np.int_]]]: return list(map(_ndarray_serialize, x)) call_val = apply_or_none(_ndarray_serialize, call) diff --git a/strkit/call/call_sample.py b/strkit/call/call_sample.py index c841332..4ed07f4 100644 --- a/strkit/call/call_sample.py +++ b/strkit/call/call_sample.py @@ -44,11 +44,7 @@ get_locus_index = itemgetter("locus_index") -def get_vcf_contig_format(snv_vcf_file: Optional[pysam.VariantFile]) -> Literal["chr", "num", "acc", ""]: - if snv_vcf_file is None: - return "" - - snv_vcf_contigs = list(map(lambda c: c.name, snv_vcf_file.header.contigs.values())) +def get_vcf_contig_format(snv_vcf_contigs: list[str]) -> Literal["chr", "num", "acc", ""]: if not snv_vcf_contigs or snv_vcf_contigs[0].startswith("chr"): return "chr" elif NUMERAL_CONTIG_PATTERN.match(snv_vcf_contigs[0]): @@ -80,6 +76,7 @@ def locus_worker( pr = None import pysam as p + from strkit_rust_ext import STRkitBAMReader, STRkitVCFReader lg: logging.Logger if is_single_processed: @@ -89,14 +86,16 @@ def locus_worker( lg = create_process_logger(os.getpid(), params.log_level) ref = p.FastaFile(params.reference_file) - bfs = tuple(p.AlignmentFile(rf, reference_filename=params.reference_file) for rf in params.read_files) + bf = STRkitBAMReader(params.read_file, params.reference_file) snv_vcf_file = p.VariantFile(params.snv_vcf) if params.snv_vcf else None - snv_vcf_contigs: list[str] = [] - vcf_file_format: Literal["chr", "num", "acc", ""] = get_vcf_contig_format(snv_vcf_file) + snv_vcf_contigs = list(map(lambda c: c.name, snv_vcf_file.header.contigs.values())) if snv_vcf_file else [] + vcf_file_format: Literal["chr", "num", "acc", ""] = get_vcf_contig_format(snv_vcf_contigs) ref_file_has_chr = any(r.startswith("chr") for r in ref.references) - read_file_has_chr = any(r.startswith("chr") for bf in bfs for r in bf.references) + read_file_has_chr = any(r.startswith("chr") for r in bf.references) + + snv_vcf_reader = STRkitVCFReader(str(params.snv_vcf)) if params.snv_vcf else None results: list[dict] = [] @@ -108,7 +107,7 @@ def locus_worker( t_idx, t, n_alleles, locus_seed = td res = call_locus( - t_idx, t, n_alleles, bfs, ref, params, + t_idx, t, n_alleles, bf, ref, params, phase_set_lock, phase_set_counter, phase_set_remap, @@ -117,7 +116,7 @@ def locus_worker( snv_genotype_cache, seed=locus_seed, logger_=lg, - snv_vcf_file=snv_vcf_file, + snv_vcf_file=snv_vcf_reader, snv_vcf_contigs=tuple(snv_vcf_contigs), snv_vcf_file_format=vcf_file_format, read_file_has_chr=read_file_has_chr, diff --git a/strkit/call/cigar.py b/strkit/call/cigar.py index 042a94d..5b19f81 100644 --- a/strkit/call/cigar.py +++ b/strkit/call/cigar.py @@ -1,3 +1,5 @@ +import numpy as np +from numpy.typing import NDArray from typing import Iterable, Union from strkit_rust_ext import get_aligned_pair_matches @@ -5,6 +7,7 @@ __all__ = [ "CoordPair", "decode_cigar", + "decode_cigar_np", "get_aligned_pair_matches", ] @@ -18,3 +21,7 @@ def _decode_cigar_item(item: int) -> tuple[int, int]: def decode_cigar(encoded_cigar: list[int]) -> Iterable[tuple[int, int]]: return map(_decode_cigar_item, encoded_cigar) + + +def decode_cigar_np(encoded_cigar: NDArray[np.uint32]) -> Iterable[tuple[int, int]]: + return zip(np.bitwise_and(encoded_cigar, 15, dtype=int), np.right_shift(encoded_cigar, 4, dtype=int)) diff --git a/strkit/call/params.py b/strkit/call/params.py index 120f6a5..897eb14 100644 --- a/strkit/call/params.py +++ b/strkit/call/params.py @@ -15,7 +15,7 @@ def __init__( logger: logging.Logger, - read_files: tuple[str, ...], + read_file: str, reference_file: str, loci_file: str, sample_id: Optional[str], @@ -39,7 +39,7 @@ def __init__( seed: Optional[int] = None, processes: int = 1, ): - self.read_files: tuple[str, ...] = read_files + self.read_file: str = read_file self.reference_file: str = reference_file self.loci_file: str = loci_file self.min_reads: int = min_reads @@ -62,12 +62,12 @@ def __init__( self.seed: Optional[int] = seed self.processes: int = processes - bfs = tuple(AlignmentFile(rf, reference_filename=reference_file) for rf in read_files) + bf = AlignmentFile(read_file, reference_filename=reference_file) # noinspection PyTypeChecker - bfhs = [bf.header.to_dict() for bf in bfs] + bfh = bf.header.to_dict() - sns: set[str] = {e.get("SM") for bfh in bfhs for e in bfh.get("RG", ()) if e.get("SM")} + sns: set[str] = {e.get("SM") for e in bfh.get("RG", ()) if e.get("SM")} bam_sample_id: Optional[str] = None if len(sns) > 1: @@ -87,7 +87,7 @@ def __init__( def from_args(cls, logger: logging.Logger, p_args): return cls( logger, - tuple(p_args.read_files), + p_args.read_file, p_args.ref, p_args.loci, sample_id=p_args.sample_id, @@ -114,7 +114,7 @@ def from_args(cls, logger: logging.Logger, p_args): def to_dict(self, as_inputted: bool = False): return { - "read_files": self.read_files, + "read_file": self.read_file, "reference_file": self.reference_file, "min_reads": self.min_reads, "min_allele_reads": self.min_allele_reads, diff --git a/strkit/call/snvs.py b/strkit/call/snvs.py index 811b215..9cc4824 100644 --- a/strkit/call/snvs.py +++ b/strkit/call/snvs.py @@ -1,22 +1,20 @@ import logging import multiprocessing.managers as mmg -import numpy as np -import pysam import threading from collections import Counter -from typing import Literal, Optional +from typing import Optional -from strkit_rust_ext import get_read_snvs, process_read_snvs_for_locus_and_calculate_useful_snvs +from strkit_rust_ext import get_read_snvs, process_read_snvs_for_locus_and_calculate_useful_snvs, CandidateSNVs from strkit.logger import logger -from .types import ReadDict, CandidateSNV, CalledSNV +from .types import ReadDict, CalledSNV +from .utils import idx_1_getter __all__ = [ "SNV_OUT_OF_RANGE_CHAR", "SNV_GAP_CHAR", - "get_candidate_snvs", "get_read_snvs", "call_and_filter_useful_snvs", "process_read_snvs_for_locus_and_calculate_useful_snvs", @@ -26,59 +24,12 @@ SNV_GAP_CHAR = "_" -def _human_chrom_to_refseq_accession(contig: str, snv_vcf_contigs: tuple[str, ...]) -> Optional[str]: - contig = contig.removeprefix("chr") - if contig == "X": - contig = "23" - if contig == "Y": - contig = "24" - if contig == "M": - contig = "12920" - contig = f"NC_{contig.zfill(6)}" - - for vcf_contig in snv_vcf_contigs: - if vcf_contig.startswith(contig): - return vcf_contig - - return None - - -def get_candidate_snvs( - snv_vcf_file: pysam.VariantFile, - snv_vcf_contigs: tuple[str, ...], - snv_vcf_file_format: Literal["chr", "num", "acc", ""], - contig: str, - left_most_coord: int, - right_most_coord: int, -) -> dict[int, CandidateSNV]: - candidate_snvs_dict: dict[int, CandidateSNV] = {} # Lookup dictionary for candidate SNVs by position - - snv_contig: str = contig - if snv_contig not in snv_vcf_contigs: - if snv_vcf_file_format == "num": - snv_contig = snv_contig.removeprefix("chr") - elif snv_vcf_file_format == "acc": - snv_contig = _human_chrom_to_refseq_accession(snv_contig, snv_vcf_contigs) - # Otherwise, leave as-is - - for snv in snv_vcf_file.fetch(snv_contig, left_most_coord, right_most_coord + 1): - snv_ref = snv.ref - snv_alts = snv.alts - # check actually is SNV - if snv_ref is not None and len(snv_ref) == 1 and snv_alts and any(len(a) == 1 for a in snv_alts): - # Convert from 1-based indexing to 0-based indexing!!! - # See https://pysam.readthedocs.io/en/latest/api.html#pysam.VariantRecord.pos - candidate_snvs_dict[snv.pos - 1] = CandidateSNV(id=snv.id, ref=snv.ref, alts=snv_alts) - - return candidate_snvs_dict - - def call_and_filter_useful_snvs( contig: str, n_alleles: int, read_dict: dict[str, ReadDict], useful_snvs: list[tuple[int, int]], - candidate_snvs_dict: dict[int, CandidateSNV], + candidate_snvs_dict: CandidateSNVs, # --- snv_genotype_update_lock: threading.Lock, snv_genotype_cache: mmg.DictProxy, @@ -158,21 +109,27 @@ def call_and_filter_useful_snvs( call.append(mcc[0]) rs.append(mcc[1]) - if not skipped and len(set(call)) == 1: + snv_call_set = set(call) + + if not skipped and len(snv_call_set) == 1: # print(u_idx, u_ref, peak_counts, call, rs) logger_.warning(f"{locus_log_str} - for SNV position {u_ref}: got degenerate call {call} from {peak_counts=}") skipped = True snv_rec = candidate_snvs_dict.get(u_ref) - snv_id = snv_rec["id"] if snv_rec is not None else f"{contig}_{u_ref}" - snv_call = np.array(call).tolist() + if snv_rec is not None: + snv_id = snv_rec["id"] + if snv_id == ".": + snv_id = f"{contig}_{u_ref}" + else: + snv_id = f"{contig}_{u_ref}" if not skipped: snv_genotype_update_lock.acquire(timeout=600) - if snv_id in snv_genotype_cache and (cgt := set(snv_genotype_cache[snv_id][0])) != (sgt := set(snv_call)): + if snv_id in snv_genotype_cache and (cgt := set(snv_genotype_cache[snv_id][0])) != snv_call_set: logger_.warning( f"{locus_log_str} - got mismatch for SNV {snv_id} (position {u_ref}); cache genotype set {cgt} != " - f"current genotype set {sgt}") + f"current genotype set {snv_call_set}") skipped = True snv_genotype_update_lock.release() @@ -182,10 +139,10 @@ def call_and_filter_useful_snvs( called_snvs.append({ "id": snv_id, - **({"ref": snv_rec["ref"]} if snv_rec is not None else {}), + **({"ref": snv_rec["ref_base"]} if snv_rec is not None else {}), "pos": u_ref, - "call": snv_call, - "rcs": np.array(rs).tolist(), + "call": call, + "rcs": rs, }) # If we've skipped any SNVs, filter them out of the read dict - MUTATION @@ -193,7 +150,7 @@ def call_and_filter_useful_snvs( for read in read_dict.values(): if "snvu" not in read: continue - read["snvu"] = tuple(b for i, b in enumerate(read["snvu"]) if i not in skipped_snvs) + read["snvu"] = tuple(map(idx_1_getter, filter(lambda e: e[0] not in skipped_snvs, enumerate(read["snvu"])))) logger.debug(f"{locus_log_str} - filtered out {len(skipped_snvs)} not-actually-useful SNVs") return called_snvs diff --git a/strkit/call/types.py b/strkit/call/types.py index cc69f38..e06ccae 100644 --- a/strkit/call/types.py +++ b/strkit/call/types.py @@ -31,7 +31,7 @@ class ReadDict(_ReadDictBase, total=False): kmers: dict[str, int] # Dictionary of {kmer: count} # Only added if HP tags from a haplotagged alignment file are being incorporated: - hp: str + hp: int ps: int # Only added if SNVs are being incorporated: diff --git a/strkit/entry.py b/strkit/entry.py index 5379813..db6d2bc 100644 --- a/strkit/entry.py +++ b/strkit/entry.py @@ -1,7 +1,6 @@ from __future__ import annotations import argparse -import logging import pathlib import os import sys @@ -16,10 +15,9 @@ def add_call_parser_args(call_parser): call_parser.add_argument( - "read_files", - nargs="+", + "read_file", type=str, - help="BAM file(s) with reads to call from. If multiple files are specified, the reads will be pooled.") + help="Indexed BAM/CRAM file(s) with reads to call from.") call_parser.add_argument( "--ref", "-r", @@ -332,7 +330,7 @@ def add_cv_parser_args(al_parser): def add_vs_parser_args(vs_parser): - vs_parser.add_argument("align_files", nargs="+", type=str, help="Alignment file(s) to visualize.") + vs_parser.add_argument("align_file", type=str, help="Alignment file to visualize.") vs_parser.add_argument( "--align-indices", nargs="*", @@ -470,26 +468,17 @@ def _exec_viz_server(p_args): from strkit.json import json from strkit.viz.server import run_server as viz_run_server - align_files = [str(pathlib.Path(af).resolve()) for af in p_args.align_files] - align_indices = [str(pathlib.Path(aif).resolve()) for aif in (p_args.align_indices or ())] + align_file = str(pathlib.Path(p_args.align_file).resolve()) + align_index = str(pathlib.Path(p_args.align_index).resolve()) if p_args.align_index else None - if align_indices and len(align_files) != len(align_indices): - raise ParamError( - f"Number of alignment indices must match number of alignment files ({len(align_indices)} vs. " - f"{len(align_files)})") + align_format = os.path.splitext(align_file)[-1].lstrip(".") + if align_format not in ("bam", "cram"): + raise ParamError(f"File type '{align_format}' not supported") - align_formats = [] - - for af in align_files: - if (align_type := os.path.splitext(af)[-1].lstrip(".")) not in ("bam", "cram"): - raise ParamError(f"File type '{align_type}' not supported") - align_formats.append(align_type) - - if not align_indices: - for idx, af in enumerate(align_files): - if not os.path.exists(align_index := f"{af}.{'crai' if align_formats[idx] == 'cram' else 'bai'}"): - raise ParamError(f"Missing index at '{align_index}'") - align_indices.append(align_index) + if not align_index: + align_index = f"{align_file}.{'crai' if align_format == 'cram' else 'bai'}" + if not os.path.exists(align_index): + raise ParamError(f"Missing index at '{align_index}'") # TODO: Conditionally use this code if ref looks like a path # ref = pathlib.Path(p_args.ref).resolve() @@ -518,11 +507,11 @@ def _exec_viz_server(p_args): initial_i=idx-1, ref=p_args.ref, # ref_index=ref_index, - align_files=align_files, - align_indices=align_indices, - align_names=[os.path.basename(af) for af in align_files], - align_index_names=[os.path.basename(ai) for ai in align_indices], - align_formats=align_formats, + align_file=align_file, + align_index=align_index, + align_name=os.path.basename(align_file), + align_index_name=os.path.basename(align_index), + align_format=align_format, ) return 0 diff --git a/strkit/viz/server.py b/strkit/viz/server.py index 1aae55b..b7f8f6f 100644 --- a/strkit/viz/server.py +++ b/strkit/viz/server.py @@ -66,14 +66,14 @@ def get_call_data(i: int): # return send_file(app.config["PARAMS"]["ref_index"], conditional=True) -@app.route("/align_files/") +@app.route("/align_file") def get_align_file(i: int): - return send_file(app.config["PARAMS"]["align_files"][i], conditional=True) + return send_file(app.config["PARAMS"]["align_file"], conditional=True) -@app.route("/align_indices/") +@app.route("/align_index") def get_align_index_file(i: int): - return send_file(app.config["PARAMS"]["align_indices"][i], conditional=True) + return send_file(app.config["PARAMS"]["align_index"], conditional=True) def run_server(call_report, **kwargs): diff --git a/strkit/viz/templates/browser.html b/strkit/viz/templates/browser.html index 7763b6e..401b528 100644 --- a/strkit/viz/templates/browser.html +++ b/strkit/viz/templates/browser.html @@ -378,24 +378,21 @@ browser = null; } - const alignTracks = []; - for (let i = 0; i < params.cmd.align_names.length; i++) { - alignTracks.push({ - name: params.cmd.align_names[i], - url: `/align_files/${i}`, - indexURL: `/align_indices/${i}`, - format: params.cmd.align_formats[i], - showSoftClips: true, - showInsertionText: true, - showDeletionText: true, - color: alignment => readColours[alignment.readName] ?? "rgb(180, 180, 180)", - sort: { - chr: callData.contig, - position: ((callData.start + callData.end) / 2).toFixed(0), - option: "INSERT_SIZE", - }, - }); - } + const alignTracks = [{ + name: params.cmd.align_name, + url: `/align_file`, + indexURL: `/align_index`, + format: params.cmd.align_format, + showSoftClips: true, + showInsertionText: true, + showDeletionText: true, + color: alignment => readColours[alignment.readName] ?? "rgb(180, 180, 180)", + sort: { + chr: callData.contig, + position: ((callData.start + callData.end) / 2).toFixed(0), + option: "INSERT_SIZE", + }, + }]; // Set up new IGV browser instance igv.createBrowser(igvContainer, {