diff --git a/src/jaxued/level_sampler.py b/src/jaxued/level_sampler.py index 2d5a729..e5f9455 100644 --- a/src/jaxued/level_sampler.py +++ b/src/jaxued/level_sampler.py @@ -236,7 +236,9 @@ def get_levels_extra(self, sampler: Sampler, level_idx: int) -> dict: def update(self, sampler: Sampler, idx: int, score: float, level_extra: dict=None) -> Sampler: """ - This updates the score and level_extras of a level + This updates the score and level_extras of a level. + The update is performed only if the score is not -inf, i.e. if at least + one episode was completed during the policy rollout. Args: sampler (Sampler): The sampler object @@ -247,13 +249,18 @@ def update(self, sampler: Sampler, idx: int, score: float, level_extra: dict=Non Returns: Sampler: Updated Sampler """ - new_sampler = { - **sampler, - "scores": sampler["scores"].at[idx].set(score), - } - if level_extra is not None: - new_sampler["levels_extra"] = jax.tree_map(lambda x, y: x.at[idx].set(y), new_sampler["levels_extra"], level_extra) - return new_sampler + update_cond = score > -jnp.inf + + def _replace(): + new_sampler = { + **sampler, + "scores": sampler["scores"].at[idx].set(score), + } + if level_extra is not None: + new_sampler["levels_extra"] = jax.tree_map(lambda x, y: x.at[idx].set(y), new_sampler["levels_extra"], level_extra) + return new_sampler + + return jax.lax.cond(update_cond, _replace, lambda: sampler,) def update_batch(self, sampler: Sampler, level_inds: chex.Array, scores: chex.Array, level_extras: dict=None) -> Sampler: """