post_process.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335
  1. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import logging
  18. import numpy as np
  19. import cv2
  20. __all__ = ['nms']
  21. logger = logging.getLogger(__name__)
  22. def box_flip(boxes, im_shape):
  23. im_width = im_shape[0][1]
  24. flipped_boxes = boxes.copy()
  25. flipped_boxes[:, 0::4] = im_width - boxes[:, 2::4] - 1
  26. flipped_boxes[:, 2::4] = im_width - boxes[:, 0::4] - 1
  27. return flipped_boxes
  28. def nms(dets, thresh):
  29. """
  30. refer to:
  31. https://github.com/facebookresearch/Detectron/blob/main/detectron/utils/cython_nms.pyx
  32. Apply classic DPM-style greedy NMS.
  33. """
  34. if dets.shape[0] == 0:
  35. return dets[[], :]
  36. scores = dets[:, 0]
  37. x1 = dets[:, 1]
  38. y1 = dets[:, 2]
  39. x2 = dets[:, 3]
  40. y2 = dets[:, 4]
  41. areas = (x2 - x1 + 1) * (y2 - y1 + 1)
  42. order = scores.argsort()[::-1]
  43. ndets = dets.shape[0]
  44. suppressed = np.zeros((ndets), dtype=np.int)
  45. # nominal indices
  46. # _i, _j
  47. # sorted indices
  48. # i, j
  49. # temp variables for box i's (the box currently under consideration)
  50. # ix1, iy1, ix2, iy2, iarea
  51. # variables for computing overlap with box j (lower scoring box)
  52. # xx1, yy1, xx2, yy2
  53. # w, h
  54. # inter, ovr
  55. for _i in range(ndets):
  56. i = order[_i]
  57. if suppressed[i] == 1:
  58. continue
  59. ix1 = x1[i]
  60. iy1 = y1[i]
  61. ix2 = x2[i]
  62. iy2 = y2[i]
  63. iarea = areas[i]
  64. for _j in range(_i + 1, ndets):
  65. j = order[_j]
  66. if suppressed[j] == 1:
  67. continue
  68. xx1 = max(ix1, x1[j])
  69. yy1 = max(iy1, y1[j])
  70. xx2 = min(ix2, x2[j])
  71. yy2 = min(iy2, y2[j])
  72. w = max(0.0, xx2 - xx1 + 1)
  73. h = max(0.0, yy2 - yy1 + 1)
  74. inter = w * h
  75. ovr = inter / (iarea + areas[j] - inter)
  76. if ovr >= thresh:
  77. suppressed[j] = 1
  78. keep = np.where(suppressed == 0)[0]
  79. dets = dets[keep, :]
  80. return dets
  81. def soft_nms(dets, sigma, thres):
  82. """
  83. refer to:
  84. https://github.com/facebookresearch/Detectron/blob/main/detectron/utils/cython_nms.pyx
  85. """
  86. dets_final = []
  87. while len(dets) > 0:
  88. maxpos = np.argmax(dets[:, 0])
  89. dets_final.append(dets[maxpos].copy())
  90. ts, tx1, ty1, tx2, ty2 = dets[maxpos]
  91. scores = dets[:, 0]
  92. # force remove bbox at maxpos
  93. scores[maxpos] = -1
  94. x1 = dets[:, 1]
  95. y1 = dets[:, 2]
  96. x2 = dets[:, 3]
  97. y2 = dets[:, 4]
  98. areas = (x2 - x1 + 1) * (y2 - y1 + 1)
  99. xx1 = np.maximum(tx1, x1)
  100. yy1 = np.maximum(ty1, y1)
  101. xx2 = np.minimum(tx2, x2)
  102. yy2 = np.minimum(ty2, y2)
  103. w = np.maximum(0.0, xx2 - xx1 + 1)
  104. h = np.maximum(0.0, yy2 - yy1 + 1)
  105. inter = w * h
  106. ovr = inter / (areas + areas[maxpos] - inter)
  107. weight = np.exp(-(ovr * ovr) / sigma)
  108. scores = scores * weight
  109. idx_keep = np.where(scores >= thres)
  110. dets[:, 0] = scores
  111. dets = dets[idx_keep]
  112. dets_final = np.array(dets_final).reshape(-1, 5)
  113. return dets_final
  114. def bbox_area(box):
  115. w = box[2] - box[0] + 1
  116. h = box[3] - box[1] + 1
  117. return w * h
  118. def bbox_overlaps(x, y):
  119. N = x.shape[0]
  120. K = y.shape[0]
  121. overlaps = np.zeros((N, K), dtype=np.float32)
  122. for k in range(K):
  123. y_area = bbox_area(y[k])
  124. for n in range(N):
  125. iw = min(x[n, 2], y[k, 2]) - max(x[n, 0], y[k, 0]) + 1
  126. if iw > 0:
  127. ih = min(x[n, 3], y[k, 3]) - max(x[n, 1], y[k, 1]) + 1
  128. if ih > 0:
  129. x_area = bbox_area(x[n])
  130. ua = x_area + y_area - iw * ih
  131. overlaps[n, k] = iw * ih / ua
  132. return overlaps
  133. def box_voting(nms_dets, dets, vote_thresh):
  134. top_dets = nms_dets.copy()
  135. top_boxes = nms_dets[:, 1:]
  136. all_boxes = dets[:, 1:]
  137. all_scores = dets[:, 0]
  138. top_to_all_overlaps = bbox_overlaps(top_boxes, all_boxes)
  139. for k in range(nms_dets.shape[0]):
  140. inds_to_vote = np.where(top_to_all_overlaps[k] >= vote_thresh)[0]
  141. boxes_to_vote = all_boxes[inds_to_vote, :]
  142. ws = all_scores[inds_to_vote]
  143. top_dets[k, 1:] = np.average(boxes_to_vote, axis=0, weights=ws)
  144. return top_dets
  145. def get_nms_result(boxes,
  146. scores,
  147. config,
  148. num_classes,
  149. background_label=0,
  150. labels=None):
  151. has_labels = labels is not None
  152. cls_boxes = [[] for _ in range(num_classes)]
  153. start_idx = 1 if background_label == 0 else 0
  154. for j in range(start_idx, num_classes):
  155. inds = np.where(labels == j)[0] if has_labels else np.where(
  156. scores[:, j] > config['score_thresh'])[0]
  157. scores_j = scores[inds] if has_labels else scores[inds, j]
  158. boxes_j = boxes[inds, :] if has_labels else boxes[inds, j * 4:(j + 1) *
  159. 4]
  160. dets_j = np.hstack((scores_j[:, np.newaxis], boxes_j)).astype(
  161. np.float32, copy=False)
  162. if config.get('use_soft_nms', False):
  163. nms_dets = soft_nms(dets_j, config['sigma'], config['nms_thresh'])
  164. else:
  165. nms_dets = nms(dets_j, config['nms_thresh'])
  166. if config.get('enable_voting', False):
  167. nms_dets = box_voting(nms_dets, dets_j, config['vote_thresh'])
  168. #add labels
  169. label = np.array([j for _ in range(len(nms_dets))])
  170. nms_dets = np.hstack((label[:, np.newaxis], nms_dets)).astype(
  171. np.float32, copy=False)
  172. cls_boxes[j] = nms_dets
  173. # Limit to max_per_image detections **over all classes**
  174. image_scores = np.hstack(
  175. [cls_boxes[j][:, 1] for j in range(start_idx, num_classes)])
  176. if len(image_scores) > config['detections_per_im']:
  177. image_thresh = np.sort(image_scores)[-config['detections_per_im']]
  178. for j in range(start_idx, num_classes):
  179. keep = np.where(cls_boxes[j][:, 1] >= image_thresh)[0]
  180. cls_boxes[j] = cls_boxes[j][keep, :]
  181. im_results = np.vstack(
  182. [cls_boxes[j] for j in range(start_idx, num_classes)])
  183. return im_results
  184. def mstest_box_post_process(result, config, num_classes):
  185. """
  186. Multi-scale Test
  187. Only available for batch_size=1 now.
  188. """
  189. post_bbox = {}
  190. use_flip = False
  191. ms_boxes = []
  192. ms_scores = []
  193. im_shape = result['im_shape'][0]
  194. for k in result.keys():
  195. if 'bbox' in k:
  196. boxes = result[k][0]
  197. boxes = np.reshape(boxes, (-1, 4 * num_classes))
  198. scores = result['score' + k[4:]][0]
  199. if 'flip' in k:
  200. boxes = box_flip(boxes, im_shape)
  201. use_flip = True
  202. ms_boxes.append(boxes)
  203. ms_scores.append(scores)
  204. ms_boxes = np.concatenate(ms_boxes)
  205. ms_scores = np.concatenate(ms_scores)
  206. bbox_pred = get_nms_result(ms_boxes, ms_scores, config, num_classes)
  207. post_bbox.update({'bbox': (bbox_pred, [[len(bbox_pred)]])})
  208. if use_flip:
  209. bbox = bbox_pred[:, 2:]
  210. bbox_flip = np.append(
  211. bbox_pred[:, :2], box_flip(bbox, im_shape), axis=1)
  212. post_bbox.update({'bbox_flip': (bbox_flip, [[len(bbox_flip)]])})
  213. return post_bbox
  214. def mstest_mask_post_process(result, cfg):
  215. mask_list = []
  216. im_shape = result['im_shape'][0]
  217. M = cfg.FPNRoIAlign['mask_resolution']
  218. for k in result.keys():
  219. if 'mask' in k:
  220. masks = result[k][0]
  221. if len(masks.shape) != 4:
  222. masks = np.zeros((0, M, M))
  223. mask_list.append(masks)
  224. continue
  225. if 'flip' in k:
  226. masks = masks[:, :, :, ::-1]
  227. mask_list.append(masks)
  228. mask_pred = np.mean(mask_list, axis=0)
  229. return {'mask': (mask_pred, [[len(mask_pred)]])}
  230. def mask_encode(results, resolution, thresh_binarize=0.5):
  231. import pycocotools.mask as mask_util
  232. from ppdet.utils.coco_eval import expand_boxes
  233. scale = (resolution + 2.0) / resolution
  234. bboxes = results['bbox'][0]
  235. masks = results['mask'][0]
  236. lengths = results['mask'][1][0]
  237. im_shapes = results['im_shape'][0]
  238. segms = []
  239. if bboxes.shape == (1, 1) or bboxes is None:
  240. return segms
  241. if len(bboxes.tolist()) == 0:
  242. return segms
  243. s = 0
  244. # for each sample
  245. for i in range(len(lengths)):
  246. num = lengths[i]
  247. im_shape = im_shapes[i]
  248. bbox = bboxes[s:s + num][:, 2:]
  249. clsid_scores = bboxes[s:s + num][:, 0:2]
  250. mask = masks[s:s + num]
  251. s += num
  252. im_h = int(im_shape[0])
  253. im_w = int(im_shape[1])
  254. expand_bbox = expand_boxes(bbox, scale)
  255. expand_bbox = expand_bbox.astype(np.int32)
  256. padded_mask = np.zeros(
  257. (resolution + 2, resolution + 2), dtype=np.float32)
  258. for j in range(num):
  259. xmin, ymin, xmax, ymax = expand_bbox[j].tolist()
  260. clsid, score = clsid_scores[j].tolist()
  261. clsid = int(clsid)
  262. padded_mask[1:-1, 1:-1] = mask[j, clsid, :, :]
  263. w = xmax - xmin + 1
  264. h = ymax - ymin + 1
  265. w = np.maximum(w, 1)
  266. h = np.maximum(h, 1)
  267. resized_mask = cv2.resize(padded_mask, (w, h))
  268. resized_mask = np.array(
  269. resized_mask > thresh_binarize, dtype=np.uint8)
  270. im_mask = np.zeros((im_h, im_w), dtype=np.uint8)
  271. x0 = min(max(xmin, 0), im_w)
  272. x1 = min(max(xmax + 1, 0), im_w)
  273. y0 = min(max(ymin, 0), im_h)
  274. y1 = min(max(ymax + 1, 0), im_h)
  275. im_mask[y0:y1, x0:x1] = resized_mask[(y0 - ymin):(y1 - ymin), (
  276. x0 - xmin):(x1 - xmin)]
  277. segm = mask_util.encode(
  278. np.array(
  279. im_mask[:, :, np.newaxis], order='F'))[0]
  280. segms.append(segm)
  281. return segms
  282. def corner_post_process(results, config, num_classes):
  283. detections = results['bbox'][0]
  284. keep_inds = (detections[:, 1] > -1)
  285. detections = detections[keep_inds]
  286. labels = detections[:, 0]
  287. scores = detections[:, 1]
  288. boxes = detections[:, 2:6]
  289. cls_boxes = get_nms_result(
  290. boxes, scores, config, num_classes, background_label=-1, labels=labels)
  291. results.update({'bbox': (cls_boxes, [[len(cls_boxes)]])})