-
-
Notifications
You must be signed in to change notification settings - Fork 986
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FR] Support Automatic Mixed Precision training #3316
Comments
@fritzo I might be willing to try to tackle this, do you have any opinions on how to expose the functionality to the end user? |
Hi @austinv11, Thanks for offering. I'd guess there are a few ways we could support AMP in Pyro:
Would you be interested in getting (1) or (2) working for yourself then contributing docs to show how you did it? We're happy to answer any questions about Pyro, but I think you know more about AMP than us 🙂 |
It looks like I might need to try option 3 since AMP-aware gradient scaling requires access to the optimizer's step() function. I could try making it a boolean flag for But I could see most users wanting to just activate AMP for their entire model rather than just specific portions of code. Do you think it might be worth adding a new ELBO function that autocasts the entire model for the user? |
Let me try again to persuade you towards options (1) or (2) 😄, admitting I don't know your details or how AMP works. Back in the early days of Pyro we decided to wrap PyTorch's optimizer classes so we could have more control over dynamically created parameters. In practice this made Pyro's optimization idioms incompatible with other frameworks build on top of PyTorch, e.g. lightning, horovod, AMP, new higher-order optimizers. To work around this incompatibility we've since added ways to compute differentiable losses in Pyro so that optimization can be done entirely using torch idioms, without ever using For example instead of the original pyro-idiomatic optimization def model(args):
...
guide = AutoNormal(model)
elbo = Trace_ELBO()
optim = pyro.optim.Adam(...) # <---- pyro idioms
svi = SVI(model, guide, optim, elbo)
for step in range(...):
svi.step(args) you can use torch-idiomatic optimizers class Model(PyroModule):
def forward(args):
...
model = Model()
guide = AutoNormal(model)
elbo = Trace_ELBO()
loss_fn = elbo(model, guide)
optim = torch.optim.Adam(elbo.parameters(), ...) # <---- torch idioms
for step in range(...):
optimizer.zero_grad()
loss = loss_fn(args)
loss.backward()
optimizer.step() # <---- Can we use AMP here? What I'm hoping is that by switching to torch-native optimizers it will be easy/trivial to support AMP. That said, we'd still be open to adding AMP support to pyro.optim if you can find a simple maintainable way to do so 🙂. |
Ah, I see what you mean. Am I correct in understanding that this wouldn't be compatible with the SVI trainer and would require using PyroModules then? |
That is also incompatible with models/guides that dynamically create parameters during training, if I understand correctly. |
@austinv11 @ilia-kats correct. |
Issue Description
Better support for mixed precision training would be extremely helpful, at least for SVI. I can manually cast data into
float16
orbfloat16
but I am unable to leverage PyTorch's automatic mixed precision training. This is because it requires the use of theGradScaler
class during the optimization loop to properly scale gradients in a mixed-precision-aware manner. See the documentation for more info: https://pytorch.org/docs/stable/amp.htmlIt would be nice to have support for using this class within pyro optimizers to allow for amp support.
The text was updated successfully, but these errors were encountered: