diff --git a/alphafold/model/model.py b/alphafold/model/model.py index ba6938e20..9abea9087 100644 --- a/alphafold/model/model.py +++ b/alphafold/model/model.py @@ -29,6 +29,7 @@ import tensorflow.compat.v1 as tf import tree + class RunModel: """Container for JAX model.""" @@ -123,7 +124,8 @@ def predict(self, random_seed: int = 0, return_representations: bool = False, fix_single_representation: bool = True, - callback: Any = None) -> Mapping[str, Any]: + callback: Any = None, + manipulation_callback: Any = None) -> Mapping[str, Any]: """Makes a prediction by inferencing the model on the provided features. Args: @@ -199,6 +201,10 @@ def _jnp_to_np(x): # callback if callback is not None: callback(result, r) + # modify prev representations to use in next recycling iteration + if manipulation_callback is not None: + prev = manipulation_callback(prev) + # decide when to stop if result["ranking_confidence"] > self.config.model.stop_at_score: break