diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..9295f54 --- /dev/null +++ b/.gitignore @@ -0,0 +1,7 @@ +build +*.egg-info +.python-version +__pycache__ +.mypy_cache + +ablang/model-weights-* diff --git a/ablang/fairseq_mha.py b/ablang/fairseq_mha.py index fd29f8d..992b5e6 100644 --- a/ablang/fairseq_mha.py +++ b/ablang/fairseq_mha.py @@ -750,7 +750,7 @@ def forward( # treats bias in linear module as method. and not torch.jit.is_scripting() # The Multihead attention implemented in pytorch forces strong dimension check - # for input embedding dimention and K,Q,V projection dimension. + # for input embedding dimension and K,Q,V projection dimension. # Since pruning will break the dimension check and it is not easy to modify the pytorch API, # it is preferred to bypass the pytorch MHA when we need to skip embed_dim_check and not self.skip_embed_dim_check @@ -1082,7 +1082,7 @@ def reorder_incremental_state( return incremental_state def set_beam_size(self, beam_size): - """Used for effiecient beamable enc-dec attention""" + """Used for efficient beamable enc-dec attention""" self.beam_size = beam_size def _get_input_buffer( diff --git a/ablang/pretrained.py b/ablang/pretrained.py index 6c8c231..41348ef 100644 --- a/ablang/pretrained.py +++ b/ablang/pretrained.py @@ -49,7 +49,7 @@ def __init__(self, chain="heavy", model_folder="download", random_init=False, nc self.AbLang.to(self.used_device) if not random_init: - self.AbLang.load_state_dict(torch.load(self.model_file, map_location=self.used_device)) + self.AbLang.load_state_dict(torch.load(self.model_file, map_location=self.used_device, weights_only=True)) self.tokenizer = tokenizers.ABtokenizer(os.path.join(model_folder, 'vocab.json')) self.AbRep = self.AbLang.AbRep diff --git a/examples/example-ablang-usecases.ipynb b/examples/example-ablang-usecases.ipynb index 9a03c96..0916a4e 100644 --- a/examples/example-ablang-usecases.ipynb +++ b/examples/example-ablang-usecases.ipynb @@ -41,7 +41,7 @@ "--------------\n", "## **AbLang building blocks**\n", "\n", - "For easy use we have build the AbLang module (see below), however; for incoorporating AbLang into personal codebases it might be more convenient to use the individual building blocks." + "For easy use we have build the AbLang module (see below), however; for incorporating AbLang into personal codebases it might be more convenient to use the individual building blocks." ] }, { diff --git a/examples/example-ablang-usecases.py b/examples/example-ablang-usecases.py new file mode 100644 index 0000000..405d0e7 --- /dev/null +++ b/examples/example-ablang-usecases.py @@ -0,0 +1,216 @@ +#!/usr/bin/env python +# coding: utf-8 + +# # **AbLang Examples** +# +# AbLang is a RoBERTa inspired language model trained on antibody sequences. The following is a set of possible use cases of AbLang. + +# In[1]: +import sys +from pathlib import Path + +sys.path.append(str(Path(__file__).parent.parent)) + +import ablang + +# Print module path: + +# In[2]: + + +heavy_ablang = ablang.pretrained("heavy") +heavy_ablang.freeze() + + +# -------------- +# ## **AbLang building blocks** +# +# For easy use we have build the AbLang module (see below), however; for incorporating AbLang into personal codebases it might be more convenient to use the individual building blocks. + +# #### AbLang tokenizer + +# In[3]: + + +seqs = [ + "EVQLVESGPGLVQPGKSLRLSCVASGFTFSGYGMHWVRQAPGKGLEWIALIIYDESNKYYADSVKGRFTISRDNSKNTLYLQMSSLRAEDTAVFYCAKVKFYDPTAPNDYWGQGTLVTVSS", + "QVQLVQSGAEVKKPGASVKVSCKASGYTFTSYGISWVRQAPGQGLEWMGWISAYNGNTNYAQKLQGRVTMTTDTSTSTAYMELRSLRSDDTAVYYCARVLGWGSMDVWGQGTTVTVSS", +] + +print("-" * 100) +print("Input sequences:") +for seq in seqs: + print("-", seq) + +tokens = heavy_ablang.tokenizer(seqs, pad=True) + +print("-" * 100) +print("Tokens:") +print(tokens.shape) +print(tokens) + +# #### AbLang encoder (AbRep) + +rescodings = heavy_ablang.AbRep(tokens) + +print("-" * 100) +print("Res-codings:") +print(rescodings) + +# #### AbLang full model (AbRep+AbHead) + +# In[5]: + +likelihoods = heavy_ablang.AbLang(tokens) + +print("-" * 100) +print("Likelihoods:") +print(likelihoods) + + +# ----- +# ## **AbLang module: Res-codings** +# +# The res-codings are the 768 values for each residue, describing both a residue's individual properties (e.g. size, hydrophobicity, etc.) and properties in relation to the rest of the sequence (e.g. secondary structure, position, etc.). +# +# To calculate the res-codings, you can use the mode "rescoding" as seen below. + +# In[6]: + + +seqs = [ + "EVQLVESGPGLVQPGKSLRLSCVASGFTFSGYGMHWVRQAPGKGLEWIALIIYDESNKYYADSVKGRFTISRDNSKNTLYLQMSSLRAEDTAVFYCAKVKFYDPTAPNDYWGQGTLVTVSS", + "QVQLVQSGAEVKKPGASVKVSCKASGYTFTSYGISWVRQAPGQGLEWMGWISAYNGNTNYAQKLQGRVTMTTDTSTSTAYMELRSLRSDDTAVYYCARVLGWGSMDVWGQGTTVTVSS", +] + +rescodings = heavy_ablang(seqs, mode="rescoding") + +print("-" * 100) +print("The output shape of a single sequence:", rescodings[0].shape) +print("This shape is different for each sequence, depending on their length.") +print("-" * 100) +print(rescodings) + + +# ---- +# An additional feature, is the ability to align the rescodings. This can be done by setting the parameter align to "True". +# +# Alignment is done by numbering with anarci and then aligning sequences to all unique numberings found in input antibody sequences. +# +# **NB:** You need to install anarci and pandas for this feature. + +# In[7]: + +try: + import anarci +except ImportError: + print("Please install anarci to use the alignment feature.") + sys.exit() + +seqs = [ + "EVQLVESGPGLVQPGKSLRLSCVASGFTFSGYGMHWVRQAPGKGLEWIALIIYDESNKYYADSVKGRFTISRDNSKNTLYLQMSSLRAEDTAVFYCAKVKFYDPTAPNDYWGQGTLVTVSS", + "QVQLVQSGAEVKKPGASVKVSCKASGYTFTSYGISWVRQAPGQGLEWMGWISAYNGNTNYAQKLQGRVTMTTDTSTSTAYMELRSLRSDDTAVYYCARVLGWGSMDVWGQGTTVTVSS", +] + +rescodings = heavy_ablang(seqs, mode="rescoding", align=True) + +print("-" * 100) +print( + "The output shape for the aligned sequences ('aligned_embeds'):", + rescodings[0].aligned_embeds.shape, +) +print( + "This output also includes this numberings ('number_alignment') used for this set of sequences." +) +print("-" * 100) +print(rescodings[0].aligned_embeds) +print(rescodings[0].number_alignment) + + +# --------- +# ## **AbLang module: Seq-codings** +# +# Seq-codings are a set of 768 values for each sequences, derived from averaging across the res-codings. Seq-codings allow one to avoid sequence alignments, as every antibody sequence, regardless of their length, will be represented with 768 values. + +# In[8]: + + +seqs = [ + "EVQLVESGPGLVQPGKSLRLSCVASGFTFSGYGMHWVRQAPGKGLEWIALIIYDESNKYYADSVKGRFTISRDNSKNTLYLQMSSLRAEDTAVFYCAKVKFYDPTAPNDYWGQGTLVTVSS", + "QVQLVQSGAEVKKPGASVKVSCKASGYTFTSYGISWVRQAPGQGLEWMGWISAYNGNTNYAQKLQGRVTMTTDTSTSTAYMELRSLRSDDTAVYYCARVLGWGSMDVWGQGTTVTVSS", +] + +seqcodings = heavy_ablang(seqs, mode="seqcoding") +print("-" * 100) +print("The output shape of the seq-codings:", seqcodings.shape) +print("-" * 100) + +print(seqcodings) + + +# ----- +# ## **AbLang module: Residue likelihood** +# +# Res- and seq-codings are both derived from the representations created by AbRep. Another interesting representation are the likelihoods created by AbHead. These values are the likelihoods of each amino acids at each position in the sequence. These can be used to explore which amino acids are most likely to be mutated into and thereby explore the mutational space. +# +# **NB:** Currently, the likelihoods includes the start and end tokens and padding. + +# In[9]: + + +seqs = [ + "EVQLVESGPGLVQPGKSLRLSCVASGFTFSGYGMHWVRQAPGKGLEWIALIIYDESNKYYADSVKGRFTISRDNSKNTLYLQMSSLRAEDTAVFYCAKVKFYDPTAPNDYWGQGTLVTVSS", + "QVQLVQSGAEVKKPGASVKVSCKASGYTFTSYGISWVRQAPGQGLEWMGWISAYNGNTNYAQKLQGRVTMTTDTSTSTAYMELRSLRSDDTAVYYCARVLGWGSMDVWGQGTTVTVSS", +] + +likelihoods = heavy_ablang(seqs, mode="likelihood") +print("-" * 100) +print("The output shape with paddings still there:", likelihoods.shape) +print("-" * 100) +print(likelihoods) + + +# ### The corresponding amino acids for each likelihood +# +# For each position the likelihood for each of the 20 amino acids are returned. The amino acid order can be found by looking at the ablang vocabulary. For this output the likelihoods for '<', '-', '>' and '\*' have been removed. + +# In[10]: + + +ablang_vocab = heavy_ablang.tokenizer.vocab_to_aa +ablang_vocab + + +# ----- +# ## **AbLang module: Antibody sequence restoration** +# +# In some cases, an antibody sequence is missing some residues. This could be derived from sequencing errors or limitations of current sequencing methods. To solve this AbLang has the "restore" mode, as seen below, which picks the amino acid with the highest likelihood for residues marked with an asterisk (*). + +# In[11]: + + +seqs = [ + "EV*LVESGPGLVQPGKSLRLSCVASGFTFSGYGMHWVRQAPGKGLEWIALIIYDESNKYYADSVKGRFTISRDNSKNTLYLQMSSLRAEDTAVFYCAKVKFYDPTAPNDYWGQGTLVTVSS", + "*************PGKSLRLSCVASGFTFSGYGMHWVRQAPGKGLEWIALIIYDESNK*YADSVKGRFTISRDNSKNTLYLQMSSLRAEDTAVFYCAKVKFYDPTAPNDYWGQGTL*****", +] + +print("-" * 100) +print("Restoration of masked residues.") +print("-" * 100) +print(heavy_ablang(seqs, mode="restore")) + + +# In cases where sequences are missing unknown lengths at the ends, we can use the "align=True" argument. + +# In[12]: + + +seqs = [ + "EV*LVESGPGLVQPGKSLRLSCVASGFTFSGYGMHWVRQAPGKGLEWIALIIYDESNKYYADSVKGRFTISRDNSKNTLYLQMSSLRAEDTAVFYCAKVKFYDPTAPNDYWGQGTLVTVSS", + "PGKSLRLSCVASGFTFSGYGMHWVRQAPGKGLEWIALIIYDESNK*YADSVKGRFTISRDNSKNTLYLQMSSLRAEDTAVFYCAKVKFYDPTAPNDYWGQGTL", +] + +print("-" * 100) +print("Restoration of masked residues and unknown missing end lengths.") +print("-" * 100) +print(heavy_ablang(seqs, mode="restore", align=True)) diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..18cc950 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,4 @@ +import sys +from pathlib import Path + +sys.path.append(str(Path(__file__).parent.parent)) \ No newline at end of file diff --git a/tests/test_ablang.py b/tests/test_ablang.py new file mode 100644 index 0000000..2e30393 --- /dev/null +++ b/tests/test_ablang.py @@ -0,0 +1,42 @@ +import ablang + +# import numpy as np +# import numpy.testing as npt + + +def test_ablang(): + model = ablang.pretrained("light") + model.freeze() + + seqs = [ + "EVQLVESGPGLVQPGKSLRLSCVASGFTFSGYGMHWVRQAPGKGLEWIALIIYDESNKYYADSVKGRFTISRDNSKNTLYLQMSSLRAEDTAVFYCAKVKFYDPTAPNDYWGQGTLVTVSS", + "QVQLVQSGAEVKKPGASVKVSCKASGYTFTSYGISWVRQAPGQGLEWMGWISAYNGNTNYAQKLQGRVTMTTDTSTSTAYMELRSLRSDDTAVYYCARVLGWGSMDVWGQGTTVTVSS", + ] + + tokens = model.tokenizer(seqs, pad=True) + + assert tuple(tokens.shape) == (2, 123) + + # fmt: off + expected_tokens = [ + [ + 0, 6, 15, 10, 20, 15, 6, 7, 12, 13, 12, 20, 15, 10, 13, 12, 4, 7, + 20, 2, 20, 7, 11, 15, 14, 7, 12, 17, 8, 17, 7, 12, 18, 12, 1, 3, 19, + 15, 2, 10, 14, 13, 12, 4, 12, 20, 6, 19, 16, 14, 20, 16, 16, 18, 5, + 6, 7, 9, 4, 18, 18, 14, 5, 7, 15, 4, 12, 2, 17, 8, 16, 7, 2, 5, 9, + 7, 4, 9, 8, 20, 18, 20, 10, 1, 7, 7, 20, 2, 14, 6, 5, 8, 14, 15, 17, + 18, 11, 14, 4, 15, 4, 17, 18, 5, 13, 8, 14, 13, 9, 5, 18, 19, 12, + 10, 12, 8, 20, 15, 8, 15, 7, 7, 22, + ], + [ + 0, 10, 15, 10, 20, 15, 10, 7, 12, 14, 6, 15, 4, 4, 13, 12, 14, 7, + 15, 4, 15, 7, 11, 4, 14, 7, 12, 18, 8, 17, 8, 7, 18, 12, 16, 7, 19, + 15, 2, 10, 14, 13, 12, 10, 12, 20, 6, 19, 1, 12, 19, 16, 7, 14, 18, + 9, 12, 9, 8, 9, 18, 14, 10, 4, 20, 10, 12, 2, 15, 8, 1, 8, 8, 5, 8, + 7, 8, 7, 8, 14, 18, 1, 6, 20, 2, 7, 20, 2, 7, 5, 5, 8, 14, 15, 18, + 18, 11, 14, 2, 15, 20, 12, 19, 12, 7, 1, 5, 15, 19, 12, 10, 12, 8, + 8, 15, 8, 15, 7, 7, 22, 21, 21, 21] + ] + # fmt: on + + assert tokens.tolist() == expected_tokens