From 1181a6a59bc801fce09f91457e9baeeca6abcd1a Mon Sep 17 00:00:00 2001 From: Qianmin Jiang Date: Sat, 26 Apr 2025 06:37:15 +0000 Subject: [PATCH 1/2] Fix the Op info test of cat --- torchax/test/test_ops.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/torchax/test/test_ops.py b/torchax/test/test_ops.py index 001e2eb8bc6b..5aafe5f1e404 100644 --- a/torchax/test/test_ops.py +++ b/torchax/test/test_ops.py @@ -14,7 +14,6 @@ "_segment_reduce", "bincount", # NOTE: dtype for int input torch gives float. This is weird. "byte", - "cat", "cholesky_solve", "diagonal_copy", "geqrf", @@ -204,6 +203,10 @@ def test_reference_eager(self, device, dtype, op): # print("[DEBUG] sample_input: ", sample_input) + if op.name == "cat": + dim = sample_input.kwargs.get('dim') + if dim and dim >= sample_input.input[0].dim(): + continue # TODO: this is a workaround to skip int64 cast for linspace # reference: https://github.com/pytorch/xla/issues/7505#issuecomment-2400895692 and subsequent comments # we have opened a bug in pytorch: https://github.com/pytorch/pytorch/issues/137546 From 7e8482e0deaded04c39fca371c1fc327b6c8837d Mon Sep 17 00:00:00 2001 From: Qianmin Jiang Date: Sat, 26 Apr 2025 06:37:15 +0000 Subject: [PATCH 2/2] Fix the Op info test of cat --- torchax/test/test_ops.py | 1 - torchax/torchax/ops/jaten.py | 5 +++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/torchax/test/test_ops.py b/torchax/test/test_ops.py index 001e2eb8bc6b..2cfafa62a5b5 100644 --- a/torchax/test/test_ops.py +++ b/torchax/test/test_ops.py @@ -14,7 +14,6 @@ "_segment_reduce", "bincount", # NOTE: dtype for int input torch gives float. This is weird. "byte", - "cat", "cholesky_solve", "diagonal_copy", "geqrf", diff --git a/torchax/torchax/ops/jaten.py b/torchax/torchax/ops/jaten.py index c47068165949..6e7df1591dc6 100644 --- a/torchax/torchax/ops/jaten.py +++ b/torchax/torchax/ops/jaten.py @@ -1198,6 +1198,11 @@ def _aten_relu(self): @op(torch.ops.aten.cat) def _aten_cat(tensors, dims=0): + # handle it as a special case if the first tensor is empty. + # torch.cat will ignore the empty tensor, while jnp.concatenate + # will error if the dims > 0. + if tensors[0].ndim == 1 and tensors[0].shape[0] == 0: + return jnp.concatenate(tensors[1:], dims) return jnp.concatenate(tensors, dims)