Focal loss#325
Conversation
There was a problem hiding this comment.
thanks for the PR @LucasFidon!
the unit tests look great, after merging this I'll populate the tests to all segmentation losses.
a few minor changes needed:
- append this doc section to
docs/source/losses.rst
`FocalLoss`
~~~~~~~~~
.. autoclass:: monai.losses.focal_loss.FocalLoss
:members:
-
in the unit test, allocate cuda variables only if
torch.cuda.is_available(), so that we can also run tests with CPU only -
fix the flake8 warnings https://github.com/Project-MONAI/MONAI/runs/629355850?check_suite_focus=true#step:5:133 (if you use IDEs such as pycharm/vscode most of them could be resolved automatically)
| t = t.unsqueeze(2) # N,1 => N,1,1 | ||
|
|
||
| # Compute the log proba (more stable numerically than softmax) | ||
| logpt = F.log_softmax(i, dim=1) # N,C,H*W |
There was a problem hiding this comment.
came across sigmoidfocal and softmaxfocal in the pytorch repo (for 2D):
https://github.com/pytorch/pytorch/blob/821b5f138a987807032a2fd908fe10a5be5439d9/modules/detectron/sigmoid_focal_loss_op.cu#L26
https://github.com/pytorch/pytorch/blob/821b5f138a987807032a2fd908fe10a5be5439d9/modules/detectron/softmax_focal_loss_op.cu#L59
shall we consider both options here?
There was a problem hiding this comment.
The sigmoid formulation is for binary classification I think.
|
/black |
Fixes #115 Port focal loss .
Description
Add an implementation of the Focal loss.
Status
Ready
Types of changes