diff --git a/src/clm/commands/write_structural_prior_CV.py b/src/clm/commands/write_structural_prior_CV.py index 7a48a018..e5c1b855 100644 --- a/src/clm/commands/write_structural_prior_CV.py +++ b/src/clm/commands/write_structural_prior_CV.py @@ -242,6 +242,41 @@ def write_structural_prior_CV( logger.info("Reading sample file from generative model") gen = read_csv_file(sample_file) + # some SMILES may be invalid when tabulate_molecules used a different + # rdkit version -- validate only generated SMILES that are candidates to + # match a test molecule + gen_sorted = gen.sort_values("mass", kind="stable") + gen_masses = gen_sorted["mass"].values + lo_vals = test["mass_range"].apply(lambda r: r[0]).values + hi_vals = test["mass_range"].apply(lambda r: r[1]).values + lefts = np.searchsorted(gen_masses, lo_vals, side="left") + rights = np.searchsorted(gen_masses, hi_vals, side="right") + candidate_positions = set() + for left, right in zip(lefts, rights): + candidate_positions.update(range(left, right)) + candidates = gen_sorted.iloc[sorted(candidate_positions)] + invalid_idx = candidates.index[ + candidates["smiles"].progress_apply( + lambda s: clean_mol(s, raise_error=False) is None + ) + ] + + gen = gen.drop(invalid_idx) + + n_candidates = len(candidates) + n_invalid = len(invalid_idx) + + # log if invalid SMILES were detected and removed + if n_invalid > 0: + examples = gen.loc[invalid_idx, "smiles"].head(5).tolist() + + logger.warning( + f"Removed {n_invalid} invalid SMILES among " + f"{n_candidates} candidates to match a test molecule " + f"(possibly due to a different rdkit version). " + f"Examples: {examples}" + ) + inputs = {"model": gen.assign(source="model")} if pubchem_file: