Skip to content

Commit 002f27a

Browse files
authored
[torchax] Fixes test for kthvalue #7458 (#9223)
1 parent a687881 commit 002f27a

File tree

2 files changed

+18
-2
lines changed

2 files changed

+18
-2
lines changed

torchax/test/test_ops.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
"histogram", # hard op: AssertionError: Tensor-likes are not close!
2020
"histogramdd", # TypeError: histogram requires ndarray or scalar arguments, got <class 'list'> at position 1.
2121
"index_reduce",
22-
"kthvalue",
2322
"linalg.ldl_solve",
2423
"max_pool2d_with_indices_backward",
2524
"nn.functional.adaptive_max_pool1d",
@@ -175,7 +174,7 @@ def run_export_and_compare(testcase,
175174
# Sort related ops should ignore index;
176175
# For example: sort( [1, 0, 0]) -> [0, 0, 1]
177176
# the correct index can be [1, 2, 0] or [2, 1, 0]
178-
should_ignore_indexes = {"topk", "mode"}
177+
should_ignore_indexes = {"topk", "mode", "kthvalue"}
179178

180179

181180
class TestOpInfo(TestCase):

torchax/torchax/ops/jaten.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5496,3 +5496,20 @@ def linear(input, weight, bias=None):
54965496
if bias is not None:
54975497
res += bias
54985498
return res
5499+
5500+
5501+
@op(torch.ops.aten.kthvalue)
5502+
def kthvalue(input, k, dim=None, keepdim=False, *, out=None):
5503+
if input.ndim == 0:
5504+
return input, jnp.array(0)
5505+
dimension = -1
5506+
if dim is not None:
5507+
dimension = dim
5508+
while dimension < 0:
5509+
dimension = dimension + input.ndim
5510+
values = jax.lax.index_in_dim(
5511+
jnp.partition(input, k - 1, dimension), k - 1, dimension, keepdim)
5512+
indices = jax.lax.index_in_dim(
5513+
jnp.argpartition(input, k - 1, dimension).astype('int64'), k - 1,
5514+
dimension, keepdim)
5515+
return values, indices

0 commit comments

Comments
 (0)