Skip to content

Commit 092848f

Browse files
committed
add point slicing to dask.array.core
This is equivalent to numpy slicing with multiple input lists. We could use a better name. cc @shoyer @jhamman Example ------- >>> x = np.arange(56).reshape((7, 8)) >>> x array([[ 0, 1, 2, 3, 4, 5, 6, 7], [ 8, 9, 10, 11, 12, 13, 14, 15], [16, 17, 18, 19, 20, 21, 22, 23], [24, 25, 26, 27, 28, 29, 30, 31], [32, 33, 34, 35, 36, 37, 38, 39], [40, 41, 42, 43, 44, 45, 46, 47], [48, 49, 50, 51, 52, 53, 54, 55]]) >>> d = from_array(x, chunks=(3, 4)) >>> result = isel(d, [0, 1, 6, 0], [0, 1, 0, 7]) >>> result.compute() array([ 0, 9, 48, 7]) Fixes dask#433
1 parent 65cca6e commit 092848f

File tree

3 files changed

+78
-1
lines changed

3 files changed

+78
-1
lines changed

dask/array/core.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2222,3 +2222,69 @@ def to_hdf5(filename, *args, **kwargs):
22222222
**kwargs)
22232223
for dp, x in data.items()]
22242224
store(list(data.values()), dsets)
2225+
2226+
2227+
def isel(x, *indexes):
2228+
""" Point wise slicing
2229+
2230+
This is equivalent to numpy slicing with multiple input lists
2231+
2232+
>>> x = np.arange(56).reshape((7, 8))
2233+
>>> x
2234+
array([[ 0, 1, 2, 3, 4, 5, 6, 7],
2235+
[ 8, 9, 10, 11, 12, 13, 14, 15],
2236+
[16, 17, 18, 19, 20, 21, 22, 23],
2237+
[24, 25, 26, 27, 28, 29, 30, 31],
2238+
[32, 33, 34, 35, 36, 37, 38, 39],
2239+
[40, 41, 42, 43, 44, 45, 46, 47],
2240+
[48, 49, 50, 51, 52, 53, 54, 55]])
2241+
2242+
>>> d = from_array(x, chunks=(3, 4))
2243+
>>> result = isel(d, [0, 1, 6, 0], [0, 1, 0, 7])
2244+
>>> result.compute()
2245+
array([ 0, 9, 48, 7])
2246+
"""
2247+
indexes = list(map(list, indexes))
2248+
bounds = [list(accumulate(add, (0,) + c)) for c in x.chunks]
2249+
points = list()
2250+
2251+
for i, idx in enumerate(zip(*indexes)):
2252+
block_idx = [np.searchsorted(b, ind, 'right') - 1 for b, ind in zip(bounds, idx)]
2253+
inblock_idx = [ind - bounds[k][j]
2254+
for k, (ind, j) in enumerate(zip(idx, block_idx))]
2255+
points.append((i, (x.name,) + tuple(block_idx), tuple(inblock_idx)))
2256+
2257+
per_block = groupby(1, points)
2258+
2259+
token = next(tokens)
2260+
name = 'isel-slice' + token
2261+
dsk = dict(((name, i), (_isel_slice, key, list(pluck(2, per_block[key]))))
2262+
for i, key in enumerate(per_block))
2263+
2264+
dsk[('isel-merge' + token, 0)] = (_isel_merge,
2265+
[list(pluck(0, per_block[key])) for key in per_block],
2266+
[(name, i) for i in range(len(per_block))])
2267+
2268+
chunks = ((len(points),),)
2269+
2270+
return Array(merge(x.dask, dsk), 'isel-merge' + token, chunks, x.dtype)
2271+
2272+
2273+
def _isel_slice(block, points):
2274+
points2 = list(zip(*points))
2275+
return block[points2]
2276+
2277+
2278+
def _isel_merge(locations, values):
2279+
locations = list(map(list, locations))
2280+
values = list(values)
2281+
2282+
n = sum(map(len, locations))
2283+
dtype = values[0].dtype
2284+
2285+
x = np.empty(n, dtype=dtype)
2286+
2287+
for loc, values in zip(locations, values):
2288+
x[loc] = values
2289+
2290+
return x

dask/array/slicing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from itertools import count, product
2-
from toolz import merge, first, accumulate
2+
from toolz import merge, first, accumulate, groupby, pluck
33
from operator import getitem, add
44
from math import ceil
55
from ..compatibility import long

dask/array/tests/test_array_core.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1220,3 +1220,14 @@ def test_h5py_newaxis():
12201220

12211221
def test_ellipsis_slicing():
12221222
assert eq(da.ones(4, chunks=2)[...], np.ones(4))
1223+
1224+
1225+
def test_point_slicing():
1226+
x = np.arange(56).reshape((7, 8))
1227+
d = da.from_array(x, chunks=(3, 4))
1228+
1229+
result = isel(d, [1, 2, 5, 5], [3, 1, 6, 1])
1230+
assert eq(result, x[[1, 2, 5, 5], [3, 1, 6, 1]])
1231+
1232+
result = isel(d, [0, 1, 6, 0], [0, 1, 0, 7])
1233+
assert eq(result, x[[0, 1, 6, 0], [0, 1, 0, 7]])

0 commit comments

Comments
 (0)