generate_test_images.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # @Time : 2024/6/18 0018 下午 3:49
  4. # @Author : liudan
  5. # @File : generate_test_images.py
  6. # @Software: pycharm
  7. import os
  8. import cv2
  9. import json
  10. import numpy as np
  11. from superpoint_superglue_deployment import Matcher
  12. class GenerateTestImage:
  13. def __init__(self, template_img_dir, template_mask_json_dir, template_label_json_dir, save_dir):
  14. # self.template_img_paths = ['./data/ref/image165214-001.jpg','./data/ref/image165214-019.jpg', './data/ref/image165214-037.jpg']
  15. # self.template_mask_json_paths = ['./data/json/image165214-001_mask.json','./data/json/image165214-019_mask.json', './data/json/image165214-037_mask.json']
  16. # self.template_label_json_paths = ['./data/json/image165214-001.json','./data/json/image165214-019.json', './data/json/image165214-037.json']
  17. # self.save_dir="./data/generated_mask_test_images"
  18. # if not os.path.exists(self.save_dir):
  19. # os.mkdir(self.save_dir)
  20. self.template_img_dir = template_img_dir
  21. self.template_mask_json_dir = template_mask_json_dir
  22. self.template_label_json_dir = template_label_json_dir
  23. self.save_dir = save_dir
  24. if not os.path.exists(self.save_dir):
  25. os.mkdir(self.save_dir)
  26. self.template_img_paths = [os.path.join(self.template_img_dir, f) for f in os.listdir(self.template_img_dir) if
  27. f.endswith('.jpg') or f.endswith('.JPG') ]
  28. self.template_mask_json_paths = [os.path.join(self.template_mask_json_dir, f) for f in os.listdir(self.template_mask_json_dir) if
  29. f.endswith('.json')]
  30. self.template_label_json_paths = [os.path.join(self.template_label_json_dir, f) for f in os.listdir(self.template_label_json_dir) if
  31. f.endswith('.json')]
  32. def generate_test_image(self, query_image_dir):
  33. for file_name in os.listdir(query_image_dir):
  34. query_image_path = os.path.join(query_image_dir, file_name)
  35. best_matched_index = self.find_best_matched_template_image(query_image_path)
  36. best_matched_template_img_path = self.template_img_paths[best_matched_index]
  37. best_matched_template_mask_json_path = self.template_mask_json_paths[best_matched_index]
  38. best_matched_template_label_json_path = self.template_label_json_paths[best_matched_index]
  39. query_image = cv2.imread(query_image_path)
  40. ref_image = cv2.imread(best_matched_template_img_path)
  41. height, width = ref_image.shape[:2]
  42. scale_factor = 0.37
  43. query_image_resize = cv2.resize(query_image, dsize=None, fx=scale_factor, fy=scale_factor)
  44. ref_image_resize = cv2.resize(ref_image, dsize=None, fx=scale_factor, fy=scale_factor)
  45. query_gray = cv2.cvtColor(query_image_resize, cv2.COLOR_BGR2GRAY)
  46. ref_gray = cv2.cvtColor(ref_image_resize, cv2.COLOR_BGR2GRAY)
  47. if os.path.exists(best_matched_template_mask_json_path):
  48. ref_gray = self.process_image_with_mask(ref_image, scale_factor, best_matched_template_mask_json_path,
  49. ref_gray)
  50. superglue_matcher = Matcher(
  51. {
  52. "superpoint": {
  53. "input_shape": (-1, -1),
  54. "keypoint_threshold": 0.003,
  55. },
  56. "superglue": {
  57. "match_threshold": 0.5,
  58. },
  59. "use_gpu": True,
  60. }
  61. )
  62. query_kpts, ref_kpts, _, _, matches = superglue_matcher.match(query_gray, ref_gray)
  63. try:
  64. if len(matches) < 4:
  65. raise ValueError("Not enough matches to compute homography.")
  66. M, mask = cv2.findHomography(
  67. np.float64([query_kpts[m.queryIdx].pt for m in matches]).reshape(-1, 1, 2),
  68. np.float64([ref_kpts[m.trainIdx].pt for m in matches]).reshape(-1, 1, 2),
  69. method=cv2.USAC_MAGSAC,
  70. ransacReprojThreshold=5.0,
  71. maxIters=10000,
  72. confidence=0.95,
  73. )
  74. if M is None:
  75. raise ValueError("Unable to compute homography.")
  76. wrap_image = cv2.warpPerspective(query_image_resize, M,
  77. (ref_image_resize.shape[1], ref_image_resize.shape[0]))
  78. wrap_image = cv2.resize(wrap_image, (ref_image.shape[1], ref_image.shape[0]))
  79. with open(best_matched_template_label_json_path) as json_file:
  80. data = json.load(json_file)
  81. for shape in data['shapes']:
  82. if 'points' in shape:
  83. shape['points'] = [[int(round(x)), int(round(y))] for x, y in shape['points']]
  84. x1, y1 = shape['points'][0]
  85. x2, y2 = shape['points'][1]
  86. region = wrap_image[y1:y2, x1:x2]
  87. filename = f'{file_name[:-4]}_{shape["label"]}.jpg'
  88. save_subdir =os.path.join(self.save_dir, "test_img", shape["label"])
  89. if not os.path.exists(save_subdir):
  90. os.makedirs(save_subdir)
  91. black_pixels = np.sum(cv2.cvtColor(region, cv2.COLOR_BGR2GRAY) == 0)
  92. if black_pixels == 0:
  93. cv2.imwrite(os.path.join(save_subdir, filename), region)
  94. except ValueError as e:
  95. print(e)
  96. except Exception as e:
  97. print(f"An error occurred: {e}")
  98. def process_image_with_mask(self, ref_image, scale_factor, json_ref_path, ref_gray):
  99. with open(json_ref_path, 'r') as f:
  100. data = json.load(f)
  101. shapes = data['shapes']
  102. shape = shapes[0]
  103. if shape['shape_type'] == 'polygon':
  104. coords = [(int(round(x * scale_factor)), int(round(y * scale_factor))) for x, y in shape['points']]
  105. else:
  106. coords = []
  107. mask = np.zeros(ref_gray.shape, dtype=np.uint8) * 255
  108. pts = np.array(coords, np.int32)
  109. cv2.fillPoly(mask, [pts], 1)
  110. ref_gray_masked = cv2.bitwise_and(ref_gray, ref_gray, mask=mask)
  111. return ref_gray_masked
  112. def find_best_matched_template_image(self, query_img_path):
  113. number_match_pts_list = []
  114. for index, item in enumerate(self.template_img_paths):
  115. ref_image_path = item
  116. query_image = cv2.imread(query_img_path)
  117. ref_image = cv2.imread(ref_image_path)
  118. height, width = ref_image.shape[:2]
  119. scale_factor = 0.37
  120. query_image_resize = cv2.resize(query_image, dsize=None, fx=scale_factor, fy=scale_factor)
  121. ref_image_resize = cv2.resize(ref_image, dsize=None, fx=scale_factor, fy=scale_factor)
  122. query_gray = cv2.cvtColor(query_image_resize, cv2.COLOR_BGR2GRAY)
  123. ref_gray = cv2.cvtColor(ref_image_resize, cv2.COLOR_BGR2GRAY)
  124. if os.path.exists( self.template_mask_json_paths[index]):
  125. ref_gray = self.process_image_with_mask(ref_image, scale_factor,self.template_mask_json_paths[index],
  126. ref_gray)
  127. superglue_matcher = Matcher(
  128. {
  129. "superpoint": {
  130. "input_shape": (-1, -1),
  131. "keypoint_threshold": 0.003,
  132. },
  133. "superglue": {
  134. "match_threshold": 0.5,
  135. },
  136. "use_gpu": True,
  137. }
  138. )
  139. query_kpts, ref_kpts, _, _, matches = superglue_matcher.match(query_gray, ref_gray)
  140. number_match_pts_list.append(len(matches))
  141. best_match_index = number_match_pts_list.index(max(number_match_pts_list))
  142. return best_match_index
  143. if __name__ =="__main__":
  144. template_img_dir = './data/REF'
  145. template_mask_json_dir = './data/JSON_MASK/'
  146. template_label_json_dir = './data/JSON/'
  147. save_dir = "./data/generated"
  148. query_image_dir = "/data2/liudan/superpoint_superglue_deployment/data/QUERY"
  149. gti =GenerateTestImage(template_img_dir, template_mask_json_dir, template_label_json_dir, save_dir)
  150. gti.generate_test_image(query_image_dir)