Skip to content

Commit f999928

Browse files
committed
Started implemented parallelization along broadcasted dimensions
1 parent 1b8b7c0 commit f999928

File tree

1 file changed

+78
-46
lines changed

1 file changed

+78
-46
lines changed

reproject/common.py

Lines changed: 78 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,23 @@ def _reproject_dispatcher(
9797
if reproject_func_kwargs is None:
9898
reproject_func_kwargs = {}
9999

100+
# Determine whether any broadcasting is taking place
101+
broadcasting = wcs_in.low_level_wcs.pixel_n_dim < len(shape_out)
102+
103+
# Determine whether block size indicates we should parallelize over broadcasted dimension
104+
broadcasted_parallelization = False
105+
if broadcasting and block_size:
106+
if len(block_size) == len(shape_out):
107+
if (
108+
block_size[-wcs_in.low_level_wcs.pixel_n_dim :]
109+
== shape_out[-wcs_in.low_level_wcs.pixel_n_dim :]
110+
):
111+
broadcasted_parallelization = True
112+
block_size = (
113+
block_size[: -wcs_in.low_level_wcs.pixel_n_dim]
114+
+ (-1,) * wcs_in.low_level_wcs.pixel_n_dim
115+
)
116+
100117
# We set up a global temporary directory since this will be used e.g. to
101118
# store memory mapped Numpy arrays and zarr arrays.
102119

@@ -154,28 +171,10 @@ def _reproject_dispatcher(
154171

155172
shape_in = array_in.shape
156173

157-
# When in parallel mode, we want to make sure we avoid having to copy the
158-
# input array to all processes for each chunk, so instead we write out
159-
# the input array to a Numpy memory map and load it in inside each process
160-
# as a memory-mapped array. We need to be careful how this gets passed to
161-
# reproject_single_block so we pass a variable that can be either a string
162-
# or the array itself (for synchronous mode). If the input array is a dask
163-
# array we should always write it out to a memmap even in synchronous mode
164-
# otherwise map_blocks gets confused if it gets two dask arrays and tries
165-
# to iterate over both.
166-
167-
if isinstance(array_in, da.core.Array) or parallel:
168-
array_in_or_path = as_delayed_memmap_path(array_in, tmp_dir)
169-
else:
170-
# Here we could set array_in_or_path to array_in_path if it
171-
# has been set previously, but in synchronous mode it is better to
172-
# simply pass a reference to the memmap array itself to avoid having
173-
# to load the memmap inside each reproject_single_block call.
174-
array_in_or_path = array_in
175-
176174
def reproject_single_block(a, array_or_path, block_info=None):
177175
if a.ndim == 0 or block_info is None or block_info == []:
178176
return np.array([a, a])
177+
179178
slices = [slice(*x) for x in block_info[None]["array-location"][-wcs_out.pixel_n_dim :]]
180179

181180
if isinstance(wcs_out, BaseHighLevelWCS):
@@ -209,37 +208,70 @@ def reproject_single_block(a, array_or_path, block_info=None):
209208
# NOTE: the following array is just used to set up the iteration in map_blocks
210209
# but isn't actually used otherwise - this is deliberate.
211210

212-
if block_size:
213-
if wcs_in.low_level_wcs.pixel_n_dim < len(shape_out):
214-
if len(block_size) < len(shape_out):
215-
block_size = [-1] * (len(shape_out) - len(block_size)) + list(block_size)
216-
else:
217-
for i in range(len(shape_out) - wcs_in.low_level_wcs.pixel_n_dim):
218-
if block_size[i] != -1 and block_size[i] != shape_out[i]:
219-
raise ValueError(
220-
"block shape for extra broadcasted dimensions should cover entire array along those dimensions"
221-
)
211+
if broadcasted_parallelization:
222212
array_out_dask = da.empty(shape_out, chunks=block_size)
213+
array_in = array_in.rechunk(block_size)
214+
215+
result = da.map_blocks(
216+
reproject_single_block,
217+
array_out_dask,
218+
array_in,
219+
dtype=float,
220+
new_axis=0,
221+
chunks=(2,) + array_out_dask.chunksize,
222+
)
223+
223224
else:
224-
if wcs_in.low_level_wcs.pixel_n_dim < len(shape_out):
225-
chunks = (-1,) * (len(shape_out) - wcs_in.low_level_wcs.pixel_n_dim)
226-
chunks += ("auto",) * wcs_in.low_level_wcs.pixel_n_dim
227-
rechunk_kwargs = {"chunks": chunks}
225+
# When in parallel mode, we want to make sure we avoid having to copy the
226+
# input array to all processes for each chunk, so instead we write out
227+
# the input array to a Numpy memory map and load it in inside each process
228+
# as a memory-mapped array. We need to be careful how this gets passed to
229+
# reproject_single_block so we pass a variable that can be either a string
230+
# or the array itself (for synchronous mode). If the input array is a dask
231+
# array we should always write it out to a memmap even in synchronous mode
232+
# otherwise map_blocks gets confused if it gets two dask arrays and tries
233+
# to iterate over both.
234+
235+
if isinstance(array_in, da.core.Array) or parallel:
236+
array_in_or_path = as_delayed_memmap_path(array_in, tmp_dir)
228237
else:
229-
rechunk_kwargs = {}
230-
array_out_dask = da.empty(shape_out)
231-
array_out_dask = array_out_dask.rechunk(
232-
block_size_limit=8 * 1024**2, **rechunk_kwargs
233-
)
238+
# Here we could set array_in_or_path to array_in_path if it
239+
# has been set previously, but in synchronous mode it is better to
240+
# simply pass a reference to the memmap array itself to avoid having
241+
# to load the memmap inside each reproject_single_block call.
242+
array_in_or_path = array_in
243+
244+
if block_size:
245+
if broadcasting:
246+
if len(block_size) < len(shape_out):
247+
block_size = [-1] * (len(shape_out) - len(block_size)) + list(block_size)
248+
else:
249+
for i in range(len(shape_out) - wcs_in.low_level_wcs.pixel_n_dim):
250+
if block_size[i] != -1 and block_size[i] != shape_out[i]:
251+
raise ValueError(
252+
"block shape for extra broadcasted dimensions should cover entire array along those dimensions"
253+
)
254+
array_out_dask = da.empty(shape_out, chunks=block_size)
255+
else:
256+
if broadcasting:
257+
chunks = (-1,) * (len(shape_out) - wcs_in.low_level_wcs.pixel_n_dim)
258+
chunks += ("auto",) * wcs_in.low_level_wcs.pixel_n_dim
259+
rechunk_kwargs = {"chunks": chunks}
260+
else:
261+
rechunk_kwargs = {}
262+
array_out_dask = da.empty(shape_out)
263+
array_out_dask = array_out_dask.rechunk(
264+
block_size_limit=8 * 1024**2, **rechunk_kwargs
265+
)
234266

235-
result = da.map_blocks(
236-
reproject_single_block,
237-
array_out_dask,
238-
array_in_or_path,
239-
dtype=float,
240-
new_axis=0,
241-
chunks=(2,) + array_out_dask.chunksize,
242-
)
267+
result = da.map_blocks(
268+
reproject_single_block,
269+
array_out_dask,
270+
array_in_or_path,
271+
dtype=float,
272+
new_axis=0,
273+
chunks=(2,) + array_out_dask.chunksize,
274+
)
243275

244276
# Ensure that there are no more references to Numpy memmaps
245277
array_in = None

0 commit comments

Comments
 (0)