image_registration.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # @Time : 2024/6/26 0026 下午 3:21
  4. # @Author : liudan
  5. # @File : image_registration.py
  6. # @Software: pycharm
  7. import os
  8. import cv2
  9. import json
  10. import numpy as np
  11. from loguru import logger
  12. from superpoint_superglue_deployment import Matcher
  13. class GenerateTestImage:
  14. def __init__(self, template_img_dir, template_mask_json_dir, save_dir):
  15. self.template_img_dir = template_img_dir
  16. self.template_mask_json_dir = template_mask_json_dir
  17. self.save_dir = save_dir
  18. if not os.path.exists(self.save_dir):
  19. os.mkdir(self.save_dir)
  20. self.template_img_paths = sorted([os.path.join(self.template_img_dir, f) for f in os.listdir(self.template_img_dir) if
  21. f.endswith('.jpg') or f.endswith('.JPG') ])
  22. self.template_mask_json_paths = sorted([os.path.join(self.template_mask_json_dir, f) for f in os.listdir(self.template_mask_json_dir) if
  23. f.endswith('.json')])
  24. def generate_test_image(self, query_image_dir):
  25. for file_name in os.listdir(query_image_dir):
  26. query_image_path = os.path.join(query_image_dir, file_name)
  27. best_matched_index = self.find_best_matched_template_image(query_image_path)
  28. best_matched_template_img_path = self.template_img_paths[best_matched_index]
  29. best_matched_template_mask_json_path = self.template_mask_json_paths[best_matched_index]
  30. query_image = cv2.imread(query_image_path)
  31. ref_image = cv2.imread(best_matched_template_img_path)
  32. height, width = ref_image.shape[:2]
  33. scale_factor = 0.37
  34. query_image_resize = cv2.resize(query_image, dsize=None, fx=scale_factor, fy=scale_factor)
  35. ref_image_resize = cv2.resize(ref_image, dsize=None, fx=scale_factor, fy=scale_factor)
  36. query_gray = cv2.cvtColor(query_image_resize, cv2.COLOR_BGR2GRAY)
  37. ref_gray = cv2.cvtColor(ref_image_resize, cv2.COLOR_BGR2GRAY)
  38. if os.path.exists(best_matched_template_mask_json_path):
  39. ref_gray = self.process_image_with_mask(ref_image, scale_factor, best_matched_template_mask_json_path,
  40. ref_gray)
  41. superglue_matcher = Matcher(
  42. {
  43. "superpoint": {
  44. "input_shape": (-1, -1),
  45. "keypoint_threshold": 0.003,
  46. },
  47. "superglue": {
  48. "match_threshold": 0.5,
  49. },
  50. "use_gpu": True,
  51. }
  52. )
  53. query_kpts, ref_kpts, _, _, matches = superglue_matcher.match(query_gray, ref_gray)
  54. try:
  55. if len(matches) < 4:
  56. raise ValueError("Not enough matches to compute homography.")
  57. M, mask = cv2.findHomography(
  58. np.float64([query_kpts[m.queryIdx].pt for m in matches]).reshape(-1, 1, 2),
  59. np.float64([ref_kpts[m.trainIdx].pt for m in matches]).reshape(-1, 1, 2),
  60. method=cv2.USAC_MAGSAC,
  61. ransacReprojThreshold=5.0,
  62. maxIters=10000,
  63. confidence=0.95,
  64. )
  65. if M is None:
  66. raise ValueError("Unable to compute homography.")
  67. logger.info(f"number of inliers: {mask.sum()}")
  68. matches = np.array(matches)[np.all(mask > 0, axis=1)]
  69. matches = sorted(matches, key=lambda match: match.distance)
  70. matched_image = cv2.drawMatches(query_image_resize,query_kpts,ref_image_resize,ref_kpts,matches[:50],None,flags=2)
  71. wrap_image = cv2.warpPerspective(query_image_resize, M,(ref_image_resize.shape[1], ref_image_resize.shape[0]))
  72. wrap_image = cv2.resize(wrap_image, (ref_image.shape[1], ref_image.shape[0]))
  73. sub_image = cv2.subtract(ref_image, wrap_image)
  74. cv2.imwrite(os.path.join(self.save_dir, f'match_{file_name[:-4]}.jpg'), matched_image)
  75. cv2.imwrite(os.path.join(self.save_dir, f'wrap_{file_name[:-4]}.jpg'), wrap_image)
  76. cv2.imwrite(os.path.join(self.save_dir, f'sub_{file_name[:-4]}.jpg'), sub_image)
  77. except ValueError as e:
  78. print(e)
  79. except Exception as e:
  80. print(f"An error occurred: {e}")
  81. def process_image_with_mask(self, ref_image, scale_factor, json_ref_path, ref_gray):
  82. with open(json_ref_path, 'r') as f:
  83. data = json.load(f)
  84. shapes = data['shapes']
  85. shape = shapes[0]
  86. if shape['shape_type'] == 'polygon':
  87. coords = [(int(round(x * scale_factor)), int(round(y * scale_factor))) for x, y in shape['points']]
  88. else:
  89. coords = []
  90. mask = np.zeros(ref_gray.shape, dtype=np.uint8) * 255
  91. pts = np.array(coords, np.int32)
  92. cv2.fillPoly(mask, [pts], 1)
  93. ref_gray_masked = cv2.bitwise_and(ref_gray, ref_gray, mask=mask)
  94. return ref_gray_masked
  95. def find_best_matched_template_image(self, query_img_path):
  96. number_match_pts_list = []
  97. for index, item in enumerate(self.template_img_paths):
  98. ref_image_path = item
  99. query_image = cv2.imread(query_img_path)
  100. ref_image = cv2.imread(ref_image_path)
  101. height, width = ref_image.shape[:2]
  102. scale_factor = 0.37
  103. query_image_resize = cv2.resize(query_image, dsize=None, fx=scale_factor, fy=scale_factor)
  104. ref_image_resize = cv2.resize(ref_image, dsize=None, fx=scale_factor, fy=scale_factor)
  105. query_gray = cv2.cvtColor(query_image_resize, cv2.COLOR_BGR2GRAY)
  106. ref_gray = cv2.cvtColor(ref_image_resize, cv2.COLOR_BGR2GRAY)
  107. if os.path.exists( self.template_mask_json_paths[index]):
  108. ref_gray = self.process_image_with_mask(ref_image, scale_factor,self.template_mask_json_paths[index],
  109. ref_gray)
  110. superglue_matcher = Matcher(
  111. {
  112. "superpoint": {
  113. "input_shape": (-1, -1),
  114. "keypoint_threshold": 0.003,
  115. },
  116. "superglue": {
  117. "match_threshold": 0.5,
  118. },
  119. "use_gpu": True,
  120. }
  121. )
  122. query_kpts, ref_kpts, _, _, matches = superglue_matcher.match(query_gray, ref_gray)
  123. number_match_pts_list.append(len(matches))
  124. best_match_index = number_match_pts_list.index(max(number_match_pts_list))
  125. return best_match_index
  126. if __name__ =="__main__":
  127. template_img_dir = './data/best_registration/template_image/'
  128. template_mask_json_dir = './data/best_registration/mask_json/'
  129. save_dir = "./data/best_registration/result"
  130. query_image_dir = "./data/best_registration/query_image"
  131. gti =GenerateTestImage(template_img_dir, template_mask_json_dir, save_dir)
  132. gti.generate_test_image(query_image_dir)