-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathloss.py
120 lines (100 loc) · 4.02 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
# -*- coding: utf-8 -*-
# @Time : 27/12/2022 4:33 PM
# @Author : Breeze
# @Email : [email protected]
from util.dice_score import dice_loss
from torch import nn
import torch.nn.functional as F
from torch.nn.functional import one_hot
import torch
def lw_loss(pred, gt):
pred = pred.squeeze()
gt = gt.squeeze()
gt = torch.tensor(gt, dtype=torch.int64)
gt = one_hot(gt, 4)
# print(gt)
loss = 0
criterion = nn.CrossEntropyLoss()
# for d in range(len(gt)):
# # print("nonzero:", torch.nonzero(gt))
# w = abs(d - torch.nonzero(gt))
# # print("w", w)
# print("pred:", pred)
# print("gt:", gt)
# loss += w * (pred[d] - gt[d])
# # loss += dice_loss(pred.unsqueeze(dim=0), gt.unsqueeze(dim=0), multiclass=False)
# epsilon = 1e-6
# set_inner = 2 * (pred * gt).sum()
# set_sum = pred.sum() + gt.sum()
# set_sum = torch.where(set_sum == 0, set_inner, set_sum)
# dice = (set_inner + epsilon) / (set_sum + epsilon)
# loss += (1 - dice)
loss += criterion(pred.float(), gt.float())
return loss
# Define the Focal Loss function
class FocalLoss(torch.nn.Module):
def __init__(self, alpha=1, gamma=2, reduction='mean'):
super(FocalLoss, self).__init__()
self.alpha = torch.tensor(alpha)
self.gamma = gamma
self.reduction = reduction
def forward(self, logits, labels):
logits = logits.float() # Convert to torch.float32
labels = labels.long() # Convert to torch.int64
ce_loss = F.cross_entropy(logits, labels, reduction='none')
pt = torch.exp(-ce_loss)
focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss
if self.reduction == 'mean':
focal_loss = torch.mean(focal_loss)
elif self.reduction == 'sum':
focal_loss = torch.sum(focal_loss)
return focal_loss
def unet_loss(model, masks_pred, true_masks):
criterion = nn.CrossEntropyLoss() if model.n_classes > 1 else nn.BCEWithLogitsLoss()
if true_masks.shape[0] != 1:
true_masks = true_masks.unsqueeze(0)
if model.n_classes == 1:
loss = criterion(masks_pred.squeeze(1), true_masks.float())
loss += dice_loss(F.sigmoid(masks_pred.squeeze(1)), true_masks.float(), multiclass=False)
else:
loss = criterion(masks_pred, true_masks)
loss += dice_loss(
F.softmax(masks_pred, dim=1).float(),
F.one_hot(true_masks, model.n_classes).permute(0, 3, 1, 2).float(),
multiclass=True
)
return loss
def unetpp_loss(model, masks_pred, true_masks):
criterion = nn.CrossEntropyLoss() if model.n_classes > 1 else nn.BCEWithLogitsLoss()
loss = 0
for i in masks_pred:
if model.n_classes == 1:
loss += criterion(i.squeeze(1), true_masks.float())
loss += dice_loss(F.sigmoid(i.squeeze(1)), true_masks.float(), multiclass=False)
else:
loss += criterion(i, true_masks)
loss += dice_loss(
F.softmax(i, dim=1).float(),
F.one_hot(true_masks, model.n_classes).permute(0, 3, 1, 2).float(),
multiclass=True
)
return loss/len(masks_pred)
class TverskyLoss(nn.Module):
def __init__(self, alpha=0.5, beta=0.5, smooth=1, weight=None, size_average=True):
super(TverskyLoss, self).__init__()
self.alpha = alpha
self.beta = beta
self.smooth = smooth
def forward(self, inputs, targets ):
# comment out if your model contains a sigmoid or equivalent activation layer
inputs = F.sigmoid(inputs)
bn = inputs.shape[0]
# flatten label and prediction tensors
inputs = inputs.view(bn, -1)
targets = targets.view(bn, -1)
# True Positives, False Positives & False Negatives
TP = (inputs * targets).sum()
FP = ((1 - targets) * inputs).sum()
FN = (targets * (1 - inputs)).sum()
Tversky = (TP + self.smooth) / (TP + self.alpha * FP + self.beta * FN + self.smooth)
return 1 - Tversky