Skip to content

Commit fb83378

Browse files
committed
Catch corner cases in feature analysis script
1 parent 71d22c8 commit fb83378

File tree

1 file changed

+7
-9
lines changed

1 file changed

+7
-9
lines changed

project/datasets/analysis/analyze_feature_correlation.py

+7-9
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import atom3.pair as pa
66
import matplotlib.pyplot as plt
7+
import numpy as np
78
import pandas as pd
89
import seaborn as sns
910

@@ -17,15 +18,12 @@
1718
@click.command()
1819
@click.argument('output_dir', default='../DIPS/final/raw', type=click.Path())
1920
@click.option('--source_type', default='rcsb', type=click.Choice(['rcsb', 'db5']))
20-
@click.option('--feature_types_to_correlate', default='rcsb', type=click.Choice(['rsa_value-rd_value']))
21+
@click.option('--feature_types_to_correlate', default='rcsb', type=click.Choice(['rsa_value-rd_value', 'rsa_value-cn_value', 'rd_value-cn_value']))
2122
def main(output_dir: str, source_type: str, feature_types_to_correlate: str):
2223
logger = logging.getLogger(__name__)
2324
logger.info("Analyzing feature correlation for each dataset example...")
2425

25-
if feature_types_to_correlate == "rsa_value-rd_value":
26-
features_to_correlate = feature_types_to_correlate.split("-")
27-
else:
28-
raise NotImplementedError(f"Feature types {features_to_correlate} are currently not supported.")
26+
features_to_correlate = feature_types_to_correlate.split("-")
2927
assert len(features_to_correlate) == 2, "Exactly two features may be currently compared for correlation measures."
3028

3129
if source_type.lower() == "rcsb":
@@ -63,8 +61,8 @@ def main(output_dir: str, source_type: str, feature_types_to_correlate: str):
6361
download_pdb_file(os.path.basename(r_b_pdb_filepath), r_b_pdb_filepath)
6462
assert os.path.exists(l_b_pdb_filepath) and os.path.exists(r_b_pdb_filepath), "Both left and right-bound PDB files collected must exist."
6563

66-
l_b_df0_feature_values = postprocessed_train_pair.df0[features_to_correlate].dropna()
67-
r_b_df1_feature_values = postprocessed_train_pair.df1[features_to_correlate].dropna()
64+
l_b_df0_feature_values = postprocessed_train_pair.df0[features_to_correlate].applymap(lambda x: np.nan if x == 'NA' else x).dropna().apply(pd.to_numeric)
65+
r_b_df1_feature_values = postprocessed_train_pair.df1[features_to_correlate].applymap(lambda x: np.nan if x == 'NA' else x).dropna().apply(pd.to_numeric)
6866
train_feature_values.append(pd.concat([l_b_df0_feature_values, r_b_df1_feature_values]))
6967

7068
# Collect (and, if necessary, extract) all validation PDB files
@@ -101,8 +99,8 @@ def main(output_dir: str, source_type: str, feature_types_to_correlate: str):
10199
download_pdb_file(os.path.basename(r_b_pdb_filepath), r_b_pdb_filepath)
102100
assert os.path.exists(l_b_pdb_filepath) and os.path.exists(r_b_pdb_filepath), "Both left and right-bound PDB files collected must exist."
103101

104-
l_b_df0_feature_values = postprocessed_val_pair.df0[features_to_correlate].dropna()
105-
r_b_df1_feature_values = postprocessed_val_pair.df1[features_to_correlate].dropna()
102+
l_b_df0_feature_values = postprocessed_val_pair.df0[features_to_correlate].applymap(lambda x: np.nan if x == 'NA' else x).dropna().apply(pd.to_numeric)
103+
r_b_df1_feature_values = postprocessed_val_pair.df1[features_to_correlate].applymap(lambda x: np.nan if x == 'NA' else x).dropna().apply(pd.to_numeric)
106104
val_feature_values.append(pd.concat([l_b_df0_feature_values, r_b_df1_feature_values]))
107105

108106
# Train PDBs

0 commit comments

Comments
 (0)