Skip to content

Support for __array_function__ implementers (sparse arrays) [WIP] #3117

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 31 commits into from
Aug 5, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
5b5e245
Support for __array_function__ implementers
nvictus Jul 13, 2019
ad403fc
Pep8
nvictus Jul 13, 2019
df618ad
Consistent naming
nvictus Jul 14, 2019
1b62f6d
Check for NEP18 enabled and nep18 non-numpy arrays
nvictus Jul 14, 2019
21c0a21
Replace .values with .data
nvictus Jul 14, 2019
d62820b
Add initial test for nep18
nvictus Jul 14, 2019
0493263
Fix linting issues
nvictus Jul 14, 2019
ad42627
Add parameterized tests
nvictus Jul 14, 2019
ec01625
Internal clean-up of isnull() to avoid relying on pandas
shoyer Jul 15, 2019
fd05566
Merge branch 'master' into isnull-duck
shoyer Jul 15, 2019
5d8edfd
Add sparse to ci requirements
nvictus Jul 17, 2019
7c7a9f2
Moar tests
nvictus Jul 17, 2019
a231571
Two more patches for __array_function__ duck-arrays
nvictus Jul 17, 2019
4009379
Don't use coords attribute from duck-arrays that aren't derived from …
nvictus Jul 17, 2019
2c3b183
Improve checking for coords, and autopep8
nvictus Jul 17, 2019
23633b4
Skip tests if NEP-18 envvar is not set
nvictus Jul 17, 2019
38c4717
flake8
nvictus Jul 17, 2019
e9d9c41
Update xarray/core/dataarray.py
nvictus Jul 18, 2019
3f1eec2
Fix coords parsing
nvictus Jul 19, 2019
56864c3
More tests
nvictus Jul 19, 2019
1876892
Add align tests
nvictus Jul 19, 2019
ef51976
Replace nep18 tests with more extensive tests on pydata/sparse
nvictus Aug 2, 2019
86ee35c
Merge remote-tracking branch 'shoyer/isnull-duck' into sparse_xarrays
nvictus Aug 2, 2019
3afe7e2
Add xfails for missing np.result_type (fixed by pydata/sparse/pull/261)
nvictus Aug 2, 2019
33a07e7
Fix xpasses
nvictus Aug 2, 2019
5a817a8
Revert isnull/notnull
nvictus Aug 4, 2019
5675819
Fix as_like_arrays by coercing dense arrays to COO if any sparse
nvictus Aug 4, 2019
96e1346
Make Variable.load a no-op for non-dask duck arrays
nvictus Aug 5, 2019
563148c
Merge branch 'master' into sparse_xarrays
nvictus Aug 5, 2019
66c2f82
Add additional method tests
nvictus Aug 5, 2019
b353a0b
Fix utils.as_scalar to handle duck arrays with ndim>0
nvictus Aug 5, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ci/requirements/py37.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ dependencies:
- pip
- scipy
- seaborn
- sparse
- toolz
Copy link
Contributor

@crusaderky crusaderky Aug 8, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did you exclude py36 and Windows?

- rasterio
- boto3
Expand Down
7 changes: 5 additions & 2 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,15 +260,18 @@ def __init__(
else:
# try to fill in arguments from data if they weren't supplied
if coords is None:
coords = getattr(data, 'coords', None)
if isinstance(data, pd.Series):

if isinstance(data, DataArray):
coords = data.coords
elif isinstance(data, pd.Series):
coords = [data.index]
elif isinstance(data, pd.DataFrame):
coords = [data.index, data.columns]
elif isinstance(data, (pd.Index, IndexVariable)):
coords = [data]
elif isinstance(data, pdcompat.Panel):
coords = [data.items, data.major_axis, data.minor_axis]

if dims is None:
dims = getattr(data, 'dims', getattr(coords, 'dims', None))
if name is None:
Expand Down
12 changes: 10 additions & 2 deletions xarray/core/duck_array_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from . import dask_array_ops, dtypes, npcompat, nputils
from .nputils import nanfirst, nanlast
from .pycompat import dask_array_type
from .pycompat import dask_array_type, sparse_array_type

try:
import dask.array as dask_array
Expand Down Expand Up @@ -64,6 +64,7 @@ def fail_on_dask_array_input(values, msg=None, func_name=None):
around = _dask_or_eager_func('around')
isclose = _dask_or_eager_func('isclose')


if hasattr(np, 'isnat') and (
dask_array is None or hasattr(dask_array_type, '__array_ufunc__')):
# np.isnat is available since NumPy 1.13, so __array_ufunc__ is always
Expand Down Expand Up @@ -153,7 +154,11 @@ def trapz(y, x, axis):


def asarray(data):
return data if isinstance(data, dask_array_type) else np.asarray(data)
return (
data if (isinstance(data, dask_array_type)
or hasattr(data, '__array_function__'))
else np.asarray(data)
)


def as_shared_dtype(scalars_or_arrays):
Expand All @@ -170,6 +175,9 @@ def as_shared_dtype(scalars_or_arrays):
def as_like_arrays(*data):
if all(isinstance(d, dask_array_type) for d in data):
return data
elif any(isinstance(d, sparse_array_type) for d in data):
from sparse import COO
return tuple(COO(d) for d in data)
else:
return tuple(np.asarray(d) for d in data)

Expand Down
7 changes: 5 additions & 2 deletions xarray/core/formatting.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,10 @@ def set_numpy_options(*args, **kwargs):


def short_array_repr(array):
array = np.asarray(array)

if not hasattr(array, '__array_function__'):
array = np.asarray(array)

# default to lower precision so a full (abbreviated) line can fit on
# one line with the default display_width
options = {
Expand Down Expand Up @@ -394,7 +397,7 @@ def short_data_repr(array):
if isinstance(getattr(array, 'variable', array)._data, dask_array_type):
return short_dask_repr(array)
elif array._in_memory or array.size < 1e5:
return short_array_repr(array.values)
return short_array_repr(array.data)
else:
return u'[{} values with dtype={}]'.format(array.size, array.dtype)

Expand Down
13 changes: 13 additions & 0 deletions xarray/core/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,6 +657,9 @@ def as_indexable(array):
return PandasIndexAdapter(array)
if isinstance(array, dask_array_type):
return DaskIndexingAdapter(array)
if hasattr(array, '__array_function__'):
return NdArrayLikeIndexingAdapter(array)

raise TypeError('Invalid array type: {}'.format(type(array)))


Expand Down Expand Up @@ -1189,6 +1192,16 @@ def __setitem__(self, key, value):
raise


class NdArrayLikeIndexingAdapter(NumpyIndexingAdapter):
def __init__(self, array):
if not hasattr(array, '__array_function__'):
raise TypeError(
'NdArrayLikeIndexingAdapter must wrap an object that '
'implements the __array_function__ protocol'
)
self.array = array


class DaskIndexingAdapter(ExplicitlyIndexedNDArrayMixin):
"""Wrap a dask array to support explicit indexing."""

Expand Down
15 changes: 15 additions & 0 deletions xarray/core/npcompat.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,3 +357,18 @@ def moveaxis(a, source, destination):
# https://github.com/numpy/numpy/issues/7370
# https://github.com/numpy/numpy-stubs/
DTypeLike = Union[np.dtype, str]


# from dask/array/utils.py
def _is_nep18_active():
class A:
def __array_function__(self, *args, **kwargs):
return True

try:
return np.concatenate([A()])
except ValueError:
return False


IS_NEP18_ACTIVE = _is_nep18_active()
7 changes: 7 additions & 0 deletions xarray/core/pycompat.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,10 @@
dask_array_type = (dask.array.Array,)
except ImportError: # pragma: no cover
dask_array_type = ()

try:
# solely for isinstance checks
import sparse
sparse_array_type = (sparse.SparseArray,)
except ImportError: # pragma: no cover
sparse_array_type = ()
4 changes: 3 additions & 1 deletion xarray/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,9 @@ def is_scalar(value: Any) -> bool:
return (
getattr(value, 'ndim', None) == 0 or
isinstance(value, (str, bytes)) or not
isinstance(value, (Iterable, ) + dask_array_type))
(isinstance(value, (Iterable, ) + dask_array_type) or
hasattr(value, '__array_function__'))
)


def is_valid_numpy_dtype(dtype: Any) -> bool:
Expand Down
39 changes: 27 additions & 12 deletions xarray/core/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
as_indexable)
from .options import _get_keep_attrs
from .pycompat import dask_array_type, integer_types
from .npcompat import IS_NEP18_ACTIVE
from .utils import (
OrderedSet, decode_numpy_dict_values, either_dict_or_kwargs,
ensure_us_time_resolution)
Expand Down Expand Up @@ -179,6 +180,18 @@ def as_compatible_data(data, fastpath=False):
else:
data = np.asarray(data)

if not isinstance(data, np.ndarray):
if hasattr(data, '__array_function__'):
if IS_NEP18_ACTIVE:
return data
else:
raise TypeError(
'Got an NumPy-like array type providing the '
'__array_function__ protocol but NEP18 is not enabled. '
'Check that numpy >= v1.16 and that the environment '
'variable "NUMPY_EXPERIMENTAL_ARRAY_FUNCTION" is set to '
'"1"')

# validate whether the data is valid data types
data = np.asarray(data)

Expand Down Expand Up @@ -288,7 +301,7 @@ def _in_memory(self):

@property
def data(self):
if isinstance(self._data, dask_array_type):
if hasattr(self._data, '__array_function__'):
return self._data
else:
return self.values
Expand Down Expand Up @@ -320,7 +333,7 @@ def load(self, **kwargs):
"""
if isinstance(self._data, dask_array_type):
self._data = as_compatible_data(self._data.compute(**kwargs))
elif not isinstance(self._data, np.ndarray):
elif not hasattr(self._data, '__array_function__'):
self._data = np.asarray(self._data)
return self

Expand Down Expand Up @@ -705,8 +718,8 @@ def __setitem__(self, key, value):

if new_order:
value = duck_array_ops.asarray(value)
value = value[(len(dims) - value.ndim) * (np.newaxis,) +
(Ellipsis,)]
value = value[(len(dims) - value.ndim) * (np.newaxis,)
+ (Ellipsis,)]
value = duck_array_ops.moveaxis(
value, new_order, range(len(new_order)))

Expand Down Expand Up @@ -805,7 +818,8 @@ def copy(self, deep=True, data=None):
data = indexing.MemoryCachedArray(data.array)

if deep:
if isinstance(data, dask_array_type):
if (hasattr(data, '__array_function__')
or isinstance(data, dask_array_type)):
data = data.copy()
elif not isinstance(data, PandasIndexAdapter):
# pandas.Index is immutable
Expand Down Expand Up @@ -1494,9 +1508,10 @@ def equals(self, other, equiv=duck_array_ops.array_equiv):
"""
other = getattr(other, 'variable', other)
try:
return (self.dims == other.dims and
(self._data is other._data or
equiv(self.data, other.data)))
return (
self.dims == other.dims and
(self._data is other._data or equiv(self.data, other.data))
)
except (TypeError, AttributeError):
return False

Expand All @@ -1517,8 +1532,8 @@ def identical(self, other):
"""Like equals, but also checks attributes.
"""
try:
return (utils.dict_equiv(self.attrs, other.attrs) and
self.equals(other))
return (utils.dict_equiv(self.attrs, other.attrs)
and self.equals(other))
except (TypeError, AttributeError):
return False

Expand Down Expand Up @@ -1959,8 +1974,8 @@ def equals(self, other, equiv=None):
# otherwise use the native index equals, rather than looking at _data
other = getattr(other, 'variable', other)
try:
return (self.dims == other.dims and
self._data_equals(other))
return (self.dims == other.dims
and self._data_equals(other))
except (TypeError, AttributeError):
return False

Expand Down
Loading