Skip to content

Commit 7901c0c

Browse files
committed
Fall back sel -> isel if no index, add/fix some tests
1 parent 9ae7ddf commit 7901c0c

File tree

8 files changed

+133
-57
lines changed

8 files changed

+133
-57
lines changed

doc/whats-new.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,10 @@ Bug fixes
9595
- ``.where()`` and ``.fillna()`` now preserve attributes(:issue:`1009`).
9696
By `Fabien Maussion <https://github.com/fmaussion>`_.
9797

98+
- Fixed accessing coordinate variables with non-string names from ``.coords``
99+
(:issue:`TBD`).
100+
By `Stephan Hoyer <https://github.com/shoyer>`_.
101+
98102
.. _whats-new.0.8.2:
99103

100104
v0.8.2 (18 August 2016)

xarray/core/alignment.py

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,12 @@ def align(*objects, **kwargs):
7373
-------
7474
aligned : same as *objects
7575
Tuple of objects with aligned coordinates.
76+
77+
Raises
78+
------
79+
ValueError
80+
If any dimensions without labels on the arguments have different sizes,
81+
or a different size than the size of the aligned dimension labels.
7682
"""
7783
join = kwargs.pop('join', 'inner')
7884
copy = kwargs.pop('copy', True)
@@ -87,19 +93,19 @@ def align(*objects, **kwargs):
8793
% list(kwargs))
8894

8995
all_indexes = defaultdict(list)
90-
unlabeled_dim_sizes = defaultdict(list)
96+
unlabeled_dim_sizes = defaultdict(set)
9197
for obj in objects:
9298
for dim in obj.dims:
9399
if dim not in exclude:
94100
try:
95101
index = obj.indexes[dim]
96102
except KeyError:
97103
size = get_size(obj, dim)
98-
unlabeled_dim_sizes[dim].append(size)
104+
unlabeled_dim_sizes[dim].add(size)
99105
else:
100106
all_indexes[dim].append(index)
101107

102-
# We don't join over dimensions with all equal indexes for two reasons:
108+
# We don't reindex over dimensions with all equal indexes for two reasons:
103109
# - It's faster for the usual case (already aligned objects).
104110
# - It ensures it's possible to do operations that don't require alignment
105111
# on indexes with duplicate values (which cannot be reindexed with
@@ -114,20 +120,34 @@ def align(*objects, **kwargs):
114120
else:
115121
if any(not matching_indexes[0].equals(other)
116122
for other in matching_indexes[1:]):
117-
joined_indexes[name] = joiner(matching_indexes)
123+
index = joiner(matching_indexes)
124+
joined_indexes[name] = index
125+
else:
126+
index = matching_indexes[0]
127+
128+
if dim in unlabeled_dim_sizes:
129+
unlabeled_sizes = unlabeled_dim_sizes[dim]
130+
labeled_size = index.size
131+
if len(unlabeled_sizes | {labeled_size}) > 1:
132+
raise ValueError(
133+
'arguments without labels along dimension %r cannot be '
134+
'aligned because they have different dimension size(s) %r '
135+
'than the size of the aligned dimension labels: %r'
136+
% (dim, unlabeled_sizes, labeled_size))
118137

119138
for dim in unlabeled_dim_sizes:
120-
if dim not in joined_indexes:
121-
sizes = set(unlabeled_dim_sizes[dim])
139+
if dim not in all_indexes:
140+
sizes = unlabeled_dim_sizes[dim]
122141
if len(sizes) > 1:
123142
raise ValueError(
124-
'dimension %r without indexes cannot be aligned because '
125-
'it has different sizes: %r' % (dim, sizes))
143+
'arguments without labels along dimension %r cannot be '
144+
'aligned because they have different dimension sizes: %r'
145+
% (dim, sizes))
126146

127147
result = []
128148
for obj in objects:
129-
valid_indexers = dict((k, v) for k, v in joined_indexes.items()
130-
if k in obj.dims)
149+
valid_indexers = {k: v for k, v in joined_indexes.items()
150+
if k in obj.dims}
131151
result.append(obj.reindex(copy=copy, **valid_indexers))
132152

133153
return tuple(result)
@@ -298,11 +318,6 @@ def broadcast(*args, **kwargs):
298318
The same data as the input arrays, but with additional dimensions
299319
inserted so that all data arrays have the same dimensions and shape.
300320
301-
Raises
302-
------
303-
ValueError
304-
If indexes on the different objects are not aligned.
305-
306321
Examples
307322
--------
308323

xarray/core/coordinates.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,7 @@
1111

1212
class AbstractCoordinates(Mapping, formatting.ReprMixin):
1313
def __getitem__(self, key):
14-
if (key in self._names or
15-
(isinstance(key, basestring) and
16-
key.split('.')[0] in self._names)):
17-
# allow indexing current coordinates or components
18-
return self._data[key]
19-
else:
20-
raise KeyError(key)
14+
raise NotImplementedError
2115

2216
def __setitem__(self, key, value):
2317
self.update({key: value})
@@ -143,6 +137,11 @@ def variables(self):
143137
for k, v in self._data.variables.items()
144138
if k in self._names))
145139

140+
def __getitem__(self, key):
141+
if key in self._data.data_vars:
142+
raise KeyError(key)
143+
return self._data[key]
144+
146145
def to_dataset(self):
147146
"""Convert these coordinates into a new Dataset
148147
"""
@@ -188,6 +187,9 @@ def __init__(self, dataarray):
188187
def _names(self):
189188
return set(self._data._coords)
190189

190+
def __getitem__(self, key):
191+
return self._data._getitem_coord(key)
192+
191193
def _update_coords(self, coords):
192194
from .dataset import calculate_dimensions
193195

xarray/core/dataarray.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -439,17 +439,21 @@ def _level_coords(self):
439439
level_coords.update({lname: dim for lname in level_names})
440440
return level_coords
441441

442-
def __getitem__(self, key):
443-
if isinstance(key, basestring):
444-
from .dataset import _get_virtual_variable
442+
def _getitem_coord(self, key):
443+
from .dataset import _get_virtual_variable
445444

446-
try:
447-
var = self._coords[key]
448-
except KeyError:
449-
_, key, var = _get_virtual_variable(
450-
self._coords, key, self._level_coords)
445+
try:
446+
var = self._coords[key]
447+
except KeyError:
448+
dim_sizes = dict(zip(self.dims, self.shape))
449+
_, key, var = _get_virtual_variable(
450+
self._coords, key, self._level_coords, dim_sizes)
451+
452+
return self._replace_maybe_drop_dims(var, name=key)
451453

452-
return self._replace_maybe_drop_dims(var, name=key)
454+
def __getitem__(self, key):
455+
if isinstance(key, basestring):
456+
return self._getitem_coord(key)
453457
else:
454458
# orthogonal array indexing
455459
return self.isel(**self._item_key_to_dict(key))

xarray/core/dataset.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,19 @@
3333
'quarter']
3434

3535

36-
def _get_virtual_variable(variables, key, level_vars={}):
36+
def _get_virtual_variable(variables, key, level_vars=None, dim_sizes=None):
3737
"""Get a virtual variable (e.g., 'time.year' or a MultiIndex level)
3838
from a dict of xarray.Variable objects (if possible)
3939
"""
40+
if level_vars is None:
41+
level_vars = {}
42+
if dim_sizes is None:
43+
dim_sizes = {}
44+
45+
if key in dim_sizes:
46+
variable = Variable((key,), np.arange(dim_sizes[key]))
47+
return key, key, variable
48+
4049
if not isinstance(key, basestring):
4150
raise KeyError(key)
4251

@@ -452,9 +461,9 @@ def _copy_listed(self, names):
452461
variables[name] = self._variables[name]
453462
except KeyError:
454463
ref_name, var_name, var = _get_virtual_variable(
455-
self._variables, name, self._level_coords)
464+
self._variables, name, self._level_coords, self.dims)
456465
variables[var_name] = var
457-
if ref_name in self._coord_names:
466+
if ref_name in self._coord_names or ref_name in self.dims:
458467
coord_names.add(var_name)
459468

460469
return self._subset_with_all_valid_coords(variables, coord_names,
@@ -469,7 +478,7 @@ def _construct_dataarray(self, name):
469478
variable = self._variables[name]
470479
except KeyError:
471480
_, name, variable = _get_virtual_variable(
472-
self._variables, name, self._level_coords)
481+
self._variables, name, self._level_coords, self.dims)
473482

474483
coords = OrderedDict()
475484
needed_dims = set(variable.dims)

xarray/core/indexing.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -266,14 +266,22 @@ def remap_label_indexers(data_obj, indexers, method=None, tolerance=None):
266266
if method is not None and not isinstance(method, str):
267267
raise TypeError('``method`` must be a string')
268268

269-
pos_indexers, new_indexes = {}, {}
270-
for dim, label in iteritems(get_dim_indexers(data_obj, indexers)):
271-
index = data_obj[dim].to_index()
272-
idxr, new_idx = convert_label_indexer(index, label,
273-
dim, method, tolerance)
274-
pos_indexers[dim] = idxr
275-
if new_idx is not None:
276-
new_indexes[dim] = new_idx
269+
pos_indexers = {}
270+
new_indexes = {}
271+
272+
dim_indexers = get_dim_indexers(data_obj, indexers)
273+
for dim, label in iteritems(dim_indexers):
274+
try:
275+
index_coord = data_obj[dim]
276+
except KeyError:
277+
pos_indexers[dim] = label
278+
else:
279+
index = index_coord.to_index()
280+
idxr, new_idx = convert_label_indexer(index, label,
281+
dim, method, tolerance)
282+
pos_indexers[dim] = idxr
283+
if new_idx is not None:
284+
new_indexes[dim] = new_idx
277285

278286
return pos_indexers, new_indexes
279287

xarray/test/test_dataarray.py

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@
1111
from xarray.core.pycompat import iteritems, OrderedDict
1212
from xarray.core.common import _full_like
1313

14-
from xarray.test import (TestCase, ReturnItem, source_ndarray, unittest, requires_dask,
15-
requires_bottleneck)
14+
from xarray.test import (TestCase, ReturnItem, source_ndarray, unittest,
15+
requires_dask, requires_bottleneck)
1616

1717

1818
class TestDataArray(TestCase):
@@ -609,7 +609,13 @@ def test_sel(lab_indexer, pos_indexer, replaced_idx=False,
609609
self.assertDataArrayIdentical(mdata.sel(x={'one': 'a', 'two': 1}),
610610
mdata.sel(one='a', two=1))
611611

612-
def test_time_components(self):
612+
def test_virtual_default_coords(self):
613+
array = DataArray(np.zeros((5,)), dims='x')
614+
expected = DataArray(range(5), dims='x', name='x')
615+
self.assertDataArrayIdentical(expected, array['x'])
616+
self.assertDataArrayIdentical(expected, array.coords['x'])
617+
618+
def test_virtual_time_components(self):
613619
dates = pd.date_range('2000-01-01', periods=10)
614620
da = DataArray(np.arange(1, 11), [('time', dates)])
615621

@@ -743,7 +749,9 @@ def test_coords_alignment(self):
743749
rhs = DataArray([2, 3, 4], [('x', [1, 2, 3])])
744750
lhs.coords['rhs'] = rhs
745751

746-
expected = DataArray([1, 2, 3], coords={'rhs': ('x', [np.nan, 2, 3])},
752+
expected = DataArray([1, 2, 3],
753+
coords={'rhs': ('x', [np.nan, 2, 3]),
754+
'x': [0, 1, 2]},
747755
dims='x')
748756
self.assertDataArrayIdentical(lhs, expected)
749757

@@ -755,6 +763,12 @@ def test_coords_replacement_alignment(self):
755763
expected = DataArray([0, 1, 2], coords=[('abc', [1, 2, 3])])
756764
self.assertDataArrayIdentical(arr, expected)
757765

766+
def test_coords_non_string(self):
767+
arr = DataArray(0, coords={1: 2})
768+
actual = arr.coords[1]
769+
expected = DataArray(2, coords={1: 2}, name=1)
770+
self.assertDataArrayIdentical(actual, expected)
771+
758772
def test_reindex(self):
759773
foo = self.dv
760774
bar = self.dv[:2, :2]
@@ -1649,10 +1663,11 @@ def test_resample_upsampling(self):
16491663
self.assertDataArrayIdentical(expected, actual)
16501664

16511665
def test_align(self):
1652-
self.ds['x'] = ('x', np.array(list('abcdefghij')))
1653-
dv1, dv2 = align(self.dv, self.dv[:5], join='inner')
1654-
self.assertDataArrayIdentical(dv1, self.dv[:5])
1655-
self.assertDataArrayIdentical(dv2, self.dv[:5])
1666+
array = DataArray(np.random.random((6, 8)),
1667+
coords={'x': list('abcdef')}, dims=['x', 'y'])
1668+
array1, array2 = align(array, array[:5], join='inner')
1669+
self.assertDataArrayIdentical(array1, array[:5])
1670+
self.assertDataArrayIdentical(array2, array[:5])
16561671

16571672
def test_align_dtype(self):
16581673
# regression test for #264
@@ -1723,6 +1738,15 @@ def test_align_indexes(self):
17231738
coords=[('a', [-2, 7, 10, -1])])
17241739
self.assertDataArrayIdentical(expected_x2, x2)
17251740

1741+
def test_align_without_indexes_errors(self):
1742+
with self.assertRaisesRegexp(ValueError, 'cannot be aligned'):
1743+
align(DataArray([1, 2, 3], dims=['x']),
1744+
DataArray([1, 2], dims=['x']))
1745+
1746+
with self.assertRaisesRegexp(ValueError, 'cannot be aligned'):
1747+
align(DataArray([1, 2, 3], dims=['x']),
1748+
DataArray([1, 2], coords=[('x', [0, 1])]))
1749+
17261750
def test_broadcast_arrays(self):
17271751
x = DataArray([1, 2], coords=[('a', [-1, -2])], name='x')
17281752
y = DataArray([1, 2], coords=[('b', [3, 4])], name='y')

xarray/test/test_dataset.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ def create_test_data(seed=None):
3333

3434
obj = Dataset()
3535
obj['time'] = ('time', pd.date_range('2000-01-01', periods=20))
36-
obj['dim1'] = ('dim1', np.arange(_dims['dim1'], dtype='int64'))
3736
obj['dim2'] = ('dim2', 0.5 * np.arange(_dims['dim2']))
3837
obj['dim3'] = ('dim3', list('abcdefghij'))
3938
for v, dims in sorted(_vars.items()):
@@ -70,7 +69,6 @@ def test_repr(self):
7069
Dimensions: (dim1: 8, dim2: 9, dim3: 10, time: 20)
7170
Coordinates:
7271
* time (time) datetime64[ns] 2000-01-01 2000-01-02 2000-01-03 ...
73-
* dim1 (dim1) int64 0 1 2 3 4 5 6 7
7472
* dim2 (dim2) float64 0.0 0.5 1.0 1.5 2.0 2.5 3.0 3.5 4.0
7573
* dim3 (dim3) %s 'a' 'b' 'c' 'd' 'e' 'f' 'g' 'h' 'i' 'j'
7674
numbers (dim3) int64 0 1 2 0 0 1 1 2 2 3
@@ -566,10 +564,12 @@ def test_coords_to_dataset(self):
566564
self.assertDatasetIdentical(expected, actual)
567565

568566
def test_coords_merge(self):
569-
orig_coords = Dataset(coords={'a': ('x', [1, 2])}).coords
570-
other_coords = Dataset(coords={'b': ('x', ['a', 'b'])}).coords
567+
orig_coords = Dataset(coords={'a': ('x', [1, 2]), 'x': [0, 1]}).coords
568+
other_coords = Dataset(coords={'b': ('x', ['a', 'b']),
569+
'x': [0, 1]}).coords
571570
expected = Dataset(coords={'a': ('x', [1, 2]),
572-
'b': ('x', ['a', 'b'])})
571+
'b': ('x', ['a', 'b']),
572+
'x': [0, 1]})
573573
actual = orig_coords.merge(other_coords)
574574
self.assertDatasetIdentical(expected, actual)
575575
actual = other_coords.merge(orig_coords)
@@ -1495,7 +1495,17 @@ def test_getitem_hashable(self):
14951495
with self.assertRaisesRegexp(KeyError, "('var1', 'var2')"):
14961496
data[('var1', 'var2')]
14971497

1498-
def test_virtual_variables(self):
1498+
def test_virtual_variables_default_coords(self):
1499+
dataset = Dataset({'foo': ('x', range(10))})
1500+
expected = DataArray(range(10), dims='x', name='x')
1501+
actual = dataset['x']
1502+
self.assertDataArrayIdentical(expected, actual)
1503+
1504+
actual = dataset[['x', 'foo']]
1505+
expected = dataset.assign_coords(x=range(10))
1506+
self.assertDatasetIdentical(expected, actual)
1507+
1508+
def test_virtual_variables_time(self):
14991509
# access virtual variables
15001510
data = create_test_data()
15011511
expected = DataArray(1 + np.arange(20), coords=[data['time']],

0 commit comments

Comments
 (0)