metric.py 2.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. import numpy as np
  2. import cv2
  3. import torch
  4. import ot
  5. from basic_ops import Line
  6. from chamfer_distance import ChamferDistance
  7. cd = ChamferDistance()
  8. def sa_metric(angle_p, angle_g):
  9. d_angle = np.abs(angle_p - angle_g)
  10. d_angle = min(d_angle, np.pi - d_angle)
  11. d_angle = d_angle * 2 / np.pi
  12. return max(0, (1 - d_angle)) ** 2
  13. def se_metric(coord_p, coord_g, size=(400, 400)):
  14. c_p = [(coord_p[0] + coord_p[2]) / 2, (coord_p[1] + coord_p[3]) / 2]
  15. c_g = [(coord_g[0] + coord_g[2]) / 2, (coord_g[1] + coord_g[3]) / 2]
  16. d_coord = np.abs(c_p[0] - c_g[0])**2 + np.abs(c_p[1] - c_g[1])**2
  17. d_coord = np.sqrt(d_coord) / max(size[0], size[1])
  18. return max(0, (1 - d_coord)) ** 2
  19. def EA_metric(l_pred, l_gt, size=(400, 400)):
  20. se = se_metric(l_pred.coord, l_gt.coord, size=size)
  21. sa = sa_metric(l_pred.angle(), l_gt.angle())
  22. return sa * se
  23. def Chamfer_metric(l_pred, l_gt, size=(400, 400)):
  24. points1 = get_points_coords(l_pred)
  25. points2 = get_points_coords(l_gt)
  26. #add z-axis
  27. points1 = np.insert(points1, 0, values=0, axis=1)
  28. points2 = np.insert(points2, 0, values=0, axis=1)
  29. p1 = torch.from_numpy(points1).unsqueeze(0).float()
  30. p2 = torch.from_numpy(points2).unsqueeze(0).float()
  31. d1, d2 = cd(p1, p2)
  32. d = (d1.mean().item() + d2.mean().item()) / 2
  33. mmax = size[0] * size[0] + size[1] * size[1]
  34. return 1 - d / mmax
  35. def Emd_metric(l_pred, l_gt, size=(400, 400)):
  36. points1 = get_points_coords(l_pred)
  37. points2 = get_points_coords(l_gt)
  38. M = ot.dist(points1, points2, metric='euclidean')
  39. _, log = ot.emd([], [], M, log=True)
  40. cost = log['cost']
  41. return 1 - cost / np.sqrt(size[0] * size[0] + size[1] * size[1])
  42. def get_points_coords(l):
  43. points = []
  44. y0, x0, y1, x1 = l.coord
  45. dx = x1 - x0
  46. dy = y1 - y0
  47. length = int(np.sqrt(dx * dx + dy * dy))
  48. for _ in range(length + 1):
  49. points.append([int(np.round(x0)), int(np.round(y0))])
  50. x0 += (dx / length)
  51. y0 += (dy / length)
  52. return points
  53. if __name__ == "__main__":
  54. # l1 = Line([0, 200, 400, 200])
  55. # l2 = Line([200, 0, 200, 400])
  56. l1 = Line([200, 0, 190, 399])
  57. l2 = Line([190, 0, 200, 399])
  58. print(EA_metric(l1, l2))
  59. mask = np.zeros((400, 400))
  60. cv2.line(mask, (5, 0), (0, 5), 255, 1)
  61. cv2.line(mask, (394, 399), (399, 394), 255, 1)
  62. cv2.imwrite('debug.png', mask)
  63. cd_score = Chamfer_metric(l1, l2)
  64. emd_score = Emd_metric(l1, l2)
  65. print(cd_score, emd_score)