Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

deprecation message for non-full backward hook #328

Open
Solosneros opened this issue Jan 18, 2022 · 13 comments
Open

deprecation message for non-full backward hook #328

Solosneros opened this issue Jan 18, 2022 · 13 comments
Assignees
Labels
enhancement New feature or request

Comments

@Solosneros
Copy link
Contributor

🐛 Bug

I am getting the following warning when using opacus on my system:
/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py:1025: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior. warnings.warn("Using a non-full backward hook when the forward contains multiple autograd Nodes "

To Reproduce

Steps to reproduce the behavior:
(sorry, I had problems running google colab)..

  1. Execute the minimal example code given below:
import numpy as np
import opacus
import torch
import torch.nn as nn
from opacus.utils.uniform_sampler import UniformWithReplacementSampler
from torch.utils.data import DataLoader, TensorDataset, Dataset


DEVICE = "cuda:0"

BATCH_SZ = 8
FEATURE_DIM = 16
N_CLASSES = 2
DATASET_LEN = 50_000

class DummyDataset(Dataset):
    def __init__(self):
        pass

    def __getitem__(self, i):
        x = torch.randn([FEATURE_DIM]).to(DEVICE) 
        y = 0
        return x, y

    def __len__(self):
        return DATASET_LEN

train_ds = DummyDataset()
train_loader = DataLoader(train_ds, BATCH_SZ, shuffle=False)

model = nn.Linear(FEATURE_DIM, N_CLASSES).to(DEVICE)
criterion = nn.CrossEntropyLoss()

privacy_engine = opacus.PrivacyEngine()
model, optimizer, train_loader = privacy_engine.make_private(
        module=model,
        optimizer=torch.optim.SGD(model.parameters(), 0.01),
        data_loader=train_loader,
        noise_multiplier=1,
        max_grad_norm=10,
    )

for epoch in range(2):
  for x, y in train_loader:
        x = x.to(DEVICE)
        y = y.to(DEVICE)
        optimizer.zero_grad()
        out = model(x)
        loss = criterion(out, y)
        loss.backward()
        optimizer.step()

Expected behavior

No error message.

Environment

PyTorch version: 1.10.1
Is debug build: False
CUDA used to build PyTorch: 10.2
ROCM used to build PyTorch: N/A

OS: Ubuntu 20.04.3 LTS (x86_64)
GCC version: (Ubuntu 9.3.0-17ubuntu1~20.04) 9.3.0
Clang version: 10.0.0-4ubuntu1
CMake version: version 3.16.3
Libc version: glibc-2.31

Python version: 3.9.7 (default, Sep 16 2021, 13:09:58) [GCC 7.5.0] (64-bit runtime)
Python platform: Linux-5.13.0-1026-oem-x86_64-with-glibc2.31
Is CUDA available: False
CUDA runtime version: No CUDA
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] mypy-extensions==0.4.3
[pip3] numpy==1.20.3
[pip3] numpydoc==1.1.0
[pip3] pytorch-ranger==0.1.1
[pip3] torch==1.10.1
[pip3] torch-optimizer==0.3.0
[pip3] torchaudio==0.10.1
[pip3] torchvision==0.11.2
[conda] blas 1.0 mkl
[conda] cudatoolkit 10.2.89 hfd86e86_1
[conda] ffmpeg 4.3 hf484d3e_0 pytorch
[conda] mkl 2021.4.0 h06a4308_640
[conda] mkl-service 2.4.0 py39h7f8727e_0
[conda] mkl_fft 1.3.1 py39hd3c417c_0
[conda] mkl_random 1.2.2 py39h51133e4_0
[conda] mypy_extensions 0.4.3 py39h06a4308_0
[conda] numpy 1.20.3 py39hf144106_0
[conda] numpy-base 1.20.3 py39h74d4b33_0
[conda] numpydoc 1.1.0 pyhd3eb1b0_1
[conda] pytorch 1.10.1 py3.9_cuda10.2_cudnn7.6.5_0 pytorch
[conda] pytorch-mutex 1.0 cuda pytorch
[conda] pytorch-ranger 0.1.1 pyhd8ed1ab_0 conda-forge
[conda] torch-optimizer 0.3.0 pyhd8ed1ab_0 conda-forge
[conda] torchaudio 0.10.1 py39_cu102 pytorch
[conda] torchvision 0.11.2 py39_cu102 pytorc

@alexandresablayrolles alexandresablayrolles added the enhancement New feature or request label Jan 18, 2022
@alexandresablayrolles alexandresablayrolles self-assigned this Jan 18, 2022
@alexandresablayrolles
Copy link
Contributor

Thanks for flagging this. We are currently discussing with the Pytorch team as the new proposed hooks are not ideal for our use-case.

@Solosneros
Copy link
Contributor Author

Thanks for letting me know!
I noticed that the bug template on colab is for opacus >= 1.0.0. As this seems to be a small issue, I would open a PR for that. Or is that already taken care of?

@ffuuugor
Copy link
Contributor

ffuuugor commented Feb 1, 2022

Whoops, thanks for pointing out the old colab in the bug report template - fixed now.

With full backward hooks, unfortunately, it's not as simple as just replacing the deprecated hooks with the new method - their behaviour not always matches.
One issue in particular is in-place operations - they are flat out forbidden with the new method, but are actively used in torchvision models (at least the last time I checked couple months ago). Abandoning support for torchvision models is a hard trade-off to make, so we're currently weighing our options (including #259 proposal).

I'll leave this issue open for tracking progress in the future, as we definitely plan to address this eventually

@mmsaki
Copy link

mmsaki commented Jun 29, 2022

having this same issue

@alexandresablayrolles
Copy link
Contributor

Thanks for flagging @mmsaki. Currently, it is a warning so you can safely ignore it. We are working on a solution for the next version of Pytorch.

@xiyuanyang45
Copy link

has this issue been solved now?

@xiyuanyang45
Copy link

i have the same problem

@karthikprasad
Copy link
Contributor

Hi @dDCTRr, the warning for hooks still exist due to the reasons outlined by @ffuuugor in this comment; you can safely ignore it though (apologies for the annoyance).

That said, starting Opacus 1.2.0, we support functorch based per-sample gradient computation (no hooks, no warnings). To use this, simply set grad_sample_mode="functorch" in the call to make_private(). You can find more details about this in the release notes.
Note: as mentioned in the release notes, functorch support is still in beta mode, and it could be slower than hooks.

@weiiewwei
Copy link

i met the same error

@weiiewwei
Copy link

has this issue been solved now?

@xiyuanyang45
Copy link

xiyuanyang45 commented Mar 12, 2023 via email

@erik-buchholz
Copy link

Hi @dDCTRr, the warning for hooks still exist due to the reasons outlined by @ffuuugor in this comment; you can safely ignore it though (apologies for the annoyance).

That said, starting Opacus 1.2.0, we support functorch based per-sample gradient computation (no hooks, no warnings). To use this, simply set grad_sample_mode="functorch" in the call to make_private(). You can find more details about this in the release notes. Note: as mentioned in the release notes, functorch support is still in beta mode, and it could be slower than hooks.

Hi @karthikprasad,

I tried setting grad_sample_mode='functorch' or to ew`, but in both cases the warning is still shown. Can you give me any hint on why that is? And more importantly, can I still assume that it is safe to ignore the warning? I am using the module within a GAN, but below is a reduced example that already triggers the warning. I am using Opacus version 1.4.0 and Pytorch 2.0.0+cu118.

Thanks a lot for your help!

import os

import numpy as np
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from opacus import PrivacyEngine
from torch.utils.data import DataLoader
from torchvision import datasets
import torchvision.transforms.functional as F

channels = 1
img_size = 28
img_shape = (channels, img_size, img_size)

def load_data(batch_size: int) -> DataLoader:
    # Configure data loader
    os.makedirs(f"tmp/data/mnist", exist_ok=True)
    dataloader = torch.utils.data.DataLoader(
        datasets.MNIST(
            f"tmp/data/mnist",
            train=True,
            download=True,
            transform=transforms.Compose(
                [transforms.Resize(img_size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5])]
            ),
        ),
        batch_size=batch_size,
        shuffle=True,
    )
    return dataloader

class Generator(nn.Module):
    def __init__(self, dp: bool = False):
        super(Generator, self).__init__()
        self.model = nn.Linear(100, int(np.prod(img_shape)))

    def forward(self, z):
        img = self.model(z).tanh()
        img = img.view(img.size(0), *img_shape)
        return img


m = Generator(True)
m = m.to('cuda')
opt = torch.optim.Adam(m.parameters())
dataloader = load_data(128)
privacy_engine = PrivacyEngine(accountant='rdp')


m, opt, dataloader = privacy_engine.make_private_with_epsilon(
    module=m,
    optimizer=opt,
    data_loader=dataloader,
    epochs=10,
    target_epsilon=10,
    target_delta=1e-5,
    max_grad_norm=1.0,
    poisson_sampling=True,
    grad_sample_mode="functorch"  # ew
)

loss = torch.nn.BCELoss()
for epoch in range(1, 10 + 1):
    for i, (imgs, _) in enumerate(dataloader):
        valid = torch.ones((imgs.size(0), 1), device='cuda:0')
        fake = torch.zeros((imgs.size(0), 1), device='cuda:0')

        real_imgs = imgs.to('cuda')

        # -----------------
        #  Train Generator
        # -----------------
        m.train()
        opt.zero_grad()

        z = torch.randn(size=(imgs.shape[0], 100), device='cuda')

        # The following line triggers the warning:
        gen_imgs = m(z)
        # Quit here, as this is enough to reproduce the warning
        break
    break

The displayed warning:

[...]/lib/python3.10/site-packages/torch/nn/modules/module.py:1344: UserWarning: Using a non-full backward hook when the forward contains multiple autograd Nodes is deprecated and will be removed in future versions. This hook will be missing some grad_input. Please use register_full_backward_hook to get the documented behavior.
warnings.warn("Using a non-full backward hook when the forward contains multiple autograd Nodes "

facebook-github-bot pushed a commit that referenced this issue Apr 18, 2023
Summary:
This PR is a collection of smaller fixes that will save us some deprecation issues in the future

## 1. Updating to PyTorch 2.0

**Key files: grad_sample/functorch.py, requirements.txt**

`functorch` has been a part of core PyTorch since 1.13.
Now they're going a step further and changing the API, while deprecating the old one.

There's a [guide](https://pytorch.org/docs/master/func.migrating.html) on how to migrate. TL;DR - `make_functional` will no longer be part of the API, with `torch.func.functional_call()` being (non drop-in) replacement.

They key difference for us is `make_functional()` creates a fresh copy of the module, while `functional_call()` uses existing module. As a matter of fact, we need the fresh copy (otherwise all the hooks start firing and you enter nested madness), so I've copy-pasted a [gist](https://gist.github.com/zou3519/7769506acc899d83ef1464e28f22e6cf) from the official guide on how to get a full replacement for `make_functional`.

## 2. New mechanism for gradient accumulation detection

**Key file: privacy_engine.py, grad_sample_module.py**

As [reported](https://discuss.pytorch.org/t/gan-raises-userwarning-using-a-non-full-backward-hook-when-the-forward-contains-multiple/175638/2) on the forum, clients are still getting "non-full backward hook" warning even when using `grad_sample_mode="ew"`. Naturally, `functorch` and `hooks` modes rely on backward hooks and can't be migrated to full hooks because [reasons](#328 (comment)). However, `ew` doesn't rely on hooks and it's unclear why the message should appear.

The reason, however, is simple. If the client is using poisson sampling we add an extra check to prohibit gradient accumulation (two poisson batches combined is not a poisson batch), and we do that by the means of backward hooks.

~In this case, backward hook serves a simple purpose and there shouldn't be any problems with migrating to the new method, however that involved changing the checking method. That's because `register_backward_hook` is called *after* hooks on submodule, but `register_full_backward_hook` is called before.~

Strikethrough solution didn't work, because hook order execution is weird for complex graphs, e.g. for GANs. For example, if your forward call looks like this:
```
Discriminator(Generator(x))
```
then top-level module hook will precede submodule's hooks for `Generator`, but not for `Discriminator`

As such, I've realised that gradient accumulation is not even supported in `ExpandedWeights`, so we don't have to worry about that. And the other two modes are both hooks-based, so we can just check the accumulation in the existing backward hook, no need for an extra hook. Deleted some code, profit.

## 3. Refactoring `wrap_collate_with_empty` to please pickle

Now here're two facts I didn't know before

1) You can't pickle a nested function, e.g. you can't do the following
```python
def foo():
    def bar():
        <...>

    return bar

pickle.dump(foo(), ...)
```

2) Whether or not `multiprocessing` uses pickle is python- and platform- dependant.

This affects our tests when we test `DataLoader` with multiple workers. As such, our data loaders tests:
* Pass on CircleCI with python3.9
* Fail on my local machine with python3.9
* Pass on my local machine with python3.7

I'm not sure how cow common the issue is, but it's safer to just refactor `wrap_collate_with_empty` to avoid nested functions.

## 4. Fix benchmark tests

We don't really run `benchmarks/tests` on a regular basis, and some of them were broken since we've upgraded to PyTorch 1.13 (`API_CUTOFF_VERSION` doesn't exist anymore)

## 4. Fix flake8 config

Flake8 config no [longer support](https://flake8.pycqa.org/en/latest/user/configuration.html) inline comments, fix is due

Pull Request resolved: #581

Reviewed By: alexandresablayrolles

Differential Revision: D44749760

Pulled By: ffuuugor

fbshipit-source-id: cf225f4134c049da4ee2eef53e1af3ef54d090bf
@volmodaoist
Copy link

I meet the same Userwarning.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

9 participants