Skip to content

Commit 560e8ff

Browse files
authored
Update dice.py (#608)
1 parent d316fbf commit 560e8ff

File tree

1 file changed

+2
-2
lines changed
  • segmentation_models_pytorch/losses

1 file changed

+2
-2
lines changed

segmentation_models_pytorch/losses/dice.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -89,10 +89,10 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
8989
y_pred = y_pred * mask.unsqueeze(1)
9090

9191
y_true = F.one_hot((y_true * mask).to(torch.long), num_classes) # N,H*W -> N,H*W, C
92-
y_true = y_true.permute(0, 2, 1) * mask.unsqueeze(1) # H, C, H*W
92+
y_true = y_true.permute(0, 2, 1) * mask.unsqueeze(1) # N, C, H*W
9393
else:
9494
y_true = F.one_hot(y_true, num_classes) # N,H*W -> N,H*W, C
95-
y_true = y_true.permute(0, 2, 1) # H, C, H*W
95+
y_true = y_true.permute(0, 2, 1) # N, C, H*W
9696

9797
if self.mode == MULTILABEL_MODE:
9898
y_true = y_true.view(bs, num_classes, -1)

0 commit comments

Comments
 (0)