sparsercnn_loss.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. This code is based on https://github.com/PeizeSun/SparseR-CNN/blob/main/projects/SparseRCNN/sparsercnn/loss.py
  16. Ths copyright of PeizeSun/SparseR-CNN is as follows:
  17. MIT License [see LICENSE for details]
  18. """
  19. from __future__ import absolute_import
  20. from __future__ import division
  21. from __future__ import print_function
  22. from scipy.optimize import linear_sum_assignment
  23. import paddle
  24. import paddle.nn as nn
  25. import paddle.nn.functional as F
  26. from paddle.metric import accuracy
  27. from ppdet.core.workspace import register
  28. from ppdet.modeling.losses.iou_loss import GIoULoss
  29. __all__ = ["SparseRCNNLoss"]
  30. @register
  31. class SparseRCNNLoss(nn.Layer):
  32. """ This class computes the loss for SparseRCNN.
  33. The process happens in two steps:
  34. 1) we compute hungarian assignment between ground truth boxes and the outputs of the model
  35. 2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
  36. """
  37. __shared__ = ['num_classes']
  38. def __init__(self,
  39. losses,
  40. focal_loss_alpha,
  41. focal_loss_gamma,
  42. num_classes=80,
  43. class_weight=2.,
  44. l1_weight=5.,
  45. giou_weight=2.):
  46. """ Create the criterion.
  47. Parameters:
  48. num_classes: number of object categories, omitting the special no-object category
  49. weight_dict: dict containing as key the names of the losses and as values their relative weight.
  50. losses: list of all the losses to be applied. See get_loss for list of available losses.
  51. matcher: module able to compute a matching between targets and proposals
  52. """
  53. super().__init__()
  54. self.num_classes = num_classes
  55. weight_dict = {
  56. "loss_ce": class_weight,
  57. "loss_bbox": l1_weight,
  58. "loss_giou": giou_weight
  59. }
  60. self.weight_dict = weight_dict
  61. self.losses = losses
  62. self.giou_loss = GIoULoss(reduction="sum")
  63. self.focal_loss_alpha = focal_loss_alpha
  64. self.focal_loss_gamma = focal_loss_gamma
  65. self.matcher = HungarianMatcher(focal_loss_alpha, focal_loss_gamma,
  66. class_weight, l1_weight, giou_weight)
  67. def loss_labels(self, outputs, targets, indices, num_boxes, log=True):
  68. """Classification loss (NLL)
  69. targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
  70. """
  71. assert 'pred_logits' in outputs
  72. src_logits = outputs['pred_logits']
  73. idx = self._get_src_permutation_idx(indices)
  74. target_classes_o = paddle.concat([
  75. paddle.gather(
  76. t["labels"], J, axis=0) for t, (_, J) in zip(targets, indices)
  77. ])
  78. target_classes = paddle.full(
  79. src_logits.shape[:2], self.num_classes, dtype="int32")
  80. for i, ind in enumerate(zip(idx[0], idx[1])):
  81. target_classes[int(ind[0]), int(ind[1])] = target_classes_o[i]
  82. target_classes.stop_gradient = True
  83. src_logits = src_logits.flatten(start_axis=0, stop_axis=1)
  84. # prepare one_hot target.
  85. target_classes = target_classes.flatten(start_axis=0, stop_axis=1)
  86. class_ids = paddle.arange(0, self.num_classes)
  87. labels = (target_classes.unsqueeze(-1) == class_ids).astype("float32")
  88. labels.stop_gradient = True
  89. # comp focal loss.
  90. class_loss = sigmoid_focal_loss(
  91. src_logits,
  92. labels,
  93. alpha=self.focal_loss_alpha,
  94. gamma=self.focal_loss_gamma,
  95. reduction="sum", ) / num_boxes
  96. losses = {'loss_ce': class_loss}
  97. if log:
  98. label_acc = target_classes_o.unsqueeze(-1)
  99. src_idx = [src for (src, _) in indices]
  100. pred_list = []
  101. for i in range(outputs["pred_logits"].shape[0]):
  102. pred_list.append(
  103. paddle.gather(
  104. outputs["pred_logits"][i], src_idx[i], axis=0))
  105. pred = F.sigmoid(paddle.concat(pred_list, axis=0))
  106. acc = accuracy(pred, label_acc.astype("int64"))
  107. losses["acc"] = acc
  108. return losses
  109. def loss_boxes(self, outputs, targets, indices, num_boxes):
  110. """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss
  111. targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]
  112. The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size.
  113. """
  114. assert 'pred_boxes' in outputs # [batch_size, num_proposals, 4]
  115. src_idx = [src for (src, _) in indices]
  116. src_boxes_list = []
  117. for i in range(outputs["pred_boxes"].shape[0]):
  118. src_boxes_list.append(
  119. paddle.gather(
  120. outputs["pred_boxes"][i], src_idx[i], axis=0))
  121. src_boxes = paddle.concat(src_boxes_list, axis=0)
  122. target_boxes = paddle.concat(
  123. [
  124. paddle.gather(
  125. t['boxes'], I, axis=0)
  126. for t, (_, I) in zip(targets, indices)
  127. ],
  128. axis=0)
  129. target_boxes.stop_gradient = True
  130. losses = {}
  131. losses['loss_giou'] = self.giou_loss(src_boxes,
  132. target_boxes) / num_boxes
  133. image_size = paddle.concat([v["img_whwh_tgt"] for v in targets])
  134. src_boxes_ = src_boxes / image_size
  135. target_boxes_ = target_boxes / image_size
  136. loss_bbox = F.l1_loss(src_boxes_, target_boxes_, reduction='sum')
  137. losses['loss_bbox'] = loss_bbox / num_boxes
  138. return losses
  139. def _get_src_permutation_idx(self, indices):
  140. # permute predictions following indices
  141. batch_idx = paddle.concat(
  142. [paddle.full_like(src, i) for i, (src, _) in enumerate(indices)])
  143. src_idx = paddle.concat([src for (src, _) in indices])
  144. return batch_idx, src_idx
  145. def _get_tgt_permutation_idx(self, indices):
  146. # permute targets following indices
  147. batch_idx = paddle.concat(
  148. [paddle.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
  149. tgt_idx = paddle.concat([tgt for (_, tgt) in indices])
  150. return batch_idx, tgt_idx
  151. def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs):
  152. loss_map = {
  153. 'labels': self.loss_labels,
  154. 'boxes': self.loss_boxes,
  155. }
  156. assert loss in loss_map, f'do you really want to compute {loss} loss?'
  157. return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs)
  158. def forward(self, outputs, targets):
  159. """ This performs the loss computation.
  160. Parameters:
  161. outputs: dict of tensors, see the output specification of the model for the format
  162. targets: list of dicts, such that len(targets) == batch_size.
  163. The expected keys in each dict depends on the losses applied, see each loss' doc
  164. """
  165. outputs_without_aux = {
  166. k: v
  167. for k, v in outputs.items() if k != 'aux_outputs'
  168. }
  169. # Retrieve the matching between the outputs of the last layer and the targets
  170. indices = self.matcher(outputs_without_aux, targets)
  171. # Compute the average number of target boxes across all nodes, for normalization purposes
  172. num_boxes = sum(len(t["labels"]) for t in targets)
  173. num_boxes = paddle.to_tensor(
  174. [num_boxes],
  175. dtype="float32",
  176. place=next(iter(outputs.values())).place)
  177. # Compute all the requested losses
  178. losses = {}
  179. for loss in self.losses:
  180. losses.update(
  181. self.get_loss(loss, outputs, targets, indices, num_boxes))
  182. # In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
  183. if 'aux_outputs' in outputs:
  184. for i, aux_outputs in enumerate(outputs['aux_outputs']):
  185. indices = self.matcher(aux_outputs, targets)
  186. for loss in self.losses:
  187. kwargs = {}
  188. if loss == 'labels':
  189. # Logging is enabled only for the last layer
  190. kwargs = {'log': False}
  191. l_dict = self.get_loss(loss, aux_outputs, targets, indices,
  192. num_boxes, **kwargs)
  193. w_dict = {}
  194. for k in l_dict.keys():
  195. if k in self.weight_dict:
  196. w_dict[k + f'_{i}'] = l_dict[k] * self.weight_dict[
  197. k]
  198. else:
  199. w_dict[k + f'_{i}'] = l_dict[k]
  200. losses.update(w_dict)
  201. return losses
  202. class HungarianMatcher(nn.Layer):
  203. """This class computes an assignment between the targets and the predictions of the network
  204. For efficiency reasons, the targets don't include the no_object. Because of this, in general,
  205. there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
  206. while the others are un-matched (and thus treated as non-objects).
  207. """
  208. def __init__(self,
  209. focal_loss_alpha,
  210. focal_loss_gamma,
  211. cost_class: float=1,
  212. cost_bbox: float=1,
  213. cost_giou: float=1):
  214. """Creates the matcher
  215. Params:
  216. cost_class: This is the relative weight of the classification error in the matching cost
  217. cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost
  218. cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost
  219. """
  220. super().__init__()
  221. self.cost_class = cost_class
  222. self.cost_bbox = cost_bbox
  223. self.cost_giou = cost_giou
  224. self.focal_loss_alpha = focal_loss_alpha
  225. self.focal_loss_gamma = focal_loss_gamma
  226. assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, "all costs cant be 0"
  227. @paddle.no_grad()
  228. def forward(self, outputs, targets):
  229. """ Performs the matching
  230. Args:
  231. outputs: This is a dict that contains at least these entries:
  232. "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
  233. "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates
  234. eg. outputs = {"pred_logits": pred_logits, "pred_boxes": pred_boxes}
  235. targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:
  236. "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth
  237. objects in the target) containing the class labels
  238. "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates
  239. eg. targets = [{"labels":labels, "boxes": boxes}, ...,{"labels":labels, "boxes": boxes}]
  240. Returns:
  241. A list of size batch_size, containing tuples of (index_i, index_j) where:
  242. - index_i is the indices of the selected predictions (in order)
  243. - index_j is the indices of the corresponding selected targets (in order)
  244. For each batch element, it holds:
  245. len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
  246. """
  247. bs, num_queries = outputs["pred_logits"].shape[:2]
  248. # We flatten to compute the cost matrices in a batch
  249. out_prob = F.sigmoid(outputs["pred_logits"].flatten(
  250. start_axis=0, stop_axis=1))
  251. out_bbox = outputs["pred_boxes"].flatten(start_axis=0, stop_axis=1)
  252. # Also concat the target labels and boxes
  253. tgt_ids = paddle.concat([v["labels"] for v in targets])
  254. assert (tgt_ids > -1).all()
  255. tgt_bbox = paddle.concat([v["boxes"] for v in targets])
  256. # Compute the classification cost. Contrary to the loss, we don't use the NLL,
  257. # but approximate it in 1 - proba[target class].
  258. # The 1 is a constant that doesn't change the matching, it can be ommitted.
  259. # Compute the classification cost.
  260. alpha = self.focal_loss_alpha
  261. gamma = self.focal_loss_gamma
  262. neg_cost_class = (1 - alpha) * (out_prob**gamma) * (-(
  263. 1 - out_prob + 1e-8).log())
  264. pos_cost_class = alpha * ((1 - out_prob)
  265. **gamma) * (-(out_prob + 1e-8).log())
  266. cost_class = paddle.gather(
  267. pos_cost_class, tgt_ids, axis=1) - paddle.gather(
  268. neg_cost_class, tgt_ids, axis=1)
  269. # Compute the L1 cost between boxes
  270. image_size_out = paddle.concat(
  271. [v["img_whwh"].unsqueeze(0) for v in targets])
  272. image_size_out = image_size_out.unsqueeze(1).tile(
  273. [1, num_queries, 1]).flatten(
  274. start_axis=0, stop_axis=1)
  275. image_size_tgt = paddle.concat([v["img_whwh_tgt"] for v in targets])
  276. out_bbox_ = out_bbox / image_size_out
  277. tgt_bbox_ = tgt_bbox / image_size_tgt
  278. cost_bbox = F.l1_loss(
  279. out_bbox_.unsqueeze(-2), tgt_bbox_,
  280. reduction='none').sum(-1) # [batch_size * num_queries, num_tgts]
  281. # Compute the giou cost betwen boxes
  282. cost_giou = -get_bboxes_giou(out_bbox, tgt_bbox)
  283. # Final cost matrix
  284. C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou
  285. C = C.reshape([bs, num_queries, -1])
  286. sizes = [len(v["boxes"]) for v in targets]
  287. indices = [
  288. linear_sum_assignment(c[i].numpy())
  289. for i, c in enumerate(C.split(sizes, -1))
  290. ]
  291. return [(paddle.to_tensor(
  292. i, dtype="int32"), paddle.to_tensor(
  293. j, dtype="int32")) for i, j in indices]
  294. def box_area(boxes):
  295. assert (boxes[:, 2:] >= boxes[:, :2]).all()
  296. wh = boxes[:, 2:] - boxes[:, :2]
  297. return wh[:, 0] * wh[:, 1]
  298. def boxes_iou(boxes1, boxes2):
  299. '''
  300. Compute iou
  301. Args:
  302. boxes1 (paddle.tensor) shape (N, 4)
  303. boxes2 (paddle.tensor) shape (M, 4)
  304. Return:
  305. (paddle.tensor) shape (N, M)
  306. '''
  307. area1 = box_area(boxes1)
  308. area2 = box_area(boxes2)
  309. lt = paddle.maximum(boxes1.unsqueeze(-2)[:, :, :2], boxes2[:, :2])
  310. rb = paddle.minimum(boxes1.unsqueeze(-2)[:, :, 2:], boxes2[:, 2:])
  311. wh = (rb - lt).astype("float32").clip(min=1e-9)
  312. inter = wh[:, :, 0] * wh[:, :, 1]
  313. union = area1.unsqueeze(-1) + area2 - inter + 1e-9
  314. iou = inter / union
  315. return iou, union
  316. def get_bboxes_giou(boxes1, boxes2, eps=1e-9):
  317. """calculate the ious of boxes1 and boxes2
  318. Args:
  319. boxes1 (Tensor): shape [N, 4]
  320. boxes2 (Tensor): shape [M, 4]
  321. eps (float): epsilon to avoid divide by zero
  322. Return:
  323. ious (Tensor): ious of boxes1 and boxes2, with the shape [N, M]
  324. """
  325. assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
  326. assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
  327. iou, union = boxes_iou(boxes1, boxes2)
  328. lt = paddle.minimum(boxes1.unsqueeze(-2)[:, :, :2], boxes2[:, :2])
  329. rb = paddle.maximum(boxes1.unsqueeze(-2)[:, :, 2:], boxes2[:, 2:])
  330. wh = (rb - lt).astype("float32").clip(min=eps)
  331. enclose_area = wh[:, :, 0] * wh[:, :, 1]
  332. giou = iou - (enclose_area - union) / enclose_area
  333. return giou
  334. def sigmoid_focal_loss(inputs, targets, alpha, gamma, reduction="sum"):
  335. assert reduction in ["sum", "mean"
  336. ], f'do not support this {reduction} reduction?'
  337. p = F.sigmoid(inputs)
  338. ce_loss = F.binary_cross_entropy_with_logits(
  339. inputs, targets, reduction="none")
  340. p_t = p * targets + (1 - p) * (1 - targets)
  341. loss = ce_loss * ((1 - p_t)**gamma)
  342. if alpha >= 0:
  343. alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
  344. loss = alpha_t * loss
  345. if reduction == "mean":
  346. loss = loss.mean()
  347. elif reduction == "sum":
  348. loss = loss.sum()
  349. return loss