Skip to content

Commit

Permalink
Replace register_backward_hook with register_full_backward_hook (#720)
Browse files Browse the repository at this point in the history
Summary:

register_backward_hook is deprecated.

Differential Revision: D68562558
  • Loading branch information
iden-kalemaj authored and facebook-github-bot committed Jan 23, 2025
1 parent ea4cb95 commit bf7bf0c
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 2 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci_gpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ jobs:
mkdir -p runs/cifar10/logs
mkdir -p runs/cifar10/test-reports
pip install tensorboard
python examples/cifar10.py --lr 0.1 --sigma 1.5 -c 10 --batch-size 2000 --epochs 10 --data-root runs/cifar10/data --log-dir runs/cifar10/logs --device cuda
python examples/cifar10.py --lr 0.1 --sigma 1.5 -c 10 --batch-size 2000 --epochs 10 --data-root runs/cifar10/data --log-dir runs/cifar10/logs --device cuda --clip_per_layer
python -c "import torch; model = torch.load('model_best.pth.tar'); exit(0) if (model['best_acc1']>0.4 and model['best_acc1']<0.49) else exit(1)"
python examples/cifar10.py --lr 0.1 --sigma 1.5 -c 10 --batch-size 2000 --epochs 10 --data-root runs/cifar10/data --log-dir runs/cifar10/logs --device cuda --grad_sample_mode no_op
python -c "import torch; model = torch.load('model_best.pth.tar'); exit(0) if (model['best_acc1']>0.4 and model['best_acc1']<0.49) else exit(1)"
Expand Down
2 changes: 1 addition & 1 deletion opacus/grad_sample/grad_sample_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def add_hooks(
)

self.autograd_grad_sample_hooks.append(
module.register_backward_hook(
module.register_full_backward_hook(
partial(
self.capture_backprops_hook,
loss_reduction=loss_reduction,
Expand Down
1 change: 1 addition & 0 deletions opacus/tests/multigpu_gradcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ def test_gradient_correct(self) -> None:
)
clipping_grad_sample_pairs.append(("ghost", "ghost"))

clipping_grad_sample_pairs = [("per_layer", "hooks")]
for clipping, grad_sample_mode in clipping_grad_sample_pairs:

weight_dp, weight_nodp = torch.zeros(10, 10), torch.zeros(10, 10)
Expand Down

0 comments on commit bf7bf0c

Please sign in to comment.