Skip to content

Commit b477e01

Browse files
committed
tests for core dimensions
1 parent 87e6067 commit b477e01

File tree

2 files changed

+115
-72
lines changed

2 files changed

+115
-72
lines changed

xarray/core/computation.py

Lines changed: 36 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import collections
12
import functools
23
import itertools
34
import re
@@ -138,12 +139,10 @@ def _default_result_attrs(attrs, func, signature):
138139

139140
def _build_output_coords(args, signature, new_coords=None):
140141

141-
def get_coord_variables(arg):
142-
return getattr(getattr(arg, 'coords', {}), 'variables', {})
143-
144-
coord_variables = [get_coord_variables(a) for a in args]
142+
coord_variables = [getattr(getattr(arg, 'coords', arg), 'variables', {})
143+
for arg in args]
145144
if new_coords is not None:
146-
coord_variables.append(get_coord_variables(new_coords))
145+
coord_variables.append(new_coords)
147146

148147
merged = merge_coords_without_align(coord_variables)
149148

@@ -176,19 +175,17 @@ def apply_dataarray_ufunc(func, *args, **kwargs):
176175
args = deep_align(args, join=join, copy=False, raise_on_invalid=False)
177176

178177
name = result_name(args)
179-
list_of_coords = _build_output_coords(args, signature, new_coords)
178+
result_coords = _build_output_coords(args, signature, new_coords)
180179

181180
data_vars = [getattr(a, 'variable') for a in args]
182-
variable_or_variables = func(*data_vars)
181+
result_var = func(*data_vars)
183182

184183
if signature.n_outputs > 1:
185-
return tuple(DataArray(variable, coords, name=name, fastpath=True)
186-
for variable, coords in zip(
187-
variable_or_variables, list_of_coords))
184+
return tuple(DataArray(variable, coords, name=name)
185+
for variable, coords in zip(result_var, result_coords))
188186
else:
189-
variable = variable_or_variables
190-
coords, = list_of_coords
191-
return DataArray(variable, coords, name=name, fastpath=True)
187+
coords, = result_coords
188+
return DataArray(result_var, coords, name=name)
192189

193190

194191
def join_dict_keys(objects, how='inner'):
@@ -243,14 +240,6 @@ def apply_dataset_ufunc(func, args, signature=None, join='inner',
243240
for name, variable_args in zip(names, lists_of_args):
244241
result_vars[name] = func(*variable_args)
245242

246-
def make_dataset(data_vars, coord_vars):
247-
# Normally, we would copy data_vars to be safe, but we created the
248-
# OrderedDict in this function and don't use it for anything else.
249-
variables = data_vars
250-
variables.update(coord_vars)
251-
coord_names = set(coord_vars)
252-
return Dataset._from_vars_and_coord_names(variables, coord_names)
253-
254243
if signature.n_outputs > 1:
255244
# we need to unpack result_vars from Dict[object, Tuple[Variable]] ->
256245
# Tuple[Dict[object, Variable]].
@@ -259,12 +248,12 @@ def make_dataset(data_vars, coord_vars):
259248
for value, results_dict in zip(values, result_dict_list):
260249
results_dict[name] = value
261250

262-
return tuple(make_dataset(*args)
251+
return tuple(Dataset(*args)
263252
for args in zip(result_dict_list, list_of_coords))
264253
else:
265254
data_vars = result_vars
266255
coord_vars, = list_of_coords
267-
return make_dataset(data_vars, coord_vars)
256+
return Dataset(data_vars, coord_vars)
268257

269258

270259
def _iter_over_selections(obj, dim, values):
@@ -383,47 +372,6 @@ def broadcast_compat_data(variable, broadcast_dims, core_dims):
383372
return data
384373

385374

386-
def _deep_unpack_list(arg):
387-
if isinstance(arg, list):
388-
arg, = arg
389-
return _deep_unpack_list(arg)
390-
return arg
391-
392-
393-
def _apply_with_dask_atop(func, args, signature, kwargs, dtype):
394-
395-
if signature.all_input_core_dims or signature.all_output_core_dims:
396-
raise ValueError("cannot use dask_array='auto' on unlabeled dask "
397-
'arrays with a function signature that uses core '
398-
'dimensions')
399-
return da.elemwise(func, *args, dtype=dtype, **kwargs)
400-
401-
# import toolz # required dependency of dask.array
402-
403-
# if len(signature.output_core_dims) > 1:
404-
# raise ValueError('cannot create use dask.array.atop for '
405-
# 'multiple outputs')
406-
# if signature.all_output_core_dims - signature.all_input_core_dims:
407-
# raise ValueError('cannot create new dimensions in dask.array.atop')
408-
409-
# input_dims = [broadcast_dims + inp for inp in signature.input_core_dims]
410-
# dropped = signature.all_input_core_dims - signature.all_output_core_dims
411-
# for data, dims in zip(args, input_dims):
412-
# if isinstance(data, dask_array_type):
413-
# for dropped_dim in dropped:
414-
# if (dropped_dim in dims and
415-
# len(data.chunks[dims.index(dropped_dim)]) != 1):
416-
# raise ValueError(
417-
# 'dimension %r dropped in the output does not '
418-
# 'consist of exactly one chunk on all arrays '
419-
# 'in the inputs' % dropped_dim)
420-
421-
# out_ind, = output_dims
422-
# atop_args = [ai for a in (args, input_dims) for ai in a]
423-
# func2 = toolz.functools.compose(func, _deep_unpack_list)
424-
# result_data = da.atop(func2, out_ind, *atop_args, dtype=dtype, **kwargs)
425-
426-
427375
def apply_variable_ufunc(func, *args, **kwargs):
428376
"""
429377
def apply_variable_ufunc(func, args, signature=None, dask_array='forbidden',
@@ -465,8 +413,8 @@ def apply_variable_ufunc(func, args, signature=None, dask_array='forbidden',
465413
if dask_array == 'forbidden' and contains_dask:
466414
raise ValueError('encountered dask array')
467415
elif dask_array == 'auto' and contains_dask:
468-
result_data = _apply_with_dask_atop(func, list_of_input_data, signature,
469-
kwargs_, dask_dtype)
416+
result_data = apply_dask_array(func, *args, signature=signature,
417+
kwargs=kwargs, dtype=dask_dtype)
470418
else:
471419
result_data = func(*list_of_input_data, **kwargs_)
472420

@@ -481,6 +429,26 @@ def apply_variable_ufunc(func, args, signature=None, dask_array='forbidden',
481429
return Variable(dims, data)
482430

483431

432+
def apply_dask_ufunc(func, *args, **kwargs):
433+
import dask.array as da
434+
435+
signature = kwargs.pop('signature')
436+
kwargs_ = kwargs.pop('kwargs', None)
437+
dtype = kwargs.pop('dtype', None)
438+
439+
if signature.n_outputs != 1:
440+
raise ValueError("cannot use dask_array='auto' with functions that "
441+
'return multiple values')
442+
443+
if signature.all_input_core_dims or signature.all_output_core_dims:
444+
raise ValueError("cannot use dask_array='auto' on unlabeled dask "
445+
'arrays with a function signature that uses core '
446+
'dimensions')
447+
448+
f = functools.partial(func, **kwargs_)
449+
return da.elemwise(f, *args, dtype=dtype)
450+
451+
484452
def apply_ufunc(func, *args, **kwargs):
485453
"""apply_ufunc(func, *args, signature=None, join='inner', new_coords=None,
486454
kwargs=None, dask_array='forbidden', dask_dtype=None)
@@ -556,12 +524,8 @@ def apply_ufunc(func, *args, **kwargs):
556524
return variables_ufunc(*args)
557525
elif dask_array == 'auto' and any(
558526
isinstance(arg, dask_array_type) for arg in args):
559-
import dask.array as da
560-
if signature.all_input_core_dims or signature.all_output_core_dims:
561-
raise ValueError("cannot use dask_array='auto' on unlabeled dask "
562-
'arrays with a function signature that uses core '
563-
'dimensions')
564-
return da.elemwise(func, *args, dtype=dask_dtype)
527+
return apply_dask_array(func, *args, signature=signature, kwargs=kwargs,
528+
dtype=dask_dtype)
565529
else:
566530
return func(*args)
567531

xarray/test/test_computation.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,85 @@ def test_apply_ufunc_two_outputs():
100100
assert_identical(out1, 2 * dataset)
101101

102102

103+
def test_apply_ufunc_input_core_dimension():
104+
105+
def first_element(obj, dim):
106+
func = lambda x: x[..., 0]
107+
sig = ([(dim,)], [(),])
108+
return xr.apply_ufunc(func, obj, signature=sig)
109+
110+
array = np.array([[1, 2], [3, 4]])
111+
variable = xr.Variable(['x', 'y'], array)
112+
data_array = xr.DataArray(variable, {'x': ['a', 'b'], 'y': [-1, -2]})
113+
dataset = xr.Dataset({'data': data_array})
114+
115+
expected_variable_x = xr.Variable(['y'], [1, 2])
116+
expected_data_array_x = xr.DataArray(expected_variable_x, {'y': [-1, -2]})
117+
expected_dataset_x = xr.Dataset({'data': expected_data_array_x})
118+
119+
expected_variable_y = xr.Variable(['x'], [1, 3])
120+
expected_data_array_y = xr.DataArray(expected_variable_y, {'x': ['a', 'b']})
121+
expected_dataset_y = xr.Dataset({'data': expected_data_array_y})
122+
123+
actual = first_element(variable, 'x')
124+
assert_identical(actual, expected_variable_x)
125+
actual = first_element(variable, 'y')
126+
assert_identical(actual, expected_variable_y)
127+
128+
actual = first_element(data_array, 'x')
129+
assert_identical(actual, expected_data_array_x)
130+
actual = first_element(data_array, 'y')
131+
assert_identical(actual, expected_data_array_y)
132+
133+
actual = first_element(dataset, 'x')
134+
assert_identical(actual, expected_dataset_x)
135+
actual = first_element(dataset, 'y')
136+
assert_identical(actual, expected_dataset_y)
137+
138+
139+
def test_apply_ufunc_output_core_dimension():
140+
141+
def stack_negative(obj):
142+
func = lambda x: xr.core.npcompat.stack([x, -x], axis=-1)
143+
sig = ([()], [('sign',)])
144+
new_coords = {'sign': [1, -1]}
145+
return xr.apply_ufunc(func, obj, signature=sig, new_coords=new_coords)
146+
147+
array = np.array([[1, 2], [3, 4]])
148+
variable = xr.Variable(['x', 'y'], array)
149+
data_array = xr.DataArray(variable, {'x': ['a', 'b'], 'y': [-1, -2]})
150+
dataset = xr.Dataset({'data': data_array})
151+
152+
stacked_array = np.array([[[1, -1], [2, -2]], [[3, -3], [4, -4]]])
153+
expected_variable = xr.Variable(['x', 'y', 'sign'], stacked_array)
154+
expected_coords = {'x': ['a', 'b'], 'y': [-1, -2], 'sign': [1, -1]}
155+
expected_data_array = xr.DataArray(expected_variable, expected_coords)
156+
expected_dataset = xr.Dataset({'data': expected_data_array})
157+
158+
actual = stack_negative(variable)
159+
assert_identical(actual, expected_variable)
160+
161+
actual = stack_negative(data_array)
162+
assert_identical(actual, expected_data_array)
163+
164+
actual = stack_negative(dataset)
165+
assert_identical(actual, expected_dataset)
166+
167+
def stack2(obj):
168+
func = lambda x: xr.core.npcompat.stack([x, -x], axis=-1)
169+
sig = ([()], [('sign',)])
170+
# no new_coords
171+
return xr.apply_ufunc(func, obj, signature=sig)
172+
173+
actual = stack2(data_array)
174+
expected_data_array.coords['sign'] = [0, 1]
175+
assert_identical(actual, expected_data_array)
176+
177+
actual = stack2(dataset)
178+
expected_dataset.coords['sign'] = [0, 1]
179+
assert_identical(actual, expected_dataset)
180+
181+
103182
def test_broadcast_compat_data_1d():
104183
data = np.arange(5)
105184
var = xr.Variable('x', data)

0 commit comments

Comments
 (0)