simota_head.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. # The code is based on:
  15. # https://github.com/open-mmlab/mmdetection/blob/master/mmdet/models/dense_heads/yolox_head.py
  16. from __future__ import absolute_import
  17. from __future__ import division
  18. from __future__ import print_function
  19. import math
  20. from functools import partial
  21. import numpy as np
  22. import paddle
  23. import paddle.nn as nn
  24. import paddle.nn.functional as F
  25. from paddle import ParamAttr
  26. from paddle.nn.initializer import Normal, Constant
  27. from ppdet.core.workspace import register
  28. from ppdet.modeling.bbox_utils import distance2bbox, bbox2distance
  29. from ppdet.data.transform.atss_assigner import bbox_overlaps
  30. from .gfl_head import GFLHead
  31. @register
  32. class OTAHead(GFLHead):
  33. """
  34. OTAHead
  35. Args:
  36. conv_feat (object): Instance of 'FCOSFeat'
  37. num_classes (int): Number of classes
  38. fpn_stride (list): The stride of each FPN Layer
  39. prior_prob (float): Used to set the bias init for the class prediction layer
  40. loss_qfl (object): Instance of QualityFocalLoss.
  41. loss_dfl (object): Instance of DistributionFocalLoss.
  42. loss_bbox (object): Instance of bbox loss.
  43. assigner (object): Instance of label assigner.
  44. reg_max: Max value of integral set :math: `{0, ..., reg_max}`
  45. n QFL setting. Default: 16.
  46. """
  47. __inject__ = [
  48. 'conv_feat', 'dgqp_module', 'loss_class', 'loss_dfl', 'loss_bbox',
  49. 'assigner', 'nms'
  50. ]
  51. __shared__ = ['num_classes']
  52. def __init__(self,
  53. conv_feat='FCOSFeat',
  54. dgqp_module=None,
  55. num_classes=80,
  56. fpn_stride=[8, 16, 32, 64, 128],
  57. prior_prob=0.01,
  58. loss_class='QualityFocalLoss',
  59. loss_dfl='DistributionFocalLoss',
  60. loss_bbox='GIoULoss',
  61. assigner='SimOTAAssigner',
  62. reg_max=16,
  63. feat_in_chan=256,
  64. nms=None,
  65. nms_pre=1000,
  66. cell_offset=0):
  67. super(OTAHead, self).__init__(
  68. conv_feat=conv_feat,
  69. dgqp_module=dgqp_module,
  70. num_classes=num_classes,
  71. fpn_stride=fpn_stride,
  72. prior_prob=prior_prob,
  73. loss_class=loss_class,
  74. loss_dfl=loss_dfl,
  75. loss_bbox=loss_bbox,
  76. reg_max=reg_max,
  77. feat_in_chan=feat_in_chan,
  78. nms=nms,
  79. nms_pre=nms_pre,
  80. cell_offset=cell_offset)
  81. self.conv_feat = conv_feat
  82. self.dgqp_module = dgqp_module
  83. self.num_classes = num_classes
  84. self.fpn_stride = fpn_stride
  85. self.prior_prob = prior_prob
  86. self.loss_qfl = loss_class
  87. self.loss_dfl = loss_dfl
  88. self.loss_bbox = loss_bbox
  89. self.reg_max = reg_max
  90. self.feat_in_chan = feat_in_chan
  91. self.nms = nms
  92. self.nms_pre = nms_pre
  93. self.cell_offset = cell_offset
  94. self.use_sigmoid = self.loss_qfl.use_sigmoid
  95. self.assigner = assigner
  96. def _get_target_single(self, flatten_cls_pred, flatten_center_and_stride,
  97. flatten_bbox, gt_bboxes, gt_labels):
  98. """Compute targets for priors in a single image.
  99. """
  100. pos_num, label, label_weight, bbox_target = self.assigner(
  101. F.sigmoid(flatten_cls_pred), flatten_center_and_stride,
  102. flatten_bbox, gt_bboxes, gt_labels)
  103. return (pos_num, label, label_weight, bbox_target)
  104. def get_loss(self, head_outs, gt_meta):
  105. cls_scores, bbox_preds = head_outs
  106. num_level_anchors = [
  107. featmap.shape[-2] * featmap.shape[-1] for featmap in cls_scores
  108. ]
  109. num_imgs = gt_meta['im_id'].shape[0]
  110. featmap_sizes = [[featmap.shape[-2], featmap.shape[-1]]
  111. for featmap in cls_scores]
  112. decode_bbox_preds = []
  113. center_and_strides = []
  114. for featmap_size, stride, bbox_pred in zip(featmap_sizes,
  115. self.fpn_stride, bbox_preds):
  116. # center in origin image
  117. yy, xx = self.get_single_level_center_point(featmap_size, stride,
  118. self.cell_offset)
  119. center_and_stride = paddle.stack([xx, yy, stride, stride], -1).tile(
  120. [num_imgs, 1, 1])
  121. center_and_strides.append(center_and_stride)
  122. center_in_feature = center_and_stride.reshape(
  123. [-1, 4])[:, :-2] / stride
  124. bbox_pred = bbox_pred.transpose([0, 2, 3, 1]).reshape(
  125. [num_imgs, -1, 4 * (self.reg_max + 1)])
  126. pred_distances = self.distribution_project(bbox_pred)
  127. decode_bbox_pred_wo_stride = distance2bbox(
  128. center_in_feature, pred_distances).reshape([num_imgs, -1, 4])
  129. decode_bbox_preds.append(decode_bbox_pred_wo_stride * stride)
  130. flatten_cls_preds = [
  131. cls_pred.transpose([0, 2, 3, 1]).reshape(
  132. [num_imgs, -1, self.cls_out_channels])
  133. for cls_pred in cls_scores
  134. ]
  135. flatten_cls_preds = paddle.concat(flatten_cls_preds, axis=1)
  136. flatten_bboxes = paddle.concat(decode_bbox_preds, axis=1)
  137. flatten_center_and_strides = paddle.concat(center_and_strides, axis=1)
  138. gt_boxes, gt_labels = gt_meta['gt_bbox'], gt_meta['gt_class']
  139. pos_num_l, label_l, label_weight_l, bbox_target_l = [], [], [], []
  140. for flatten_cls_pred,flatten_center_and_stride,flatten_bbox,gt_box, gt_label \
  141. in zip(flatten_cls_preds.detach(),flatten_center_and_strides.detach(), \
  142. flatten_bboxes.detach(),gt_boxes, gt_labels):
  143. pos_num, label, label_weight, bbox_target = self._get_target_single(
  144. flatten_cls_pred, flatten_center_and_stride, flatten_bbox,
  145. gt_box, gt_label)
  146. pos_num_l.append(pos_num)
  147. label_l.append(label)
  148. label_weight_l.append(label_weight)
  149. bbox_target_l.append(bbox_target)
  150. labels = paddle.to_tensor(np.stack(label_l, axis=0))
  151. label_weights = paddle.to_tensor(np.stack(label_weight_l, axis=0))
  152. bbox_targets = paddle.to_tensor(np.stack(bbox_target_l, axis=0))
  153. center_and_strides_list = self._images_to_levels(
  154. flatten_center_and_strides, num_level_anchors)
  155. labels_list = self._images_to_levels(labels, num_level_anchors)
  156. label_weights_list = self._images_to_levels(label_weights,
  157. num_level_anchors)
  158. bbox_targets_list = self._images_to_levels(bbox_targets,
  159. num_level_anchors)
  160. num_total_pos = sum(pos_num_l)
  161. try:
  162. num_total_pos = paddle.distributed.all_reduce(num_total_pos.clone(
  163. )) / paddle.distributed.get_world_size()
  164. except:
  165. num_total_pos = max(num_total_pos, 1)
  166. loss_bbox_list, loss_dfl_list, loss_qfl_list, avg_factor = [], [], [], []
  167. for cls_score, bbox_pred, center_and_strides, labels, label_weights, bbox_targets, stride in zip(
  168. cls_scores, bbox_preds, center_and_strides_list, labels_list,
  169. label_weights_list, bbox_targets_list, self.fpn_stride):
  170. center_and_strides = center_and_strides.reshape([-1, 4])
  171. cls_score = cls_score.transpose([0, 2, 3, 1]).reshape(
  172. [-1, self.cls_out_channels])
  173. bbox_pred = bbox_pred.transpose([0, 2, 3, 1]).reshape(
  174. [-1, 4 * (self.reg_max + 1)])
  175. bbox_targets = bbox_targets.reshape([-1, 4])
  176. labels = labels.reshape([-1])
  177. label_weights = label_weights.reshape([-1])
  178. bg_class_ind = self.num_classes
  179. pos_inds = paddle.nonzero(
  180. paddle.logical_and((labels >= 0), (labels < bg_class_ind)),
  181. as_tuple=False).squeeze(1)
  182. score = np.zeros(labels.shape)
  183. if len(pos_inds) > 0:
  184. pos_bbox_targets = paddle.gather(bbox_targets, pos_inds, axis=0)
  185. pos_bbox_pred = paddle.gather(bbox_pred, pos_inds, axis=0)
  186. pos_centers = paddle.gather(
  187. center_and_strides[:, :-2], pos_inds, axis=0) / stride
  188. weight_targets = F.sigmoid(cls_score.detach())
  189. weight_targets = paddle.gather(
  190. weight_targets.max(axis=1, keepdim=True), pos_inds, axis=0)
  191. pos_bbox_pred_corners = self.distribution_project(pos_bbox_pred)
  192. pos_decode_bbox_pred = distance2bbox(pos_centers,
  193. pos_bbox_pred_corners)
  194. pos_decode_bbox_targets = pos_bbox_targets / stride
  195. bbox_iou = bbox_overlaps(
  196. pos_decode_bbox_pred.detach().numpy(),
  197. pos_decode_bbox_targets.detach().numpy(),
  198. is_aligned=True)
  199. score[pos_inds.numpy()] = bbox_iou
  200. pred_corners = pos_bbox_pred.reshape([-1, self.reg_max + 1])
  201. target_corners = bbox2distance(pos_centers,
  202. pos_decode_bbox_targets,
  203. self.reg_max).reshape([-1])
  204. # regression loss
  205. loss_bbox = paddle.sum(
  206. self.loss_bbox(pos_decode_bbox_pred,
  207. pos_decode_bbox_targets) * weight_targets)
  208. # dfl loss
  209. loss_dfl = self.loss_dfl(
  210. pred_corners,
  211. target_corners,
  212. weight=weight_targets.expand([-1, 4]).reshape([-1]),
  213. avg_factor=4.0)
  214. else:
  215. loss_bbox = bbox_pred.sum() * 0
  216. loss_dfl = bbox_pred.sum() * 0
  217. weight_targets = paddle.to_tensor([0], dtype='float32')
  218. # qfl loss
  219. score = paddle.to_tensor(score)
  220. loss_qfl = self.loss_qfl(
  221. cls_score, (labels, score),
  222. weight=label_weights,
  223. avg_factor=num_total_pos)
  224. loss_bbox_list.append(loss_bbox)
  225. loss_dfl_list.append(loss_dfl)
  226. loss_qfl_list.append(loss_qfl)
  227. avg_factor.append(weight_targets.sum())
  228. avg_factor = sum(avg_factor)
  229. try:
  230. avg_factor = paddle.distributed.all_reduce(avg_factor.clone())
  231. avg_factor = paddle.clip(
  232. avg_factor / paddle.distributed.get_world_size(), min=1)
  233. except:
  234. avg_factor = max(avg_factor.item(), 1)
  235. if avg_factor <= 0:
  236. loss_qfl = paddle.to_tensor(0, dtype='float32', stop_gradient=False)
  237. loss_bbox = paddle.to_tensor(
  238. 0, dtype='float32', stop_gradient=False)
  239. loss_dfl = paddle.to_tensor(0, dtype='float32', stop_gradient=False)
  240. else:
  241. losses_bbox = list(map(lambda x: x / avg_factor, loss_bbox_list))
  242. losses_dfl = list(map(lambda x: x / avg_factor, loss_dfl_list))
  243. loss_qfl = sum(loss_qfl_list)
  244. loss_bbox = sum(losses_bbox)
  245. loss_dfl = sum(losses_dfl)
  246. loss_states = dict(
  247. loss_qfl=loss_qfl, loss_bbox=loss_bbox, loss_dfl=loss_dfl)
  248. return loss_states
  249. @register
  250. class OTAVFLHead(OTAHead):
  251. __inject__ = [
  252. 'conv_feat', 'dgqp_module', 'loss_class', 'loss_dfl', 'loss_bbox',
  253. 'assigner', 'nms'
  254. ]
  255. __shared__ = ['num_classes']
  256. def __init__(self,
  257. conv_feat='FCOSFeat',
  258. dgqp_module=None,
  259. num_classes=80,
  260. fpn_stride=[8, 16, 32, 64, 128],
  261. prior_prob=0.01,
  262. loss_class='VarifocalLoss',
  263. loss_dfl='DistributionFocalLoss',
  264. loss_bbox='GIoULoss',
  265. assigner='SimOTAAssigner',
  266. reg_max=16,
  267. feat_in_chan=256,
  268. nms=None,
  269. nms_pre=1000,
  270. cell_offset=0):
  271. super(OTAVFLHead, self).__init__(
  272. conv_feat=conv_feat,
  273. dgqp_module=dgqp_module,
  274. num_classes=num_classes,
  275. fpn_stride=fpn_stride,
  276. prior_prob=prior_prob,
  277. loss_class=loss_class,
  278. loss_dfl=loss_dfl,
  279. loss_bbox=loss_bbox,
  280. reg_max=reg_max,
  281. feat_in_chan=feat_in_chan,
  282. nms=nms,
  283. nms_pre=nms_pre,
  284. cell_offset=cell_offset)
  285. self.conv_feat = conv_feat
  286. self.dgqp_module = dgqp_module
  287. self.num_classes = num_classes
  288. self.fpn_stride = fpn_stride
  289. self.prior_prob = prior_prob
  290. self.loss_vfl = loss_class
  291. self.loss_dfl = loss_dfl
  292. self.loss_bbox = loss_bbox
  293. self.reg_max = reg_max
  294. self.feat_in_chan = feat_in_chan
  295. self.nms = nms
  296. self.nms_pre = nms_pre
  297. self.cell_offset = cell_offset
  298. self.use_sigmoid = self.loss_vfl.use_sigmoid
  299. self.assigner = assigner
  300. def get_loss(self, head_outs, gt_meta):
  301. cls_scores, bbox_preds = head_outs
  302. num_level_anchors = [
  303. featmap.shape[-2] * featmap.shape[-1] for featmap in cls_scores
  304. ]
  305. num_imgs = gt_meta['im_id'].shape[0]
  306. featmap_sizes = [[featmap.shape[-2], featmap.shape[-1]]
  307. for featmap in cls_scores]
  308. decode_bbox_preds = []
  309. center_and_strides = []
  310. for featmap_size, stride, bbox_pred in zip(featmap_sizes,
  311. self.fpn_stride, bbox_preds):
  312. # center in origin image
  313. yy, xx = self.get_single_level_center_point(featmap_size, stride,
  314. self.cell_offset)
  315. strides = paddle.full((len(xx), ), stride)
  316. center_and_stride = paddle.stack([xx, yy, strides, strides],
  317. -1).tile([num_imgs, 1, 1])
  318. center_and_strides.append(center_and_stride)
  319. center_in_feature = center_and_stride.reshape(
  320. [-1, 4])[:, :-2] / stride
  321. bbox_pred = bbox_pred.transpose([0, 2, 3, 1]).reshape(
  322. [num_imgs, -1, 4 * (self.reg_max + 1)])
  323. pred_distances = self.distribution_project(bbox_pred)
  324. decode_bbox_pred_wo_stride = distance2bbox(
  325. center_in_feature, pred_distances).reshape([num_imgs, -1, 4])
  326. decode_bbox_preds.append(decode_bbox_pred_wo_stride * stride)
  327. flatten_cls_preds = [
  328. cls_pred.transpose([0, 2, 3, 1]).reshape(
  329. [num_imgs, -1, self.cls_out_channels])
  330. for cls_pred in cls_scores
  331. ]
  332. flatten_cls_preds = paddle.concat(flatten_cls_preds, axis=1)
  333. flatten_bboxes = paddle.concat(decode_bbox_preds, axis=1)
  334. flatten_center_and_strides = paddle.concat(center_and_strides, axis=1)
  335. gt_boxes, gt_labels = gt_meta['gt_bbox'], gt_meta['gt_class']
  336. pos_num_l, label_l, label_weight_l, bbox_target_l = [], [], [], []
  337. for flatten_cls_pred, flatten_center_and_stride, flatten_bbox,gt_box,gt_label \
  338. in zip(flatten_cls_preds.detach(), flatten_center_and_strides.detach(), \
  339. flatten_bboxes.detach(),gt_boxes,gt_labels):
  340. pos_num, label, label_weight, bbox_target = self._get_target_single(
  341. flatten_cls_pred, flatten_center_and_stride, flatten_bbox,
  342. gt_box, gt_label)
  343. pos_num_l.append(pos_num)
  344. label_l.append(label)
  345. label_weight_l.append(label_weight)
  346. bbox_target_l.append(bbox_target)
  347. labels = paddle.to_tensor(np.stack(label_l, axis=0))
  348. label_weights = paddle.to_tensor(np.stack(label_weight_l, axis=0))
  349. bbox_targets = paddle.to_tensor(np.stack(bbox_target_l, axis=0))
  350. center_and_strides_list = self._images_to_levels(
  351. flatten_center_and_strides, num_level_anchors)
  352. labels_list = self._images_to_levels(labels, num_level_anchors)
  353. label_weights_list = self._images_to_levels(label_weights,
  354. num_level_anchors)
  355. bbox_targets_list = self._images_to_levels(bbox_targets,
  356. num_level_anchors)
  357. num_total_pos = sum(pos_num_l)
  358. try:
  359. num_total_pos = paddle.distributed.all_reduce(num_total_pos.clone(
  360. )) / paddle.distributed.get_world_size()
  361. except:
  362. num_total_pos = max(num_total_pos, 1)
  363. loss_bbox_list, loss_dfl_list, loss_vfl_list, avg_factor = [], [], [], []
  364. for cls_score, bbox_pred, center_and_strides, labels, label_weights, bbox_targets, stride in zip(
  365. cls_scores, bbox_preds, center_and_strides_list, labels_list,
  366. label_weights_list, bbox_targets_list, self.fpn_stride):
  367. center_and_strides = center_and_strides.reshape([-1, 4])
  368. cls_score = cls_score.transpose([0, 2, 3, 1]).reshape(
  369. [-1, self.cls_out_channels])
  370. bbox_pred = bbox_pred.transpose([0, 2, 3, 1]).reshape(
  371. [-1, 4 * (self.reg_max + 1)])
  372. bbox_targets = bbox_targets.reshape([-1, 4])
  373. labels = labels.reshape([-1])
  374. bg_class_ind = self.num_classes
  375. pos_inds = paddle.nonzero(
  376. paddle.logical_and((labels >= 0), (labels < bg_class_ind)),
  377. as_tuple=False).squeeze(1)
  378. # vfl
  379. vfl_score = np.zeros(cls_score.shape)
  380. if len(pos_inds) > 0:
  381. pos_bbox_targets = paddle.gather(bbox_targets, pos_inds, axis=0)
  382. pos_bbox_pred = paddle.gather(bbox_pred, pos_inds, axis=0)
  383. pos_centers = paddle.gather(
  384. center_and_strides[:, :-2], pos_inds, axis=0) / stride
  385. weight_targets = F.sigmoid(cls_score.detach())
  386. weight_targets = paddle.gather(
  387. weight_targets.max(axis=1, keepdim=True), pos_inds, axis=0)
  388. pos_bbox_pred_corners = self.distribution_project(pos_bbox_pred)
  389. pos_decode_bbox_pred = distance2bbox(pos_centers,
  390. pos_bbox_pred_corners)
  391. pos_decode_bbox_targets = pos_bbox_targets / stride
  392. bbox_iou = bbox_overlaps(
  393. pos_decode_bbox_pred.detach().numpy(),
  394. pos_decode_bbox_targets.detach().numpy(),
  395. is_aligned=True)
  396. # vfl
  397. pos_labels = paddle.gather(labels, pos_inds, axis=0)
  398. vfl_score[pos_inds.numpy(), pos_labels] = bbox_iou
  399. pred_corners = pos_bbox_pred.reshape([-1, self.reg_max + 1])
  400. target_corners = bbox2distance(pos_centers,
  401. pos_decode_bbox_targets,
  402. self.reg_max).reshape([-1])
  403. # regression loss
  404. loss_bbox = paddle.sum(
  405. self.loss_bbox(pos_decode_bbox_pred,
  406. pos_decode_bbox_targets) * weight_targets)
  407. # dfl loss
  408. loss_dfl = self.loss_dfl(
  409. pred_corners,
  410. target_corners,
  411. weight=weight_targets.expand([-1, 4]).reshape([-1]),
  412. avg_factor=4.0)
  413. else:
  414. loss_bbox = bbox_pred.sum() * 0
  415. loss_dfl = bbox_pred.sum() * 0
  416. weight_targets = paddle.to_tensor([0], dtype='float32')
  417. # vfl loss
  418. num_pos_avg_per_gpu = num_total_pos
  419. vfl_score = paddle.to_tensor(vfl_score)
  420. loss_vfl = self.loss_vfl(
  421. cls_score, vfl_score, avg_factor=num_pos_avg_per_gpu)
  422. loss_bbox_list.append(loss_bbox)
  423. loss_dfl_list.append(loss_dfl)
  424. loss_vfl_list.append(loss_vfl)
  425. avg_factor.append(weight_targets.sum())
  426. avg_factor = sum(avg_factor)
  427. try:
  428. avg_factor = paddle.distributed.all_reduce(avg_factor.clone())
  429. avg_factor = paddle.clip(
  430. avg_factor / paddle.distributed.get_world_size(), min=1)
  431. except:
  432. avg_factor = max(avg_factor.item(), 1)
  433. if avg_factor <= 0:
  434. loss_vfl = paddle.to_tensor(0, dtype='float32', stop_gradient=False)
  435. loss_bbox = paddle.to_tensor(
  436. 0, dtype='float32', stop_gradient=False)
  437. loss_dfl = paddle.to_tensor(0, dtype='float32', stop_gradient=False)
  438. else:
  439. losses_bbox = list(map(lambda x: x / avg_factor, loss_bbox_list))
  440. losses_dfl = list(map(lambda x: x / avg_factor, loss_dfl_list))
  441. loss_vfl = sum(loss_vfl_list)
  442. loss_bbox = sum(losses_bbox)
  443. loss_dfl = sum(losses_dfl)
  444. loss_states = dict(
  445. loss_vfl=loss_vfl, loss_bbox=loss_bbox, loss_dfl=loss_dfl)
  446. return loss_states