Dice Loss损失函数

Jul 6, 2022· · 1 min read

一种在分割任务中常用的损失函数。

公式如下, $$\begin{aligned} L_{Dice} &= 1 - \frac{2 \times |A \cap B|}{|A| + |B|} \quad& Standard \\ L_{Dice} &= 1 - \frac{2 \times |A \cap B| + 1}{|A| + |B| + 1} \quad& Laplace\ Smooothing \\ L_{Dice} &= 1 - \frac{2 \times |A \cap B|}{|A|^2 + |B|^2} \quad& Square \\ \end{aligned}$$

如果使用集合的角度来看,Dice Loss可以写作 $\frac{2*TP}{2*TP+FP+FN}$,也就是类似F1-Value的表示。

Dice Loss的好处是可以解决正负样本类别不平衡的现象。表现在图像分割任务中,通常前景占据的面积比较少,背景占据的面积比较大。Dice Loss在训练前期会对正样本比较敏感(较大的梯度),因而倾向于挖掘前景区域。而CrossEntropy比较公平的对待前景和背景,因此容易被负样本淹没。

一种pytorch的实现代码如下,其中pred和gt是基于分割的模型输出,即输出与原图像等比例缩放的取值0-1的分割图,因此可以直接使用乘法和求和函数计算。

class DiceLoss(nn.Module):
    def __init__(self, eps=1e-6):
        super(DiceLoss, self).__init__()
        self.eps = eps
    def forward(self, pred, gt):
        intersection = (pred * gt).sum()
        union = pred.sum() + gt.sum() + self.eps
        loss = 1 - 2.0 * intersection / union
        return loss