Skip to content

Commit

Permalink
Fix PixelWiseCrossEntropy error when combined with skip_last_target
Browse files Browse the repository at this point in the history
… flag
  • Loading branch information
wolny committed Jan 3, 2024
1 parent ab2a1ff commit 8e31345
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 29 deletions.
1 change: 1 addition & 0 deletions pytorch3dunet/datasets/hdf5.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def __init__(self, file_path, phase, slice_builder_config, transformer_config, r

@staticmethod
def load_dataset(input_file, internal_path):
assert internal_path in input_file, f"Internal path: {internal_path} not found in the H5 file"
ds = input_file[internal_path][:]
assert ds.ndim in [3, 4], \
f"Invalid dataset dimension: {ds.ndim}. Supported dataset formats: (C, Z, Y, X) or (Z, Y, X)"
Expand Down
52 changes: 25 additions & 27 deletions pytorch3dunet/unet3d/losses.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
import torch
import torch.nn.functional as F
from torch import nn as nn
from torch.autograd import Variable
from torch.nn import MSELoss, SmoothL1Loss, L1Loss

from pytorch3dunet.unet3d.utils import expand_as_one_hot


def compute_per_channel_dice(input, target, epsilon=1e-6, weight=None):
"""
Expand Down Expand Up @@ -69,15 +66,17 @@ def __init__(self, loss, squeeze_channel=False):
self.loss = loss
self.squeeze_channel = squeeze_channel

def forward(self, input, target):
def forward(self, input, target, weight=None):
assert target.size(1) > 1, 'Target tensor has a singleton channel dimension, cannot remove channel'

# skips last target channel if needed
target = target[:, :-1, ...]

if self.squeeze_channel:
# squeeze channel dimension if singleton
# squeeze channel dimension
target = torch.squeeze(target, dim=1)
if weight is not None:
return self.loss(input, target, weight)
return self.loss(input, target)


Expand Down Expand Up @@ -198,14 +197,13 @@ def _class_weights(input):
flattened = flatten(input)
nominator = (1. - flattened).sum(-1)
denominator = flattened.sum(-1)
class_weights = Variable(nominator / denominator, requires_grad=False)
return class_weights
class_weights = nominator / denominator
return class_weights.detach()


class PixelWiseCrossEntropyLoss(nn.Module):
def __init__(self, class_weights=None, ignore_index=None):
def __init__(self, ignore_index=None):
super(PixelWiseCrossEntropyLoss, self).__init__()
self.register_buffer('class_weights', class_weights)
self.ignore_index = ignore_index
self.log_softmax = nn.LogSoftmax(dim=1)

Expand All @@ -214,26 +212,26 @@ def forward(self, input, target, weights):
# normalize the input
log_probabilities = self.log_softmax(input)
# standard CrossEntropyLoss requires the target to be (NxDxHxW), so we need to expand it to (NxCxDxHxW)
target = expand_as_one_hot(target, C=input.size()[1], ignore_index=self.ignore_index)
# expand weights
weights = weights.unsqueeze(1)
weights = weights.expand_as(input)

# create default class_weights if None
if self.class_weights is None:
class_weights = torch.ones(input.size()[1]).float().cuda()
if self.ignore_index is not None:
mask = target == self.ignore_index
target[mask] = 0
else:
class_weights = self.class_weights

# resize class_weights to be broadcastable into the weights
class_weights = class_weights.view(1, -1, 1, 1, 1)

# multiply weights tensor by class weights
weights = class_weights * weights

mask = torch.zeros_like(target)
# add channel dimension and invert the mask
mask = 1 - mask.unsqueeze(1)
# convert target to one-hot encoding
target = F.one_hot(target.long())
if target.ndim == 5:
# permute target to (NxCxDxHxW)
target = target.permute(0, 4, 1, 2, 3).contiguous()
else:
target = target.permute(0, 3, 1, 2).contiguous()
# apply the mask on the target
target = target * mask
# add channel dimension to the weights
weights = weights.unsqueeze(1)
# compute the losses
result = -weights * target * log_probabilities
# average the losses
return result.mean()


Expand Down Expand Up @@ -326,7 +324,7 @@ def _create_loss(name, loss_config, weight, ignore_index, pos_weight):
ignore_index = -100 # use the default 'ignore_index' as defined in the CrossEntropyLoss
return WeightedCrossEntropyLoss(ignore_index=ignore_index)
elif name == 'PixelWiseCrossEntropyLoss':
return PixelWiseCrossEntropyLoss(class_weights=weight, ignore_index=ignore_index)
return PixelWiseCrossEntropyLoss(ignore_index=ignore_index)
elif name == 'GeneralizedDiceLoss':
normalization = loss_config.get('normalization', 'sigmoid')
return GeneralizedDiceLoss(normalization=normalization)
Expand Down
20 changes: 18 additions & 2 deletions tests/test_criterion.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from pytorch3dunet.augment.transforms import LabelToAffinities, StandardLabelToBoundary
from pytorch3dunet.unet3d.losses import GeneralizedDiceLoss, DiceLoss, WeightedSmoothL1Loss, _MaskingLossWrapper, \
SkipLastTargetChannelWrapper, BCEDiceLoss
SkipLastTargetChannelWrapper, BCEDiceLoss, PixelWiseCrossEntropyLoss
from pytorch3dunet.unet3d.metrics import DiceCoefficient, MeanIoU, BoundaryAveragePrecision, AdaptedRandError, \
BoundaryAdaptedRandError

Expand Down Expand Up @@ -39,7 +39,7 @@ def _eval_criterion(criterion, batch_shape, n_times=100):
class TestCriterion:
def test_dice_coefficient(self):
results = _compute_criterion(DiceCoefficient())
# check that all of the coefficients belong to [0, 1]
# check that all the coefficients belong to [0, 1]
results = np.array(results)
assert np.all(results > 0)
assert np.all(results < 1)
Expand Down Expand Up @@ -133,6 +133,22 @@ def test_bce_dice_loss(self):
results = np.array(results)
assert np.all(results > 0)

def test_pixel_wise_cross_entropy_3d(self):
loss = PixelWiseCrossEntropyLoss()
input = torch.randn(3, 2, 30, 30, 30)
target = torch.empty(3, 30, 30, 30, dtype=torch.long).random_(2)
weight = torch.rand(3, 30, 30, 30)
output = loss(input, target, weight)
assert output.item() > 0

def test_pixel_wise_cross_entropy_2d(self):
loss = PixelWiseCrossEntropyLoss()
input = torch.randn(3, 2, 30, 30)
target = torch.empty(3, 30, 30, dtype=torch.long).random_(2)
weight = torch.rand(3, 30, 30)
output = loss(input, target, weight)
assert output.item() > 0

def test_ignore_index_loss(self):
loss = _MaskingLossWrapper(nn.BCEWithLogitsLoss(), ignore_index=-1)
input = torch.rand((3, 3))
Expand Down

0 comments on commit 8e31345

Please sign in to comment.