diff --git a/pytorch3dunet/datasets/hdf5.py b/pytorch3dunet/datasets/hdf5.py index 147c1d97..d4a02e1c 100644 --- a/pytorch3dunet/datasets/hdf5.py +++ b/pytorch3dunet/datasets/hdf5.py @@ -72,6 +72,7 @@ def __init__(self, file_path, phase, slice_builder_config, transformer_config, r @staticmethod def load_dataset(input_file, internal_path): + assert internal_path in input_file, f"Internal path: {internal_path} not found in the H5 file" ds = input_file[internal_path][:] assert ds.ndim in [3, 4], \ f"Invalid dataset dimension: {ds.ndim}. Supported dataset formats: (C, Z, Y, X) or (Z, Y, X)" diff --git a/pytorch3dunet/unet3d/losses.py b/pytorch3dunet/unet3d/losses.py index 598b0715..6a53966f 100644 --- a/pytorch3dunet/unet3d/losses.py +++ b/pytorch3dunet/unet3d/losses.py @@ -1,11 +1,8 @@ import torch import torch.nn.functional as F from torch import nn as nn -from torch.autograd import Variable from torch.nn import MSELoss, SmoothL1Loss, L1Loss -from pytorch3dunet.unet3d.utils import expand_as_one_hot - def compute_per_channel_dice(input, target, epsilon=1e-6, weight=None): """ @@ -69,15 +66,17 @@ def __init__(self, loss, squeeze_channel=False): self.loss = loss self.squeeze_channel = squeeze_channel - def forward(self, input, target): + def forward(self, input, target, weight=None): assert target.size(1) > 1, 'Target tensor has a singleton channel dimension, cannot remove channel' # skips last target channel if needed target = target[:, :-1, ...] if self.squeeze_channel: - # squeeze channel dimension if singleton + # squeeze channel dimension target = torch.squeeze(target, dim=1) + if weight is not None: + return self.loss(input, target, weight) return self.loss(input, target) @@ -198,14 +197,13 @@ def _class_weights(input): flattened = flatten(input) nominator = (1. - flattened).sum(-1) denominator = flattened.sum(-1) - class_weights = Variable(nominator / denominator, requires_grad=False) - return class_weights + class_weights = nominator / denominator + return class_weights.detach() class PixelWiseCrossEntropyLoss(nn.Module): - def __init__(self, class_weights=None, ignore_index=None): + def __init__(self, ignore_index=None): super(PixelWiseCrossEntropyLoss, self).__init__() - self.register_buffer('class_weights', class_weights) self.ignore_index = ignore_index self.log_softmax = nn.LogSoftmax(dim=1) @@ -214,26 +212,26 @@ def forward(self, input, target, weights): # normalize the input log_probabilities = self.log_softmax(input) # standard CrossEntropyLoss requires the target to be (NxDxHxW), so we need to expand it to (NxCxDxHxW) - target = expand_as_one_hot(target, C=input.size()[1], ignore_index=self.ignore_index) - # expand weights - weights = weights.unsqueeze(1) - weights = weights.expand_as(input) - - # create default class_weights if None - if self.class_weights is None: - class_weights = torch.ones(input.size()[1]).float().cuda() + if self.ignore_index is not None: + mask = target == self.ignore_index + target[mask] = 0 else: - class_weights = self.class_weights - - # resize class_weights to be broadcastable into the weights - class_weights = class_weights.view(1, -1, 1, 1, 1) - - # multiply weights tensor by class weights - weights = class_weights * weights - + mask = torch.zeros_like(target) + # add channel dimension and invert the mask + mask = 1 - mask.unsqueeze(1) + # convert target to one-hot encoding + target = F.one_hot(target.long()) + if target.ndim == 5: + # permute target to (NxCxDxHxW) + target = target.permute(0, 4, 1, 2, 3).contiguous() + else: + target = target.permute(0, 3, 1, 2).contiguous() + # apply the mask on the target + target = target * mask + # add channel dimension to the weights + weights = weights.unsqueeze(1) # compute the losses result = -weights * target * log_probabilities - # average the losses return result.mean() @@ -326,7 +324,7 @@ def _create_loss(name, loss_config, weight, ignore_index, pos_weight): ignore_index = -100 # use the default 'ignore_index' as defined in the CrossEntropyLoss return WeightedCrossEntropyLoss(ignore_index=ignore_index) elif name == 'PixelWiseCrossEntropyLoss': - return PixelWiseCrossEntropyLoss(class_weights=weight, ignore_index=ignore_index) + return PixelWiseCrossEntropyLoss(ignore_index=ignore_index) elif name == 'GeneralizedDiceLoss': normalization = loss_config.get('normalization', 'sigmoid') return GeneralizedDiceLoss(normalization=normalization) diff --git a/tests/test_criterion.py b/tests/test_criterion.py index bf17db69..40a4b1d9 100644 --- a/tests/test_criterion.py +++ b/tests/test_criterion.py @@ -5,7 +5,7 @@ from pytorch3dunet.augment.transforms import LabelToAffinities, StandardLabelToBoundary from pytorch3dunet.unet3d.losses import GeneralizedDiceLoss, DiceLoss, WeightedSmoothL1Loss, _MaskingLossWrapper, \ - SkipLastTargetChannelWrapper, BCEDiceLoss + SkipLastTargetChannelWrapper, BCEDiceLoss, PixelWiseCrossEntropyLoss from pytorch3dunet.unet3d.metrics import DiceCoefficient, MeanIoU, BoundaryAveragePrecision, AdaptedRandError, \ BoundaryAdaptedRandError @@ -39,7 +39,7 @@ def _eval_criterion(criterion, batch_shape, n_times=100): class TestCriterion: def test_dice_coefficient(self): results = _compute_criterion(DiceCoefficient()) - # check that all of the coefficients belong to [0, 1] + # check that all the coefficients belong to [0, 1] results = np.array(results) assert np.all(results > 0) assert np.all(results < 1) @@ -133,6 +133,22 @@ def test_bce_dice_loss(self): results = np.array(results) assert np.all(results > 0) + def test_pixel_wise_cross_entropy_3d(self): + loss = PixelWiseCrossEntropyLoss() + input = torch.randn(3, 2, 30, 30, 30) + target = torch.empty(3, 30, 30, 30, dtype=torch.long).random_(2) + weight = torch.rand(3, 30, 30, 30) + output = loss(input, target, weight) + assert output.item() > 0 + + def test_pixel_wise_cross_entropy_2d(self): + loss = PixelWiseCrossEntropyLoss() + input = torch.randn(3, 2, 30, 30) + target = torch.empty(3, 30, 30, dtype=torch.long).random_(2) + weight = torch.rand(3, 30, 30) + output = loss(input, target, weight) + assert output.item() > 0 + def test_ignore_index_loss(self): loss = _MaskingLossWrapper(nn.BCEWithLogitsLoss(), ignore_index=-1) input = torch.rand((3, 3))