Skip to content

Commit

Permalink
refact: rewrite Optional/Union as py3.10 union op
Browse files Browse the repository at this point in the history
  • Loading branch information
davidlougheed committed Dec 10, 2024
1 parent 9d34be5 commit a357704
Show file tree
Hide file tree
Showing 20 changed files with 158 additions and 174 deletions.
26 changes: 13 additions & 13 deletions strkit/call/allele.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from warnings import simplefilter

from numpy.typing import NDArray
from typing import Iterable, Literal, Optional, TypedDict, Union
from typing import Iterable, Literal, TypedDict, Union

import strkit.constants as cc

Expand All @@ -34,7 +34,7 @@
"call_alleles",
]

RepeatCounts = Union[list[int], tuple[int, ...], NDArray[np.int_]]
RepeatCounts = list[int] | tuple[int, ...] | NDArray[np.int_]


# K-means convergence errors - we expect convergence to some extent with homozygous alleles
Expand All @@ -54,7 +54,7 @@
}


def _array_as_int(n: Union[NDArray[np.int_], NDArray[np.float_]]) -> NDArray[np.int32]:
def _array_as_int(n: NDArray[np.int_] | NDArray[np.float_]) -> NDArray[np.int32]:
return np.rint(n).astype(np.int32)


Expand All @@ -65,7 +65,7 @@ def _calculate_cis(samples, ci: str = Literal["95", "99"]) -> NDArray[np.int32]:
return _array_as_int(percentiles)


def get_n_alleles(default_n_alleles: int, sample_sex_chroms: Optional[str], contig: str) -> Optional[int]:
def get_n_alleles(default_n_alleles: int, sample_sex_chroms: str | None, contig: str) -> int | None:
if contig in cc.M_CHROMOSOME_NAMES:
return 1

Expand Down Expand Up @@ -105,9 +105,9 @@ def fit_gmm(
hq: bool,
gm_filter_factor: int,
init_params: GMMInitParamsMethod = "k-means++", # TODO: parameterize outside
) -> Optional[object]:
) -> object | None:
sample_rs = sample.reshape(-1, 1)
g: Optional[object] = None
g: object | None = None

n_components: int = n_alleles
while n_components > 0:
Expand Down Expand Up @@ -165,26 +165,26 @@ class CallDict(BaseCallDict, total=False):
ps: int


def make_read_weights(read_weights: Optional[Iterable[float]], num_reads: int) -> NDArray[np.float_]:
def make_read_weights(read_weights: Iterable[float] | None, num_reads: int) -> NDArray[np.float_]:
return np.array(
read_weights if read_weights is not None else np.array(([1/num_reads] * num_reads) if num_reads else []))


def call_alleles(
repeats_fwd: NDArray[np.int32],
repeats_rev: NDArray[np.int32],
read_weights_fwd: Optional[Iterable[float]],
read_weights_rev: Optional[Iterable[float]],
read_weights_fwd: Iterable[float] | None,
read_weights_rev: Iterable[float] | None,
params: CallParams,
min_reads: int,
n_alleles: int,
separate_strands: bool,
read_bias_corr_min: int,
gm_filter_factor: int,
seed: Optional[int],
seed: int | None,
logger_: logging.Logger,
debug_str: str,
) -> Optional[CallDict]:
) -> CallDict | None:
fwd_len = repeats_fwd.shape[0]
rev_len = repeats_rev.shape[0]

Expand Down Expand Up @@ -268,7 +268,7 @@ def call_alleles(

gmm_cache = {}

def _get_fitted_gmm(s: Union[NDArray[np.int_], NDArray[np.float_]]) -> Optional[object]:
def _get_fitted_gmm(s: NDArray[np.int_] | NDArray[np.float_]) -> object | None:
if (s_t := s.tobytes()) not in gmm_cache:
# Fit Gaussian mixture model to the resampled data
gmm_cache[s_t] = fit_gmm(rng, s, n_alleles, allele_filter, params.hq, gm_filter_factor)
Expand All @@ -282,7 +282,7 @@ def _get_fitted_gmm(s: Union[NDArray[np.int_], NDArray[np.float_]]) -> Optional[
for i in range(params.num_bootstrap):
sample = concat_samples[i, :]

g: Optional[object] = _get_fitted_gmm(sample)
g: object | None = _get_fitted_gmm(sample)
if not g:
# Could not fit any Gaussian mixture; skip this allele
return None
Expand Down
40 changes: 19 additions & 21 deletions strkit/call/call_locus.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from sklearn.mixture import GaussianMixture

from numpy.typing import NDArray
from typing import Iterable, Literal, Optional, Union
from typing import Iterable, Literal

from strkit_rust_ext import (
CandidateSNVs,
Expand All @@ -28,7 +28,7 @@
)

from strkit.call.allele import CallDict, call_alleles
from strkit.utils import apply_or_none
from strkit.utils import idx_0_getter, apply_or_none

from .align_matrix import match_score
from .cigar import decode_cigar_np
Expand All @@ -45,9 +45,7 @@
from .types import (
VCFContigFormat, AssignMethod, AssignMethodWithHP, ConsensusMethod, ReadDict, ReadDictExtra, CalledSNV, LocusResult
)
from .utils import (
idx_0_getter, cn_getter, find_pair_by_ref_pos, normalize_contig, get_new_seed, calculate_seq_with_wildcards
)
from .utils import cn_getter, find_pair_by_ref_pos, normalize_contig, get_new_seed, calculate_seq_with_wildcards


__all__ = [
Expand Down Expand Up @@ -224,12 +222,12 @@ def call_alleles_with_haplotags(
# ---
logger_: logging.Logger,
locus_log_str: str,
) -> Optional[dict]:
) -> dict | None:
n_alleles: int = len(haplotags)

hp_reads: list[tuple[ReadDict, ...]] = []
cns: list[NDArray[np.int32]] = []
c_ws: list[Union[NDArray[np.int_], NDArray[np.float_]]] = []
c_ws: list[NDArray[np.int_] | NDArray[np.float_]] = []

for hi, hp in enumerate(haplotags):
# Find reads for cluster
Expand Down Expand Up @@ -315,13 +313,13 @@ def _determine_snv_call_phase_set(
# ---
logger_: logging.Logger,
locus_log_str: str,
) -> Optional[int]:
) -> int | None:
# May mutate: cdd_ordered

# We may need to re-order (flip) calls based on SNVs. Check each SNV to see if it's in the SNV genotype/phase-set
# dictionary; otherwise, assign a phase set to all reads which have been used for peak calling here.

call_phase_set: Optional[int]
call_phase_set: int | None

snv_pss_with_should_flip: list[tuple[int, bool]] = []

Expand Down Expand Up @@ -464,7 +462,7 @@ def call_alleles_with_incorporated_snvs(
rng: np.random.Generator,
logger_: logging.Logger,
locus_log_str: str,
) -> tuple[AssignMethod, Optional[tuple[dict, list[CalledSNV]]]]:
) -> tuple[AssignMethod, tuple[dict, list[CalledSNV]] | None]:
assign_method: AssignMethod = "dist"

# TODO: parametrize min 'enough to do pure SNV haplotyping' thresholds
Expand All @@ -479,7 +477,7 @@ def call_alleles_with_incorporated_snvs(

for read_item in read_dict_items:
rn, read = read_item
snv_bases: Optional[tuple[tuple[str, int], ...]] = read_dict_extra[rn].get("snv_bases")
snv_bases: tuple[tuple[str, int], ...] | None = read_dict_extra[rn].get("snv_bases")

if snv_bases is None:
read_dict_items_with_no_snvs.append(read_item)
Expand Down Expand Up @@ -597,7 +595,7 @@ def call_alleles_with_incorporated_snvs(
cdd: list[CallDict] = []

for ci in cluster_indices:
cc: Optional[CallDict] = call_alleles(
cc: CallDict | None = call_alleles(
cns[ci], EMPTY_NP_ARRAY, # Don't bother separating by strand for now...
c_ws[ci], (),
params,
Expand Down Expand Up @@ -671,7 +669,7 @@ def call_alleles_with_incorporated_snvs(
# - cdd_ordered
# - called_useful_snvs

call_phase_set: Optional[int] = _determine_snv_call_phase_set(
call_phase_set: int | None = _determine_snv_call_phase_set(
read_dict,
cdd_ordered,
called_useful_snvs,
Expand Down Expand Up @@ -724,11 +722,11 @@ def _calc_motif_size_kmers(tr_read_seq_wc: str, tr_len: int, motif_size: int):
yield tr_read_seq_wc[i:i + motif_size]


def _ndarray_serialize(x: Iterable) -> list[Union[int, np.int_]]:
def _ndarray_serialize(x: Iterable) -> list[int | np.int_]:
return list(map(round, x))


def _nested_ndarray_serialize(x: Iterable) -> list[list[Union[int, np.int_]]]:
def _nested_ndarray_serialize(x: Iterable) -> list[list[int | np.int_]]:
return list(map(_ndarray_serialize, x))


Expand All @@ -755,13 +753,13 @@ def call_locus(
logger_: logging.Logger,
locus_log_str: str,
# ---
snv_vcf_file: Optional[STRkitVCFReader] = None,
snv_vcf_file: STRkitVCFReader | None = None,
snv_vcf_contigs: tuple[str, ...] = (),
snv_vcf_file_format: VCFContigFormat = "",
# ---
read_file_has_chr: bool = True,
ref_file_has_chr: bool = True,
) -> Optional[LocusResult]:
) -> LocusResult | None:
call_timer = time.perf_counter()

# params de-structuring ------------
Expand Down Expand Up @@ -870,7 +868,7 @@ def call_locus(
ref_max_iters = 50
ref_local_search_range = 1

ref_cn: Union[int, float]
ref_cn: int | float
(ref_cn, _), l_offset, r_offset, r_n_is, (ref_left_flank_seq, ref_seq, ref_right_flank_seq) = get_ref_repeat_count(
ref_est_cn,
ref_seq,
Expand Down Expand Up @@ -946,7 +944,7 @@ def get_read_length_partition_mean(p_idx: int) -> float:

# Find candidate SNVs, if we're using SNV data

candidate_snvs: Optional[CandidateSNVs] = None # Lookup dictionary for candidate SNVs by position
candidate_snvs: CandidateSNVs | None = None # Lookup dictionary for candidate SNVs by position
if n_overlapping_reads and should_incorporate_snvs and snv_vcf_file:
# ^^ n_overlapping_reads check since otherwise we will have invalid left/right_most_coord
candidate_snvs = snv_vcf_file.get_candidate_snvs(
Expand Down Expand Up @@ -991,8 +989,8 @@ def get_read_length_partition_mean(p_idx: int) -> float:
right_flank_start = -1
right_flank_end = -1

q_coords: Optional[NDArray[np.uint64]] = None
r_coords: Optional[NDArray[np.uint64]] = None
q_coords: NDArray[np.uint64] | None = None
r_coords: NDArray[np.uint64] | None = None

# Soft-clipping in large insertions can result from mapping difficulties.
# If we have a soft clip which overlaps with our TR region (+ flank), we can try to recover it
Expand Down
18 changes: 9 additions & 9 deletions strkit/call/call_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from pysam import VariantFile as PySamVariantFile
from queue import Empty as QueueEmpty
from threading import Lock
from typing import Iterable, Literal, Optional
from typing import Iterable, Literal

from .allele import get_n_alleles
from .call_locus import call_locus
Expand Down Expand Up @@ -108,7 +108,7 @@ def locus_worker(

snv_vcf_reader = STRkitVCFReader(str(params.snv_vcf)) if params.snv_vcf else None

current_contig: Optional[str] = None
current_contig: str | None = None
results: list[LocusResult] = []

while True:
Expand Down Expand Up @@ -192,7 +192,7 @@ def locus_worker(


def progress_worker(
sample_id: Optional[str],
sample_id: str | None,
start_time: float,
log_level: int,
locus_queue: mp.Queue,
Expand Down Expand Up @@ -260,8 +260,8 @@ def parse_loci_bed(loci_file: str) -> Iterable[tuple[str, ...]]:

def call_sample(
params: CallParams,
json_path: Optional[str] = None,
vcf_path: Optional[str] = None,
json_path: str | None = None,
vcf_path: str | None = None,
indent_json: bool = False,
output_tsv: bool = True,
) -> None:
Expand All @@ -283,7 +283,7 @@ def call_sample(
locus_queue = manager.Queue() # TODO: one queue per contig?

# Cache get_n_alleles calls for contigs
contig_n_alleles: dict[str, Optional[int]] = {}
contig_n_alleles: dict[str, int | None] = {}

def _get_contig_n_alleles(ctg: str):
if ctg not in contig_n_alleles:
Expand All @@ -295,12 +295,12 @@ def _get_contig_n_alleles(ctg: str):
num_loci: int = 0
# Keep track of all contigs we are processing to speed up downstream Mendelian inheritance analysis.
contig_set: set[str] = set()
last_contig: Optional[str] = None
last_contig: str | None = None
last_none_append_n_loci: int = 0
for t_idx, t in enumerate(parse_loci_bed(params.loci_file), 1):
contig = t[0]

n_alleles: Optional[int] = _get_contig_n_alleles(contig)
n_alleles: int | None = _get_contig_n_alleles(contig)
if (
n_alleles is None # Sex chromosome, but we don't have a specified sex chromosome karyotype
or n_alleles == 0 # Don't have this chromosome, e.g., Y chromosome for an XX individual
Expand Down Expand Up @@ -341,7 +341,7 @@ def _get_contig_n_alleles(ctg: str):

# If we're outputting a VCF, open the file and write the header
sample_id_str = params.sample_id or "sample"
vf: Optional[PySamVariantFile] = None
vf: PySamVariantFile | None = None
if vcf_path is not None:
vh = build_vcf_header(sample_id_str, params.reference_file)
vf = PySamVariantFile(vcf_path if vcf_path != "stdout" else "-", "w", header=vh)
Expand Down
3 changes: 1 addition & 2 deletions strkit/call/output/tsv.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import sys
from typing import Union

__all__ = ["output_tsv"]


def _cn_to_str(cn: Union[int, float]) -> str:
def _cn_to_str(cn: int | float) -> str:
return f"{cn:.1f}" if isinstance(cn, float) else str(cn)


Expand Down
6 changes: 3 additions & 3 deletions strkit/call/output/vcf.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
from pysam import FastaFile, VariantFile, VariantHeader, VariantRecord
from typing import Iterable, Optional

from strkit.utils import cat_strs, is_none
from strkit.utils import cat_strs, is_none, idx_0_getter
from ..allele import get_n_alleles
from ..params import CallParams
from ..utils import idx_0_getter, cn_getter
from ..utils import cn_getter

__all__ = [
"build_vcf_header",
Expand Down Expand Up @@ -170,7 +170,7 @@ def output_contig_vcf_lines(
call = result["call"]
call_95_cis = result["call_95_cis"]

seq_alleles_raw: tuple[Optional[str], ...] = (
seq_alleles_raw: tuple[str | None, ...] = (
((ref_seq, ref_start_anchor), *(seq_alts or (None,)))
if call is not None
else ()
Expand Down
Loading

0 comments on commit a357704

Please sign in to comment.