rorate_image_detection.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # @Time : 2024/6/19 0019 下午 4:35
  4. # @Author : liudan
  5. # @File : rorate_image_detection.py
  6. # @Software: pycharm
  7. import time
  8. import cv2 as cv
  9. import cv2
  10. import numpy as np
  11. from loguru import logger
  12. import os
  13. from superpoint_superglue_deployment import Matcher
  14. from datetime import datetime
  15. import random
  16. import cv2
  17. import numpy as np
  18. import random
  19. def rotate_image(image_path):
  20. rotate_result_images = []
  21. for path in os.listdir(image_path):
  22. ref_image_path = os.path.join(image_path,path)
  23. ref_image = cv2.imread(ref_image_path, 0)
  24. if random.random() < 0.5:
  25. rotated_image = cv2.rotate(ref_image, cv2.ROTATE_180)
  26. rotate_result_images.append((rotated_image, path))
  27. else:
  28. print(f"Skipped")
  29. print(len(rotate_result_images))
  30. return rotate_result_images
  31. def match_image(ref_img, query_img):
  32. superglue_matcher = Matcher(
  33. {
  34. "superpoint": {
  35. "input_shape": (-1, -1),
  36. "keypoint_threshold": 0.003,
  37. },
  38. "superglue": {
  39. "match_threshold": 0.5,
  40. },
  41. "use_gpu": True,
  42. }
  43. )
  44. _, _, _, _, matches = superglue_matcher.match(ref_img, query_img)
  45. return matches
  46. def registration(ref_image_dir, ref_selection_dir):
  47. # filename = os.listdir(ref_image_dir)
  48. # image_files = [f for f in filename if f.endswith(('.jpg', '.jpeg', '.png', '.JPG'))]
  49. # image_file = random.choice(image_files)
  50. # ref_image_path = os.path.join(ref_image_dir, image_file)
  51. # ref_image = cv2.imread(ref_image_path, 0)
  52. true_positives = 0
  53. false_positives = 0
  54. false_negatives = 0
  55. true_negatives = 0
  56. positive_false = []
  57. negative_false = []
  58. for file_name in os.listdir(ref_selection_dir):
  59. ref_image_path = os.path.join(ref_selection_dir,file_name)
  60. ref_image = cv2.imread(ref_image_path, 0)
  61. # positive_false = []
  62. save_path = './data/region_rotate/A'
  63. for file in os.listdir(ref_image_dir):
  64. query_image_path = os.path.join(ref_image_dir, file)
  65. query_image = cv2.imread(query_image_path, 0)
  66. query_image_rotate = cv2.rotate(query_image, cv2.ROTATE_180)
  67. # cv2.imwrite('1111111111.jpg', query_image_rotate)
  68. matches1 = match_image(ref_image, query_image)
  69. matches2 = match_image(ref_image,query_image_rotate)
  70. if len(matches1) > len(matches2):
  71. flag = True
  72. else:
  73. flag = False
  74. print(flag)
  75. if flag == True:
  76. true_positives += 1
  77. else:
  78. false_negatives += 1
  79. positive_false.append((file,file_name))
  80. cv2.imwrite(os.path.join(save_path, f'{file[:-4]}.jpg'),query_image)
  81. # print(positive_false)
  82. # print(file_name)
  83. # negative_false = []
  84. nagetive_image = rotate_image(ref_image_dir)
  85. for i, item in enumerate(nagetive_image):
  86. file, path = item
  87. # query_image = cv2.cvtColor(file,cv2.COLOR_BGR2GRAY)
  88. query_image = file
  89. query_image_rotate = cv2.rotate(query_image, cv2.ROTATE_180)
  90. matches1 = match_image(ref_image, query_image)
  91. matches2 = match_image(ref_image, query_image_rotate)
  92. if len(matches1) > len(matches2):
  93. flag = True
  94. else:
  95. flag = False
  96. print(flag)
  97. if flag == False:
  98. true_negatives += 1
  99. else:
  100. false_positives += 1
  101. negative_false.append((path, file_name))
  102. cv2.imwrite(os.path.join(save_path, f'{path[:-4]}.jpg'),query_image)
  103. # print(negative_false)
  104. # print(file_name)
  105. Accurary = (true_positives + true_negatives) / (true_positives + true_negatives + false_positives + false_negatives)
  106. Precision = true_negatives / (true_negatives + false_negatives)
  107. Recall = true_negatives / (true_negatives + false_positives)
  108. F1_score = 2 * (Precision * Recall) / (Precision + Recall)
  109. print(positive_false)
  110. print(negative_false)
  111. # print(file_name)
  112. print(f"Accurary:{Accurary: .4f}")
  113. print(f"Precision: {Precision:.4f}")
  114. print(f"Recall: {Recall:.4f}")
  115. print(f"F1 Score: {F1_score:.4f}")
  116. if __name__ == "__main__":
  117. ref_image_path = './data/region_rotate/E1'
  118. ref_selection_dir = './data/region_rotate/E'
  119. # start_time = time.time()
  120. registration(ref_image_path, ref_selection_dir)
  121. # end_time = time.time()
  122. # elapsed_time = end_time - start_time
  123. # print(f"程序用时: {elapsed_time:.2f} 秒")