Skip to content

Commit

Permalink
Focal loss (#325)
Browse files Browse the repository at this point in the history
* added the Focal loss

* focal loss inherits from _WeightedLoss and improved documentation

* [MONAI] python code formatting

Co-authored-by: monai-bot <[email protected]>
  • Loading branch information
LucasFidon and monai-bot authored May 7, 2020
1 parent 57abc3b commit 5fcbf5a
Show file tree
Hide file tree
Showing 4 changed files with 321 additions and 0 deletions.
4 changes: 4 additions & 0 deletions docs/source/losses.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ Segmentation Losses
.. autoclass:: GeneralizedDiceLoss
:members:

`FocalLoss`
~~~~~~~~~~~
.. autoclass:: monai.losses.focal_loss.FocalLoss
:members:

.. automodule:: monai.losses.tversky
.. currentmodule:: monai.losses.tversky
Expand Down
1 change: 1 addition & 0 deletions monai/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@
# limitations under the License.

from .dice import *
from .focal_loss import *
from .tversky import *
98 changes: 98 additions & 0 deletions monai/losses/focal_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import torch.nn.functional as F
from torch.nn.modules.loss import _WeightedLoss


class FocalLoss(_WeightedLoss):
"""
PyTorch implementation of the Focal Loss.
[1] "Focal Loss for Dense Object Detection", T. Lin et al., ICCV 2017
"""

def __init__(self, gamma=2.0, weight=None, reduction="mean"):
"""
Args:
gamma: (float) value of the exponent gamma in the definition
of the Focal loss.
weight: (tensor) weights to apply to the
voxels of each class. If None no weights are applied.
This corresponds to the weights \alpha in [1].
reduction: (string) reduction operation to apply on the loss batch.
It can be 'mean', 'sum' or 'none' as in the standard PyTorch API
for loss functions.
Exemple:
pred = torch.tensor([[1, 0], [0, 1], [1, 0]], dtype=torch.float32)
grnd = torch.tensor([0, 1 ,0], dtype=torch.int64)
fl = FocalLoss()
fl(pred, grnd)
"""
super(FocalLoss, self).__init__(weight=weight, reduction=reduction)
self.gamma = gamma

def forward(self, input, target):
"""
Args:
input: (tensor): the shape should be BNH[WD].
target: (tensor): the shape should be BNH[WD].
"""
i = input
t = target

if t.dim() < i.dim():
# Add a class dimension to the ground-truth segmentation.
t = t.unsqueeze(1) # N,H,W => N,1,H,W

# Change the shape of input and target to
# num_batch x num_class x num_voxels.
if input.dim() > 2:
i = i.view(i.size(0), i.size(1), -1) # N,C,H,W => N,C,H*W
t = t.view(t.size(0), t.size(1), -1) # N,1,H,W => N,1,H*W
else: # Compatibility with classification.
i = i.unsqueeze(2) # N,C => N,C,1
t = t.unsqueeze(2) # N,1 => N,1,1

# Compute the log proba (more stable numerically than softmax).
logpt = F.log_softmax(i, dim=1) # N,C,H*W
# Keep only log proba values of the ground truth class for each voxel.
logpt = logpt.gather(1, t) # N,C,H*W => N,1,H*W
logpt = torch.squeeze(logpt, dim=1) # N,1,H*W => N,H*W

# Get the proba
pt = torch.exp(logpt) # N,H*W

if self.weight is not None:
if self.weight.type() != i.data.type():
self.weight = self.weight.type_as(i.data)
# Convert the weight to a map in which each voxel
# has the weight associated with the ground-truth label
# associated with this voxel in target.
at = self.weight[None, :, None] # C => 1,C,1
at = at.expand((t.size(0), -1, t.size(2))) # 1,C,1 => N,C,H*W
at = at.gather(1, t.data) # selection of the weights => N,1,H*W
at = torch.squeeze(at, dim=1) # N,1,H*W => N,H*W
# Multiply the log proba by their weights.
logpt = logpt * at

# Compute the loss mini-batch.
weight = torch.pow(-pt + 1.0, self.gamma)
loss = torch.mean(-weight * logpt, dim=1) # N

if self.reduction == "sum":
return loss.sum()
elif self.reduction == "none":
return loss
# Default is mean reduction.
else:
return loss.mean()
218 changes: 218 additions & 0 deletions tests/test_focal_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
import torch
import unittest
from monai.losses import FocalLoss


class TestFocalLoss(unittest.TestCase):
def test_consistency_with_cross_entropy_2d(self):
# For gamma=0 the focal loss reduces to the cross entropy loss
focal_loss = FocalLoss(gamma=0.0, reduction="mean")
ce = nn.CrossEntropyLoss(reduction="mean")
max_error = 0
class_num = 10
batch_size = 128
for _ in range(100):
# Create a random tensor of shape (batch_size, class_num, 8, 4)
x = torch.rand(batch_size, class_num, 8, 4, requires_grad=True)
# Create a random batch of classes
l = torch.randint(low=0, high=class_num, size=(batch_size, 8, 4))
l = l.long()
if torch.cuda.is_available():
x = x.cuda()
l = l.cuda()
output0 = focal_loss.forward(x, l)
output1 = ce.forward(x, l)
a = float(output0.cpu().detach())
b = float(output1.cpu().detach())
if abs(a - b) > max_error:
max_error = abs(a - b)
self.assertAlmostEqual(max_error, 0.0, places=3)

def test_consistency_with_cross_entropy_classification(self):
# for gamma=0 the focal loss reduces to the cross entropy loss
focal_loss = FocalLoss(gamma=0.0, reduction="mean")
ce = nn.CrossEntropyLoss(reduction="mean")
max_error = 0
class_num = 10
batch_size = 128
for _ in range(100):
# Create a random scores tensor of shape (batch_size, class_num)
x = torch.rand(batch_size, class_num, requires_grad=True)
# Create a random batch of classes
l = torch.randint(low=0, high=class_num, size=(batch_size,))
l = l.long()
if torch.cuda.is_available():
x = x.cuda()
l = l.cuda()
output0 = focal_loss.forward(x, l)
output1 = ce.forward(x, l)
a = float(output0.cpu().detach())
b = float(output1.cpu().detach())
if abs(a - b) > max_error:
max_error = abs(a - b)
self.assertAlmostEqual(max_error, 0.0, places=3)

def test_bin_seg_2d(self):
# define 2d examples
target = torch.tensor([[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]])
# add another dimension corresponding to the batch (batch size = 1 here)
target = target.unsqueeze(0) # shape (1, H, W)
pred_very_good = 1000 * F.one_hot(target, num_classes=2).permute(0, 3, 1, 2).float()

# initialize the mean dice loss
loss = FocalLoss()

# focal loss for pred_very_good should be close to 0
focal_loss_good = float(loss.forward(pred_very_good, target).cpu())
self.assertAlmostEqual(focal_loss_good, 0.0, places=3)

# Same test, but for target with a class dimension
target = target.unsqueeze(1) # shape (1, 1, H, W)
focal_loss_good = float(loss.forward(pred_very_good, target).cpu())
self.assertAlmostEqual(focal_loss_good, 0.0, places=3)

def test_empty_class_2d(self):
num_classes = 2
# define 2d examples
target = torch.tensor([[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]])
# add another dimension corresponding to the batch (batch size = 1 here)
target = target.unsqueeze(0) # shape (1, H, W)
pred_very_good = 1000 * F.one_hot(target, num_classes=num_classes).permute(0, 3, 1, 2).float()

# initialize the mean dice loss
loss = FocalLoss()

# focal loss for pred_very_good should be close to 0
focal_loss_good = float(loss.forward(pred_very_good, target).cpu())
self.assertAlmostEqual(focal_loss_good, 0.0, places=3)

def test_multi_class_seg_2d(self):
num_classes = 6 # labels 0 to 5
# define 2d examples
target = torch.tensor([[0, 0, 0, 0], [0, 1, 2, 0], [0, 3, 4, 0], [0, 0, 0, 0]])
# add another dimension corresponding to the batch (batch size = 1 here)
target = target.unsqueeze(0) # shape (1, H, W)
pred_very_good = 1000 * F.one_hot(target, num_classes=num_classes).permute(0, 3, 1, 2).float()

# initialize the mean dice loss
loss = FocalLoss()

# focal loss for pred_very_good should be close to 0
focal_loss_good = float(loss.forward(pred_very_good, target).cpu())
self.assertAlmostEqual(focal_loss_good, 0.0, places=3)

def test_bin_seg_3d(self):
# define 2d examples
target = torch.tensor(
[
# raw 0
[[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]],
# raw 1
[[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]],
# raw 2
[[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]],
]
)
# add another dimension corresponding to the batch (batch size = 1 here)
target = target.unsqueeze(0) # shape (1, H, W, D)
pred_very_good = 1000 * F.one_hot(target, num_classes=2).permute(0, 4, 1, 2, 3).float()

# initialize the mean dice loss
loss = FocalLoss()

# focal loss for pred_very_good should be close to 0
focal_loss_good = float(loss.forward(pred_very_good, target).cpu())
self.assertAlmostEqual(focal_loss_good, 0.0, places=3)

def test_convergence(self):
"""
The goal of this test is to assess if the gradient of the loss function
is correct by testing if we can train a one layer neural network
to segment one image.
We verify that the loss is decreasing in almost all SGD steps.
"""
learning_rate = 0.001
max_iter = 20

# define a simple 3d example
target_seg = torch.tensor(
[
# raw 0
[[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]],
# raw 1
[[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]],
# raw 2
[[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]],
]
)
target_seg = torch.unsqueeze(target_seg, dim=0)
image = 12 * target_seg + 27
image = image.float()
num_classes = 2
num_voxels = 3 * 4 * 4

# define a one layer model
class OnelayerNet(nn.Module):
def __init__(self):
super(OnelayerNet, self).__init__()
self.layer = nn.Linear(num_voxels, num_voxels * num_classes)

def forward(self, x):
x = x.view(-1, num_voxels)
x = self.layer(x)
x = x.view(-1, num_classes, 3, 4, 4)
return x

# initialise the network
net = OnelayerNet()

# initialize the loss
loss = FocalLoss()

# initialize an SGD
optimizer = optim.SGD(net.parameters(), lr=learning_rate, momentum=0.9)

loss_history = []
# train the network
for _ in range(max_iter):
# set the gradient to zero
optimizer.zero_grad()

# forward pass
output = net(image)
loss_val = loss(output, target_seg)

# backward pass
loss_val.backward()
optimizer.step()

# stats
loss_history.append(loss_val.item())

# count the number of SGD steps in which the loss decreases
num_decreasing_steps = 0
for i in range(len(loss_history) - 1):
if loss_history[i] > loss_history[i + 1]:
num_decreasing_steps += 1
decreasing_steps_ratio = float(num_decreasing_steps) / (len(loss_history) - 1)

# verify that the loss is decreasing for sufficiently many SGD steps
self.assertTrue(decreasing_steps_ratio > 0.9)


if __name__ == "__main__":
unittest.main()

0 comments on commit 5fcbf5a

Please sign in to comment.