Skip to content

Commit

Permalink
Update scripts (#15)
Browse files Browse the repository at this point in the history
By David Hahn and Victoria Lim

* Update plots.

* Merge master.

* Update analysis scripts.

* Update analysis scripts.

* Update color_by_moiety.py.

* Update compare_ffs.py.

* Black style compare_ffs.py.

* Uncomment line plots of rel. energies and bar plots of RMSEs per FF.

* Remove customization of labels in scripts color_by_moiety.py, compare_ffs.py and match_minima.py.

* Fix lgtm alerts.

* A couple of final LGTM alerts
  • Loading branch information
dfhahn authored Sep 27, 2020
1 parent f5ddb3d commit 339b806
Show file tree
Hide file tree
Showing 7 changed files with 623 additions and 253 deletions.
8 changes: 4 additions & 4 deletions 02_calc/minimize_ffs.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,11 +220,11 @@ def min_gaffx(mol, ofs, gaff2=False):
openmoltools.amber.run_antechamber(title, tmol2, charge_method=None,
gaff_mol2_filename = gmol2, frcmod_filename = frc,
gaff_version = invar)
except Exception:
except Exception as e:
# earlier smilabel seems to be missing
smilabel = oechem.OEGetSDData(mol, "SMILES QCArchive")
print( ' >>> Antechamber failed to produce GAFF mol2 file: '
f'{title} {smilabel}')
f'{title} {smilabel}: {e}')
return

# generate gaff inpcrd and prmtop files
Expand Down Expand Up @@ -282,10 +282,10 @@ def min_ffxml(mol, ofs, ffxml):
#system = ff.create_openmm_system(topology, charge_from_molecules=[off_mol])
system = ff.create_openmm_system(topology)

except Exception:
except Exception as e:
smilabel = oechem.OEGetSDData(oe_mol, "SMILES QCArchive")
print( ' >>> openforcefield failed to create OpenMM system: '
f'{oe_mol.GetTitle()} {smilabel}')
f'{oe_mol.GetTitle()} {smilabel}: {e}')
return

positions = structure.extractPositionsFromOEMol(oe_mol)
Expand Down
156 changes: 107 additions & 49 deletions 03_analysis/color_by_moiety.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,29 @@
import re
import numpy as np
import pickle
import matplotlib as mpl
import matplotlib.pyplot as plt
import openeye.oechem as oechem
import reader

def draw_scatter_moiety(x_data, y_data, all_x_subset, all_y_subset,
labels_subset, x_label, y_label, out_file, what_for='talk',
x_range=None, y_range=None):
import seaborn as sns
import pandas as pd

sns.set(rc={"text.usetex": True})


def draw_scatter_moiety(
x_data,
y_data,
all_x_subset,
all_y_subset,
labels_subset,
method_label,
x_label,
y_label,
out_file,
what_for="talk",
x_range=None,
y_range=None,
):

"""
Draw a scatter plot and color only a subset of points.
Expand Down Expand Up @@ -60,51 +75,58 @@ def draw_scatter_moiety(x_data, y_data, all_x_subset, all_y_subset,
"""
print(f"Number of data points in full scatter plot: {len(x_data)}")

num_methods = len(x_data)

# set plot limits if specified
if x_range is not None:
plt.xlim(x_range[0], x_range[1])
if y_range is not None:
plt.ylim(y_range[0], y_range[1])

# set log scaling but use symmetric log for negative values
#plt.yscale('symlog')
# plt.yscale('symlog')

if what_for == 'paper':
if what_for == "paper":
fig = plt.gcf()
fig.set_size_inches(4, 3)
plt.xlabel(x_label, fontsize=10)
plt.ylabel(y_label, fontsize=10)
plt.xticks(fontsize=10)
plt.yticks(fontsize=10)
plt_options = {'s':10, 'alpha':1.0}
plt.rc('legend', fontsize=10)
plt_options = {"s": 10, "alpha": 1.0}
plt.rc("legend", fontsize=10)

elif what_for == 'talk':
elif what_for == "talk":
fig = plt.gcf()
fig.set_size_inches(8, 6)
plt.xlabel(x_label, fontsize=14)
plt.ylabel(y_label, fontsize=14)
plt.xticks(fontsize=12)
plt.yticks(fontsize=12)
plt_options = {'s':10, 'alpha':1.0}
plt.rc('legend', fontsize=14)
plt_options = {"s": 10, "alpha": 1.0}
plt.rc("legend", fontsize=14)

# generate the plot with full set
plt.scatter(x_data, y_data,
c='white', edgecolors='grey', **plt_options)
plt.scatter(x_data, y_data, c="white", edgecolors="grey", **plt_options)

# generate the plot with subset(s)
subset_colors = ['#01959f', '#614051', '#e96058'] # todo generalize
for i, (xs, ys, lab) in enumerate(zip(all_x_subset, all_y_subset, labels_subset)):
print(f"Number of data points in subset {i}: {len(xs)}")
plt.scatter(xs, ys, label=lab, c=subset_colors[i], zorder=2, **plt_options)

plt.legend(loc=(0.2, 1.04))
plt.savefig(out_file, bbox_inches='tight')
subset_colors = ["#01959f", "#614051", "#e96058"] # todo generalize
with open("statistics.dat", "a") as file:
file.write(
f"{method_label:19s},{'all':9s},{len(x_data):9d},{np.average(x_data[~np.isnan(x_data)]):9.3f},{np.std(x_data[~np.isnan(x_data)]):9.3f},{np.average(y_data[~np.isnan(y_data)]):9.3f},{np.std(y_data[~np.isnan(y_data)]):9.3f}\n"
)
for i, (xs, ys, lab) in enumerate(
zip(all_x_subset, all_y_subset, labels_subset)
):
print(f"Number of data points in subset {i}: {len(xs)}")
file.write(
f"{method_label:19s},{lab:9s},{len(xs):9d},{np.average(xs):9.3f},{np.std(xs):9.3f},{np.average(ys):9.3f},{np.std(ys):9.3f}\n"
)
plt.scatter(xs, ys, label=lab, c=subset_colors[i], zorder=2, **plt_options)

plt.title(method_label)
plt.legend(loc=(0.35, 0.02))
plt.savefig(out_file, bbox_inches="tight")
plt.clf()
#plt.show()
# plt.show()


def main(in_dict, pickle_file, smi_files, out_prefix):
Expand All @@ -131,8 +153,7 @@ def main(in_dict, pickle_file, smi_files, out_prefix):
num_methods = len(method_labels)

# enes_full[i][j][k] = ddE of ith method, jth mol, kth conformer.
enes_full, rmsds_full, tfds_full, smiles_full = pickle.load(
open(pickle_file, 'rb'))
enes_full, rmsds_full, tfds_full, smiles_full = pickle.load(open(pickle_file, "rb"))

all_smi_subsets = []
labels_subset = []
Expand All @@ -143,9 +164,7 @@ def main(in_dict, pickle_file, smi_files, out_prefix):
smiles_subset = [x.strip() for x in smiles_subset]
all_smi_subsets.append(smiles_subset)

x_subset = []
y_subset = []
all_inds_subset = [] # list[i][j] has the smiles indices for ith subset
all_inds_subset = [] # list[i][j] has the smiles indices for ith subset

# flatten list of lists of smiles from pickle file
smi_flat = [val for sublist in smiles_full[0] for val in sublist]
Expand All @@ -164,10 +183,9 @@ def main(in_dict, pickle_file, smi_files, out_prefix):
# store and move onto next subset file
all_inds_subset.append(inds_subset)

for i in range(num_methods-1):

for i in range(num_methods - 1):
# get output filename and make sure it has no forbidden characters
out_file = out_prefix + method_labels[i+1] + '.png'
out_file = out_prefix + method_labels[i + 1] + ".png"
out_file = re.sub(r'[\\/*?:"<>|]', "", out_file)
print(f"\n{out_file}")

Expand All @@ -187,50 +205,86 @@ def main(in_dict, pickle_file, smi_files, out_prefix):
print(f"min/max y: {np.nanmin(y_data):10.4f}\t{np.nanmax(y_data):10.4f}")

draw_scatter_moiety(
x_data, y_data,
all_x_subset, all_y_subset, labels_subset,
x_data,
y_data,
all_x_subset,
all_y_subset,
labels_subset,
method_labels[i + 1],
"TFD",
"ddE (kcal/mol)",
out_file,
what_for='paper',
what_for="paper",
x_range=(0, 0.8),
y_range=(-50, 30))
y_range=(-50, 30),
)

data = pd.read_csv("statistics.dat", sep=",")
print(data)
for i, dat, lab, unit in zip(
range(4),
["avg_tfd", "std_tfd", "avg_ene", "std_ene"],
[
"$\overline{\mathrm{TFD}}$",
"$\sigma(\mathrm{TFD})$",
"$\overline{\mathrm{dd}E}$",
"$\sigma(\mathrm{dd}E)$",
],
["", "", " [kcal/mol]", " [kcal/mol]"],
):
sns.barplot(x=data.iloc[:, 0], y=data.iloc[:, i + 3], hue=data.iloc[:, 1])
plt.xticks(rotation=90)
plt.xlabel("")
plt.ylabel(lab + unit)
plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.0)
plt.savefig(f"{dat}.png", bbox_inches="tight")
plt.clf()


### ------------------- Parser -------------------

if __name__ == "__main__":
import argparse

parser = argparse.ArgumentParser()

# parse slice if not analyzing full set
# https://stackoverflow.com/questions/18632320/numpy-array-indices-via-argparse-how-to-do-it-properly
def _parse_slice(inslice):
if inslice == 'all':
if inslice == "all":
return slice(None)
try:
section = int(inslice)
except ValueError:
section = [int(s) if s else None for s in inslice.split(':')]
section = [int(s) if s else None for s in inslice.split(":")]
if len(section) > 3:
raise ValueError('error parsing input slice')
raise ValueError("error parsing input slice")
section = slice(*section)
return section

parser.add_argument("-i", "--infile",
parser.add_argument(
"-i",
"--infile",
help="Name of text file with force field in first column and molecule "
"file in second column. Columns separated by commas.")
"file in second column. Columns separated by commas.",
)

parser.add_argument("-p", "--picklefile",
help="Pickle file from compare_ffs.py analysis")
parser.add_argument(
"-p", "--picklefile", help="Pickle file from compare_ffs.py analysis"
)

parser.add_argument("-s", "--smifiles", nargs='+',
parser.add_argument(
"-s",
"--smifiles",
nargs="+",
help="One or more text files with SMILES from SD tags for "
"molecules to color in plot. If multiple files, list them "
"in order from bottommost color to topmost color.")
"molecules to color in plot. If multiple files, list them "
"in order from bottommost color to topmost color.",
)

parser.add_argument("-o", "--out_prefix",
help="Prefix of the names of the output plots")
parser.add_argument(
"-o", "--out_prefix", help="Prefix of the names of the output plots"
)

# parse arguments
args = parser.parse_args()
Expand All @@ -247,5 +301,9 @@ def _parse_slice(inslice):

# run main
print("Log file from color_by_moiety.py\n")
with open("statistics.dat", "w") as file:
file.write("")
file.write(
f"#{'FF':18s},{'moiety':9s},{'count':>9s},{'avg_tfd':>9s},{'std_tfd':>9s},{'avg_dde':>9s},{'std_dde':>9s}\n"
)
main(in_dict, args.picklefile, args.smifiles, args.out_prefix)

Loading

0 comments on commit 339b806

Please sign in to comment.