Skip to content

Commit 5fcbf5a

Browse files
Lucas Fidonmonai-bot
andauthored
Focal loss (#325)
* added the Focal loss * focal loss inherits from _WeightedLoss and improved documentation * [MONAI] python code formatting Co-authored-by: monai-bot <[email protected]>
1 parent 57abc3b commit 5fcbf5a

File tree

4 files changed

+321
-0
lines changed

4 files changed

+321
-0
lines changed

docs/source/losses.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@ Segmentation Losses
2222
.. autoclass:: GeneralizedDiceLoss
2323
:members:
2424

25+
`FocalLoss`
26+
~~~~~~~~~~~
27+
.. autoclass:: monai.losses.focal_loss.FocalLoss
28+
:members:
2529

2630
.. automodule:: monai.losses.tversky
2731
.. currentmodule:: monai.losses.tversky

monai/losses/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,5 @@
1010
# limitations under the License.
1111

1212
from .dice import *
13+
from .focal_loss import *
1314
from .tversky import *

monai/losses/focal_loss.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
# Copyright 2020 MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
import torch
13+
import torch.nn.functional as F
14+
from torch.nn.modules.loss import _WeightedLoss
15+
16+
17+
class FocalLoss(_WeightedLoss):
18+
"""
19+
PyTorch implementation of the Focal Loss.
20+
[1] "Focal Loss for Dense Object Detection", T. Lin et al., ICCV 2017
21+
"""
22+
23+
def __init__(self, gamma=2.0, weight=None, reduction="mean"):
24+
"""
25+
Args:
26+
gamma: (float) value of the exponent gamma in the definition
27+
of the Focal loss.
28+
weight: (tensor) weights to apply to the
29+
voxels of each class. If None no weights are applied.
30+
This corresponds to the weights \alpha in [1].
31+
reduction: (string) reduction operation to apply on the loss batch.
32+
It can be 'mean', 'sum' or 'none' as in the standard PyTorch API
33+
for loss functions.
34+
35+
Exemple:
36+
pred = torch.tensor([[1, 0], [0, 1], [1, 0]], dtype=torch.float32)
37+
grnd = torch.tensor([0, 1 ,0], dtype=torch.int64)
38+
fl = FocalLoss()
39+
fl(pred, grnd)
40+
"""
41+
super(FocalLoss, self).__init__(weight=weight, reduction=reduction)
42+
self.gamma = gamma
43+
44+
def forward(self, input, target):
45+
"""
46+
Args:
47+
input: (tensor): the shape should be BNH[WD].
48+
target: (tensor): the shape should be BNH[WD].
49+
"""
50+
i = input
51+
t = target
52+
53+
if t.dim() < i.dim():
54+
# Add a class dimension to the ground-truth segmentation.
55+
t = t.unsqueeze(1) # N,H,W => N,1,H,W
56+
57+
# Change the shape of input and target to
58+
# num_batch x num_class x num_voxels.
59+
if input.dim() > 2:
60+
i = i.view(i.size(0), i.size(1), -1) # N,C,H,W => N,C,H*W
61+
t = t.view(t.size(0), t.size(1), -1) # N,1,H,W => N,1,H*W
62+
else: # Compatibility with classification.
63+
i = i.unsqueeze(2) # N,C => N,C,1
64+
t = t.unsqueeze(2) # N,1 => N,1,1
65+
66+
# Compute the log proba (more stable numerically than softmax).
67+
logpt = F.log_softmax(i, dim=1) # N,C,H*W
68+
# Keep only log proba values of the ground truth class for each voxel.
69+
logpt = logpt.gather(1, t) # N,C,H*W => N,1,H*W
70+
logpt = torch.squeeze(logpt, dim=1) # N,1,H*W => N,H*W
71+
72+
# Get the proba
73+
pt = torch.exp(logpt) # N,H*W
74+
75+
if self.weight is not None:
76+
if self.weight.type() != i.data.type():
77+
self.weight = self.weight.type_as(i.data)
78+
# Convert the weight to a map in which each voxel
79+
# has the weight associated with the ground-truth label
80+
# associated with this voxel in target.
81+
at = self.weight[None, :, None] # C => 1,C,1
82+
at = at.expand((t.size(0), -1, t.size(2))) # 1,C,1 => N,C,H*W
83+
at = at.gather(1, t.data) # selection of the weights => N,1,H*W
84+
at = torch.squeeze(at, dim=1) # N,1,H*W => N,H*W
85+
# Multiply the log proba by their weights.
86+
logpt = logpt * at
87+
88+
# Compute the loss mini-batch.
89+
weight = torch.pow(-pt + 1.0, self.gamma)
90+
loss = torch.mean(-weight * logpt, dim=1) # N
91+
92+
if self.reduction == "sum":
93+
return loss.sum()
94+
elif self.reduction == "none":
95+
return loss
96+
# Default is mean reduction.
97+
else:
98+
return loss.mean()

tests/test_focal_loss.py

Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
# Copyright 2020 MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
import torch.nn.functional as F
13+
import torch.nn as nn
14+
import torch.optim as optim
15+
import torch
16+
import unittest
17+
from monai.losses import FocalLoss
18+
19+
20+
class TestFocalLoss(unittest.TestCase):
21+
def test_consistency_with_cross_entropy_2d(self):
22+
# For gamma=0 the focal loss reduces to the cross entropy loss
23+
focal_loss = FocalLoss(gamma=0.0, reduction="mean")
24+
ce = nn.CrossEntropyLoss(reduction="mean")
25+
max_error = 0
26+
class_num = 10
27+
batch_size = 128
28+
for _ in range(100):
29+
# Create a random tensor of shape (batch_size, class_num, 8, 4)
30+
x = torch.rand(batch_size, class_num, 8, 4, requires_grad=True)
31+
# Create a random batch of classes
32+
l = torch.randint(low=0, high=class_num, size=(batch_size, 8, 4))
33+
l = l.long()
34+
if torch.cuda.is_available():
35+
x = x.cuda()
36+
l = l.cuda()
37+
output0 = focal_loss.forward(x, l)
38+
output1 = ce.forward(x, l)
39+
a = float(output0.cpu().detach())
40+
b = float(output1.cpu().detach())
41+
if abs(a - b) > max_error:
42+
max_error = abs(a - b)
43+
self.assertAlmostEqual(max_error, 0.0, places=3)
44+
45+
def test_consistency_with_cross_entropy_classification(self):
46+
# for gamma=0 the focal loss reduces to the cross entropy loss
47+
focal_loss = FocalLoss(gamma=0.0, reduction="mean")
48+
ce = nn.CrossEntropyLoss(reduction="mean")
49+
max_error = 0
50+
class_num = 10
51+
batch_size = 128
52+
for _ in range(100):
53+
# Create a random scores tensor of shape (batch_size, class_num)
54+
x = torch.rand(batch_size, class_num, requires_grad=True)
55+
# Create a random batch of classes
56+
l = torch.randint(low=0, high=class_num, size=(batch_size,))
57+
l = l.long()
58+
if torch.cuda.is_available():
59+
x = x.cuda()
60+
l = l.cuda()
61+
output0 = focal_loss.forward(x, l)
62+
output1 = ce.forward(x, l)
63+
a = float(output0.cpu().detach())
64+
b = float(output1.cpu().detach())
65+
if abs(a - b) > max_error:
66+
max_error = abs(a - b)
67+
self.assertAlmostEqual(max_error, 0.0, places=3)
68+
69+
def test_bin_seg_2d(self):
70+
# define 2d examples
71+
target = torch.tensor([[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]])
72+
# add another dimension corresponding to the batch (batch size = 1 here)
73+
target = target.unsqueeze(0) # shape (1, H, W)
74+
pred_very_good = 1000 * F.one_hot(target, num_classes=2).permute(0, 3, 1, 2).float()
75+
76+
# initialize the mean dice loss
77+
loss = FocalLoss()
78+
79+
# focal loss for pred_very_good should be close to 0
80+
focal_loss_good = float(loss.forward(pred_very_good, target).cpu())
81+
self.assertAlmostEqual(focal_loss_good, 0.0, places=3)
82+
83+
# Same test, but for target with a class dimension
84+
target = target.unsqueeze(1) # shape (1, 1, H, W)
85+
focal_loss_good = float(loss.forward(pred_very_good, target).cpu())
86+
self.assertAlmostEqual(focal_loss_good, 0.0, places=3)
87+
88+
def test_empty_class_2d(self):
89+
num_classes = 2
90+
# define 2d examples
91+
target = torch.tensor([[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]])
92+
# add another dimension corresponding to the batch (batch size = 1 here)
93+
target = target.unsqueeze(0) # shape (1, H, W)
94+
pred_very_good = 1000 * F.one_hot(target, num_classes=num_classes).permute(0, 3, 1, 2).float()
95+
96+
# initialize the mean dice loss
97+
loss = FocalLoss()
98+
99+
# focal loss for pred_very_good should be close to 0
100+
focal_loss_good = float(loss.forward(pred_very_good, target).cpu())
101+
self.assertAlmostEqual(focal_loss_good, 0.0, places=3)
102+
103+
def test_multi_class_seg_2d(self):
104+
num_classes = 6 # labels 0 to 5
105+
# define 2d examples
106+
target = torch.tensor([[0, 0, 0, 0], [0, 1, 2, 0], [0, 3, 4, 0], [0, 0, 0, 0]])
107+
# add another dimension corresponding to the batch (batch size = 1 here)
108+
target = target.unsqueeze(0) # shape (1, H, W)
109+
pred_very_good = 1000 * F.one_hot(target, num_classes=num_classes).permute(0, 3, 1, 2).float()
110+
111+
# initialize the mean dice loss
112+
loss = FocalLoss()
113+
114+
# focal loss for pred_very_good should be close to 0
115+
focal_loss_good = float(loss.forward(pred_very_good, target).cpu())
116+
self.assertAlmostEqual(focal_loss_good, 0.0, places=3)
117+
118+
def test_bin_seg_3d(self):
119+
# define 2d examples
120+
target = torch.tensor(
121+
[
122+
# raw 0
123+
[[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]],
124+
# raw 1
125+
[[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]],
126+
# raw 2
127+
[[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]],
128+
]
129+
)
130+
# add another dimension corresponding to the batch (batch size = 1 here)
131+
target = target.unsqueeze(0) # shape (1, H, W, D)
132+
pred_very_good = 1000 * F.one_hot(target, num_classes=2).permute(0, 4, 1, 2, 3).float()
133+
134+
# initialize the mean dice loss
135+
loss = FocalLoss()
136+
137+
# focal loss for pred_very_good should be close to 0
138+
focal_loss_good = float(loss.forward(pred_very_good, target).cpu())
139+
self.assertAlmostEqual(focal_loss_good, 0.0, places=3)
140+
141+
def test_convergence(self):
142+
"""
143+
The goal of this test is to assess if the gradient of the loss function
144+
is correct by testing if we can train a one layer neural network
145+
to segment one image.
146+
We verify that the loss is decreasing in almost all SGD steps.
147+
"""
148+
learning_rate = 0.001
149+
max_iter = 20
150+
151+
# define a simple 3d example
152+
target_seg = torch.tensor(
153+
[
154+
# raw 0
155+
[[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]],
156+
# raw 1
157+
[[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]],
158+
# raw 2
159+
[[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]],
160+
]
161+
)
162+
target_seg = torch.unsqueeze(target_seg, dim=0)
163+
image = 12 * target_seg + 27
164+
image = image.float()
165+
num_classes = 2
166+
num_voxels = 3 * 4 * 4
167+
168+
# define a one layer model
169+
class OnelayerNet(nn.Module):
170+
def __init__(self):
171+
super(OnelayerNet, self).__init__()
172+
self.layer = nn.Linear(num_voxels, num_voxels * num_classes)
173+
174+
def forward(self, x):
175+
x = x.view(-1, num_voxels)
176+
x = self.layer(x)
177+
x = x.view(-1, num_classes, 3, 4, 4)
178+
return x
179+
180+
# initialise the network
181+
net = OnelayerNet()
182+
183+
# initialize the loss
184+
loss = FocalLoss()
185+
186+
# initialize an SGD
187+
optimizer = optim.SGD(net.parameters(), lr=learning_rate, momentum=0.9)
188+
189+
loss_history = []
190+
# train the network
191+
for _ in range(max_iter):
192+
# set the gradient to zero
193+
optimizer.zero_grad()
194+
195+
# forward pass
196+
output = net(image)
197+
loss_val = loss(output, target_seg)
198+
199+
# backward pass
200+
loss_val.backward()
201+
optimizer.step()
202+
203+
# stats
204+
loss_history.append(loss_val.item())
205+
206+
# count the number of SGD steps in which the loss decreases
207+
num_decreasing_steps = 0
208+
for i in range(len(loss_history) - 1):
209+
if loss_history[i] > loss_history[i + 1]:
210+
num_decreasing_steps += 1
211+
decreasing_steps_ratio = float(num_decreasing_steps) / (len(loss_history) - 1)
212+
213+
# verify that the loss is decreasing for sufficiently many SGD steps
214+
self.assertTrue(decreasing_steps_ratio > 0.9)
215+
216+
217+
if __name__ == "__main__":
218+
unittest.main()

0 commit comments

Comments
 (0)