Skip to content

Commit 879ee32

Browse files
committed
Tests and fixes for binary ops
1 parent 4dfca28 commit 879ee32

File tree

2 files changed

+59
-17
lines changed

2 files changed

+59
-17
lines changed

xarray/core/computation.py

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ def apply_dataarray_ufunc(func, *args, **kwargs):
177177
name = result_name(args)
178178
result_coords = build_output_coords(args, signature, new_coords)
179179

180-
data_vars = [getattr(a, 'variable') for a in args]
180+
data_vars = [getattr(a, 'variable', a) for a in args]
181181
result_var = func(*data_vars)
182182

183183
if signature.n_outputs > 1:
@@ -196,16 +196,11 @@ def join_dict_keys(objects, how='inner'):
196196

197197

198198
def collect_dict_values(objects, keys, fill_value=None):
199-
result_values = []
200-
for key in keys:
201-
values = []
202-
for obj in objects:
203-
if hasattr(obj, 'keys'):
204-
values.append(obj.get(key, fill_value))
205-
else:
206-
values = tobj
207-
result_values.append(values)
208-
return result_values
199+
return [[obj.get(key, fill_value)
200+
if is_dict_like(obj)
201+
else obj
202+
for obj in objects]
203+
for key in keys]
209204

210205

211206
def apply_dataset_ufunc(func, *args, **kwargs):
@@ -220,8 +215,8 @@ def apply_dataset_ufunc(func, args, signature=None, join='inner',
220215
fill_value = kwargs.pop('fill_value', None)
221216
new_coords = kwargs.pop('new_coords', None)
222217
if kwargs:
223-
raise TypeError('apply_dataarray_ufunc() got unexpected keyword arguments: %s'
224-
% list(kwargs))
218+
raise TypeError('apply_dataarray_ufunc() got unexpected keyword '
219+
'arguments: %s' % list(kwargs))
225220

226221
if signature is None:
227222
signature = _default_signature(len(args))
@@ -230,10 +225,10 @@ def apply_dataset_ufunc(func, args, signature=None, join='inner',
230225

231226
list_of_coords = build_output_coords(args, signature, new_coords)
232227

233-
list_of_data_vars = [getattr(a, 'data_vars', {}) for a in args]
228+
list_of_data_vars = [getattr(a, 'data_vars', a) for a in args]
234229
names = join_dict_keys(list_of_data_vars, how=join)
235230

236-
list_of_variables = [getattr(a, 'variables', {}) for a in args]
231+
list_of_variables = [getattr(a, 'variables', a) for a in args]
237232
lists_of_args = collect_dict_values(list_of_variables, names, fill_value)
238233

239234
result_vars = OrderedDict()

xarray/test/test_computation.py

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from collections import OrderedDict
2+
import operator
23

34
import numpy as np
45
import pytest
@@ -47,8 +48,8 @@ def test_join_dict_keys():
4748

4849

4950
def test_collect_dict_values():
50-
dicts = [{'x': 1, 'y': 2, 'z': 3}, {'z': 4}]
51-
expected = [[1, 0], [2, 0], [3, 4]]
51+
dicts = [{'x': 1, 'y': 2, 'z': 3}, {'z': 4}, 5]
52+
expected = [[1, 0, 5], [2, 0, 5], [3, 4, 5]]
5253
collected = collect_dict_values(dicts, ['x', 'y', 'z'], fill_value=0)
5354
assert collected == expected
5455

@@ -74,6 +75,52 @@ def test_apply_ufunc_identity():
7475
assert_identical(output, dataset)
7576

7677

78+
def test_apply_ufunc_two_inputs():
79+
array = np.array([1, 2, 3])
80+
variable = xr.Variable('x', array)
81+
data_array = xr.DataArray(variable, [('x', -array)])
82+
dataset = xr.Dataset({'y': variable}, {'x': -array})
83+
84+
zeros_array = np.zeros_like(array)
85+
zeros_variable = xr.Variable('x', zeros_array)
86+
zeros_data_array = xr.DataArray(zeros_variable, [('x', -array)])
87+
zeros_dataset = xr.Dataset({'y': zeros_variable}, {'x': -array})
88+
89+
add = lambda a, b: xr.apply_ufunc(operator.add, a, b)
90+
91+
assert_array_equal(array, add(array, 0))
92+
assert_array_equal(array, add(array, zeros_array))
93+
assert_array_equal(array, add(0, array))
94+
assert_array_equal(array, add(zeros_array, array))
95+
96+
assert_identical(variable, add(variable, 0))
97+
assert_identical(variable, add(variable, zeros_array))
98+
assert_identical(variable, add(variable, zeros_variable))
99+
assert_identical(variable, add(0, variable))
100+
assert_identical(variable, add(zeros_array, variable))
101+
assert_identical(variable, add(zeros_variable, variable))
102+
103+
assert_identical(data_array, add(data_array, 0))
104+
assert_identical(data_array, add(data_array, zeros_array))
105+
assert_identical(data_array, add(data_array, zeros_variable))
106+
assert_identical(data_array, add(data_array, zeros_data_array))
107+
assert_identical(data_array, add(0, data_array))
108+
assert_identical(data_array, add(zeros_array, data_array))
109+
assert_identical(data_array, add(zeros_variable, data_array))
110+
assert_identical(data_array, add(zeros_data_array, data_array))
111+
112+
assert_identical(dataset, add(dataset, 0))
113+
assert_identical(dataset, add(dataset, zeros_array))
114+
assert_identical(dataset, add(dataset, zeros_variable))
115+
assert_identical(dataset, add(dataset, zeros_data_array))
116+
assert_identical(dataset, add(dataset, zeros_dataset))
117+
assert_identical(dataset, add(0, dataset))
118+
assert_identical(dataset, add(zeros_array, dataset))
119+
assert_identical(dataset, add(zeros_variable, dataset))
120+
assert_identical(dataset, add(zeros_data_array, dataset))
121+
assert_identical(dataset, add(zeros_dataset, dataset))
122+
123+
77124
def test_apply_ufunc_two_outputs():
78125
array = np.arange(10)
79126
variable = xr.Variable('x', array)

0 commit comments

Comments
 (0)