Skip to content

Focal loss #325

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 7, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/source/losses.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ Segmentation Losses
.. autoclass:: GeneralizedDiceLoss
:members:

`FocalLoss`
~~~~~~~~~~~
.. autoclass:: monai.losses.focal_loss.FocalLoss
:members:

.. automodule:: monai.losses.tversky
.. currentmodule:: monai.losses.tversky
Expand Down
1 change: 1 addition & 0 deletions monai/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@
# limitations under the License.

from .dice import *
from .focal_loss import *
from .tversky import *
98 changes: 98 additions & 0 deletions monai/losses/focal_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import torch.nn.functional as F
from torch.nn.modules.loss import _WeightedLoss


class FocalLoss(_WeightedLoss):
"""
PyTorch implementation of the Focal Loss.
[1] "Focal Loss for Dense Object Detection", T. Lin et al., ICCV 2017
"""

def __init__(self, gamma=2.0, weight=None, reduction="mean"):
"""
Args:
gamma: (float) value of the exponent gamma in the definition
of the Focal loss.
weight: (tensor) weights to apply to the
voxels of each class. If None no weights are applied.
This corresponds to the weights \alpha in [1].
reduction: (string) reduction operation to apply on the loss batch.
It can be 'mean', 'sum' or 'none' as in the standard PyTorch API
for loss functions.

Exemple:
pred = torch.tensor([[1, 0], [0, 1], [1, 0]], dtype=torch.float32)
grnd = torch.tensor([0, 1 ,0], dtype=torch.int64)
fl = FocalLoss()
fl(pred, grnd)
"""
super(FocalLoss, self).__init__(weight=weight, reduction=reduction)
self.gamma = gamma

def forward(self, input, target):
"""
Args:
input: (tensor): the shape should be BNH[WD].
target: (tensor): the shape should be BNH[WD].
"""
i = input
t = target

if t.dim() < i.dim():
# Add a class dimension to the ground-truth segmentation.
t = t.unsqueeze(1) # N,H,W => N,1,H,W

# Change the shape of input and target to
# num_batch x num_class x num_voxels.
if input.dim() > 2:
i = i.view(i.size(0), i.size(1), -1) # N,C,H,W => N,C,H*W
t = t.view(t.size(0), t.size(1), -1) # N,1,H,W => N,1,H*W
else: # Compatibility with classification.
i = i.unsqueeze(2) # N,C => N,C,1
t = t.unsqueeze(2) # N,1 => N,1,1

# Compute the log proba (more stable numerically than softmax).
logpt = F.log_softmax(i, dim=1) # N,C,H*W
Copy link
Contributor

@wyli wyli May 2, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The sigmoid formulation is for binary classification I think.

# Keep only log proba values of the ground truth class for each voxel.
logpt = logpt.gather(1, t) # N,C,H*W => N,1,H*W
logpt = torch.squeeze(logpt, dim=1) # N,1,H*W => N,H*W

# Get the proba
pt = torch.exp(logpt) # N,H*W

if self.weight is not None:
if self.weight.type() != i.data.type():
self.weight = self.weight.type_as(i.data)
# Convert the weight to a map in which each voxel
# has the weight associated with the ground-truth label
# associated with this voxel in target.
at = self.weight[None, :, None] # C => 1,C,1
at = at.expand((t.size(0), -1, t.size(2))) # 1,C,1 => N,C,H*W
at = at.gather(1, t.data) # selection of the weights => N,1,H*W
at = torch.squeeze(at, dim=1) # N,1,H*W => N,H*W
# Multiply the log proba by their weights.
logpt = logpt * at

# Compute the loss mini-batch.
weight = torch.pow(-pt + 1.0, self.gamma)
loss = torch.mean(-weight * logpt, dim=1) # N

if self.reduction == "sum":
return loss.sum()
elif self.reduction == "none":
return loss
# Default is mean reduction.
else:
return loss.mean()
218 changes: 218 additions & 0 deletions tests/test_focal_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
import torch
import unittest
from monai.losses import FocalLoss


class TestFocalLoss(unittest.TestCase):
def test_consistency_with_cross_entropy_2d(self):
# For gamma=0 the focal loss reduces to the cross entropy loss
focal_loss = FocalLoss(gamma=0.0, reduction="mean")
ce = nn.CrossEntropyLoss(reduction="mean")
max_error = 0
class_num = 10
batch_size = 128
for _ in range(100):
# Create a random tensor of shape (batch_size, class_num, 8, 4)
x = torch.rand(batch_size, class_num, 8, 4, requires_grad=True)
# Create a random batch of classes
l = torch.randint(low=0, high=class_num, size=(batch_size, 8, 4))
l = l.long()
if torch.cuda.is_available():
x = x.cuda()
l = l.cuda()
output0 = focal_loss.forward(x, l)
output1 = ce.forward(x, l)
a = float(output0.cpu().detach())
b = float(output1.cpu().detach())
if abs(a - b) > max_error:
max_error = abs(a - b)
self.assertAlmostEqual(max_error, 0.0, places=3)

def test_consistency_with_cross_entropy_classification(self):
# for gamma=0 the focal loss reduces to the cross entropy loss
focal_loss = FocalLoss(gamma=0.0, reduction="mean")
ce = nn.CrossEntropyLoss(reduction="mean")
max_error = 0
class_num = 10
batch_size = 128
for _ in range(100):
# Create a random scores tensor of shape (batch_size, class_num)
x = torch.rand(batch_size, class_num, requires_grad=True)
# Create a random batch of classes
l = torch.randint(low=0, high=class_num, size=(batch_size,))
l = l.long()
if torch.cuda.is_available():
x = x.cuda()
l = l.cuda()
output0 = focal_loss.forward(x, l)
output1 = ce.forward(x, l)
a = float(output0.cpu().detach())
b = float(output1.cpu().detach())
if abs(a - b) > max_error:
max_error = abs(a - b)
self.assertAlmostEqual(max_error, 0.0, places=3)

def test_bin_seg_2d(self):
# define 2d examples
target = torch.tensor([[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]])
# add another dimension corresponding to the batch (batch size = 1 here)
target = target.unsqueeze(0) # shape (1, H, W)
pred_very_good = 1000 * F.one_hot(target, num_classes=2).permute(0, 3, 1, 2).float()

# initialize the mean dice loss
loss = FocalLoss()

# focal loss for pred_very_good should be close to 0
focal_loss_good = float(loss.forward(pred_very_good, target).cpu())
self.assertAlmostEqual(focal_loss_good, 0.0, places=3)

# Same test, but for target with a class dimension
target = target.unsqueeze(1) # shape (1, 1, H, W)
focal_loss_good = float(loss.forward(pred_very_good, target).cpu())
self.assertAlmostEqual(focal_loss_good, 0.0, places=3)

def test_empty_class_2d(self):
num_classes = 2
# define 2d examples
target = torch.tensor([[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]])
# add another dimension corresponding to the batch (batch size = 1 here)
target = target.unsqueeze(0) # shape (1, H, W)
pred_very_good = 1000 * F.one_hot(target, num_classes=num_classes).permute(0, 3, 1, 2).float()

# initialize the mean dice loss
loss = FocalLoss()

# focal loss for pred_very_good should be close to 0
focal_loss_good = float(loss.forward(pred_very_good, target).cpu())
self.assertAlmostEqual(focal_loss_good, 0.0, places=3)

def test_multi_class_seg_2d(self):
num_classes = 6 # labels 0 to 5
# define 2d examples
target = torch.tensor([[0, 0, 0, 0], [0, 1, 2, 0], [0, 3, 4, 0], [0, 0, 0, 0]])
# add another dimension corresponding to the batch (batch size = 1 here)
target = target.unsqueeze(0) # shape (1, H, W)
pred_very_good = 1000 * F.one_hot(target, num_classes=num_classes).permute(0, 3, 1, 2).float()

# initialize the mean dice loss
loss = FocalLoss()

# focal loss for pred_very_good should be close to 0
focal_loss_good = float(loss.forward(pred_very_good, target).cpu())
self.assertAlmostEqual(focal_loss_good, 0.0, places=3)

def test_bin_seg_3d(self):
# define 2d examples
target = torch.tensor(
[
# raw 0
[[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]],
# raw 1
[[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]],
# raw 2
[[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]],
]
)
# add another dimension corresponding to the batch (batch size = 1 here)
target = target.unsqueeze(0) # shape (1, H, W, D)
pred_very_good = 1000 * F.one_hot(target, num_classes=2).permute(0, 4, 1, 2, 3).float()

# initialize the mean dice loss
loss = FocalLoss()

# focal loss for pred_very_good should be close to 0
focal_loss_good = float(loss.forward(pred_very_good, target).cpu())
self.assertAlmostEqual(focal_loss_good, 0.0, places=3)

def test_convergence(self):
"""
The goal of this test is to assess if the gradient of the loss function
is correct by testing if we can train a one layer neural network
to segment one image.
We verify that the loss is decreasing in almost all SGD steps.
"""
learning_rate = 0.001
max_iter = 20

# define a simple 3d example
target_seg = torch.tensor(
[
# raw 0
[[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]],
# raw 1
[[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]],
# raw 2
[[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]],
]
)
target_seg = torch.unsqueeze(target_seg, dim=0)
image = 12 * target_seg + 27
image = image.float()
num_classes = 2
num_voxels = 3 * 4 * 4

# define a one layer model
class OnelayerNet(nn.Module):
def __init__(self):
super(OnelayerNet, self).__init__()
self.layer = nn.Linear(num_voxels, num_voxels * num_classes)

def forward(self, x):
x = x.view(-1, num_voxels)
x = self.layer(x)
x = x.view(-1, num_classes, 3, 4, 4)
return x

# initialise the network
net = OnelayerNet()

# initialize the loss
loss = FocalLoss()

# initialize an SGD
optimizer = optim.SGD(net.parameters(), lr=learning_rate, momentum=0.9)

loss_history = []
# train the network
for _ in range(max_iter):
# set the gradient to zero
optimizer.zero_grad()

# forward pass
output = net(image)
loss_val = loss(output, target_seg)

# backward pass
loss_val.backward()
optimizer.step()

# stats
loss_history.append(loss_val.item())

# count the number of SGD steps in which the loss decreases
num_decreasing_steps = 0
for i in range(len(loss_history) - 1):
if loss_history[i] > loss_history[i + 1]:
num_decreasing_steps += 1
decreasing_steps_ratio = float(num_decreasing_steps) / (len(loss_history) - 1)

# verify that the loss is decreasing for sufficiently many SGD steps
self.assertTrue(decreasing_steps_ratio > 0.9)


if __name__ == "__main__":
unittest.main()