@@ -97,6 +97,23 @@ def _reproject_dispatcher(
97
97
if reproject_func_kwargs is None :
98
98
reproject_func_kwargs = {}
99
99
100
+ # Determine whether any broadcasting is taking place
101
+ broadcasting = wcs_in .low_level_wcs .pixel_n_dim < len (shape_out )
102
+
103
+ # Determine whether block size indicates we should parallelize over broadcasted dimension
104
+ broadcasted_parallelization = False
105
+ if broadcasting and block_size :
106
+ if len (block_size ) == len (shape_out ):
107
+ if (
108
+ block_size [- wcs_in .low_level_wcs .pixel_n_dim :]
109
+ == shape_out [- wcs_in .low_level_wcs .pixel_n_dim :]
110
+ ):
111
+ broadcasted_parallelization = True
112
+ block_size = (
113
+ block_size [: - wcs_in .low_level_wcs .pixel_n_dim ]
114
+ + (- 1 ,) * wcs_in .low_level_wcs .pixel_n_dim
115
+ )
116
+
100
117
# We set up a global temporary directory since this will be used e.g. to
101
118
# store memory mapped Numpy arrays and zarr arrays.
102
119
@@ -154,28 +171,10 @@ def _reproject_dispatcher(
154
171
155
172
shape_in = array_in .shape
156
173
157
- # When in parallel mode, we want to make sure we avoid having to copy the
158
- # input array to all processes for each chunk, so instead we write out
159
- # the input array to a Numpy memory map and load it in inside each process
160
- # as a memory-mapped array. We need to be careful how this gets passed to
161
- # reproject_single_block so we pass a variable that can be either a string
162
- # or the array itself (for synchronous mode). If the input array is a dask
163
- # array we should always write it out to a memmap even in synchronous mode
164
- # otherwise map_blocks gets confused if it gets two dask arrays and tries
165
- # to iterate over both.
166
-
167
- if isinstance (array_in , da .core .Array ) or parallel :
168
- array_in_or_path = as_delayed_memmap_path (array_in , tmp_dir )
169
- else :
170
- # Here we could set array_in_or_path to array_in_path if it
171
- # has been set previously, but in synchronous mode it is better to
172
- # simply pass a reference to the memmap array itself to avoid having
173
- # to load the memmap inside each reproject_single_block call.
174
- array_in_or_path = array_in
175
-
176
174
def reproject_single_block (a , array_or_path , block_info = None ):
177
175
if a .ndim == 0 or block_info is None or block_info == []:
178
176
return np .array ([a , a ])
177
+
179
178
slices = [slice (* x ) for x in block_info [None ]["array-location" ][- wcs_out .pixel_n_dim :]]
180
179
181
180
if isinstance (wcs_out , BaseHighLevelWCS ):
@@ -209,37 +208,70 @@ def reproject_single_block(a, array_or_path, block_info=None):
209
208
# NOTE: the following array is just used to set up the iteration in map_blocks
210
209
# but isn't actually used otherwise - this is deliberate.
211
210
212
- if block_size :
213
- if wcs_in .low_level_wcs .pixel_n_dim < len (shape_out ):
214
- if len (block_size ) < len (shape_out ):
215
- block_size = [- 1 ] * (len (shape_out ) - len (block_size )) + list (block_size )
216
- else :
217
- for i in range (len (shape_out ) - wcs_in .low_level_wcs .pixel_n_dim ):
218
- if block_size [i ] != - 1 and block_size [i ] != shape_out [i ]:
219
- raise ValueError (
220
- "block shape for extra broadcasted dimensions should cover entire array along those dimensions"
221
- )
211
+ if broadcasted_parallelization :
222
212
array_out_dask = da .empty (shape_out , chunks = block_size )
213
+ array_in = array_in .rechunk (block_size )
214
+
215
+ result = da .map_blocks (
216
+ reproject_single_block ,
217
+ array_out_dask ,
218
+ array_in ,
219
+ dtype = float ,
220
+ new_axis = 0 ,
221
+ chunks = (2 ,) + array_out_dask .chunksize ,
222
+ )
223
+
223
224
else :
224
- if wcs_in .low_level_wcs .pixel_n_dim < len (shape_out ):
225
- chunks = (- 1 ,) * (len (shape_out ) - wcs_in .low_level_wcs .pixel_n_dim )
226
- chunks += ("auto" ,) * wcs_in .low_level_wcs .pixel_n_dim
227
- rechunk_kwargs = {"chunks" : chunks }
225
+ # When in parallel mode, we want to make sure we avoid having to copy the
226
+ # input array to all processes for each chunk, so instead we write out
227
+ # the input array to a Numpy memory map and load it in inside each process
228
+ # as a memory-mapped array. We need to be careful how this gets passed to
229
+ # reproject_single_block so we pass a variable that can be either a string
230
+ # or the array itself (for synchronous mode). If the input array is a dask
231
+ # array we should always write it out to a memmap even in synchronous mode
232
+ # otherwise map_blocks gets confused if it gets two dask arrays and tries
233
+ # to iterate over both.
234
+
235
+ if isinstance (array_in , da .core .Array ) or parallel :
236
+ array_in_or_path = as_delayed_memmap_path (array_in , tmp_dir )
228
237
else :
229
- rechunk_kwargs = {}
230
- array_out_dask = da .empty (shape_out )
231
- array_out_dask = array_out_dask .rechunk (
232
- block_size_limit = 8 * 1024 ** 2 , ** rechunk_kwargs
233
- )
238
+ # Here we could set array_in_or_path to array_in_path if it
239
+ # has been set previously, but in synchronous mode it is better to
240
+ # simply pass a reference to the memmap array itself to avoid having
241
+ # to load the memmap inside each reproject_single_block call.
242
+ array_in_or_path = array_in
243
+
244
+ if block_size :
245
+ if broadcasting :
246
+ if len (block_size ) < len (shape_out ):
247
+ block_size = [- 1 ] * (len (shape_out ) - len (block_size )) + list (block_size )
248
+ else :
249
+ for i in range (len (shape_out ) - wcs_in .low_level_wcs .pixel_n_dim ):
250
+ if block_size [i ] != - 1 and block_size [i ] != shape_out [i ]:
251
+ raise ValueError (
252
+ "block shape for extra broadcasted dimensions should cover entire array along those dimensions"
253
+ )
254
+ array_out_dask = da .empty (shape_out , chunks = block_size )
255
+ else :
256
+ if broadcasting :
257
+ chunks = (- 1 ,) * (len (shape_out ) - wcs_in .low_level_wcs .pixel_n_dim )
258
+ chunks += ("auto" ,) * wcs_in .low_level_wcs .pixel_n_dim
259
+ rechunk_kwargs = {"chunks" : chunks }
260
+ else :
261
+ rechunk_kwargs = {}
262
+ array_out_dask = da .empty (shape_out )
263
+ array_out_dask = array_out_dask .rechunk (
264
+ block_size_limit = 8 * 1024 ** 2 , ** rechunk_kwargs
265
+ )
234
266
235
- result = da .map_blocks (
236
- reproject_single_block ,
237
- array_out_dask ,
238
- array_in_or_path ,
239
- dtype = float ,
240
- new_axis = 0 ,
241
- chunks = (2 ,) + array_out_dask .chunksize ,
242
- )
267
+ result = da .map_blocks (
268
+ reproject_single_block ,
269
+ array_out_dask ,
270
+ array_in_or_path ,
271
+ dtype = float ,
272
+ new_axis = 0 ,
273
+ chunks = (2 ,) + array_out_dask .chunksize ,
274
+ )
243
275
244
276
# Ensure that there are no more references to Numpy memmaps
245
277
array_in = None
0 commit comments