Skip to content

Commit e5572dd

Browse files
committed
Add feature insertion script
1 parent fb83378 commit e5572dd

9 files changed

+429
-2304
lines changed

environment.yml

+42-7
Original file line numberDiff line numberDiff line change
@@ -150,8 +150,6 @@ dependencies:
150150
- llvm-openmp=16.0.4=h4dfa4b3_0
151151
- lz4-c=1.9.4=h6a678d5_0
152152
- magma=2.6.2=hc72dce7_0
153-
- matplotlib=3.3.2=0
154-
- matplotlib-base=3.3.2=py38h5c7f4ab_1
155153
- matplotlib-inline=0.1.6=pyhd8ed1ab_0
156154
- mkl=2022.2.1=h84fe81f_16997
157155
- mpi=1.0=openmpi
@@ -164,8 +162,6 @@ dependencies:
164162
- ninja=1.11.1=h924138e_0
165163
- nsight-compute=2023.1.1.4=0
166164
- numexpr=2.8.4=py38hd2a5715_1
167-
- numpy=1.24.3=py38hf838250_0
168-
- numpy-base=1.24.3=py38h1e6e340_0
169165
- openh264=2.1.1=h780b84a_0
170166
- openjpeg=2.5.0=hfec8fc6_2
171167
- openmpi=4.1.5=h414af15_101
@@ -187,7 +183,6 @@ dependencies:
187183
- pycparser=2.21=pyhd3eb1b0_0
188184
- pygments=2.15.1=pyhd8ed1ab_0
189185
- pyopenssl=23.1.1=pyhd8ed1ab_0
190-
- pyparsing=3.1.0=pyhd8ed1ab_0
191186
- pysocks=1.7.1=py38h06a4308_0
192187
- python=3.8.16=he550d4f_1_cpython
193188
- python-dateutil=2.8.2=pyhd3eb1b0_0
@@ -235,63 +230,103 @@ dependencies:
235230
- git+https://github.com/amorehead/atom3.git@83987404ceed38a1f5a5abd517aa38128d0a4f2c
236231
- attrs==23.1.0
237232
- babel==2.12.1
233+
- beautifulsoup4==4.12.2
238234
- biopandas==0.5.0.dev0
235+
- bioservices==1.11.2
239236
- cachetools==5.3.1
237+
- cattrs==23.1.2
240238
- click==7.0
239+
- colorlog==6.7.0
241240
- configparser==5.3.0
241+
- contourpy==1.1.0
242+
- deepdiff==6.3.1
242243
- dill==0.3.3
243244
- docker-pycreds==0.4.0
244245
- docutils==0.17.1
245246
- easy-parallel-py3==0.1.6.4
247+
- easydev==0.12.1
248+
- exceptiongroup==1.1.2
246249
- fairscale==0.4.0
250+
- fonttools==4.40.0
247251
- frozenlist==1.3.3
248252
- fsspec==2023.5.0
249253
- future==0.18.3
254+
- gevent==22.10.2
250255
- gitdb==4.0.10
251256
- gitpython==3.1.31
252257
- google-auth==2.19.0
253258
- google-auth-oauthlib==1.0.0
259+
- git+https://github.com/a-r-j/graphein.git@371ce9a462b610529488e87a712484328a89de36
260+
- greenlet==2.0.2
261+
- grequests==0.7.0
254262
- grpcio==1.54.2
255263
- h5py==3.8.0
256264
- hickle==5.0.2
257265
- imagesize==1.4.1
266+
- importlib-resources==6.0.0
258267
- install==1.3.5
268+
- jaxtyping==0.2.19
269+
- jinja2==2.11.3
259270
- loguru==0.7.0
260271
- looseversion==1.1.2
272+
- lxml==4.9.3
261273
- markdown==3.4.3
262-
- markupsafe==2.1.3
274+
- markdown-it-py==3.0.0
275+
- markupsafe==1.1.1
276+
- matplotlib==3.7.2
277+
- mdurl==0.1.2
263278
- mmtf-python==1.1.3
264279
- mpi4py==3.0.3
265280
- msgpack==1.0.5
266281
- multidict==6.0.4
282+
- multipledispatch==1.0.0
267283
- multiprocess==0.70.11.1
284+
- numpy==1.23.5
268285
- oauthlib==3.2.2
286+
- ordered-set==4.1.0
269287
- pathos==0.2.7
270288
- pathtools==0.1.2
271289
- pdb-tools==2.5.0
290+
- platformdirs==3.8.1
291+
- plotly==5.15.0
272292
- pox==0.3.2
273293
- ppft==1.7.6.6
274294
- promise==2.3
275295
- protobuf==3.20.3
276296
- pyasn1==0.5.0
277297
- pyasn1-modules==0.3.0
298+
- pydantic==1.10.11
278299
- pydeprecate==0.3.1
300+
- pyparsing==3.0.9
279301
- pytorch-lightning==1.4.8
280-
- pyyaml==6.0
302+
- pyyaml==5.4.1
303+
- requests-cache==1.1.0
281304
- requests-oauthlib==1.3.1
305+
- rich==13.4.2
306+
- rich-click==1.6.1
282307
- rsa==4.9
283308
- seaborn==0.12.2
284309
- sentry-sdk==1.24.0
285310
- shortuuid==1.0.11
286311
- smmap==5.0.0
287312
- snowballstemmer==2.2.0
313+
- soupsieve==2.4.1
288314
- subprocess32==3.5.4
315+
- suds-community==1.1.2
316+
- tenacity==8.2.2
289317
- tensorboard==2.13.0
290318
- tensorboard-data-server==0.7.0
291319
- termcolor==2.3.0
292320
- torchmetrics==0.5.1
321+
- typeguard==4.0.0
322+
- url-normalize==1.4.3
293323
- wandb==0.12.2
294324
- werkzeug==2.3.6
295325
- wget==3.2
326+
- wrapt==1.15.0
327+
- xarray==2023.1.0
328+
- xmltodict==0.13.0
296329
- yarl==1.9.2
297330
- yaspin==2.3.0
331+
- zope-event==5.0
332+
- zope-interface==6.0

project/datasets/analysis/analyze_experiment_types_and_resolution.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@
88
import numpy as np
99
import pandas as pd
1010

11+
from graphein.ml.datasets import PDBManager
1112
from pathlib import Path
1213
from tqdm import tqdm
1314

14-
from project.datasets.analysis.pdb_data import PDBManager
1515
from project.utils.utils import download_pdb_file, gunzip_file
1616

1717

@@ -79,7 +79,7 @@ def main(output_dir: str, source_type: str):
7979
# Collect (and, if necessary, extract) all training PDB files
8080
train_pdb_codes = []
8181
pairs_postprocessed_train_txt = os.path.join(output_dir, 'pairs-postprocessed-train-before-structure-based-filtering.txt')
82-
assert os.path.exists(pairs_postprocessed_train_txt), "DB5-Plus train filenames must be curated in advance to partition training and validation filenames."
82+
assert os.path.exists(pairs_postprocessed_train_txt), "DIPS-Plus train filenames must be curated in advance."
8383
with open(pairs_postprocessed_train_txt, "r") as f:
8484
train_filenames = [line.strip() for line in f.readlines()]
8585
for train_filename in tqdm(train_filenames):
@@ -117,7 +117,7 @@ def main(output_dir: str, source_type: str):
117117
# Collect (and, if necessary, extract) all validation PDB files
118118
val_pdb_codes = []
119119
pairs_postprocessed_val_txt = os.path.join(output_dir, 'pairs-postprocessed-val-before-structure-based-filtering.txt')
120-
assert os.path.exists(pairs_postprocessed_val_txt), "DB5-Plus validation filenames must be curated in advance to partition training and validation filenames."
120+
assert os.path.exists(pairs_postprocessed_val_txt), "DIPS-Plus validation filenames must be curated in advance."
121121
with open(pairs_postprocessed_val_txt, "r") as f:
122122
val_filenames = [line.strip() for line in f.readlines()]
123123
for val_filename in tqdm(val_filenames):

project/datasets/analysis/analyze_feature_correlation.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def main(output_dir: str, source_type: str, feature_types_to_correlate: str):
3030
# Collect (and, if necessary, extract) all training PDB files
3131
train_feature_values = []
3232
pairs_postprocessed_train_txt = os.path.join(output_dir, 'pairs-postprocessed-train-before-structure-based-filtering.txt')
33-
assert os.path.exists(pairs_postprocessed_train_txt), "DB5-Plus train filenames must be curated in advance to partition training and validation filenames."
33+
assert os.path.exists(pairs_postprocessed_train_txt), "DIPS-Plus train filenames must be curated in advance."
3434
with open(pairs_postprocessed_train_txt, "r") as f:
3535
train_filenames = [line.strip() for line in f.readlines()]
3636
for train_filename in tqdm(train_filenames):
@@ -68,7 +68,7 @@ def main(output_dir: str, source_type: str, feature_types_to_correlate: str):
6868
# Collect (and, if necessary, extract) all validation PDB files
6969
val_feature_values = []
7070
pairs_postprocessed_val_txt = os.path.join(output_dir, 'pairs-postprocessed-val-before-structure-based-filtering.txt')
71-
assert os.path.exists(pairs_postprocessed_val_txt), "DB5-Plus validation filenames must be curated in advance to partition training and validation filenames."
71+
assert os.path.exists(pairs_postprocessed_val_txt), "DIPS-Plus validation filenames must be curated in advance."
7272
with open(pairs_postprocessed_val_txt, "r") as f:
7373
val_filenames = [line.strip() for line in f.readlines()]
7474
for val_filename in tqdm(val_filenames):
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
import click
2+
import logging
3+
import os
4+
import warnings
5+
6+
import atom3.pair as pa
7+
import numpy as np
8+
import pandas as pd
9+
10+
from Bio import BiopythonWarning
11+
from Bio.PDB import NeighborSearch
12+
from Bio.PDB import PDBParser
13+
from pathlib import Path
14+
from tqdm import tqdm
15+
16+
from project.utils.utils import download_pdb_file, gunzip_file
17+
18+
19+
@click.command()
20+
@click.argument('output_dir', default='../DIPS/final/raw', type=click.Path())
21+
@click.option('--source_type', default='rcsb', type=click.Choice(['rcsb', 'db5']))
22+
@click.option('--interfacing_water_distance_cutoff', default=10.0, type=float)
23+
def main(output_dir: str, source_type: str, interfacing_water_distance_cutoff: float):
24+
logger = logging.getLogger(__name__)
25+
logger.info("Analyzing interface waters within each dataset example...")
26+
27+
if source_type.lower() == "rcsb":
28+
parser = PDBParser()
29+
30+
# Filter and suppress BioPython warnings
31+
warnings.filterwarnings("ignore", category=BiopythonWarning)
32+
33+
# Collect (and, if necessary, extract) all training PDB files
34+
train_num_complexes = 0
35+
train_complex_num_waters = 0
36+
pairs_postprocessed_train_txt = os.path.join(output_dir, 'pairs-postprocessed-train-before-structure-based-filtering.txt')
37+
assert os.path.exists(pairs_postprocessed_train_txt), "DIPS-Plus train filenames must be curated in advance."
38+
with open(pairs_postprocessed_train_txt, "r") as f:
39+
train_filenames = [line.strip() for line in f.readlines()]
40+
for train_filename in tqdm(train_filenames):
41+
try:
42+
postprocessed_train_pair: pa.Pair = pd.read_pickle(os.path.join(output_dir, train_filename))
43+
except Exception as e:
44+
logging.error(f"Could not open postprocessed training pair {os.path.join(output_dir, train_filename)} due to: {e}")
45+
continue
46+
pdb_code = postprocessed_train_pair.df0.pdb_name[0].split("_")[0][1:3]
47+
pdb_dir = os.path.join(Path(output_dir).parent.parent, "raw", "pdb", pdb_code)
48+
l_b_pdb_filepath = os.path.join(pdb_dir, postprocessed_train_pair.df0.pdb_name[0])
49+
r_b_pdb_filepath = os.path.join(pdb_dir, postprocessed_train_pair.df1.pdb_name[0])
50+
l_b_df0_chains = postprocessed_train_pair.df0.chain.unique()
51+
r_b_df1_chains = postprocessed_train_pair.df1.chain.unique()
52+
assert (
53+
len(postprocessed_train_pair.df0.pdb_name.unique()) == len(l_b_df0_chains) == 1
54+
), "Only a single PDB filename and chain identifier can be associated with a single training example."
55+
assert (
56+
len(postprocessed_train_pair.df1.pdb_name.unique()) == len(r_b_df1_chains) == 1
57+
), "Only a single PDB filename and chain identifier can be associated with a single training example."
58+
if not os.path.exists(l_b_pdb_filepath) and os.path.exists(l_b_pdb_filepath + ".gz"):
59+
gunzip_file(l_b_pdb_filepath)
60+
if not os.path.exists(r_b_pdb_filepath) and os.path.exists(r_b_pdb_filepath + ".gz"):
61+
gunzip_file(r_b_pdb_filepath)
62+
if not os.path.exists(l_b_pdb_filepath):
63+
download_pdb_file(os.path.basename(l_b_pdb_filepath), l_b_pdb_filepath)
64+
if not os.path.exists(r_b_pdb_filepath):
65+
download_pdb_file(os.path.basename(r_b_pdb_filepath), r_b_pdb_filepath)
66+
assert os.path.exists(l_b_pdb_filepath) and os.path.exists(r_b_pdb_filepath), "Both left and right-bound PDB files collected must exist."
67+
68+
l_b_structure = parser.get_structure('protein', l_b_pdb_filepath)
69+
r_b_structure = parser.get_structure('protein', r_b_pdb_filepath)
70+
71+
l_b_interface_residues = postprocessed_train_pair.df0[postprocessed_train_pair.df0.index.isin(postprocessed_train_pair.pos_idx[:, 0])]
72+
r_b_interface_residues = postprocessed_train_pair.df1[postprocessed_train_pair.df1.index.isin(postprocessed_train_pair.pos_idx[:, 1])]
73+
74+
train_num_complexes += 1
75+
76+
l_b_ns = NeighborSearch(list(l_b_structure.get_atoms()))
77+
for index, row in l_b_interface_residues.iterrows():
78+
chain_id = row['chain']
79+
residue = row['residue'].strip()
80+
model = l_b_structure[0]
81+
chain = model[chain_id]
82+
if residue.lstrip("-").isdigit():
83+
residue = int(residue)
84+
else:
85+
residue_index, residue_icode = residue[:-1], residue[-1:]
86+
if residue_icode.strip() == "":
87+
residue = int(residue)
88+
else:
89+
residue = (" ", int(residue_index), residue_icode)
90+
target_residue = chain[residue]
91+
target_coords = np.array([atom.get_coord() for atom in target_residue.get_atoms() if atom.get_name() == 'CA']).squeeze()
92+
interfacing_atoms = l_b_ns.search(target_coords, interfacing_water_distance_cutoff, 'A')
93+
waters_within_threshold = [atom for atom in interfacing_atoms if atom.get_parent().get_resname() in ['HOH', 'WAT']]
94+
train_complex_num_waters += len(waters_within_threshold)
95+
96+
r_b_ns = NeighborSearch(list(r_b_structure.get_atoms()))
97+
for index, row in r_b_interface_residues.iterrows():
98+
chain_id = row['chain']
99+
residue = row['residue'].strip()
100+
model = r_b_structure[0]
101+
chain = model[chain_id]
102+
if residue.lstrip("-").isdigit():
103+
residue = int(residue)
104+
else:
105+
residue_index, residue_icode = residue[:-1], residue[-1:]
106+
residue = (" ", int(residue_index), residue_icode)
107+
target_residue = chain[residue]
108+
target_coords = np.array([atom.get_coord() for atom in target_residue.get_atoms() if atom.get_name() == 'CA']).squeeze()
109+
interfacing_atoms = r_b_ns.search(target_coords, interfacing_water_distance_cutoff, 'A')
110+
waters_within_threshold = [atom for atom in interfacing_atoms if atom.get_parent().get_resname() in ['HOH', 'WAT']]
111+
train_complex_num_waters += len(waters_within_threshold)
112+
113+
# Collect (and, if necessary, extract) all validation PDB files
114+
val_num_complexes = 0
115+
val_complex_num_waters = 0
116+
pairs_postprocessed_val_txt = os.path.join(output_dir, 'pairs-postprocessed-val-before-structure-based-filtering.txt')
117+
assert os.path.exists(pairs_postprocessed_val_txt), "DIPS-Plus validation filenames must be curated in advance."
118+
with open(pairs_postprocessed_val_txt, "r") as f:
119+
val_filenames = [line.strip() for line in f.readlines()]
120+
for val_filename in tqdm(val_filenames):
121+
try:
122+
postprocessed_val_pair: pa.Pair = pd.read_pickle(os.path.join(output_dir, val_filename))
123+
except Exception as e:
124+
logging.error(f"Could not open postprocessed validation pair {os.path.join(output_dir, val_filename)} due to: {e}")
125+
continue
126+
pdb_code = postprocessed_val_pair.df0.pdb_name[0].split("_")[0][1:3]
127+
pdb_dir = os.path.join(Path(output_dir).parent.parent, "raw", "pdb", pdb_code)
128+
l_b_pdb_filepath = os.path.join(pdb_dir, postprocessed_val_pair.df0.pdb_name[0])
129+
r_b_pdb_filepath = os.path.join(pdb_dir, postprocessed_val_pair.df1.pdb_name[0])
130+
l_b_df0_chains = postprocessed_val_pair.df0.chain.unique()
131+
r_b_df1_chains = postprocessed_val_pair.df1.chain.unique()
132+
assert (
133+
len(postprocessed_val_pair.df0.pdb_name.unique()) == len(l_b_df0_chains) == 1
134+
), "Only a single PDB filename and chain identifier can be associated with a single validation example."
135+
assert (
136+
len(postprocessed_val_pair.df1.pdb_name.unique()) == len(r_b_df1_chains) == 1
137+
), "Only a single PDB filename and chain identifier can be associated with a single validation example."
138+
if not os.path.exists(l_b_pdb_filepath) and os.path.exists(l_b_pdb_filepath + ".gz"):
139+
gunzip_file(l_b_pdb_filepath)
140+
if not os.path.exists(r_b_pdb_filepath) and os.path.exists(r_b_pdb_filepath + ".gz"):
141+
gunzip_file(r_b_pdb_filepath)
142+
if not os.path.exists(l_b_pdb_filepath):
143+
download_pdb_file(os.path.basename(l_b_pdb_filepath), l_b_pdb_filepath)
144+
if not os.path.exists(r_b_pdb_filepath):
145+
download_pdb_file(os.path.basename(r_b_pdb_filepath), r_b_pdb_filepath)
146+
assert os.path.exists(l_b_pdb_filepath) and os.path.exists(r_b_pdb_filepath), "Both left and right-bound PDB files collected must exist."
147+
148+
l_b_structure = parser.get_structure('protein', l_b_pdb_filepath)
149+
r_b_structure = parser.get_structure('protein', r_b_pdb_filepath)
150+
151+
l_b_interface_residues = postprocessed_val_pair.df0[postprocessed_val_pair.df0.index.isin(postprocessed_val_pair.pos_idx[:, 0])]
152+
r_b_interface_residues = postprocessed_val_pair.df1[postprocessed_val_pair.df1.index.isin(postprocessed_val_pair.pos_idx[:, 1])]
153+
154+
val_num_complexes += 1
155+
156+
l_b_ns = NeighborSearch(list(l_b_structure.get_atoms()))
157+
for index, row in l_b_interface_residues.iterrows():
158+
chain_id = row['chain']
159+
residue = row['residue'].strip()
160+
model = l_b_structure[0]
161+
chain = model[chain_id]
162+
if residue.lstrip("-").isdigit():
163+
residue = int(residue)
164+
else:
165+
residue_index, residue_icode = residue[:-1], residue[-1:]
166+
residue = (" ", int(residue_index), residue_icode)
167+
target_residue = chain[residue]
168+
target_coords = np.array([atom.get_coord() for atom in target_residue.get_atoms() if atom.get_name() == 'CA']).squeeze()
169+
interfacing_atoms = l_b_ns.search(target_coords, interfacing_water_distance_cutoff, 'A')
170+
waters_within_threshold = [atom for atom in interfacing_atoms if atom.get_parent().get_resname() in ['HOH', 'WAT']]
171+
val_complex_num_waters += len(waters_within_threshold)
172+
173+
r_b_ns = NeighborSearch(list(r_b_structure.get_atoms()))
174+
for index, row in r_b_interface_residues.iterrows():
175+
chain_id = row['chain']
176+
residue = row['residue'].strip()
177+
model = r_b_structure[0]
178+
chain = model[chain_id]
179+
if residue.lstrip("-").isdigit():
180+
residue = int(residue)
181+
else:
182+
residue_index, residue_icode = residue[:-1], residue[-1:]
183+
residue = (" ", int(residue_index), residue_icode)
184+
target_residue = chain[residue]
185+
target_coords = np.array([atom.get_coord() for atom in target_residue.get_atoms() if atom.get_name() == 'CA']).squeeze()
186+
interfacing_atoms = r_b_ns.search(target_coords, interfacing_water_distance_cutoff, 'A')
187+
waters_within_threshold = [atom for atom in interfacing_atoms if atom.get_parent().get_resname() in ['HOH', 'WAT']]
188+
val_complex_num_waters += len(waters_within_threshold)
189+
190+
# Train complexes
191+
train_num_waters_per_complex = train_complex_num_waters / train_num_complexes
192+
logging.info(f"Number of waters, on average, in each training complex: {train_num_waters_per_complex}")
193+
194+
# Validation complexes
195+
val_num_waters_per_complex = val_complex_num_waters / val_num_complexes
196+
logging.info(f"Number of waters, on average, in each validation complex: {val_num_waters_per_complex}")
197+
198+
# Train + Validation complexes
199+
train_val_num_waters_per_complex = (train_complex_num_waters + val_complex_num_waters) / (train_num_complexes + val_num_complexes)
200+
logging.info(f"Number of waters, on average, in each training (or validation) complex: {train_val_num_waters_per_complex}")
201+
202+
logger.info("Finished analyzing interface waters for all training and validation complexes")
203+
204+
else:
205+
raise NotImplementedError(f"Source type {source_type} is currently not supported.")
206+
207+
208+
if __name__ == "__main__":
209+
log_fmt = '%(asctime)s %(levelname)s %(process)d: %(message)s'
210+
logging.basicConfig(level=logging.INFO, format=log_fmt)
211+
212+
main()

0 commit comments

Comments
 (0)