diff --git a/deepfrier/Predictor.py b/deepfrier/Predictor.py index 3381d73..83500ad 100644 --- a/deepfrier/Predictor.py +++ b/deepfrier/Predictor.py @@ -54,6 +54,8 @@ def __init__(self, model_prefix, gcn=True): self.model_prefix = model_prefix self.gcn = gcn self._load_model() + self.prot2goterms = {} + self.goidx2chains = {} def _load_model(self): self.model = tf.keras.models.load_model(self.model_prefix + '.hdf5', @@ -130,6 +132,44 @@ def predict(self, test_prot, cmap_thresh=10.0, chain='query_prot'): self.goidx2chains[idx].add(chain) self.prot2goterms[chain].append((self.goterms[idx], self.gonames[idx], float(y[idx]))) + def predict_with_cmap(self, seqres, cmap, chain): + self.Y_hat = np.zeros((1, len(self.goterms)), dtype=float) + self.data = {} + self.test_prot_list = [chain] + if self.gcn: + S = seq2onehot(seqres) + S = S.reshape(1, *S.shape) + A = cmap + y = self.model([A, S], training=False).numpy()[:, :, 0].reshape(-1) + self.Y_hat[0] = y + self.prot2goterms[chain] = [] + self.data[chain] = [[A, S], seqres] + go_idx = np.where((y >= self.thresh) == True)[0] + for idx in go_idx: + if idx not in self.goidx2chains: + self.goidx2chains[idx] = set() + self.goidx2chains[idx].add(chain) + self.prot2goterms[chain].append((self.goterms[idx], self.gonames[idx], float(y[idx]))) + + def predict_from_sequence(self, sequence, chain): + self.test_prot_list = [chain] + self.Y_hat = np.zeros((len(self.test_prot_list), len(self.goterms)), dtype=float) + self.data = {} + + for i, chain in enumerate(self.test_prot_list): + S = seq2onehot(str(sequence)) + S = S.reshape(1, *S.shape) + y = self.model(S, training=False).numpy()[:, :, 0].reshape(-1) + self.Y_hat[i] = y + self.prot2goterms[chain] = [] + self.data[chain] = [[S], sequence] + go_idx = np.where((y >= self.thresh) == True)[0] + for idx in go_idx: + if idx not in self.goidx2chains: + self.goidx2chains[idx] = set() + self.goidx2chains[idx].add(chain) + self.prot2goterms[chain].append((self.goterms[idx], self.gonames[idx], float(y[idx]))) + def predict_from_PDB_dir(self, dir_name, cmap_thresh=10.0): print ("### Computing predictions from directory with PDB files...") pdb_fn_list = glob.glob(dir_name + '/*.pdb*')