Skip to content

Fixes pytorch/xla#7398 #9047

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
1 change: 0 additions & 1 deletion torchax/test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
"_segment_reduce",
"bincount", # NOTE: dtype for int input torch gives float. This is weird.
"byte",
"cat",
"cholesky_solve",
"diagonal_copy",
"geqrf",
Expand Down
5 changes: 5 additions & 0 deletions torchax/torchax/ops/jaten.py
Original file line number Diff line number Diff line change
Expand Up @@ -1248,6 +1248,11 @@ def _aten_relu(self):

@op(torch.ops.aten.cat)
def _aten_cat(tensors, dims=0):
# handle it as a special case if the first tensor is empty.
# torch.cat will ignore the empty tensor, while jnp.concatenate
# will error if the dims > 0.
if len(tensors) > 0 and tensors[0].ndim == 1 and tensors[0].shape[0] == 0:
return jnp.concatenate(tensors[1:], dims)
return jnp.concatenate(tensors, dims)


Expand Down