loss.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304
  1. # YOLOv5 🚀 by Ultralytics, GPL-3.0 license
  2. """
  3. Loss functions
  4. """
  5. import torch
  6. import torch.nn as nn
  7. from utils.metrics import bbox_iou
  8. from utils.torch_utils import de_parallel
  9. # 标签平滑
  10. def smooth_BCE(eps=0.1): # https://github.com/ultralytics/yolov3/issues/238#issuecomment-598028441
  11. # return positive, negative label smoothing BCE targets
  12. return 1.0 - 0.5 * eps, 0.5 * eps
  13. class BCEBlurWithLogitsLoss(nn.Module):
  14. # BCEwithLogitLoss() with reduced missing label effects.
  15. def __init__(self, alpha=0.05):
  16. """
  17. 标签平滑操作 [1, 0] => [0.95, 0.05]
  18. :param alpha:平滑参数
  19. :type alpha:
  20. """
  21. super().__init__()
  22. self.loss_fcn = nn.BCEWithLogitsLoss(reduction='none') # must be nn.BCEWithLogitsLoss()
  23. self.alpha = alpha
  24. def forward(self, pred, true):
  25. loss = self.loss_fcn(pred, true)
  26. pred = torch.sigmoid(pred) # prob from logits
  27. dx = pred - true # reduce only missing label effects
  28. # dx = (pred - true).abs() # reduce missing label and false label effects
  29. alpha_factor = 1 - torch.exp((dx - 1) / (self.alpha + 1e-4))
  30. loss *= alpha_factor
  31. return loss.mean()
  32. class FocalLoss(nn.Module):
  33. # Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5)
  34. def __init__(self, loss_fcn, gamma=1.5, alpha=0.25):
  35. super().__init__()
  36. self.loss_fcn = loss_fcn # must be nn.BCEWithLogitsLoss()
  37. self.gamma = gamma
  38. self.alpha = alpha
  39. self.reduction = loss_fcn.reduction
  40. self.loss_fcn.reduction = 'none' # required to apply FL to each element
  41. def forward(self, pred, true):
  42. loss = self.loss_fcn(pred, true)
  43. # p_t = torch.exp(-loss)
  44. # loss *= self.alpha * (1.000001 - p_t) ** self.gamma # non-zero power for gradient stability
  45. # TF implementation https://github.com/tensorflow/addons/blob/v0.7.1/tensorflow_addons/losses/focal_loss.py
  46. pred_prob = torch.sigmoid(pred) # prob from logits
  47. p_t = true * pred_prob + (1 - true) * (1 - pred_prob)
  48. alpha_factor = true * self.alpha + (1 - true) * (1 - self.alpha)
  49. modulating_factor = (1.0 - p_t) ** self.gamma
  50. loss *= alpha_factor * modulating_factor
  51. if self.reduction == 'mean':
  52. return loss.mean()
  53. elif self.reduction == 'sum':
  54. return loss.sum()
  55. else: # 'none'
  56. return loss
  57. class QFocalLoss(nn.Module):
  58. # Wraps Quality focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5)
  59. def __init__(self, loss_fcn, gamma=1.5, alpha=0.25):
  60. super().__init__()
  61. self.loss_fcn = loss_fcn # must be nn.BCEWithLogitsLoss()
  62. self.gamma = gamma
  63. self.alpha = alpha
  64. self.reduction = loss_fcn.reduction
  65. self.loss_fcn.reduction = 'none' # required to apply FL to each element
  66. def forward(self, pred, true):
  67. loss = self.loss_fcn(pred, true)
  68. pred_prob = torch.sigmoid(pred) # prob from logits
  69. alpha_factor = true * self.alpha + (1 - true) * (1 - self.alpha)
  70. modulating_factor = torch.abs(true - pred_prob) ** self.gamma
  71. loss *= alpha_factor * modulating_factor
  72. if self.reduction == 'mean':
  73. return loss.mean()
  74. elif self.reduction == 'sum':
  75. return loss.sum()
  76. else: # 'none'
  77. return loss
  78. #计算损失(分类损失 + 置信度损失 + 坐标框损失)
  79. class ComputeLoss:
  80. sort_obj_iou = False
  81. # Compute losses
  82. def __init__(self, model, autobalance=False):
  83. device = next(model.parameters()).device # get model device
  84. h = model.hyp # hyperparameters
  85. # Define criteria
  86. BCEcls = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h['cls_pw']], device=device))
  87. BCEobj = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([h['obj_pw']], device=device))
  88. # Class label smoothing https://arxiv.org/pdf/1902.04103.pdf eqn 3
  89. self.cp, self.cn = smooth_BCE(eps=h.get('label_smoothing', 0.0)) # positive, negative BCE targets
  90. # Focal loss
  91. g = h['fl_gamma'] # focal loss gamma 如果设置了fl_gamma参数, 就是用focal loss,默认没有使用
  92. if g > 0:
  93. BCEcls, BCEobj = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g)
  94. m = de_parallel(model).model[-1] # Detect() module
  95. self.balance = {3: [4.0, 1.0, 0.4]}.get(m.nl, [4.0, 1.0, 0.25, 0.06, 0.02]) # P3-P7 设置三个特征图对应输出的损失系数
  96. self.ssi = list(m.stride).index(16) if autobalance else 0 # stride 16 index
  97. self.BCEcls, self.BCEobj, self.gr, self.hyp, self.autobalance = BCEcls, BCEobj, 1.0, h, autobalance
  98. self.na = m.na # number of anchors
  99. self.nc = m.nc # number of classes
  100. self.nl = m.nl # number of layers
  101. self.anchors = m.anchors
  102. self.device = device
  103. def __call__(self, p, targets): # predictions, targets #
  104. '''
  105. :param p: 网络输出,List[torch.tensor * 3, p[i].shape = (b, 3, h, w, nc+5)], hw分别为特征图的长宽,b为batch-size
  106. :type p:
  107. :param targets:targets.shape = (nt, 6), 6=icxywh, i=0表示第一张图片, c为类别, 然后为坐标xywh
  108. :type targets:
  109. :return:
  110. :rtype:
  111. '''
  112. #初始化各个损失
  113. lcls = torch.zeros(1, device=self.device) # class loss
  114. lbox = torch.zeros(1, device=self.device) # box loss
  115. lobj = torch.zeros(1, device=self.device) # object loss
  116. tcls, tbox, indices, anchors = self.build_targets(p, targets) # targets 获得标签分类,边框,索引,anchors
  117. # Losses 遍历每个预测输出
  118. for i, pi in enumerate(p): # layer index, layer predictions
  119. # b表示当前bbox属于batch内部的第几张图片,
  120. # a表示当前bbox和当前层的第几个anchor匹配上,
  121. # gi,gj是对应的负责预测该bbox的网格坐标
  122. b, a, gj, gi = indices[i] # image, anchor, gridy, gridx
  123. tobj = torch.zeros(pi.shape[:4], dtype=pi.dtype, device=self.device) # target obj
  124. n = b.shape[0] # number of targets
  125. if n:
  126. # pxy, pwh, _, pcls = pi[b, a, gj, gi].tensor_split((2, 4, 5), dim=1) # faster, requires torch 1.8.0
  127. pxy, pwh, _, pcls = pi[b, a, gj, gi].split((2, 2, 1, self.nc), 1) # target-subset of predictions 找到对应网格的输出,取出对应位置预测值
  128. # Regression 目标框回归
  129. pxy = pxy.sigmoid() * 2 - 0.5
  130. pwh = (pwh.sigmoid() * 2) ** 2 * anchors[i]
  131. pbox = torch.cat((pxy, pwh), 1) # predicted box
  132. iou = bbox_iou(pbox, tbox[i], CIoU=True).squeeze() # iou(prediction, target) 计算边框损失,计算的是CIOU
  133. lbox += (1.0 - iou).mean() # iou loss
  134. # Objectness 置信度损失
  135. iou = iou.detach().clamp(0).type(tobj.dtype)
  136. if self.sort_obj_iou:
  137. j = iou.argsort()
  138. b, a, gj, gi, iou = b[j], a[j], gj[j], gi[j], iou[j]
  139. if self.gr < 1:
  140. iou = (1.0 - self.gr) + self.gr * iou
  141. # 将正样本的iou赋给
  142. tobj[b, a, gj, gi] = iou # iou ratio
  143. # Classification 分类损失
  144. if self.nc > 1: # cls loss (only if multiple classes) 类别数大于1
  145. t = torch.full_like(pcls, self.cn, device=self.device) # targets
  146. t[range(n), tcls[i]] = self.cp
  147. lcls += self.BCEcls(pcls, t) # BCE 分别对每个类别计算loss
  148. # Append targets to text file
  149. # with open('targets.txt', 'a') as file:
  150. # [file.write('%11.5g ' * 4 % tuple(x) + '\n') for x in torch.cat((txy[i], twh[i]), 1)]
  151. obji = self.BCEobj(pi[..., 4], tobj)
  152. lobj += obji * self.balance[i] # obj loss
  153. if self.autobalance:
  154. self.balance[i] = self.balance[i] * 0.9999 + 0.0001 / obji.detach().item()
  155. if self.autobalance:
  156. self.balance = [x / self.balance[self.ssi] for x in self.balance]
  157. # 根据超参数设置的各个部分损失的系数获取最终的损失
  158. lbox *= self.hyp['box']
  159. lobj *= self.hyp['obj']
  160. lcls *= self.hyp['cls']
  161. bs = tobj.shape[0] # batch size
  162. return (lbox + lobj + lcls) * bs, torch.cat((lbox, lobj, lcls)).detach()
  163. '''
  164. build_targets函数用于获得在训练时计算loss函数所需要的目标框,即被认为是正样本
  165. 与yolov3/v4的不同:yolov5支持跨网格预测
  166. 对于任何一个bbox,三个输出预测特征层都可能有先验框anchors匹配
  167. 该函数输出的正样本框可能比传入的targets(GT框)数目多
  168. 具体处理过程:
  169. (1)对于任何一层计算当前bbox和当前层anchor的匹配程度,不采用iou,而是shape比例;如果anchor和bbox的宽高比差距大于4,则不认为匹配,此时忽略相应的bbox,即当作背景;
  170. (2)然后对bbox计算落在的网格所有anchors都计算loss(并不是直接和GT框比较计算loss) 注意此时落在网格不再是一个,而是附近多个,这样就增加了正样本数,可能u才能在有些bbox在三个尺度都预测的情况;
  171. 另外,yolov5也没有conf分支忽略阈值(ignore_thresh)的操作,而yolov3/v4有
  172. '''
  173. def build_targets(self, p, targets): # p: 网络输出, targets:GT框, model:模型
  174. # Build targets for compute_loss(), input targets(image,class,x,y,w,h)
  175. na, nt = self.na, targets.shape[0] # number of anchors, targets anchor数量和标签框的数量
  176. tcls, tbox, indices, anch = [], [], [], []
  177. # ai,shape = (na, nt)生成anchor索引
  178. # anchor索引,用于表示当前bbox和当前层的那个anchor匹配
  179. gain = torch.ones(7, device=self.device) # normalized to gridspace gain
  180. ai = torch.arange(na, device=self.device).float().view(na, 1).repeat(1, nt) # same as .repeat_interleave(nt)
  181. targets = torch.cat((targets.repeat(na, 1, 1), ai[..., None]), 2) # append anchor indices 先repeat targets和当前层anchor个数一样,相当于每个bbox变成了三个,然后和3个anchor单独匹配
  182. g = 0.5 # bias 设置网格中心偏移量
  183. off = torch.tensor(
  184. [
  185. [0, 0], # 当前网格
  186. [1, 0], # 右边网格
  187. [0, 1], # 下边网格
  188. [-1, 0], # 左边网格
  189. [0, -1], # j,k,l,m # 上边网格
  190. # [1, 1], [1, -1], [-1, 1], [-1, -1], # jk,jm,lk,lm
  191. ],
  192. device=self.device).float() * g # offsets 找出当前网格临近的4个网格
  193. # 对每个检测层进行处理
  194. for i in range(self.nl): # 三个尺度的预测特征图输出分支 self.nl=3
  195. anchors = self.anchors[i]# 当前分支的anchor大小
  196. gain[2:6] = torch.tensor(p[i].shape)[[3, 2, 3, 2]] # xyxy gain 当前特征层大小
  197. # Match targets to anchors
  198. t = targets * gain # shape(3,n,7) 将标签框的xywh从基于0~1映射到基于特征图;targets的xywh本省是归一化尺度,故需要变成特征图尺度
  199. #对每个输出层单独匹配;首先将targets变成anchor尺度,方便计算;
  200. # 然后将target wh shape和anchor的wh计算比例,如果比例过大,则说明匹配度不高,将该bbox过滤,在当前层认为是背景层
  201. if nt:
  202. # Matches
  203. '''
  204. 预测的wh与anchor的wh做匹配,筛选掉比值大于hyp['anchor_t']的,从而更好的回归。
  205. 作者采用新的wh回归方式
  206. 与拿来yolov3/v4为anchors[i] * exp(wh)
  207. 将标签框与anchor的备注控制在0~4之间;hyp.scratch.yaml中的超参数anchor_t=4, 用于判定anchors与标签框默契度;
  208. '''
  209. # 计算当前target的wh和anchor的wh比例值
  210. # 如果最大比例大于预设值model.hyp['anchor_t']=4,则当前target和anchor匹配度不高,不强制回归,而把target丢弃
  211. # 计算比值ratio
  212. r = t[..., 4:6] / anchors[:, None] # wh ratio 不考虑xy坐标
  213. j = torch.max(r, 1 / r).max(2)[0] < self.hyp['anchor_t'] # compare 筛选满足 1/hyp['anchor_t'] < targets_wh/anchor_wh < hyp['anchor_t']的框;
  214. # j = wh_iou(anchors, t[:, 4:6]) > model.hyp['iou_t'] # iou(3,n)=wh_iou(anchors(3,2), gwh(n,2))
  215. # 筛选过后的t.shape = (M, 7), M为筛选过后的数量
  216. t = t[j] # filter 注意过滤规则没有考虑xy, 也就是当前bbox的wh是和所有anchoe计算的
  217. # Offsets
  218. gxy = t[:, 2:4] # grid xy label的中心点坐标
  219. gxi = gain[[2, 3]] - gxy # inverse 得到中心点相对于当前特征图的坐标
  220. '''
  221. 把相对于各个网格左上角x<0.5,y<0.5和相对于右下角的x<0.5,y<0.5的框提取出来,也就是j,k,l,m;
  222. 在选取gij(标签分配给的网格)的时候对这四个部分的框都做一个偏移(减去上面的offsets),
  223. 也就是下面的gij=(gxy - offsets).long()操作;
  224. 再将这四个部分的框与原始的gxy拼接在一起,总共就是五个部分;
  225. yolov3/v4仅仅采用当前网格的anchor进行回归;yolov4也有解决网格跑偏的措施,即通过对sigmoid限制输出;
  226. yolov5中心点回归从yolov3/v4的0~1的范围变成-0.5~1.5的范围;
  227. 中心点回归的公式变为:xy.sigmoid() * 2. - 0.5 + cx (其中对原始中心点网格坐标扩展两个邻居像素)
  228. '''
  229. # 对于筛选后的bbox,计算其落在哪个网格内,同时找出邻近的网格,将这些网格都认为是负责预测该bbox的网格
  230. # 浮点数取模的数学定义:对于两个浮点数a和b,a % b = a - n * b, 其中n为不能超过a / b 的最大整数
  231. j, k = ((gxy % 1 < g) & (gxy > 1)).T
  232. l, m = ((gxi % 1 < g) & (gxi > 1)).T
  233. j = torch.stack((torch.ones_like(j), j, k, l, m))
  234. t = t.repeat((5, 1, 1))[j] # 预设offset是5
  235. offsets = (torch.zeros_like(gxy)[None] + off[:, None])[j] # 选择出最近的3个
  236. else:
  237. t = targets[0]
  238. offsets = 0
  239. # Define
  240. '''
  241. 对每个bbox找出对应的正样本anchor,其中包括b表示当前bbox属于batch内部的第几张图片,a表示当前bbox和当前层的第几个anchor匹配上,
  242. gi,gj是对应的负责预测该bbox的网格坐标,
  243. gxy是不考虑offset或者说yolov3/v4里面设定的该bbox的负责预测网格中心点坐标xy,
  244. gwh是对应的bbox wh, c是该bbox类别
  245. '''
  246. bc, gxy, gwh, a = t.chunk(4, 1) # (image, class), grid xy, grid wh, anchors 中心点回归标签和宽高回归标签
  247. a, (b, c) = a.long().view(-1), bc.long().T # anchors, image, class
  248. gij = (gxy - offsets).long() # 当前label落在哪个网格上
  249. gi, gj = gij.T # grid indices
  250. # Append
  251. indices.append((b, a, gj.clamp_(0, gain[3] - 1), gi.clamp_(0, gain[2] - 1))) # image, anchor, grid indices 添加索引,方便计算损失的时取出对应位置的输出
  252. tbox.append(torch.cat((gxy - gij, gwh), 1)) # box 坐标值
  253. anch.append(anchors[a]) # anchors 尺寸
  254. tcls.append(c) # class
  255. return tcls, tbox, indices, anch