tood_head.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import paddle
  18. import paddle.nn as nn
  19. import paddle.nn.functional as F
  20. from paddle import ParamAttr
  21. from paddle.nn.initializer import Constant
  22. from ppdet.core.workspace import register
  23. from ..initializer import normal_, constant_, bias_init_with_prob
  24. from ppdet.modeling.bbox_utils import bbox_center, batch_distance2bbox
  25. from ..losses import GIoULoss
  26. from ppdet.modeling.layers import ConvNormLayer
  27. from ppdet.modeling.ops import get_static_shape
  28. from ppdet.modeling.assigners.utils import generate_anchors_for_grid_cell
  29. class ScaleReg(nn.Layer):
  30. """
  31. Parameter for scaling the regression outputs.
  32. """
  33. def __init__(self, init_scale=1.):
  34. super(ScaleReg, self).__init__()
  35. self.scale_reg = self.create_parameter(
  36. shape=[1],
  37. attr=ParamAttr(initializer=Constant(value=init_scale)),
  38. dtype="float32")
  39. def forward(self, inputs):
  40. out = inputs * self.scale_reg
  41. return out
  42. class TaskDecomposition(nn.Layer):
  43. """This code is based on
  44. https://github.com/fcjian/TOOD/blob/master/mmdet/models/dense_heads/tood_head.py
  45. """
  46. def __init__(
  47. self,
  48. feat_channels,
  49. stacked_convs,
  50. la_down_rate=8,
  51. norm_type='gn',
  52. norm_groups=32, ):
  53. super(TaskDecomposition, self).__init__()
  54. self.feat_channels = feat_channels
  55. self.stacked_convs = stacked_convs
  56. self.norm_type = norm_type
  57. self.norm_groups = norm_groups
  58. self.in_channels = self.feat_channels * self.stacked_convs
  59. self.la_conv1 = nn.Conv2D(self.in_channels,
  60. self.in_channels // la_down_rate, 1)
  61. self.la_conv2 = nn.Conv2D(self.in_channels // la_down_rate,
  62. self.stacked_convs, 1)
  63. self.reduction_conv = ConvNormLayer(
  64. self.in_channels,
  65. self.feat_channels,
  66. filter_size=1,
  67. stride=1,
  68. norm_type=self.norm_type,
  69. norm_groups=self.norm_groups)
  70. self._init_weights()
  71. def _init_weights(self):
  72. normal_(self.la_conv1.weight, std=0.001)
  73. normal_(self.la_conv2.weight, std=0.001)
  74. def forward(self, feat, avg_feat):
  75. b, _, h, w = get_static_shape(feat)
  76. weight = F.relu(self.la_conv1(avg_feat))
  77. weight = F.sigmoid(self.la_conv2(weight)).unsqueeze(-1)
  78. feat = paddle.reshape(
  79. feat, [b, self.stacked_convs, self.feat_channels, h, w]) * weight
  80. feat = self.reduction_conv(feat.flatten(1, 2))
  81. feat = F.relu(feat)
  82. return feat
  83. @register
  84. class TOODHead(nn.Layer):
  85. """This code is based on
  86. https://github.com/fcjian/TOOD/blob/master/mmdet/models/dense_heads/tood_head.py
  87. """
  88. __inject__ = ['nms', 'static_assigner', 'assigner']
  89. __shared__ = ['num_classes']
  90. def __init__(self,
  91. num_classes=80,
  92. feat_channels=256,
  93. stacked_convs=6,
  94. fpn_strides=(8, 16, 32, 64, 128),
  95. grid_cell_scale=8,
  96. grid_cell_offset=0.5,
  97. norm_type='gn',
  98. norm_groups=32,
  99. static_assigner_epoch=4,
  100. use_align_head=True,
  101. loss_weight={
  102. 'class': 1.0,
  103. 'bbox': 1.0,
  104. 'iou': 2.0,
  105. },
  106. nms='MultiClassNMS',
  107. static_assigner='ATSSAssigner',
  108. assigner='TaskAlignedAssigner'):
  109. super(TOODHead, self).__init__()
  110. self.num_classes = num_classes
  111. self.feat_channels = feat_channels
  112. self.stacked_convs = stacked_convs
  113. self.fpn_strides = fpn_strides
  114. self.grid_cell_scale = grid_cell_scale
  115. self.grid_cell_offset = grid_cell_offset
  116. self.static_assigner_epoch = static_assigner_epoch
  117. self.use_align_head = use_align_head
  118. self.nms = nms
  119. self.static_assigner = static_assigner
  120. self.assigner = assigner
  121. self.loss_weight = loss_weight
  122. self.giou_loss = GIoULoss()
  123. self.inter_convs = nn.LayerList()
  124. for i in range(self.stacked_convs):
  125. self.inter_convs.append(
  126. ConvNormLayer(
  127. self.feat_channels,
  128. self.feat_channels,
  129. filter_size=3,
  130. stride=1,
  131. norm_type=norm_type,
  132. norm_groups=norm_groups))
  133. self.cls_decomp = TaskDecomposition(
  134. self.feat_channels,
  135. self.stacked_convs,
  136. self.stacked_convs * 8,
  137. norm_type=norm_type,
  138. norm_groups=norm_groups)
  139. self.reg_decomp = TaskDecomposition(
  140. self.feat_channels,
  141. self.stacked_convs,
  142. self.stacked_convs * 8,
  143. norm_type=norm_type,
  144. norm_groups=norm_groups)
  145. self.tood_cls = nn.Conv2D(
  146. self.feat_channels, self.num_classes, 3, padding=1)
  147. self.tood_reg = nn.Conv2D(self.feat_channels, 4, 3, padding=1)
  148. if self.use_align_head:
  149. self.cls_prob_conv1 = nn.Conv2D(self.feat_channels *
  150. self.stacked_convs,
  151. self.feat_channels // 4, 1)
  152. self.cls_prob_conv2 = nn.Conv2D(
  153. self.feat_channels // 4, 1, 3, padding=1)
  154. self.reg_offset_conv1 = nn.Conv2D(self.feat_channels *
  155. self.stacked_convs,
  156. self.feat_channels // 4, 1)
  157. self.reg_offset_conv2 = nn.Conv2D(
  158. self.feat_channels // 4, 4 * 2, 3, padding=1)
  159. self.scales_regs = nn.LayerList([ScaleReg() for _ in self.fpn_strides])
  160. self._init_weights()
  161. @classmethod
  162. def from_config(cls, cfg, input_shape):
  163. return {
  164. 'feat_channels': input_shape[0].channels,
  165. 'fpn_strides': [i.stride for i in input_shape],
  166. }
  167. def _init_weights(self):
  168. bias_cls = bias_init_with_prob(0.01)
  169. normal_(self.tood_cls.weight, std=0.01)
  170. constant_(self.tood_cls.bias, bias_cls)
  171. normal_(self.tood_reg.weight, std=0.01)
  172. if self.use_align_head:
  173. normal_(self.cls_prob_conv1.weight, std=0.01)
  174. normal_(self.cls_prob_conv2.weight, std=0.01)
  175. constant_(self.cls_prob_conv2.bias, bias_cls)
  176. normal_(self.reg_offset_conv1.weight, std=0.001)
  177. constant_(self.reg_offset_conv2.weight)
  178. constant_(self.reg_offset_conv2.bias)
  179. def _reg_grid_sample(self, feat, offset, anchor_points):
  180. b, _, h, w = get_static_shape(feat)
  181. feat = paddle.reshape(feat, [-1, 1, h, w])
  182. offset = paddle.reshape(offset, [-1, 2, h, w]).transpose([0, 2, 3, 1])
  183. grid_shape = paddle.concat([w, h]).astype('float32')
  184. grid = (offset + anchor_points) / grid_shape
  185. grid = 2 * grid.clip(0., 1.) - 1
  186. feat = F.grid_sample(feat, grid)
  187. feat = paddle.reshape(feat, [b, -1, h, w])
  188. return feat
  189. def forward(self, feats):
  190. assert len(feats) == len(self.fpn_strides), \
  191. "The size of feats is not equal to size of fpn_strides"
  192. anchors, anchor_points, num_anchors_list, stride_tensor =\
  193. generate_anchors_for_grid_cell(
  194. feats, self.fpn_strides, self.grid_cell_scale,
  195. self.grid_cell_offset)
  196. anchor_centers_split = paddle.split(anchor_points / stride_tensor,
  197. num_anchors_list)
  198. cls_score_list, bbox_pred_list = [], []
  199. for feat, scale_reg, anchor_centers, stride in zip(
  200. feats, self.scales_regs, anchor_centers_split,
  201. self.fpn_strides):
  202. b, _, h, w = get_static_shape(feat)
  203. inter_feats = []
  204. for inter_conv in self.inter_convs:
  205. feat = F.relu(inter_conv(feat))
  206. inter_feats.append(feat)
  207. feat = paddle.concat(inter_feats, axis=1)
  208. # task decomposition
  209. avg_feat = F.adaptive_avg_pool2d(feat, (1, 1))
  210. cls_feat = self.cls_decomp(feat, avg_feat)
  211. reg_feat = self.reg_decomp(feat, avg_feat)
  212. # cls prediction and alignment
  213. cls_logits = self.tood_cls(cls_feat)
  214. if self.use_align_head:
  215. cls_prob = F.relu(self.cls_prob_conv1(feat))
  216. cls_prob = F.sigmoid(self.cls_prob_conv2(cls_prob))
  217. cls_score = (F.sigmoid(cls_logits) * cls_prob).sqrt()
  218. else:
  219. cls_score = F.sigmoid(cls_logits)
  220. cls_score_list.append(cls_score.flatten(2).transpose([0, 2, 1]))
  221. # reg prediction and alignment
  222. reg_dist = scale_reg(self.tood_reg(reg_feat).exp())
  223. reg_dist = reg_dist.flatten(2).transpose([0, 2, 1])
  224. reg_bbox = batch_distance2bbox(
  225. anchor_centers.unsqueeze(0), reg_dist)
  226. if self.use_align_head:
  227. reg_offset = F.relu(self.reg_offset_conv1(feat))
  228. reg_offset = self.reg_offset_conv2(reg_offset)
  229. reg_bbox = reg_bbox.transpose([0, 2, 1]).reshape([b, 4, h, w])
  230. anchor_centers = anchor_centers.reshape([1, h, w, 2])
  231. bbox_pred = self._reg_grid_sample(reg_bbox, reg_offset,
  232. anchor_centers)
  233. bbox_pred = bbox_pred.flatten(2).transpose([0, 2, 1])
  234. else:
  235. bbox_pred = reg_bbox
  236. if not self.training:
  237. bbox_pred *= stride
  238. bbox_pred_list.append(bbox_pred)
  239. cls_score_list = paddle.concat(cls_score_list, axis=1)
  240. bbox_pred_list = paddle.concat(bbox_pred_list, axis=1)
  241. return cls_score_list, bbox_pred_list, anchors, num_anchors_list, stride_tensor
  242. @staticmethod
  243. def _focal_loss(score, label, alpha=0.25, gamma=2.0):
  244. weight = (score - label).pow(gamma)
  245. if alpha > 0:
  246. alpha_t = alpha * label + (1 - alpha) * (1 - label)
  247. weight *= alpha_t
  248. loss = F.binary_cross_entropy(
  249. score, label, weight=weight, reduction='sum')
  250. return loss
  251. def get_loss(self, head_outs, gt_meta):
  252. pred_scores, pred_bboxes, anchors, \
  253. num_anchors_list, stride_tensor = head_outs
  254. gt_labels = gt_meta['gt_class']
  255. gt_bboxes = gt_meta['gt_bbox']
  256. pad_gt_mask = gt_meta['pad_gt_mask']
  257. # label assignment
  258. if gt_meta['epoch_id'] < self.static_assigner_epoch:
  259. assigned_labels, assigned_bboxes, assigned_scores = self.static_assigner(
  260. anchors,
  261. num_anchors_list,
  262. gt_labels,
  263. gt_bboxes,
  264. pad_gt_mask,
  265. bg_index=self.num_classes)
  266. alpha_l = 0.25
  267. else:
  268. assigned_labels, assigned_bboxes, assigned_scores = self.assigner(
  269. pred_scores.detach(),
  270. pred_bboxes.detach() * stride_tensor,
  271. bbox_center(anchors),
  272. num_anchors_list,
  273. gt_labels,
  274. gt_bboxes,
  275. pad_gt_mask,
  276. bg_index=self.num_classes)
  277. alpha_l = -1
  278. # rescale bbox
  279. assigned_bboxes /= stride_tensor
  280. # classification loss
  281. loss_cls = self._focal_loss(pred_scores, assigned_scores, alpha=alpha_l)
  282. # select positive samples mask
  283. mask_positive = (assigned_labels != self.num_classes)
  284. num_pos = mask_positive.astype(paddle.float32).sum()
  285. # bbox regression loss
  286. if num_pos > 0:
  287. bbox_mask = mask_positive.unsqueeze(-1).tile([1, 1, 4])
  288. pred_bboxes_pos = paddle.masked_select(pred_bboxes,
  289. bbox_mask).reshape([-1, 4])
  290. assigned_bboxes_pos = paddle.masked_select(
  291. assigned_bboxes, bbox_mask).reshape([-1, 4])
  292. bbox_weight = paddle.masked_select(
  293. assigned_scores.sum(-1), mask_positive).unsqueeze(-1)
  294. # iou loss
  295. loss_iou = self.giou_loss(pred_bboxes_pos,
  296. assigned_bboxes_pos) * bbox_weight
  297. loss_iou = loss_iou.sum() / bbox_weight.sum()
  298. # l1 loss
  299. loss_l1 = F.l1_loss(pred_bboxes_pos, assigned_bboxes_pos)
  300. else:
  301. loss_iou = paddle.zeros([1])
  302. loss_l1 = paddle.zeros([1])
  303. loss_cls /= assigned_scores.sum().clip(min=1)
  304. loss = self.loss_weight['class'] * loss_cls + self.loss_weight[
  305. 'iou'] * loss_iou
  306. return {
  307. 'loss': loss,
  308. 'loss_class': loss_cls,
  309. 'loss_iou': loss_iou,
  310. 'loss_l1': loss_l1
  311. }
  312. def post_process(self, head_outs, img_shape, scale_factor):
  313. pred_scores, pred_bboxes, _, _, _ = head_outs
  314. pred_scores = pred_scores.transpose([0, 2, 1])
  315. for i in range(len(pred_bboxes)):
  316. pred_bboxes[i, :, 0] = pred_bboxes[i, :, 0].clip(
  317. min=0, max=img_shape[i, 1])
  318. pred_bboxes[i, :, 1] = pred_bboxes[i, :, 1].clip(
  319. min=0, max=img_shape[i, 0])
  320. pred_bboxes[i, :, 2] = pred_bboxes[i, :, 2].clip(
  321. min=0, max=img_shape[i, 1])
  322. pred_bboxes[i, :, 3] = pred_bboxes[i, :, 3].clip(
  323. min=0, max=img_shape[i, 0])
  324. # scale bbox to origin
  325. scale_factor = scale_factor.flip([1]).tile([1, 2]).unsqueeze(1)
  326. pred_bboxes /= scale_factor
  327. bbox_pred, bbox_num, _ = self.nms(pred_bboxes, pred_scores)
  328. return bbox_pred, bbox_num