Skip to content

Commit 503cde9

Browse files
pytorchbotZonglin Peng
andauthored
[Jarvis][Nightly] address zero division jarvis-nightly-operators-test-aten-div-out (#15570)
This PR was created by the merge bot to help merge the original PR into the main branch. ghstack PR number: #15496 by @zonglinpeng ^ Please use this as the source of truth for the PR details, comments, and reviews ghstack PR base: https://github.com/pytorch/executorch/tree/gh/zonglinpeng/8/base ghstack PR head: https://github.com/pytorch/executorch/tree/gh/zonglinpeng/8/head Merge bot PR base: https://github.com/pytorch/executorch/tree/gh/zonglinpeng/7/orig Merge bot PR head: https://github.com/pytorch/executorch/tree/gh/zonglinpeng/8/orig Differential Revision: [D85364549](https://our.internmc.facebook.com/intern/diff/D85364549/) @diff-train-skip-merge --------- Co-authored-by: Zonglin Peng <[email protected]>
1 parent 691d16e commit 503cde9

File tree

1 file changed

+130
-43
lines changed

1 file changed

+130
-43
lines changed

backends/cadence/utils/facto_util.py

Lines changed: 130 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import torch
1616
from facto.inputgen.argtuple.gen import ArgumentTupleGenerator
1717
from facto.inputgen.specs.model import ConstraintProducer as cp
18+
from facto.inputgen.utils.random_manager import seeded_random_manager as rm
1819
from facto.inputgen.variable.type import ScalarDtype
1920
from facto.specdb.db import SpecDictDB
2021

@@ -26,6 +27,33 @@
2627
_shape_cache: dict[str, list[int]] = {}
2728

2829

30+
def _positive_valid_dim_list(tensor: torch.Tensor, length: int) -> set[tuple[int, ...]]:
31+
"""
32+
Generate valid permutations using only positive dimension indices.
33+
This is required for Cadence/Xtensa kernels that don't support negative indexing.
34+
35+
Args:
36+
tensor: Input tensor to generate permutations for
37+
length: Number of dimensions in the permutation (must equal tensor.dim())
38+
39+
Returns:
40+
Set of valid permutation tuples containing only positive indices [0, rank-1]
41+
"""
42+
if length > tensor.dim():
43+
return set()
44+
45+
n = tensor.dim()
46+
pool = list(range(n))
47+
48+
# Generate multiple valid permutations (only positive indices)
49+
permutations: set[tuple[int, ...]] = set()
50+
for _ in range(3): # Generate 3 different permutations for diversity
51+
perm = tuple(rm.get_random().sample(pool, length))
52+
permutations.add(perm)
53+
54+
return permutations
55+
56+
2957
def apply_tensor_contraints(op_name: str, index: int) -> list[object]:
3058
# Constraint to limit tensor size to < 4000 bytes with fully randomized shapes
3159
import random
@@ -161,47 +189,37 @@ def random_size_constraint(deps: object, r: int, d: int) -> int:
161189
if index == 0: # condition
162190
tensor_constraints = [
163191
cp.Dtype.In(lambda deps: [torch.bool]),
164-
cp.Value.Ge(lambda deps, dtype, struct: -(2**4)),
165-
cp.Value.Le(lambda deps, dtype, struct: 2**4),
192+
cp.Value.Ge(lambda deps, dtype, struct: 0),
193+
cp.Value.Le(lambda deps, dtype, struct: 1),
166194
cp.Rank.Ge(lambda deps: 1),
167195
cp.Size.Ge(lambda deps, r, d: 1),
168196
max_size_constraint,
169197
]
170198
elif index == 1: # input tensor(a)
171199
tensor_constraints = [
172-
cp.Dtype.In(
173-
lambda deps: [
174-
torch.int8,
175-
torch.int16,
176-
torch.uint8,
177-
torch.uint16,
178-
torch.int32,
179-
torch.float32,
180-
]
181-
),
200+
cp.Dtype.In(lambda deps: [torch.float32]),
182201
cp.Value.Ge(lambda deps, dtype, struct: -(2**4)),
183202
cp.Value.Le(lambda deps, dtype, struct: 2**4),
184203
cp.Rank.Ge(lambda deps: 1),
185204
cp.Size.Ge(lambda deps, r, d: 1),
205+
cp.Size.In(
206+
lambda deps, r, d: fn.broadcast_with(deps[0].shape, r, d)
207+
),
186208
max_size_constraint,
187209
]
188210
else: # input tensor(b)
189211
tensor_constraints = [
190-
cp.Dtype.In(
191-
lambda deps: [
192-
torch.int8,
193-
torch.int16,
194-
torch.uint8,
195-
torch.uint16,
196-
torch.int32,
197-
torch.float32,
198-
]
199-
),
212+
cp.Dtype.In(lambda deps: [torch.float32]),
200213
cp.Dtype.Eq(lambda deps: deps[1].dtype),
201214
cp.Value.Ge(lambda deps, dtype, struct: -(2**4)),
202215
cp.Value.Le(lambda deps, dtype, struct: 2**4),
203216
cp.Rank.Ge(lambda deps: 1),
204217
cp.Size.Ge(lambda deps, r, d: 1),
218+
cp.Size.In(
219+
lambda deps, r, d: fn.broadcast_with(
220+
fn.broadcasted_shape(deps[0].shape, deps[1].shape), r, d
221+
)
222+
),
205223
max_size_constraint,
206224
]
207225
case "embedding.default":
@@ -248,6 +266,9 @@ def random_size_constraint(deps: object, r: int, d: int) -> int:
248266
tensor_constraints.extend(
249267
[
250268
cp.Dtype.In(lambda deps: [torch.float32, torch.int32]),
269+
# Avoid NaN/Inf values that expose clamp NaN handling bugs
270+
cp.Value.Ge(lambda deps, dtype, struct: -(2**4)),
271+
cp.Value.Le(lambda deps, dtype, struct: 2**4),
251272
]
252273
)
253274
case "rsqrt.default":
@@ -323,12 +344,15 @@ def random_size_constraint(deps: object, r: int, d: int) -> int:
323344
]
324345
)
325346
case "constant_pad_nd.default":
326-
tensor_constraints.extend(
327-
[
328-
cp.Dtype.In(lambda deps: [torch.float32]),
329-
cp.Size.Le(lambda deps, r, d: 2**2),
330-
]
331-
)
347+
tensor_constraints = [
348+
cp.Dtype.In(lambda deps: [torch.float32]),
349+
cp.Value.Ge(lambda deps, dtype, struct: -(2**4)),
350+
cp.Value.Le(lambda deps, dtype, struct: 2**4),
351+
cp.Rank.Ge(lambda deps: 1),
352+
cp.Rank.Le(lambda deps: 2), # Reduced from 3 to 2 (max 2D tensors)
353+
cp.Size.Ge(lambda deps, r, d: 1),
354+
cp.Size.Le(lambda deps, r, d: 3), # Max dimension size of 3
355+
]
332356
case "avg_pool2d.default":
333357
tensor_constraints.extend(
334358
[
@@ -344,14 +368,25 @@ def random_size_constraint(deps: object, r: int, d: int) -> int:
344368
]
345369
)
346370
case "div.Tensor":
347-
tensor_constraints.extend(
348-
[
349-
cp.Value.Ne(lambda deps, dtype, struct: 0),
350-
cp.Value.Le(lambda deps, dtype, struct: 2**3),
351-
cp.Size.Le(lambda deps, r, d: 2**3),
352-
cp.Rank.Le(lambda deps: 2**2),
353-
]
354-
)
371+
if index == 1: # Only apply zero-prevention to divisor
372+
tensor_constraints.extend(
373+
[
374+
cp.Value.Ne(
375+
lambda deps, dtype, struct: 0
376+
), # Prevent division by zero
377+
cp.Value.Le(lambda deps, dtype, struct: 2**3),
378+
cp.Size.Le(lambda deps, r, d: 2**3),
379+
cp.Rank.Le(lambda deps: 2**2),
380+
]
381+
)
382+
else:
383+
tensor_constraints.extend(
384+
[
385+
cp.Value.Le(lambda deps, dtype, struct: 2**3),
386+
cp.Size.Le(lambda deps, r, d: 2**3),
387+
cp.Rank.Le(lambda deps: 2**2),
388+
]
389+
)
355390
case "pow.Tensor_Scalar":
356391
tensor_constraints.extend(
357392
[
@@ -405,6 +440,12 @@ def random_size_constraint(deps: object, r: int, d: int) -> int:
405440
cp.Size.Le(lambda deps, r, d: 2**2),
406441
]
407442
)
443+
case "flip.default":
444+
tensor_constraints.extend(
445+
[
446+
cp.Dtype.In(lambda deps: [torch.float32]),
447+
]
448+
)
408449
case _:
409450
pass
410451
return tensor_constraints
@@ -418,6 +459,7 @@ def apply_scalar_contraints(op_name: str) -> list[ScalarDtype]:
418459
| "mul.Scalar"
419460
| "div.Scalar"
420461
| "constant_pad_nd.default"
462+
| "clamp.default"
421463
):
422464
return [ScalarDtype.int]
423465
case "full.default":
@@ -445,7 +487,32 @@ def facto_testcase_gen( # noqa: C901
445487
cp.Size.Le(lambda deps, r, d: 2**2),
446488
]
447489
)
448-
if in_spec.name == "max_val": # hardtanh
490+
# Special handling for clamp.default to ensure min < max with sufficient gap (at least 2) and never None
491+
if op_name == "clamp.default":
492+
if in_spec.name == "min":
493+
# min must always be provided (not None) and bounded, leave room for max
494+
spec.inspec[index].constraints.extend(
495+
[
496+
cp.Optional.Eq(lambda deps: False), # Never None
497+
cp.Value.Ge(lambda deps, dtype: -(2**4)),
498+
cp.Value.Le(
499+
lambda deps, dtype: 2**4 - 2
500+
), # Leave room for max (at least 2 units)
501+
]
502+
)
503+
elif in_spec.name == "max":
504+
# max must always be provided (not None), be >= min + 2 (sufficient gap), and bounded
505+
spec.inspec[index].deps = [0, 1] # deps on input tensor and min
506+
spec.inspec[index].constraints.extend(
507+
[
508+
cp.Optional.Eq(lambda deps: False), # Never None
509+
cp.Value.Ge(
510+
lambda deps, dtype: deps[1] + 2
511+
), # max >= min + 2 (sufficient gap)
512+
cp.Value.Le(lambda deps, dtype: 2**4),
513+
]
514+
)
515+
elif in_spec.name == "max_val": # hardtanh
449516
spec.inspec[index].deps = [0, 1]
450517
spec.inspec[index].constraints.extend(
451518
[cp.Value.Ge(lambda deps, _: deps[1])]
@@ -482,12 +549,32 @@ def facto_testcase_gen( # noqa: C901
482549
apply_tensor_contraints(op_name, index)
483550
)
484551
elif in_spec.type.is_dim_list():
485-
spec.inspec[index].constraints.extend(
486-
[
487-
cp.Length.Ge(lambda deps: 1),
488-
cp.Optional.Eq(lambda deps: False),
489-
]
490-
)
552+
# Special handling for permute_copy.default to ensure valid permutation
553+
if op_name == "permute_copy.default":
554+
spec.inspec[index].constraints.extend(
555+
[
556+
cp.Length.Ge(lambda deps: 1),
557+
cp.Length.Eq(
558+
lambda deps: deps[0].dim()
559+
), # Must be a complete permutation
560+
cp.Optional.Eq(lambda deps: False),
561+
# Generate valid permutations using only positive indices
562+
# Cadence/Xtensa hardware kernels do not support negative dimension indices
563+
cp.Value.Gen(
564+
lambda deps, length: (
565+
_positive_valid_dim_list(deps[0], length),
566+
fn.invalid_dim_list(deps[0], length),
567+
)
568+
),
569+
]
570+
)
571+
else:
572+
spec.inspec[index].constraints.extend(
573+
[
574+
cp.Length.Ge(lambda deps: 1),
575+
cp.Optional.Eq(lambda deps: False),
576+
]
577+
)
491578
elif in_spec.type.is_bool():
492579
spec.inspec[index].constraints.extend(
493580
[

0 commit comments

Comments
 (0)