test.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336
  1. """
  2. # File : test.py
  3. # Time :2024-06-11 10:15
  4. # Author :FEANGYANG
  5. # version :python 3.7
  6. # Contact :1071082183@qq.com
  7. # Description:
  8. """
  9. import json
  10. import itertools
  11. import random
  12. import cv2
  13. from PIL import Image, ImageDraw
  14. import os
  15. from superpoint_superglue_deployment import Matcher
  16. import numpy as np
  17. from loguru import logger
  18. import cv2 as cv
  19. # 定义大类
  20. class StructureClass:
  21. def __init__(self, ref_image_path, query_image_path, json_path, save_image_path, json_mask_path):
  22. self.weldmentclasses = []
  23. self.ref_image = []
  24. self.query_image = []
  25. self.boxes_xy_label = {}
  26. self.scale_factor = 0.37
  27. self.read_ref_image(ref_image_path)
  28. self.read_query_image(query_image_path)
  29. self.read_json(json_path)
  30. self.registration_demo(save_image_path, json_mask_path)
  31. def read_ref_image(self, path):
  32. self.ref_image = self.process_image_data(path)
  33. def read_query_image(self, path):
  34. self.query_image = self.process_image_data(path)
  35. def replace_query_image(self, wrap_image):
  36. self.query_image = wrap_image
  37. def read_json(self, json_path):
  38. with open(json_path, 'r') as f:
  39. data = json.load(f)
  40. for shape in data['shapes']:
  41. if 'points' in shape:
  42. shape['points'] = [[int(round(x)), int(round(y))] for x, y in shape['points']]
  43. x1, y1 = shape['points'][0]
  44. x2, y2 = shape['points'][1]
  45. label = shape['label']
  46. self.boxes_xy_label[label] = [x1, y1, x2, y2]
  47. def process_image_data(self, data):
  48. extensions = ['.PNG', '.png', '.jpg', '.jpeg', '.JPG', '.JPEG']
  49. if any(data.endswith(ext) for ext in extensions):
  50. if not os.path.exists(data):
  51. raise FileNotFoundError(f"Image file {data} not found")
  52. image = cv2.imread(data)
  53. return image
  54. else:
  55. if isinstance(data, np.ndarray):
  56. return data
  57. else:
  58. raise FileNotFoundError(f"Image file {data} not found")
  59. def registration_demo(self,save_image_path, json_mask_path):
  60. height, width = self.ref_image.shape[:2]
  61. ref_image_resize = cv2.resize(self.ref_image, dsize=None, fx=self.scale_factor, fy=self.scale_factor)
  62. query_image_resize = cv2.resize(self.query_image, dsize=None, fx=self.scale_factor, fy=self.scale_factor)
  63. ref_gray = cv2.cvtColor(ref_image_resize, cv2.COLOR_BGR2GRAY)
  64. query_gray = cv2.cvtColor(query_image_resize, cv2.COLOR_BGR2GRAY)
  65. if os.path.exists(json_path):
  66. with open(json_mask_path, 'r') as f:
  67. data = json.load(f)
  68. shapes = data['shapes']
  69. shape = shapes[0]
  70. if shape['shape_type'] == 'polygon':
  71. coords = [(int(round(x * self.scale_factor)), int(round(y * self.scale_factor))) for x, y in
  72. shape['points']]
  73. else:
  74. coords = []
  75. mask = np.zeros(ref_gray.shape, dtype=np.uint8) * 255
  76. pts = np.array(coords, np.int32)
  77. cv2.fillPoly(mask, [pts], 1)
  78. ref_gray = cv2.bitwise_and(ref_gray, ref_gray, mask=mask)
  79. superglue_matcher = Matcher(
  80. {
  81. "superpoint": {
  82. "input_shape": (-1, -1),
  83. "keypoint_threshold": 0.003,
  84. },
  85. "superglue": {
  86. "match_threshold": 0.5,
  87. },
  88. "use_gpu": True,
  89. }
  90. )
  91. query_kpts, ref_kpts, _, _, matches = superglue_matcher.match(query_gray, ref_gray)
  92. M, mask = cv2.findHomography(
  93. np.float64([query_kpts[m.queryIdx].pt for m in matches]).reshape(-1, 1, 2),
  94. np.float64([ref_kpts[m.trainIdx].pt for m in matches]).reshape(-1, 1, 2),
  95. method=cv2.USAC_MAGSAC,
  96. ransacReprojThreshold=5.0,
  97. maxIters=10000,
  98. confidence=0.95,
  99. )
  100. logger.info(f"number of inliers: {mask.sum()}")
  101. matches = np.array(matches)[np.all(mask > 0, axis=1)]
  102. matches = sorted(matches, key=lambda match: match.distance)
  103. matched_image = cv2.drawMatches(
  104. query_image_resize,
  105. query_kpts,
  106. ref_image_resize,
  107. ref_kpts,
  108. matches[:50],
  109. None,
  110. flags=2,
  111. )
  112. match_file_name = f"match.jpg"
  113. cv2.imwrite(os.path.join(save_image_path, match_file_name), matched_image)
  114. wrap_file_name = f"wrap.jpg"
  115. wrap_image = cv.warpPerspective(query_image_resize, M, (ref_image_resize.shape[1], ref_image_resize.shape[0]))
  116. wrap_image = cv2.resize(wrap_image, (self.ref_image.shape[1], self.ref_image.shape[0]))
  117. cv2.imwrite(os.path.join(save_image_path, wrap_file_name), wrap_image)
  118. sub_file_name = f"result.jpg"
  119. sub_image = cv2.subtract(self.ref_image, wrap_image)
  120. cv2.imwrite(os.path.join(save_image_path, sub_file_name), sub_image)
  121. return matched_image, wrap_image, sub_image
  122. # def process_image_with_mask(self, json_mask_path, ref_gray):
  123. # with open(json_mask_path, 'r') as f:
  124. # data = json.load(f)
  125. # shapes = data['shapes']
  126. # shape = shapes[0]
  127. # if shape['shape_type'] == 'polygon':
  128. # coords = [(int(round(x * self.scale_factor)), int(round(y * self.scale_factor))) for x, y in shape['points']]
  129. # else:
  130. # coords = []
  131. # mask = np.zeros(ref_gray.shape, dtype=np.uint8) * 255
  132. # pts = np.array(coords, np.int32)
  133. # cv2.fillPoly(mask, [pts], 1)
  134. # ref_gray_masked = cv2.bitwise_and(ref_gray, ref_gray, mask=mask)
  135. # cv2.imwrite('ref_gray_mask.jpg', ref_gray_masked)
  136. # return ref_gray_masked
  137. def add_weldmentclass(self, weldmentclass):
  138. self.weldmentclasses.append(weldmentclass)
  139. # 焊接件类
  140. class WeldmentClass(StructureClass):
  141. def __init__(self, name):
  142. self.ployclasses = []
  143. self.name = name
  144. def add_ploy(self, ployclass):
  145. self.ployclasses.append(ployclass)
  146. class SSIMClass:
  147. def __init__(self, label, x1, y1, x2, y2):
  148. self.name = 'SSIM'
  149. self.label = label
  150. self.x1 = x1
  151. self.x2 = x2
  152. self.y1 = y1
  153. self.y2 = y2
  154. self.result = self.SSIMfunc()
  155. def SSIMfunc(self):
  156. return self.label, self.x1, self.x2, self.y1, self.y2
  157. class Ploy1Class:
  158. def __init__(self, label, x1, y1, x2, y2):
  159. self.name = 'Ploy1'
  160. self.label = label
  161. self.x1 = x1
  162. self.x2 = x2
  163. self.y1 = y1
  164. self.y2 = y2
  165. self.result = self.ploy1func()
  166. def ploy1func(self):
  167. return self.label, self.x1, self.x2, self.y1, self.y2
  168. class Ploy2Class:
  169. def __init__(self, label, x1, y1, x2, y2):
  170. self.name = 'Ploy2'
  171. self.label = label
  172. self.x1 = x1
  173. self.x2 = x2
  174. self.y1 = y1
  175. self.y2 = y2
  176. self.result = self.ploy2func()
  177. def ploy2func(self):
  178. return self.label, self.x1, self.x2, self.y1, self.y2
  179. class Ploy3Class:
  180. def __init__(self, label, x1, y1, x2, y2):
  181. self.name = 'Ploy3'
  182. self.label = label
  183. self.x1 = x1
  184. self.x2 = x2
  185. self.y1 = y1
  186. self.y2 = y2
  187. self.result = self.ploy3func()
  188. def ploy3func(self):
  189. return self.label, self.x1, self.x2, self.y1, self.y2
  190. # 定义一个函数来获取每个元素的首字母
  191. def get_first_letter(item):
  192. return item[0]
  193. if __name__ == '__main__':
  194. ref_image_path = './data/yongsheng_image/ref_image/DSC_0452.JPG'
  195. query_image_path = './data/yongsheng_image/test_image_query/DSC_0445.JPG'
  196. json_path = './data/yongsheng_image/json/DSC_0452.json'
  197. save_image_path = './data/yongsheng_image/test_regis_result'
  198. json_mask_path = './data/yongsheng_image/json/DSC_0452_mask.json'
  199. for filename in os.listdir(image_dir):
  200. # struct = StructureClass('./data/yongsheng_image/ref_image/DSC_0452.JPG', './data/yongsheng_image/test_image_query/DSC_0445.JPG', './data/yongsheng_image/json/DSC_0452.json')
  201. struct = StructureClass(ref_image_path, query_image_path, json_path, save_image_path, json_mask_path)
  202. grouped_data = {}
  203. for key, group in itertools.groupby(sorted(struct.boxes_xy_label), get_first_letter):
  204. grouped_data[key] = list(group)
  205. # 创建子类实例并添加到大类中
  206. for key, group in grouped_data.items():
  207. subclass = WeldmentClass(key)
  208. for g in group:
  209. if len(g) == 1:
  210. xy = struct.boxes_xy_label.get(g)
  211. ssim = SSIMClass(g, xy[0], xy[1], xy[2], xy[3])
  212. subclass.add_ploy(ssim)
  213. else:
  214. xy = struct.boxes_xy_label.get(g)
  215. if str(g).endswith('1'):
  216. poly = Ploy1Class(g, xy[0], xy[1], xy[2], xy[3])
  217. elif str(g).endswith('2'):
  218. poly = SSIMClass(g, xy[0], xy[1], xy[2], xy[3])
  219. else:
  220. poly = Ploy3Class(g, xy[0], xy[1], xy[2], xy[3])
  221. subclass.add_ploy(poly)
  222. struct.add_weldmentclass(subclass)
  223. w = WeldmentClass('A')
  224. struct.add_weldmentclass(w)
  225. print()
  226. # with open('./DSC_0452.json', 'r') as f:
  227. # data = json.load(f)
  228. # save_value = {}
  229. # for shape in data['shapes']:
  230. # if 'points' in shape:
  231. # shape['points'] = [[int(round(x)), int(round(y))] for x, y in shape['points']]
  232. # x1, y1 = shape['points'][0]
  233. # x2, y2 = shape['points'][1]
  234. # label = shape['label']
  235. # save_value[label] = [x1, y1, x2, y2]
  236. #
  237. # # 使用groupby函数根据首字母分组
  238. # grouped_data = {}
  239. # for key, group in itertools.groupby(sorted(save_value), get_first_letter):
  240. # grouped_data[key] = list(group)
  241. #
  242. # # 打印分组后的结果
  243. # for key, group in grouped_data.items():
  244. # print(f"{key}: {group}")
  245. #
  246. # # 创建大类实例
  247. # big_class = StructureClass()
  248. # # 创建子类实例并添加到大类中
  249. # for key, group in grouped_data.items():
  250. # subclass = WeldmentClass(key)
  251. # for g in group:
  252. # if len(g) == 1:
  253. # xy = save_value.get(g)
  254. # ssim = SSIMClass(g, xy[0], xy[1], xy[2], xy[3])
  255. # subclass.add_ploy(ssim)
  256. # else:
  257. # xy = save_value.get(g)
  258. # r = random.randint(1, 4)
  259. # if r == 1:
  260. # poly = Ploy1Class(g, xy[0], xy[1], xy[2], xy[3])
  261. # elif r == 2:
  262. # poly = Ploy2Class(g, xy[0], xy[1], xy[2], xy[3])
  263. # else:
  264. # poly = Ploy3Class(g, xy[0], xy[1], xy[2], xy[3])
  265. #
  266. # subclass.add_ploy(poly)
  267. # big_class.add_weldmentclass(subclass)
  268. #
  269. # for subclass in big_class.weldmentclasses:
  270. # print(subclass)