-
Notifications
You must be signed in to change notification settings - Fork 0
Description
This is an idea to implement named dimensions/axes to any array API compatible backend. Named axes can be useful in many ways, especially for readability and reducing potential errors. There’s a nice writeup of the HarvardNLP group highlighting various benefits: https://nlp.seas.harvard.edu/NamedTensor.html.
Based on my research, the following libraries implement named axes (please suggest to add more):
- NamedTensor: https://github.com/harvardnlp/NamedTensor
- XArray: https://docs.xarray.dev/en/stable/user-guide/indexing.html
- PyTorch: https://pytorch.org/docs/stable/name_inference.html#name-inference-reference-doc
- Awkward Array: https://awkward-array.org/doc/main/user-guide/how-to-array-properties-named-axis.html
- hist: https://hist.readthedocs.io/en/latest/
- Penzai: https://penzai.readthedocs.io/en/stable/notebooks/named_axes.html
- haliax: https://github.com/stanford-crfm/haliax
Named axes could extend the Array API in the following way (inspired by #2):
from named_array import array_api_strict as nxp
x = nxp.asarray([[1, 2], [3, 4]], named_axes=["height", "width"])
print(nxp.sum(x, axis="width")) # array([3, 7])
In addition, this enables a new dict-based indexing syntax that many of the above listed packages implement:
import numpy as np
print(x[{"width": 0}]) # array([1, 3])
print(x[{"width": np.s_[0:1]}]) # array([[1], [3]])
This allows to specify slices in certain dimensions based on their names.
Best, Peter
PS: If people like this "new" named indexing, maybe we find some people to revive PEP 472 (https://peps.python.org/pep-0472/) given that indexing syntax with keywords recently became more-or-less possible in the type annotations for generics with bounds/constraints (https://docs.python.org/3/reference/compound_stmts.html#type-params), and a lot more applications/packages exist now compared to before with named axes?
The new syntax could look then like this:
print(x[width=0]) # array([1, 3])
print(x[width=0:1]) # array([[1], [3]])