bbox_utils.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import math
  15. import paddle
  16. import numpy as np
  17. def bbox2delta(src_boxes, tgt_boxes, weights):
  18. src_w = src_boxes[:, 2] - src_boxes[:, 0]
  19. src_h = src_boxes[:, 3] - src_boxes[:, 1]
  20. src_ctr_x = src_boxes[:, 0] + 0.5 * src_w
  21. src_ctr_y = src_boxes[:, 1] + 0.5 * src_h
  22. tgt_w = tgt_boxes[:, 2] - tgt_boxes[:, 0]
  23. tgt_h = tgt_boxes[:, 3] - tgt_boxes[:, 1]
  24. tgt_ctr_x = tgt_boxes[:, 0] + 0.5 * tgt_w
  25. tgt_ctr_y = tgt_boxes[:, 1] + 0.5 * tgt_h
  26. wx, wy, ww, wh = weights
  27. dx = wx * (tgt_ctr_x - src_ctr_x) / src_w
  28. dy = wy * (tgt_ctr_y - src_ctr_y) / src_h
  29. dw = ww * paddle.log(tgt_w / src_w)
  30. dh = wh * paddle.log(tgt_h / src_h)
  31. deltas = paddle.stack((dx, dy, dw, dh), axis=1)
  32. return deltas
  33. def delta2bbox(deltas, boxes, weights):
  34. clip_scale = math.log(1000.0 / 16)
  35. widths = boxes[:, 2] - boxes[:, 0]
  36. heights = boxes[:, 3] - boxes[:, 1]
  37. ctr_x = boxes[:, 0] + 0.5 * widths
  38. ctr_y = boxes[:, 1] + 0.5 * heights
  39. wx, wy, ww, wh = weights
  40. dx = deltas[:, 0::4] / wx
  41. dy = deltas[:, 1::4] / wy
  42. dw = deltas[:, 2::4] / ww
  43. dh = deltas[:, 3::4] / wh
  44. # Prevent sending too large values into paddle.exp()
  45. dw = paddle.clip(dw, max=clip_scale)
  46. dh = paddle.clip(dh, max=clip_scale)
  47. pred_ctr_x = dx * widths.unsqueeze(1) + ctr_x.unsqueeze(1)
  48. pred_ctr_y = dy * heights.unsqueeze(1) + ctr_y.unsqueeze(1)
  49. pred_w = paddle.exp(dw) * widths.unsqueeze(1)
  50. pred_h = paddle.exp(dh) * heights.unsqueeze(1)
  51. pred_boxes = []
  52. pred_boxes.append(pred_ctr_x - 0.5 * pred_w)
  53. pred_boxes.append(pred_ctr_y - 0.5 * pred_h)
  54. pred_boxes.append(pred_ctr_x + 0.5 * pred_w)
  55. pred_boxes.append(pred_ctr_y + 0.5 * pred_h)
  56. pred_boxes = paddle.stack(pred_boxes, axis=-1)
  57. return pred_boxes
  58. def expand_bbox(bboxes, scale):
  59. w_half = (bboxes[:, 2] - bboxes[:, 0]) * .5
  60. h_half = (bboxes[:, 3] - bboxes[:, 1]) * .5
  61. x_c = (bboxes[:, 2] + bboxes[:, 0]) * .5
  62. y_c = (bboxes[:, 3] + bboxes[:, 1]) * .5
  63. w_half *= scale
  64. h_half *= scale
  65. bboxes_exp = np.zeros(bboxes.shape, dtype=np.float32)
  66. bboxes_exp[:, 0] = x_c - w_half
  67. bboxes_exp[:, 2] = x_c + w_half
  68. bboxes_exp[:, 1] = y_c - h_half
  69. bboxes_exp[:, 3] = y_c + h_half
  70. return bboxes_exp
  71. def clip_bbox(boxes, im_shape):
  72. h, w = im_shape[0], im_shape[1]
  73. x1 = boxes[:, 0].clip(0, w)
  74. y1 = boxes[:, 1].clip(0, h)
  75. x2 = boxes[:, 2].clip(0, w)
  76. y2 = boxes[:, 3].clip(0, h)
  77. return paddle.stack([x1, y1, x2, y2], axis=1)
  78. def nonempty_bbox(boxes, min_size=0, return_mask=False):
  79. w = boxes[:, 2] - boxes[:, 0]
  80. h = boxes[:, 3] - boxes[:, 1]
  81. mask = paddle.logical_and(h > min_size, w > min_size)
  82. if return_mask:
  83. return mask
  84. keep = paddle.nonzero(mask).flatten()
  85. return keep
  86. def bbox_area(boxes):
  87. return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
  88. def bbox_overlaps(boxes1, boxes2):
  89. """
  90. Calculate overlaps between boxes1 and boxes2
  91. Args:
  92. boxes1 (Tensor): boxes with shape [M, 4]
  93. boxes2 (Tensor): boxes with shape [N, 4]
  94. Return:
  95. overlaps (Tensor): overlaps between boxes1 and boxes2 with shape [M, N]
  96. """
  97. M = boxes1.shape[0]
  98. N = boxes2.shape[0]
  99. if M * N == 0:
  100. return paddle.zeros([M, N], dtype='float32')
  101. area1 = bbox_area(boxes1)
  102. area2 = bbox_area(boxes2)
  103. xy_max = paddle.minimum(
  104. paddle.unsqueeze(boxes1, 1)[:, :, 2:], boxes2[:, 2:])
  105. xy_min = paddle.maximum(
  106. paddle.unsqueeze(boxes1, 1)[:, :, :2], boxes2[:, :2])
  107. width_height = xy_max - xy_min
  108. width_height = width_height.clip(min=0)
  109. inter = width_height.prod(axis=2)
  110. overlaps = paddle.where(inter > 0, inter /
  111. (paddle.unsqueeze(area1, 1) + area2 - inter),
  112. paddle.zeros_like(inter))
  113. return overlaps
  114. def batch_bbox_overlaps(bboxes1,
  115. bboxes2,
  116. mode='iou',
  117. is_aligned=False,
  118. eps=1e-6):
  119. """Calculate overlap between two set of bboxes.
  120. If ``is_aligned `` is ``False``, then calculate the overlaps between each
  121. bbox of bboxes1 and bboxes2, otherwise the overlaps between each aligned
  122. pair of bboxes1 and bboxes2.
  123. Args:
  124. bboxes1 (Tensor): shape (B, m, 4) in <x1, y1, x2, y2> format or empty.
  125. bboxes2 (Tensor): shape (B, n, 4) in <x1, y1, x2, y2> format or empty.
  126. B indicates the batch dim, in shape (B1, B2, ..., Bn).
  127. If ``is_aligned `` is ``True``, then m and n must be equal.
  128. mode (str): "iou" (intersection over union) or "iof" (intersection over
  129. foreground).
  130. is_aligned (bool, optional): If True, then m and n must be equal.
  131. Default False.
  132. eps (float, optional): A value added to the denominator for numerical
  133. stability. Default 1e-6.
  134. Returns:
  135. Tensor: shape (m, n) if ``is_aligned `` is False else shape (m,)
  136. """
  137. assert mode in ['iou', 'iof', 'giou'], 'Unsupported mode {}'.format(mode)
  138. # Either the boxes are empty or the length of boxes's last dimenstion is 4
  139. assert (bboxes1.shape[-1] == 4 or bboxes1.shape[0] == 0)
  140. assert (bboxes2.shape[-1] == 4 or bboxes2.shape[0] == 0)
  141. # Batch dim must be the same
  142. # Batch dim: (B1, B2, ... Bn)
  143. assert bboxes1.shape[:-2] == bboxes2.shape[:-2]
  144. batch_shape = bboxes1.shape[:-2]
  145. rows = bboxes1.shape[-2] if bboxes1.shape[0] > 0 else 0
  146. cols = bboxes2.shape[-2] if bboxes2.shape[0] > 0 else 0
  147. if is_aligned:
  148. assert rows == cols
  149. if rows * cols == 0:
  150. if is_aligned:
  151. return paddle.full(batch_shape + (rows, ), 1)
  152. else:
  153. return paddle.full(batch_shape + (rows, cols), 1)
  154. area1 = (bboxes1[:, 2] - bboxes1[:, 0]) * (bboxes1[:, 3] - bboxes1[:, 1])
  155. area2 = (bboxes2[:, 2] - bboxes2[:, 0]) * (bboxes2[:, 3] - bboxes2[:, 1])
  156. if is_aligned:
  157. lt = paddle.maximum(bboxes1[:, :2], bboxes2[:, :2]) # [B, rows, 2]
  158. rb = paddle.minimum(bboxes1[:, 2:], bboxes2[:, 2:]) # [B, rows, 2]
  159. wh = (rb - lt).clip(min=0) # [B, rows, 2]
  160. overlap = wh[:, 0] * wh[:, 1]
  161. if mode in ['iou', 'giou']:
  162. union = area1 + area2 - overlap
  163. else:
  164. union = area1
  165. if mode == 'giou':
  166. enclosed_lt = paddle.minimum(bboxes1[:, :2], bboxes2[:, :2])
  167. enclosed_rb = paddle.maximum(bboxes1[:, 2:], bboxes2[:, 2:])
  168. else:
  169. lt = paddle.maximum(bboxes1[:, :2].reshape([rows, 1, 2]),
  170. bboxes2[:, :2]) # [B, rows, cols, 2]
  171. rb = paddle.minimum(bboxes1[:, 2:].reshape([rows, 1, 2]),
  172. bboxes2[:, 2:]) # [B, rows, cols, 2]
  173. wh = (rb - lt).clip(min=0) # [B, rows, cols, 2]
  174. overlap = wh[:, :, 0] * wh[:, :, 1]
  175. if mode in ['iou', 'giou']:
  176. union = area1.reshape([rows,1]) \
  177. + area2.reshape([1,cols]) - overlap
  178. else:
  179. union = area1[:, None]
  180. if mode == 'giou':
  181. enclosed_lt = paddle.minimum(bboxes1[:, :2].reshape([rows, 1, 2]),
  182. bboxes2[:, :2])
  183. enclosed_rb = paddle.maximum(bboxes1[:, 2:].reshape([rows, 1, 2]),
  184. bboxes2[:, 2:])
  185. eps = paddle.to_tensor([eps])
  186. union = paddle.maximum(union, eps)
  187. ious = overlap / union
  188. if mode in ['iou', 'iof']:
  189. return ious
  190. # calculate gious
  191. enclose_wh = (enclosed_rb - enclosed_lt).clip(min=0)
  192. enclose_area = enclose_wh[:, :, 0] * enclose_wh[:, :, 1]
  193. enclose_area = paddle.maximum(enclose_area, eps)
  194. gious = ious - (enclose_area - union) / enclose_area
  195. return 1 - gious
  196. def xywh2xyxy(box):
  197. x, y, w, h = box
  198. x1 = x - w * 0.5
  199. y1 = y - h * 0.5
  200. x2 = x + w * 0.5
  201. y2 = y + h * 0.5
  202. return [x1, y1, x2, y2]
  203. def make_grid(h, w, dtype):
  204. yv, xv = paddle.meshgrid([paddle.arange(h), paddle.arange(w)])
  205. return paddle.stack((xv, yv), 2).cast(dtype=dtype)
  206. def decode_yolo(box, anchor, downsample_ratio):
  207. """decode yolo box
  208. Args:
  209. box (list): [x, y, w, h], all have the shape [b, na, h, w, 1]
  210. anchor (list): anchor with the shape [na, 2]
  211. downsample_ratio (int): downsample ratio, default 32
  212. scale (float): scale, default 1.
  213. Return:
  214. box (list): decoded box, [x, y, w, h], all have the shape [b, na, h, w, 1]
  215. """
  216. x, y, w, h = box
  217. na, grid_h, grid_w = x.shape[1:4]
  218. grid = make_grid(grid_h, grid_w, x.dtype).reshape((1, 1, grid_h, grid_w, 2))
  219. x1 = (x + grid[:, :, :, :, 0:1]) / grid_w
  220. y1 = (y + grid[:, :, :, :, 1:2]) / grid_h
  221. anchor = paddle.to_tensor(anchor)
  222. anchor = paddle.cast(anchor, x.dtype)
  223. anchor = anchor.reshape((1, na, 1, 1, 2))
  224. w1 = paddle.exp(w) * anchor[:, :, :, :, 0:1] / (downsample_ratio * grid_w)
  225. h1 = paddle.exp(h) * anchor[:, :, :, :, 1:2] / (downsample_ratio * grid_h)
  226. return [x1, y1, w1, h1]
  227. def iou_similarity(box1, box2, eps=1e-9):
  228. """Calculate iou of box1 and box2
  229. Args:
  230. box1 (Tensor): box with the shape [N, M1, 4]
  231. box2 (Tensor): box with the shape [N, M2, 4]
  232. Return:
  233. iou (Tensor): iou between box1 and box2 with the shape [N, M1, M2]
  234. """
  235. box1 = box1.unsqueeze(2) # [N, M1, 4] -> [N, M1, 1, 4]
  236. box2 = box2.unsqueeze(1) # [N, M2, 4] -> [N, 1, M2, 4]
  237. px1y1, px2y2 = box1[:, :, :, 0:2], box1[:, :, :, 2:4]
  238. gx1y1, gx2y2 = box2[:, :, :, 0:2], box2[:, :, :, 2:4]
  239. x1y1 = paddle.maximum(px1y1, gx1y1)
  240. x2y2 = paddle.minimum(px2y2, gx2y2)
  241. overlap = (x2y2 - x1y1).clip(0).prod(-1)
  242. area1 = (px2y2 - px1y1).clip(0).prod(-1)
  243. area2 = (gx2y2 - gx1y1).clip(0).prod(-1)
  244. union = area1 + area2 - overlap + eps
  245. return overlap / union
  246. def bbox_iou(box1, box2, giou=False, diou=False, ciou=False, eps=1e-9):
  247. """calculate the iou of box1 and box2
  248. Args:
  249. box1 (list): [x, y, w, h], all have the shape [b, na, h, w, 1]
  250. box2 (list): [x, y, w, h], all have the shape [b, na, h, w, 1]
  251. giou (bool): whether use giou or not, default False
  252. diou (bool): whether use diou or not, default False
  253. ciou (bool): whether use ciou or not, default False
  254. eps (float): epsilon to avoid divide by zero
  255. Return:
  256. iou (Tensor): iou of box1 and box1, with the shape [b, na, h, w, 1]
  257. """
  258. px1, py1, px2, py2 = box1
  259. gx1, gy1, gx2, gy2 = box2
  260. x1 = paddle.maximum(px1, gx1)
  261. y1 = paddle.maximum(py1, gy1)
  262. x2 = paddle.minimum(px2, gx2)
  263. y2 = paddle.minimum(py2, gy2)
  264. overlap = ((x2 - x1).clip(0)) * ((y2 - y1).clip(0))
  265. area1 = (px2 - px1) * (py2 - py1)
  266. area1 = area1.clip(0)
  267. area2 = (gx2 - gx1) * (gy2 - gy1)
  268. area2 = area2.clip(0)
  269. union = area1 + area2 - overlap + eps
  270. iou = overlap / union
  271. if giou or ciou or diou:
  272. # convex w, h
  273. cw = paddle.maximum(px2, gx2) - paddle.minimum(px1, gx1)
  274. ch = paddle.maximum(py2, gy2) - paddle.minimum(py1, gy1)
  275. if giou:
  276. c_area = cw * ch + eps
  277. return iou - (c_area - union) / c_area
  278. else:
  279. # convex diagonal squared
  280. c2 = cw**2 + ch**2 + eps
  281. # center distance
  282. rho2 = ((px1 + px2 - gx1 - gx2)**2 + (py1 + py2 - gy1 - gy2)**2) / 4
  283. if diou:
  284. return iou - rho2 / c2
  285. else:
  286. w1, h1 = px2 - px1, py2 - py1 + eps
  287. w2, h2 = gx2 - gx1, gy2 - gy1 + eps
  288. delta = paddle.atan(w1 / h1) - paddle.atan(w2 / h2)
  289. v = (4 / math.pi**2) * paddle.pow(delta, 2)
  290. alpha = v / (1 + eps - iou + v)
  291. alpha.stop_gradient = True
  292. return iou - (rho2 / c2 + v * alpha)
  293. else:
  294. return iou
  295. def rect2rbox(bboxes):
  296. """
  297. :param bboxes: shape (n, 4) (xmin, ymin, xmax, ymax)
  298. :return: dbboxes: shape (n, 5) (x_ctr, y_ctr, w, h, angle)
  299. """
  300. bboxes = bboxes.reshape(-1, 4)
  301. num_boxes = bboxes.shape[0]
  302. x_ctr = (bboxes[:, 2] + bboxes[:, 0]) / 2.0
  303. y_ctr = (bboxes[:, 3] + bboxes[:, 1]) / 2.0
  304. edges1 = np.abs(bboxes[:, 2] - bboxes[:, 0])
  305. edges2 = np.abs(bboxes[:, 3] - bboxes[:, 1])
  306. angles = np.zeros([num_boxes], dtype=bboxes.dtype)
  307. inds = edges1 < edges2
  308. rboxes = np.stack((x_ctr, y_ctr, edges1, edges2, angles), axis=1)
  309. rboxes[inds, 2] = edges2[inds]
  310. rboxes[inds, 3] = edges1[inds]
  311. rboxes[inds, 4] = np.pi / 2.0
  312. return rboxes
  313. def delta2rbox(rrois,
  314. deltas,
  315. means=[0, 0, 0, 0, 0],
  316. stds=[1, 1, 1, 1, 1],
  317. wh_ratio_clip=1e-6):
  318. """
  319. :param rrois: (cx, cy, w, h, theta)
  320. :param deltas: (dx, dy, dw, dh, dtheta)
  321. :param means:
  322. :param stds:
  323. :param wh_ratio_clip:
  324. :return:
  325. """
  326. means = paddle.to_tensor(means)
  327. stds = paddle.to_tensor(stds)
  328. deltas = paddle.reshape(deltas, [-1, deltas.shape[-1]])
  329. denorm_deltas = deltas * stds + means
  330. dx = denorm_deltas[:, 0]
  331. dy = denorm_deltas[:, 1]
  332. dw = denorm_deltas[:, 2]
  333. dh = denorm_deltas[:, 3]
  334. dangle = denorm_deltas[:, 4]
  335. max_ratio = np.abs(np.log(wh_ratio_clip))
  336. dw = paddle.clip(dw, min=-max_ratio, max=max_ratio)
  337. dh = paddle.clip(dh, min=-max_ratio, max=max_ratio)
  338. rroi_x = rrois[:, 0]
  339. rroi_y = rrois[:, 1]
  340. rroi_w = rrois[:, 2]
  341. rroi_h = rrois[:, 3]
  342. rroi_angle = rrois[:, 4]
  343. gx = dx * rroi_w * paddle.cos(rroi_angle) - dy * rroi_h * paddle.sin(
  344. rroi_angle) + rroi_x
  345. gy = dx * rroi_w * paddle.sin(rroi_angle) + dy * rroi_h * paddle.cos(
  346. rroi_angle) + rroi_y
  347. gw = rroi_w * dw.exp()
  348. gh = rroi_h * dh.exp()
  349. ga = np.pi * dangle + rroi_angle
  350. ga = (ga + np.pi / 4) % np.pi - np.pi / 4
  351. ga = paddle.to_tensor(ga)
  352. gw = paddle.to_tensor(gw, dtype='float32')
  353. gh = paddle.to_tensor(gh, dtype='float32')
  354. bboxes = paddle.stack([gx, gy, gw, gh, ga], axis=-1)
  355. return bboxes
  356. def rbox2delta(proposals, gt, means=[0, 0, 0, 0, 0], stds=[1, 1, 1, 1, 1]):
  357. """
  358. Args:
  359. proposals:
  360. gt:
  361. means: 1x5
  362. stds: 1x5
  363. Returns:
  364. """
  365. proposals = proposals.astype(np.float64)
  366. PI = np.pi
  367. gt_widths = gt[..., 2]
  368. gt_heights = gt[..., 3]
  369. gt_angle = gt[..., 4]
  370. proposals_widths = proposals[..., 2]
  371. proposals_heights = proposals[..., 3]
  372. proposals_angle = proposals[..., 4]
  373. coord = gt[..., 0:2] - proposals[..., 0:2]
  374. dx = (np.cos(proposals[..., 4]) * coord[..., 0] + np.sin(proposals[..., 4])
  375. * coord[..., 1]) / proposals_widths
  376. dy = (-np.sin(proposals[..., 4]) * coord[..., 0] + np.cos(proposals[..., 4])
  377. * coord[..., 1]) / proposals_heights
  378. dw = np.log(gt_widths / proposals_widths)
  379. dh = np.log(gt_heights / proposals_heights)
  380. da = (gt_angle - proposals_angle)
  381. da = (da + PI / 4) % PI - PI / 4
  382. da /= PI
  383. deltas = np.stack([dx, dy, dw, dh, da], axis=-1)
  384. means = np.array(means, dtype=deltas.dtype)
  385. stds = np.array(stds, dtype=deltas.dtype)
  386. deltas = (deltas - means) / stds
  387. deltas = deltas.astype(np.float32)
  388. return deltas
  389. def bbox_decode(bbox_preds,
  390. anchors,
  391. means=[0, 0, 0, 0, 0],
  392. stds=[1, 1, 1, 1, 1]):
  393. """decode bbox from deltas
  394. Args:
  395. bbox_preds: [N,H,W,5]
  396. anchors: [H*W,5]
  397. return:
  398. bboxes: [N,H,W,5]
  399. """
  400. means = paddle.to_tensor(means)
  401. stds = paddle.to_tensor(stds)
  402. num_imgs, H, W, _ = bbox_preds.shape
  403. bboxes_list = []
  404. for img_id in range(num_imgs):
  405. bbox_pred = bbox_preds[img_id]
  406. # bbox_pred.shape=[5,H,W]
  407. bbox_delta = bbox_pred
  408. anchors = paddle.to_tensor(anchors)
  409. bboxes = delta2rbox(
  410. anchors, bbox_delta, means, stds, wh_ratio_clip=1e-6)
  411. bboxes = paddle.reshape(bboxes, [H, W, 5])
  412. bboxes_list.append(bboxes)
  413. return paddle.stack(bboxes_list, axis=0)
  414. def poly2rbox(polys):
  415. """
  416. poly:[x0,y0,x1,y1,x2,y2,x3,y3]
  417. to
  418. rotated_boxes:[x_ctr,y_ctr,w,h,angle]
  419. """
  420. rotated_boxes = []
  421. for poly in polys:
  422. poly = np.array(poly[:8], dtype=np.float32)
  423. pt1 = (poly[0], poly[1])
  424. pt2 = (poly[2], poly[3])
  425. pt3 = (poly[4], poly[5])
  426. pt4 = (poly[6], poly[7])
  427. edge1 = np.sqrt((pt1[0] - pt2[0]) * (pt1[0] - pt2[0]) + (pt1[1] - pt2[
  428. 1]) * (pt1[1] - pt2[1]))
  429. edge2 = np.sqrt((pt2[0] - pt3[0]) * (pt2[0] - pt3[0]) + (pt2[1] - pt3[
  430. 1]) * (pt2[1] - pt3[1]))
  431. width = max(edge1, edge2)
  432. height = min(edge1, edge2)
  433. rbox_angle = 0
  434. if edge1 > edge2:
  435. rbox_angle = np.arctan2(
  436. float(pt2[1] - pt1[1]), float(pt2[0] - pt1[0]))
  437. elif edge2 >= edge1:
  438. rbox_angle = np.arctan2(
  439. float(pt4[1] - pt1[1]), float(pt4[0] - pt1[0]))
  440. def norm_angle(angle, range=[-np.pi / 4, np.pi]):
  441. return (angle - range[0]) % range[1] + range[0]
  442. rbox_angle = norm_angle(rbox_angle)
  443. x_ctr = float(pt1[0] + pt3[0]) / 2
  444. y_ctr = float(pt1[1] + pt3[1]) / 2
  445. rotated_box = np.array([x_ctr, y_ctr, width, height, rbox_angle])
  446. rotated_boxes.append(rotated_box)
  447. ret_rotated_boxes = np.array(rotated_boxes)
  448. assert ret_rotated_boxes.shape[1] == 5
  449. return ret_rotated_boxes
  450. def cal_line_length(point1, point2):
  451. import math
  452. return math.sqrt(
  453. math.pow(point1[0] - point2[0], 2) + math.pow(point1[1] - point2[1], 2))
  454. def get_best_begin_point_single(coordinate):
  455. x1, y1, x2, y2, x3, y3, x4, y4 = coordinate
  456. xmin = min(x1, x2, x3, x4)
  457. ymin = min(y1, y2, y3, y4)
  458. xmax = max(x1, x2, x3, x4)
  459. ymax = max(y1, y2, y3, y4)
  460. combinate = [[[x1, y1], [x2, y2], [x3, y3], [x4, y4]],
  461. [[x4, y4], [x1, y1], [x2, y2], [x3, y3]],
  462. [[x3, y3], [x4, y4], [x1, y1], [x2, y2]],
  463. [[x2, y2], [x3, y3], [x4, y4], [x1, y1]]]
  464. dst_coordinate = [[xmin, ymin], [xmax, ymin], [xmax, ymax], [xmin, ymax]]
  465. force = 100000000.0
  466. force_flag = 0
  467. for i in range(4):
  468. temp_force = cal_line_length(combinate[i][0], dst_coordinate[0]) \
  469. + cal_line_length(combinate[i][1], dst_coordinate[1]) \
  470. + cal_line_length(combinate[i][2], dst_coordinate[2]) \
  471. + cal_line_length(combinate[i][3], dst_coordinate[3])
  472. if temp_force < force:
  473. force = temp_force
  474. force_flag = i
  475. if force_flag != 0:
  476. pass
  477. return np.array(combinate[force_flag]).reshape(8)
  478. def rbox2poly_np(rrects):
  479. """
  480. rrect:[x_ctr,y_ctr,w,h,angle]
  481. to
  482. poly:[x0,y0,x1,y1,x2,y2,x3,y3]
  483. """
  484. polys = []
  485. for i in range(rrects.shape[0]):
  486. rrect = rrects[i]
  487. # x_ctr, y_ctr, width, height, angle = rrect[:5]
  488. x_ctr = rrect[0]
  489. y_ctr = rrect[1]
  490. width = rrect[2]
  491. height = rrect[3]
  492. angle = rrect[4]
  493. tl_x, tl_y, br_x, br_y = -width / 2, -height / 2, width / 2, height / 2
  494. rect = np.array([[tl_x, br_x, br_x, tl_x], [tl_y, tl_y, br_y, br_y]])
  495. R = np.array([[np.cos(angle), -np.sin(angle)],
  496. [np.sin(angle), np.cos(angle)]])
  497. poly = R.dot(rect)
  498. x0, x1, x2, x3 = poly[0, :4] + x_ctr
  499. y0, y1, y2, y3 = poly[1, :4] + y_ctr
  500. poly = np.array([x0, y0, x1, y1, x2, y2, x3, y3], dtype=np.float32)
  501. poly = get_best_begin_point_single(poly)
  502. polys.append(poly)
  503. polys = np.array(polys)
  504. return polys
  505. def rbox2poly(rrects):
  506. """
  507. rrect:[x_ctr,y_ctr,w,h,angle]
  508. to
  509. poly:[x0,y0,x1,y1,x2,y2,x3,y3]
  510. """
  511. N = paddle.shape(rrects)[0]
  512. x_ctr = rrects[:, 0]
  513. y_ctr = rrects[:, 1]
  514. width = rrects[:, 2]
  515. height = rrects[:, 3]
  516. angle = rrects[:, 4]
  517. tl_x, tl_y, br_x, br_y = -width * 0.5, -height * 0.5, width * 0.5, height * 0.5
  518. normal_rects = paddle.stack(
  519. [tl_x, br_x, br_x, tl_x, tl_y, tl_y, br_y, br_y], axis=0)
  520. normal_rects = paddle.reshape(normal_rects, [2, 4, N])
  521. normal_rects = paddle.transpose(normal_rects, [2, 0, 1])
  522. sin, cos = paddle.sin(angle), paddle.cos(angle)
  523. # M.shape=[N,2,2]
  524. M = paddle.stack([cos, -sin, sin, cos], axis=0)
  525. M = paddle.reshape(M, [2, 2, N])
  526. M = paddle.transpose(M, [2, 0, 1])
  527. # polys:[N,8]
  528. polys = paddle.matmul(M, normal_rects)
  529. polys = paddle.transpose(polys, [2, 1, 0])
  530. polys = paddle.reshape(polys, [-1, N])
  531. polys = paddle.transpose(polys, [1, 0])
  532. tmp = paddle.stack(
  533. [x_ctr, y_ctr, x_ctr, y_ctr, x_ctr, y_ctr, x_ctr, y_ctr], axis=1)
  534. polys = polys + tmp
  535. return polys
  536. def bbox_iou_np_expand(box1, box2, x1y1x2y2=True, eps=1e-16):
  537. """
  538. Calculate the iou of box1 and box2 with numpy.
  539. Args:
  540. box1 (ndarray): [N, 4]
  541. box2 (ndarray): [M, 4], usually N != M
  542. x1y1x2y2 (bool): whether in x1y1x2y2 stype, default True
  543. eps (float): epsilon to avoid divide by zero
  544. Return:
  545. iou (ndarray): iou of box1 and box2, [N, M]
  546. """
  547. N, M = len(box1), len(box2) # usually N != M
  548. if x1y1x2y2:
  549. b1_x1, b1_y1 = box1[:, 0], box1[:, 1]
  550. b1_x2, b1_y2 = box1[:, 2], box1[:, 3]
  551. b2_x1, b2_y1 = box2[:, 0], box2[:, 1]
  552. b2_x2, b2_y2 = box2[:, 2], box2[:, 3]
  553. else:
  554. # cxcywh style
  555. # Transform from center and width to exact coordinates
  556. b1_x1, b1_x2 = box1[:, 0] - box1[:, 2] / 2, box1[:, 0] + box1[:, 2] / 2
  557. b1_y1, b1_y2 = box1[:, 1] - box1[:, 3] / 2, box1[:, 1] + box1[:, 3] / 2
  558. b2_x1, b2_x2 = box2[:, 0] - box2[:, 2] / 2, box2[:, 0] + box2[:, 2] / 2
  559. b2_y1, b2_y2 = box2[:, 1] - box2[:, 3] / 2, box2[:, 1] + box2[:, 3] / 2
  560. # get the coordinates of the intersection rectangle
  561. inter_rect_x1 = np.zeros((N, M), dtype=np.float32)
  562. inter_rect_y1 = np.zeros((N, M), dtype=np.float32)
  563. inter_rect_x2 = np.zeros((N, M), dtype=np.float32)
  564. inter_rect_y2 = np.zeros((N, M), dtype=np.float32)
  565. for i in range(len(box2)):
  566. inter_rect_x1[:, i] = np.maximum(b1_x1, b2_x1[i])
  567. inter_rect_y1[:, i] = np.maximum(b1_y1, b2_y1[i])
  568. inter_rect_x2[:, i] = np.minimum(b1_x2, b2_x2[i])
  569. inter_rect_y2[:, i] = np.minimum(b1_y2, b2_y2[i])
  570. # Intersection area
  571. inter_area = np.maximum(inter_rect_x2 - inter_rect_x1, 0) * np.maximum(
  572. inter_rect_y2 - inter_rect_y1, 0)
  573. # Union Area
  574. b1_area = np.repeat(
  575. ((b1_x2 - b1_x1) * (b1_y2 - b1_y1)).reshape(-1, 1), M, axis=-1)
  576. b2_area = np.repeat(
  577. ((b2_x2 - b2_x1) * (b2_y2 - b2_y1)).reshape(1, -1), N, axis=0)
  578. ious = inter_area / (b1_area + b2_area - inter_area + eps)
  579. return ious
  580. def bbox2distance(points, bbox, max_dis=None, eps=0.1):
  581. """Decode bounding box based on distances.
  582. Args:
  583. points (Tensor): Shape (n, 2), [x, y].
  584. bbox (Tensor): Shape (n, 4), "xyxy" format
  585. max_dis (float): Upper bound of the distance.
  586. eps (float): a small value to ensure target < max_dis, instead <=
  587. Returns:
  588. Tensor: Decoded distances.
  589. """
  590. left = points[:, 0] - bbox[:, 0]
  591. top = points[:, 1] - bbox[:, 1]
  592. right = bbox[:, 2] - points[:, 0]
  593. bottom = bbox[:, 3] - points[:, 1]
  594. if max_dis is not None:
  595. left = left.clip(min=0, max=max_dis - eps)
  596. top = top.clip(min=0, max=max_dis - eps)
  597. right = right.clip(min=0, max=max_dis - eps)
  598. bottom = bottom.clip(min=0, max=max_dis - eps)
  599. return paddle.stack([left, top, right, bottom], -1)
  600. def distance2bbox(points, distance, max_shape=None):
  601. """Decode distance prediction to bounding box.
  602. Args:
  603. points (Tensor): Shape (n, 2), [x, y].
  604. distance (Tensor): Distance from the given point to 4
  605. boundaries (left, top, right, bottom).
  606. max_shape (tuple): Shape of the image.
  607. Returns:
  608. Tensor: Decoded bboxes.
  609. """
  610. x1 = points[:, 0] - distance[:, 0]
  611. y1 = points[:, 1] - distance[:, 1]
  612. x2 = points[:, 0] + distance[:, 2]
  613. y2 = points[:, 1] + distance[:, 3]
  614. if max_shape is not None:
  615. x1 = x1.clip(min=0, max=max_shape[1])
  616. y1 = y1.clip(min=0, max=max_shape[0])
  617. x2 = x2.clip(min=0, max=max_shape[1])
  618. y2 = y2.clip(min=0, max=max_shape[0])
  619. return paddle.stack([x1, y1, x2, y2], -1)
  620. def bbox_center(boxes):
  621. """Get bbox centers from boxes.
  622. Args:
  623. boxes (Tensor): boxes with shape (..., 4), "xmin, ymin, xmax, ymax" format.
  624. Returns:
  625. Tensor: boxes centers with shape (..., 2), "cx, cy" format.
  626. """
  627. boxes_cx = (boxes[..., 0] + boxes[..., 2]) / 2
  628. boxes_cy = (boxes[..., 1] + boxes[..., 3]) / 2
  629. return paddle.stack([boxes_cx, boxes_cy], axis=-1)
  630. def batch_distance2bbox(points, distance, max_shapes=None):
  631. """Decode distance prediction to bounding box for batch.
  632. Args:
  633. points (Tensor): [B, ..., 2], "xy" format
  634. distance (Tensor): [B, ..., 4], "ltrb" format
  635. max_shapes (Tensor): [B, 2], "h,w" format, Shape of the image.
  636. Returns:
  637. Tensor: Decoded bboxes, "x1y1x2y2" format.
  638. """
  639. lt, rb = paddle.split(distance, 2, -1)
  640. # while tensor add parameters, parameters should be better placed on the second place
  641. x1y1 = -lt + points
  642. x2y2 = rb + points
  643. out_bbox = paddle.concat([x1y1, x2y2], -1)
  644. if max_shapes is not None:
  645. max_shapes = max_shapes.flip(-1).tile([1, 2])
  646. delta_dim = out_bbox.ndim - max_shapes.ndim
  647. for _ in range(delta_dim):
  648. max_shapes.unsqueeze_(1)
  649. out_bbox = paddle.where(out_bbox < max_shapes, out_bbox, max_shapes)
  650. out_bbox = paddle.where(out_bbox > 0, out_bbox,
  651. paddle.zeros_like(out_bbox))
  652. return out_bbox
  653. def delta2bbox_v2(rois,
  654. deltas,
  655. means=(0.0, 0.0, 0.0, 0.0),
  656. stds=(1.0, 1.0, 1.0, 1.0),
  657. max_shape=None,
  658. wh_ratio_clip=16.0 / 1000.0,
  659. ctr_clip=None):
  660. """Transform network output(delta) to bboxes.
  661. Based on https://github.com/open-mmlab/mmdetection/blob/master/mmdet/core/
  662. bbox/coder/delta_xywh_bbox_coder.py
  663. Args:
  664. rois (Tensor): shape [..., 4], base bboxes, typical examples include
  665. anchor and rois
  666. deltas (Tensor): shape [..., 4], offset relative to base bboxes
  667. means (list[float]): the mean that was used to normalize deltas,
  668. must be of size 4
  669. stds (list[float]): the std that was used to normalize deltas,
  670. must be of size 4
  671. max_shape (list[float] or None): height and width of image, will be
  672. used to clip bboxes if not None
  673. wh_ratio_clip (float): to clip delta wh of decoded bboxes
  674. ctr_clip (float or None): whether to clip delta xy of decoded bboxes
  675. """
  676. if rois.size == 0:
  677. return paddle.empty_like(rois)
  678. means = paddle.to_tensor(means)
  679. stds = paddle.to_tensor(stds)
  680. deltas = deltas * stds + means
  681. dxy = deltas[..., :2]
  682. dwh = deltas[..., 2:]
  683. pxy = (rois[..., :2] + rois[..., 2:]) * 0.5
  684. pwh = rois[..., 2:] - rois[..., :2]
  685. dxy_wh = pwh * dxy
  686. max_ratio = np.abs(np.log(wh_ratio_clip))
  687. if ctr_clip is not None:
  688. dxy_wh = paddle.clip(dxy_wh, max=ctr_clip, min=-ctr_clip)
  689. dwh = paddle.clip(dwh, max=max_ratio)
  690. else:
  691. dwh = dwh.clip(min=-max_ratio, max=max_ratio)
  692. gxy = pxy + dxy_wh
  693. gwh = pwh * dwh.exp()
  694. x1y1 = gxy - (gwh * 0.5)
  695. x2y2 = gxy + (gwh * 0.5)
  696. bboxes = paddle.concat([x1y1, x2y2], axis=-1)
  697. if max_shape is not None:
  698. bboxes[..., 0::2] = bboxes[..., 0::2].clip(min=0, max=max_shape[1])
  699. bboxes[..., 1::2] = bboxes[..., 1::2].clip(min=0, max=max_shape[0])
  700. return bboxes
  701. def bbox2delta_v2(src_boxes,
  702. tgt_boxes,
  703. means=(0.0, 0.0, 0.0, 0.0),
  704. stds=(1.0, 1.0, 1.0, 1.0)):
  705. """Encode bboxes to deltas.
  706. Modified from ppdet.modeling.bbox_utils.bbox2delta.
  707. Args:
  708. src_boxes (Tensor[..., 4]): base bboxes
  709. tgt_boxes (Tensor[..., 4]): target bboxes
  710. means (list[float]): the mean that will be used to normalize delta
  711. stds (list[float]): the std that will be used to normalize delta
  712. """
  713. if src_boxes.size == 0:
  714. return paddle.empty_like(src_boxes)
  715. src_w = src_boxes[..., 2] - src_boxes[..., 0]
  716. src_h = src_boxes[..., 3] - src_boxes[..., 1]
  717. src_ctr_x = src_boxes[..., 0] + 0.5 * src_w
  718. src_ctr_y = src_boxes[..., 1] + 0.5 * src_h
  719. tgt_w = tgt_boxes[..., 2] - tgt_boxes[..., 0]
  720. tgt_h = tgt_boxes[..., 3] - tgt_boxes[..., 1]
  721. tgt_ctr_x = tgt_boxes[..., 0] + 0.5 * tgt_w
  722. tgt_ctr_y = tgt_boxes[..., 1] + 0.5 * tgt_h
  723. dx = (tgt_ctr_x - src_ctr_x) / src_w
  724. dy = (tgt_ctr_y - src_ctr_y) / src_h
  725. dw = paddle.log(tgt_w / src_w)
  726. dh = paddle.log(tgt_h / src_h)
  727. deltas = paddle.stack((dx, dy, dw, dh), axis=1) # [n, 4]
  728. means = paddle.to_tensor(means, place=src_boxes.place)
  729. stds = paddle.to_tensor(stds, place=src_boxes.place)
  730. deltas = (deltas - means) / stds
  731. return deltas