Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 15 additions & 8 deletions src/jaxued/level_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,9 @@ def get_levels_extra(self, sampler: Sampler, level_idx: int) -> dict:

def update(self, sampler: Sampler, idx: int, score: float, level_extra: dict=None) -> Sampler:
"""
This updates the score and level_extras of a level
This updates the score and level_extras of a level.
The update is performed only if the score is not -inf, i.e. if at least
one episode was completed during the policy rollout.

Args:
sampler (Sampler): The sampler object
Expand All @@ -247,13 +249,18 @@ def update(self, sampler: Sampler, idx: int, score: float, level_extra: dict=Non
Returns:
Sampler: Updated Sampler
"""
new_sampler = {
**sampler,
"scores": sampler["scores"].at[idx].set(score),
}
if level_extra is not None:
new_sampler["levels_extra"] = jax.tree_map(lambda x, y: x.at[idx].set(y), new_sampler["levels_extra"], level_extra)
return new_sampler
update_cond = score > -jnp.inf

def _replace():
new_sampler = {
**sampler,
"scores": sampler["scores"].at[idx].set(score),
}
if level_extra is not None:
new_sampler["levels_extra"] = jax.tree_map(lambda x, y: x.at[idx].set(y), new_sampler["levels_extra"], level_extra)
return new_sampler

return jax.lax.cond(update_cond, _replace, lambda: sampler,)

def update_batch(self, sampler: Sampler, level_inds: chex.Array, scores: chex.Array, level_extras: dict=None) -> Sampler:
"""
Expand Down