hungarian_matching.py 1.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647
  1. import numpy as np
  2. from scipy.optimize import linear_sum_assignment
  3. # https://docs.scipy.org/doc/scipy-0.18.1/reference/generated/scipy.optimize.linear_sum_assignment.html
  4. from metric import EA_metric, Chamfer_metric, Emd_metric
  5. from basic_ops import *
  6. def build_graph(p_lines, g_lines, threshold):
  7. prediction_len = len(p_lines)
  8. gt_len = len(g_lines)
  9. G = np.zeros((prediction_len, gt_len))
  10. for i in range(prediction_len):
  11. for j in range(gt_len):
  12. if EA_metric(p_lines[i], g_lines[j]) >= threshold:
  13. # if Chamfer_metric(p_lines[i], g_lines[j]) >= threshold:
  14. # if Emd_metric(p_lines[i], g_lines[j]) >= threshold:
  15. G[i][j] = 1
  16. return G
  17. def caculate_tp_fp_fn(b_points, gt_coords, thresh=0.90):
  18. p_lines = []
  19. g_lines = []
  20. for points in b_points:
  21. if len(points) == 0:
  22. continue
  23. if points[0] == points[2] and points[1] == points[3]:
  24. continue
  25. else:
  26. p_lines.append(Line(list(points)))
  27. for points in gt_coords:
  28. if len(points) == 0:
  29. continue
  30. if points[0] == points[2] and points[1] == points[3]:
  31. continue
  32. else:
  33. g_lines.append(Line(list(points)))
  34. G = build_graph(p_lines, g_lines, thresh)
  35. # convert G to -G to caculate maximum matching.
  36. row_ind, col_ind = linear_sum_assignment(-G)
  37. pair_nums = G[row_ind, col_ind].sum()
  38. tp = pair_nums
  39. fp = len(p_lines) - pair_nums
  40. fn = len(g_lines) - pair_nums
  41. return tp, fp, fn