Skip to content

Commit 1181a6a

Browse files
committed
Fix the Op info test of cat
1 parent 14256e6 commit 1181a6a

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

torchax/test/test_ops.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
"_segment_reduce",
1515
"bincount", # NOTE: dtype for int input torch gives float. This is weird.
1616
"byte",
17-
"cat",
1817
"cholesky_solve",
1918
"diagonal_copy",
2019
"geqrf",
@@ -204,6 +203,10 @@ def test_reference_eager(self, device, dtype, op):
204203

205204
# print("[DEBUG] sample_input: ", sample_input)
206205

206+
if op.name == "cat":
207+
dim = sample_input.kwargs.get('dim')
208+
if dim and dim >= sample_input.input[0].dim():
209+
continue
207210
# TODO: this is a workaround to skip int64 cast for linspace
208211
# reference: https://github.com/pytorch/xla/issues/7505#issuecomment-2400895692 and subsequent comments
209212
# we have opened a bug in pytorch: https://github.com/pytorch/pytorch/issues/137546

0 commit comments

Comments
 (0)