From 927acdb6774095c6718d016b4b90565d28ceda00 Mon Sep 17 00:00:00 2001 From: Lucas Fidon Date: Wed, 29 Apr 2020 12:51:38 +0200 Subject: [PATCH 1/3] added the Focal loss --- monai/losses/__init__.py | 1 + monai/losses/focal_loss.py | 95 ++++++++++++++ tests/test_focal_loss.py | 250 +++++++++++++++++++++++++++++++++++++ 3 files changed, 346 insertions(+) create mode 100644 monai/losses/focal_loss.py create mode 100644 tests/test_focal_loss.py diff --git a/monai/losses/__init__.py b/monai/losses/__init__.py index beaeaf6dbf..94ec922549 100644 --- a/monai/losses/__init__.py +++ b/monai/losses/__init__.py @@ -10,3 +10,4 @@ # limitations under the License. from .dice import * +from .focal_loss import * diff --git a/monai/losses/focal_loss.py b/monai/losses/focal_loss.py new file mode 100644 index 0000000000..f49e2be024 --- /dev/null +++ b/monai/losses/focal_loss.py @@ -0,0 +1,95 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Variable + + +class FocalLoss(nn.Module): + """ + PyTorch implementation of the Focal Loss. + [1] "Focal Loss for Dense Object Detection", T. Lin et al., ICCV 2017 + """ + def __init__(self, gamma=2., alpha=None, reduction='mean'): + """ + Args: + gamma: (float) value of the exponent gamma in the definition + of the Focal loss. + alpha: (float or float list or None) weights to apply to the + voxels of each class. If None no weights are applied. + reduction: (string) Reduction operation to apply on the loss batch. + It can be 'mean', 'sum' or 'none' as in the standard PyTorch API + for loss functions. + """ + super(FocalLoss, self).__init__() + # same default parameters as in the original paper [1] + self.gamma = gamma + self.alpha = alpha # weight for the classes + if isinstance(alpha, (float, int)): + self.alpha = torch.Tensor([alpha, 1 - alpha]) + if isinstance(alpha, list): + self.alpha = torch.Tensor(alpha) + self.reduction = reduction + + def forward(self, input, target): + """ + Args: + input: (tensor): the shape should be BNH[WD]. + target: (tensor): the shape should be BNH[WD]. + """ + i = input + t = target + # Resize the input and target + if t.dim() < i.dim(): + # Add a class dimension to the ground-truth segmentation + t = t.unsqueeze(1) # N,H,W => N,1,H,W + if input.dim() > 2: + i = i.view(i.size(0), i.size(1), -1) # N,C,H,W => N,C,H*W + t = t.view(t.size(0), t.size(1), -1) # N,1,H,W => N,1,H*W + else: # Compatibility with classification + i = i.unsqueeze(2) # N,C => N,C,1 + t = t.unsqueeze(2) # N,1 => N,1,1 + + # Compute the log proba (more stable numerically than softmax) + logpt = F.log_softmax(i, dim=1) # N,C,H*W + # Keep only log proba values of the ground truth class for each voxel + logpt = logpt.gather(1, t) # N,C,H*W => N,1,H*W + logpt = torch.squeeze(logpt, dim=1) # N,1,H*W => N,H*W + + # Get the proba + pt = torch.exp(logpt) # N,H*W + + if self.alpha is not None: + if self.alpha.type() != i.data.type(): + self.alpha = self.alpha.type_as(i.data) + # Select the correct weight for each voxel depending on its + # associated gt label + at = torch.unsqueeze(self.alpha, dim=0) # C => 1,C + at = torch.unsqueeze(at, dim=2) # 1,C => 1,C,1 + at = at.expand((t.size(0), -1, t.size(2))) # 1,C,1 => N,C,H*W + at = at.gather(1, t.data) # selection of the weights => N,1,H*W + at = torch.squeeze(at, dim=1) # N,1,H*W => N,H*W + # Multiply the log proba by their weights + logpt = logpt * Variable(at) + + # Compute the loss mini-batch + weight = torch.pow(-pt + 1., self.gamma) + loss = torch.mean(-weight * logpt, dim=1) # N + + if self.reduction == 'sum': + return loss.sum() + elif self.reduction == 'none': + return loss + # Default is mean reduction + else: + return loss.mean() diff --git a/tests/test_focal_loss.py b/tests/test_focal_loss.py new file mode 100644 index 0000000000..9fb05ba57a --- /dev/null +++ b/tests/test_focal_loss.py @@ -0,0 +1,250 @@ +# Copyright 2020 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch.nn.functional as F +import torch.nn as nn +import torch.optim as optim +from torch.autograd import Variable +import torch +import unittest +from monai.losses import FocalLoss + + +class TestFocalLoss(unittest.TestCase): + def test_consistency_with_cross_entropy_2d(self): + # For gamma=0 the focal loss reduces to the cross entropy loss + focal_loss = FocalLoss(gamma=0., reduction='mean') + ce = nn.CrossEntropyLoss(reduction='mean') + max_error = 0 + class_num = 10 + batch_size = 128 + for _ in range(100): + # Create a random tensor of shape (batch_size, class_num, 8, 4) + x = torch.rand(batch_size, class_num, 8, 4) + x = Variable(x.cuda()) + # Create a random batch of classes + l = torch.randint(low=0, high=class_num, size=(batch_size, 8, 4)) + l = l.long() + l = Variable(l.cuda()) + output0 = focal_loss.forward(x, l) + output1 = ce.forward(x, l) + a = float(output0.cpu().detach()) + b = float(output1.cpu().detach()) + if abs(a - b) > max_error: max_error = abs(a - b) + self.assertAlmostEqual(max_error, 0., places=3) + + def test_consistency_with_cross_entropy_classification(self): + # for gamma=0 the focal loss reduces to the cross entropy loss + focal_loss = FocalLoss(gamma=0., reduction='mean') + ce = nn.CrossEntropyLoss(reduction='mean') + max_error = 0 + class_num = 10 + batch_size = 128 + for _ in range(100): + # Create a random scores tensor of shape (batch_size, class_num) + x = torch.rand(batch_size, class_num) + x = Variable(x.cuda()) + # Create a random batch of classes + l = torch.randint(low=0, high=class_num, size=(batch_size,)) + l = l.long() + l = Variable(l.cuda()) + output0 = focal_loss.forward(x, l) + output1 = ce.forward(x, l) + a = float(output0.cpu().detach()) + b = float(output1.cpu().detach()) + if abs(a - b) > max_error: max_error = abs(a - b) + self.assertAlmostEqual(max_error, 0., places=3) + + def test_bin_seg_2d(self): + # define 2d examples + target = torch.tensor( + [[0,0,0,0], + [0,1,1,0], + [0,1,1,0], + [0,0,0,0]] + ) + # add another dimension corresponding to the batch (batch size = 1 here) + target = target.unsqueeze(0) # shape (1, H, W) + pred_very_good = 1000 * F.one_hot( + target, num_classes=2).permute(0, 3, 1, 2).float() + + # initialize the mean dice loss + loss = FocalLoss() + + # focal loss for pred_very_good should be close to 0 + focal_loss_good = float(loss.forward(pred_very_good, target).cpu()) + self.assertAlmostEqual(focal_loss_good, 0., places=3) + + # Same test, but for target with a class dimension + target = target.unsqueeze(1) # shape (1, 1, H, W) + focal_loss_good = float(loss.forward(pred_very_good, target).cpu()) + self.assertAlmostEqual(focal_loss_good, 0., places=3) + + def test_empty_class_2d(self): + num_classes = 2 + # define 2d examples + target = torch.tensor( + [[0,0,0,0], + [0,0,0,0], + [0,0,0,0], + [0,0,0,0]] + ) + # add another dimension corresponding to the batch (batch size = 1 here) + target = target.unsqueeze(0) # shape (1, H, W) + pred_very_good = 1000 * F.one_hot( + target, num_classes=num_classes).permute(0, 3, 1, 2).float() + + # initialize the mean dice loss + loss = FocalLoss() + + # focal loss for pred_very_good should be close to 0 + focal_loss_good = float(loss.forward(pred_very_good, target).cpu()) + self.assertAlmostEqual(focal_loss_good, 0., places=3) + + def test_multi_class_seg_2d(self): + num_classes = 6 # labels 0 to 5 + # define 2d examples + target = torch.tensor( + [[0,0,0,0], + [0,1,2,0], + [0,3,4,0], + [0,0,0,0]] + ) + # add another dimension corresponding to the batch (batch size = 1 here) + target = target.unsqueeze(0) # shape (1, H, W) + pred_very_good = 1000 * F.one_hot( + target, num_classes=num_classes).permute(0, 3, 1, 2).float() + + # initialize the mean dice loss + loss = FocalLoss() + + # focal loss for pred_very_good should be close to 0 + focal_loss_good = float(loss.forward(pred_very_good, target).cpu()) + self.assertAlmostEqual(focal_loss_good, 0., places=3) + + def test_bin_seg_3d(self): + # define 2d examples + target = torch.tensor( + [ + # raw 0 + [[0, 0, 0, 0], + [0, 1, 1, 0], + [0, 1, 1, 0], + [0, 0, 0, 0]], + # raw 1 + [[0, 0, 0, 0], + [0, 1, 1, 0], + [0, 1, 1, 0], + [0, 0, 0, 0]], + # raw 2 + [[0, 0, 0, 0], + [0, 1, 1, 0], + [0, 1, 1, 0], + [0, 0, 0, 0]] + ] + ) + # add another dimension corresponding to the batch (batch size = 1 here) + target = target.unsqueeze(0) # shape (1, H, W, D) + pred_very_good = 1000 * F.one_hot( + target, num_classes=2).permute(0, 4, 1, 2, 3).float() + + # initialize the mean dice loss + loss = FocalLoss() + + # focal loss for pred_very_good should be close to 0 + focal_loss_good = float(loss.forward(pred_very_good, target).cpu()) + self.assertAlmostEqual(focal_loss_good, 0., places=3) + + def test_convergence(self): + """ + The goal of this test is to assess if the gradient of the loss function + is correct by testing if we can train a one layer neural network + to segment one image. + We verify that the loss is decreasing in almost all SGD steps. + """ + learning_rate = 0.001 + max_iter = 20 + + # define a simple 3d example + target_seg = torch.tensor( + [ + # raw 0 + [[0, 0, 0, 0], + [0, 1, 1, 0], + [0, 1, 1, 0], + [0, 0, 0, 0]], + # raw 1 + [[0, 0, 0, 0], + [0, 1, 1, 0], + [0, 1, 1, 0], + [0, 0, 0, 0]], + # raw 2 + [[0, 0, 0, 0], + [0, 1, 1, 0], + [0, 1, 1, 0], + [0, 0, 0, 0]] + ] + ) + target_seg = torch.unsqueeze(target_seg, dim=0) + image = 12 * target_seg + 27 + image = image.float() + num_classes = 2 + num_voxels = 3 * 4 * 4 + # define a one layer model + class OnelayerNet(nn.Module): + def __init__(self): + super(OnelayerNet, self).__init__() + self.layer = nn.Linear(num_voxels, num_voxels * num_classes) + def forward(self, x): + x = x.view(-1, num_voxels) + x = self.layer(x) + x = x.view(-1, num_classes, 3, 4, 4) + return x + + # initialise the network + net = OnelayerNet() + + # initialize the loss + loss = FocalLoss() + + # initialize an SGD + optimizer = optim.SGD(net.parameters(), lr=learning_rate, momentum=0.9) + + loss_history = [] + # train the network + for _ in range(max_iter): + # set the gradient to zero + optimizer.zero_grad() + + # forward pass + output = net(image) + loss_val = loss(output, target_seg) + + # backward pass + loss_val.backward() + optimizer.step() + + # stats + loss_history.append(loss_val.item()) + + # count the number of SGD steps in which the loss decreases + num_decreasing_steps = 0 + for i in range(len(loss_history) - 1): + if loss_history[i] > loss_history[i+1]: + num_decreasing_steps += 1 + decreasing_steps_ratio = float(num_decreasing_steps) / (len(loss_history) - 1) + + # verify that the loss is decreasing for sufficiently many SGD steps + self.assertTrue(decreasing_steps_ratio > 0.9) + + +if __name__ == '__main__': + unittest.main() From d364fa32770cac231b6644c89ee658fa8c84228e Mon Sep 17 00:00:00 2001 From: Lucas Fidon Date: Wed, 6 May 2020 12:29:01 +0200 Subject: [PATCH 2/3] focal loss inherits from _WeightedLoss and improved documentation --- docs/source/losses.rst | 5 ++ monai/losses/focal_loss.py | 62 +++++++++++++------------ tests/test_focal_loss.py | 95 +++++++++++++++++++------------------- 3 files changed, 85 insertions(+), 77 deletions(-) diff --git a/docs/source/losses.rst b/docs/source/losses.rst index 50e4563ca1..3d05cd33a2 100644 --- a/docs/source/losses.rst +++ b/docs/source/losses.rst @@ -21,3 +21,8 @@ Segmentation Losses ~~~~~~~~~~~~~~~~~~~~~ .. autoclass:: GeneralizedDiceLoss :members: + +`FocalLoss` +~~~~~~~~~~~ +.. autoclass:: monai.losses.focal_loss.FocalLoss + :members: diff --git a/monai/losses/focal_loss.py b/monai/losses/focal_loss.py index f49e2be024..64da5654de 100644 --- a/monai/losses/focal_loss.py +++ b/monai/losses/focal_loss.py @@ -10,36 +10,35 @@ # limitations under the License. import torch -import torch.nn as nn import torch.nn.functional as F -from torch.autograd import Variable +from torch.nn.modules.loss import _WeightedLoss -class FocalLoss(nn.Module): +class FocalLoss(_WeightedLoss): """ PyTorch implementation of the Focal Loss. [1] "Focal Loss for Dense Object Detection", T. Lin et al., ICCV 2017 """ - def __init__(self, gamma=2., alpha=None, reduction='mean'): + def __init__(self, gamma=2., weight=None, reduction='mean'): """ Args: gamma: (float) value of the exponent gamma in the definition of the Focal loss. - alpha: (float or float list or None) weights to apply to the + weight: (tensor) weights to apply to the voxels of each class. If None no weights are applied. - reduction: (string) Reduction operation to apply on the loss batch. + This corresponds to the weights \alpha in [1]. + reduction: (string) reduction operation to apply on the loss batch. It can be 'mean', 'sum' or 'none' as in the standard PyTorch API for loss functions. + + Exemple: + pred = torch.tensor([[1, 0], [0, 1], [1, 0]], dtype=torch.float32) + grnd = torch.tensor([0, 1 ,0], dtype=torch.int64) + fl = FocalLoss() + fl(pred, grnd) """ - super(FocalLoss, self).__init__() - # same default parameters as in the original paper [1] + super(FocalLoss, self).__init__(weight=weight, reduction=reduction) self.gamma = gamma - self.alpha = alpha # weight for the classes - if isinstance(alpha, (float, int)): - self.alpha = torch.Tensor([alpha, 1 - alpha]) - if isinstance(alpha, list): - self.alpha = torch.Tensor(alpha) - self.reduction = reduction def forward(self, input, target): """ @@ -49,40 +48,43 @@ def forward(self, input, target): """ i = input t = target - # Resize the input and target + if t.dim() < i.dim(): - # Add a class dimension to the ground-truth segmentation + # Add a class dimension to the ground-truth segmentation. t = t.unsqueeze(1) # N,H,W => N,1,H,W + + # Change the shape of input and target to + # num_batch x num_class x num_voxels. if input.dim() > 2: i = i.view(i.size(0), i.size(1), -1) # N,C,H,W => N,C,H*W t = t.view(t.size(0), t.size(1), -1) # N,1,H,W => N,1,H*W - else: # Compatibility with classification + else: # Compatibility with classification. i = i.unsqueeze(2) # N,C => N,C,1 t = t.unsqueeze(2) # N,1 => N,1,1 - # Compute the log proba (more stable numerically than softmax) + # Compute the log proba (more stable numerically than softmax). logpt = F.log_softmax(i, dim=1) # N,C,H*W - # Keep only log proba values of the ground truth class for each voxel + # Keep only log proba values of the ground truth class for each voxel. logpt = logpt.gather(1, t) # N,C,H*W => N,1,H*W logpt = torch.squeeze(logpt, dim=1) # N,1,H*W => N,H*W # Get the proba pt = torch.exp(logpt) # N,H*W - if self.alpha is not None: - if self.alpha.type() != i.data.type(): - self.alpha = self.alpha.type_as(i.data) - # Select the correct weight for each voxel depending on its - # associated gt label - at = torch.unsqueeze(self.alpha, dim=0) # C => 1,C - at = torch.unsqueeze(at, dim=2) # 1,C => 1,C,1 + if self.weight is not None: + if self.weight.type() != i.data.type(): + self.weight = self.weight.type_as(i.data) + # Convert the weight to a map in which each voxel + # has the weight associated with the ground-truth label + # associated with this voxel in target. + at = self.weight[None, :, None] # C => 1,C,1 at = at.expand((t.size(0), -1, t.size(2))) # 1,C,1 => N,C,H*W at = at.gather(1, t.data) # selection of the weights => N,1,H*W at = torch.squeeze(at, dim=1) # N,1,H*W => N,H*W - # Multiply the log proba by their weights - logpt = logpt * Variable(at) + # Multiply the log proba by their weights. + logpt = logpt * at - # Compute the loss mini-batch + # Compute the loss mini-batch. weight = torch.pow(-pt + 1., self.gamma) loss = torch.mean(-weight * logpt, dim=1) # N @@ -90,6 +92,6 @@ def forward(self, input, target): return loss.sum() elif self.reduction == 'none': return loss - # Default is mean reduction + # Default is mean reduction. else: return loss.mean() diff --git a/tests/test_focal_loss.py b/tests/test_focal_loss.py index 9fb05ba57a..a45c3afd4e 100644 --- a/tests/test_focal_loss.py +++ b/tests/test_focal_loss.py @@ -12,7 +12,6 @@ import torch.nn.functional as F import torch.nn as nn import torch.optim as optim -from torch.autograd import Variable import torch import unittest from monai.losses import FocalLoss @@ -28,17 +27,19 @@ def test_consistency_with_cross_entropy_2d(self): batch_size = 128 for _ in range(100): # Create a random tensor of shape (batch_size, class_num, 8, 4) - x = torch.rand(batch_size, class_num, 8, 4) - x = Variable(x.cuda()) + x = torch.rand(batch_size, class_num, 8, 4, requires_grad=True) # Create a random batch of classes l = torch.randint(low=0, high=class_num, size=(batch_size, 8, 4)) l = l.long() - l = Variable(l.cuda()) + if torch.cuda.is_available(): + x = x.cuda() + l = l.cuda() output0 = focal_loss.forward(x, l) output1 = ce.forward(x, l) a = float(output0.cpu().detach()) b = float(output1.cpu().detach()) - if abs(a - b) > max_error: max_error = abs(a - b) + if abs(a - b) > max_error: + max_error = abs(a - b) self.assertAlmostEqual(max_error, 0., places=3) def test_consistency_with_cross_entropy_classification(self): @@ -50,26 +51,28 @@ def test_consistency_with_cross_entropy_classification(self): batch_size = 128 for _ in range(100): # Create a random scores tensor of shape (batch_size, class_num) - x = torch.rand(batch_size, class_num) - x = Variable(x.cuda()) + x = torch.rand(batch_size, class_num, requires_grad=True) # Create a random batch of classes l = torch.randint(low=0, high=class_num, size=(batch_size,)) l = l.long() - l = Variable(l.cuda()) + if torch.cuda.is_available(): + x = x.cuda() + l = l.cuda() output0 = focal_loss.forward(x, l) output1 = ce.forward(x, l) a = float(output0.cpu().detach()) b = float(output1.cpu().detach()) - if abs(a - b) > max_error: max_error = abs(a - b) + if abs(a - b) > max_error: + max_error = abs(a - b) self.assertAlmostEqual(max_error, 0., places=3) def test_bin_seg_2d(self): # define 2d examples target = torch.tensor( - [[0,0,0,0], - [0,1,1,0], - [0,1,1,0], - [0,0,0,0]] + [[0, 0, 0, 0], + [0, 1, 1, 0], + [0, 1, 1, 0], + [0, 0, 0, 0]] ) # add another dimension corresponding to the batch (batch size = 1 here) target = target.unsqueeze(0) # shape (1, H, W) @@ -92,10 +95,10 @@ def test_empty_class_2d(self): num_classes = 2 # define 2d examples target = torch.tensor( - [[0,0,0,0], - [0,0,0,0], - [0,0,0,0], - [0,0,0,0]] + [[0, 0, 0, 0], + [0, 0, 0, 0], + [0, 0, 0, 0], + [0, 0, 0, 0]] ) # add another dimension corresponding to the batch (batch size = 1 here) target = target.unsqueeze(0) # shape (1, H, W) @@ -113,10 +116,10 @@ def test_multi_class_seg_2d(self): num_classes = 6 # labels 0 to 5 # define 2d examples target = torch.tensor( - [[0,0,0,0], - [0,1,2,0], - [0,3,4,0], - [0,0,0,0]] + [[0, 0, 0, 0], + [0, 1, 2, 0], + [0, 3, 4, 0], + [0, 0, 0, 0]] ) # add another dimension corresponding to the batch (batch size = 1 here) target = target.unsqueeze(0) # shape (1, H, W) @@ -132,25 +135,23 @@ def test_multi_class_seg_2d(self): def test_bin_seg_3d(self): # define 2d examples - target = torch.tensor( - [ + target = torch.tensor([ # raw 0 [[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]], # raw 1 - [[0, 0, 0, 0], - [0, 1, 1, 0], - [0, 1, 1, 0], - [0, 0, 0, 0]], + [[0, 0, 0, 0], + [0, 1, 1, 0], + [0, 1, 1, 0], + [0, 0, 0, 0]], # raw 2 - [[0, 0, 0, 0], - [0, 1, 1, 0], - [0, 1, 1, 0], - [0, 0, 0, 0]] - ] - ) + [[0, 0, 0, 0], + [0, 1, 1, 0], + [0, 1, 1, 0], + [0, 0, 0, 0]] + ]) # add another dimension corresponding to the batch (batch size = 1 here) target = target.unsqueeze(0) # shape (1, H, W, D) pred_very_good = 1000 * F.one_hot( @@ -174,35 +175,35 @@ def test_convergence(self): max_iter = 20 # define a simple 3d example - target_seg = torch.tensor( - [ + target_seg = torch.tensor([ # raw 0 + [[0, 0, 0, 0], + [0, 1, 1, 0], + [0, 1, 1, 0], + [0, 0, 0, 0]], + # raw 1 [[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]], - # raw 1 - [[0, 0, 0, 0], - [0, 1, 1, 0], - [0, 1, 1, 0], - [0, 0, 0, 0]], # raw 2 - [[0, 0, 0, 0], - [0, 1, 1, 0], - [0, 1, 1, 0], - [0, 0, 0, 0]] - ] - ) + [[0, 0, 0, 0], + [0, 1, 1, 0], + [0, 1, 1, 0], + [0, 0, 0, 0]] + ]) target_seg = torch.unsqueeze(target_seg, dim=0) image = 12 * target_seg + 27 image = image.float() num_classes = 2 num_voxels = 3 * 4 * 4 + # define a one layer model class OnelayerNet(nn.Module): def __init__(self): super(OnelayerNet, self).__init__() self.layer = nn.Linear(num_voxels, num_voxels * num_classes) + def forward(self, x): x = x.view(-1, num_voxels) x = self.layer(x) @@ -238,7 +239,7 @@ def forward(self, x): # count the number of SGD steps in which the loss decreases num_decreasing_steps = 0 for i in range(len(loss_history) - 1): - if loss_history[i] > loss_history[i+1]: + if loss_history[i] > loss_history[i + 1]: num_decreasing_steps += 1 decreasing_steps_ratio = float(num_decreasing_steps) / (len(loss_history) - 1) From b6b5b7ddc06a7171d1e64bfcdb20c44c178cd24e Mon Sep 17 00:00:00 2001 From: monai-bot Date: Thu, 7 May 2020 16:48:32 +0000 Subject: [PATCH 3/3] [MONAI] python code formatting --- monai/losses/__init__.py | 2 +- monai/losses/focal_loss.py | 9 +-- tests/test_focal_loss.py | 111 +++++++++++++------------------------ 3 files changed, 45 insertions(+), 77 deletions(-) diff --git a/monai/losses/__init__.py b/monai/losses/__init__.py index e3124bf0af..3162078810 100644 --- a/monai/losses/__init__.py +++ b/monai/losses/__init__.py @@ -11,4 +11,4 @@ from .dice import * from .focal_loss import * -from .tversky import * \ No newline at end of file +from .tversky import * diff --git a/monai/losses/focal_loss.py b/monai/losses/focal_loss.py index 64da5654de..b9e173d879 100644 --- a/monai/losses/focal_loss.py +++ b/monai/losses/focal_loss.py @@ -19,7 +19,8 @@ class FocalLoss(_WeightedLoss): PyTorch implementation of the Focal Loss. [1] "Focal Loss for Dense Object Detection", T. Lin et al., ICCV 2017 """ - def __init__(self, gamma=2., weight=None, reduction='mean'): + + def __init__(self, gamma=2.0, weight=None, reduction="mean"): """ Args: gamma: (float) value of the exponent gamma in the definition @@ -85,12 +86,12 @@ def forward(self, input, target): logpt = logpt * at # Compute the loss mini-batch. - weight = torch.pow(-pt + 1., self.gamma) + weight = torch.pow(-pt + 1.0, self.gamma) loss = torch.mean(-weight * logpt, dim=1) # N - if self.reduction == 'sum': + if self.reduction == "sum": return loss.sum() - elif self.reduction == 'none': + elif self.reduction == "none": return loss # Default is mean reduction. else: diff --git a/tests/test_focal_loss.py b/tests/test_focal_loss.py index a45c3afd4e..1a51b9ca2a 100644 --- a/tests/test_focal_loss.py +++ b/tests/test_focal_loss.py @@ -20,8 +20,8 @@ class TestFocalLoss(unittest.TestCase): def test_consistency_with_cross_entropy_2d(self): # For gamma=0 the focal loss reduces to the cross entropy loss - focal_loss = FocalLoss(gamma=0., reduction='mean') - ce = nn.CrossEntropyLoss(reduction='mean') + focal_loss = FocalLoss(gamma=0.0, reduction="mean") + ce = nn.CrossEntropyLoss(reduction="mean") max_error = 0 class_num = 10 batch_size = 128 @@ -40,12 +40,12 @@ def test_consistency_with_cross_entropy_2d(self): b = float(output1.cpu().detach()) if abs(a - b) > max_error: max_error = abs(a - b) - self.assertAlmostEqual(max_error, 0., places=3) + self.assertAlmostEqual(max_error, 0.0, places=3) def test_consistency_with_cross_entropy_classification(self): # for gamma=0 the focal loss reduces to the cross entropy loss - focal_loss = FocalLoss(gamma=0., reduction='mean') - ce = nn.CrossEntropyLoss(reduction='mean') + focal_loss = FocalLoss(gamma=0.0, reduction="mean") + ce = nn.CrossEntropyLoss(reduction="mean") max_error = 0 class_num = 10 batch_size = 128 @@ -64,105 +64,79 @@ def test_consistency_with_cross_entropy_classification(self): b = float(output1.cpu().detach()) if abs(a - b) > max_error: max_error = abs(a - b) - self.assertAlmostEqual(max_error, 0., places=3) + self.assertAlmostEqual(max_error, 0.0, places=3) def test_bin_seg_2d(self): # define 2d examples - target = torch.tensor( - [[0, 0, 0, 0], - [0, 1, 1, 0], - [0, 1, 1, 0], - [0, 0, 0, 0]] - ) + target = torch.tensor([[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]]) # add another dimension corresponding to the batch (batch size = 1 here) target = target.unsqueeze(0) # shape (1, H, W) - pred_very_good = 1000 * F.one_hot( - target, num_classes=2).permute(0, 3, 1, 2).float() + pred_very_good = 1000 * F.one_hot(target, num_classes=2).permute(0, 3, 1, 2).float() # initialize the mean dice loss loss = FocalLoss() # focal loss for pred_very_good should be close to 0 focal_loss_good = float(loss.forward(pred_very_good, target).cpu()) - self.assertAlmostEqual(focal_loss_good, 0., places=3) + self.assertAlmostEqual(focal_loss_good, 0.0, places=3) # Same test, but for target with a class dimension target = target.unsqueeze(1) # shape (1, 1, H, W) focal_loss_good = float(loss.forward(pred_very_good, target).cpu()) - self.assertAlmostEqual(focal_loss_good, 0., places=3) + self.assertAlmostEqual(focal_loss_good, 0.0, places=3) def test_empty_class_2d(self): num_classes = 2 # define 2d examples - target = torch.tensor( - [[0, 0, 0, 0], - [0, 0, 0, 0], - [0, 0, 0, 0], - [0, 0, 0, 0]] - ) + target = torch.tensor([[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]) # add another dimension corresponding to the batch (batch size = 1 here) target = target.unsqueeze(0) # shape (1, H, W) - pred_very_good = 1000 * F.one_hot( - target, num_classes=num_classes).permute(0, 3, 1, 2).float() + pred_very_good = 1000 * F.one_hot(target, num_classes=num_classes).permute(0, 3, 1, 2).float() # initialize the mean dice loss loss = FocalLoss() # focal loss for pred_very_good should be close to 0 focal_loss_good = float(loss.forward(pred_very_good, target).cpu()) - self.assertAlmostEqual(focal_loss_good, 0., places=3) + self.assertAlmostEqual(focal_loss_good, 0.0, places=3) def test_multi_class_seg_2d(self): num_classes = 6 # labels 0 to 5 # define 2d examples - target = torch.tensor( - [[0, 0, 0, 0], - [0, 1, 2, 0], - [0, 3, 4, 0], - [0, 0, 0, 0]] - ) + target = torch.tensor([[0, 0, 0, 0], [0, 1, 2, 0], [0, 3, 4, 0], [0, 0, 0, 0]]) # add another dimension corresponding to the batch (batch size = 1 here) target = target.unsqueeze(0) # shape (1, H, W) - pred_very_good = 1000 * F.one_hot( - target, num_classes=num_classes).permute(0, 3, 1, 2).float() + pred_very_good = 1000 * F.one_hot(target, num_classes=num_classes).permute(0, 3, 1, 2).float() # initialize the mean dice loss loss = FocalLoss() # focal loss for pred_very_good should be close to 0 focal_loss_good = float(loss.forward(pred_very_good, target).cpu()) - self.assertAlmostEqual(focal_loss_good, 0., places=3) + self.assertAlmostEqual(focal_loss_good, 0.0, places=3) def test_bin_seg_3d(self): # define 2d examples - target = torch.tensor([ - # raw 0 - [[0, 0, 0, 0], - [0, 1, 1, 0], - [0, 1, 1, 0], - [0, 0, 0, 0]], - # raw 1 - [[0, 0, 0, 0], - [0, 1, 1, 0], - [0, 1, 1, 0], - [0, 0, 0, 0]], - # raw 2 - [[0, 0, 0, 0], - [0, 1, 1, 0], - [0, 1, 1, 0], - [0, 0, 0, 0]] - ]) + target = torch.tensor( + [ + # raw 0 + [[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]], + # raw 1 + [[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]], + # raw 2 + [[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]], + ] + ) # add another dimension corresponding to the batch (batch size = 1 here) target = target.unsqueeze(0) # shape (1, H, W, D) - pred_very_good = 1000 * F.one_hot( - target, num_classes=2).permute(0, 4, 1, 2, 3).float() + pred_very_good = 1000 * F.one_hot(target, num_classes=2).permute(0, 4, 1, 2, 3).float() # initialize the mean dice loss loss = FocalLoss() # focal loss for pred_very_good should be close to 0 focal_loss_good = float(loss.forward(pred_very_good, target).cpu()) - self.assertAlmostEqual(focal_loss_good, 0., places=3) + self.assertAlmostEqual(focal_loss_good, 0.0, places=3) def test_convergence(self): """ @@ -175,23 +149,16 @@ def test_convergence(self): max_iter = 20 # define a simple 3d example - target_seg = torch.tensor([ - # raw 0 - [[0, 0, 0, 0], - [0, 1, 1, 0], - [0, 1, 1, 0], - [0, 0, 0, 0]], - # raw 1 - [[0, 0, 0, 0], - [0, 1, 1, 0], - [0, 1, 1, 0], - [0, 0, 0, 0]], - # raw 2 - [[0, 0, 0, 0], - [0, 1, 1, 0], - [0, 1, 1, 0], - [0, 0, 0, 0]] - ]) + target_seg = torch.tensor( + [ + # raw 0 + [[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]], + # raw 1 + [[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]], + # raw 2 + [[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]], + ] + ) target_seg = torch.unsqueeze(target_seg, dim=0) image = 12 * target_seg + 27 image = image.float() @@ -247,5 +214,5 @@ def forward(self, x): self.assertTrue(decreasing_steps_ratio > 0.9) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main()