You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Some parts of TaskTrainer.__call__ are not fit for vmapping when training multiple models in parallel:
Logging and plotting operations should not be vmapped, but some logging should still be done. Say we're sending some validation plots to tensorboard: do we log one model, or all of them separately? Do we send all of the loss values separately, or just statistics?
It also doesn't make sense to vmap over a tqdm progress bar, though we still want to keep the progress bar when training ensembles.
The current solution is to add several if ensembled: blocks (example) at appropriate points in TaskTrainer.__call__, at which we:
either add (or don't add) batch dimensions to the arrays meant to store the training history, e.g. losses;
either split (or don't split) random keys;
either apply (or don't apply) vmap to TaskTrainer._train_step, optimizer.init, etc. prior to using these functions.
It would be nice for __call__ itself to be vmappable, but I'm not sure how this could be achieved. Perhaps we could use something like jax.experimental.host_callback to pass data back to logging functions, but I don't see how this would solve the progress bar issue.
The text was updated successfully, but these errors were encountered:
Some parts of
TaskTrainer.__call__
are not fit for vmapping when training multiple models in parallel:tqdm
progress bar, though we still want to keep the progress bar when training ensembles.The current solution is to add several
if ensembled:
blocks (example) at appropriate points inTaskTrainer.__call__
, at which we:vmap
toTaskTrainer._train_step
,optimizer.init
, etc. prior to using these functions.It would be nice for
__call__
itself to be vmappable, but I'm not sure how this could be achieved. Perhaps we could use something likejax.experimental.host_callback
to pass data back to logging functions, but I don't see how this would solve the progress bar issue.The text was updated successfully, but these errors were encountered: