fix: -jnp.inf scores could be set when updating the plr buffer
#6
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
While experimenting with the
maze_plr.pyexample, I realized that the metricsweighted_scoreandmean_scorewhere being reported as-Infinityin Weights and Biases.After investigating a bit, I saw this was not a problem when sampling new levels: the
replace_condcondition in the_insert_newmethod fromlevel_sampler.pyensures that-jnp.infscores cannot be pushed into the buffer since the minimum score in the buffer is precisely that amount.Instead, the problem is introduced when levels are replayed and their scores need to be updated in the buffer. If the rollout for a given level doesn't contain at least one full episode (i.e., the rollout does not end in a terminal state or reach the maximum number of steps), the score is
-jnp.inf. Theupdatemethod from theLevelSamplerdoes not account for this possibility, and can hence introduce-jnp.infscores leading to problems in theweighted_scoreandmean_scoremetrics computation; besides, I think it would make the sampling of the associated levels based on staleness only.This phenomenon does not occur with the default PPO rollout length used in
jaxued(256) since the maximum episode length used in Maze is 250; therefore, it is guaranteed that each rollout will contain at least one episode, hence avoiding the-jnp.inf. By setting the rollout length to a value lower than 250 (e.g., 128), the phenomenon is easily reproduced.I investigated whether other UED codebases perform a similar check to the one I have introduced. For example,
minimaxdoes it here (note thatignore_valis-jnp.inf, like in your approach).My suggestion is to perform a similar check to the one in
minimax, which seems to do the job as shown below. You can see that forrollout-128-bugthere are many "bullet points" at the bottom of the plot representing the-inf. I've followed a similar structure to the one you used forinsert_new.Hope this is useful!