Skip to content

Commit 43e7cf1

Browse files
pssrawatfacebook-github-bot
authored andcommitted
Don't assert for symbolic stride in dim_order_from_stride() (#15472)
Summary: Curently dim_order_for_stride() checks if any stride is 0, and fails if so. N7613577 shows a min repro from factorized joiner use case with symbolic stride, where this check fails: P2015309933 This failure is blocking us from migrating live translation models to ExecuTorch. This diff fixes the block by skipping the assert for symbolic strides. Reviewed By: angelayi Differential Revision: D85875885
1 parent be8b775 commit 43e7cf1

File tree

3 files changed

+24
-3
lines changed

3 files changed

+24
-3
lines changed

exir/tensor.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,12 +67,12 @@ def dim_order_from_stride(stride: Tuple[int]) -> Tuple[bytes]:
6767
Another example is: sizes = (1, 3, 1, 1) with strides = (3, 1, 3, 3), returned
6868
value is (0, 2, 3, 1)
6969
"""
70+
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious, guard_or_false
71+
7072
for _, s in enumerate(stride):
71-
if s == 0:
73+
if guard_or_false(s == 0):
7274
raise ValueError("0 in strides is not supported for ExecuTorch.")
7375

74-
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
75-
7676
class K(NamedTuple):
7777
stride: int
7878

exir/tests/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,7 @@ python_unittest(
385385
deps = [
386386
"//caffe2:torch",
387387
"//executorch/exir:dim_order_utils",
388+
"//executorch/exir:lib",
388389
],
389390
)
390391

exir/tests/test_dim_order_utils.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import unittest
99

1010
import torch
11+
from executorch.exir import to_edge_transform_and_lower
1112
from executorch.exir.dim_order_utils import get_dim_order, get_memory_format
1213

1314

@@ -27,3 +28,22 @@ def test_get_dim_order(self) -> None:
2728
list(range(ndim)), get_dim_order(torch.contiguous_format, ndim)
2829
)
2930
self.assertEqual([0, 2, 3, 1], get_dim_order(torch.channels_last, 4))
31+
32+
def test_dim_order_from_stride(self):
33+
class Test(torch.nn.Module):
34+
def __init__(self):
35+
super().__init__()
36+
37+
def forward(self, t1, t2):
38+
idx = torch.nonzero(t1).reshape(-1)
39+
y = torch.index_select(t2, 0, idx)
40+
return y
41+
42+
M = Test()
43+
x = torch.tensor([0, 1, 1, 0, 1], dtype=torch.bool)
44+
y = torch.randn(5, 6)
45+
M(x, y)
46+
47+
expo_prog = torch.export.export_for_training(M, (x, y))
48+
edge_prog = to_edge_transform_and_lower(expo_prog)
49+
edge_prog.to_executorch()

0 commit comments

Comments
 (0)