-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
64 lines (48 loc) · 3.09 KB
/
main.py
File metadata and controls
64 lines (48 loc) · 3.09 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import os
import argparse
from scripts.msa import compute_msa
from scripts.extract_co_evovled_pairs import extract_coevolved_pairs
from scripts.esm2 import compute_esm2
from scripts.construct_graphs import construct_co_evolved_graphs
from scripts.metal_binding_predictor import metal_binding_predictor
from scripts.metal_type_predictor import metal_type_predictor
def main():
parser = argparse.ArgumentParser(description='Process a FASTA file.')
parser.add_argument('input_fasta', type=str, help='Path to the input FASTA file containing protein sequences.')
args = parser.parse_args()
input_fasta = args.input_fasta
run_folder = f'run_{os.path.basename(input_fasta).split(".")[0]}'
os.makedirs(run_folder, exist_ok=True)
msa_folder = os.path.join(run_folder, f'{os.path.basename(input_fasta).split(".")[0]}_msa')
coevolved_pairs_file = os.path.join(run_folder, f'{os.path.basename(input_fasta).split(".")[0]}_coevolved_pairs.tsv')
embeddings_folder = os.path.join(run_folder, f'{os.path.basename(input_fasta).split(".")[0]}_embeddings')
co_evolved_graphs_file = os.path.join(run_folder, f'{os.path.basename(input_fasta).split(".")[0]}_co_evolved_graphs.pt')
co_evolved_metal_binding_graphs_file = os.path.join(run_folder, f'{os.path.basename(input_fasta).split(".")[0]}_co_evolved_metal_binding_graphs.pt')
metal_binding_file = os.path.join(run_folder, f'{os.path.basename(input_fasta).split(".")[0]}_metal_binding_result.tsv')
metal_type_file = os.path.join(run_folder, f'{os.path.basename(input_fasta).split(".")[0]}_metal_type_result.tsv')
root = os.path.dirname(os.path.realpath(__file__))
metal_binding_predictors_path = os.path.join(root, 'model_weights/metal_binding_predictor')
metal_type_predictor_path = os.path.join(root, 'model_weights/metal_type_predictor')
# Step 1: Compute MSA using compute_msa from utils_msa
print("Performing MSA using Colabfold server...")
compute_msa(input_fasta, msa_folder)
# Step 2: Extract coevolved pairs
print("Extracing co-evolved pairs...")
extract_coevolved_pairs(msa_folder, coevolved_pairs_file, num_seq=64, coevo_threshold=0.1, cuda=0)
# Step 3: Compute embedding
print("Deriving ESM2 embedding...")
compute_esm2(input_fasta, embeddings_folder)
# Step 4: Construct Co-evolved graphs
print("Constructing graphs...")
construct_co_evolved_graphs(coevolved_pairs_file, embeddings_folder, co_evolved_graphs_file)
# Step 5: Metal binding prediction
print("Predicting metal binding residues...")
metal_binding_predictor(co_evolved_graphs_file, metal_binding_predictors_path, metal_binding_file)
# Step 6: Construct Co-evolved Metal-binding graphs
print("Constructing co-evolved metal-binding graphs...")
construct_co_evolved_graphs(coevolved_pairs_file, embeddings_folder, co_evolved_metal_binding_graphs_file, metal_binding_file)
# Step 7: Metal type prediction
print("Predicting metal types...")
metal_type_predictor(co_evolved_metal_binding_graphs_file, metal_type_predictor_path, metal_type_file)
if __name__ == '__main__':
main()