generate_test_images.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # @Time : 2024/6/18 0018 下午 3:49
  4. # @Author : liudan
  5. # @File : generate_test_images.py
  6. # @Software: pycharm
  7. import os
  8. import cv2
  9. import json
  10. import numpy as np
  11. from loguru import logger
  12. from superpoint_superglue_deployment import Matcher
  13. class GenerateTestImage:
  14. def __init__(self, template_img_dir, template_mask_json_dir, template_label_json_dir, save_dir):
  15. # self.template_img_paths = ['./data/ref/image165214-001.jpg','./data/ref/image165214-019.jpg', './data/ref/image165214-037.jpg']
  16. # self.template_mask_json_paths = ['./data/json/image165214-001_mask.json','./data/json/image165214-019_mask.json', './data/json/image165214-037_mask.json']
  17. # self.template_label_json_paths = ['./data/json/image165214-001.json','./data/json/image165214-019.json', './data/json/image165214-037.json']
  18. # self.save_dir="./data/generated_mask_test_images"
  19. # if not os.path.exists(self.save_dir):
  20. # os.mkdir(self.save_dir)
  21. self.template_img_dir = template_img_dir
  22. self.template_mask_json_dir = template_mask_json_dir
  23. self.template_label_json_dir = template_label_json_dir
  24. self.save_dir = save_dir
  25. if not os.path.exists(self.save_dir):
  26. os.mkdir(self.save_dir)
  27. self.template_img_paths = sorted([os.path.join(self.template_img_dir, f) for f in os.listdir(self.template_img_dir) if
  28. f.endswith('.jpg') or f.endswith('.JPG') ])
  29. self.template_mask_json_paths = sorted([os.path.join(self.template_mask_json_dir, f) for f in os.listdir(self.template_mask_json_dir) if
  30. f.endswith('.json')])
  31. self.template_label_json_paths = sorted([os.path.join(self.template_label_json_dir, f) for f in os.listdir(self.template_label_json_dir) if
  32. f.endswith('.json')])
  33. def generate_test_image(self, query_image_dir):
  34. for file_name in os.listdir(query_image_dir):
  35. query_image_path = os.path.join(query_image_dir, file_name)
  36. best_matched_index = self.find_best_matched_template_image(query_image_path)
  37. best_matched_template_img_path = self.template_img_paths[best_matched_index]
  38. best_matched_template_mask_json_path = self.template_mask_json_paths[best_matched_index]
  39. best_matched_template_label_json_path = self.template_label_json_paths[best_matched_index]
  40. query_image = cv2.imread(query_image_path)
  41. ref_image = cv2.imread(best_matched_template_img_path)
  42. height, width = ref_image.shape[:2]
  43. scale_factor = 0.37
  44. query_image_resize = cv2.resize(query_image, dsize=None, fx=scale_factor, fy=scale_factor)
  45. ref_image_resize = cv2.resize(ref_image, dsize=None, fx=scale_factor, fy=scale_factor)
  46. query_gray = cv2.cvtColor(query_image_resize, cv2.COLOR_BGR2GRAY)
  47. ref_gray = cv2.cvtColor(ref_image_resize, cv2.COLOR_BGR2GRAY)
  48. if os.path.exists(best_matched_template_mask_json_path):
  49. ref_gray = self.process_image_with_mask(ref_image, scale_factor, best_matched_template_mask_json_path,
  50. ref_gray)
  51. superglue_matcher = Matcher(
  52. {
  53. "superpoint": {
  54. "input_shape": (-1, -1),
  55. "keypoint_threshold": 0.003,
  56. },
  57. "superglue": {
  58. "match_threshold": 0.5,
  59. },
  60. "use_gpu": True,
  61. }
  62. )
  63. query_kpts, ref_kpts, _, _, matches = superglue_matcher.match(query_gray, ref_gray)
  64. try:
  65. if len(matches) < 4:
  66. raise ValueError("Not enough matches to compute homography.")
  67. M, mask = cv2.findHomography(
  68. np.float64([query_kpts[m.queryIdx].pt for m in matches]).reshape(-1, 1, 2),
  69. np.float64([ref_kpts[m.trainIdx].pt for m in matches]).reshape(-1, 1, 2),
  70. method=cv2.USAC_MAGSAC,
  71. ransacReprojThreshold=5.0,
  72. maxIters=10000,
  73. confidence=0.95,
  74. )
  75. if M is None:
  76. raise ValueError("Unable to compute homography.")
  77. wrap_image = cv2.warpPerspective(query_image_resize, M,
  78. (ref_image_resize.shape[1], ref_image_resize.shape[0]))
  79. wrap_image = cv2.resize(wrap_image, (ref_image.shape[1], ref_image.shape[0]))
  80. with open(best_matched_template_label_json_path) as json_file:
  81. data = json.load(json_file)
  82. for shape in data['shapes']:
  83. if 'points' in shape:
  84. shape['points'] = [[int(round(x)), int(round(y))] for x, y in shape['points']]
  85. x1, y1 = shape['points'][0]
  86. x2, y2 = shape['points'][1]
  87. region = wrap_image[y1:y2, x1:x2]
  88. filename = f'{file_name[:-4]}_{shape["label"]}.jpg'
  89. save_subdir =os.path.join(self.save_dir, "test_img", shape["label"])
  90. if not os.path.exists(save_subdir):
  91. os.makedirs(save_subdir)
  92. black_pixels = np.sum(cv2.cvtColor(region, cv2.COLOR_BGR2GRAY) == 0)
  93. if black_pixels == 0:
  94. cv2.imwrite(os.path.join(save_subdir, filename), region)
  95. except ValueError as e:
  96. print(e)
  97. except Exception as e:
  98. print(f"An error occurred: {e}")
  99. def process_image_with_mask(self, ref_image, scale_factor, json_ref_path, ref_gray):
  100. with open(json_ref_path, 'r') as f:
  101. data = json.load(f)
  102. shapes = data['shapes']
  103. shape = shapes[0]
  104. if shape['shape_type'] == 'polygon':
  105. coords = [(int(round(x * scale_factor)), int(round(y * scale_factor))) for x, y in shape['points']]
  106. else:
  107. coords = []
  108. mask = np.zeros(ref_gray.shape, dtype=np.uint8) * 255
  109. pts = np.array(coords, np.int32)
  110. cv2.fillPoly(mask, [pts], 1)
  111. ref_gray_masked = cv2.bitwise_and(ref_gray, ref_gray, mask=mask)
  112. return ref_gray_masked
  113. def find_best_matched_template_image(self, query_img_path):
  114. number_match_pts_list = []
  115. for index, item in enumerate(self.template_img_paths):
  116. ref_image_path = item
  117. query_image = cv2.imread(query_img_path)
  118. ref_image = cv2.imread(ref_image_path)
  119. height, width = ref_image.shape[:2]
  120. scale_factor = 0.37
  121. query_image_resize = cv2.resize(query_image, dsize=None, fx=scale_factor, fy=scale_factor)
  122. ref_image_resize = cv2.resize(ref_image, dsize=None, fx=scale_factor, fy=scale_factor)
  123. query_gray = cv2.cvtColor(query_image_resize, cv2.COLOR_BGR2GRAY)
  124. ref_gray = cv2.cvtColor(ref_image_resize, cv2.COLOR_BGR2GRAY)
  125. if os.path.exists( self.template_mask_json_paths[index]):
  126. ref_gray = self.process_image_with_mask(ref_image, scale_factor,self.template_mask_json_paths[index],
  127. ref_gray)
  128. superglue_matcher = Matcher(
  129. {
  130. "superpoint": {
  131. "input_shape": (-1, -1),
  132. "keypoint_threshold": 0.003,
  133. },
  134. "superglue": {
  135. "match_threshold": 0.5,
  136. },
  137. "use_gpu": True,
  138. }
  139. )
  140. query_kpts, ref_kpts, _, _, matches = superglue_matcher.match(query_gray, ref_gray)
  141. number_match_pts_list.append(len(matches))
  142. best_match_index = number_match_pts_list.index(max(number_match_pts_list))
  143. return best_match_index
  144. if __name__ =="__main__":
  145. template_img_dir = './data/ref/'
  146. template_mask_json_dir = './data/mask_json/'
  147. template_label_json_dir = './data/json/'
  148. save_dir = "./data/D_result"
  149. query_image_dir = "./data/query"
  150. # match_path = './data/vague_match'
  151. gti =GenerateTestImage(template_img_dir, template_mask_json_dir, template_label_json_dir, save_dir)
  152. gti.generate_test_image(query_image_dir)