Skip to content

Commit

Permalink
feat: add take_long_axis to specifiation
Browse files Browse the repository at this point in the history
PR-URL: #816
Closes: #808
  • Loading branch information
kgryte authored Sep 19, 2024
1 parent b877795 commit 390e9cc
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 1 deletion.
1 change: 1 addition & 0 deletions spec/draft/API_specification/indexing_functions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,4 @@ Objects in API
:template: method.rst

take
take_along_axis
26 changes: 25 additions & 1 deletion src/array_api_stubs/_draft/indexing_functions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__all__ = ["take"]
__all__ = ["take", "take_along_axis"]

from ._types import Union, Optional, array

Expand Down Expand Up @@ -38,3 +38,27 @@ def take(x: array, indices: array, /, *, axis: Optional[int] = None) -> array:
.. versionchanged:: 2023.12
Out-of-bounds behavior is explicitly left unspecified.
"""


def take_along_axis(x: array, indices: array, /, *, axis: int = -1) -> array:
"""
Returns elements from an array at the one-dimensional indices specified by ``indices`` along a provided ``axis``.
Parameters
----------
x: array
input array. Must be compatible with ``indices``, except for the axis (dimension) specified by ``axis`` (see :ref:`broadcasting`).
indices: array
array indices. Must have the same rank (i.e., number of dimensions) as ``x``.
.. note::
This specification does not require bounds checking. The behavior for out-of-bounds indices is left unspecified.
axis: int
axis along which to select values. If ``axis`` is negative, the function must determine the axis along which to select values by counting from the last dimension. Default: ``-1``.
Returns
-------
out: array
an array having the same data type as ``x``. Must have the same rank (i.e., number of dimensions) as ``x`` and must have a shape determined according to :ref:`broadcasting`, except for the axis (dimension) specified by ``axis`` whose size must equal the size of the corresponding axis (dimension) in ``indices``.
"""

0 comments on commit 390e9cc

Please sign in to comment.