Skip to content

Commit 9d4ea83

Browse files
authored
Merge pull request wolny#115 from Josh-Talks/fix-multi-threshold
fix overwriting of prediction for multiple thresholds
2 parents 8ef822f + 19aae9e commit 9d4ea83

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

pytorch3dunet/unet3d/metrics.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -205,16 +205,16 @@ def input_to_segm(self, input):
205205
for predictions in input:
206206
for th in self.thresholds:
207207
# threshold probability maps
208-
predictions = predictions > th
208+
predictions_th = predictions > th
209209

210210
if self.invert_pmaps:
211211
# for connected component analysis we need to treat boundary signal as background
212212
# assign 0-label to boundary mask
213-
predictions = np.logical_not(predictions)
213+
predictions_th = np.logical_not(predictions_th)
214214

215-
predictions = predictions.astype(np.uint8)
215+
predictions_th = predictions_th.astype(np.uint8)
216216
# run connected components on the predicted mask; consider only 1-connectivity
217-
seg = measure.label(predictions, background=0, connectivity=1)
217+
seg = measure.label(predictions_th, background=0, connectivity=1)
218218
segs.append(seg)
219219

220220
return np.stack(segs)

0 commit comments

Comments
 (0)