nms_locality.py 1.8 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. # coding=utf-8
  2. import numpy as np
  3. from shapely.geometry import Polygon
  4. def intersection(g, p):
  5. # 取g,p中的几何体信息组成多边形
  6. g = Polygon(g[:8].reshape((4, 2)))
  7. p = Polygon(p[:8].reshape((4, 2)))
  8. # 判断g,p是否为有效的多边形几何体
  9. if not g.is_valid or not p.is_valid:
  10. return 0
  11. # 取两个几何体的交集和并集
  12. inter = Polygon(g).intersection(Polygon(p)).area
  13. union = g.area + p.area - inter
  14. if union == 0:
  15. return 0
  16. else:
  17. return inter / union
  18. def weighted_merge(g, p):
  19. # 取g,p两个几何体的加权(权重根据对应的检测得分计算得到)
  20. g[:8] = (g[8] * g[:8] + p[8] * p[:8]) / (g[8] + p[8])
  21. # 合并后的几何体的得分为两个几何体得分的总和
  22. g[8] = (g[8] + p[8])
  23. return g
  24. def standard_nms(S, thres):
  25. # 标准NMS
  26. order = np.argsort(S[:, 8])[::-1]
  27. keep = []
  28. while order.size > 0:
  29. i = order[0]
  30. keep.append(i)
  31. ovr = np.array([intersection(S[i], S[t]) for t in order[1:]])
  32. inds = np.where(ovr <= thres)[0]
  33. order = order[inds + 1]
  34. return S[keep]
  35. def nms_locality(polys, thres=0.3):
  36. '''
  37. locality aware nms of EAST
  38. :param polys: a N*9 numpy array. first 8 coordinates, then prob
  39. :return: boxes after nms
  40. '''
  41. S = [] # 合并后的几何体集合
  42. p = None # 合并后的几何体
  43. for g in polys:
  44. if p is not None and intersection(g, p) > thres: # 若两个几何体的相交面积大于指定的阈值,则进行合并
  45. p = weighted_merge(g, p)
  46. else: # 反之,则保留当前的几何体
  47. if p is not None:
  48. S.append(p)
  49. p = g
  50. if p is not None:
  51. S.append(p)
  52. if len(S) == 0:
  53. return np.array([])
  54. return standard_nms(np.array(S), thres)