varifocal_loss.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # The code is based on:
  15. # https://github.com/open-mmlab/mmdetection/blob/master/mmdet/models/losses/varifocal_loss.py
  16. from __future__ import absolute_import
  17. from __future__ import division
  18. from __future__ import print_function
  19. import numpy as np
  20. import paddle
  21. import paddle.nn as nn
  22. import paddle.nn.functional as F
  23. from ppdet.core.workspace import register, serializable
  24. from ppdet.modeling import ops
  25. __all__ = ['VarifocalLoss']
  26. def varifocal_loss(pred,
  27. target,
  28. alpha=0.75,
  29. gamma=2.0,
  30. iou_weighted=True,
  31. use_sigmoid=True):
  32. """`Varifocal Loss <https://arxiv.org/abs/2008.13367>`_
  33. Args:
  34. pred (Tensor): The prediction with shape (N, C), C is the
  35. number of classes
  36. target (Tensor): The learning target of the iou-aware
  37. classification score with shape (N, C), C is the number of classes.
  38. alpha (float, optional): A balance factor for the negative part of
  39. Varifocal Loss, which is different from the alpha of Focal Loss.
  40. Defaults to 0.75.
  41. gamma (float, optional): The gamma for calculating the modulating
  42. factor. Defaults to 2.0.
  43. iou_weighted (bool, optional): Whether to weight the loss of the
  44. positive example with the iou target. Defaults to True.
  45. """
  46. # pred and target should be of the same size
  47. assert pred.shape == target.shape
  48. if use_sigmoid:
  49. pred_new = F.sigmoid(pred)
  50. else:
  51. pred_new = pred
  52. target = target.cast(pred.dtype)
  53. if iou_weighted:
  54. focal_weight = target * (target > 0.0).cast('float32') + \
  55. alpha * (pred_new - target).abs().pow(gamma) * \
  56. (target <= 0.0).cast('float32')
  57. else:
  58. focal_weight = (target > 0.0).cast('float32') + \
  59. alpha * (pred_new - target).abs().pow(gamma) * \
  60. (target <= 0.0).cast('float32')
  61. if use_sigmoid:
  62. loss = F.binary_cross_entropy_with_logits(
  63. pred, target, reduction='none') * focal_weight
  64. else:
  65. loss = F.binary_cross_entropy(
  66. pred, target, reduction='none') * focal_weight
  67. loss = loss.sum(axis=1)
  68. return loss
  69. @register
  70. @serializable
  71. class VarifocalLoss(nn.Layer):
  72. def __init__(self,
  73. use_sigmoid=True,
  74. alpha=0.75,
  75. gamma=2.0,
  76. iou_weighted=True,
  77. reduction='mean',
  78. loss_weight=1.0):
  79. """`Varifocal Loss <https://arxiv.org/abs/2008.13367>`_
  80. Args:
  81. use_sigmoid (bool, optional): Whether the prediction is
  82. used for sigmoid or softmax. Defaults to True.
  83. alpha (float, optional): A balance factor for the negative part of
  84. Varifocal Loss, which is different from the alpha of Focal
  85. Loss. Defaults to 0.75.
  86. gamma (float, optional): The gamma for calculating the modulating
  87. factor. Defaults to 2.0.
  88. iou_weighted (bool, optional): Whether to weight the loss of the
  89. positive examples with the iou target. Defaults to True.
  90. reduction (str, optional): The method used to reduce the loss into
  91. a scalar. Defaults to 'mean'. Options are "none", "mean" and
  92. "sum".
  93. loss_weight (float, optional): Weight of loss. Defaults to 1.0.
  94. """
  95. super(VarifocalLoss, self).__init__()
  96. assert alpha >= 0.0
  97. self.use_sigmoid = use_sigmoid
  98. self.alpha = alpha
  99. self.gamma = gamma
  100. self.iou_weighted = iou_weighted
  101. self.reduction = reduction
  102. self.loss_weight = loss_weight
  103. def forward(self, pred, target, weight=None, avg_factor=None):
  104. """Forward function.
  105. Args:
  106. pred (Tensor): The prediction.
  107. target (Tensor): The learning target of the prediction.
  108. weight (Tensor, optional): The weight of loss for each
  109. prediction. Defaults to None.
  110. avg_factor (int, optional): Average factor that is used to average
  111. the loss. Defaults to None.
  112. Returns:
  113. Tensor: The calculated loss
  114. """
  115. loss = self.loss_weight * varifocal_loss(
  116. pred,
  117. target,
  118. alpha=self.alpha,
  119. gamma=self.gamma,
  120. iou_weighted=self.iou_weighted,
  121. use_sigmoid=self.use_sigmoid)
  122. if weight is not None:
  123. loss = loss * weight
  124. if avg_factor is None:
  125. if self.reduction == 'none':
  126. return loss
  127. elif self.reduction == 'mean':
  128. return loss.mean()
  129. elif self.reduction == 'sum':
  130. return loss.sum()
  131. else:
  132. # if reduction is mean, then average the loss by avg_factor
  133. if self.reduction == 'mean':
  134. loss = loss.sum() / avg_factor
  135. # if reduction is 'none', then do nothing, otherwise raise an error
  136. elif self.reduction != 'none':
  137. raise ValueError(
  138. 'avg_factor can not be used with reduction="sum"')
  139. return loss