target.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. import paddle
  16. from ..bbox_utils import bbox2delta, bbox_overlaps
  17. def rpn_anchor_target(anchors,
  18. gt_boxes,
  19. rpn_batch_size_per_im,
  20. rpn_positive_overlap,
  21. rpn_negative_overlap,
  22. rpn_fg_fraction,
  23. use_random=True,
  24. batch_size=1,
  25. ignore_thresh=-1,
  26. is_crowd=None,
  27. weights=[1., 1., 1., 1.],
  28. assign_on_cpu=False):
  29. tgt_labels = []
  30. tgt_bboxes = []
  31. tgt_deltas = []
  32. for i in range(batch_size):
  33. gt_bbox = gt_boxes[i]
  34. is_crowd_i = is_crowd[i] if is_crowd else None
  35. # Step1: match anchor and gt_bbox
  36. matches, match_labels = label_box(
  37. anchors, gt_bbox, rpn_positive_overlap, rpn_negative_overlap, True,
  38. ignore_thresh, is_crowd_i, assign_on_cpu)
  39. # Step2: sample anchor
  40. fg_inds, bg_inds = subsample_labels(match_labels, rpn_batch_size_per_im,
  41. rpn_fg_fraction, 0, use_random)
  42. # Fill with the ignore label (-1), then set positive and negative labels
  43. labels = paddle.full(match_labels.shape, -1, dtype='int32')
  44. if bg_inds.shape[0] > 0:
  45. labels = paddle.scatter(labels, bg_inds, paddle.zeros_like(bg_inds))
  46. if fg_inds.shape[0] > 0:
  47. labels = paddle.scatter(labels, fg_inds, paddle.ones_like(fg_inds))
  48. # Step3: make output
  49. if gt_bbox.shape[0] == 0:
  50. matched_gt_boxes = paddle.zeros([matches.shape[0], 4])
  51. tgt_delta = paddle.zeros([matches.shape[0], 4])
  52. else:
  53. matched_gt_boxes = paddle.gather(gt_bbox, matches)
  54. tgt_delta = bbox2delta(anchors, matched_gt_boxes, weights)
  55. matched_gt_boxes.stop_gradient = True
  56. tgt_delta.stop_gradient = True
  57. labels.stop_gradient = True
  58. tgt_labels.append(labels)
  59. tgt_bboxes.append(matched_gt_boxes)
  60. tgt_deltas.append(tgt_delta)
  61. return tgt_labels, tgt_bboxes, tgt_deltas
  62. def label_box(anchors,
  63. gt_boxes,
  64. positive_overlap,
  65. negative_overlap,
  66. allow_low_quality,
  67. ignore_thresh,
  68. is_crowd=None,
  69. assign_on_cpu=False):
  70. if assign_on_cpu:
  71. device = paddle.device.get_device()
  72. paddle.set_device("cpu")
  73. iou = bbox_overlaps(gt_boxes, anchors)
  74. paddle.set_device(device)
  75. else:
  76. iou = bbox_overlaps(gt_boxes, anchors)
  77. n_gt = gt_boxes.shape[0]
  78. if n_gt == 0 or is_crowd is None:
  79. n_gt_crowd = 0
  80. else:
  81. n_gt_crowd = paddle.nonzero(is_crowd).shape[0]
  82. if iou.shape[0] == 0 or n_gt_crowd == n_gt:
  83. # No truth, assign everything to background
  84. default_matches = paddle.full((iou.shape[1], ), 0, dtype='int64')
  85. default_match_labels = paddle.full((iou.shape[1], ), 0, dtype='int32')
  86. return default_matches, default_match_labels
  87. # if ignore_thresh > 0, remove anchor if it is closed to
  88. # one of the crowded ground-truth
  89. if n_gt_crowd > 0:
  90. N_a = anchors.shape[0]
  91. ones = paddle.ones([N_a])
  92. mask = is_crowd * ones
  93. if ignore_thresh > 0:
  94. crowd_iou = iou * mask
  95. valid = (paddle.sum((crowd_iou > ignore_thresh).cast('int32'),
  96. axis=0) > 0).cast('float32')
  97. iou = iou * (1 - valid) - valid
  98. # ignore the iou between anchor and crowded ground-truth
  99. iou = iou * (1 - mask) - mask
  100. matched_vals, matches = paddle.topk(iou, k=1, axis=0)
  101. match_labels = paddle.full(matches.shape, -1, dtype='int32')
  102. # set ignored anchor with iou = -1
  103. neg_cond = paddle.logical_and(matched_vals > -1,
  104. matched_vals < negative_overlap)
  105. match_labels = paddle.where(neg_cond,
  106. paddle.zeros_like(match_labels), match_labels)
  107. match_labels = paddle.where(matched_vals >= positive_overlap,
  108. paddle.ones_like(match_labels), match_labels)
  109. if allow_low_quality:
  110. highest_quality_foreach_gt = iou.max(axis=1, keepdim=True)
  111. pred_inds_with_highest_quality = paddle.logical_and(
  112. iou > 0, iou == highest_quality_foreach_gt).cast('int32').sum(
  113. 0, keepdim=True)
  114. match_labels = paddle.where(pred_inds_with_highest_quality > 0,
  115. paddle.ones_like(match_labels),
  116. match_labels)
  117. matches = matches.flatten()
  118. match_labels = match_labels.flatten()
  119. return matches, match_labels
  120. def subsample_labels(labels,
  121. num_samples,
  122. fg_fraction,
  123. bg_label=0,
  124. use_random=True):
  125. positive = paddle.nonzero(
  126. paddle.logical_and(labels != -1, labels != bg_label))
  127. negative = paddle.nonzero(labels == bg_label)
  128. fg_num = int(num_samples * fg_fraction)
  129. fg_num = min(positive.numel(), fg_num)
  130. bg_num = num_samples - fg_num
  131. bg_num = min(negative.numel(), bg_num)
  132. if fg_num == 0 and bg_num == 0:
  133. fg_inds = paddle.zeros([0], dtype='int32')
  134. bg_inds = paddle.zeros([0], dtype='int32')
  135. return fg_inds, bg_inds
  136. # randomly select positive and negative examples
  137. negative = negative.cast('int32').flatten()
  138. bg_perm = paddle.randperm(negative.numel(), dtype='int32')
  139. bg_perm = paddle.slice(bg_perm, axes=[0], starts=[0], ends=[bg_num])
  140. if use_random:
  141. bg_inds = paddle.gather(negative, bg_perm)
  142. else:
  143. bg_inds = paddle.slice(negative, axes=[0], starts=[0], ends=[bg_num])
  144. if fg_num == 0:
  145. fg_inds = paddle.zeros([0], dtype='int32')
  146. return fg_inds, bg_inds
  147. positive = positive.cast('int32').flatten()
  148. fg_perm = paddle.randperm(positive.numel(), dtype='int32')
  149. fg_perm = paddle.slice(fg_perm, axes=[0], starts=[0], ends=[fg_num])
  150. if use_random:
  151. fg_inds = paddle.gather(positive, fg_perm)
  152. else:
  153. fg_inds = paddle.slice(positive, axes=[0], starts=[0], ends=[fg_num])
  154. return fg_inds, bg_inds
  155. def generate_proposal_target(rpn_rois,
  156. gt_classes,
  157. gt_boxes,
  158. batch_size_per_im,
  159. fg_fraction,
  160. fg_thresh,
  161. bg_thresh,
  162. num_classes,
  163. ignore_thresh=-1.,
  164. is_crowd=None,
  165. use_random=True,
  166. is_cascade=False,
  167. cascade_iou=0.5,
  168. assign_on_cpu=False):
  169. rois_with_gt = []
  170. tgt_labels = []
  171. tgt_bboxes = []
  172. tgt_gt_inds = []
  173. new_rois_num = []
  174. # In cascade rcnn, the threshold for foreground and background
  175. # is used from cascade_iou
  176. fg_thresh = cascade_iou if is_cascade else fg_thresh
  177. bg_thresh = cascade_iou if is_cascade else bg_thresh
  178. for i, rpn_roi in enumerate(rpn_rois):
  179. gt_bbox = gt_boxes[i]
  180. is_crowd_i = is_crowd[i] if is_crowd else None
  181. gt_class = paddle.squeeze(gt_classes[i], axis=-1)
  182. # Concat RoIs and gt boxes except cascade rcnn or none gt
  183. if not is_cascade and gt_bbox.shape[0] > 0:
  184. bbox = paddle.concat([rpn_roi, gt_bbox])
  185. else:
  186. bbox = rpn_roi
  187. # Step1: label bbox
  188. matches, match_labels = label_box(bbox, gt_bbox, fg_thresh, bg_thresh,
  189. False, ignore_thresh, is_crowd_i,
  190. assign_on_cpu)
  191. # Step2: sample bbox
  192. sampled_inds, sampled_gt_classes = sample_bbox(
  193. matches, match_labels, gt_class, batch_size_per_im, fg_fraction,
  194. num_classes, use_random, is_cascade)
  195. # Step3: make output
  196. rois_per_image = bbox if is_cascade else paddle.gather(bbox,
  197. sampled_inds)
  198. sampled_gt_ind = matches if is_cascade else paddle.gather(matches,
  199. sampled_inds)
  200. if gt_bbox.shape[0] > 0:
  201. sampled_bbox = paddle.gather(gt_bbox, sampled_gt_ind)
  202. else:
  203. num = rois_per_image.shape[0]
  204. sampled_bbox = paddle.zeros([num, 4], dtype='float32')
  205. rois_per_image.stop_gradient = True
  206. sampled_gt_ind.stop_gradient = True
  207. sampled_bbox.stop_gradient = True
  208. tgt_labels.append(sampled_gt_classes)
  209. tgt_bboxes.append(sampled_bbox)
  210. rois_with_gt.append(rois_per_image)
  211. tgt_gt_inds.append(sampled_gt_ind)
  212. new_rois_num.append(paddle.shape(sampled_inds)[0])
  213. new_rois_num = paddle.concat(new_rois_num)
  214. return rois_with_gt, tgt_labels, tgt_bboxes, tgt_gt_inds, new_rois_num
  215. def sample_bbox(matches,
  216. match_labels,
  217. gt_classes,
  218. batch_size_per_im,
  219. fg_fraction,
  220. num_classes,
  221. use_random=True,
  222. is_cascade=False):
  223. n_gt = gt_classes.shape[0]
  224. if n_gt == 0:
  225. # No truth, assign everything to background
  226. gt_classes = paddle.ones(matches.shape, dtype='int32') * num_classes
  227. #return matches, match_labels + num_classes
  228. else:
  229. gt_classes = paddle.gather(gt_classes, matches)
  230. gt_classes = paddle.where(match_labels == 0,
  231. paddle.ones_like(gt_classes) * num_classes,
  232. gt_classes)
  233. gt_classes = paddle.where(match_labels == -1,
  234. paddle.ones_like(gt_classes) * -1, gt_classes)
  235. if is_cascade:
  236. index = paddle.arange(matches.shape[0])
  237. return index, gt_classes
  238. rois_per_image = int(batch_size_per_im)
  239. fg_inds, bg_inds = subsample_labels(gt_classes, rois_per_image, fg_fraction,
  240. num_classes, use_random)
  241. if fg_inds.shape[0] == 0 and bg_inds.shape[0] == 0:
  242. # fake output labeled with -1 when all boxes are neither
  243. # foreground nor background
  244. sampled_inds = paddle.zeros([1], dtype='int32')
  245. else:
  246. sampled_inds = paddle.concat([fg_inds, bg_inds])
  247. sampled_gt_classes = paddle.gather(gt_classes, sampled_inds)
  248. return sampled_inds, sampled_gt_classes
  249. def polygons_to_mask(polygons, height, width):
  250. """
  251. Convert the polygons to mask format
  252. Args:
  253. polygons (list[ndarray]): each array has shape (Nx2,)
  254. height (int): mask height
  255. width (int): mask width
  256. Returns:
  257. ndarray: a bool mask of shape (height, width)
  258. """
  259. import pycocotools.mask as mask_util
  260. assert len(polygons) > 0, "COCOAPI does not support empty polygons"
  261. rles = mask_util.frPyObjects(polygons, height, width)
  262. rle = mask_util.merge(rles)
  263. return mask_util.decode(rle).astype(np.bool)
  264. def rasterize_polygons_within_box(poly, box, resolution):
  265. w, h = box[2] - box[0], box[3] - box[1]
  266. polygons = [np.asarray(p, dtype=np.float64) for p in poly]
  267. for p in polygons:
  268. p[0::2] = p[0::2] - box[0]
  269. p[1::2] = p[1::2] - box[1]
  270. ratio_h = resolution / max(h, 0.1)
  271. ratio_w = resolution / max(w, 0.1)
  272. if ratio_h == ratio_w:
  273. for p in polygons:
  274. p *= ratio_h
  275. else:
  276. for p in polygons:
  277. p[0::2] *= ratio_w
  278. p[1::2] *= ratio_h
  279. # 3. Rasterize the polygons with coco api
  280. mask = polygons_to_mask(polygons, resolution, resolution)
  281. mask = paddle.to_tensor(mask, dtype='int32')
  282. return mask
  283. def generate_mask_target(gt_segms, rois, labels_int32, sampled_gt_inds,
  284. num_classes, resolution):
  285. mask_rois = []
  286. mask_rois_num = []
  287. tgt_masks = []
  288. tgt_classes = []
  289. mask_index = []
  290. tgt_weights = []
  291. for k in range(len(rois)):
  292. labels_per_im = labels_int32[k]
  293. # select rois labeled with foreground
  294. fg_inds = paddle.nonzero(
  295. paddle.logical_and(labels_per_im != -1, labels_per_im !=
  296. num_classes))
  297. has_fg = True
  298. # generate fake roi if foreground is empty
  299. if fg_inds.numel() == 0:
  300. has_fg = False
  301. fg_inds = paddle.ones([1], dtype='int32')
  302. inds_per_im = sampled_gt_inds[k]
  303. inds_per_im = paddle.gather(inds_per_im, fg_inds)
  304. rois_per_im = rois[k]
  305. fg_rois = paddle.gather(rois_per_im, fg_inds)
  306. # Copy the foreground roi to cpu
  307. # to generate mask target with ground-truth
  308. boxes = fg_rois.numpy()
  309. gt_segms_per_im = gt_segms[k]
  310. new_segm = []
  311. inds_per_im = inds_per_im.numpy()
  312. if len(gt_segms_per_im) > 0:
  313. for i in inds_per_im:
  314. new_segm.append(gt_segms_per_im[i])
  315. fg_inds_new = fg_inds.reshape([-1]).numpy()
  316. results = []
  317. if len(gt_segms_per_im) > 0:
  318. for j in fg_inds_new:
  319. results.append(
  320. rasterize_polygons_within_box(new_segm[j], boxes[j],
  321. resolution))
  322. else:
  323. results.append(paddle.ones([resolution, resolution], dtype='int32'))
  324. fg_classes = paddle.gather(labels_per_im, fg_inds)
  325. weight = paddle.ones([fg_rois.shape[0]], dtype='float32')
  326. if not has_fg:
  327. # now all sampled classes are background
  328. # which will cause error in loss calculation,
  329. # make fake classes with weight of 0.
  330. fg_classes = paddle.zeros([1], dtype='int32')
  331. weight = weight - 1
  332. tgt_mask = paddle.stack(results)
  333. tgt_mask.stop_gradient = True
  334. fg_rois.stop_gradient = True
  335. mask_index.append(fg_inds)
  336. mask_rois.append(fg_rois)
  337. mask_rois_num.append(paddle.shape(fg_rois)[0])
  338. tgt_classes.append(fg_classes)
  339. tgt_masks.append(tgt_mask)
  340. tgt_weights.append(weight)
  341. mask_index = paddle.concat(mask_index)
  342. mask_rois_num = paddle.concat(mask_rois_num)
  343. tgt_classes = paddle.concat(tgt_classes, axis=0)
  344. tgt_masks = paddle.concat(tgt_masks, axis=0)
  345. tgt_weights = paddle.concat(tgt_weights, axis=0)
  346. return mask_rois, mask_rois_num, tgt_classes, tgt_masks, mask_index, tgt_weights
  347. def libra_sample_pos(max_overlaps, max_classes, pos_inds, num_expected):
  348. if len(pos_inds) <= num_expected:
  349. return pos_inds
  350. else:
  351. unique_gt_inds = np.unique(max_classes[pos_inds])
  352. num_gts = len(unique_gt_inds)
  353. num_per_gt = int(round(num_expected / float(num_gts)) + 1)
  354. sampled_inds = []
  355. for i in unique_gt_inds:
  356. inds = np.nonzero(max_classes == i)[0]
  357. before_len = len(inds)
  358. inds = list(set(inds) & set(pos_inds))
  359. after_len = len(inds)
  360. if len(inds) > num_per_gt:
  361. inds = np.random.choice(inds, size=num_per_gt, replace=False)
  362. sampled_inds.extend(list(inds)) # combine as a new sampler
  363. if len(sampled_inds) < num_expected:
  364. num_extra = num_expected - len(sampled_inds)
  365. extra_inds = np.array(list(set(pos_inds) - set(sampled_inds)))
  366. assert len(sampled_inds) + len(extra_inds) == len(pos_inds), \
  367. "sum of sampled_inds({}) and extra_inds({}) length must be equal with pos_inds({})!".format(
  368. len(sampled_inds), len(extra_inds), len(pos_inds))
  369. if len(extra_inds) > num_extra:
  370. extra_inds = np.random.choice(
  371. extra_inds, size=num_extra, replace=False)
  372. sampled_inds.extend(extra_inds.tolist())
  373. elif len(sampled_inds) > num_expected:
  374. sampled_inds = np.random.choice(
  375. sampled_inds, size=num_expected, replace=False)
  376. return paddle.to_tensor(sampled_inds)
  377. def libra_sample_via_interval(max_overlaps, full_set, num_expected, floor_thr,
  378. num_bins, bg_thresh):
  379. max_iou = max_overlaps.max()
  380. iou_interval = (max_iou - floor_thr) / num_bins
  381. per_num_expected = int(num_expected / num_bins)
  382. sampled_inds = []
  383. for i in range(num_bins):
  384. start_iou = floor_thr + i * iou_interval
  385. end_iou = floor_thr + (i + 1) * iou_interval
  386. tmp_set = set(
  387. np.where(
  388. np.logical_and(max_overlaps >= start_iou, max_overlaps <
  389. end_iou))[0])
  390. tmp_inds = list(tmp_set & full_set)
  391. if len(tmp_inds) > per_num_expected:
  392. tmp_sampled_set = np.random.choice(
  393. tmp_inds, size=per_num_expected, replace=False)
  394. else:
  395. tmp_sampled_set = np.array(tmp_inds, dtype=np.int)
  396. sampled_inds.append(tmp_sampled_set)
  397. sampled_inds = np.concatenate(sampled_inds)
  398. if len(sampled_inds) < num_expected:
  399. num_extra = num_expected - len(sampled_inds)
  400. extra_inds = np.array(list(full_set - set(sampled_inds)))
  401. assert len(sampled_inds) + len(extra_inds) == len(full_set), \
  402. "sum of sampled_inds({}) and extra_inds({}) length must be equal with full_set({})!".format(
  403. len(sampled_inds), len(extra_inds), len(full_set))
  404. if len(extra_inds) > num_extra:
  405. extra_inds = np.random.choice(extra_inds, num_extra, replace=False)
  406. sampled_inds = np.concatenate([sampled_inds, extra_inds])
  407. return sampled_inds
  408. def libra_sample_neg(max_overlaps,
  409. max_classes,
  410. neg_inds,
  411. num_expected,
  412. floor_thr=-1,
  413. floor_fraction=0,
  414. num_bins=3,
  415. bg_thresh=0.5):
  416. if len(neg_inds) <= num_expected:
  417. return neg_inds
  418. else:
  419. # balance sampling for negative samples
  420. neg_set = set(neg_inds.tolist())
  421. if floor_thr > 0:
  422. floor_set = set(
  423. np.where(
  424. np.logical_and(max_overlaps >= 0, max_overlaps < floor_thr))
  425. [0])
  426. iou_sampling_set = set(np.where(max_overlaps >= floor_thr)[0])
  427. elif floor_thr == 0:
  428. floor_set = set(np.where(max_overlaps == 0)[0])
  429. iou_sampling_set = set(np.where(max_overlaps > floor_thr)[0])
  430. else:
  431. floor_set = set()
  432. iou_sampling_set = set(np.where(max_overlaps > floor_thr)[0])
  433. floor_thr = 0
  434. floor_neg_inds = list(floor_set & neg_set)
  435. iou_sampling_neg_inds = list(iou_sampling_set & neg_set)
  436. num_expected_iou_sampling = int(num_expected * (1 - floor_fraction))
  437. if len(iou_sampling_neg_inds) > num_expected_iou_sampling:
  438. if num_bins >= 2:
  439. iou_sampled_inds = libra_sample_via_interval(
  440. max_overlaps,
  441. set(iou_sampling_neg_inds), num_expected_iou_sampling,
  442. floor_thr, num_bins, bg_thresh)
  443. else:
  444. iou_sampled_inds = np.random.choice(
  445. iou_sampling_neg_inds,
  446. size=num_expected_iou_sampling,
  447. replace=False)
  448. else:
  449. iou_sampled_inds = np.array(iou_sampling_neg_inds, dtype=np.int)
  450. num_expected_floor = num_expected - len(iou_sampled_inds)
  451. if len(floor_neg_inds) > num_expected_floor:
  452. sampled_floor_inds = np.random.choice(
  453. floor_neg_inds, size=num_expected_floor, replace=False)
  454. else:
  455. sampled_floor_inds = np.array(floor_neg_inds, dtype=np.int)
  456. sampled_inds = np.concatenate((sampled_floor_inds, iou_sampled_inds))
  457. if len(sampled_inds) < num_expected:
  458. num_extra = num_expected - len(sampled_inds)
  459. extra_inds = np.array(list(neg_set - set(sampled_inds)))
  460. if len(extra_inds) > num_extra:
  461. extra_inds = np.random.choice(
  462. extra_inds, size=num_extra, replace=False)
  463. sampled_inds = np.concatenate((sampled_inds, extra_inds))
  464. return paddle.to_tensor(sampled_inds)
  465. def libra_label_box(anchors, gt_boxes, gt_classes, positive_overlap,
  466. negative_overlap, num_classes):
  467. # TODO: use paddle API to speed up
  468. gt_classes = gt_classes.numpy()
  469. gt_overlaps = np.zeros((anchors.shape[0], num_classes))
  470. matches = np.zeros((anchors.shape[0]), dtype=np.int32)
  471. if len(gt_boxes) > 0:
  472. proposal_to_gt_overlaps = bbox_overlaps(anchors, gt_boxes).numpy()
  473. overlaps_argmax = proposal_to_gt_overlaps.argmax(axis=1)
  474. overlaps_max = proposal_to_gt_overlaps.max(axis=1)
  475. # Boxes which with non-zero overlap with gt boxes
  476. overlapped_boxes_ind = np.where(overlaps_max > 0)[0]
  477. overlapped_boxes_gt_classes = gt_classes[overlaps_argmax[
  478. overlapped_boxes_ind]]
  479. for idx in range(len(overlapped_boxes_ind)):
  480. gt_overlaps[overlapped_boxes_ind[idx], overlapped_boxes_gt_classes[
  481. idx]] = overlaps_max[overlapped_boxes_ind[idx]]
  482. matches[overlapped_boxes_ind[idx]] = overlaps_argmax[
  483. overlapped_boxes_ind[idx]]
  484. gt_overlaps = paddle.to_tensor(gt_overlaps)
  485. matches = paddle.to_tensor(matches)
  486. matched_vals = paddle.max(gt_overlaps, axis=1)
  487. match_labels = paddle.full(matches.shape, -1, dtype='int32')
  488. match_labels = paddle.where(matched_vals < negative_overlap,
  489. paddle.zeros_like(match_labels), match_labels)
  490. match_labels = paddle.where(matched_vals >= positive_overlap,
  491. paddle.ones_like(match_labels), match_labels)
  492. return matches, match_labels, matched_vals
  493. def libra_sample_bbox(matches,
  494. match_labels,
  495. matched_vals,
  496. gt_classes,
  497. batch_size_per_im,
  498. num_classes,
  499. fg_fraction,
  500. fg_thresh,
  501. bg_thresh,
  502. num_bins,
  503. use_random=True,
  504. is_cascade_rcnn=False):
  505. rois_per_image = int(batch_size_per_im)
  506. fg_rois_per_im = int(np.round(fg_fraction * rois_per_image))
  507. bg_rois_per_im = rois_per_image - fg_rois_per_im
  508. if is_cascade_rcnn:
  509. fg_inds = paddle.nonzero(matched_vals >= fg_thresh)
  510. bg_inds = paddle.nonzero(matched_vals < bg_thresh)
  511. else:
  512. matched_vals_np = matched_vals.numpy()
  513. match_labels_np = match_labels.numpy()
  514. # sample fg
  515. fg_inds = paddle.nonzero(matched_vals >= fg_thresh).flatten()
  516. fg_nums = int(np.minimum(fg_rois_per_im, fg_inds.shape[0]))
  517. if (fg_inds.shape[0] > fg_nums) and use_random:
  518. fg_inds = libra_sample_pos(matched_vals_np, match_labels_np,
  519. fg_inds.numpy(), fg_rois_per_im)
  520. fg_inds = fg_inds[:fg_nums]
  521. # sample bg
  522. bg_inds = paddle.nonzero(matched_vals < bg_thresh).flatten()
  523. bg_nums = int(np.minimum(rois_per_image - fg_nums, bg_inds.shape[0]))
  524. if (bg_inds.shape[0] > bg_nums) and use_random:
  525. bg_inds = libra_sample_neg(
  526. matched_vals_np,
  527. match_labels_np,
  528. bg_inds.numpy(),
  529. bg_rois_per_im,
  530. num_bins=num_bins,
  531. bg_thresh=bg_thresh)
  532. bg_inds = bg_inds[:bg_nums]
  533. sampled_inds = paddle.concat([fg_inds, bg_inds])
  534. gt_classes = paddle.gather(gt_classes, matches)
  535. gt_classes = paddle.where(match_labels == 0,
  536. paddle.ones_like(gt_classes) * num_classes,
  537. gt_classes)
  538. gt_classes = paddle.where(match_labels == -1,
  539. paddle.ones_like(gt_classes) * -1, gt_classes)
  540. sampled_gt_classes = paddle.gather(gt_classes, sampled_inds)
  541. return sampled_inds, sampled_gt_classes
  542. def libra_generate_proposal_target(rpn_rois,
  543. gt_classes,
  544. gt_boxes,
  545. batch_size_per_im,
  546. fg_fraction,
  547. fg_thresh,
  548. bg_thresh,
  549. num_classes,
  550. use_random=True,
  551. is_cascade_rcnn=False,
  552. max_overlaps=None,
  553. num_bins=3):
  554. rois_with_gt = []
  555. tgt_labels = []
  556. tgt_bboxes = []
  557. sampled_max_overlaps = []
  558. tgt_gt_inds = []
  559. new_rois_num = []
  560. for i, rpn_roi in enumerate(rpn_rois):
  561. max_overlap = max_overlaps[i] if is_cascade_rcnn else None
  562. gt_bbox = gt_boxes[i]
  563. gt_class = paddle.squeeze(gt_classes[i], axis=-1)
  564. if is_cascade_rcnn:
  565. rpn_roi = filter_roi(rpn_roi, max_overlap)
  566. bbox = paddle.concat([rpn_roi, gt_bbox])
  567. # Step1: label bbox
  568. matches, match_labels, matched_vals = libra_label_box(
  569. bbox, gt_bbox, gt_class, fg_thresh, bg_thresh, num_classes)
  570. # Step2: sample bbox
  571. sampled_inds, sampled_gt_classes = libra_sample_bbox(
  572. matches, match_labels, matched_vals, gt_class, batch_size_per_im,
  573. num_classes, fg_fraction, fg_thresh, bg_thresh, num_bins,
  574. use_random, is_cascade_rcnn)
  575. # Step3: make output
  576. rois_per_image = paddle.gather(bbox, sampled_inds)
  577. sampled_gt_ind = paddle.gather(matches, sampled_inds)
  578. sampled_bbox = paddle.gather(gt_bbox, sampled_gt_ind)
  579. sampled_overlap = paddle.gather(matched_vals, sampled_inds)
  580. rois_per_image.stop_gradient = True
  581. sampled_gt_ind.stop_gradient = True
  582. sampled_bbox.stop_gradient = True
  583. sampled_overlap.stop_gradient = True
  584. tgt_labels.append(sampled_gt_classes)
  585. tgt_bboxes.append(sampled_bbox)
  586. rois_with_gt.append(rois_per_image)
  587. sampled_max_overlaps.append(sampled_overlap)
  588. tgt_gt_inds.append(sampled_gt_ind)
  589. new_rois_num.append(paddle.shape(sampled_inds)[0])
  590. new_rois_num = paddle.concat(new_rois_num)
  591. # rois_with_gt, tgt_labels, tgt_bboxes, tgt_gt_inds, new_rois_num
  592. return rois_with_gt, tgt_labels, tgt_bboxes, tgt_gt_inds, new_rois_num