Skip to content

Commit 27d04a1

Browse files
authored
New function for applying vectorized functions for unlabeled arrays to xarray objects (#964)
* WIP: apply_ufunc for applying generic functions to xarray objects * Move operations.py -> computation.py * Rewrite _build_and_check_signature * More work on signatures, including auto-dask * apply_groupby_ufunc * fixes * add gufunc string signature parsing * Add some basic tests (and get them passing) * Signature parse tests * more docs, work on Signature parsing * add check/test for unexpected dimensions * tests for core dimensions * build_output_coords fixes * Tests and fixes for binary ops * performance improvements * GroupBy ufunc fixes * More succinct tests & fixes for groupby ufuncs * More optimizations * Tests for _calculate_unified_dim_sizes * apply_dask_ufunc tests * Tests for dask.array * Fix function signatures * Tweaks * Tests and full docstring for apply_ufunc * Switch new_coords to a list * Add exclude_dims * Move deep_align back into alignment.py * Lint * changes per MaximilianR's review * style fixes * Docstrings * Rename apply_ufunc to apply * Fix recursion bug in test * Simpler exclude_dims * WIP * Fix test failures * Remove unused function. Rename apply to apply_ufunc. * Remove extraneous "new_coords" argument
1 parent 21a792d commit 27d04a1

File tree

7 files changed

+1362
-82
lines changed

7 files changed

+1362
-82
lines changed

xarray/core/alignment.py

Lines changed: 74 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from .common import _maybe_promote
1313
from .indexing import get_indexer
1414
from .pycompat import iteritems, OrderedDict, suppress
15-
from .utils import is_full_slice
15+
from .utils import is_full_slice, is_dict_like
1616
from .variable import Variable, IndexVariable
1717

1818

@@ -29,8 +29,12 @@ def _get_joiner(join):
2929
raise ValueError('invalid value for join: %s' % join)
3030

3131

32+
_DEFAULT_EXCLUDE = frozenset()
33+
34+
3235
def align(*objects, **kwargs):
33-
"""align(*objects, join='inner', copy=True)
36+
"""align(*objects, join='inner', copy=True, indexes=None,
37+
exclude=frozenset())
3438
3539
Given any number of Dataset and/or DataArray objects, returns new
3640
objects with aligned indexes and dimension sizes.
@@ -76,15 +80,18 @@ def align(*objects, **kwargs):
7680
join = kwargs.pop('join', 'inner')
7781
copy = kwargs.pop('copy', True)
7882
indexes = kwargs.pop('indexes', None)
79-
exclude = kwargs.pop('exclude', None)
83+
exclude = kwargs.pop('exclude', _DEFAULT_EXCLUDE)
8084
if indexes is None:
8185
indexes = {}
82-
if exclude is None:
83-
exclude = set()
8486
if kwargs:
8587
raise TypeError('align() got unexpected keyword arguments: %s'
8688
% list(kwargs))
8789

90+
if not indexes and len(objects) == 1:
91+
# fast path for the trivial case
92+
obj, = objects
93+
return (obj.copy(deep=copy),)
94+
8895
all_indexes = defaultdict(list)
8996
unlabeled_dim_sizes = defaultdict(set)
9097
for obj in objects:
@@ -142,11 +149,72 @@ def align(*objects, **kwargs):
142149
for obj in objects:
143150
valid_indexers = {k: v for k, v in joined_indexes.items()
144151
if k in obj.dims}
145-
result.append(obj.reindex(copy=copy, **valid_indexers))
152+
if not valid_indexers:
153+
# fast path for no reindexing necessary
154+
new_obj = obj.copy(deep=copy)
155+
else:
156+
new_obj = obj.reindex(copy=copy, **valid_indexers)
157+
result.append(new_obj)
146158

147159
return tuple(result)
148160

149161

162+
def deep_align(objects, join='inner', copy=True, indexes=None,
163+
exclude=frozenset(), raise_on_invalid=True):
164+
"""Align objects for merging, recursing into dictionary values.
165+
166+
This function is not public API.
167+
"""
168+
if indexes is None:
169+
indexes = {}
170+
171+
def is_alignable(obj):
172+
return hasattr(obj, 'indexes') and hasattr(obj, 'reindex')
173+
174+
positions = []
175+
keys = []
176+
out = []
177+
targets = []
178+
no_key = object()
179+
not_replaced = object()
180+
for n, variables in enumerate(objects):
181+
if is_alignable(variables):
182+
positions.append(n)
183+
keys.append(no_key)
184+
targets.append(variables)
185+
out.append(not_replaced)
186+
elif is_dict_like(variables):
187+
for k, v in variables.items():
188+
if is_alignable(v) and k not in indexes:
189+
# Skip variables in indexes for alignment, because these
190+
# should to be overwritten instead:
191+
# https://github.com/pydata/xarray/issues/725
192+
positions.append(n)
193+
keys.append(k)
194+
targets.append(v)
195+
out.append(OrderedDict(variables))
196+
elif raise_on_invalid:
197+
raise ValueError('object to align is neither an xarray.Dataset, '
198+
'an xarray.DataArray nor a dictionary: %r'
199+
% variables)
200+
else:
201+
out.append(variables)
202+
203+
aligned = align(*targets, join=join, copy=copy, indexes=indexes,
204+
exclude=exclude)
205+
206+
for position, key, aligned_obj in zip(positions, keys, aligned):
207+
if key is no_key:
208+
out[position] = aligned_obj
209+
else:
210+
out[position][key] = aligned_obj
211+
212+
# something went wrong: we should have replaced all sentinel values
213+
assert all(arg is not not_replaced for arg in out)
214+
215+
return out
216+
217+
150218
def reindex_like_indexers(target, other):
151219
"""Extract indexers to align target with other.
152220

0 commit comments

Comments
 (0)