rbox_iou_op.h 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356
  1. // Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. //
  15. // The code is based on https://github.com/csuhan/s2anet/blob/master/mmdet/ops/box_iou_rotated
  16. #pragma once
  17. #include <cassert>
  18. #include <cmath>
  19. #include <vector>
  20. #ifdef __CUDACC__
  21. // Designates functions callable from the host (CPU) and the device (GPU)
  22. #define HOST_DEVICE __host__ __device__
  23. #define HOST_DEVICE_INLINE HOST_DEVICE __forceinline__
  24. #else
  25. #include <algorithm>
  26. #define HOST_DEVICE
  27. #define HOST_DEVICE_INLINE HOST_DEVICE inline
  28. #endif
  29. namespace {
  30. template <typename T>
  31. struct RotatedBox {
  32. T x_ctr, y_ctr, w, h, a;
  33. };
  34. template <typename T>
  35. struct Point {
  36. T x, y;
  37. HOST_DEVICE_INLINE Point(const T& px = 0, const T& py = 0) : x(px), y(py) {}
  38. HOST_DEVICE_INLINE Point operator+(const Point& p) const {
  39. return Point(x + p.x, y + p.y);
  40. }
  41. HOST_DEVICE_INLINE Point& operator+=(const Point& p) {
  42. x += p.x;
  43. y += p.y;
  44. return *this;
  45. }
  46. HOST_DEVICE_INLINE Point operator-(const Point& p) const {
  47. return Point(x - p.x, y - p.y);
  48. }
  49. HOST_DEVICE_INLINE Point operator*(const T coeff) const {
  50. return Point(x * coeff, y * coeff);
  51. }
  52. };
  53. template <typename T>
  54. HOST_DEVICE_INLINE T dot_2d(const Point<T>& A, const Point<T>& B) {
  55. return A.x * B.x + A.y * B.y;
  56. }
  57. template <typename T>
  58. HOST_DEVICE_INLINE T cross_2d(const Point<T>& A, const Point<T>& B) {
  59. return A.x * B.y - B.x * A.y;
  60. }
  61. template <typename T>
  62. HOST_DEVICE_INLINE void get_rotated_vertices(
  63. const RotatedBox<T>& box,
  64. Point<T> (&pts)[4]) {
  65. // M_PI / 180. == 0.01745329251
  66. //double theta = box.a * 0.01745329251;
  67. //MODIFIED
  68. double theta = box.a;
  69. T cosTheta2 = (T)cos(theta) * 0.5f;
  70. T sinTheta2 = (T)sin(theta) * 0.5f;
  71. // y: top --> down; x: left --> right
  72. pts[0].x = box.x_ctr - sinTheta2 * box.h - cosTheta2 * box.w;
  73. pts[0].y = box.y_ctr + cosTheta2 * box.h - sinTheta2 * box.w;
  74. pts[1].x = box.x_ctr + sinTheta2 * box.h - cosTheta2 * box.w;
  75. pts[1].y = box.y_ctr - cosTheta2 * box.h - sinTheta2 * box.w;
  76. pts[2].x = 2 * box.x_ctr - pts[0].x;
  77. pts[2].y = 2 * box.y_ctr - pts[0].y;
  78. pts[3].x = 2 * box.x_ctr - pts[1].x;
  79. pts[3].y = 2 * box.y_ctr - pts[1].y;
  80. }
  81. template <typename T>
  82. HOST_DEVICE_INLINE int get_intersection_points(
  83. const Point<T> (&pts1)[4],
  84. const Point<T> (&pts2)[4],
  85. Point<T> (&intersections)[24]) {
  86. // Line vector
  87. // A line from p1 to p2 is: p1 + (p2-p1)*t, t=[0,1]
  88. Point<T> vec1[4], vec2[4];
  89. for (int i = 0; i < 4; i++) {
  90. vec1[i] = pts1[(i + 1) % 4] - pts1[i];
  91. vec2[i] = pts2[(i + 1) % 4] - pts2[i];
  92. }
  93. // Line test - test all line combos for intersection
  94. int num = 0; // number of intersections
  95. for (int i = 0; i < 4; i++) {
  96. for (int j = 0; j < 4; j++) {
  97. // Solve for 2x2 Ax=b
  98. T det = cross_2d<T>(vec2[j], vec1[i]);
  99. // This takes care of parallel lines
  100. if (fabs(det) <= 1e-14) {
  101. continue;
  102. }
  103. auto vec12 = pts2[j] - pts1[i];
  104. T t1 = cross_2d<T>(vec2[j], vec12) / det;
  105. T t2 = cross_2d<T>(vec1[i], vec12) / det;
  106. if (t1 >= 0.0f && t1 <= 1.0f && t2 >= 0.0f && t2 <= 1.0f) {
  107. intersections[num++] = pts1[i] + vec1[i] * t1;
  108. }
  109. }
  110. }
  111. // Check for vertices of rect1 inside rect2
  112. {
  113. const auto& AB = vec2[0];
  114. const auto& DA = vec2[3];
  115. auto ABdotAB = dot_2d<T>(AB, AB);
  116. auto ADdotAD = dot_2d<T>(DA, DA);
  117. for (int i = 0; i < 4; i++) {
  118. // assume ABCD is the rectangle, and P is the point to be judged
  119. // P is inside ABCD iff. P's projection on AB lies within AB
  120. // and P's projection on AD lies within AD
  121. auto AP = pts1[i] - pts2[0];
  122. auto APdotAB = dot_2d<T>(AP, AB);
  123. auto APdotAD = -dot_2d<T>(AP, DA);
  124. if ((APdotAB >= 0) && (APdotAD >= 0) && (APdotAB <= ABdotAB) &&
  125. (APdotAD <= ADdotAD)) {
  126. intersections[num++] = pts1[i];
  127. }
  128. }
  129. }
  130. // Reverse the check - check for vertices of rect2 inside rect1
  131. {
  132. const auto& AB = vec1[0];
  133. const auto& DA = vec1[3];
  134. auto ABdotAB = dot_2d<T>(AB, AB);
  135. auto ADdotAD = dot_2d<T>(DA, DA);
  136. for (int i = 0; i < 4; i++) {
  137. auto AP = pts2[i] - pts1[0];
  138. auto APdotAB = dot_2d<T>(AP, AB);
  139. auto APdotAD = -dot_2d<T>(AP, DA);
  140. if ((APdotAB >= 0) && (APdotAD >= 0) && (APdotAB <= ABdotAB) &&
  141. (APdotAD <= ADdotAD)) {
  142. intersections[num++] = pts2[i];
  143. }
  144. }
  145. }
  146. return num;
  147. }
  148. template <typename T>
  149. HOST_DEVICE_INLINE int convex_hull_graham(
  150. const Point<T> (&p)[24],
  151. const int& num_in,
  152. Point<T> (&q)[24],
  153. bool shift_to_zero = false) {
  154. assert(num_in >= 2);
  155. // Step 1:
  156. // Find point with minimum y
  157. // if more than 1 points have the same minimum y,
  158. // pick the one with the minimum x.
  159. int t = 0;
  160. for (int i = 1; i < num_in; i++) {
  161. if (p[i].y < p[t].y || (p[i].y == p[t].y && p[i].x < p[t].x)) {
  162. t = i;
  163. }
  164. }
  165. auto& start = p[t]; // starting point
  166. // Step 2:
  167. // Subtract starting point from every points (for sorting in the next step)
  168. for (int i = 0; i < num_in; i++) {
  169. q[i] = p[i] - start;
  170. }
  171. // Swap the starting point to position 0
  172. auto tmp = q[0];
  173. q[0] = q[t];
  174. q[t] = tmp;
  175. // Step 3:
  176. // Sort point 1 ~ num_in according to their relative cross-product values
  177. // (essentially sorting according to angles)
  178. // If the angles are the same, sort according to their distance to origin
  179. T dist[24];
  180. for (int i = 0; i < num_in; i++) {
  181. dist[i] = dot_2d<T>(q[i], q[i]);
  182. }
  183. #ifdef __CUDACC__
  184. // CUDA version
  185. // In the future, we can potentially use thrust
  186. // for sorting here to improve speed (though not guaranteed)
  187. for (int i = 1; i < num_in - 1; i++) {
  188. for (int j = i + 1; j < num_in; j++) {
  189. T crossProduct = cross_2d<T>(q[i], q[j]);
  190. if ((crossProduct < -1e-6) ||
  191. (fabs(crossProduct) < 1e-6 && dist[i] > dist[j])) {
  192. auto q_tmp = q[i];
  193. q[i] = q[j];
  194. q[j] = q_tmp;
  195. auto dist_tmp = dist[i];
  196. dist[i] = dist[j];
  197. dist[j] = dist_tmp;
  198. }
  199. }
  200. }
  201. #else
  202. // CPU version
  203. std::sort(
  204. q + 1, q + num_in, [](const Point<T>& A, const Point<T>& B) -> bool {
  205. T temp = cross_2d<T>(A, B);
  206. if (fabs(temp) < 1e-6) {
  207. return dot_2d<T>(A, A) < dot_2d<T>(B, B);
  208. } else {
  209. return temp > 0;
  210. }
  211. });
  212. #endif
  213. // Step 4:
  214. // Make sure there are at least 2 points (that don't overlap with each other)
  215. // in the stack
  216. int k; // index of the non-overlapped second point
  217. for (k = 1; k < num_in; k++) {
  218. if (dist[k] > 1e-8) {
  219. break;
  220. }
  221. }
  222. if (k == num_in) {
  223. // We reach the end, which means the convex hull is just one point
  224. q[0] = p[t];
  225. return 1;
  226. }
  227. q[1] = q[k];
  228. int m = 2; // 2 points in the stack
  229. // Step 5:
  230. // Finally we can start the scanning process.
  231. // When a non-convex relationship between the 3 points is found
  232. // (either concave shape or duplicated points),
  233. // we pop the previous point from the stack
  234. // until the 3-point relationship is convex again, or
  235. // until the stack only contains two points
  236. for (int i = k + 1; i < num_in; i++) {
  237. while (m > 1 && cross_2d<T>(q[i] - q[m - 2], q[m - 1] - q[m - 2]) >= 0) {
  238. m--;
  239. }
  240. q[m++] = q[i];
  241. }
  242. // Step 6 (Optional):
  243. // In general sense we need the original coordinates, so we
  244. // need to shift the points back (reverting Step 2)
  245. // But if we're only interested in getting the area/perimeter of the shape
  246. // We can simply return.
  247. if (!shift_to_zero) {
  248. for (int i = 0; i < m; i++) {
  249. q[i] += start;
  250. }
  251. }
  252. return m;
  253. }
  254. template <typename T>
  255. HOST_DEVICE_INLINE T polygon_area(const Point<T> (&q)[24], const int& m) {
  256. if (m <= 2) {
  257. return 0;
  258. }
  259. T area = 0;
  260. for (int i = 1; i < m - 1; i++) {
  261. area += fabs(cross_2d<T>(q[i] - q[0], q[i + 1] - q[0]));
  262. }
  263. return area / 2.0;
  264. }
  265. template <typename T>
  266. HOST_DEVICE_INLINE T rboxes_intersection(
  267. const RotatedBox<T>& box1,
  268. const RotatedBox<T>& box2) {
  269. // There are up to 4 x 4 + 4 + 4 = 24 intersections (including dups) returned
  270. // from rotated_rect_intersection_pts
  271. Point<T> intersectPts[24], orderedPts[24];
  272. Point<T> pts1[4];
  273. Point<T> pts2[4];
  274. get_rotated_vertices<T>(box1, pts1);
  275. get_rotated_vertices<T>(box2, pts2);
  276. int num = get_intersection_points<T>(pts1, pts2, intersectPts);
  277. if (num <= 2) {
  278. return 0.0;
  279. }
  280. // Convex Hull to order the intersection points in clockwise order and find
  281. // the contour area.
  282. int num_convex = convex_hull_graham<T>(intersectPts, num, orderedPts, true);
  283. return polygon_area<T>(orderedPts, num_convex);
  284. }
  285. } // namespace
  286. template <typename T>
  287. HOST_DEVICE_INLINE T
  288. rbox_iou_single(T const* const box1_raw, T const* const box2_raw) {
  289. // shift center to the middle point to achieve higher precision in result
  290. RotatedBox<T> box1, box2;
  291. auto center_shift_x = (box1_raw[0] + box2_raw[0]) / 2.0;
  292. auto center_shift_y = (box1_raw[1] + box2_raw[1]) / 2.0;
  293. box1.x_ctr = box1_raw[0] - center_shift_x;
  294. box1.y_ctr = box1_raw[1] - center_shift_y;
  295. box1.w = box1_raw[2];
  296. box1.h = box1_raw[3];
  297. box1.a = box1_raw[4];
  298. box2.x_ctr = box2_raw[0] - center_shift_x;
  299. box2.y_ctr = box2_raw[1] - center_shift_y;
  300. box2.w = box2_raw[2];
  301. box2.h = box2_raw[3];
  302. box2.a = box2_raw[4];
  303. const T area1 = box1.w * box1.h;
  304. const T area2 = box2.w * box2.h;
  305. if (area1 < 1e-14 || area2 < 1e-14) {
  306. return 0.f;
  307. }
  308. const T intersection = rboxes_intersection<T>(box1, box2);
  309. const T iou = intersection / (area1 + area2 - intersection);
  310. return iou;
  311. }