test_nkl.py 2.9 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. import argparse
  2. import numpy as np
  3. import os
  4. from hungarian_matching import caculate_tp_fp_fn
  5. gt_path = './data/training/SL5K_resize_100_100'
  6. #gt_path = '/home/hanqi/work/semantic/data/crawl/JTLEE_resize_100_100'
  7. parser = argparse.ArgumentParser(description='PyTorch Semantic-Line Training')
  8. parser.add_argument('--pred', type=str, required=True)
  9. parser.add_argument('--gt', type=str, required=True)
  10. parser.add_argument('--align', default=False, action='store_true')
  11. arg = parser.parse_args()
  12. #
  13. pred_path = arg.pred
  14. gt_path = arg.gt
  15. filenames = sorted(os.listdir(pred_path))
  16. total_tp = np.zeros(99)
  17. total_fp = np.zeros(99)
  18. total_fn = np.zeros(99)
  19. total_tp_align = np.zeros(99)
  20. total_fp_align = np.zeros(99)
  21. total_fn_align = np.zeros(99)
  22. for filename in filenames:
  23. if 'npy' not in filename:
  24. continue
  25. if 'align' in filename:
  26. continue
  27. pred = np.load(os.path.join(pred_path, filename))
  28. if arg.align:
  29. pred_align = np.load(os.path.join(pred_path, filename.split('.')[0]+'_align.npy'))
  30. gt_txt = open(os.path.join(gt_path, filename.split('.')[0] + '.txt'))
  31. gt_coords = gt_txt.readlines()
  32. gt = []
  33. for i in range(int(gt_coords[0].split(' ')[0])):
  34. gt.append([int(float(gt_coords[0].split(' ')[i*4+2])), int(float(gt_coords[0].split(' ')[i*4+1])), int(float(gt_coords[0].split(' ')[i*4+4])), int(float(gt_coords[0].split(' ')[i*4+3]))])
  35. for i in range(1, 100):
  36. tp, fp, fn = caculate_tp_fp_fn(pred.tolist(), gt, thresh=i*0.01)
  37. total_tp[i-1] += tp
  38. total_fp[i-1] += fp
  39. total_fn[i-1] += fn
  40. if arg.align:
  41. tp, fp, fn = caculate_tp_fp_fn(pred_align.tolist(), gt, thresh=i*0.01)
  42. total_tp_align[i-1] += tp
  43. total_fp_align[i-1] += fp
  44. total_fn_align[i-1] += fn
  45. total_recall = total_tp / (total_tp + total_fn)
  46. total_precision = total_tp / (total_tp + total_fp)
  47. f = 2 * total_recall * total_precision / (total_recall + total_precision + 1e-6)
  48. if arg.align:
  49. total_recall_align = total_tp_align / (total_tp_align + total_fn_align)
  50. total_precision_align = total_tp_align / (total_tp_align + total_fp_align)
  51. f_align = 2 * total_recall_align * total_precision_align / (total_recall_align + total_precision_align + 1e-6)
  52. print('Mean P:', total_precision.mean())
  53. print('Mean R:', total_recall.mean())
  54. print('Mean F:', f.mean())
  55. print('F@0.95:', f[94])
  56. #np.savetxt('precision.csv', total_precision)
  57. #np.savetxt('recall.csv', total_recall)
  58. #np.savetxt('fscore.csv', total_f)
  59. if arg.align:
  60. print('Mean P_align:', total_precision_align.mean())
  61. print('Mean R_align:', total_recall_align.mean())
  62. print('Mean F_align:', f_align.mean())
  63. print('F_align@0.95:', f_align[94])
  64. #np.savetxt('total_precision_refine.csv', total_precision_align)
  65. #np.savetxt('total_recall_refine.csv', total_recall_align)
  66. #np.savetxt('total_f_refine.csv', f_align)