diff --git a/cubed/array_api/array_object.py b/cubed/array_api/array_object.py index 5dc78c9b3..62435600f 100644 --- a/cubed/array_api/array_object.py +++ b/cubed/array_api/array_object.py @@ -29,6 +29,24 @@ 120 # cubed doesn't have a config module like dask does so hard-code this for now ) +_HANDLED_FUNCTIONS = {} + + +def implements(*numpy_functions): + """Register an __array_function__ implementation for cubed.Array + + Note that this is **only** used for functions that are not defined in the + Array API Standard. + """ + + def decorator(cubed_func): + for numpy_function in numpy_functions: + _HANDLED_FUNCTIONS[numpy_function] = cubed_func + + return cubed_func + + return decorator + class Array(CoreArray): """Chunked array backed by Zarr storage that conforms to the Python Array API standard.""" @@ -44,6 +62,12 @@ def __array__(self, dtype=None) -> np.ndarray: x = np.array(x) return x + def __array_function__(self, func, types, args, kwargs): + # Only dispatch to functions that are not defined in the Array API Standard + if func in _HANDLED_FUNCTIONS: + return _HANDLED_FUNCTIONS[func](*args, **kwargs) + return NotImplemented + def __repr__(self): return f"cubed.Array<{self.name}, shape={self.shape}, dtype={self.dtype}, chunks={self.chunks}>" diff --git a/cubed/nan_functions.py b/cubed/nan_functions.py index 928a2f41d..a493b9fff 100644 --- a/cubed/nan_functions.py +++ b/cubed/nan_functions.py @@ -1,5 +1,7 @@ import numpy as np +from cubed.array_api.array_object import implements +from cubed.array_api.creation_functions import asarray from cubed.array_api.dtypes import ( _numeric_dtypes, _signed_integer_dtypes, @@ -18,9 +20,10 @@ # https://github.com/data-apis/array-api/issues/621 -def nanmean(x, /, *, axis=None, keepdims=False, split_every=None): +@implements(np.nanmean) +def nanmean(x, /, *, axis=None, dtype=None, keepdims=False, split_every=None): """Compute the arithmetic mean along the specified axis, ignoring NaNs.""" - dtype = x.dtype + dtype = dtype or x.dtype intermediate_dtype = [("n", nxp.int64), ("total", nxp.float64)] return reduction( x, @@ -60,6 +63,7 @@ def _nannumel(x, **kwargs): return nxp.sum(~(nxp.isnan(x)), **kwargs) +@implements(np.nansum) def nansum(x, /, *, axis=None, dtype=None, keepdims=False, split_every=None): """Return the sum of array elements over a given axis treating NaNs as zero.""" if x.dtype not in _numeric_dtypes: @@ -83,3 +87,12 @@ def nansum(x, /, *, axis=None, dtype=None, keepdims=False, split_every=None): keepdims=keepdims, split_every=split_every, ) + + +@implements(np.isclose) +def isclose(a, b, rtol=1.0e-5, atol=1.0e-8, equal_nan=False): + # Note: this should only be used for testing small arrays since it + # materialize arrays in memory + na = nxp.asarray(a) + nb = nxp.asarray(b) + return asarray(nxp.isclose(na, nb, rtol=rtol, atol=atol, equal_nan=equal_nan)) diff --git a/cubed/pad.py b/cubed/pad.py index c292c65e5..afb0c4d95 100644 --- a/cubed/pad.py +++ b/cubed/pad.py @@ -1,6 +1,13 @@ +import numpy as np + +from cubed.array_api.array_object import implements from cubed.array_api.manipulation_functions import concat +# TODO: refactor once pad is standardized: +# https://github.com/data-apis/array-api/issues/187 + +@implements(np.pad) def pad(x, pad_width, mode=None, chunks=None): """Pad an array.""" if len(pad_width) != x.ndim: diff --git a/cubed/tests/runtime/utils.py b/cubed/tests/runtime/utils.py index 238218aa5..f8da354bb 100644 --- a/cubed/tests/runtime/utils.py +++ b/cubed/tests/runtime/utils.py @@ -56,7 +56,7 @@ def deterministic_failure(path, timing_map, i, *, default_sleep=0.01, name=None) else: time.sleep(-timing_code) raise RuntimeError( - f"Deliberately fail on invocation number {invocation_count+1} for input {i}" + f"Deliberately fail on invocation number {invocation_count + 1} for input {i}" ) diff --git a/cubed/tests/test_nan_functions.py b/cubed/tests/test_nan_functions.py index 53264e791..f67ce71d3 100644 --- a/cubed/tests/test_nan_functions.py +++ b/cubed/tests/test_nan_functions.py @@ -11,9 +11,11 @@ def spec(tmp_path): return cubed.Spec(tmp_path, allowed_mem=100000) -def test_nanmean(spec): +@pytest.mark.parametrize("namespace", [cubed, np]) +def test_nanmean(spec, namespace): a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, xp.nan]], chunks=(2, 2), spec=spec) - b = cubed.nanmean(a) + b = namespace.nanmean(a) + assert isinstance(b, cubed.Array) assert_array_equal( b.compute(), np.nanmean(np.array([[1, 2, 3], [4, 5, 6], [7, 8, np.nan]])) ) @@ -26,9 +28,11 @@ def test_nanmean_allnan(spec): assert_array_equal(b.compute(), np.nanmean(np.array([np.nan]))) -def test_nansum(spec): +@pytest.mark.parametrize("namespace", [cubed, np]) +def test_nansum(spec, namespace): a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, xp.nan]], chunks=(2, 2), spec=spec) - b = cubed.nansum(a) + b = namespace.nansum(a) + assert isinstance(b, cubed.Array) assert_array_equal( b.compute(), np.nansum(np.array([[1, 2, 3], [4, 5, 6], [7, 8, np.nan]])) ) diff --git a/cubed/tests/test_pad.py b/cubed/tests/test_pad.py index 7ba985f46..027348da7 100644 --- a/cubed/tests/test_pad.py +++ b/cubed/tests/test_pad.py @@ -11,11 +11,15 @@ def spec(tmp_path): return cubed.Spec(tmp_path, allowed_mem=100000) -def test_pad(spec): +@pytest.mark.parametrize("namespace", [cubed, np]) +def test_pad(spec, namespace): an = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 2), spec=spec) - b = cubed.pad(a, ((1, 0), (0, 0)), mode="symmetric") + # check that we can dispatch via the numpy namespace (via __array_function__) + # since pad is not yet a part of the Array API Standard + b = namespace.pad(a, ((1, 0), (0, 0)), mode="symmetric") + assert isinstance(b, cubed.Array) assert b.chunks == ((2, 2), (2, 1)) assert_array_equal(b.compute(), np.pad(an, ((1, 0), (0, 0)), mode="symmetric"))