test.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444
  1. """
  2. # File : test.py
  3. # Time :2024-06-11 10:15
  4. # Author :FEANGYANG
  5. # version :python 3.7
  6. # Contact :1071082183@qq.com
  7. # Description:
  8. """
  9. import json
  10. import itertools
  11. import random
  12. import cv2
  13. from PIL import Image, ImageDraw
  14. import os
  15. from superpoint_superglue_deployment import Matcher
  16. import numpy as np
  17. from loguru import logger
  18. import cv2 as cv
  19. import yaml
  20. class StructureClass:
  21. def __init__(self, ref_image_path, query_image_path, json_path, save_image_path, json_mask_path, scale_factor=0.37):
  22. self.weldmentclasses = []
  23. self.ref_image = []
  24. self.query_image = []
  25. self.boxes_xy_label = {}
  26. self.scale_factor = scale_factor
  27. self.superglue_matcher = Matcher({
  28. "superpoint": {
  29. "input_shape": (-1, -1),
  30. "keypoint_threshold": 0.003,
  31. },
  32. "superglue": {
  33. "match_threshold": 0.5,
  34. },
  35. "use_gpu": True,
  36. })
  37. self.read_ref_image(ref_image_path)
  38. self.read_query_image(query_image_path)
  39. self.read_json(json_path)
  40. _, wrap_image, _ = self.registration_demo(save_image_path, json_mask_path)
  41. self.replace_query_image(wrap_image)
  42. self.group_weldment()
  43. # 定义一个函数来获取每个元素的首字母
  44. def get_first_letter(self, item):
  45. it = '-'.join(item.split('-')[:2])
  46. return it
  47. def read_ref_image(self, path):
  48. self.ref_image = self.process_image_data(path)
  49. def read_query_image(self, path):
  50. self.query_image = self.process_image_data(path)
  51. def replace_query_image(self, wrap_image):
  52. self.query_image = wrap_image
  53. def read_json(self, json_path):
  54. """
  55. 读取焊接件的标注信息
  56. :param json_path:
  57. :type json_path:
  58. :return:
  59. :rtype:
  60. """
  61. with open(json_path, 'r') as f:
  62. data = json.load(f)
  63. for shape in data['shapes']:
  64. if 'points' in shape:
  65. shape['points'] = [[int(round(x)), int(round(y))] for x, y in shape['points']]
  66. x1, y1 = shape['points'][0]
  67. x2, y2 = shape['points'][1]
  68. label = shape['label']
  69. self.boxes_xy_label[label] = [x1, y1, x2, y2]
  70. def process_image_data(self, data):
  71. """
  72. 读取图像,如果是文件则用cv读取,如果已经读取过则直接返回
  73. :param data:
  74. :type data:
  75. :return:
  76. :rtype:
  77. """
  78. extensions = ['.PNG', '.png', '.jpg', '.jpeg', '.JPG', '.JPEG']
  79. if any(data.endswith(ext) for ext in extensions):
  80. if not os.path.exists(data):
  81. raise FileNotFoundError(f"Image file {data} not found")
  82. image = cv2.imread(data)
  83. return image
  84. else:
  85. if isinstance(data, np.ndarray):
  86. return data
  87. else:
  88. raise FileNotFoundError(f"Image file {data} not found")
  89. def mask2image(self, json_mask_path, ref_gray):
  90. """
  91. 如果image_mask json文件存在,则将ref_image按照mask的形状扣出来
  92. :param json_mask_path:
  93. :type json_mask_path:
  94. :param ref_gray:
  95. :type ref_gray:
  96. :return:
  97. :rtype:
  98. """
  99. with open(json_mask_path, 'r') as f:
  100. data = json.load(f)
  101. shape = data['shapes'][0]
  102. if shape['shape_type'] == 'polygon':
  103. coords = [(int(round(x * self.scale_factor)), int(round(y * self.scale_factor))) for x, y in
  104. shape['points']]
  105. mask = np.zeros(ref_gray.shape, np.uint8)
  106. cv2.fillPoly(mask, [np.array(coords, np.int32)], 255)
  107. ref_gray = cv2.bitwise_and(ref_gray, mask)
  108. return ref_gray
  109. def registration_demo(self, save_image_path, json_mask_path):
  110. ref_image_resize = cv2.resize(self.ref_image, None, fx=self.scale_factor, fy=self.scale_factor)
  111. query_image_resize = cv2.resize(self.query_image, None, fx=self.scale_factor, fy=self.scale_factor)
  112. ref_gray = cv2.cvtColor(ref_image_resize, cv2.COLOR_BGR2GRAY)
  113. query_gray = cv2.cvtColor(query_image_resize, cv2.COLOR_BGR2GRAY)
  114. if os.path.exists(json_mask_path):
  115. ref_gray = self.mask2image(json_mask_path, ref_gray)
  116. query_kpts, ref_kpts, _, _, matches = self.superglue_matcher.match(query_gray, ref_gray)
  117. src_pts = np.float32([query_kpts[m.queryIdx].pt for m in matches]).reshape(-1, 1, 2)
  118. dst_pts = np.float32([ref_kpts[m.trainIdx].pt for m in matches]).reshape(-1, 1, 2)
  119. M, mask = cv2.findHomography(src_pts, dst_pts, cv2.USAC_MAGSAC, 5.0, maxIters=10000, confidence=0.95)
  120. logger.info(f"Number of inliers: {mask.sum()}")
  121. matches = [m for m, m_mask in zip(matches, mask) if m_mask]
  122. matches.sort(key=lambda m: m.distance)
  123. matched_image = cv2.drawMatches(query_image_resize, query_kpts, ref_image_resize, ref_kpts, matches[:50], None, flags=2)
  124. wrap_image = cv2.warpPerspective(query_image_resize, M, (ref_image_resize.shape[1], ref_image_resize.shape[0]))
  125. wrap_image = cv2.resize(wrap_image, (self.ref_image.shape[1], self.ref_image.shape[0]))
  126. sub_image = cv2.subtract(self.ref_image, wrap_image)
  127. cv2.imwrite(os.path.join(save_image_path, "match.jpg"), matched_image)
  128. cv2.imwrite(os.path.join(save_image_path, "wrap.jpg"), wrap_image)
  129. cv2.imwrite(os.path.join(save_image_path, "result.jpg"), sub_image)
  130. return matched_image, wrap_image, sub_image
  131. def group_weldment(self):
  132. grouped_data = {}
  133. for key, group in itertools.groupby(sorted(self.boxes_xy_label), self.get_first_letter):
  134. grouped_data[key] = list(group)
  135. # 创建子类实例并添加到大类中
  136. for key, group in grouped_data.items():
  137. subclass = WeldmentClass(key)
  138. for g in group:
  139. subclass.addshapelist(g, self.boxes_xy_label.get(g))
  140. self.add_weldmentclass(subclass)
  141. def add_weldmentclass(self, weldmentclass):
  142. self.weldmentclasses.append(weldmentclass)
  143. # 焊接件类
  144. class WeldmentClass(StructureClass):
  145. def __init__(self, name):
  146. self.shapelist = []
  147. self.xylist = []
  148. self.methodclasses = []
  149. self.name = name
  150. self.flaglist = []
  151. self.result = None
  152. def addshapelist(self, shape, box_xy):
  153. self.shapelist.append(shape)
  154. self.xylist.append(box_xy)
  155. def add_method(self, methodclass):
  156. self.methodclasses.append(methodclass)
  157. class SSIMDet:
  158. def __init__(self, ref_image, query_image, label, box_xy): # x1, y1, x2, y2
  159. self.name = 'SSIM'
  160. self.label = label
  161. self.x1, self.y1, self.x2, self.y2 = box_xy
  162. self.cut_ref_image = self.cut_image(ref_image)
  163. self.cut_query_image = self.cut_image(query_image)
  164. self.result = self.ssim_func(self.cut_ref_image, self.cut_query_image)
  165. def cut_image(self, image):
  166. return image[self.y1:self.y2, self.x1:self.x2]
  167. def ssim_func(self, im1, im2):
  168. imgsize = im1.shape[1] * im1.shape[2]
  169. avg1 = im1.mean((1, 2), keepdims=1)
  170. avg2 = im2.mean((1, 2), keepdims=1)
  171. std1 = im1.std((1, 2), ddof=1)
  172. std2 = im2.std((1, 2), ddof=1)
  173. cov = ((im1 - avg1) * (im2 - avg2)).mean((1, 2)) * imgsize / (imgsize - 1)
  174. avg1 = np.squeeze(avg1)
  175. avg2 = np.squeeze(avg2)
  176. k1 = 0.01
  177. k2 = 0.03
  178. c1 = (k1 * 255) ** 2
  179. c2 = (k2 * 255) ** 2
  180. c3 = c2 / 2
  181. # return np.mean((cov + c3) / (std1 * std2 + c3))
  182. return np.mean(
  183. (2 * avg1 * avg2 + c1) * 2 * (cov + c3) / (avg1 ** 2 + avg2 ** 2 + c1) / (std1 ** 2 + std2 ** 2 + c2))
  184. class VarianceDet:
  185. def __init__(self, ref_image, query_image, label, box_xy):
  186. self.name = 'VarianceDet'
  187. self.label = label
  188. self.x1, self.y1, self.x2, self.y2 = box_xy
  189. self.cut_ref_image = self.cut_image(ref_image)
  190. self.cut_query_image = self.cut_image(query_image)
  191. self.proportion = self.black_pixels_proportion(self.cut_query_image)
  192. if self.proportion > 0.05:
  193. self.result = 1
  194. else:
  195. self.result = self.variance_det_func(self.cut_ref_image, self.cut_query_image)
  196. def cut_image(self, image):
  197. return image[self.y1:self.y2, self.x1:self.x2]
  198. def black_pixels_proportion(self, cut_query_image):
  199. black_pixels = np.sum(cv2.cvtColor(cut_query_image, cv2.COLOR_BGR2GRAY) == 0)
  200. other_pixels = np.sum(cv2.cvtColor(cut_query_image, cv2.COLOR_BGR2GRAY) != 0)
  201. proportion = black_pixels / (other_pixels + black_pixels)
  202. return proportion
  203. # 计算两张图片的方差
  204. def calculate_variance(self, image):
  205. gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
  206. variance = np.var(gray_image)
  207. mean = np.mean(image)
  208. variance = variance / mean
  209. return variance
  210. def variance_det_func(self, ref_image, query_image):
  211. variance1 = self.calculate_variance(ref_image)
  212. variance2 = self.calculate_variance(query_image)
  213. variance_diff = abs(variance1 - variance2)
  214. max_variance = max(variance1, variance2)
  215. normalized_diff = variance_diff / max_variance if max_variance != 0 else 0
  216. if normalized_diff < 0.9:
  217. return True
  218. else:
  219. return False
  220. class RorateDet:
  221. def __init__(self, ref_image, query_image, label, box_xy):
  222. self.name = 'RorateDet'
  223. self.label = label
  224. self.x1, self.y1, self.x2, self.y2 = box_xy
  225. self.cut_ref_image = self.cut_image(ref_image)
  226. self.cut_query_image = self.cut_image(query_image)
  227. self.query_image_rotate = cv2.rotate(self.cut_query_image, cv2.ROTATE_180)
  228. self.result = self.rorate_det_func()
  229. def cut_image(self, image):
  230. return image[self.y1:self.y2, self.x1:self.x2]
  231. # 配准操作,计算模板图与查询图配准点的距离差异(翻转后的距离大于未翻转的)
  232. def match_image(self, ref_img, query_img):
  233. superglue_matcher = Matcher(
  234. {
  235. "superpoint": {
  236. "input_shape": (-1, -1),
  237. "keypoint_threshold": 0.003,
  238. },
  239. "superglue": {
  240. "match_threshold": 0.5,
  241. },
  242. "use_gpu": True,
  243. },shape=20
  244. )
  245. ref_img = cv2.cvtColor(ref_img, cv2.COLOR_BGR2GRAY)
  246. query_img = cv2.cvtColor(query_img, cv2.COLOR_BGR2GRAY)
  247. query_kpts, ref_kpts, _, _, matches = superglue_matcher.match(ref_img, query_img)
  248. matched_query_kpts = [query_kpts[m.queryIdx].pt for m in matches]
  249. matched_ref_kpts = [ref_kpts[m.trainIdx].pt for m in matches]
  250. diff_query_ref = np.array(matched_ref_kpts) - np.array(matched_query_kpts)
  251. print(diff_query_ref)
  252. if len(matches) != 0:
  253. diff_query_ref = np.linalg.norm(diff_query_ref, axis=1, ord=2)
  254. diff_query_ref = np.sqrt(diff_query_ref)
  255. diff_query_ref = np.mean(diff_query_ref)
  256. else:
  257. diff_query_ref = np.inf
  258. return diff_query_ref # 返回差异值,用于后续做比较
  259. def rorate_det_func(self):
  260. """
  261. True: 没有缺陷
  262. False:有缺陷
  263. :return:
  264. """
  265. diff1 = self.match_image(self.cut_ref_image, self.cut_query_image) # 计算模板图与查询图的配准点差异
  266. diff2 = self.match_image(self.cut_ref_image, self.query_image_rotate) # 计算模板图与翻转180度图的配准点差异
  267. if diff1 < diff2:
  268. return True
  269. return False
  270. class NumPixel:
  271. def __init__(self, ref_image, query_image, label, box_xy, threshld=0.15, x_scale = 120, y_scale = 80):
  272. self.name = 'NumPixel'
  273. self.label = label
  274. self.scale_box = []
  275. self.x_scale, self.y_scale, self.threshld = x_scale, y_scale, threshld
  276. self.ref_h, self.ref_w, _ = ref_image.shape
  277. self.x1, self.y1, self.x2, self.y2 = self.big_box(box_xy)
  278. # self.cut_ref_image = self.cut_image(ref_image)
  279. self.cut_query_image = self.cut_image(query_image)
  280. self.ostu_query_image = self.otsu_binarize(self.cut_query_image)
  281. self.ostu_query_image = self.ostu_query_image[self.scale_box[1]:self.scale_box[3], self.scale_box[0]:self.scale_box[2]]
  282. self.result = self.num_pixel_func(self.ostu_query_image)
  283. def big_box(self, box_xy):
  284. x1, y1, x2, y2 = box_xy
  285. nx1, ny1, nx2, ny2 = 0,0,0,0
  286. if x1 >= self.x_scale:
  287. nx1 = x1-self.x_scale
  288. self.scale_box.append(self.x_scale)
  289. else:
  290. nx1 = 0
  291. self.scale_box.append(x1)
  292. if y1 >= self.y_scale:
  293. ny1 = y1-self.y_scale
  294. self.scale_box.append(self.y_scale)
  295. else:
  296. ny1 = 0
  297. self.scale_box.append(y1)
  298. if x2 + self.x_scale <= self.ref_w:
  299. nx2 = x2 + self.x_scale
  300. self.scale_box.append(self.scale_box[0]+(x2-x1))
  301. else:
  302. nx2 = self.ref_w
  303. self.scale_box.append(self.scale_box[0]+(x2-x1))
  304. if y2 + self.y_scale <= self.ref_h:
  305. ny2 = y2 + self.y_scale
  306. self.scale_box.append(self.scale_box[1]+(y2-y1))
  307. else:
  308. ny2 = self.ref_h
  309. self.scale_box.append(self.scale_box[1]+(y2-y1))
  310. return nx1, ny1, nx2, ny2
  311. def num_pixel_func(self, ostu_query_image):
  312. """
  313. True: 无缺陷
  314. False:有缺陷
  315. :return:
  316. :rtype:
  317. """
  318. num_pixel_region_query = round((np.sum(ostu_query_image == 0) / (ostu_query_image.shape[0] * ostu_query_image.shape[1])), 2)
  319. if num_pixel_region_query >= self.threshld:
  320. return True
  321. return False
  322. def cut_image(self, image):
  323. return image[self.y1:self.y2, self.x1:self.x2]
  324. def otsu_binarize(self, image):
  325. gray_image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
  326. gray_image = cv2.equalizeHist(gray_image)
  327. # ret1, mask1 = cv2.threshold(gray_image, 0, 255, cv2.THRESH_BINARY +cv2.THRESH_OTSU)
  328. ret2, mask = cv2.threshold(gray_image, 60, 255, cv2.THRESH_BINARY)
  329. # mask =mask1 & mask2
  330. kernel = np.ones((3, 3), np.uint8)
  331. mask = cv2.dilate(mask, kernel, iterations=3)
  332. mask = cv2.erode(mask, kernel, iterations=3)
  333. return mask
  334. def read_yaml(yaml_file):
  335. with open(yaml_file, 'r') as file:
  336. data = yaml.safe_load(file)
  337. return data
  338. def calculate_det(struct, yaml_data):
  339. for weldment in struct.weldmentclasses:
  340. shapelist = weldment.shapelist
  341. xylist = weldment.xylist
  342. for i in range(len(shapelist)):
  343. for method in yaml_data.get(shapelist[i]):
  344. class_obj = globals()[method]
  345. instance = class_obj(struct.ref_image, struct.query_image, shapelist[i], xylist[i])
  346. weldment.flaglist.append(instance.result)
  347. weldment.result = all(weldment.flaglist)
  348. weldment.add_method(instance)
  349. if __name__ == '__main__':
  350. ref_image_path = './data/yongsheng_image/ref_image/image165214-001.jpg'
  351. query_image_path = './data/yongsheng_image/test_image_query/image165214-011.jpg'
  352. json_path = './data/yongsheng_image/json/image165214-001.json'
  353. save_image_path = './data/yongsheng_image/test_regis_result'
  354. json_mask_path = './data/yongsheng_image/json/image165214-001_mask.json'
  355. # for filename in os.listdir(image_dir):
  356. struct = StructureClass(ref_image_path, query_image_path, json_path, save_image_path, json_mask_path)
  357. yaml_data = read_yaml('./test.yaml')
  358. calculate_det(struct, yaml_data.get('image165214-001'))
  359. print()