Skip to content

Commit 130582c

Browse files
authored
Disable calling Tensor.requires_grad_() inside a functorch transform (#849)
Fixes #847 We do not allow users to call requires_grad_() inside a functorch transform. This is because the user is effectively saying "hey, I want another layer of autograd if I call requires_grad_()", but that doesn't actually work because to set up a layer of autograd we need to do some work (e.g. push autograd onto the DynamicLayerStack). Instead, when a user calls requires_grad_() (and similarly retain_grad), we raise a nice error message. This has the intended consequence of causing torch.autograd.functional.{jvp, vjp, jacobian} to error out when called inside of a functorch transform. Users should use the functorch equivalent. Test Plan: - added tests
1 parent d16c10b commit 130582c

File tree

5 files changed

+118
-3
lines changed

5 files changed

+118
-3
lines changed

Diff for: functorch/_src/eager_transforms.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
_func_increment_nesting,
3333
_assert_wrapped_functional,
3434
_propagate_functional_input_mutation,
35+
set_inplace_requires_grad_allowed,
3536
)
3637

3738
argnums_t = Union[int, Tuple[int, ...]]
@@ -40,7 +41,12 @@
4041
def _create_differentiable(inps, level=None):
4142
def create_differentiable(x):
4243
if isinstance(x, torch.Tensor):
43-
return x.requires_grad_()
44+
try:
45+
set_inplace_requires_grad_allowed(True)
46+
return x.requires_grad_()
47+
finally:
48+
set_inplace_requires_grad_allowed(False)
49+
4450
raise ValueError(f'Thing passed to transform API must be Tensor, '
4551
f'got {type(x)}')
4652
return tree_map(create_differentiable, inps)

Diff for: functorch/csrc/DynamicLayer.cpp

+21-2
Original file line numberDiff line numberDiff line change
@@ -102,13 +102,26 @@ class FuncTorchTLS : public FuncTorchTLSBase {
102102
}
103103

104104
void checkSupportsInplaceRequiresGrad() const override {
105-
// Does nothing
105+
TORCH_CHECK(dynamicLayerStack.size() == 0 || allow_inplace_requires_grad_,
106+
"You are attempting to call Tensor.requires_grad_() (or perhaps using ",
107+
"torch.autograd.functional.* APIs) inside of a function being transformed ",
108+
"by a functorch transform. ",
109+
"This is unsupported, please attempt to use the functorch transforms ",
110+
"(e.g. grad, vjp, jacrev, jacfwd, hessian) or call requires_grad_() "
111+
"outside of a function being transformed instead.");
106112
}
107113
void checkSupportsRetainGrad() const override {
108-
// Does nothing
114+
TORCH_CHECK(dynamicLayerStack.size() == 0,
115+
"You are attempting to call Tensor.retain_grad() ",
116+
"inside of a function being transformed ",
117+
"by a functorch transform. ",
118+
"This is unsupported, please attempt to use the functorch transforms ",
119+
"(e.g. grad, vjp, jacrev, jacfwd, hessian) or call retain_grad() "
120+
"outside of a function being transformed instead.");
109121
}
110122

111123
std::vector<DynamicLayer> dynamicLayerStack;
124+
bool allow_inplace_requires_grad_ = false;
112125
};
113126

114127
static FuncTorchTLS* getRawFunctorchTLS() {
@@ -122,6 +135,12 @@ static FuncTorchTLS* getRawFunctorchTLS() {
122135
return result;
123136
}
124137

138+
void setInplaceRequiresGradAllowed(bool allowed) {
139+
auto* functorch_tls = getRawFunctorchTLS();
140+
functorch_tls->allow_inplace_requires_grad_ = allowed;
141+
}
142+
143+
125144
static std::vector<DynamicLayer>& dynamicLayerStackAccessor() {
126145
return getRawFunctorchTLS()->dynamicLayerStack;
127146
}

Diff for: functorch/csrc/DynamicLayer.h

+2
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,8 @@ Tensor unwrapIfDead(const Tensor& tensor);
8585
std::ostream& operator<<(std::ostream& os, const DynamicLayer& layer);
8686
std::ostream& operator<<(std::ostream& os, const std::vector<DynamicLayer>& dynamicLayerStack);
8787

88+
void setInplaceRequiresGradAllowed(bool allowed);
89+
8890

8991
}
9092
} // namespace at

Diff for: functorch/csrc/init.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
379379
m.def("_set_vmap_fallback_warning_enabled", &at::functorch::setVmapFallbackWarningEnabled, "Set vmap fallback warnings");
380380
m.def("_set_vmap_fallback_enabled", &at::functorch::setVmapFallbackEnabled);
381381
m.def("_is_vmap_fallback_enabled", &at::functorch::isVmapFallbackEnabled);
382+
m.def("set_inplace_requires_grad_allowed", &at::functorch::setInplaceRequiresGradAllowed);
382383
m.def("dlevel", &at::functorch::dlevel, "dlevel");
383384
m.def("dump_tensor", &at::functorch::dump_tensor, "dump_tensor");
384385
m.def("reshape_dim_into", &at::functorch::reshape_dim_into);

Diff for: test/test_eager_transforms.py

+87
Original file line numberDiff line numberDiff line change
@@ -2157,6 +2157,93 @@ def f(x):
21572157
new_cotangent = torch.randn(())
21582158
self.assertEqual(fx_f(new_cotangent, True, True), vjp_fn(new_cotangent))
21592159

2160+
def test_requires_grad_inside_transform(self, device):
2161+
def f(x):
2162+
x.requires_grad_()
2163+
return x.sin().sum()
2164+
2165+
x = torch.randn(3)
2166+
2167+
with self.assertRaisesRegex(RuntimeError, "Tensor.requires_grad_()"):
2168+
vmap(f)(x)
2169+
with self.assertRaisesRegex(RuntimeError, "Tensor.requires_grad_()"):
2170+
grad(f)(x)
2171+
with self.assertRaisesRegex(RuntimeError, "Tensor.requires_grad_()"):
2172+
vmap(grad(f))(x)
2173+
2174+
x = torch.randn([])
2175+
with self.assertRaisesRegex(RuntimeError, "Tensor.requires_grad_()"):
2176+
grad(grad(f))(x)
2177+
2178+
def test_retain_grad_inside_transform(self, device):
2179+
def f(x):
2180+
y = x.sin()
2181+
y.retain_grad()
2182+
return y.sum()
2183+
2184+
x = torch.randn(3)
2185+
2186+
with self.assertRaisesRegex(RuntimeError, "Tensor.retain_grad()"):
2187+
grad(f)(x)
2188+
2189+
def test_autograd_functional_jacrev_inside_transform(self, device):
2190+
def f(x):
2191+
y = torch.autograd.functional.jacobian(lambda x: x.sin().sum(), x)
2192+
return y
2193+
2194+
B = 5
2195+
x = torch.randn(B, 3)
2196+
with self.assertRaisesRegex(RuntimeError, "torch.autograd.functional"):
2197+
vmap(f)(x)
2198+
2199+
x = torch.randn([])
2200+
with self.assertRaisesRegex(RuntimeError, "torch.autograd.functional"):
2201+
grad(f)(x)
2202+
2203+
def test_autograd_functional_vjp_inside_transform(self, device):
2204+
def f(x):
2205+
y = torch.autograd.functional.vjp(lambda x: x.sin().sum(), x)
2206+
return y
2207+
2208+
B = 5
2209+
x = torch.randn(B, 3)
2210+
with self.assertRaisesRegex(RuntimeError, "torch.autograd.functional"):
2211+
vmap(f)(x)
2212+
2213+
x = torch.randn([])
2214+
with self.assertRaisesRegex(RuntimeError, "torch.autograd.functional"):
2215+
grad(f)(x)
2216+
2217+
def test_autograd_functional_jvp_inside_transform(self, device):
2218+
def f(x):
2219+
t = torch.ones_like(x)
2220+
y = torch.autograd.functional.jvp(lambda x: x.sin().sum(), (x,), (t,))
2221+
return y
2222+
2223+
B = 5
2224+
x = torch.randn(B, 3)
2225+
with self.assertRaisesRegex(RuntimeError, "torch.autograd.functional"):
2226+
vmap(f)(x)
2227+
2228+
x = torch.randn([])
2229+
with self.assertRaisesRegex(RuntimeError, "torch.autograd.functional"):
2230+
grad(f)(x)
2231+
2232+
def test_autograd_functional_jacfwd_inside_transform(self, device):
2233+
def f(x):
2234+
y = torch.autograd.functional.jacobian(
2235+
lambda x: x.sin().sum(), x, strategy='forward-mode', vectorize=True)
2236+
return y
2237+
2238+
B = 5
2239+
x = torch.randn(B, 3)
2240+
with self.assertRaises(RuntimeError):
2241+
vmap(f)(x)
2242+
2243+
x = torch.randn([])
2244+
with self.assertRaises(RuntimeError):
2245+
grad(f)(x)
2246+
21602247

21612248
class TestMakeFunctional(TestCase):
21622249
@parametrize('disable_autograd_tracking', [True, False])

0 commit comments

Comments
 (0)