Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extra loss terms before loss.backward() seem to have no effects #249

Open
kenziyuliu opened this issue Nov 8, 2021 · 10 comments
Open

Extra loss terms before loss.backward() seem to have no effects #249

kenziyuliu opened this issue Nov 8, 2021 · 10 comments
Assignees

Comments

@kenziyuliu
Copy link

🐛 Bug

Extra loss terms before loss.backward() do not seem to have effects when privacy_engine is in use. One use case this would be blocking is when we want to regularize model weights towards another set of weights (e.g. multi-task learning regularization), or other weight-based regularization techniques.

Please reproduce using our template Colab and post here the link

https://colab.research.google.com/drive/1TyZMh4IgkB8qTak1JqYpBFMrrE_x1Rbp?usp=sharing

  • 1st code cell: added an extra loss term based weights (l2 loss)
  • last 2 code cells: train models with and without privacy_engine respectively

To Reproduce

  1. Run all cells in the notebook
  2. With privacy_engine attached, I would expect the extra loss term (1st code cell) to have an effect on model learning
  3. If we look at the output of the last two cells, it seems that when privacy_engine is enabled, the extra loss term is not taken into account

Expected behavior

When we add loss terms before backprop, e.g.,

loss = criterion(y_pred, y_true)
loss += l2_loss(model)
loss += proximal_loss(model, another_model)   # e.g. encourage two models to have similar weights
loss.backward()

the extra loss would reflect into training. However, when we use privacy_engine the extra loss terms seem to have no effect, and this is unexpected since we only clip and noise gradients corresponding to the training examples.

Environment

The issue should be reproducible in the provided colab notebook

@romovpa
Copy link
Contributor

romovpa commented Nov 9, 2021

@kenziyuliu Thanks for reporting! Indeed, there is an issue with GradSampleModule implementation leading to incorrect gradients.

In your example you use squared sum of the parameters (model = nn.Linear):

regularizer = 5 * torch.sum(torch.stack([torch.square(p).sum() for p in model.parameters()]))
loss = criterion(model(x), y) + regularizer

GradSampleModule collects grad_samples using Module backward_hooks. When the left part is computed, Module backward_hook is called because in the expression we call the model. In the regularizer part the model is not called hence Model.backward_hook is not called too. But in both cases Tensor.hook is invoked.

Thanks this message for the tip. It also recommends using Tensor hooks instead of Module hooks. I think it's a good idea and we should redesign grad sampler to use Tensor hooks.

@kenziyuliu
Copy link
Author

Hi @romovpa, thanks for the reply! Would there be a workaround for now with minor code changes?

@romovpa
Copy link
Contributor

romovpa commented Nov 12, 2021

@kenziyuliu From what comes to mind, as a workaround it may be possible to create a custom Module and implement a grad sampler for it. Something like:

model = RegularizedLinear()

y_hat, regularizer = model(x)
loss = criterion(y_hat, y) + regularizer

@ffuuugor @karthikprasad Coud you check it this is viable? Can we handle multiple outputs?

Created a separate issue for the bigger grad sampler problem #259

@alexandresablayrolles
Copy link
Contributor

@romovpa: I don't believe this will work because it will create "per-sample" gradients of the regularizer.
@kenziyuliu: In your particular case, there should be a work-around, that consists in adding to p.grad the value model.p - another_model.p, right after the PrivacyEngine step (which is executed before the Optimizer step).

I believe something along these lines should work:

def new_step(self, is_empty):
    self.original_step(is_empty)
    for (p, p_another) in zip(model.parameters(), another_model.parameters()):
        p.grad += lambda_regularizer * (p - p_another)
engine.original_step = engine.step
engine.step = types.MethodType(new_step, engine)

@kenziyuliu
Copy link
Author

kenziyuliu commented Dec 12, 2021

Hi @alexandresablayrolles, thanks for the response! It seems that the API has changed for the engine object and engine.step is no longer available. Do we now apply this to the DPOptimizer object after the engine.make_private() call? If so where should I specify the is_empty argument (by default we do optimizer.step())? Thanks.

@kenziyuliu
Copy link
Author

kenziyuliu commented Dec 16, 2022

Hi Opacus team, just bumping this issue -- would this by any chance be resolved / have a clean work-around in the new versions of Opacus? I know with functorch one could now do DP training with grad and vmap manually as in JAX, though it'd be very nice to have this as part of the PrivacyEngine. Thanks!

@lucacorbucci
Copy link

Hi, @kenziyuliu have you found a way to solve this problem? Is there a workaround to add an extra term to the loss before loss.backward()?

@kenziyuliu
Copy link
Author

Hi @lucacorbucci I haven't tried since then, though I believe the new functorch package provides a good workaround that allows you to manually take per-example gradients (with grad and vmap) as in JAX

@lucacorbucci
Copy link

Thank you @kenziyuliu! I'll try with functorch

@PaulaDelgado-Santos
Copy link

PaulaDelgado-Santos commented Jul 19, 2024

Hi, has anyone tried this (add an extra term to the loss before loss.backward()) since and do you know if it works? Thank you very much

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants