Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion alphafold/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import tensorflow.compat.v1 as tf
import tree


class RunModel:
"""Container for JAX model."""

Expand Down Expand Up @@ -123,7 +124,8 @@ def predict(self,
random_seed: int = 0,
return_representations: bool = False,
fix_single_representation: bool = True,
callback: Any = None) -> Mapping[str, Any]:
callback: Any = None,
manipulation_callback: Any = None) -> Mapping[str, Any]:
"""Makes a prediction by inferencing the model on the provided features.

Args:
Expand Down Expand Up @@ -199,6 +201,10 @@ def _jnp_to_np(x):
# callback
if callback is not None: callback(result, r)

# modify prev representations to use in next recycling iteration
if manipulation_callback is not None:
prev = manipulation_callback(prev)

# decide when to stop
if result["ranking_confidence"] > self.config.model.stop_at_score:
break
Expand Down