From 065668403d07a1adc1e97254ca7f9a73bc39c34b Mon Sep 17 00:00:00 2001 From: Victor Garcia Reolid Date: Thu, 20 Mar 2025 23:00:23 +0100 Subject: [PATCH 1/8] feat: support numba compiled sort and argsort functions Signed-off-by: Victor Garcia Reolid --- pytensor/link/numba/dispatch/basic.py | 41 +++++++++++++++++++++++++++ tests/link/numba/test_basic.py | 15 ++++++++++ 2 files changed, 56 insertions(+) diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py index 19e91e5f8e..a5322cc442 100644 --- a/pytensor/link/numba/dispatch/basic.py +++ b/pytensor/link/numba/dispatch/basic.py @@ -37,6 +37,7 @@ from pytensor.tensor.math import Dot from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape from pytensor.tensor.slinalg import Solve +from pytensor.tensor.sort import ArgSortOp, SortOp from pytensor.tensor.type import TensorType from pytensor.tensor.type_other import MakeSlice, NoneConst @@ -432,6 +433,46 @@ def shape_i(x): return shape_i +@numba_funcify.register(SortOp) +def numba_funcify_SortOp(op, node, **kwargs): + if op.kind == "quicksort": + + @numba_njit + def sort_f(a, axis): + return np.sort(a) # numba supports sort without arguments + else: + ret_sig = get_numba_type(node.outputs[0].type) + + def sort_f(a, axis): + with numba.objmode(ret=ret_sig): + ret = np.sort(a, axis=axis, kind=op.kind) + return ret + + return sort_f + + +@numba_funcify.register(ArgSortOp) +def numba_funcify_ArgSortOp(op, node, **kwargs): + def argsort_f_kind(kind): + @numba_njit + def argsort_f(a, axis): + return np.argsort(a, kind=kind) + + return argsort_f + + if op.kind in ["quicksort", "mergesort"]: + return argsort_f_kind(op.kind) + else: + ret_sig = get_numba_type(node.outputs[0].type) + + def argsort_f(a, axis): + with numba.objmode(ret=ret_sig): + ret = np.argsort(a, axis=axis, kind=op.kind) + return ret + + return argsort_f + + @numba.extending.intrinsic def direct_cast(typingctx, val, typ): if isinstance(typ, numba.types.TypeRef): diff --git a/tests/link/numba/test_basic.py b/tests/link/numba/test_basic.py index 654cbe7bd4..606d2114eb 100644 --- a/tests/link/numba/test_basic.py +++ b/tests/link/numba/test_basic.py @@ -33,6 +33,7 @@ from pytensor.tensor import blas from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape +from pytensor.tensor.sort import ArgSortOp, SortOp if TYPE_CHECKING: @@ -378,6 +379,20 @@ def test_Shape(x, i): compare_numba_and_py([], [g], []) +@pytest.mark.parametrize("kind", ["quicksort"]) +def test_Sort(kind): + x = [5, 4, 3, 2, 1] + g = SortOp(kind)(pt.as_tensor_variable(x)) + compare_numba_and_py([], [g], []) + + +@pytest.mark.parametrize("kind", ["quicksort", "mergesort"]) +def test_ArgSort(kind): + x = [5, 4, 3, 2, 1] + g = ArgSortOp(kind)(pt.as_tensor_variable(x)) + compare_numba_and_py([], [g], []) + + @pytest.mark.parametrize( "v, shape, ndim", [ From 5aa1a3930e89553c0f8a79420fe1aea1864638c5 Mon Sep 17 00:00:00 2001 From: Victor Garcia Reolid Date: Fri, 21 Mar 2025 19:34:22 +0100 Subject: [PATCH 2/8] default to supported kind and add warning Signed-off-by: Victor Garcia Reolid --- pytensor/link/numba/dispatch/basic.py | 43 ++++++++++++++------------- tests/link/numba/test_basic.py | 39 ++++++++++++++++++++---- 2 files changed, 57 insertions(+), 25 deletions(-) diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py index a5322cc442..ee1fd06b74 100644 --- a/pytensor/link/numba/dispatch/basic.py +++ b/pytensor/link/numba/dispatch/basic.py @@ -435,18 +435,18 @@ def shape_i(x): @numba_funcify.register(SortOp) def numba_funcify_SortOp(op, node, **kwargs): - if op.kind == "quicksort": - - @numba_njit - def sort_f(a, axis): - return np.sort(a) # numba supports sort without arguments - else: - ret_sig = get_numba_type(node.outputs[0].type) + @numba_njit + def sort_f(a, axis): + return np.sort(a) # numba supports sort without arguments - def sort_f(a, axis): - with numba.objmode(ret=ret_sig): - ret = np.sort(a, axis=axis, kind=op.kind) - return ret + if op.kind != "quicksort": + warnings.warn( + ( + f'Numba function sort doesn\'t support kind="{op.kind}"' + " switching to `quicksort`." + ), + UserWarning, + ) return sort_f @@ -460,17 +460,20 @@ def argsort_f(a, axis): return argsort_f - if op.kind in ["quicksort", "mergesort"]: - return argsort_f_kind(op.kind) - else: - ret_sig = get_numba_type(node.outputs[0].type) + kind = op.kind - def argsort_f(a, axis): - with numba.objmode(ret=ret_sig): - ret = np.argsort(a, axis=axis, kind=op.kind) - return ret + if kind in ["quicksort", "mergesort"]: + return argsort_f_kind(kind) + else: + warnings.warn( + ( + f'Numba function argsort doesn\'t support kind="{op.kind}"' + " switching to `quicksort`." + ), + UserWarning, + ) - return argsort_f + return argsort_f_kind("quicksort") @numba.extending.intrinsic diff --git a/tests/link/numba/test_basic.py b/tests/link/numba/test_basic.py index 606d2114eb..aace623956 100644 --- a/tests/link/numba/test_basic.py +++ b/tests/link/numba/test_basic.py @@ -379,18 +379,47 @@ def test_Shape(x, i): compare_numba_and_py([], [g], []) -@pytest.mark.parametrize("kind", ["quicksort"]) -def test_Sort(kind): +@pytest.mark.parametrize( + "kind, exc", + [ + ["quicksort", None], + ["mergesort", UserWarning], + ["heapsort", UserWarning], + ["stable", UserWarning], + ], +) +def test_Sort(kind, exc): x = [5, 4, 3, 2, 1] + g = SortOp(kind)(pt.as_tensor_variable(x)) + + if exc: + with pytest.warns(exc): + compare_numba_and_py([], [g], []) + else: + compare_numba_and_py([], [g], []) + compare_numba_and_py([], [g], []) -@pytest.mark.parametrize("kind", ["quicksort", "mergesort"]) -def test_ArgSort(kind): +@pytest.mark.parametrize( + "kind, exc", + [ + ["quicksort", None], + ["mergesort", None], + ["heapsort", UserWarning], + ["stable", UserWarning], + ], +) +def test_ArgSort(kind, exc): x = [5, 4, 3, 2, 1] g = ArgSortOp(kind)(pt.as_tensor_variable(x)) - compare_numba_and_py([], [g], []) + + if exc: + with pytest.warns(exc): + compare_numba_and_py([], [g], []) + else: + compare_numba_and_py([], [g], []) @pytest.mark.parametrize( From d0159e53b7e25f42b5522e36f0edcfbc5309e62e Mon Sep 17 00:00:00 2001 From: Victor Garcia Reolid Date: Fri, 28 Mar 2025 21:40:09 +0100 Subject: [PATCH 3/8] feat: support axis Signed-off-by: Victor Garcia Reolid --- pytensor/link/numba/dispatch/basic.py | 34 ++++++++++++-- tests/link/numba/test_basic.py | 66 ++++++++++++++++----------- 2 files changed, 69 insertions(+), 31 deletions(-) diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py index ee1fd06b74..afebbce9a3 100644 --- a/pytensor/link/numba/dispatch/basic.py +++ b/pytensor/link/numba/dispatch/basic.py @@ -11,7 +11,7 @@ import scipy import scipy.special from llvmlite import ir -from numba import types +from numba import prange, types from numba.core.errors import NumbaWarning, TypingError from numba.cpython.unsafe.tuple import tuple_setitem # noqa: F401 from numba.extending import box, overload @@ -437,7 +437,14 @@ def shape_i(x): def numba_funcify_SortOp(op, node, **kwargs): @numba_njit def sort_f(a, axis): - return np.sort(a) # numba supports sort without arguments + if not isinstance(axis, int): + axis = -1 + + a_swapped = np.swapaxes(a, axis, -1) + a_sorted = np.sort(a_swapped) + a_sorted_swapped = np.swapaxes(a_sorted, -1, axis) + + return a_sorted_swapped if op.kind != "quicksort": warnings.warn( @@ -455,10 +462,27 @@ def sort_f(a, axis): def numba_funcify_ArgSortOp(op, node, **kwargs): def argsort_f_kind(kind): @numba_njit - def argsort_f(a, axis): - return np.argsort(a, kind=kind) + def argort_vec(X, axis): + if axis > len(X.shape): + raise ValueError("Wrong axis.") + + axis = axis.item() + + Y = np.swapaxes(X, axis, 0) + result = np.empty_like(Y) + + N = int(np.prod(np.array(Y.shape)[1:])) + indices = list(np.ndindex(Y.shape[1:])) + + for i in prange(N): + idx = indices[i] + result[:, *idx] = np.argsort(Y[:, *idx], kind=kind) + + result = np.swapaxes(result, 0, axis) + + return result - return argsort_f + return argort_vec kind = op.kind diff --git a/tests/link/numba/test_basic.py b/tests/link/numba/test_basic.py index aace623956..a923a20f60 100644 --- a/tests/link/numba/test_basic.py +++ b/tests/link/numba/test_basic.py @@ -380,45 +380,59 @@ def test_Shape(x, i): @pytest.mark.parametrize( - "kind, exc", + "x, axis, kind, exc", [ - ["quicksort", None], - ["mergesort", UserWarning], - ["heapsort", UserWarning], - ["stable", UserWarning], + [[3, 2, 1], None, "quicksort", None], + [[], None, "quicksort", None], + [[[3, 2, 1], [5, 6, 7]], None, "quicksort", None], + [[3, 2, 1], None, "mergesort", UserWarning], + [[3, 2, 1], None, "heapsort", UserWarning], + [[3, 2, 1], None, "stable", UserWarning], + [[[3, 2, 1], [5, 6, 7]], 0, "quicksort", None], + [[[3, 2, 1], [5, 6, 7]], 1, "quicksort", None], + [[[3, 2, 1], [5, 6, 7]], -1, "quicksort", None], + [[3, 2, 1], 0, "quicksort", None], + [np.random.randint(0, 100, (40, 40, 40, 40)), 3, "quicksort", None], ], ) -def test_Sort(kind, exc): - x = [5, 4, 3, 2, 1] +def test_Sort(x, axis, kind, exc): + if axis: + g = SortOp(kind)(pt.as_tensor_variable(x), axis) + else: + g = SortOp(kind)(pt.as_tensor_variable(x)) - g = SortOp(kind)(pt.as_tensor_variable(x)) + cm = contextlib.suppress() if not exc else pytest.warns(exc) - if exc: - with pytest.warns(exc): - compare_numba_and_py([], [g], []) - else: + with cm: compare_numba_and_py([], [g], []) - compare_numba_and_py([], [g], []) - @pytest.mark.parametrize( - "kind, exc", + "x, axis, kind, exc", [ - ["quicksort", None], - ["mergesort", None], - ["heapsort", UserWarning], - ["stable", UserWarning], + [[3, 2, 1], None, "quicksort", None], + [[], None, "quicksort", None], + [[[3, 2, 1], [5, 6, 7]], None, "quicksort", None], + [[3, 2, 1], None, "heapsort", UserWarning], + [[3, 2, 1], None, "stable", UserWarning], + [[[3, 2, 1], [5, 6, 7]], 0, "quicksort", None], + [[[3, 2, 1], [5, 6, 7]], None, "quicksort", None], + [[[3, 2, 1], [5, 6, 7]], 1, "quicksort", None], + [[[3, 2, 1], [5, 6, 7]], -1, "quicksort", None], + [[3, 2, 1], 0, "quicksort", None], + [np.random.randint(0, 10, (3, 2, 3)), 1, "quicksort", None], + [np.random.randint(0, 10, (3, 2, 3, 4, 4)), 2, "quicksort", None], ], ) -def test_ArgSort(kind, exc): - x = [5, 4, 3, 2, 1] - g = ArgSortOp(kind)(pt.as_tensor_variable(x)) - - if exc: - with pytest.warns(exc): - compare_numba_and_py([], [g], []) +def test_ArgSort(x, axis, kind, exc): + if axis: + g = ArgSortOp(kind)(pt.as_tensor_variable(x), axis) else: + g = ArgSortOp(kind)(pt.as_tensor_variable(x)) + + cm = contextlib.suppress() if not exc else pytest.warns(exc) + + with cm: compare_numba_and_py([], [g], []) From cf75225851935f0ca45b2b20e6bb35a56ab85342 Mon Sep 17 00:00:00 2001 From: Victor Garcia Reolid Date: Sun, 30 Mar 2025 11:57:36 +0200 Subject: [PATCH 4/8] use syntax compatible with python 3.10 Signed-off-by: Victor Garcia Reolid --- pytensor/link/numba/dispatch/basic.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py index afebbce9a3..ee72647db0 100644 --- a/pytensor/link/numba/dispatch/basic.py +++ b/pytensor/link/numba/dispatch/basic.py @@ -476,7 +476,9 @@ def argort_vec(X, axis): for i in prange(N): idx = indices[i] - result[:, *idx] = np.argsort(Y[:, *idx], kind=kind) + result[(slice(None), *idx)] = np.argsort( + Y[(slice(None), *idx)], kind=kind + ) result = np.swapaxes(result, 0, axis) From 0eb3b7e89ef18b4b7f21b02df79a22c0ff6b49e0 Mon Sep 17 00:00:00 2001 From: Victor Garcia Reolid Date: Mon, 31 Mar 2025 08:41:58 +0200 Subject: [PATCH 5/8] remove checks Signed-off-by: Victor Garcia Reolid --- pytensor/link/numba/dispatch/basic.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py index ee72647db0..e4c9bbb9ca 100644 --- a/pytensor/link/numba/dispatch/basic.py +++ b/pytensor/link/numba/dispatch/basic.py @@ -437,8 +437,7 @@ def shape_i(x): def numba_funcify_SortOp(op, node, **kwargs): @numba_njit def sort_f(a, axis): - if not isinstance(axis, int): - axis = -1 + axis = axis.item() a_swapped = np.swapaxes(a, axis, -1) a_sorted = np.sort(a_swapped) @@ -463,9 +462,6 @@ def numba_funcify_ArgSortOp(op, node, **kwargs): def argsort_f_kind(kind): @numba_njit def argort_vec(X, axis): - if axis > len(X.shape): - raise ValueError("Wrong axis.") - axis = axis.item() Y = np.swapaxes(X, axis, 0) From 216f6b2ba8bcbda275b9df92fdebda9804e699cc Mon Sep 17 00:00:00 2001 From: Victor Garcia Reolid Date: Mon, 31 Mar 2025 08:43:06 +0200 Subject: [PATCH 6/8] use range instead of prange Signed-off-by: Victor Garcia Reolid --- pytensor/link/numba/dispatch/basic.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py index e4c9bbb9ca..9f605d47e6 100644 --- a/pytensor/link/numba/dispatch/basic.py +++ b/pytensor/link/numba/dispatch/basic.py @@ -11,7 +11,7 @@ import scipy import scipy.special from llvmlite import ir -from numba import prange, types +from numba import types from numba.core.errors import NumbaWarning, TypingError from numba.cpython.unsafe.tuple import tuple_setitem # noqa: F401 from numba.extending import box, overload @@ -467,11 +467,9 @@ def argort_vec(X, axis): Y = np.swapaxes(X, axis, 0) result = np.empty_like(Y) - N = int(np.prod(np.array(Y.shape)[1:])) indices = list(np.ndindex(Y.shape[1:])) - for i in prange(N): - idx = indices[i] + for idx in indices: result[(slice(None), *idx)] = np.argsort( Y[(slice(None), *idx)], kind=kind ) From daa6730396437018ad92aece0d445107b91af292 Mon Sep 17 00:00:00 2001 From: Victor Garcia Reolid Date: Mon, 31 Mar 2025 08:47:16 +0200 Subject: [PATCH 7/8] add extra case to check Axis error is raised Signed-off-by: Victor Garcia Reolid --- tests/link/numba/test_basic.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/tests/link/numba/test_basic.py b/tests/link/numba/test_basic.py index a923a20f60..3605a0a301 100644 --- a/tests/link/numba/test_basic.py +++ b/tests/link/numba/test_basic.py @@ -393,6 +393,7 @@ def test_Shape(x, i): [[[3, 2, 1], [5, 6, 7]], -1, "quicksort", None], [[3, 2, 1], 0, "quicksort", None], [np.random.randint(0, 100, (40, 40, 40, 40)), 3, "quicksort", None], + [[3, 2, 1], -5, "quicksort", np.exceptions.AxisError], ], ) def test_Sort(x, axis, kind, exc): @@ -401,7 +402,13 @@ def test_Sort(x, axis, kind, exc): else: g = SortOp(kind)(pt.as_tensor_variable(x)) - cm = contextlib.suppress() if not exc else pytest.warns(exc) + cm = ( + contextlib.suppress() + if not exc + else pytest.warns(exc) + if isinstance(exc, Warning) + else pytest.raises(exc) + ) with cm: compare_numba_and_py([], [g], []) From 87269552f45b881fdf2fd5ca0d452e1586c9ae03 Mon Sep 17 00:00:00 2001 From: Victor Garcia Reolid Date: Tue, 1 Apr 2025 18:19:33 +0200 Subject: [PATCH 8/8] simplify tests Signed-off-by: Victor Garcia Reolid --- pytensor/link/numba/dispatch/basic.py | 7 ++- tests/link/numba/test_basic.py | 66 +++++++++++++-------------- 2 files changed, 36 insertions(+), 37 deletions(-) diff --git a/pytensor/link/numba/dispatch/basic.py b/pytensor/link/numba/dispatch/basic.py index 9f605d47e6..92bc44739f 100644 --- a/pytensor/link/numba/dispatch/basic.py +++ b/pytensor/link/numba/dispatch/basic.py @@ -482,9 +482,8 @@ def argort_vec(X, axis): kind = op.kind - if kind in ["quicksort", "mergesort"]: - return argsort_f_kind(kind) - else: + if kind not in ["quicksort", "mergesort"]: + kind = "quicksort" warnings.warn( ( f'Numba function argsort doesn\'t support kind="{op.kind}"' @@ -493,7 +492,7 @@ def argort_vec(X, axis): UserWarning, ) - return argsort_f_kind("quicksort") + return argsort_f_kind(kind) @numba.extending.intrinsic diff --git a/tests/link/numba/test_basic.py b/tests/link/numba/test_basic.py index 3605a0a301..101dd393d3 100644 --- a/tests/link/numba/test_basic.py +++ b/tests/link/numba/test_basic.py @@ -380,20 +380,21 @@ def test_Shape(x, i): @pytest.mark.parametrize( - "x, axis, kind, exc", + "x", [ - [[3, 2, 1], None, "quicksort", None], - [[], None, "quicksort", None], - [[[3, 2, 1], [5, 6, 7]], None, "quicksort", None], - [[3, 2, 1], None, "mergesort", UserWarning], - [[3, 2, 1], None, "heapsort", UserWarning], - [[3, 2, 1], None, "stable", UserWarning], - [[[3, 2, 1], [5, 6, 7]], 0, "quicksort", None], - [[[3, 2, 1], [5, 6, 7]], 1, "quicksort", None], - [[[3, 2, 1], [5, 6, 7]], -1, "quicksort", None], - [[3, 2, 1], 0, "quicksort", None], - [np.random.randint(0, 100, (40, 40, 40, 40)), 3, "quicksort", None], - [[3, 2, 1], -5, "quicksort", np.exceptions.AxisError], + [], # Empty list + [3, 2, 1], # Simple list + np.random.randint(0, 10, (3, 2, 3, 4, 4)), # Multi-dimensional array + ], +) +@pytest.mark.parametrize("axis", [0, -1, None]) +@pytest.mark.parametrize( + ("kind", "exc"), + [ + ["quicksort", None], + ["mergesort", UserWarning], + ["heapsort", UserWarning], + ["stable", UserWarning], ], ) def test_Sort(x, axis, kind, exc): @@ -402,36 +403,35 @@ def test_Sort(x, axis, kind, exc): else: g = SortOp(kind)(pt.as_tensor_variable(x)) - cm = ( - contextlib.suppress() - if not exc - else pytest.warns(exc) - if isinstance(exc, Warning) - else pytest.raises(exc) - ) + cm = contextlib.suppress() if not exc else pytest.warns(exc) with cm: compare_numba_and_py([], [g], []) @pytest.mark.parametrize( - "x, axis, kind, exc", + "x", + [ + [], # Empty list + [3, 2, 1], # Simple list + None, # Multi-dimensional array (see below) + ], +) +@pytest.mark.parametrize("axis", [0, -1, None]) +@pytest.mark.parametrize( + ("kind", "exc"), [ - [[3, 2, 1], None, "quicksort", None], - [[], None, "quicksort", None], - [[[3, 2, 1], [5, 6, 7]], None, "quicksort", None], - [[3, 2, 1], None, "heapsort", UserWarning], - [[3, 2, 1], None, "stable", UserWarning], - [[[3, 2, 1], [5, 6, 7]], 0, "quicksort", None], - [[[3, 2, 1], [5, 6, 7]], None, "quicksort", None], - [[[3, 2, 1], [5, 6, 7]], 1, "quicksort", None], - [[[3, 2, 1], [5, 6, 7]], -1, "quicksort", None], - [[3, 2, 1], 0, "quicksort", None], - [np.random.randint(0, 10, (3, 2, 3)), 1, "quicksort", None], - [np.random.randint(0, 10, (3, 2, 3, 4, 4)), 2, "quicksort", None], + ["quicksort", None], + ["heapsort", None], + ["stable", UserWarning], ], ) def test_ArgSort(x, axis, kind, exc): + if x is None: + x = np.arange(5 * 5 * 5 * 5) + np.random.shuffle(x) + x = np.reshape(x, (5, 5, 5, 5)) + if axis: g = ArgSortOp(kind)(pt.as_tensor_variable(x), axis) else: