Skip to content

Commit 330a9b5

Browse files
committed
apply_groupby_ufunc
1 parent 1e7005a commit 330a9b5

File tree

2 files changed

+77
-8
lines changed

2 files changed

+77
-8
lines changed

xarray/core/computation.py

Lines changed: 76 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,8 @@ def get_coord_variables(arg):
8888
return output
8989

9090

91-
def apply_dataarray(args, func, signature=None, join='inner',
92-
kwargs=None, new_coords=None, combine_names=None):
91+
def apply_dataarray_ufunc(args, func, signature=None, join='inner',
92+
kwargs=None, new_coords=None, combine_names=None):
9393
if signature is None:
9494
signature = _default_signature(len(args))
9595

@@ -137,8 +137,9 @@ def collect_dict_values(objects, keys, fill_value=None)
137137
return result_values
138138

139139

140-
def apply_dataset(args, func, signature=None, join='inner', fill_value=None,
141-
kwargs=None, new_coords=None, result_attrs=None):
140+
def apply_dataset_ufunc(args, func, signature=None, join='inner',
141+
fill_value=None, kwargs=None, new_coords=None,
142+
result_attrs=None):
142143
if kwargs is None:
143144
kwargs = {}
144145

@@ -194,6 +195,65 @@ def make_dataset(data_vars, coord_vars, attrs):
194195
return make_dataset(data_vars, coord_vars, attrs)
195196

196197

198+
199+
200+
def _iter_over_selections(obj, dim, values):
201+
"""Iterate over selections of an xarray object in the provided order.
202+
"""
203+
dummy = None
204+
for value in values:
205+
try:
206+
obj_sel = obj.sel(**{dim: values})
207+
except KeyError:
208+
if dim not in obj.dims:
209+
raise ValueError('incompatible dimensions for a grouped '
210+
'binary operation: the group variable %r '
211+
'is not a dimension on the other argument'
212+
% dim)
213+
if dummy is None:
214+
dummy = _dummy_copy(obj)
215+
obj_sel = dummy
216+
yield obj_sel
217+
218+
219+
def apply_groupby_ufunc(args, func):
220+
groupbys = [arg for arg in args if isinstance(GroupBy)]
221+
if not groupbys:
222+
raise ValueError('must have at least one groupby to iterate over')
223+
first_groupby = groups[0]
224+
if any(not first_groupby.unique_coord.equals(gb.unique_coord)
225+
for gb in groupbys[1:]):
226+
raise ValueError('can only perform operations over multiple groupbys '
227+
'at once if they have all the same unique coordinate')
228+
229+
grouped_dim = first_groupby.group.name
230+
unique_values = first_groupby.unique_coord.values
231+
232+
iterators = []
233+
for arg in args:
234+
if isinstance(arg, GroupBy):
235+
iterator = (value for _, value in arg)
236+
elif hasattr(arg, 'dims') and group_name in arg.dims:
237+
if isinstance(arg, Variable):
238+
raise ValueError(
239+
'groupby operations cannot be performed with '
240+
'xarray.Variable objects that share a dimension with '
241+
'the grouped dimension')
242+
iterator = _iter_over_selections(arg, grouped_dim, unique_vlaues)
243+
else:
244+
iterator = itertools.repeat(arg)
245+
iterators.append(iterator)
246+
247+
applied = (func(*zipped_args) for zipped_args in zip(iterators))
248+
applied_example, applied = peek_at(applied)
249+
combine = first_groupby._combined
250+
if isinstance(applied_example, tuple):
251+
combined = tuple(combine(output) for output in zip(*applied))
252+
else:
253+
combined = combine(applied)
254+
return combined
255+
256+
197257
def _calculate_unified_dim_sizes(variables):
198258
dim_sizes = OrderedDict()
199259

@@ -340,11 +400,19 @@ def apply_ufunc(args, func=None, signature=None, join='inner',
340400
apply_variable_ufunc, func=func, dask_array=dask_array,
341401
combine_attrs=combine_variable_attrs, kwargs=kwargs)
342402

343-
if any(is_dict_like(a) for a in args):
344-
return apply_dataset(args, variables_ufunc, join=join,
345-
combine_attrs=combine_dataset_attrs)
403+
if any(isinstance(a, GroupBy) for a in args):
404+
partial_apply_ufunc = functools.partial(
405+
apply_ufunc, func=func, signature=signature, join=join,
406+
dask_array=dask_array, kwargs=kwargs,
407+
combine_dataset_attrs=combine_dataset_attrs,
408+
combine_variable_attrs=combine_variable_attrs,
409+
dtype=None)
410+
return apply_groupby_ufunc(args, partial_apply_ufunc)
411+
elif any(is_dict_like(a) for a in args):
412+
return apply_dataset_ufunc(args, variables_ufunc, join=join,
413+
combine_attrs=combine_dataset_attrs)
346414
elif any(isinstance(a, DataArray) for a in args):
347-
return apply_dataarray(args, variables_ufunc, join=join)
415+
return apply_dataarray_ufunc(args, variables_ufunc, join=join)
348416
elif any(isinstance(a, Variable) for a in args):
349417
return variables_ufunc(args)
350418
elif dask_array == 'auto' and any(

xarray/core/groupby.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def _dummy_copy(xarray_obj):
6969
raise AssertionError
7070
return res
7171

72+
7273
def _is_one_or_none(obj):
7374
return obj == 1 or obj is None
7475

0 commit comments

Comments
 (0)