Skip to content

Commit

Permalink
refact(call): use rust extension for repeat counting + best seq repre…
Browse files Browse the repository at this point in the history
…sentative
  • Loading branch information
davidlougheed committed Jun 17, 2024
1 parent 1a3443e commit fd71d08
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 97 deletions.
8 changes: 4 additions & 4 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@ Flask==3.0.3
importlib_metadata==7.1.0
iniconfig==2.0.0
itsdangerous==2.2.0
Jinja2==3.1.3
Jinja2==3.1.4
joblib==1.3.2
MarkupSafe==2.1.5
numpy==1.26.4
orjson==3.10.3
orjson==3.10.5
packaging==24.0
pandas==2.2.2
parasail==1.3.4
Expand All @@ -24,10 +24,10 @@ pytest-cov==4.1.0
python-dateutil==2.8.2
pytz==2024.1
scikit-learn==1.4.2
scipy==1.13.0
scipy==1.13.1
six==1.16.0
statsmodels==0.14.2
strkit_rust_ext==0.12.2
strkit_rust_ext==0.13.0
threadpoolctl==3.4.0
tomli==2.0.1
tzdata==2023.4
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
"scikit-learn>=1.2.1,<1.5",
"scipy>=1.10,<1.14",
"statsmodels>=0.14.0,<0.15",
"strkit_rust_ext==0.12.2",
"strkit_rust_ext==0.13.0",
],

description="A toolkit for analyzing variation in short(ish) tandem repeats.",
Expand Down
23 changes: 2 additions & 21 deletions strkit/call/consensus.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import logging
from random import choice
from strkit_rust_ext import best_representatives as _best_representatives, consensus_seq as _consensus_seq
from typing import Iterable, Optional, Sequence
from strkit_rust_ext import best_representative, consensus_seq as _consensus_seq
from typing import Iterable, Optional

from .types import ConsensusMethod

Expand All @@ -11,24 +10,6 @@
]


def best_representative(seqs: Sequence[str]) -> Optional[str]:
"""
Slightly different from a true consensus - returns the string with the minimum Levenshtein distance to all other
strings for a particular allele. This roughly approximates a true consensus when |seqs| is large. If more than one
best representative exist, one is chosen at random. If |best| == |seqs| or |best| == 0, None is returned since there
is effectively no true consensus.
:param seqs: An iterable of sequences to find the best representative of.
:return: One of the best representative sequences from the passed sequences.
"""
res = _best_representatives(seqs)
if len(res) == 1:
return res.pop()
elif len(res) == 0:
return None
else:
return choice(tuple(res))


def _run_best_representative(seqs: list[str], logger: logging.Logger) -> Optional[tuple[str, ConsensusMethod]]:
res = best_representative(seqs)
method: ConsensusMethod = "best_rep"
Expand Down
78 changes: 7 additions & 71 deletions strkit/call/repeats.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import parasail

from functools import lru_cache
from typing import Literal, Optional, Union
from typing import Literal, Union

from .align_matrix import dna_matrix, indel_penalty, match_score
from strkit_rust_ext import get_repeat_count as _get_repeat_count

from .align_matrix import dna_matrix, indel_penalty
from .utils import idx_1_getter

__all__ = [
Expand Down Expand Up @@ -67,75 +69,9 @@ def get_repeat_count(
local_search_range: int = DEFAULT_LOCAL_SEARCH_RANGE, # TODO: Parametrize for user
step_size: int = 1,
) -> tuple[tuple[int, int], int, int]:

db_seq_profile: parasail.Profile = parasail.profile_create_sat(
f"{flank_left_seq}{tr_seq}{flank_right_seq}", dna_matrix)

max_init_score = (len(motif) * start_count + len(flank_left_seq) + len(flank_right_seq)) * match_score
start_score = score_candidate(db_seq_profile, motif, start_count, flank_left_seq, flank_right_seq)

score_diff = abs(start_score - max_init_score) / max_init_score

if score_diff < 0.05: # TODO: parametrize
# If we're very close to the maximum, explore less.
local_search_range = 1
step_size = 1
elif score_diff < 0.1 and local_search_range > 2:
local_search_range = 2
step_size = 1

explored_sizes: set[int] = {start_count}
best_size: int = start_count
best_score: int = start_score
n_explored: int = 1
to_explore: list[tuple[int, Literal[-1, 1]]] = [(start_count - 1, -1), (start_count + 1, 1)]

while to_explore and n_explored < max_iters:
size_to_explore, direction = to_explore.pop()
if size_to_explore < 0:
continue

skip_search: bool = step_size > local_search_range # whether we're skipping small areas for a faster search

best_size_this_round: Optional[int] = None
best_score_this_round: int = -99999999999

start_size = max(size_to_explore - (local_search_range if (direction == -1 or skip_search) else 0), 0)
end_size = size_to_explore + (local_search_range if (direction == 1 or skip_search) else 0)

for i in range(start_size, end_size + 1):
if i not in explored_sizes:
# Generate a candidate TR tract by copying the provided motif 'i' times & score it
# Separate this from the .get() to postpone computation to until we need it
explored_sizes.add(i)
i_score = score_candidate(db_seq_profile, motif, i, flank_left_seq, flank_right_seq)

if best_size_this_round is None or i_score > best_score_this_round:
best_size_this_round = i
best_score_this_round = i_score

n_explored += 1

if best_size_this_round:
# If this round is the best we've got so far, update the record size/score for the final return
if best_score_this_round > best_score:
best_size = best_size_this_round
best_score = best_score_this_round

if local_search_range > 1 and abs(best_score - max_init_score) / max_init_score < 0.05:
# reduce search range as we approach an optimum
local_search_range = 1

if (best_size_this_round > size_to_explore and
(new_rc := best_size_this_round + step_size) not in explored_sizes):
if new_rc >= 0:
to_explore.append((new_rc, 1))
elif (best_size_this_round < size_to_explore and
(new_rc := best_size_this_round - step_size) not in explored_sizes):
if new_rc >= 0:
to_explore.append((new_rc, -1))

return (best_size, best_score), n_explored, best_size - start_count
return _get_repeat_count(
start_count, tr_seq, flank_left_seq, flank_right_seq, motif, max_iters, local_search_range, step_size
)


def get_ref_repeat_count(
Expand Down

0 comments on commit fd71d08

Please sign in to comment.