retina_head.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import math
  18. import paddle
  19. import paddle.nn as nn
  20. import paddle.nn.functional as F
  21. from paddle import ParamAttr
  22. from paddle.nn.initializer import Normal, Constant
  23. from ppdet.modeling.bbox_utils import bbox2delta, delta2bbox
  24. from ppdet.modeling.heads.fcos_head import FCOSFeat
  25. from ppdet.core.workspace import register
  26. __all__ = ['RetinaHead']
  27. @register
  28. class RetinaFeat(FCOSFeat):
  29. """We use FCOSFeat to construct conv layers in RetinaNet.
  30. We rename FCOSFeat to RetinaFeat to avoid confusion.
  31. """
  32. pass
  33. @register
  34. class RetinaHead(nn.Layer):
  35. """Used in RetinaNet proposed in paper https://arxiv.org/pdf/1708.02002.pdf
  36. """
  37. __shared__ = ['num_classes']
  38. __inject__ = [
  39. 'conv_feat', 'anchor_generator', 'bbox_assigner', 'loss_class',
  40. 'loss_bbox', 'nms'
  41. ]
  42. def __init__(self,
  43. num_classes=80,
  44. conv_feat='RetinaFeat',
  45. anchor_generator='RetinaAnchorGenerator',
  46. bbox_assigner='MaxIoUAssigner',
  47. loss_class='FocalLoss',
  48. loss_bbox='SmoothL1Loss',
  49. nms='MultiClassNMS',
  50. prior_prob=0.01,
  51. nms_pre=1000,
  52. weights=[1., 1., 1., 1.]):
  53. super(RetinaHead, self).__init__()
  54. self.num_classes = num_classes
  55. self.conv_feat = conv_feat
  56. self.anchor_generator = anchor_generator
  57. self.bbox_assigner = bbox_assigner
  58. self.loss_class = loss_class
  59. self.loss_bbox = loss_bbox
  60. self.nms = nms
  61. self.nms_pre = nms_pre
  62. self.weights = weights
  63. bias_init_value = -math.log((1 - prior_prob) / prior_prob)
  64. num_anchors = self.anchor_generator.num_anchors
  65. self.retina_cls = nn.Conv2D(
  66. in_channels=self.conv_feat.feat_out,
  67. out_channels=self.num_classes * num_anchors,
  68. kernel_size=3,
  69. stride=1,
  70. padding=1,
  71. weight_attr=ParamAttr(initializer=Normal(
  72. mean=0.0, std=0.01)),
  73. bias_attr=ParamAttr(initializer=Constant(value=bias_init_value)))
  74. self.retina_reg = nn.Conv2D(
  75. in_channels=self.conv_feat.feat_out,
  76. out_channels=4 * num_anchors,
  77. kernel_size=3,
  78. stride=1,
  79. padding=1,
  80. weight_attr=ParamAttr(initializer=Normal(
  81. mean=0.0, std=0.01)),
  82. bias_attr=ParamAttr(initializer=Constant(value=0)))
  83. def forward(self, neck_feats, targets=None):
  84. cls_logits_list = []
  85. bboxes_reg_list = []
  86. for neck_feat in neck_feats:
  87. conv_cls_feat, conv_reg_feat = self.conv_feat(neck_feat)
  88. cls_logits = self.retina_cls(conv_cls_feat)
  89. bbox_reg = self.retina_reg(conv_reg_feat)
  90. cls_logits_list.append(cls_logits)
  91. bboxes_reg_list.append(bbox_reg)
  92. if self.training:
  93. return self.get_loss([cls_logits_list, bboxes_reg_list], targets)
  94. else:
  95. return [cls_logits_list, bboxes_reg_list]
  96. def get_loss(self, head_outputs, targets):
  97. """Here we calculate loss for a batch of images.
  98. We assign anchors to gts in each image and gather all the assigned
  99. postive and negative samples. Then loss is calculated on the gathered
  100. samples.
  101. """
  102. cls_logits_list, bboxes_reg_list = head_outputs
  103. anchors = self.anchor_generator(cls_logits_list)
  104. anchors = paddle.concat(anchors)
  105. # matches: contain gt_inds
  106. # match_labels: -1(ignore), 0(neg) or 1(pos)
  107. matches_list, match_labels_list = [], []
  108. # assign anchors to gts, no sampling is involved
  109. for gt_bbox in targets['gt_bbox']:
  110. matches, match_labels = self.bbox_assigner(anchors, gt_bbox)
  111. matches_list.append(matches)
  112. match_labels_list.append(match_labels)
  113. # reshape network outputs
  114. cls_logits = [
  115. _.transpose([0, 2, 3, 1]).reshape([0, -1, self.num_classes])
  116. for _ in cls_logits_list
  117. ]
  118. bboxes_reg = [
  119. _.transpose([0, 2, 3, 1]).reshape([0, -1, 4])
  120. for _ in bboxes_reg_list
  121. ]
  122. cls_logits = paddle.concat(cls_logits, axis=1)
  123. bboxes_reg = paddle.concat(bboxes_reg, axis=1)
  124. cls_pred_list, cls_tar_list = [], []
  125. reg_pred_list, reg_tar_list = [], []
  126. # find and gather preds and targets in each image
  127. for matches, match_labels, cls_logit, bbox_reg, gt_bbox, gt_class in \
  128. zip(matches_list, match_labels_list, cls_logits, bboxes_reg,
  129. targets['gt_bbox'], targets['gt_class']):
  130. pos_mask = (match_labels == 1)
  131. neg_mask = (match_labels == 0)
  132. chosen_mask = paddle.logical_or(pos_mask, neg_mask)
  133. gt_class = gt_class.reshape([-1])
  134. bg_class = paddle.to_tensor(
  135. [self.num_classes], dtype=gt_class.dtype)
  136. # a trick to assign num_classes to negative targets
  137. gt_class = paddle.concat([gt_class, bg_class], axis=-1)
  138. matches = paddle.where(neg_mask,
  139. paddle.full_like(matches, gt_class.size - 1),
  140. matches)
  141. cls_pred = cls_logit[chosen_mask]
  142. cls_tar = gt_class[matches[chosen_mask]]
  143. reg_pred = bbox_reg[pos_mask].reshape([-1, 4])
  144. reg_tar = gt_bbox[matches[pos_mask]].reshape([-1, 4])
  145. reg_tar = bbox2delta(anchors[pos_mask], reg_tar, self.weights)
  146. cls_pred_list.append(cls_pred)
  147. cls_tar_list.append(cls_tar)
  148. reg_pred_list.append(reg_pred)
  149. reg_tar_list.append(reg_tar)
  150. cls_pred = paddle.concat(cls_pred_list)
  151. cls_tar = paddle.concat(cls_tar_list)
  152. reg_pred = paddle.concat(reg_pred_list)
  153. reg_tar = paddle.concat(reg_tar_list)
  154. avg_factor = max(1.0, reg_pred.shape[0])
  155. cls_loss = self.loss_class(
  156. cls_pred, cls_tar, reduction='sum') / avg_factor
  157. if reg_pred.shape[0] == 0:
  158. reg_loss = paddle.zeros([1])
  159. reg_loss.stop_gradient = False
  160. else:
  161. reg_loss = self.loss_bbox(
  162. reg_pred, reg_tar, reduction='sum') / avg_factor
  163. loss = cls_loss + reg_loss
  164. out_dict = {
  165. 'loss_cls': cls_loss,
  166. 'loss_reg': reg_loss,
  167. 'loss': loss,
  168. }
  169. return out_dict
  170. def get_bboxes_single(self,
  171. anchors,
  172. cls_scores_list,
  173. bbox_preds_list,
  174. im_shape,
  175. scale_factor,
  176. rescale=True):
  177. assert len(cls_scores_list) == len(bbox_preds_list)
  178. mlvl_bboxes = []
  179. mlvl_scores = []
  180. for anchor, cls_score, bbox_pred in zip(anchors, cls_scores_list,
  181. bbox_preds_list):
  182. cls_score = cls_score.reshape([-1, self.num_classes])
  183. bbox_pred = bbox_pred.reshape([-1, 4])
  184. if self.nms_pre is not None and cls_score.shape[0] > self.nms_pre:
  185. max_score = cls_score.max(axis=1)
  186. _, topk_inds = max_score.topk(self.nms_pre)
  187. bbox_pred = bbox_pred.gather(topk_inds)
  188. anchor = anchor.gather(topk_inds)
  189. cls_score = cls_score.gather(topk_inds)
  190. bbox_pred = delta2bbox(bbox_pred, anchor, self.weights).squeeze()
  191. mlvl_bboxes.append(bbox_pred)
  192. mlvl_scores.append(F.sigmoid(cls_score))
  193. mlvl_bboxes = paddle.concat(mlvl_bboxes)
  194. mlvl_bboxes = paddle.squeeze(mlvl_bboxes)
  195. if rescale:
  196. mlvl_bboxes = mlvl_bboxes / paddle.concat(
  197. [scale_factor[::-1], scale_factor[::-1]])
  198. mlvl_scores = paddle.concat(mlvl_scores)
  199. mlvl_scores = mlvl_scores.transpose([1, 0])
  200. return mlvl_bboxes, mlvl_scores
  201. def decode(self, anchors, cls_logits, bboxes_reg, im_shape, scale_factor):
  202. batch_bboxes = []
  203. batch_scores = []
  204. for img_id in range(cls_logits[0].shape[0]):
  205. num_lvls = len(cls_logits)
  206. cls_scores_list = [cls_logits[i][img_id] for i in range(num_lvls)]
  207. bbox_preds_list = [bboxes_reg[i][img_id] for i in range(num_lvls)]
  208. bboxes, scores = self.get_bboxes_single(
  209. anchors, cls_scores_list, bbox_preds_list, im_shape[img_id],
  210. scale_factor[img_id])
  211. batch_bboxes.append(bboxes)
  212. batch_scores.append(scores)
  213. batch_bboxes = paddle.stack(batch_bboxes, axis=0)
  214. batch_scores = paddle.stack(batch_scores, axis=0)
  215. return batch_bboxes, batch_scores
  216. def post_process(self, head_outputs, im_shape, scale_factor):
  217. cls_logits_list, bboxes_reg_list = head_outputs
  218. anchors = self.anchor_generator(cls_logits_list)
  219. cls_logits = [_.transpose([0, 2, 3, 1]) for _ in cls_logits_list]
  220. bboxes_reg = [_.transpose([0, 2, 3, 1]) for _ in bboxes_reg_list]
  221. bboxes, scores = self.decode(anchors, cls_logits, bboxes_reg, im_shape,
  222. scale_factor)
  223. bbox_pred, bbox_num, _ = self.nms(bboxes, scores)
  224. return bbox_pred, bbox_num