-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Focal loss #325
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Focal loss #325
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,4 +10,5 @@ | |
# limitations under the License. | ||
|
||
from .dice import * | ||
from .focal_loss import * | ||
from .tversky import * |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
# Copyright 2020 MONAI Consortium | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
from torch.nn.modules.loss import _WeightedLoss | ||
|
||
|
||
class FocalLoss(_WeightedLoss): | ||
""" | ||
PyTorch implementation of the Focal Loss. | ||
[1] "Focal Loss for Dense Object Detection", T. Lin et al., ICCV 2017 | ||
""" | ||
|
||
def __init__(self, gamma=2.0, weight=None, reduction="mean"): | ||
""" | ||
Args: | ||
gamma: (float) value of the exponent gamma in the definition | ||
of the Focal loss. | ||
weight: (tensor) weights to apply to the | ||
voxels of each class. If None no weights are applied. | ||
This corresponds to the weights \alpha in [1]. | ||
reduction: (string) reduction operation to apply on the loss batch. | ||
It can be 'mean', 'sum' or 'none' as in the standard PyTorch API | ||
for loss functions. | ||
|
||
Exemple: | ||
pred = torch.tensor([[1, 0], [0, 1], [1, 0]], dtype=torch.float32) | ||
grnd = torch.tensor([0, 1 ,0], dtype=torch.int64) | ||
fl = FocalLoss() | ||
fl(pred, grnd) | ||
""" | ||
super(FocalLoss, self).__init__(weight=weight, reduction=reduction) | ||
self.gamma = gamma | ||
|
||
def forward(self, input, target): | ||
""" | ||
Args: | ||
input: (tensor): the shape should be BNH[WD]. | ||
target: (tensor): the shape should be BNH[WD]. | ||
""" | ||
i = input | ||
t = target | ||
|
||
if t.dim() < i.dim(): | ||
# Add a class dimension to the ground-truth segmentation. | ||
t = t.unsqueeze(1) # N,H,W => N,1,H,W | ||
|
||
# Change the shape of input and target to | ||
# num_batch x num_class x num_voxels. | ||
if input.dim() > 2: | ||
i = i.view(i.size(0), i.size(1), -1) # N,C,H,W => N,C,H*W | ||
t = t.view(t.size(0), t.size(1), -1) # N,1,H,W => N,1,H*W | ||
else: # Compatibility with classification. | ||
i = i.unsqueeze(2) # N,C => N,C,1 | ||
t = t.unsqueeze(2) # N,1 => N,1,1 | ||
|
||
# Compute the log proba (more stable numerically than softmax). | ||
logpt = F.log_softmax(i, dim=1) # N,C,H*W | ||
# Keep only log proba values of the ground truth class for each voxel. | ||
logpt = logpt.gather(1, t) # N,C,H*W => N,1,H*W | ||
logpt = torch.squeeze(logpt, dim=1) # N,1,H*W => N,H*W | ||
|
||
# Get the proba | ||
pt = torch.exp(logpt) # N,H*W | ||
|
||
if self.weight is not None: | ||
if self.weight.type() != i.data.type(): | ||
self.weight = self.weight.type_as(i.data) | ||
# Convert the weight to a map in which each voxel | ||
# has the weight associated with the ground-truth label | ||
# associated with this voxel in target. | ||
at = self.weight[None, :, None] # C => 1,C,1 | ||
at = at.expand((t.size(0), -1, t.size(2))) # 1,C,1 => N,C,H*W | ||
at = at.gather(1, t.data) # selection of the weights => N,1,H*W | ||
at = torch.squeeze(at, dim=1) # N,1,H*W => N,H*W | ||
# Multiply the log proba by their weights. | ||
logpt = logpt * at | ||
|
||
# Compute the loss mini-batch. | ||
weight = torch.pow(-pt + 1.0, self.gamma) | ||
loss = torch.mean(-weight * logpt, dim=1) # N | ||
|
||
if self.reduction == "sum": | ||
return loss.sum() | ||
elif self.reduction == "none": | ||
return loss | ||
# Default is mean reduction. | ||
else: | ||
return loss.mean() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,218 @@ | ||
# Copyright 2020 MONAI Consortium | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import torch.nn.functional as F | ||
import torch.nn as nn | ||
import torch.optim as optim | ||
import torch | ||
import unittest | ||
from monai.losses import FocalLoss | ||
|
||
|
||
class TestFocalLoss(unittest.TestCase): | ||
def test_consistency_with_cross_entropy_2d(self): | ||
# For gamma=0 the focal loss reduces to the cross entropy loss | ||
focal_loss = FocalLoss(gamma=0.0, reduction="mean") | ||
ce = nn.CrossEntropyLoss(reduction="mean") | ||
max_error = 0 | ||
class_num = 10 | ||
batch_size = 128 | ||
for _ in range(100): | ||
# Create a random tensor of shape (batch_size, class_num, 8, 4) | ||
x = torch.rand(batch_size, class_num, 8, 4, requires_grad=True) | ||
# Create a random batch of classes | ||
l = torch.randint(low=0, high=class_num, size=(batch_size, 8, 4)) | ||
l = l.long() | ||
if torch.cuda.is_available(): | ||
x = x.cuda() | ||
l = l.cuda() | ||
output0 = focal_loss.forward(x, l) | ||
output1 = ce.forward(x, l) | ||
a = float(output0.cpu().detach()) | ||
b = float(output1.cpu().detach()) | ||
if abs(a - b) > max_error: | ||
max_error = abs(a - b) | ||
self.assertAlmostEqual(max_error, 0.0, places=3) | ||
|
||
def test_consistency_with_cross_entropy_classification(self): | ||
# for gamma=0 the focal loss reduces to the cross entropy loss | ||
focal_loss = FocalLoss(gamma=0.0, reduction="mean") | ||
ce = nn.CrossEntropyLoss(reduction="mean") | ||
max_error = 0 | ||
class_num = 10 | ||
batch_size = 128 | ||
for _ in range(100): | ||
# Create a random scores tensor of shape (batch_size, class_num) | ||
x = torch.rand(batch_size, class_num, requires_grad=True) | ||
# Create a random batch of classes | ||
l = torch.randint(low=0, high=class_num, size=(batch_size,)) | ||
l = l.long() | ||
if torch.cuda.is_available(): | ||
x = x.cuda() | ||
l = l.cuda() | ||
output0 = focal_loss.forward(x, l) | ||
output1 = ce.forward(x, l) | ||
a = float(output0.cpu().detach()) | ||
b = float(output1.cpu().detach()) | ||
if abs(a - b) > max_error: | ||
max_error = abs(a - b) | ||
self.assertAlmostEqual(max_error, 0.0, places=3) | ||
|
||
def test_bin_seg_2d(self): | ||
# define 2d examples | ||
target = torch.tensor([[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]]) | ||
# add another dimension corresponding to the batch (batch size = 1 here) | ||
target = target.unsqueeze(0) # shape (1, H, W) | ||
pred_very_good = 1000 * F.one_hot(target, num_classes=2).permute(0, 3, 1, 2).float() | ||
|
||
# initialize the mean dice loss | ||
loss = FocalLoss() | ||
|
||
# focal loss for pred_very_good should be close to 0 | ||
focal_loss_good = float(loss.forward(pred_very_good, target).cpu()) | ||
self.assertAlmostEqual(focal_loss_good, 0.0, places=3) | ||
|
||
# Same test, but for target with a class dimension | ||
target = target.unsqueeze(1) # shape (1, 1, H, W) | ||
focal_loss_good = float(loss.forward(pred_very_good, target).cpu()) | ||
self.assertAlmostEqual(focal_loss_good, 0.0, places=3) | ||
|
||
def test_empty_class_2d(self): | ||
num_classes = 2 | ||
# define 2d examples | ||
target = torch.tensor([[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]) | ||
# add another dimension corresponding to the batch (batch size = 1 here) | ||
target = target.unsqueeze(0) # shape (1, H, W) | ||
pred_very_good = 1000 * F.one_hot(target, num_classes=num_classes).permute(0, 3, 1, 2).float() | ||
|
||
# initialize the mean dice loss | ||
loss = FocalLoss() | ||
|
||
# focal loss for pred_very_good should be close to 0 | ||
focal_loss_good = float(loss.forward(pred_very_good, target).cpu()) | ||
self.assertAlmostEqual(focal_loss_good, 0.0, places=3) | ||
|
||
def test_multi_class_seg_2d(self): | ||
num_classes = 6 # labels 0 to 5 | ||
# define 2d examples | ||
target = torch.tensor([[0, 0, 0, 0], [0, 1, 2, 0], [0, 3, 4, 0], [0, 0, 0, 0]]) | ||
# add another dimension corresponding to the batch (batch size = 1 here) | ||
target = target.unsqueeze(0) # shape (1, H, W) | ||
pred_very_good = 1000 * F.one_hot(target, num_classes=num_classes).permute(0, 3, 1, 2).float() | ||
|
||
# initialize the mean dice loss | ||
loss = FocalLoss() | ||
|
||
# focal loss for pred_very_good should be close to 0 | ||
focal_loss_good = float(loss.forward(pred_very_good, target).cpu()) | ||
self.assertAlmostEqual(focal_loss_good, 0.0, places=3) | ||
|
||
def test_bin_seg_3d(self): | ||
# define 2d examples | ||
target = torch.tensor( | ||
[ | ||
# raw 0 | ||
[[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]], | ||
# raw 1 | ||
[[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]], | ||
# raw 2 | ||
[[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]], | ||
] | ||
) | ||
# add another dimension corresponding to the batch (batch size = 1 here) | ||
target = target.unsqueeze(0) # shape (1, H, W, D) | ||
pred_very_good = 1000 * F.one_hot(target, num_classes=2).permute(0, 4, 1, 2, 3).float() | ||
|
||
# initialize the mean dice loss | ||
loss = FocalLoss() | ||
|
||
# focal loss for pred_very_good should be close to 0 | ||
focal_loss_good = float(loss.forward(pred_very_good, target).cpu()) | ||
self.assertAlmostEqual(focal_loss_good, 0.0, places=3) | ||
|
||
def test_convergence(self): | ||
""" | ||
The goal of this test is to assess if the gradient of the loss function | ||
is correct by testing if we can train a one layer neural network | ||
to segment one image. | ||
We verify that the loss is decreasing in almost all SGD steps. | ||
""" | ||
learning_rate = 0.001 | ||
max_iter = 20 | ||
|
||
# define a simple 3d example | ||
target_seg = torch.tensor( | ||
[ | ||
# raw 0 | ||
[[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]], | ||
# raw 1 | ||
[[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]], | ||
# raw 2 | ||
[[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]], | ||
] | ||
) | ||
target_seg = torch.unsqueeze(target_seg, dim=0) | ||
image = 12 * target_seg + 27 | ||
image = image.float() | ||
num_classes = 2 | ||
num_voxels = 3 * 4 * 4 | ||
|
||
# define a one layer model | ||
class OnelayerNet(nn.Module): | ||
def __init__(self): | ||
super(OnelayerNet, self).__init__() | ||
self.layer = nn.Linear(num_voxels, num_voxels * num_classes) | ||
|
||
def forward(self, x): | ||
x = x.view(-1, num_voxels) | ||
x = self.layer(x) | ||
x = x.view(-1, num_classes, 3, 4, 4) | ||
return x | ||
|
||
# initialise the network | ||
net = OnelayerNet() | ||
|
||
# initialize the loss | ||
loss = FocalLoss() | ||
|
||
# initialize an SGD | ||
optimizer = optim.SGD(net.parameters(), lr=learning_rate, momentum=0.9) | ||
|
||
loss_history = [] | ||
# train the network | ||
for _ in range(max_iter): | ||
# set the gradient to zero | ||
optimizer.zero_grad() | ||
|
||
# forward pass | ||
output = net(image) | ||
loss_val = loss(output, target_seg) | ||
|
||
# backward pass | ||
loss_val.backward() | ||
optimizer.step() | ||
|
||
# stats | ||
loss_history.append(loss_val.item()) | ||
|
||
# count the number of SGD steps in which the loss decreases | ||
num_decreasing_steps = 0 | ||
for i in range(len(loss_history) - 1): | ||
if loss_history[i] > loss_history[i + 1]: | ||
num_decreasing_steps += 1 | ||
decreasing_steps_ratio = float(num_decreasing_steps) / (len(loss_history) - 1) | ||
|
||
# verify that the loss is decreasing for sufficiently many SGD steps | ||
self.assertTrue(decreasing_steps_ratio > 0.9) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
came across sigmoidfocal and softmaxfocal in the pytorch repo (for 2D):
https://github.com/pytorch/pytorch/blob/821b5f138a987807032a2fd908fe10a5be5439d9/modules/detectron/sigmoid_focal_loss_op.cu#L26
https://github.com/pytorch/pytorch/blob/821b5f138a987807032a2fd908fe10a5be5439d9/modules/detectron/softmax_focal_loss_op.cu#L59
shall we consider both options here?
ref clcarwin/focal_loss_pytorch#7
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The sigmoid formulation is for binary classification I think.