Implementation of some unbalanced loss for NLP task like focal_loss, dice_loss, DSC Loss, GHM Loss et.al
Summary
Here is a loss implementation repository included unbalanced loss
Loss Name | paper | Notes |
---|---|---|
Weighted CE Loss | UNet Architectures in Multiplanar Volumetric Segmentation -- Validated on Three Knee MRI Cohorts | |
Focal Loss | Focal Loss for Dense Object Detection | |
Dice Loss | V-Net: Fully Convolutional Neural Networks for Volumetric Medical Image Segmentation | |
DSC Loss | Dice Loss for Data-imbalanced NLP Tasks | |
GHM Loss | Gradient Harmonized Single-stage Detector |
How to use?
You can find all the loss usage information in test_loss.py.
Here is a simple demo of usage:
import torch
from unbalanced_loss.focal_loss import MultiFocalLoss
batch_size, num_class = 64, 10
Loss_Func = MultiFocalLoss(num_class=num_class, gamma=2.0, reduction='mean')
logits = torch.rand(batch_size, num_class, requires_grad=True) # (batch_size, num_classes)
targets = torch.randint(0, num_class, size=(batch_size, )) # (batch_size, )
loss = Loss_Func(logits, targets)
loss.backward()