Skip to content

Commit 8f1c2bc

Browse files
author
Zonglin Peng
committed
jarvis-nightly-operators-test-aten-where-out
Differential Revision: [D85364554](https://our.internmc.facebook.com/intern/diff/D85364554/) ghstack-source-id: 320085176 Pull Request resolved: #15500
1 parent d3897f8 commit 8f1c2bc

File tree

1 file changed

+6
-22
lines changed

1 file changed

+6
-22
lines changed

backends/cadence/utils/facto_util.py

Lines changed: 6 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -189,47 +189,31 @@ def random_size_constraint(deps: object, r: int, d: int) -> int:
189189
if index == 0: # condition
190190
tensor_constraints = [
191191
cp.Dtype.In(lambda deps: [torch.bool]),
192-
cp.Value.Ge(lambda deps, dtype, struct: -(2**4)),
193-
cp.Value.Le(lambda deps, dtype, struct: 2**4),
192+
cp.Value.Ge(lambda deps, dtype, struct: 0),
193+
cp.Value.Le(lambda deps, dtype, struct: 1),
194194
cp.Rank.Ge(lambda deps: 1),
195195
cp.Size.Ge(lambda deps, r, d: 1),
196196
max_size_constraint,
197197
]
198198
elif index == 1: # input tensor(a)
199199
tensor_constraints = [
200-
cp.Dtype.In(
201-
lambda deps: [
202-
torch.int8,
203-
torch.int16,
204-
torch.uint8,
205-
torch.uint16,
206-
torch.int32,
207-
torch.float32,
208-
]
209-
),
200+
cp.Dtype.In(lambda deps: [torch.float32]),
210201
cp.Value.Ge(lambda deps, dtype, struct: -(2**4)),
211202
cp.Value.Le(lambda deps, dtype, struct: 2**4),
212203
cp.Rank.Ge(lambda deps: 1),
213204
cp.Size.Ge(lambda deps, r, d: 1),
205+
cp.Size.In(lambda deps, r, d: fn.broadcast_with(deps[0].shape, r, d)),
214206
max_size_constraint,
215207
]
216208
else: # input tensor(b)
217209
tensor_constraints = [
218-
cp.Dtype.In(
219-
lambda deps: [
220-
torch.int8,
221-
torch.int16,
222-
torch.uint8,
223-
torch.uint16,
224-
torch.int32,
225-
torch.float32,
226-
]
227-
),
210+
cp.Dtype.In(lambda deps: [torch.float32]),
228211
cp.Dtype.Eq(lambda deps: deps[1].dtype),
229212
cp.Value.Ge(lambda deps, dtype, struct: -(2**4)),
230213
cp.Value.Le(lambda deps, dtype, struct: 2**4),
231214
cp.Rank.Ge(lambda deps: 1),
232215
cp.Size.Ge(lambda deps, r, d: 1),
216+
cp.Size.In(lambda deps, r, d: fn.broadcast_with(fn.broadcasted_shape(deps[0].shape, deps[1].shape), r, d)),
233217
max_size_constraint,
234218
]
235219
case "embedding.default":

0 commit comments

Comments
 (0)