Skip to content

Commit 5b2b66b

Browse files
committed
FIX: fixed array[slice_group_from_an_incompatible_axis] (fixes #1146 and #1117)
probably also fixed a few edge cases in Axis.index()
1 parent cbfbcb6 commit 5b2b66b

File tree

5 files changed

+143
-52
lines changed

5 files changed

+143
-52
lines changed

doc/source/changes/version_0_35.rst.inc

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,28 @@ Miscellaneous improvements
112112
Fixes
113113
^^^^^
114114

115+
* fixed array[slice_group_from_an_incompatible_axis] and
116+
array.sum(slice_group_from_an_incompatible_axis) (closes :issue:`1146`
117+
and :issue:`1117`).
118+
It used to evaluate the slice on the array axis instead of first evaluating
119+
the slice on the axis it was created on, then take the corresponding labels
120+
from the array axis.
121+
122+
>>> arr = ndtest(3)
123+
>>> arr
124+
a a0 a1 a2
125+
0 1 2
126+
>>> other_axis_a = Axis('a=a0,a1')
127+
>>> group = other_axis_a[:]
128+
>>> print(group)
129+
['a0' 'a1']
130+
>>> arr[group] # <-- before
131+
a a0 a1 a2
132+
0 1 2
133+
>>> arr[group] # <-- now
134+
a a0 a1
135+
0 1
136+
115137
* fixed error message when trying to take a subset of an array with an array
116138
key which has ndim > 1 and some bad values in the key. The message was also
117139
improved (see the issue for details). Closes :issue:`1134`.

larray/core/array.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2140,6 +2140,9 @@ def sort_values(self, key=None, axis=None, ascending=True) -> 'Array':
21402140
# FWIW, using .data, I get IGroup([1, 2, 0], axis='nat'), which works.
21412141
sorter = axis.i[indicesofsorted.data]
21422142
res = self[sorter]
2143+
# res has its axis in a different order than the original axis
2144+
# so we need this line to reverse the order below if not ascending
2145+
axis = res.axes[axis]
21432146
else:
21442147
res = self.combine_axes()
21452148
indicesofsorted = np.argsort(res.data)
@@ -2799,22 +2802,27 @@ def _group_aggregate(self, op, items, keepaxes=False, out=None, **kwargs) -> 'Ar
27992802
if isinstance(item, tuple):
28002803
assert all(isinstance(g, Group) for g in item)
28012804
groups = item
2802-
axis = groups[0].axis
2805+
group_axis = groups[0].axis
28032806
# they should all have the same axis (this is already checked
28042807
# in _prepare_aggregate though)
2805-
assert all(g.axis.equals(axis) for g in groups[1:])
2808+
assert all(g.axis.equals(group_axis) for g in groups[1:])
28062809
killaxis = False
28072810
else:
28082811
# item is in fact a single group
28092812
assert isinstance(item, Group), type(item)
28102813
groups = (item,)
2811-
axis = item.axis
2814+
group_axis = item.axis
28122815
# it is easier to kill the axis after the fact
28132816
killaxis = True
28142817

2815-
axis, axis_idx = res.axes[axis], res.axes.index(axis)
2818+
axis_idx = res.axes.index(group_axis)
2819+
res_axis = res.axes[axis_idx]
2820+
assert group_axis.equals(res_axis)
2821+
28162822
# potentially translate axis reference to real axes
2817-
groups = tuple(g.with_axis(axis) for g in groups)
2823+
# with_axis is correct because we already checked
2824+
# that g.axis.equals(axis)
2825+
groups = tuple(g.with_axis(res_axis) for g in groups)
28182826
res_shape[axis_idx] = len(groups)
28192827

28202828
# XXX: this code is fragile. I wonder if there isn't a way to ask the function what kind of dtype/shape it
@@ -2866,7 +2874,7 @@ def _group_aggregate(self, op, items, keepaxes=False, out=None, **kwargs) -> 'Ar
28662874
# We do NOT modify the axis name (eg append "_agg" or "*") even though this creates a new axis that is
28672875
# independent from the original one because the original name is what users will want to use to access
28682876
# that axis (eg in .filter kwargs)
2869-
res_axes[axis_idx] = Axis(groups, axis.name)
2877+
res_axes[axis_idx] = Axis(groups, res_axis.name)
28702878

28712879
if isinstance(res_data, np.ndarray):
28722880
res = Array(res_data, res_axes)

larray/core/axis.py

Lines changed: 72 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -922,32 +922,41 @@ def index(self, key) -> Union[int, np.ndarray, slice]:
922922
"""
923923
mapping = self._mapping
924924

925-
if isinstance(key, Group) and key.axis is not self and key.axis is not None:
926-
try:
927-
# XXX: this is potentially very expensive if key.key is an array or list and should be tried as a last
928-
# resort
929-
potential_tick = _to_tick(key)
930-
931-
# avoid matching 0 against False or 0.0, note that None has object dtype and so always pass this test
932-
if self._is_key_type_compatible(potential_tick):
933-
try:
934-
res_idx = mapping[potential_tick]
935-
if potential_tick != key.key:
936-
# only warn if no KeyError was raised (potential_tick is in mapping)
937-
msg = "Using a Group object which was used to create an aggregate to " \
938-
"target its aggregated label is deprecated. " \
939-
"Please use the aggregated label directly instead. " \
940-
f"In this case, you should use {potential_tick!r} instead of " \
941-
f"using {key!r}."
942-
# let us hope the stacklevel does not vary by codepath
943-
warnings.warn(msg, FutureWarning, stacklevel=8)
944-
return res_idx
945-
except KeyError:
946-
pass
947-
# we must catch TypeError because key might not be hashable (eg slice)
948-
# IndexError is for when mapping is an ndarray
949-
except (KeyError, TypeError, IndexError):
950-
pass
925+
if isinstance(key, Group):
926+
if key.axis is self:
927+
if isinstance(key, IGroup):
928+
return key.key
929+
else:
930+
# at this point we do not care about the axis nor the name
931+
key = key.key
932+
elif key.axis is not None:
933+
try:
934+
# TODO: remove this as it is potentially very expensive
935+
# if key.key is an array or list and should be tried
936+
# as a last resort
937+
potential_tick = _to_tick(key)
938+
939+
# avoid matching 0 against False or 0.0, note that None has
940+
# object dtype and so always pass this test
941+
if self._is_key_type_compatible(potential_tick):
942+
try:
943+
res_idx = mapping[potential_tick]
944+
if potential_tick != key.key:
945+
# only warn if no KeyError was raised (potential_tick is in mapping)
946+
msg = "Using a Group object which was used to create an aggregate to " \
947+
"target its aggregated label is deprecated. " \
948+
"Please use the aggregated label directly instead. " \
949+
f"In this case, you should use {potential_tick!r} instead of " \
950+
f"using {key!r}."
951+
# let us hope the stacklevel does not vary by codepath
952+
warnings.warn(msg, FutureWarning, stacklevel=8)
953+
return res_idx
954+
except KeyError:
955+
pass
956+
# we must catch TypeError because key might not be hashable (eg slice)
957+
# IndexError is for when mapping is an ndarray
958+
except (KeyError, TypeError, IndexError):
959+
pass
951960

952961
if isinstance(key, str):
953962
# try the key as-is to allow getting at ticks with special characters (",", ":", ...)
@@ -961,24 +970,35 @@ def index(self, key) -> Union[int, np.ndarray, slice]:
961970
except (KeyError, TypeError, IndexError):
962971
pass
963972

964-
# transform "specially formatted strings" for slices, lists, LGroup and IGroup to actual objects
973+
# transform "specially formatted strings" for slices, lists, LGroup
974+
# and IGroup to actual objects
965975
key = _to_key(key)
966976

967977
if isinstance(key, range):
968978
key = list(key)
969-
970-
# this can happen when key was passed as a string and converted to a Group via _to_key
971-
if isinstance(key, Group) and isinstance(key.axis, str) and key.axis != self.name:
972-
raise KeyError(key)
973-
974-
if isinstance(key, IGroup):
975-
if isinstance(key.axis, Axis):
976-
assert key.axis is self
977-
return key.key
978-
979-
if isinstance(key, LGroup):
980-
# at this point we do not care about the axis nor the name
981-
key = key.key
979+
elif isinstance(key, Group):
980+
key_axis = key.axis
981+
if isinstance(key_axis, str):
982+
if key_axis != self.name:
983+
raise KeyError(key)
984+
elif isinstance(key_axis, AxisReference):
985+
if key_axis.name != self.name:
986+
raise KeyError(key)
987+
elif isinstance(key_axis, Axis): # we know it is not self
988+
# IGroups will be retargeted to LGroups
989+
key = key.retarget_to(self)
990+
elif isinstance(key_axis, int):
991+
raise TypeError('Axis.index() does not support Group keys with '
992+
'integer axis')
993+
else:
994+
assert key_axis is None
995+
# an IGroup can still exist at this point if the key was an IGroup
996+
# with a compatible axis (string or AxisReference axis with the
997+
# correct name or Axis object equal to self)
998+
if isinstance(key, IGroup):
999+
return key.key
1000+
else:
1001+
key = key.key
9821002

9831003
if isinstance(key, slice):
9841004
start = mapping[key.start] if key.start is not None else None
@@ -1915,7 +1935,8 @@ def __contains__(self, key) -> bool:
19151935
if isinstance(key, int):
19161936
return -len(self) <= key < len(self)
19171937
elif isinstance(key, Axis):
1918-
# the special case is just a performance optimization to avoid scanning through the whole list
1938+
# the special case is just a performance optimization to avoid
1939+
# scanning through the whole list
19191940
if key.name is not None:
19201941
return key.name in self._map
19211942
else:
@@ -2808,7 +2829,7 @@ def _guess_axis(self, axis_key):
28082829
# we have axis information but not necessarily an Axis object from self
28092830
real_axis = self[group_axis]
28102831
if group_axis is not real_axis:
2811-
axis_key = axis_key.with_axis(real_axis)
2832+
axis_key = axis_key.retarget_to(real_axis)
28122833
return axis_key
28132834

28142835
real_axis, axis_pos_key = self._translate_nice_key(axis_key)
@@ -2828,6 +2849,7 @@ def _translate_axis_key_chunk(self, axis_key):
28282849
(axis, indices)
28292850
Indices group with a valid axis (from self)
28302851
"""
2852+
orig_key = axis_key
28312853
axis_key = remove_nested_groups(axis_key)
28322854

28332855
if isinstance(axis_key, IGroup):
@@ -2852,11 +2874,16 @@ def _translate_axis_key_chunk(self, axis_key):
28522874
# labels but known axis
28532875
if isinstance(axis_key, LGroup) and axis_key.axis is not None:
28542876
try:
2855-
real_axis = self[axis_key.axis]
2877+
key_axis = axis_key.axis
2878+
real_axis = self[key_axis]
2879+
if isinstance(key_axis, (AxisReference, int)):
2880+
# this is one of the rare cases where with_axis is correct !
2881+
axis_key = axis_key.with_axis(real_axis)
2882+
28562883
try:
28572884
axis_pos_key = real_axis.index(axis_key)
28582885
except KeyError:
2859-
raise ValueError(f"{axis_key!r} is not a valid label for the {real_axis.name!r} axis "
2886+
raise ValueError(f"{orig_key!r} is not a valid label for the {real_axis.name!r} axis "
28602887
f"with labels: {', '.join(repr(label) for label in real_axis.labels)}")
28612888
return real_axis, axis_pos_key
28622889
except KeyError:
@@ -3889,6 +3916,7 @@ def align_axis_collections(axis_collections, join='outer', axes=None):
38893916

38903917
class AxisReference(ABCAxisReference, ExprNode, Axis):
38913918
def __init__(self, name):
3919+
assert isinstance(name, (int, str))
38923920
self.name = name
38933921
self._labels = None
38943922
self._iswildcard = False

larray/tests/test_array.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -560,6 +560,17 @@ def test_getitem(array):
560560
_ = array[bad[1, 2], a[3, 4]]
561561

562562

563+
def test_getitem_group_from_another_axis():
564+
# using slice Group from an axis not present, we must retarget the group
565+
arr = ndtest(3)
566+
a2 = Axis('a=a0,a1')
567+
568+
# issue #1146
569+
expected = ndtest(2)
570+
res = arr[a2[:]]
571+
assert_larray_equal(res, expected)
572+
573+
563574
def test_getitem_abstract_axes(array):
564575
raw = array.data
565576
a, b, c, d = array.axes
@@ -1130,8 +1141,6 @@ def test_getitem_single_larray_key_guess():
11301141
_ = arr[key]
11311142

11321143

1133-
1134-
11351144
def test_getitem_multiple_larray_key_guess():
11361145
a, b, c, d, e = ndtest((2, 3, 2, 3, 2)).axes
11371146
arr = ndtest((a, b))
@@ -2161,6 +2170,20 @@ def test_group_agg_label_group(array):
21612170
res = array.sum(a, c).sum((g1, g2, g3, g_all))
21622171
assert res.shape == (4, 6)
21632172

2173+
# d) group aggregate using a group from another axis
2174+
# 1) LGroup
2175+
array = ndtest(3)
2176+
smaller_a_axis = Axis('a=a0,a1')
2177+
group = smaller_a_axis[:]
2178+
res = array.sum(group)
2179+
assert res == 1
2180+
2181+
# 2) IGroup
2182+
group = Axis("a=a1,a0").i[0] # targets a1
2183+
assert array[group] == 1
2184+
res = array.sum(group)
2185+
assert res == 1
2186+
21642187

21652188
def test_group_agg_label_group_no_axis(array):
21662189
a, b, c, d = array.axes

larray/tests/test_axis.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,16 @@ def test_index():
123123
assert a.index('a1') == 1
124124
assert a.index('a1 >> A1') == 1
125125

126+
time = Axis([2007, 2009], 'time')
127+
res = time.index(time.i[1])
128+
assert res == 1
129+
130+
res = time.index(X.time.i[1])
131+
assert res == 1
132+
133+
res = time.index('time.i[1]')
134+
assert res == 1
135+
126136

127137
def test_astype():
128138
arr = ndtest(Axis('time=2015..2020,total')).drop('total')

0 commit comments

Comments
 (0)