Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
6b47f89
Add a flag for SCFV to the number method.
ALGW71 May 1, 2025
940bfe9
Simple CWC fix for SCFVs.
ALGW71 May 1, 2025
8116ec5
Trial code for new SCFV method, not reliant on CWC. Need to document …
ALGW71 Jun 12, 2025
371c636
More edits and testing on multi chain fasta.
ALGW71 Jun 12, 2025
b741e58
SCFV code when missing CWC is as good as ever. Now add back in CWC an…
ALGW71 Jun 13, 2025
14d0ad7
SCFV mode now working better - CWC does not work. Back to window method.
ALGW71 Jun 13, 2025
f52c9fd
On test set 6/1133 are misnumbered but still ID as SCFVs.
ALGW71 Jun 13, 2025
aeed5d3
Added comments for clarity.
ALGW71 Jun 16, 2025
6cc98ba
Added comments for clarity.
ALGW71 Jun 16, 2025
c8e60bc
Merge branch 'main' into scfv_testing
ALGW71 Jun 16, 2025
c3b2453
Add SCFV test file.
ALGW71 Jun 16, 2025
1cef6e3
Modify SCFv notebook.
ALGW71 Jun 16, 2025
1ba5538
Basics of MHC working - need to account for long seq length.
ALGW71 Jun 25, 2025
1ea3db6
Basics of MHC working - need to account for long seq length.
ALGW71 Jun 25, 2025
7ea92d4
Long sequence issue sorted. Now need to explore edge cases.
ALGW71 Jun 25, 2025
a6f9601
Added cli support for mhc.
ALGW71 Jun 25, 2025
9dd30be
TCR cut off implemented and continue statement skips all MHC.
ALGW71 Jul 1, 2025
4d96cbc
Forgot to modify score cutoff.
ALGW71 Jul 2, 2025
bb80370
Update with new model and tweaks to classifii code.
ALGW71 Jul 2, 2025
d9230aa
Merge branch 'main' into MHC
ALGW71 Jul 2, 2025
9226083
to_scheme() should not renumber MHC/HLA - these have no conversion an…
ALGW71 Jul 2, 2025
7b0eacf
The MHC chain type annotations are far from perfect - need to train m…
ALGW71 Jul 2, 2025
9e94b35
Found an important BUG - self.seq_type was passed to number with type…
ALGW71 Jul 2, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ __pycache__/
*.out
*.pdb
*.gz
*.cif

!*expected*.txt
!*expected*.json
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ repos:
rev: v5.0.0
hooks:
- id: check-added-large-files
exclude: "^tests/data/.*"
exclude: "^(tests/data/.*|src/anarcii/models/.*)$"
- id: check-yaml
- id: check-merge-conflict
- id: end-of-file-fixer
Expand Down
165 changes: 152 additions & 13 deletions notebook/renumbering_pdb_files.ipynb

Large diffs are not rendered by default.

2,437 changes: 2,437 additions & 0 deletions notebook/scfv_testing.ipynb

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions src/anarcii/classifii/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from anarcii.inference.utils import dataloader
from anarcii.input_data_processing.tokeniser import Tokeniser

type_tokens = {"A": "antibody", "T": "tcr"}
type_tokens = {"A": "antibody", "T": "tcr", "M": "mhc"}


class TypeTokeniser(Tokeniser):
Expand Down Expand Up @@ -103,7 +103,7 @@ def __call__(self, sequences: dict[str, str]) -> dict[str, dict[str, str]]:
tokenized_seqs = []
# Capped at 235 for now.
for seq in sequences.values():
bookend_seq = [self.aa.start, *seq[:235], self.aa.end]
bookend_seq = [self.aa.start, *seq[:305], self.aa.end]
try:
tokenized_seqs.append(torch.from_numpy(self.aa.encode(bookend_seq)))
except KeyError as e:
Expand Down
Binary file modified src/anarcii/classifii/classifii.pt
Binary file not shown.
2 changes: 1 addition & 1 deletion src/anarcii/classifii/model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
import torch.nn as nn

seq_max_len = 240
seq_max_len = 310


class EncoderLayer(nn.Module):
Expand Down
2 changes: 1 addition & 1 deletion src/anarcii/classifii/params.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
"TRG_PAD_IDX": 0,

"INPUT_DIM": 24,
"OUTPUT_DIM": 4,
"OUTPUT_DIM": 5,
"HID_DIM": 128,
"LAYERS": 1,

Expand Down
4 changes: 2 additions & 2 deletions src/anarcii/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@
"--seq_type",
type=str,
default="antibody",
choices=["antibody", "tcr", "vnar", "vhh", "shark", "unknown"],
choices=["antibody", "tcr", "vnar", "vhh", "shark", "unknown", "mhc"],
help=(
"Sequence type to process: antibody, tcr, vnar/vhh/shark or unknown"
"Sequence type to process: antibody, tcr, vnar/vhh/shark, mhc or unknown"
"(default: antibody)."
),
)
Expand Down
6 changes: 2 additions & 4 deletions src/anarcii/inference/model.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import torch
import torch.nn as nn

seq_max_len = 210


class EncoderLayer(nn.Module):
def __init__(self, hid_dim, n_heads, pf_dim, dropout, device):
Expand Down Expand Up @@ -49,7 +47,7 @@ def __init__(
pf_dim,
dropout,
device,
max_length=seq_max_len,
max_length,
):
super().__init__()

Expand Down Expand Up @@ -242,7 +240,7 @@ def __init__(
pf_dim,
dropout,
device,
max_length=seq_max_len,
max_length,
):
super().__init__()

Expand Down
13 changes: 13 additions & 0 deletions src/anarcii/inference/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,16 @@ def __init__(self, sequence_type, mode, device):
def _load_params(self):
if self.type == "shark":
param_filename = f"{self.type}_4_2_128_512.json"

elif self.type == "mhc":
param_filename = f"{self.type}.json"

elif self.mode == "speed":
param_filename = f"{self.type}_4_1_128_512.json"

elif self.mode == "accuracy":
param_filename = f"{self.type}_4_2_128_512.json"

else:
raise ValueError(
"Invalid mode specified. Choose either 'speed' or 'accuracy' or "
Expand Down Expand Up @@ -67,6 +73,11 @@ def _get_model_path(self):
return str(model_path)

def _load_model(self):
if self.type == "mhc":
seq_max_len = 400
else:
seq_max_len = 210

ENC = model.Encoder(
self.INPUT_DIM,
self.HID_DIM,
Expand All @@ -75,6 +86,7 @@ def _load_model(self):
self.ENC_PF_DIM,
self.ENC_DROPOUT,
self.device,
seq_max_len,
)

DEC = model.Decoder(
Expand All @@ -85,6 +97,7 @@ def _load_model(self):
self.DEC_PF_DIM,
self.DEC_DROPOUT,
self.device,
seq_max_len,
)

S2S = model.S2S(ENC, DEC, self.SRC_PAD_IDX, self.TRG_PAD_IDX, self.device)
Expand Down
27 changes: 25 additions & 2 deletions src/anarcii/inference/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
# A cutoff score to consider a sequence as well numbered by the language model.
CUTOFF_SCORE = 15

# For TCRs this needs to be higher.
TCR_CUTOFF_SCORE = 25


class ModelRunner:
"""
Expand Down Expand Up @@ -46,14 +49,20 @@ def __init__(self, sequence_type, mode, batch_size, device, verbose):
self.batch_size = batch_size
self.device = device
self.verbose = verbose
self.cut_off = CUTOFF_SCORE

if self.type in ["antibody", "shark"]:
self.sequence_tokeniser = NumberingTokeniser("protein_antibody")
self.number_tokeniser = NumberingTokeniser("number_antibody")

elif self.type == "tcr":
self.cut_off = TCR_CUTOFF_SCORE
self.sequence_tokeniser = NumberingTokeniser("protein_tcr")
self.number_tokeniser = NumberingTokeniser("number_tcr")

elif self.type == "mhc":
self.sequence_tokeniser = NumberingTokeniser("protein_mhc")
self.number_tokeniser = NumberingTokeniser("number_mhc")
else:
raise ValueError(f"Invalid model type: {self.type}")

Expand Down Expand Up @@ -256,7 +265,7 @@ def _predict_numbering(self, dl):
normalized_score = 0.0
error_msg = "Less than 50 non insertion residues numbered."

if normalized_score < CUTOFF_SCORE:
if normalized_score < self.cut_off:
numbering.append(
{
"numbering": None,
Expand Down Expand Up @@ -438,8 +447,22 @@ def _predict_numbering(self, dl):
# The last number depends on chain type - check type here.
if pred_tokens[batch_no, 1] in ["H", "A", "G"]:
last_num = 128
else:
elif pred_tokens[batch_no, 1] in ["L", "K", "B", "D"]:
last_num = 127
else:
# Fail or MHC stop and continue
numbering.append(
{
"numbering": list(zip(nums, residues)),
"chain_type": str(pred_tokens[batch_no, 1]),
"score": normalized_score,
"query_start": start_index,
"query_end": end_index,
"error": None,
"scheme": "imgt",
}
)
continue

try:
last_predicted_num = int(
Expand Down
48 changes: 12 additions & 36 deletions src/anarcii/inference/window_selector.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
# import matplotlib.pyplot as plt
import torch

# import matplotlib.pyplot as plt
from anarcii.input_data_processing.tokeniser import NumberingTokeniser

from .model_loader import Loader
Expand All @@ -15,32 +13,6 @@ def first_index_above_threshold(preds, threshold=25):
return None


def detect_peaks(data, threshold=25, min_distance=50):
peaks = []
peak_values = []

for i in range(1, len(data) - 1):
# Check if current point is a peak above the threshold
if data[i] > data[i - 1] and data[i] > data[i + 1] and data[i] > threshold:
# Ensure a minimum distance from the last detected peak
if len(peaks) == 0 or (i - peaks[-1] >= min_distance):
peaks.append(i)
peak_values.append(data[i])

print(
"Number of high scoring chains found: ",
len(peaks),
"\n",
"Indices: ",
peaks,
"\n",
"Values: ",
peak_values,
)

return peaks


class WindowFinder:
def __init__(self, sequence_type, mode, batch_size, device):
self.type = sequence_type.lower()
Expand All @@ -55,6 +27,11 @@ def __init__(self, sequence_type, mode, batch_size, device):
elif self.type == "tcr":
self.sequence_tokeniser = NumberingTokeniser("protein_tcr")
self.number_tokeniser = NumberingTokeniser("number_tcr")

elif self.type == "mhc":
self.sequence_tokeniser = NumberingTokeniser("protein_mhc")
self.number_tokeniser = NumberingTokeniser("number_mhc")

else:
raise ValueError(f"Invalid model type: {self.type}")

Expand All @@ -64,16 +41,11 @@ def _load_model(self):
model_loader = Loader(self.type, self.mode, self.device)
return model_loader.model

def __call__(self, list_of_seqs, fallback: bool = False):
def __call__(self, list_of_seqs, fallback: bool = False, scfv: bool = False):
"""
Select the highest-scoring sequence.

list_of_seqs: Sequences from which to select the highest scoring above a
threshold score.
fallback: If `True` and no sequence scores above the threshold for
selection, return the highest-scoring sequence anyway. Otherwise,
return `None`.

list_of_seqs: Sequences from whi, pdb_out_stem="blah"
"""
dl = dataloader(self.batch_size, list_of_seqs)
preds = []
Expand All @@ -96,7 +68,11 @@ def __call__(self, list_of_seqs, fallback: bool = False):
normalized_likelihood = likelihoods[batch_no, 0].item()
preds.append(normalized_likelihood)

# print(preds)
if scfv:
#### DEBUG CMDS FOR SCFV DEV ####
# plt.plot(preds)
# plt.show()
return preds

# find first index over 25
magic_number = first_index_above_threshold(preds, 25)
Expand Down
Loading