gfl_head.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # The code is based on:
  15. # https://github.com/open-mmlab/mmdetection/blob/master/mmdet/models/dense_heads/gfl_head.py
  16. from __future__ import absolute_import
  17. from __future__ import division
  18. from __future__ import print_function
  19. import math
  20. import numpy as np
  21. import paddle
  22. import paddle.nn as nn
  23. import paddle.nn.functional as F
  24. from paddle import ParamAttr
  25. from paddle.nn.initializer import Normal, Constant
  26. from ppdet.core.workspace import register
  27. from ppdet.modeling.layers import ConvNormLayer
  28. from ppdet.modeling.bbox_utils import distance2bbox, bbox2distance, batch_distance2bbox
  29. from ppdet.data.transform.atss_assigner import bbox_overlaps
  30. class ScaleReg(nn.Layer):
  31. """
  32. Parameter for scaling the regression outputs.
  33. """
  34. def __init__(self):
  35. super(ScaleReg, self).__init__()
  36. self.scale_reg = self.create_parameter(
  37. shape=[1],
  38. attr=ParamAttr(initializer=Constant(value=1.)),
  39. dtype="float32")
  40. def forward(self, inputs):
  41. out = inputs * self.scale_reg
  42. return out
  43. class Integral(nn.Layer):
  44. """A fixed layer for calculating integral result from distribution.
  45. This layer calculates the target location by :math: `sum{P(y_i) * y_i}`,
  46. P(y_i) denotes the softmax vector that represents the discrete distribution
  47. y_i denotes the discrete set, usually {0, 1, 2, ..., reg_max}
  48. Args:
  49. reg_max (int): The maximal value of the discrete set. Default: 16. You
  50. may want to reset it according to your new dataset or related
  51. settings.
  52. """
  53. def __init__(self, reg_max=16):
  54. super(Integral, self).__init__()
  55. self.reg_max = reg_max
  56. self.register_buffer('project',
  57. paddle.linspace(0, self.reg_max, self.reg_max + 1))
  58. def forward(self, x):
  59. """Forward feature from the regression head to get integral result of
  60. bounding box location.
  61. Args:
  62. x (Tensor): Features of the regression head, shape (N, 4*(n+1)),
  63. n is self.reg_max.
  64. Returns:
  65. x (Tensor): Integral result of box locations, i.e., distance
  66. offsets from the box center in four directions, shape (N, 4).
  67. """
  68. x = F.softmax(x.reshape([-1, self.reg_max + 1]), axis=1)
  69. x = F.linear(x, self.project)
  70. if self.training:
  71. x = x.reshape([-1, 4])
  72. return x
  73. @register
  74. class DGQP(nn.Layer):
  75. """Distribution-Guided Quality Predictor of GFocal head
  76. Args:
  77. reg_topk (int): top-k statistics of distribution to guide LQE
  78. reg_channels (int): hidden layer unit to generate LQE
  79. add_mean (bool): Whether to calculate the mean of top-k statistics
  80. """
  81. def __init__(self, reg_topk=4, reg_channels=64, add_mean=True):
  82. super(DGQP, self).__init__()
  83. self.reg_topk = reg_topk
  84. self.reg_channels = reg_channels
  85. self.add_mean = add_mean
  86. self.total_dim = reg_topk
  87. if add_mean:
  88. self.total_dim += 1
  89. self.reg_conv1 = self.add_sublayer(
  90. 'dgqp_reg_conv1',
  91. nn.Conv2D(
  92. in_channels=4 * self.total_dim,
  93. out_channels=self.reg_channels,
  94. kernel_size=1,
  95. weight_attr=ParamAttr(initializer=Normal(
  96. mean=0., std=0.01)),
  97. bias_attr=ParamAttr(initializer=Constant(value=0))))
  98. self.reg_conv2 = self.add_sublayer(
  99. 'dgqp_reg_conv2',
  100. nn.Conv2D(
  101. in_channels=self.reg_channels,
  102. out_channels=1,
  103. kernel_size=1,
  104. weight_attr=ParamAttr(initializer=Normal(
  105. mean=0., std=0.01)),
  106. bias_attr=ParamAttr(initializer=Constant(value=0))))
  107. def forward(self, x):
  108. """Forward feature from the regression head to get integral result of
  109. bounding box location.
  110. Args:
  111. x (Tensor): Features of the regression head, shape (N, 4*(n+1)),
  112. n is self.reg_max.
  113. Returns:
  114. x (Tensor): Integral result of box locations, i.e., distance
  115. offsets from the box center in four directions, shape (N, 4).
  116. """
  117. N, _, H, W = x.shape[:]
  118. prob = F.softmax(x.reshape([N, 4, -1, H, W]), axis=2)
  119. prob_topk, _ = prob.topk(self.reg_topk, axis=2)
  120. if self.add_mean:
  121. stat = paddle.concat(
  122. [prob_topk, prob_topk.mean(
  123. axis=2, keepdim=True)], axis=2)
  124. else:
  125. stat = prob_topk
  126. y = F.relu(self.reg_conv1(stat.reshape([N, -1, H, W])))
  127. y = F.sigmoid(self.reg_conv2(y))
  128. return y
  129. @register
  130. class GFLHead(nn.Layer):
  131. """
  132. GFLHead
  133. Args:
  134. conv_feat (object): Instance of 'FCOSFeat'
  135. num_classes (int): Number of classes
  136. fpn_stride (list): The stride of each FPN Layer
  137. prior_prob (float): Used to set the bias init for the class prediction layer
  138. loss_class (object): Instance of QualityFocalLoss.
  139. loss_dfl (object): Instance of DistributionFocalLoss.
  140. loss_bbox (object): Instance of bbox loss.
  141. reg_max: Max value of integral set :math: `{0, ..., reg_max}`
  142. n QFL setting. Default: 16.
  143. """
  144. __inject__ = [
  145. 'conv_feat', 'dgqp_module', 'loss_class', 'loss_dfl', 'loss_bbox', 'nms'
  146. ]
  147. __shared__ = ['num_classes']
  148. def __init__(self,
  149. conv_feat='FCOSFeat',
  150. dgqp_module=None,
  151. num_classes=80,
  152. fpn_stride=[8, 16, 32, 64, 128],
  153. prior_prob=0.01,
  154. loss_class='QualityFocalLoss',
  155. loss_dfl='DistributionFocalLoss',
  156. loss_bbox='GIoULoss',
  157. reg_max=16,
  158. feat_in_chan=256,
  159. nms=None,
  160. nms_pre=1000,
  161. cell_offset=0):
  162. super(GFLHead, self).__init__()
  163. self.conv_feat = conv_feat
  164. self.dgqp_module = dgqp_module
  165. self.num_classes = num_classes
  166. self.fpn_stride = fpn_stride
  167. self.prior_prob = prior_prob
  168. self.loss_qfl = loss_class
  169. self.loss_dfl = loss_dfl
  170. self.loss_bbox = loss_bbox
  171. self.reg_max = reg_max
  172. self.feat_in_chan = feat_in_chan
  173. self.nms = nms
  174. self.nms_pre = nms_pre
  175. self.cell_offset = cell_offset
  176. self.use_sigmoid = self.loss_qfl.use_sigmoid
  177. if self.use_sigmoid:
  178. self.cls_out_channels = self.num_classes
  179. else:
  180. self.cls_out_channels = self.num_classes + 1
  181. conv_cls_name = "gfl_head_cls"
  182. bias_init_value = -math.log((1 - self.prior_prob) / self.prior_prob)
  183. self.gfl_head_cls = self.add_sublayer(
  184. conv_cls_name,
  185. nn.Conv2D(
  186. in_channels=self.feat_in_chan,
  187. out_channels=self.cls_out_channels,
  188. kernel_size=3,
  189. stride=1,
  190. padding=1,
  191. weight_attr=ParamAttr(initializer=Normal(
  192. mean=0., std=0.01)),
  193. bias_attr=ParamAttr(
  194. initializer=Constant(value=bias_init_value))))
  195. conv_reg_name = "gfl_head_reg"
  196. self.gfl_head_reg = self.add_sublayer(
  197. conv_reg_name,
  198. nn.Conv2D(
  199. in_channels=self.feat_in_chan,
  200. out_channels=4 * (self.reg_max + 1),
  201. kernel_size=3,
  202. stride=1,
  203. padding=1,
  204. weight_attr=ParamAttr(initializer=Normal(
  205. mean=0., std=0.01)),
  206. bias_attr=ParamAttr(initializer=Constant(value=0))))
  207. self.scales_regs = []
  208. for i in range(len(self.fpn_stride)):
  209. lvl = int(math.log(int(self.fpn_stride[i]), 2))
  210. feat_name = 'p{}_feat'.format(lvl)
  211. scale_reg = self.add_sublayer(feat_name, ScaleReg())
  212. self.scales_regs.append(scale_reg)
  213. self.distribution_project = Integral(self.reg_max)
  214. def forward(self, fpn_feats):
  215. assert len(fpn_feats) == len(
  216. self.fpn_stride
  217. ), "The size of fpn_feats is not equal to size of fpn_stride"
  218. cls_logits_list = []
  219. bboxes_reg_list = []
  220. for stride, scale_reg, fpn_feat in zip(self.fpn_stride,
  221. self.scales_regs, fpn_feats):
  222. conv_cls_feat, conv_reg_feat = self.conv_feat(fpn_feat)
  223. cls_score = self.gfl_head_cls(conv_cls_feat)
  224. bbox_pred = scale_reg(self.gfl_head_reg(conv_reg_feat))
  225. if self.dgqp_module:
  226. quality_score = self.dgqp_module(bbox_pred)
  227. cls_score = F.sigmoid(cls_score) * quality_score
  228. if not self.training:
  229. cls_score = F.sigmoid(cls_score.transpose([0, 2, 3, 1]))
  230. bbox_pred = bbox_pred.transpose([0, 2, 3, 1])
  231. b, cell_h, cell_w, _ = paddle.shape(cls_score)
  232. y, x = self.get_single_level_center_point(
  233. [cell_h, cell_w], stride, cell_offset=self.cell_offset)
  234. center_points = paddle.stack([x, y], axis=-1)
  235. cls_score = cls_score.reshape([b, -1, self.cls_out_channels])
  236. bbox_pred = self.distribution_project(bbox_pred) * stride
  237. bbox_pred = bbox_pred.reshape([b, cell_h * cell_w, 4])
  238. # NOTE: If keep_ratio=False and image shape value that
  239. # multiples of 32, distance2bbox not set max_shapes parameter
  240. # to speed up model prediction. If need to set max_shapes,
  241. # please use inputs['im_shape'].
  242. bbox_pred = batch_distance2bbox(
  243. center_points, bbox_pred, max_shapes=None)
  244. cls_logits_list.append(cls_score)
  245. bboxes_reg_list.append(bbox_pred)
  246. return (cls_logits_list, bboxes_reg_list)
  247. def _images_to_levels(self, target, num_level_anchors):
  248. """
  249. Convert targets by image to targets by feature level.
  250. """
  251. level_targets = []
  252. start = 0
  253. for n in num_level_anchors:
  254. end = start + n
  255. level_targets.append(target[:, start:end].squeeze(0))
  256. start = end
  257. return level_targets
  258. def _grid_cells_to_center(self, grid_cells):
  259. """
  260. Get center location of each gird cell
  261. Args:
  262. grid_cells: grid cells of a feature map
  263. Returns:
  264. center points
  265. """
  266. cells_cx = (grid_cells[:, 2] + grid_cells[:, 0]) / 2
  267. cells_cy = (grid_cells[:, 3] + grid_cells[:, 1]) / 2
  268. return paddle.stack([cells_cx, cells_cy], axis=-1)
  269. def get_loss(self, gfl_head_outs, gt_meta):
  270. cls_logits, bboxes_reg = gfl_head_outs
  271. num_level_anchors = [
  272. featmap.shape[-2] * featmap.shape[-1] for featmap in cls_logits
  273. ]
  274. grid_cells_list = self._images_to_levels(gt_meta['grid_cells'],
  275. num_level_anchors)
  276. labels_list = self._images_to_levels(gt_meta['labels'],
  277. num_level_anchors)
  278. label_weights_list = self._images_to_levels(gt_meta['label_weights'],
  279. num_level_anchors)
  280. bbox_targets_list = self._images_to_levels(gt_meta['bbox_targets'],
  281. num_level_anchors)
  282. num_total_pos = sum(gt_meta['pos_num'])
  283. try:
  284. num_total_pos = paddle.distributed.all_reduce(num_total_pos.clone(
  285. )) / paddle.distributed.get_world_size()
  286. except:
  287. num_total_pos = max(num_total_pos, 1)
  288. loss_bbox_list, loss_dfl_list, loss_qfl_list, avg_factor = [], [], [], []
  289. for cls_score, bbox_pred, grid_cells, labels, label_weights, bbox_targets, stride in zip(
  290. cls_logits, bboxes_reg, grid_cells_list, labels_list,
  291. label_weights_list, bbox_targets_list, self.fpn_stride):
  292. grid_cells = grid_cells.reshape([-1, 4])
  293. cls_score = cls_score.transpose([0, 2, 3, 1]).reshape(
  294. [-1, self.cls_out_channels])
  295. bbox_pred = bbox_pred.transpose([0, 2, 3, 1]).reshape(
  296. [-1, 4 * (self.reg_max + 1)])
  297. bbox_targets = bbox_targets.reshape([-1, 4])
  298. labels = labels.reshape([-1])
  299. label_weights = label_weights.reshape([-1])
  300. bg_class_ind = self.num_classes
  301. pos_inds = paddle.nonzero(
  302. paddle.logical_and((labels >= 0), (labels < bg_class_ind)),
  303. as_tuple=False).squeeze(1)
  304. score = np.zeros(labels.shape)
  305. if len(pos_inds) > 0:
  306. pos_bbox_targets = paddle.gather(bbox_targets, pos_inds, axis=0)
  307. pos_bbox_pred = paddle.gather(bbox_pred, pos_inds, axis=0)
  308. pos_grid_cells = paddle.gather(grid_cells, pos_inds, axis=0)
  309. pos_grid_cell_centers = self._grid_cells_to_center(
  310. pos_grid_cells) / stride
  311. weight_targets = F.sigmoid(cls_score.detach())
  312. weight_targets = paddle.gather(
  313. weight_targets.max(axis=1, keepdim=True), pos_inds, axis=0)
  314. pos_bbox_pred_corners = self.distribution_project(pos_bbox_pred)
  315. pos_decode_bbox_pred = distance2bbox(pos_grid_cell_centers,
  316. pos_bbox_pred_corners)
  317. pos_decode_bbox_targets = pos_bbox_targets / stride
  318. bbox_iou = bbox_overlaps(
  319. pos_decode_bbox_pred.detach().numpy(),
  320. pos_decode_bbox_targets.detach().numpy(),
  321. is_aligned=True)
  322. score[pos_inds.numpy()] = bbox_iou
  323. pred_corners = pos_bbox_pred.reshape([-1, self.reg_max + 1])
  324. target_corners = bbox2distance(pos_grid_cell_centers,
  325. pos_decode_bbox_targets,
  326. self.reg_max).reshape([-1])
  327. # regression loss
  328. loss_bbox = paddle.sum(
  329. self.loss_bbox(pos_decode_bbox_pred,
  330. pos_decode_bbox_targets) * weight_targets)
  331. # dfl loss
  332. loss_dfl = self.loss_dfl(
  333. pred_corners,
  334. target_corners,
  335. weight=weight_targets.expand([-1, 4]).reshape([-1]),
  336. avg_factor=4.0)
  337. else:
  338. loss_bbox = bbox_pred.sum() * 0
  339. loss_dfl = bbox_pred.sum() * 0
  340. weight_targets = paddle.to_tensor([0], dtype='float32')
  341. # qfl loss
  342. score = paddle.to_tensor(score)
  343. loss_qfl = self.loss_qfl(
  344. cls_score, (labels, score),
  345. weight=label_weights,
  346. avg_factor=num_total_pos)
  347. loss_bbox_list.append(loss_bbox)
  348. loss_dfl_list.append(loss_dfl)
  349. loss_qfl_list.append(loss_qfl)
  350. avg_factor.append(weight_targets.sum())
  351. avg_factor = sum(avg_factor)
  352. try:
  353. avg_factor = paddle.distributed.all_reduce(avg_factor.clone())
  354. avg_factor = paddle.clip(
  355. avg_factor / paddle.distributed.get_world_size(), min=1)
  356. except:
  357. avg_factor = max(avg_factor.item(), 1)
  358. if avg_factor <= 0:
  359. loss_qfl = paddle.to_tensor(0, dtype='float32', stop_gradient=False)
  360. loss_bbox = paddle.to_tensor(
  361. 0, dtype='float32', stop_gradient=False)
  362. loss_dfl = paddle.to_tensor(0, dtype='float32', stop_gradient=False)
  363. else:
  364. losses_bbox = list(map(lambda x: x / avg_factor, loss_bbox_list))
  365. losses_dfl = list(map(lambda x: x / avg_factor, loss_dfl_list))
  366. loss_qfl = sum(loss_qfl_list)
  367. loss_bbox = sum(losses_bbox)
  368. loss_dfl = sum(losses_dfl)
  369. loss_states = dict(
  370. loss_qfl=loss_qfl, loss_bbox=loss_bbox, loss_dfl=loss_dfl)
  371. return loss_states
  372. def get_single_level_center_point(self, featmap_size, stride,
  373. cell_offset=0):
  374. """
  375. Generate pixel centers of a single stage feature map.
  376. Args:
  377. featmap_size: height and width of the feature map
  378. stride: down sample stride of the feature map
  379. Returns:
  380. y and x of the center points
  381. """
  382. h, w = featmap_size
  383. x_range = (paddle.arange(w, dtype='float32') + cell_offset) * stride
  384. y_range = (paddle.arange(h, dtype='float32') + cell_offset) * stride
  385. y, x = paddle.meshgrid(y_range, x_range)
  386. y = y.flatten()
  387. x = x.flatten()
  388. return y, x
  389. def post_process(self, gfl_head_outs, im_shape, scale_factor):
  390. cls_scores, bboxes_reg = gfl_head_outs
  391. bboxes = paddle.concat(bboxes_reg, axis=1)
  392. # rescale: [h_scale, w_scale] -> [w_scale, h_scale, w_scale, h_scale]
  393. im_scale = scale_factor.flip([1]).tile([1, 2]).unsqueeze(1)
  394. bboxes /= im_scale
  395. mlvl_scores = paddle.concat(cls_scores, axis=1)
  396. mlvl_scores = mlvl_scores.transpose([0, 2, 1])
  397. bbox_pred, bbox_num, _ = self.nms(bboxes, mlvl_scores)
  398. return bbox_pred, bbox_num