rotate_env.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # @Time : 2024/6/24 0024 上午 10:38
  4. # @Author : liudan
  5. # @File : rotate_env.py
  6. # @Software: pycharm
  7. #!/usr/bin/env python
  8. # -*- coding: utf-8 -*-
  9. # @Time : 2024/6/19 0019 下午 4:35
  10. # @Author : liudan
  11. # @File : rorate_image_detection.py
  12. # @Software: pycharm
  13. import time
  14. import cv2 as cv
  15. import cv2
  16. import numpy as np
  17. from loguru import logger
  18. import os
  19. from superpoint_superglue_deployment import Matcher
  20. from datetime import datetime
  21. import random
  22. import cv2
  23. import numpy as np
  24. import random
  25. # 在模板图的基础上创造负样本(翻转图片)
  26. def rotate_image(image_path):
  27. rotate_result_images = []
  28. for path in os.listdir(image_path):
  29. ref_image_path = os.path.join(image_path,path)
  30. ref_image = cv2.imread(ref_image_path, 0)
  31. if random.random() < 0.5: # 以0.5的概率决定是否翻转
  32. rotated_image = cv2.rotate(ref_image, cv2.ROTATE_180)
  33. rotate_result_images.append((rotated_image, path))
  34. else:
  35. print(f"Skipped")
  36. print(len(rotate_result_images))
  37. return rotate_result_images
  38. # 配准操作,计算模板图与查询图配准点的距离差异(翻转后的距离大于未翻转的)
  39. def match_image(ref_img, query_img):
  40. superglue_matcher = Matcher(
  41. {
  42. "superpoint": {
  43. "input_shape": (-1, -1),
  44. "keypoint_threshold": 0.003,
  45. },
  46. "superglue": {
  47. "match_threshold": 0.5,
  48. },
  49. "use_gpu": True,
  50. }
  51. )
  52. query_kpts, ref_kpts, _, _, matches = superglue_matcher.match(ref_img, query_img)
  53. matched_query_kpts = [query_kpts[m.queryIdx].pt for m in matches]
  54. matched_ref_kpts = [ref_kpts[m.trainIdx].pt for m in matches]
  55. diff_query_ref = np.array(matched_ref_kpts) - np.array(matched_query_kpts)
  56. print(diff_query_ref)
  57. if len(matches)!=0:
  58. diff_query_ref = np.linalg.norm(diff_query_ref, axis=1, ord=2)
  59. diff_query_ref = np.sqrt(diff_query_ref)
  60. diff_query_ref = np.mean(diff_query_ref)
  61. else:
  62. diff_query_ref = np.inf
  63. return diff_query_ref, matches # 返回差异值,用于后续做比较
  64. def registration(ref_image_dir, ref_selection_dir):
  65. # 计算recall等需要的的参数
  66. true_positives = 0
  67. false_positives = 0
  68. false_negatives = 0
  69. true_negatives = 0
  70. # 返回预测错误的样本
  71. positive_false = []
  72. negative_false = []
  73. # 循环遍历模板图
  74. for file_name in os.listdir(ref_selection_dir):
  75. ref_image_path = os.path.join(ref_selection_dir,file_name)
  76. ref_image = cv2.imread(ref_image_path, 0)
  77. save_path = './data/region_rotate/vague/A' # 保存预测错误的原图片,可不用
  78. # 循环遍历正样本(这里正样本是查询图,因为查询图都是没有翻转的,所以可当作正样本)
  79. # 这是在正样本中预测
  80. for file in os.listdir(ref_image_dir):
  81. query_image_path = os.path.join(ref_image_dir, file)
  82. query_image = cv2.imread(query_image_path, 0)
  83. query_image_rotate = cv2.rotate(query_image, cv2.ROTATE_180)
  84. # cv2.imwrite('1111111111.jpg', query_image_rotate)
  85. diff1, matches1 = match_image(ref_image, query_image) # 计算模板图与查询图的配准点差异
  86. diff2, matches2 = match_image(ref_image,query_image_rotate) # 计算模板图与翻转180度图的配准点差异
  87. # if (len(matches1) > len(matches2)) or (diff1<diff2) :
  88. if diff1 < diff2:
  89. flag = True
  90. else:
  91. flag = False
  92. print(flag)
  93. if flag == True:
  94. true_positives += 1
  95. else:
  96. false_negatives += 1
  97. positive_false.append((file,file_name))
  98. cv2.imwrite(os.path.join(save_path, f'{file[:-4]}.jpg'),query_image)
  99. # 这是在负样本中预测
  100. # 因为缺少负样本,所以用rotate_image方法创造负样本
  101. nagetive_image = rotate_image(ref_image_dir)
  102. for i, item in enumerate(nagetive_image):
  103. file, path = item
  104. # query_image = cv2.cvtColor(file,cv2.COLOR_BGR2GRAY)
  105. query_image = file
  106. query_image_rotate = cv2.rotate(query_image, cv2.ROTATE_180)
  107. diff1, matches1 = match_image(ref_image, query_image)
  108. diff2, matches2 = match_image(ref_image, query_image_rotate)
  109. # if (len(matches1) > len(matches2)) or (diff1<diff2):
  110. if diff1 < diff2:
  111. flag = True
  112. else:
  113. flag = False
  114. print(flag)
  115. if flag == False:
  116. true_negatives += 1
  117. else:
  118. false_positives += 1
  119. negative_false.append((path, file_name))
  120. cv2.imwrite(os.path.join(save_path, f'{path[:-4]}.jpg'),query_image)
  121. Accurary = (true_positives + true_negatives) / (true_positives + true_negatives + false_positives + false_negatives)
  122. Precision = true_negatives / (true_negatives + false_negatives)
  123. Recall = true_negatives / (true_negatives + false_positives)
  124. F1_score = 2 * (Precision * Recall) / (Precision + Recall)
  125. print(positive_false)
  126. print(negative_false)
  127. # print(file_name)
  128. print(f"Accurary:{Accurary: .4f}")
  129. print(f"Precision: {Precision:.4f}")
  130. print(f"Recall: {Recall:.4f}")
  131. print(f"F1 Score: {F1_score:.4f}")
  132. if __name__ == "__main__":
  133. ref_image_path = './data/region_rotate/vague/E2'
  134. ref_selection_dir = './data/region_rotate/vague/E22'
  135. # start_time = time.time()
  136. registration(ref_image_path, ref_selection_dir)
  137. # end_time = time.time()
  138. # elapsed_time = end_time - start_time
  139. # print(f"程序用时: {elapsed_time:.2f} 秒")