Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Training, and vmapping over ensembles of models #5

Open
mlprt opened this issue Feb 16, 2024 · 0 comments
Open

Training, and vmapping over ensembles of models #5

mlprt opened this issue Feb 16, 2024 · 0 comments

Comments

@mlprt
Copy link
Owner

mlprt commented Feb 16, 2024

Some parts of TaskTrainer.__call__ are not fit for vmapping when training multiple models in parallel:

  • Logging and plotting operations should not be vmapped, but some logging should still be done. Say we're sending some validation plots to tensorboard: do we log one model, or all of them separately? Do we send all of the loss values separately, or just statistics?
  • It also doesn't make sense to vmap over a tqdm progress bar, though we still want to keep the progress bar when training ensembles.

The current solution is to add several if ensembled: blocks (example) at appropriate points in TaskTrainer.__call__, at which we:

  • either add (or don't add) batch dimensions to the arrays meant to store the training history, e.g. losses;
  • either split (or don't split) random keys;
  • either apply (or don't apply) vmap to TaskTrainer._train_step, optimizer.init, etc. prior to using these functions.

It would be nice for __call__ itself to be vmappable, but I'm not sure how this could be achieved. Perhaps we could use something like jax.experimental.host_callback to pass data back to logging functions, but I don't see how this would solve the progress bar issue.

@mlprt mlprt added help wanted Extra attention is needed jax and removed help wanted Extra attention is needed labels Feb 16, 2024
@mlprt mlprt pinned this issue Mar 4, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

1 participant