rorate_image_detection.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # @Time : 2024/6/19 0019 下午 4:35
  4. # @Author : liudan
  5. # @File : rorate_image_detection.py
  6. # @Software: pycharm
  7. import time
  8. import cv2 as cv
  9. import cv2
  10. import numpy as np
  11. from loguru import logger
  12. import os
  13. from superpoint_superglue_deployment import Matcher
  14. from datetime import datetime
  15. import random
  16. import cv2
  17. import numpy as np
  18. import random
  19. # 在模板图的基础上创造负样本(翻转图片)
  20. def rotate_image(image_path):
  21. rotate_result_images = []
  22. for path in os.listdir(image_path):
  23. ref_image_path = os.path.join(image_path,path)
  24. ref_image = cv2.imread(ref_image_path, 0)
  25. if random.random() < 1: # 以0.5的概率决定是否翻转
  26. rotated_image = cv2.rotate(ref_image, cv2.ROTATE_180)
  27. rotate_result_images.append((rotated_image, path))
  28. else:
  29. print(f"Skipped")
  30. print(len(rotate_result_images))
  31. return rotate_result_images
  32. # 配准操作,计算模板图与查询图配准点的距离差异(翻转后的距离大于未翻转的)
  33. def match_image(ref_img, query_img):
  34. superglue_matcher = Matcher(
  35. {
  36. "superpoint": {
  37. "input_shape": (-1, -1),
  38. "keypoint_threshold": 0.003,
  39. },
  40. "superglue": {
  41. "match_threshold": 0.5,
  42. },
  43. "use_gpu": True,
  44. }
  45. )
  46. query_kpts, ref_kpts, _, _, matches = superglue_matcher.match(ref_img, query_img)
  47. matched_query_kpts = [query_kpts[m.queryIdx].pt for m in matches]
  48. matched_ref_kpts = [ref_kpts[m.trainIdx].pt for m in matches]
  49. diff_query_ref = np.array(matched_ref_kpts) - np.array(matched_query_kpts)
  50. print(diff_query_ref)
  51. if len(matches)!=0:
  52. diff_query_ref = np.linalg.norm(diff_query_ref, axis=1, ord=2)
  53. diff_query_ref = np.sqrt(diff_query_ref)
  54. diff_query_ref = np.mean(diff_query_ref)
  55. else:
  56. diff_query_ref = np.inf
  57. return diff_query_ref # 返回差异值,用于后续做比较
  58. def registration(ref_image_dir, ref_selection_dir):
  59. # 计算recall等需要的的参数
  60. true_positives = 0
  61. false_positives = 0
  62. false_negatives = 0
  63. true_negatives = 0
  64. # 返回预测错误的样本
  65. positive_false = []
  66. negative_false = []
  67. # 循环遍历模板图
  68. for file_name in os.listdir(ref_selection_dir):
  69. ref_image_path = os.path.join(ref_selection_dir,file_name)
  70. ref_image = cv2.imread(ref_image_path, 0)
  71. save_path = './data/region_rotate/vague/A' # 保存预测错误的原图片,可不用
  72. # 循环遍历正样本(这里正样本是查询图,因为查询图都是没有翻转的,所以可当作正样本)
  73. # 这是在正样本中预测
  74. for file in os.listdir(ref_image_dir):
  75. query_image_path = os.path.join(ref_image_dir, file)
  76. query_image = cv2.imread(query_image_path, 0)
  77. query_image_rotate = cv2.rotate(query_image, cv2.ROTATE_180)
  78. # cv2.imwrite('1111111111.jpg', query_image_rotate)
  79. diff1 = match_image(ref_image, query_image) # 计算模板图与查询图的配准点差异
  80. diff2 = match_image(ref_image,query_image_rotate) # 计算模板图与翻转180度图的配准点差异
  81. # if (len(matches1) > len(matches2)) or (diff1<diff2) :
  82. if diff1 < diff2:
  83. flag = True
  84. else:
  85. flag = False
  86. print(flag)
  87. if flag == True:
  88. true_positives += 1
  89. else:
  90. false_negatives += 1
  91. positive_false.append((file,file_name))
  92. cv2.imwrite(os.path.join(save_path, f'{file[:-4]}.jpg'),query_image)
  93. # 这是在负样本中预测
  94. # 因为缺少负样本,所以用rotate_image方法创造负样本
  95. nagetive_image = rotate_image(ref_image_dir)
  96. for i, item in enumerate(nagetive_image):
  97. file, path = item
  98. # query_image = cv2.cvtColor(file,cv2.COLOR_BGR2GRAY)
  99. query_image = file
  100. query_image_rotate = cv2.rotate(query_image, cv2.ROTATE_180)
  101. diff1 = match_image(ref_image, query_image)
  102. diff2 = match_image(ref_image, query_image_rotate)
  103. # if (len(matches1) > len(matches2)) or (diff1<diff2):
  104. if diff1 < diff2:
  105. flag = True
  106. else:
  107. flag = False
  108. print(flag)
  109. if flag == False:
  110. true_negatives += 1
  111. else:
  112. false_positives += 1
  113. negative_false.append((path, file_name))
  114. cv2.imwrite(os.path.join(save_path, f'{path[:-4]}.jpg'),query_image)
  115. Accurary = (true_positives + true_negatives) / (true_positives + true_negatives + false_positives + false_negatives)
  116. Precision = true_negatives / (true_negatives + false_negatives)
  117. Recall = true_negatives / (true_negatives + false_positives)
  118. F1_score = 2 * (Precision * Recall) / (Precision + Recall)
  119. print(positive_false)
  120. print(negative_false)
  121. # print(file_name)
  122. print(f"Accurary:{Accurary: .4f}")
  123. print(f"Precision: {Precision:.4f}")
  124. print(f"Recall: {Recall:.4f}")
  125. print(f"F1 Score: {F1_score:.4f}")
  126. if __name__ == "__main__":
  127. # ref_image_path = './data/region_rotate/vague/E2'
  128. # ref_selection_dir = './data/region_rotate/vague/E22'
  129. ref_image_path = './data/big_624/G1'
  130. ref_selection_dir = './data/big_624/G2'
  131. # start_time = time.time()
  132. registration(ref_image_path, ref_selection_dir)
  133. # end_time = time.time()
  134. # elapsed_time = end_time - start_time
  135. # print(f"程序用时: {elapsed_time:.2f} 秒")