post_process.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. import paddle
  16. import paddle.nn as nn
  17. import paddle.nn.functional as F
  18. from ppdet.core.workspace import register
  19. from ppdet.modeling.bbox_utils import nonempty_bbox, rbox2poly
  20. from ppdet.modeling.layers import TTFBox
  21. from .transformers import bbox_cxcywh_to_xyxy
  22. try:
  23. from collections.abc import Sequence
  24. except Exception:
  25. from collections import Sequence
  26. __all__ = [
  27. 'BBoxPostProcess', 'MaskPostProcess', 'FCOSPostProcess',
  28. 'S2ANetBBoxPostProcess', 'JDEBBoxPostProcess', 'CenterNetPostProcess',
  29. 'DETRBBoxPostProcess', 'SparsePostProcess'
  30. ]
  31. @register
  32. class BBoxPostProcess(object):
  33. __shared__ = ['num_classes', 'export_onnx']
  34. __inject__ = ['decode', 'nms']
  35. def __init__(self, num_classes=80, decode=None, nms=None,
  36. export_onnx=False):
  37. super(BBoxPostProcess, self).__init__()
  38. self.num_classes = num_classes
  39. self.decode = decode
  40. self.nms = nms
  41. self.export_onnx = export_onnx
  42. def __call__(self, head_out, rois, im_shape, scale_factor):
  43. """
  44. Decode the bbox and do NMS if needed.
  45. Args:
  46. head_out (tuple): bbox_pred and cls_prob of bbox_head output.
  47. rois (tuple): roi and rois_num of rpn_head output.
  48. im_shape (Tensor): The shape of the input image.
  49. scale_factor (Tensor): The scale factor of the input image.
  50. export_onnx (bool): whether export model to onnx
  51. Returns:
  52. bbox_pred (Tensor): The output prediction with shape [N, 6], including
  53. labels, scores and bboxes. The size of bboxes are corresponding
  54. to the input image, the bboxes may be used in other branch.
  55. bbox_num (Tensor): The number of prediction boxes of each batch with
  56. shape [1], and is N.
  57. """
  58. if self.nms is not None:
  59. bboxes, score = self.decode(head_out, rois, im_shape, scale_factor)
  60. bbox_pred, bbox_num, _ = self.nms(bboxes, score, self.num_classes)
  61. else:
  62. bbox_pred, bbox_num = self.decode(head_out, rois, im_shape,
  63. scale_factor)
  64. if self.export_onnx:
  65. # add fake box after postprocess when exporting onnx
  66. fake_bboxes = paddle.to_tensor(
  67. np.array(
  68. [[0., 0.0, 0.0, 0.0, 1.0, 1.0]], dtype='float32'))
  69. bbox_pred = paddle.concat([bbox_pred, fake_bboxes])
  70. bbox_num = bbox_num + 1
  71. return bbox_pred, bbox_num
  72. def get_pred(self, bboxes, bbox_num, im_shape, scale_factor):
  73. """
  74. Rescale, clip and filter the bbox from the output of NMS to
  75. get final prediction.
  76. Notes:
  77. Currently only support bs = 1.
  78. Args:
  79. bboxes (Tensor): The output bboxes with shape [N, 6] after decode
  80. and NMS, including labels, scores and bboxes.
  81. bbox_num (Tensor): The number of prediction boxes of each batch with
  82. shape [1], and is N.
  83. im_shape (Tensor): The shape of the input image.
  84. scale_factor (Tensor): The scale factor of the input image.
  85. Returns:
  86. pred_result (Tensor): The final prediction results with shape [N, 6]
  87. including labels, scores and bboxes.
  88. """
  89. if not self.export_onnx:
  90. bboxes_list = []
  91. bbox_num_list = []
  92. id_start = 0
  93. fake_bboxes = paddle.to_tensor(
  94. np.array(
  95. [[0., 0.0, 0.0, 0.0, 1.0, 1.0]], dtype='float32'))
  96. fake_bbox_num = paddle.to_tensor(np.array([1], dtype='int32'))
  97. # add fake bbox when output is empty for each batch
  98. for i in range(bbox_num.shape[0]):
  99. if bbox_num[i] == 0:
  100. bboxes_i = fake_bboxes
  101. bbox_num_i = fake_bbox_num
  102. else:
  103. bboxes_i = bboxes[id_start:id_start + bbox_num[i], :]
  104. bbox_num_i = bbox_num[i]
  105. id_start += bbox_num[i]
  106. bboxes_list.append(bboxes_i)
  107. bbox_num_list.append(bbox_num_i)
  108. bboxes = paddle.concat(bboxes_list)
  109. bbox_num = paddle.concat(bbox_num_list)
  110. origin_shape = paddle.floor(im_shape / scale_factor + 0.5)
  111. if not self.export_onnx:
  112. origin_shape_list = []
  113. scale_factor_list = []
  114. # scale_factor: scale_y, scale_x
  115. for i in range(bbox_num.shape[0]):
  116. expand_shape = paddle.expand(origin_shape[i:i + 1, :],
  117. [bbox_num[i], 2])
  118. scale_y, scale_x = scale_factor[i][0], scale_factor[i][1]
  119. scale = paddle.concat([scale_x, scale_y, scale_x, scale_y])
  120. expand_scale = paddle.expand(scale, [bbox_num[i], 4])
  121. origin_shape_list.append(expand_shape)
  122. scale_factor_list.append(expand_scale)
  123. self.origin_shape_list = paddle.concat(origin_shape_list)
  124. scale_factor_list = paddle.concat(scale_factor_list)
  125. else:
  126. # simplify the computation for bs=1 when exporting onnx
  127. scale_y, scale_x = scale_factor[0][0], scale_factor[0][1]
  128. scale = paddle.concat(
  129. [scale_x, scale_y, scale_x, scale_y]).unsqueeze(0)
  130. self.origin_shape_list = paddle.expand(origin_shape,
  131. [bbox_num[0], 2])
  132. scale_factor_list = paddle.expand(scale, [bbox_num[0], 4])
  133. # bboxes: [N, 6], label, score, bbox
  134. pred_label = bboxes[:, 0:1]
  135. pred_score = bboxes[:, 1:2]
  136. pred_bbox = bboxes[:, 2:]
  137. # rescale bbox to original image
  138. scaled_bbox = pred_bbox / scale_factor_list
  139. origin_h = self.origin_shape_list[:, 0]
  140. origin_w = self.origin_shape_list[:, 1]
  141. zeros = paddle.zeros_like(origin_h)
  142. # clip bbox to [0, original_size]
  143. x1 = paddle.maximum(paddle.minimum(scaled_bbox[:, 0], origin_w), zeros)
  144. y1 = paddle.maximum(paddle.minimum(scaled_bbox[:, 1], origin_h), zeros)
  145. x2 = paddle.maximum(paddle.minimum(scaled_bbox[:, 2], origin_w), zeros)
  146. y2 = paddle.maximum(paddle.minimum(scaled_bbox[:, 3], origin_h), zeros)
  147. pred_bbox = paddle.stack([x1, y1, x2, y2], axis=-1)
  148. # filter empty bbox
  149. keep_mask = nonempty_bbox(pred_bbox, return_mask=True)
  150. keep_mask = paddle.unsqueeze(keep_mask, [1])
  151. pred_label = paddle.where(keep_mask, pred_label,
  152. paddle.ones_like(pred_label) * -1)
  153. pred_result = paddle.concat([pred_label, pred_score, pred_bbox], axis=1)
  154. return bboxes, pred_result, bbox_num
  155. def get_origin_shape(self, ):
  156. return self.origin_shape_list
  157. @register
  158. class MaskPostProcess(object):
  159. __shared__ = ['export_onnx', 'assign_on_cpu']
  160. """
  161. refer to:
  162. https://github.com/facebookresearch/detectron2/layers/mask_ops.py
  163. Get Mask output according to the output from model
  164. """
  165. def __init__(self,
  166. binary_thresh=0.5,
  167. export_onnx=False,
  168. assign_on_cpu=False):
  169. super(MaskPostProcess, self).__init__()
  170. self.binary_thresh = binary_thresh
  171. self.export_onnx = export_onnx
  172. self.assign_on_cpu = assign_on_cpu
  173. def paste_mask(self, masks, boxes, im_h, im_w):
  174. """
  175. Paste the mask prediction to the original image.
  176. """
  177. x0_int, y0_int = 0, 0
  178. x1_int, y1_int = im_w, im_h
  179. x0, y0, x1, y1 = paddle.split(boxes, 4, axis=1)
  180. N = masks.shape[0]
  181. img_y = paddle.arange(y0_int, y1_int) + 0.5
  182. img_x = paddle.arange(x0_int, x1_int) + 0.5
  183. img_y = (img_y - y0) / (y1 - y0) * 2 - 1
  184. img_x = (img_x - x0) / (x1 - x0) * 2 - 1
  185. # img_x, img_y have shapes (N, w), (N, h)
  186. if self.assign_on_cpu:
  187. paddle.set_device('cpu')
  188. gx = img_x[:, None, :].expand(
  189. [N, paddle.shape(img_y)[1], paddle.shape(img_x)[1]])
  190. gy = img_y[:, :, None].expand(
  191. [N, paddle.shape(img_y)[1], paddle.shape(img_x)[1]])
  192. grid = paddle.stack([gx, gy], axis=3)
  193. img_masks = F.grid_sample(masks, grid, align_corners=False)
  194. return img_masks[:, 0]
  195. def __call__(self, mask_out, bboxes, bbox_num, origin_shape):
  196. """
  197. Decode the mask_out and paste the mask to the origin image.
  198. Args:
  199. mask_out (Tensor): mask_head output with shape [N, 28, 28].
  200. bbox_pred (Tensor): The output bboxes with shape [N, 6] after decode
  201. and NMS, including labels, scores and bboxes.
  202. bbox_num (Tensor): The number of prediction boxes of each batch with
  203. shape [1], and is N.
  204. origin_shape (Tensor): The origin shape of the input image, the tensor
  205. shape is [N, 2], and each row is [h, w].
  206. Returns:
  207. pred_result (Tensor): The final prediction mask results with shape
  208. [N, h, w] in binary mask style.
  209. """
  210. num_mask = mask_out.shape[0]
  211. origin_shape = paddle.cast(origin_shape, 'int32')
  212. device = paddle.device.get_device()
  213. if self.export_onnx:
  214. h, w = origin_shape[0][0], origin_shape[0][1]
  215. mask_onnx = self.paste_mask(mask_out[:, None, :, :], bboxes[:, 2:],
  216. h, w)
  217. mask_onnx = mask_onnx >= self.binary_thresh
  218. pred_result = paddle.cast(mask_onnx, 'int32')
  219. else:
  220. max_h = paddle.max(origin_shape[:, 0])
  221. max_w = paddle.max(origin_shape[:, 1])
  222. pred_result = paddle.zeros(
  223. [num_mask, max_h, max_w], dtype='int32') - 1
  224. id_start = 0
  225. for i in range(paddle.shape(bbox_num)[0]):
  226. bboxes_i = bboxes[id_start:id_start + bbox_num[i], :]
  227. mask_out_i = mask_out[id_start:id_start + bbox_num[i], :, :]
  228. im_h = origin_shape[i, 0]
  229. im_w = origin_shape[i, 1]
  230. bbox_num_i = bbox_num[id_start]
  231. pred_mask = self.paste_mask(mask_out_i[:, None, :, :],
  232. bboxes_i[:, 2:], im_h, im_w)
  233. pred_mask = paddle.cast(pred_mask >= self.binary_thresh,
  234. 'int32')
  235. pred_result[id_start:id_start + bbox_num[i], :im_h, :
  236. im_w] = pred_mask
  237. id_start += bbox_num[i]
  238. if self.assign_on_cpu:
  239. paddle.set_device(device)
  240. return pred_result
  241. @register
  242. class FCOSPostProcess(object):
  243. __inject__ = ['decode', 'nms']
  244. def __init__(self, decode=None, nms=None):
  245. super(FCOSPostProcess, self).__init__()
  246. self.decode = decode
  247. self.nms = nms
  248. def __call__(self, fcos_head_outs, scale_factor):
  249. """
  250. Decode the bbox and do NMS in FCOS.
  251. """
  252. locations, cls_logits, bboxes_reg, centerness = fcos_head_outs
  253. bboxes, score = self.decode(locations, cls_logits, bboxes_reg,
  254. centerness, scale_factor)
  255. bbox_pred, bbox_num, _ = self.nms(bboxes, score)
  256. return bbox_pred, bbox_num
  257. @register
  258. class S2ANetBBoxPostProcess(nn.Layer):
  259. __shared__ = ['num_classes']
  260. __inject__ = ['nms']
  261. def __init__(self, num_classes=15, nms_pre=2000, min_bbox_size=0, nms=None):
  262. super(S2ANetBBoxPostProcess, self).__init__()
  263. self.num_classes = num_classes
  264. self.nms_pre = nms_pre
  265. self.min_bbox_size = min_bbox_size
  266. self.nms = nms
  267. self.origin_shape_list = []
  268. self.fake_pred_cls_score_bbox = paddle.to_tensor(
  269. np.array(
  270. [[-1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]],
  271. dtype='float32'))
  272. self.fake_bbox_num = paddle.to_tensor(np.array([1], dtype='int32'))
  273. def forward(self, pred_scores, pred_bboxes):
  274. """
  275. pred_scores : [N, M] score
  276. pred_bboxes : [N, 5] xc, yc, w, h, a
  277. im_shape : [N, 2] im_shape
  278. scale_factor : [N, 2] scale_factor
  279. """
  280. pred_ploys0 = rbox2poly(pred_bboxes)
  281. pred_ploys = paddle.unsqueeze(pred_ploys0, axis=0)
  282. # pred_scores [NA, 16] --> [16, NA]
  283. pred_scores0 = paddle.transpose(pred_scores, [1, 0])
  284. pred_scores = paddle.unsqueeze(pred_scores0, axis=0)
  285. pred_cls_score_bbox, bbox_num, _ = self.nms(pred_ploys, pred_scores,
  286. self.num_classes)
  287. # Prevent empty bbox_pred from decode or NMS.
  288. # Bboxes and score before NMS may be empty due to the score threshold.
  289. if pred_cls_score_bbox.shape[0] <= 0 or pred_cls_score_bbox.shape[
  290. 1] <= 1:
  291. pred_cls_score_bbox = self.fake_pred_cls_score_bbox
  292. bbox_num = self.fake_bbox_num
  293. pred_cls_score_bbox = paddle.reshape(pred_cls_score_bbox, [-1, 10])
  294. return pred_cls_score_bbox, bbox_num
  295. def get_pred(self, bboxes, bbox_num, im_shape, scale_factor):
  296. """
  297. Rescale, clip and filter the bbox from the output of NMS to
  298. get final prediction.
  299. Args:
  300. bboxes(Tensor): bboxes [N, 10]
  301. bbox_num(Tensor): bbox_num
  302. im_shape(Tensor): [1 2]
  303. scale_factor(Tensor): [1 2]
  304. Returns:
  305. bbox_pred(Tensor): The output is the prediction with shape [N, 8]
  306. including labels, scores and bboxes. The size of
  307. bboxes are corresponding to the original image.
  308. """
  309. origin_shape = paddle.floor(im_shape / scale_factor + 0.5)
  310. origin_shape_list = []
  311. scale_factor_list = []
  312. # scale_factor: scale_y, scale_x
  313. for i in range(bbox_num.shape[0]):
  314. expand_shape = paddle.expand(origin_shape[i:i + 1, :],
  315. [bbox_num[i], 2])
  316. scale_y, scale_x = scale_factor[i][0], scale_factor[i][1]
  317. scale = paddle.concat([
  318. scale_x, scale_y, scale_x, scale_y, scale_x, scale_y, scale_x,
  319. scale_y
  320. ])
  321. expand_scale = paddle.expand(scale, [bbox_num[i], 8])
  322. origin_shape_list.append(expand_shape)
  323. scale_factor_list.append(expand_scale)
  324. origin_shape_list = paddle.concat(origin_shape_list)
  325. scale_factor_list = paddle.concat(scale_factor_list)
  326. # bboxes: [N, 10], label, score, bbox
  327. pred_label_score = bboxes[:, 0:2]
  328. pred_bbox = bboxes[:, 2:]
  329. # rescale bbox to original image
  330. pred_bbox = pred_bbox.reshape([-1, 8])
  331. scaled_bbox = pred_bbox / scale_factor_list
  332. origin_h = origin_shape_list[:, 0]
  333. origin_w = origin_shape_list[:, 1]
  334. bboxes = scaled_bbox
  335. zeros = paddle.zeros_like(origin_h)
  336. x1 = paddle.maximum(paddle.minimum(bboxes[:, 0], origin_w - 1), zeros)
  337. y1 = paddle.maximum(paddle.minimum(bboxes[:, 1], origin_h - 1), zeros)
  338. x2 = paddle.maximum(paddle.minimum(bboxes[:, 2], origin_w - 1), zeros)
  339. y2 = paddle.maximum(paddle.minimum(bboxes[:, 3], origin_h - 1), zeros)
  340. x3 = paddle.maximum(paddle.minimum(bboxes[:, 4], origin_w - 1), zeros)
  341. y3 = paddle.maximum(paddle.minimum(bboxes[:, 5], origin_h - 1), zeros)
  342. x4 = paddle.maximum(paddle.minimum(bboxes[:, 6], origin_w - 1), zeros)
  343. y4 = paddle.maximum(paddle.minimum(bboxes[:, 7], origin_h - 1), zeros)
  344. pred_bbox = paddle.stack([x1, y1, x2, y2, x3, y3, x4, y4], axis=-1)
  345. pred_result = paddle.concat([pred_label_score, pred_bbox], axis=1)
  346. return pred_result
  347. @register
  348. class JDEBBoxPostProcess(nn.Layer):
  349. __shared__ = ['num_classes']
  350. __inject__ = ['decode', 'nms']
  351. def __init__(self, num_classes=1, decode=None, nms=None, return_idx=True):
  352. super(JDEBBoxPostProcess, self).__init__()
  353. self.num_classes = num_classes
  354. self.decode = decode
  355. self.nms = nms
  356. self.return_idx = return_idx
  357. self.fake_bbox_pred = paddle.to_tensor(
  358. np.array(
  359. [[-1, 0.0, 0.0, 0.0, 0.0, 0.0]], dtype='float32'))
  360. self.fake_bbox_num = paddle.to_tensor(np.array([1], dtype='int32'))
  361. self.fake_nms_keep_idx = paddle.to_tensor(
  362. np.array(
  363. [[0]], dtype='int32'))
  364. self.fake_yolo_boxes_out = paddle.to_tensor(
  365. np.array(
  366. [[[0.0, 0.0, 0.0, 0.0]]], dtype='float32'))
  367. self.fake_yolo_scores_out = paddle.to_tensor(
  368. np.array(
  369. [[[0.0]]], dtype='float32'))
  370. self.fake_boxes_idx = paddle.to_tensor(np.array([[0]], dtype='int64'))
  371. def forward(self, head_out, anchors):
  372. """
  373. Decode the bbox and do NMS for JDE model.
  374. Args:
  375. head_out (list): Bbox_pred and cls_prob of bbox_head output.
  376. anchors (list): Anchors of JDE model.
  377. Returns:
  378. boxes_idx (Tensor): The index of kept bboxes after decode 'JDEBox'.
  379. bbox_pred (Tensor): The output is the prediction with shape [N, 6]
  380. including labels, scores and bboxes.
  381. bbox_num (Tensor): The number of prediction of each batch with shape [N].
  382. nms_keep_idx (Tensor): The index of kept bboxes after NMS.
  383. """
  384. boxes_idx, yolo_boxes_scores = self.decode(head_out, anchors)
  385. if len(boxes_idx) == 0:
  386. boxes_idx = self.fake_boxes_idx
  387. yolo_boxes_out = self.fake_yolo_boxes_out
  388. yolo_scores_out = self.fake_yolo_scores_out
  389. else:
  390. yolo_boxes = paddle.gather_nd(yolo_boxes_scores, boxes_idx)
  391. # TODO: only support bs=1 now
  392. yolo_boxes_out = paddle.reshape(
  393. yolo_boxes[:, :4], shape=[1, len(boxes_idx), 4])
  394. yolo_scores_out = paddle.reshape(
  395. yolo_boxes[:, 4:5], shape=[1, 1, len(boxes_idx)])
  396. boxes_idx = boxes_idx[:, 1:]
  397. if self.return_idx:
  398. bbox_pred, bbox_num, nms_keep_idx = self.nms(
  399. yolo_boxes_out, yolo_scores_out, self.num_classes)
  400. if bbox_pred.shape[0] == 0:
  401. bbox_pred = self.fake_bbox_pred
  402. bbox_num = self.fake_bbox_num
  403. nms_keep_idx = self.fake_nms_keep_idx
  404. return boxes_idx, bbox_pred, bbox_num, nms_keep_idx
  405. else:
  406. bbox_pred, bbox_num, _ = self.nms(yolo_boxes_out, yolo_scores_out,
  407. self.num_classes)
  408. if bbox_pred.shape[0] == 0:
  409. bbox_pred = self.fake_bbox_pred
  410. bbox_num = self.fake_bbox_num
  411. return _, bbox_pred, bbox_num, _
  412. @register
  413. class CenterNetPostProcess(TTFBox):
  414. """
  415. Postprocess the model outputs to get final prediction:
  416. 1. Do NMS for heatmap to get top `max_per_img` bboxes.
  417. 2. Decode bboxes using center offset and box size.
  418. 3. Rescale decoded bboxes reference to the origin image shape.
  419. Args:
  420. max_per_img(int): the maximum number of predicted objects in a image,
  421. 500 by default.
  422. down_ratio(int): the down ratio from images to heatmap, 4 by default.
  423. regress_ltrb (bool): whether to regress left/top/right/bottom or
  424. width/height for a box, true by default.
  425. for_mot (bool): whether return other features used in tracking model.
  426. """
  427. __shared__ = ['down_ratio', 'for_mot']
  428. def __init__(self,
  429. max_per_img=500,
  430. down_ratio=4,
  431. regress_ltrb=True,
  432. for_mot=False):
  433. super(TTFBox, self).__init__()
  434. self.max_per_img = max_per_img
  435. self.down_ratio = down_ratio
  436. self.regress_ltrb = regress_ltrb
  437. self.for_mot = for_mot
  438. def __call__(self, hm, wh, reg, im_shape, scale_factor):
  439. heat = self._simple_nms(hm)
  440. scores, inds, topk_clses, ys, xs = self._topk(heat)
  441. scores = scores.unsqueeze(1)
  442. clses = topk_clses.unsqueeze(1)
  443. reg_t = paddle.transpose(reg, [0, 2, 3, 1])
  444. # Like TTFBox, batch size is 1.
  445. # TODO: support batch size > 1
  446. reg = paddle.reshape(reg_t, [-1, reg_t.shape[-1]])
  447. reg = paddle.gather(reg, inds)
  448. xs = paddle.cast(xs, 'float32')
  449. ys = paddle.cast(ys, 'float32')
  450. xs = xs + reg[:, 0:1]
  451. ys = ys + reg[:, 1:2]
  452. wh_t = paddle.transpose(wh, [0, 2, 3, 1])
  453. wh = paddle.reshape(wh_t, [-1, wh_t.shape[-1]])
  454. wh = paddle.gather(wh, inds)
  455. if self.regress_ltrb:
  456. x1 = xs - wh[:, 0:1]
  457. y1 = ys - wh[:, 1:2]
  458. x2 = xs + wh[:, 2:3]
  459. y2 = ys + wh[:, 3:4]
  460. else:
  461. x1 = xs - wh[:, 0:1] / 2
  462. y1 = ys - wh[:, 1:2] / 2
  463. x2 = xs + wh[:, 0:1] / 2
  464. y2 = ys + wh[:, 1:2] / 2
  465. n, c, feat_h, feat_w = paddle.shape(hm)
  466. padw = (feat_w * self.down_ratio - im_shape[0, 1]) / 2
  467. padh = (feat_h * self.down_ratio - im_shape[0, 0]) / 2
  468. x1 = x1 * self.down_ratio
  469. y1 = y1 * self.down_ratio
  470. x2 = x2 * self.down_ratio
  471. y2 = y2 * self.down_ratio
  472. x1 = x1 - padw
  473. y1 = y1 - padh
  474. x2 = x2 - padw
  475. y2 = y2 - padh
  476. bboxes = paddle.concat([x1, y1, x2, y2], axis=1)
  477. scale_y = scale_factor[:, 0:1]
  478. scale_x = scale_factor[:, 1:2]
  479. scale_expand = paddle.concat(
  480. [scale_x, scale_y, scale_x, scale_y], axis=1)
  481. boxes_shape = bboxes.shape[:]
  482. scale_expand = paddle.expand(scale_expand, shape=boxes_shape)
  483. bboxes = paddle.divide(bboxes, scale_expand)
  484. results = paddle.concat([clses, scores, bboxes], axis=1)
  485. if self.for_mot:
  486. return results, inds, topk_clses
  487. else:
  488. return results, paddle.shape(results)[0:1], topk_clses
  489. @register
  490. class DETRBBoxPostProcess(object):
  491. __shared__ = ['num_classes', 'use_focal_loss']
  492. __inject__ = []
  493. def __init__(self,
  494. num_classes=80,
  495. num_top_queries=100,
  496. use_focal_loss=False):
  497. super(DETRBBoxPostProcess, self).__init__()
  498. self.num_classes = num_classes
  499. self.num_top_queries = num_top_queries
  500. self.use_focal_loss = use_focal_loss
  501. def __call__(self, head_out, im_shape, scale_factor):
  502. """
  503. Decode the bbox.
  504. Args:
  505. head_out (tuple): bbox_pred, cls_logit and masks of bbox_head output.
  506. im_shape (Tensor): The shape of the input image.
  507. scale_factor (Tensor): The scale factor of the input image.
  508. Returns:
  509. bbox_pred (Tensor): The output prediction with shape [N, 6], including
  510. labels, scores and bboxes. The size of bboxes are corresponding
  511. to the input image, the bboxes may be used in other branch.
  512. bbox_num (Tensor): The number of prediction boxes of each batch with
  513. shape [bs], and is N.
  514. """
  515. bboxes, logits, masks = head_out
  516. bbox_pred = bbox_cxcywh_to_xyxy(bboxes)
  517. origin_shape = paddle.floor(im_shape / scale_factor + 0.5)
  518. img_h, img_w = origin_shape.unbind(1)
  519. origin_shape = paddle.stack(
  520. [img_w, img_h, img_w, img_h], axis=-1).unsqueeze(0)
  521. bbox_pred *= origin_shape
  522. scores = F.sigmoid(logits) if self.use_focal_loss else F.softmax(
  523. logits)[:, :, :-1]
  524. if not self.use_focal_loss:
  525. scores, labels = scores.max(-1), scores.argmax(-1)
  526. if scores.shape[1] > self.num_top_queries:
  527. scores, index = paddle.topk(
  528. scores, self.num_top_queries, axis=-1)
  529. labels = paddle.stack(
  530. [paddle.gather(l, i) for l, i in zip(labels, index)])
  531. bbox_pred = paddle.stack(
  532. [paddle.gather(b, i) for b, i in zip(bbox_pred, index)])
  533. else:
  534. scores, index = paddle.topk(
  535. scores.reshape([logits.shape[0], -1]),
  536. self.num_top_queries,
  537. axis=-1)
  538. labels = index % logits.shape[2]
  539. index = index // logits.shape[2]
  540. bbox_pred = paddle.stack(
  541. [paddle.gather(b, i) for b, i in zip(bbox_pred, index)])
  542. bbox_pred = paddle.concat(
  543. [
  544. labels.unsqueeze(-1).astype('float32'), scores.unsqueeze(-1),
  545. bbox_pred
  546. ],
  547. axis=-1)
  548. bbox_num = paddle.to_tensor(
  549. bbox_pred.shape[1], dtype='int32').tile([bbox_pred.shape[0]])
  550. bbox_pred = bbox_pred.reshape([-1, 6])
  551. return bbox_pred, bbox_num
  552. @register
  553. class SparsePostProcess(object):
  554. __shared__ = ['num_classes']
  555. def __init__(self, num_proposals, num_classes=80):
  556. super(SparsePostProcess, self).__init__()
  557. self.num_classes = num_classes
  558. self.num_proposals = num_proposals
  559. def __call__(self, box_cls, box_pred, scale_factor_wh, img_whwh):
  560. """
  561. Arguments:
  562. box_cls (Tensor): tensor of shape (batch_size, num_proposals, K).
  563. The tensor predicts the classification probability for each proposal.
  564. box_pred (Tensor): tensors of shape (batch_size, num_proposals, 4).
  565. The tensor predicts 4-vector (x,y,w,h) box
  566. regression values for every proposal
  567. scale_factor_wh (Tensor): tensors of shape [batch_size, 2] the scalor of per img
  568. img_whwh (Tensor): tensors of shape [batch_size, 4]
  569. Returns:
  570. bbox_pred (Tensor): tensors of shape [num_boxes, 6] Each row has 6 values:
  571. [label, confidence, xmin, ymin, xmax, ymax]
  572. bbox_num (Tensor): tensors of shape [batch_size] the number of RoIs in each image.
  573. """
  574. assert len(box_cls) == len(scale_factor_wh) == len(img_whwh)
  575. img_wh = img_whwh[:, :2]
  576. scores = F.sigmoid(box_cls)
  577. labels = paddle.arange(0, self.num_classes). \
  578. unsqueeze(0).tile([self.num_proposals, 1]).flatten(start_axis=0, stop_axis=1)
  579. classes_all = []
  580. scores_all = []
  581. boxes_all = []
  582. for i, (scores_per_image,
  583. box_pred_per_image) in enumerate(zip(scores, box_pred)):
  584. scores_per_image, topk_indices = scores_per_image.flatten(
  585. 0, 1).topk(
  586. self.num_proposals, sorted=False)
  587. labels_per_image = paddle.gather(labels, topk_indices, axis=0)
  588. box_pred_per_image = box_pred_per_image.reshape([-1, 1, 4]).tile(
  589. [1, self.num_classes, 1]).reshape([-1, 4])
  590. box_pred_per_image = paddle.gather(
  591. box_pred_per_image, topk_indices, axis=0)
  592. classes_all.append(labels_per_image)
  593. scores_all.append(scores_per_image)
  594. boxes_all.append(box_pred_per_image)
  595. bbox_num = paddle.zeros([len(scale_factor_wh)], dtype="int32")
  596. boxes_final = []
  597. for i in range(len(scale_factor_wh)):
  598. classes = classes_all[i]
  599. boxes = boxes_all[i]
  600. scores = scores_all[i]
  601. boxes[:, 0::2] = paddle.clip(
  602. boxes[:, 0::2], min=0, max=img_wh[i][0]) / scale_factor_wh[i][0]
  603. boxes[:, 1::2] = paddle.clip(
  604. boxes[:, 1::2], min=0, max=img_wh[i][1]) / scale_factor_wh[i][1]
  605. boxes_w, boxes_h = (boxes[:, 2] - boxes[:, 0]).numpy(), (
  606. boxes[:, 3] - boxes[:, 1]).numpy()
  607. keep = (boxes_w > 1.) & (boxes_h > 1.)
  608. if (keep.sum() == 0):
  609. bboxes = paddle.zeros([1, 6]).astype("float32")
  610. else:
  611. boxes = paddle.to_tensor(boxes.numpy()[keep]).astype("float32")
  612. classes = paddle.to_tensor(classes.numpy()[keep]).astype(
  613. "float32").unsqueeze(-1)
  614. scores = paddle.to_tensor(scores.numpy()[keep]).astype(
  615. "float32").unsqueeze(-1)
  616. bboxes = paddle.concat([classes, scores, boxes], axis=-1)
  617. boxes_final.append(bboxes)
  618. bbox_num[i] = bboxes.shape[0]
  619. bbox_pred = paddle.concat(boxes_final)
  620. return bbox_pred, bbox_num
  621. def nms(dets, thresh):
  622. """Apply classic DPM-style greedy NMS."""
  623. if dets.shape[0] == 0:
  624. return dets[[], :]
  625. scores = dets[:, 0]
  626. x1 = dets[:, 1]
  627. y1 = dets[:, 2]
  628. x2 = dets[:, 3]
  629. y2 = dets[:, 4]
  630. areas = (x2 - x1 + 1) * (y2 - y1 + 1)
  631. order = scores.argsort()[::-1]
  632. ndets = dets.shape[0]
  633. suppressed = np.zeros((ndets), dtype=np.int)
  634. # nominal indices
  635. # _i, _j
  636. # sorted indices
  637. # i, j
  638. # temp variables for box i's (the box currently under consideration)
  639. # ix1, iy1, ix2, iy2, iarea
  640. # variables for computing overlap with box j (lower scoring box)
  641. # xx1, yy1, xx2, yy2
  642. # w, h
  643. # inter, ovr
  644. for _i in range(ndets):
  645. i = order[_i]
  646. if suppressed[i] == 1:
  647. continue
  648. ix1 = x1[i]
  649. iy1 = y1[i]
  650. ix2 = x2[i]
  651. iy2 = y2[i]
  652. iarea = areas[i]
  653. for _j in range(_i + 1, ndets):
  654. j = order[_j]
  655. if suppressed[j] == 1:
  656. continue
  657. xx1 = max(ix1, x1[j])
  658. yy1 = max(iy1, y1[j])
  659. xx2 = min(ix2, x2[j])
  660. yy2 = min(iy2, y2[j])
  661. w = max(0.0, xx2 - xx1 + 1)
  662. h = max(0.0, yy2 - yy1 + 1)
  663. inter = w * h
  664. ovr = inter / (iarea + areas[j] - inter)
  665. if ovr >= thresh:
  666. suppressed[j] = 1
  667. keep = np.where(suppressed == 0)[0]
  668. dets = dets[keep, :]
  669. return dets