Skip to content

Commit a357704

Browse files
committed
refact: rewrite Optional/Union as py3.10 union op
1 parent 9d34be5 commit a357704

20 files changed

+158
-174
lines changed

strkit/call/allele.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from warnings import simplefilter
2121

2222
from numpy.typing import NDArray
23-
from typing import Iterable, Literal, Optional, TypedDict, Union
23+
from typing import Iterable, Literal, TypedDict, Union
2424

2525
import strkit.constants as cc
2626

@@ -34,7 +34,7 @@
3434
"call_alleles",
3535
]
3636

37-
RepeatCounts = Union[list[int], tuple[int, ...], NDArray[np.int_]]
37+
RepeatCounts = list[int] | tuple[int, ...] | NDArray[np.int_]
3838

3939

4040
# K-means convergence errors - we expect convergence to some extent with homozygous alleles
@@ -54,7 +54,7 @@
5454
}
5555

5656

57-
def _array_as_int(n: Union[NDArray[np.int_], NDArray[np.float_]]) -> NDArray[np.int32]:
57+
def _array_as_int(n: NDArray[np.int_] | NDArray[np.float_]) -> NDArray[np.int32]:
5858
return np.rint(n).astype(np.int32)
5959

6060

@@ -65,7 +65,7 @@ def _calculate_cis(samples, ci: str = Literal["95", "99"]) -> NDArray[np.int32]:
6565
return _array_as_int(percentiles)
6666

6767

68-
def get_n_alleles(default_n_alleles: int, sample_sex_chroms: Optional[str], contig: str) -> Optional[int]:
68+
def get_n_alleles(default_n_alleles: int, sample_sex_chroms: str | None, contig: str) -> int | None:
6969
if contig in cc.M_CHROMOSOME_NAMES:
7070
return 1
7171

@@ -105,9 +105,9 @@ def fit_gmm(
105105
hq: bool,
106106
gm_filter_factor: int,
107107
init_params: GMMInitParamsMethod = "k-means++", # TODO: parameterize outside
108-
) -> Optional[object]:
108+
) -> object | None:
109109
sample_rs = sample.reshape(-1, 1)
110-
g: Optional[object] = None
110+
g: object | None = None
111111

112112
n_components: int = n_alleles
113113
while n_components > 0:
@@ -165,26 +165,26 @@ class CallDict(BaseCallDict, total=False):
165165
ps: int
166166

167167

168-
def make_read_weights(read_weights: Optional[Iterable[float]], num_reads: int) -> NDArray[np.float_]:
168+
def make_read_weights(read_weights: Iterable[float] | None, num_reads: int) -> NDArray[np.float_]:
169169
return np.array(
170170
read_weights if read_weights is not None else np.array(([1/num_reads] * num_reads) if num_reads else []))
171171

172172

173173
def call_alleles(
174174
repeats_fwd: NDArray[np.int32],
175175
repeats_rev: NDArray[np.int32],
176-
read_weights_fwd: Optional[Iterable[float]],
177-
read_weights_rev: Optional[Iterable[float]],
176+
read_weights_fwd: Iterable[float] | None,
177+
read_weights_rev: Iterable[float] | None,
178178
params: CallParams,
179179
min_reads: int,
180180
n_alleles: int,
181181
separate_strands: bool,
182182
read_bias_corr_min: int,
183183
gm_filter_factor: int,
184-
seed: Optional[int],
184+
seed: int | None,
185185
logger_: logging.Logger,
186186
debug_str: str,
187-
) -> Optional[CallDict]:
187+
) -> CallDict | None:
188188
fwd_len = repeats_fwd.shape[0]
189189
rev_len = repeats_rev.shape[0]
190190

@@ -268,7 +268,7 @@ def call_alleles(
268268

269269
gmm_cache = {}
270270

271-
def _get_fitted_gmm(s: Union[NDArray[np.int_], NDArray[np.float_]]) -> Optional[object]:
271+
def _get_fitted_gmm(s: NDArray[np.int_] | NDArray[np.float_]) -> object | None:
272272
if (s_t := s.tobytes()) not in gmm_cache:
273273
# Fit Gaussian mixture model to the resampled data
274274
gmm_cache[s_t] = fit_gmm(rng, s, n_alleles, allele_filter, params.hq, gm_filter_factor)
@@ -282,7 +282,7 @@ def _get_fitted_gmm(s: Union[NDArray[np.int_], NDArray[np.float_]]) -> Optional[
282282
for i in range(params.num_bootstrap):
283283
sample = concat_samples[i, :]
284284

285-
g: Optional[object] = _get_fitted_gmm(sample)
285+
g: object | None = _get_fitted_gmm(sample)
286286
if not g:
287287
# Could not fit any Gaussian mixture; skip this allele
288288
return None

strkit/call/call_locus.py

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from sklearn.mixture import GaussianMixture
1616

1717
from numpy.typing import NDArray
18-
from typing import Iterable, Literal, Optional, Union
18+
from typing import Iterable, Literal
1919

2020
from strkit_rust_ext import (
2121
CandidateSNVs,
@@ -28,7 +28,7 @@
2828
)
2929

3030
from strkit.call.allele import CallDict, call_alleles
31-
from strkit.utils import apply_or_none
31+
from strkit.utils import idx_0_getter, apply_or_none
3232

3333
from .align_matrix import match_score
3434
from .cigar import decode_cigar_np
@@ -45,9 +45,7 @@
4545
from .types import (
4646
VCFContigFormat, AssignMethod, AssignMethodWithHP, ConsensusMethod, ReadDict, ReadDictExtra, CalledSNV, LocusResult
4747
)
48-
from .utils import (
49-
idx_0_getter, cn_getter, find_pair_by_ref_pos, normalize_contig, get_new_seed, calculate_seq_with_wildcards
50-
)
48+
from .utils import cn_getter, find_pair_by_ref_pos, normalize_contig, get_new_seed, calculate_seq_with_wildcards
5149

5250

5351
__all__ = [
@@ -224,12 +222,12 @@ def call_alleles_with_haplotags(
224222
# ---
225223
logger_: logging.Logger,
226224
locus_log_str: str,
227-
) -> Optional[dict]:
225+
) -> dict | None:
228226
n_alleles: int = len(haplotags)
229227

230228
hp_reads: list[tuple[ReadDict, ...]] = []
231229
cns: list[NDArray[np.int32]] = []
232-
c_ws: list[Union[NDArray[np.int_], NDArray[np.float_]]] = []
230+
c_ws: list[NDArray[np.int_] | NDArray[np.float_]] = []
233231

234232
for hi, hp in enumerate(haplotags):
235233
# Find reads for cluster
@@ -315,13 +313,13 @@ def _determine_snv_call_phase_set(
315313
# ---
316314
logger_: logging.Logger,
317315
locus_log_str: str,
318-
) -> Optional[int]:
316+
) -> int | None:
319317
# May mutate: cdd_ordered
320318

321319
# We may need to re-order (flip) calls based on SNVs. Check each SNV to see if it's in the SNV genotype/phase-set
322320
# dictionary; otherwise, assign a phase set to all reads which have been used for peak calling here.
323321

324-
call_phase_set: Optional[int]
322+
call_phase_set: int | None
325323

326324
snv_pss_with_should_flip: list[tuple[int, bool]] = []
327325

@@ -464,7 +462,7 @@ def call_alleles_with_incorporated_snvs(
464462
rng: np.random.Generator,
465463
logger_: logging.Logger,
466464
locus_log_str: str,
467-
) -> tuple[AssignMethod, Optional[tuple[dict, list[CalledSNV]]]]:
465+
) -> tuple[AssignMethod, tuple[dict, list[CalledSNV]] | None]:
468466
assign_method: AssignMethod = "dist"
469467

470468
# TODO: parametrize min 'enough to do pure SNV haplotyping' thresholds
@@ -479,7 +477,7 @@ def call_alleles_with_incorporated_snvs(
479477

480478
for read_item in read_dict_items:
481479
rn, read = read_item
482-
snv_bases: Optional[tuple[tuple[str, int], ...]] = read_dict_extra[rn].get("snv_bases")
480+
snv_bases: tuple[tuple[str, int], ...] | None = read_dict_extra[rn].get("snv_bases")
483481

484482
if snv_bases is None:
485483
read_dict_items_with_no_snvs.append(read_item)
@@ -597,7 +595,7 @@ def call_alleles_with_incorporated_snvs(
597595
cdd: list[CallDict] = []
598596

599597
for ci in cluster_indices:
600-
cc: Optional[CallDict] = call_alleles(
598+
cc: CallDict | None = call_alleles(
601599
cns[ci], EMPTY_NP_ARRAY, # Don't bother separating by strand for now...
602600
c_ws[ci], (),
603601
params,
@@ -671,7 +669,7 @@ def call_alleles_with_incorporated_snvs(
671669
# - cdd_ordered
672670
# - called_useful_snvs
673671

674-
call_phase_set: Optional[int] = _determine_snv_call_phase_set(
672+
call_phase_set: int | None = _determine_snv_call_phase_set(
675673
read_dict,
676674
cdd_ordered,
677675
called_useful_snvs,
@@ -724,11 +722,11 @@ def _calc_motif_size_kmers(tr_read_seq_wc: str, tr_len: int, motif_size: int):
724722
yield tr_read_seq_wc[i:i + motif_size]
725723

726724

727-
def _ndarray_serialize(x: Iterable) -> list[Union[int, np.int_]]:
725+
def _ndarray_serialize(x: Iterable) -> list[int | np.int_]:
728726
return list(map(round, x))
729727

730728

731-
def _nested_ndarray_serialize(x: Iterable) -> list[list[Union[int, np.int_]]]:
729+
def _nested_ndarray_serialize(x: Iterable) -> list[list[int | np.int_]]:
732730
return list(map(_ndarray_serialize, x))
733731

734732

@@ -755,13 +753,13 @@ def call_locus(
755753
logger_: logging.Logger,
756754
locus_log_str: str,
757755
# ---
758-
snv_vcf_file: Optional[STRkitVCFReader] = None,
756+
snv_vcf_file: STRkitVCFReader | None = None,
759757
snv_vcf_contigs: tuple[str, ...] = (),
760758
snv_vcf_file_format: VCFContigFormat = "",
761759
# ---
762760
read_file_has_chr: bool = True,
763761
ref_file_has_chr: bool = True,
764-
) -> Optional[LocusResult]:
762+
) -> LocusResult | None:
765763
call_timer = time.perf_counter()
766764

767765
# params de-structuring ------------
@@ -870,7 +868,7 @@ def call_locus(
870868
ref_max_iters = 50
871869
ref_local_search_range = 1
872870

873-
ref_cn: Union[int, float]
871+
ref_cn: int | float
874872
(ref_cn, _), l_offset, r_offset, r_n_is, (ref_left_flank_seq, ref_seq, ref_right_flank_seq) = get_ref_repeat_count(
875873
ref_est_cn,
876874
ref_seq,
@@ -946,7 +944,7 @@ def get_read_length_partition_mean(p_idx: int) -> float:
946944

947945
# Find candidate SNVs, if we're using SNV data
948946

949-
candidate_snvs: Optional[CandidateSNVs] = None # Lookup dictionary for candidate SNVs by position
947+
candidate_snvs: CandidateSNVs | None = None # Lookup dictionary for candidate SNVs by position
950948
if n_overlapping_reads and should_incorporate_snvs and snv_vcf_file:
951949
# ^^ n_overlapping_reads check since otherwise we will have invalid left/right_most_coord
952950
candidate_snvs = snv_vcf_file.get_candidate_snvs(
@@ -991,8 +989,8 @@ def get_read_length_partition_mean(p_idx: int) -> float:
991989
right_flank_start = -1
992990
right_flank_end = -1
993991

994-
q_coords: Optional[NDArray[np.uint64]] = None
995-
r_coords: Optional[NDArray[np.uint64]] = None
992+
q_coords: NDArray[np.uint64] | None = None
993+
r_coords: NDArray[np.uint64] | None = None
996994

997995
# Soft-clipping in large insertions can result from mapping difficulties.
998996
# If we have a soft clip which overlaps with our TR region (+ flank), we can try to recover it

strkit/call/call_sample.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from pysam import VariantFile as PySamVariantFile
1616
from queue import Empty as QueueEmpty
1717
from threading import Lock
18-
from typing import Iterable, Literal, Optional
18+
from typing import Iterable, Literal
1919

2020
from .allele import get_n_alleles
2121
from .call_locus import call_locus
@@ -108,7 +108,7 @@ def locus_worker(
108108

109109
snv_vcf_reader = STRkitVCFReader(str(params.snv_vcf)) if params.snv_vcf else None
110110

111-
current_contig: Optional[str] = None
111+
current_contig: str | None = None
112112
results: list[LocusResult] = []
113113

114114
while True:
@@ -192,7 +192,7 @@ def locus_worker(
192192

193193

194194
def progress_worker(
195-
sample_id: Optional[str],
195+
sample_id: str | None,
196196
start_time: float,
197197
log_level: int,
198198
locus_queue: mp.Queue,
@@ -260,8 +260,8 @@ def parse_loci_bed(loci_file: str) -> Iterable[tuple[str, ...]]:
260260

261261
def call_sample(
262262
params: CallParams,
263-
json_path: Optional[str] = None,
264-
vcf_path: Optional[str] = None,
263+
json_path: str | None = None,
264+
vcf_path: str | None = None,
265265
indent_json: bool = False,
266266
output_tsv: bool = True,
267267
) -> None:
@@ -283,7 +283,7 @@ def call_sample(
283283
locus_queue = manager.Queue() # TODO: one queue per contig?
284284

285285
# Cache get_n_alleles calls for contigs
286-
contig_n_alleles: dict[str, Optional[int]] = {}
286+
contig_n_alleles: dict[str, int | None] = {}
287287

288288
def _get_contig_n_alleles(ctg: str):
289289
if ctg not in contig_n_alleles:
@@ -295,12 +295,12 @@ def _get_contig_n_alleles(ctg: str):
295295
num_loci: int = 0
296296
# Keep track of all contigs we are processing to speed up downstream Mendelian inheritance analysis.
297297
contig_set: set[str] = set()
298-
last_contig: Optional[str] = None
298+
last_contig: str | None = None
299299
last_none_append_n_loci: int = 0
300300
for t_idx, t in enumerate(parse_loci_bed(params.loci_file), 1):
301301
contig = t[0]
302302

303-
n_alleles: Optional[int] = _get_contig_n_alleles(contig)
303+
n_alleles: int | None = _get_contig_n_alleles(contig)
304304
if (
305305
n_alleles is None # Sex chromosome, but we don't have a specified sex chromosome karyotype
306306
or n_alleles == 0 # Don't have this chromosome, e.g., Y chromosome for an XX individual
@@ -341,7 +341,7 @@ def _get_contig_n_alleles(ctg: str):
341341

342342
# If we're outputting a VCF, open the file and write the header
343343
sample_id_str = params.sample_id or "sample"
344-
vf: Optional[PySamVariantFile] = None
344+
vf: PySamVariantFile | None = None
345345
if vcf_path is not None:
346346
vh = build_vcf_header(sample_id_str, params.reference_file)
347347
vf = PySamVariantFile(vcf_path if vcf_path != "stdout" else "-", "w", header=vh)

strkit/call/output/tsv.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
import sys
2-
from typing import Union
32

43
__all__ = ["output_tsv"]
54

65

7-
def _cn_to_str(cn: Union[int, float]) -> str:
6+
def _cn_to_str(cn: int | float) -> str:
87
return f"{cn:.1f}" if isinstance(cn, float) else str(cn)
98

109

strkit/call/output/vcf.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,10 @@
77
from pysam import FastaFile, VariantFile, VariantHeader, VariantRecord
88
from typing import Iterable, Optional
99

10-
from strkit.utils import cat_strs, is_none
10+
from strkit.utils import cat_strs, is_none, idx_0_getter
1111
from ..allele import get_n_alleles
1212
from ..params import CallParams
13-
from ..utils import idx_0_getter, cn_getter
13+
from ..utils import cn_getter
1414

1515
__all__ = [
1616
"build_vcf_header",
@@ -170,7 +170,7 @@ def output_contig_vcf_lines(
170170
call = result["call"]
171171
call_95_cis = result["call_95_cis"]
172172

173-
seq_alleles_raw: tuple[Optional[str], ...] = (
173+
seq_alleles_raw: tuple[str | None, ...] = (
174174
((ref_seq, ref_start_anchor), *(seq_alts or (None,)))
175175
if call is not None
176176
else ()

0 commit comments

Comments
 (0)