Skip to content

Commit

Permalink
Merge pull request scilus#1053 from EmmaRenauld/no_multiproc_cov
Browse files Browse the repository at this point in the history
Improve coverage in multiprocessing cases
  • Loading branch information
arnaudbore authored Nov 26, 2024
2 parents 3ea1854 + 06122c1 commit 4d7f52c
Show file tree
Hide file tree
Showing 11 changed files with 425 additions and 282 deletions.
10 changes: 10 additions & 0 deletions scilpy/image/tests/test_volume_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,16 @@ def test_resample_volume():
assert_equal(resampled_img.get_fdata(), ref3d)
assert resampled_img.affine[0, 0] == 3

# 4) Same test, with a fake 4th dimension
moving3d = np.stack((moving3d, moving3d), axis=-1)
moving3d_img = nib.Nifti1Image(moving3d, np.eye(4))
resampled_img = resample_volume(moving3d_img, voxel_res=(3, 3, 3),
interp='nn')
result = resampled_img.get_fdata()
assert_equal(result[:, :, :, 0], ref3d)
assert_equal(result[:, :, :, 1], ref3d)
assert resampled_img.affine[0, 0] == 3


def test_reshape_volume_pad():
# 3D img
Expand Down
94 changes: 56 additions & 38 deletions scilpy/reconst/divide.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def _gamma_data2fit(signal, gtab_infos, fit_iters=1, random_iters=50,
Returns
-------
best_params : np.array
Array containing the parameters of the fit.
Array containing the parameters of the fit. Shape: (4,)
"""
if np.sum(gtab_infos[3]) > 0 and do_multiple_s0 is True:
ns = len(np.unique(gtab_infos[3])) - 1
Expand Down Expand Up @@ -263,25 +263,33 @@ def gamma_fit2metrics(params):


def _fit_gamma_parallel(args):
data = args[0]
gtab_infos = args[1]
fit_iters = args[2]
random_iters = args[3]
do_weight_bvals = args[4]
do_weight_pa = args[5]
do_multiple_s0 = args[6]
chunk_id = args[7]

sub_fit_array = np.zeros((data.shape[0], 4))
for i in range(data.shape[0]):
if data[i].any():
sub_fit_array[i] = _gamma_data2fit(data[i], gtab_infos, fit_iters,
random_iters, do_weight_bvals,
do_weight_pa, do_multiple_s0)
# Data: Ravelled 4D data. Shape [N, X] where N is the number of voxels.
(data, gtab_infos, fit_iters, random_iters,
do_weight_bvals, do_weight_pa, do_multiple_s0, chunk_id) = args

sub_fit_array = _fit_gamma_loop(data, gtab_infos, fit_iters,
random_iters, do_weight_bvals,
do_weight_pa, do_multiple_s0)

return chunk_id, sub_fit_array


def _fit_gamma_loop(data, gtab_infos, fit_iters, random_iters,
do_weight_bvals, do_weight_pa, do_multiple_s0):
"""
Loops on 2D data and fits each voxel separately
See _gamma_data2fit for a complete description.
"""
# Data: Ravelled 4D data. Shape [N, X] where N is the number of voxels.
tmp_fit_array = np.zeros((data.shape[0], 4))
for i in range(data.shape[0]):
if data[i].any():
tmp_fit_array[i] = _gamma_data2fit(
data[i], gtab_infos, fit_iters, random_iters,
do_weight_bvals, do_weight_pa, do_multiple_s0)
return tmp_fit_array


def fit_gamma(data, gtab_infos, mask=None, fit_iters=1, random_iters=50,
do_weight_bvals=False, do_weight_pa=False, do_multiple_s0=False,
nbr_processes=None):
Expand Down Expand Up @@ -328,30 +336,40 @@ def fit_gamma(data, gtab_infos, mask=None, fit_iters=1, random_iters=50,
or nbr_processes <= 0 else nbr_processes

# Ravel the first 3 dimensions while keeping the 4th intact, like a list of
# 1D time series voxels. Then separate it in chunks of len(nbr_processes).
# 1D time series voxels.
data = data[mask].reshape((np.count_nonzero(mask), data_shape[3]))
chunks = np.array_split(data, nbr_processes)

chunk_len = np.cumsum([0] + [len(c) for c in chunks])
pool = multiprocessing.Pool(nbr_processes)
results = pool.map(_fit_gamma_parallel,
zip(chunks,
itertools.repeat(gtab_infos),
itertools.repeat(fit_iters),
itertools.repeat(random_iters),
itertools.repeat(do_weight_bvals),
itertools.repeat(do_weight_pa),
itertools.repeat(do_multiple_s0),
np.arange(len(chunks))))
pool.close()
pool.join()

# Re-assemble the chunk together in the original shape.
fit_array = np.zeros((data_shape[0:3])+(4,))
tmp_fit_array = np.zeros((np.count_nonzero(mask), 4))
for i, fit in results:
tmp_fit_array[chunk_len[i]:chunk_len[i+1]] = fit

# Separating the case nbr_processes=1 to help get good coverage metrics
# (codecov does not deal well with multiprocessing)
if nbr_processes == 1:
tmp_fit_array = _fit_gamma_loop(data, gtab_infos, fit_iters,
random_iters, do_weight_bvals,
do_weight_pa, do_multiple_s0)
else:
# Separate the data in chunks of len(nbr_processes).
chunks = np.array_split(data, nbr_processes)

pool = multiprocessing.Pool(nbr_processes)
results = pool.map(_fit_gamma_parallel,
zip(chunks,
itertools.repeat(gtab_infos),
itertools.repeat(fit_iters),
itertools.repeat(random_iters),
itertools.repeat(do_weight_bvals),
itertools.repeat(do_weight_pa),
itertools.repeat(do_multiple_s0),
np.arange(len(chunks))))
pool.close()
pool.join()

# Re-assemble the chunks together.
chunk_len = np.cumsum([0] + [len(c) for c in chunks])
tmp_fit_array = np.zeros((np.count_nonzero(mask), 4))
for chunk_id, fit in results:
tmp_fit_array[chunk_len[chunk_id]:chunk_len[chunk_id + 1]] = fit

# Bring back to the original shape
fit_array = np.zeros((data_shape[0:3]) + (4,))
fit_array[mask] = tmp_fit_array

return fit_array
63 changes: 39 additions & 24 deletions scilpy/reconst/fodf.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,20 +126,27 @@ def get_ventricles_max_fodf(data, fa, md, zoom, sh_basis,


def _fit_from_model_parallel(args):
model = args[0]
data = args[1]
chunk_id = args[2]
(model, data, chunk_id) = args
sub_fit_array = _fit_from_model_loop(data, model)

sub_fit_array = np.zeros((data.shape[0],), dtype='object')
return chunk_id, sub_fit_array


def _fit_from_model_loop(data, model):
"""
Loops on 2D data and fits each voxel separately.
See fit_from_model for more information.
"""
# Data: Ravelled 4D data. Shape [N, X] where N is the number of voxels.
tmp_fit_array = np.zeros((data.shape[0],), dtype='object')
for i in range(data.shape[0]):
if data[i].any():
try:
sub_fit_array[i] = model.fit(data[i])
tmp_fit_array[i] = model.fit(data[i])
except cvx.error.SolverError:
coeff = np.full((len(model.n)), np.NaN)
sub_fit_array[i] = MSDeconvFit(model, coeff, None)

return chunk_id, sub_fit_array
tmp_fit_array[i] = MSDeconvFit(model, coeff, None)
return tmp_fit_array


def fit_from_model(model, data, mask=None, nbr_processes=None):
Expand Down Expand Up @@ -181,23 +188,31 @@ def fit_from_model(model, data, mask=None, nbr_processes=None):
# Ravel the first 3 dimensions while keeping the 4th intact, like a list of
# 1D time series voxels. Then separate it in chunks of len(nbr_processes).
data = data[mask].reshape((np.count_nonzero(mask), data_shape[3]))
chunks = np.array_split(data, nbr_processes)

chunk_len = np.cumsum([0] + [len(c) for c in chunks])
pool = multiprocessing.Pool(nbr_processes)
results = pool.map(_fit_from_model_parallel,
zip(itertools.repeat(model),
chunks,
np.arange(len(chunks))))
pool.close()
pool.join()

# Re-assemble the chunk together in the original shape.
fit_array = np.zeros(data_shape[0:3], dtype='object')
tmp_fit_array = np.zeros((np.count_nonzero(mask)), dtype='object')
for i, fit in results:
tmp_fit_array[chunk_len[i]:chunk_len[i+1]] = fit

# Separating the case nbr_processes=1 to help get good coverage metrics
# (codecov does not deal well with multiprocessing)
if nbr_processes == 1:
tmp_fit_array = _fit_from_model_loop(data, model)
else:
# Separate the data in chunks of len(nbr_processes).
chunks = np.array_split(data, nbr_processes)

pool = multiprocessing.Pool(nbr_processes)
results = pool.map(_fit_from_model_parallel,
zip(itertools.repeat(model),
chunks,
np.arange(len(chunks))))
pool.close()
pool.join()

# Re-assemble the chunks together.
chunk_len = np.cumsum([0] + [len(c) for c in chunks])
tmp_fit_array = np.zeros((np.count_nonzero(mask)), dtype='object')
for i, fit in results:
tmp_fit_array[chunk_len[i]:chunk_len[i+1]] = fit

# Bring back to the original shape
fit_array = np.zeros(data_shape[0:3], dtype='object')
fit_array[mask] = tmp_fit_array
fit_array = MultiVoxelFit(model, fit_array, mask)

Expand Down
Loading

0 comments on commit 4d7f52c

Please sign in to comment.