Skip to content

Commit bb673fb

Browse files
srossrosspytorchmergebot
authored andcommitted
fix: update error when tensor escapes vmap (pytorch#89077)
Fixes pytorch/functorch#1054 @zou3519, I played around with it, but I am unsure of how to repro the cases for gen_vmap_inplace_plumbing and below in gen_vmap_plumbing_no_returns I've also seen that there are 24 other instances of the `TORCH_INTERNAL_ASSERT(maybe_layer.has_value());` assert, should I change all of these and add tests? Pull Request resolved: pytorch#89077 Approved by: https://github.com/zou3519
1 parent 2c2cce7 commit bb673fb

File tree

5 files changed

+49
-7
lines changed

5 files changed

+49
-7
lines changed

aten/src/ATen/functorch/BatchRulesHelper.h

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
//
44
// This source code is licensed under the BSD-style license found in the
55
// LICENSE file in the root directory of this source tree.
6+
#pragma once
7+
8+
#include <c10/util/TypeList.h>
69

710
#include <ATen/ATen.h>
811
#include <ATen/Operators.h>
@@ -65,7 +68,7 @@ template <typename A, A a, typename C>
6568
struct BasicUnaryBatchRuleHelper;
6669

6770
template <typename F, F Func, typename A, typename... T>
68-
struct BasicUnaryBatchRuleHelper<F, Func, typelist<A, T...>> {
71+
struct BasicUnaryBatchRuleHelper<F, Func, c10::guts::typelist::typelist<A, T...>> {
6972
static std::tuple<Tensor,optional<int64_t>> apply(
7073
const Tensor& tensor,
7174
optional<int64_t> batch_dim,
@@ -90,7 +93,7 @@ template <typename A, A a, typename C>
9093
struct VariadicBdimsBatchRuleHelper;
9194

9295
template <typename F, F Func, typename A, typename... T>
93-
struct VariadicBdimsBatchRuleHelper<F, Func, typelist<A, T...>> {
96+
struct VariadicBdimsBatchRuleHelper<F, Func, c10::guts::typelist::typelist<A, T...>> {
9497
static std::tuple<Tensor,optional<int64_t>> apply(
9598
const Tensor& tensor,
9699
optional<int64_t> batch_dim,
@@ -123,7 +126,8 @@ void boxed_tensor_inputs_batch_rule(const c10::OperatorHandle& op, torch::jit::S
123126

124127
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
125128
auto maybe_layer = maybeCurrentDynamicLayer();
126-
TORCH_INTERNAL_ASSERT(maybe_layer.has_value());
129+
vmap_check_escaped(maybe_layer, "boxed_tensor_inputs_batch_rule");
130+
127131
int64_t cur_level = maybe_layer->layerId();
128132

129133
auto orig_arguments = torch::jit::last(*stack, num_arguments);
@@ -379,7 +383,7 @@ template <typename A, A a, typename C>
379383
struct ExistingBdimBatchRuleHelper;
380384

381385
template <typename F, F Func, typename A, typename... T>
382-
struct ExistingBdimBatchRuleHelper<F, Func, typelist<A, T...>> {
386+
struct ExistingBdimBatchRuleHelper<F, Func, c10::guts::typelist::typelist<A, T...>> {
383387
static std::tuple<Tensor,optional<int64_t>> apply(
384388
const Tensor& self,
385389
optional<int64_t> self_bdim,

aten/src/ATen/functorch/PlumbingHelper.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,17 @@
1010

1111
namespace at { namespace functorch {
1212

13+
void vmap_check_escaped(const optional<DynamicLayer> &layer, const char* what) {
14+
TORCH_CHECK(
15+
layer.has_value(),
16+
"Either your tensor may have escaped from inside a function being vmapped and this is a user error ",
17+
"(see https://pytorch.org/functorch/stable/ux_limitations.html), "
18+
"or there is an internal functorch error in `",
19+
what,
20+
"` Please file an issue if it looks like the latter"
21+
)
22+
}
23+
1324
Tensor makeBatched(const Tensor& tensor, optional<int64_t> bdim, int64_t level) {
1425
if (bdim.has_value()) {
1526
TORCH_INTERNAL_ASSERT(*bdim >= 0);

aten/src/ATen/functorch/PlumbingHelper.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626

2727
namespace at { namespace functorch {
2828

29+
void vmap_check_escaped(const optional<DynamicLayer> &layer, const char* what);
30+
2931
// Create a BatchedTensor given a tensor, bdim, and level
3032
TORCH_API Tensor makeBatched(const Tensor& tensor, optional<int64_t> bdim, int64_t level);
3133

test/functorch/test_vmap.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3920,6 +3920,31 @@ def test_vmap_multi_dot_failure_1D_input(self):
39203920
with self.assertRaisesRegex(RuntimeError, "tensor 1 must be 2D but got 1D"):
39213921
return vmap(torch.linalg.multi_dot)(inputs)
39223922

3923+
def test_vmap_escaped_error(self):
3924+
escaped = None
3925+
3926+
def f(x):
3927+
nonlocal escaped
3928+
escaped = x
3929+
return x ** 2
3930+
3931+
x = torch.randn(3)
3932+
vmap(f)(x)
3933+
3934+
common_message = r"your tensor may have escaped from inside a function being vmapped.*{0}.*"
3935+
3936+
with self.assertRaisesRegex(RuntimeError, common_message.format("gen_vmap_plumbing")):
3937+
escaped.sin()
3938+
3939+
with self.assertRaisesRegex(RuntimeError, common_message.format("boxed_tensor_inputs_batch_rule")):
3940+
escaped.sin_()
3941+
3942+
with self.assertRaisesRegex(RuntimeError, common_message.format("gen_vmap_inplace_plumbing")):
3943+
escaped.mul_(1)
3944+
3945+
vmap(f)(torch.tensor([[0, 0], [0, 0]], dtype=torch.int))
3946+
with self.assertRaisesRegex(RuntimeError, common_message.format("gen_vmap_plumbing_no_returns")):
3947+
torch.ops.aten._linalg_check_errors(escaped, 'linalg.inv', is_matrix=False)
39233948

39243949
class TestRandomness(TestCase):
39253950
def _reset_random(self, generator, orig_state, use_generator, seed):

torchgen/gen_vmap_plumbing.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def gen_vmap_inplace_plumbing(native_function: NativeFunction) -> Optional[str]:
167167
{sig.decl(name=schema.name.unambiguous_name() + '_generated_plumbing')} {{
168168
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
169169
auto maybe_layer = maybeCurrentDynamicLayer();
170-
TORCH_INTERNAL_ASSERT(maybe_layer.has_value());
170+
vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing");
171171
int64_t {cur_level_var} = maybe_layer->layerId();
172172
{textwrap.indent(bdims_all_none_case, " ")}
173173
{textwrap.indent(unwraps, " ")}
@@ -189,7 +189,7 @@ def gen_vmap_plumbing_no_returns(native_function: NativeFunction) -> str:
189189
{sig.decl(name=schema.name.unambiguous_name() + '_generated_plumbing')} {{
190190
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
191191
auto maybe_layer = maybeCurrentDynamicLayer();
192-
TORCH_INTERNAL_ASSERT(maybe_layer.has_value());
192+
vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns");
193193
int64_t {cur_level_var} = maybe_layer->layerId();
194194
{textwrap.indent(bdims_all_none_case, " ")}
195195
{textwrap.indent(unwraps, " ")}
@@ -232,7 +232,7 @@ def gen_vmap_plumbing(native_function: NativeFunction) -> Optional[str]:
232232
{sig.decl(name=schema.name.unambiguous_name() + '_generated_plumbing')} {{
233233
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
234234
auto maybe_layer = maybeCurrentDynamicLayer();
235-
TORCH_INTERNAL_ASSERT(maybe_layer.has_value());
235+
vmap_check_escaped(maybe_layer, "gen_vmap_plumbing");
236236
int64_t {cur_level_var} = maybe_layer->layerId();
237237
{textwrap.indent(bdims_all_none_case, " ")}
238238
{textwrap.indent(unwraps, " ")}

0 commit comments

Comments
 (0)