""" # File : test.py # Time :2024-06-11 10:15 # Author :FEANGYANG # version :python 3.7 # Contact :1071082183@qq.com # Description: """ import json import itertools import random import cv2 from PIL import Image, ImageDraw import os from superpoint_superglue_deployment import Matcher import numpy as np from loguru import logger import cv2 as cv import yaml class StructureClass: def __init__(self, ref_image_path, query_image_path, json_path, save_image_path, json_mask_path, scale_factor=0.37): self.weldmentclasses = [] self.ref_image = [] self.query_image = [] self.boxes_xy_label = {} self.scale_factor = scale_factor self.superglue_matcher = Matcher({ "superpoint": { "input_shape": (-1, -1), "keypoint_threshold": 0.003, }, "superglue": { "match_threshold": 0.5, }, "use_gpu": True, }) self.read_ref_image(ref_image_path) self.read_query_image(query_image_path) self.read_json(json_path) _, wrap_image, _ = self.registration_demo(save_image_path, json_mask_path) self.replace_query_image(wrap_image) self.group_weldment() # 定义一个函数来获取每个元素的首字母 def get_first_letter(self, item): it = '-'.join(item.split('-')[:2]) return it def read_ref_image(self, path): self.ref_image = self.process_image_data(path) def read_query_image(self, path): self.query_image = self.process_image_data(path) def replace_query_image(self, wrap_image): self.query_image = wrap_image def read_json(self, json_path): """ 读取焊接件的标注信息 :param json_path: :type json_path: :return: :rtype: """ with open(json_path, 'r') as f: data = json.load(f) for shape in data['shapes']: if 'points' in shape: shape['points'] = [[int(round(x)), int(round(y))] for x, y in shape['points']] x1, y1 = shape['points'][0] x2, y2 = shape['points'][1] label = shape['label'] self.boxes_xy_label[label] = [x1, y1, x2, y2] def process_image_data(self, data): """ 读取图像,如果是文件则用cv读取,如果已经读取过则直接返回 :param data: :type data: :return: :rtype: """ extensions = ['.PNG', '.png', '.jpg', '.jpeg', '.JPG', '.JPEG'] if any(data.endswith(ext) for ext in extensions): if not os.path.exists(data): raise FileNotFoundError(f"Image file {data} not found") image = cv2.imread(data) return image else: if isinstance(data, np.ndarray): return data else: raise FileNotFoundError(f"Image file {data} not found") def mask2image(self, json_mask_path, ref_gray): """ 如果image_mask json文件存在,则将ref_image按照mask的形状扣出来 :param json_mask_path: :type json_mask_path: :param ref_gray: :type ref_gray: :return: :rtype: """ with open(json_mask_path, 'r') as f: data = json.load(f) shape = data['shapes'][0] if shape['shape_type'] == 'polygon': coords = [(int(round(x * self.scale_factor)), int(round(y * self.scale_factor))) for x, y in shape['points']] mask = np.zeros(ref_gray.shape, np.uint8) cv2.fillPoly(mask, [np.array(coords, np.int32)], 255) ref_gray = cv2.bitwise_and(ref_gray, mask) return ref_gray def registration_demo(self, save_image_path, json_mask_path): ref_image_resize = cv2.resize(self.ref_image, None, fx=self.scale_factor, fy=self.scale_factor) query_image_resize = cv2.resize(self.query_image, None, fx=self.scale_factor, fy=self.scale_factor) ref_gray = cv2.cvtColor(ref_image_resize, cv2.COLOR_BGR2GRAY) query_gray = cv2.cvtColor(query_image_resize, cv2.COLOR_BGR2GRAY) if os.path.exists(json_mask_path): ref_gray = self.mask2image(json_mask_path, ref_gray) query_kpts, ref_kpts, _, _, matches = self.superglue_matcher.match(query_gray, ref_gray) src_pts = np.float32([query_kpts[m.queryIdx].pt for m in matches]).reshape(-1, 1, 2) dst_pts = np.float32([ref_kpts[m.trainIdx].pt for m in matches]).reshape(-1, 1, 2) M, mask = cv2.findHomography(src_pts, dst_pts, cv2.USAC_MAGSAC, 5.0, maxIters=10000, confidence=0.95) logger.info(f"Number of inliers: {mask.sum()}") matches = [m for m, m_mask in zip(matches, mask) if m_mask] matches.sort(key=lambda m: m.distance) matched_image = cv2.drawMatches(query_image_resize, query_kpts, ref_image_resize, ref_kpts, matches[:50], None, flags=2) wrap_image = cv2.warpPerspective(query_image_resize, M, (ref_image_resize.shape[1], ref_image_resize.shape[0])) wrap_image = cv2.resize(wrap_image, (self.ref_image.shape[1], self.ref_image.shape[0])) sub_image = cv2.subtract(self.ref_image, wrap_image) cv2.imwrite(os.path.join(save_image_path, "match.jpg"), matched_image) cv2.imwrite(os.path.join(save_image_path, "wrap.jpg"), wrap_image) cv2.imwrite(os.path.join(save_image_path, "result.jpg"), sub_image) return matched_image, wrap_image, sub_image def group_weldment(self): grouped_data = {} for key, group in itertools.groupby(sorted(self.boxes_xy_label), self.get_first_letter): grouped_data[key] = list(group) # 创建子类实例并添加到大类中 for key, group in grouped_data.items(): subclass = WeldmentClass(key) for g in group: subclass.addshapelist(g, self.boxes_xy_label.get(g)) self.add_weldmentclass(subclass) def add_weldmentclass(self, weldmentclass): self.weldmentclasses.append(weldmentclass) # 焊接件类 class WeldmentClass(StructureClass): def __init__(self, name): self.shapelist = [] self.xylist = [] self.methodclasses = [] self.name = name self.flaglist = [] self.result = None def addshapelist(self, shape, box_xy): self.shapelist.append(shape) self.xylist.append(box_xy) def add_method(self, methodclass): self.methodclasses.append(methodclass) class SSIMDet: def __init__(self, ref_image, query_image, label, box_xy): # x1, y1, x2, y2 self.name = 'SSIM' self.label = label self.x1, self.y1, self.x2, self.y2 = box_xy self.cut_ref_image = self.cut_image(ref_image) self.cut_query_image = self.cut_image(query_image) self.result = self.ssim_func(self.cut_ref_image, self.cut_query_image) def cut_image(self, image): return image[self.y1:self.y2, self.x1:self.x2] def ssim_func(self, im1, im2): imgsize = im1.shape[1] * im1.shape[2] avg1 = im1.mean((1, 2), keepdims=1) avg2 = im2.mean((1, 2), keepdims=1) std1 = im1.std((1, 2), ddof=1) std2 = im2.std((1, 2), ddof=1) cov = ((im1 - avg1) * (im2 - avg2)).mean((1, 2)) * imgsize / (imgsize - 1) avg1 = np.squeeze(avg1) avg2 = np.squeeze(avg2) k1 = 0.01 k2 = 0.03 c1 = (k1 * 255) ** 2 c2 = (k2 * 255) ** 2 c3 = c2 / 2 # return np.mean((cov + c3) / (std1 * std2 + c3)) return np.mean( (2 * avg1 * avg2 + c1) * 2 * (cov + c3) / (avg1 ** 2 + avg2 ** 2 + c1) / (std1 ** 2 + std2 ** 2 + c2)) class VarianceDet: def __init__(self, ref_image, query_image, label, box_xy): self.name = 'VarianceDet' self.label = label self.x1, self.y1, self.x2, self.y2 = box_xy self.cut_ref_image = self.cut_image(ref_image) self.cut_query_image = self.cut_image(query_image) self.proportion = self.black_pixels_proportion(self.cut_query_image) if self.proportion > 0.05: self.result = 1 else: self.result = self.variance_det_func(self.cut_ref_image, self.cut_query_image) def cut_image(self, image): return image[self.y1:self.y2, self.x1:self.x2] def black_pixels_proportion(self, cut_query_image): black_pixels = np.sum(cv2.cvtColor(cut_query_image, cv2.COLOR_BGR2GRAY) == 0) other_pixels = np.sum(cv2.cvtColor(cut_query_image, cv2.COLOR_BGR2GRAY) != 0) proportion = black_pixels / (other_pixels + black_pixels) return proportion # 计算两张图片的方差 def calculate_variance(self, image): gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) variance = np.var(gray_image) mean = np.mean(image) variance = variance / mean return variance def variance_det_func(self, ref_image, query_image): variance1 = self.calculate_variance(ref_image) variance2 = self.calculate_variance(query_image) variance_diff = abs(variance1 - variance2) max_variance = max(variance1, variance2) normalized_diff = variance_diff / max_variance if max_variance != 0 else 0 if normalized_diff < 0.9: return True else: return False class RorateDet: def __init__(self, ref_image, query_image, label, box_xy): self.name = 'RorateDet' self.label = label self.x1, self.y1, self.x2, self.y2 = box_xy self.cut_ref_image = self.cut_image(ref_image) self.cut_query_image = self.cut_image(query_image) self.query_image_rotate = cv2.rotate(self.cut_query_image, cv2.ROTATE_180) self.result = self.rorate_det_func() def cut_image(self, image): return image[self.y1:self.y2, self.x1:self.x2] # 配准操作,计算模板图与查询图配准点的距离差异(翻转后的距离大于未翻转的) def match_image(self, ref_img, query_img): superglue_matcher = Matcher( { "superpoint": { "input_shape": (-1, -1), "keypoint_threshold": 0.003, }, "superglue": { "match_threshold": 0.5, }, "use_gpu": True, },shape=20 ) ref_img = cv2.cvtColor(ref_img, cv2.COLOR_BGR2GRAY) query_img = cv2.cvtColor(query_img, cv2.COLOR_BGR2GRAY) query_kpts, ref_kpts, _, _, matches = superglue_matcher.match(ref_img, query_img) matched_query_kpts = [query_kpts[m.queryIdx].pt for m in matches] matched_ref_kpts = [ref_kpts[m.trainIdx].pt for m in matches] diff_query_ref = np.array(matched_ref_kpts) - np.array(matched_query_kpts) print(diff_query_ref) if len(matches) != 0: diff_query_ref = np.linalg.norm(diff_query_ref, axis=1, ord=2) diff_query_ref = np.sqrt(diff_query_ref) diff_query_ref = np.mean(diff_query_ref) else: diff_query_ref = np.inf return diff_query_ref # 返回差异值,用于后续做比较 def rorate_det_func(self): """ True: 没有缺陷 False:有缺陷 :return: """ diff1 = self.match_image(self.cut_ref_image, self.cut_query_image) # 计算模板图与查询图的配准点差异 diff2 = self.match_image(self.cut_ref_image, self.query_image_rotate) # 计算模板图与翻转180度图的配准点差异 if diff1 < diff2: return True return False class NumPixel: def __init__(self, ref_image, query_image, label, box_xy, threshld=0.15, x_scale = 120, y_scale = 80): self.name = 'NumPixel' self.label = label self.scale_box = [] self.x_scale, self.y_scale, self.threshld = x_scale, y_scale, threshld self.ref_h, self.ref_w, _ = ref_image.shape self.x1, self.y1, self.x2, self.y2 = self.big_box(box_xy) # self.cut_ref_image = self.cut_image(ref_image) self.cut_query_image = self.cut_image(query_image) self.ostu_query_image = self.otsu_binarize(self.cut_query_image) self.ostu_query_image = self.ostu_query_image[self.scale_box[1]:self.scale_box[3], self.scale_box[0]:self.scale_box[2]] self.result = self.num_pixel_func(self.ostu_query_image) def big_box(self, box_xy): x1, y1, x2, y2 = box_xy nx1, ny1, nx2, ny2 = 0,0,0,0 if x1 >= self.x_scale: nx1 = x1-self.x_scale self.scale_box.append(self.x_scale) else: nx1 = 0 self.scale_box.append(x1) if y1 >= self.y_scale: ny1 = y1-self.y_scale self.scale_box.append(self.y_scale) else: ny1 = 0 self.scale_box.append(y1) if x2 + self.x_scale <= self.ref_w: nx2 = x2 + self.x_scale self.scale_box.append(self.scale_box[0]+(x2-x1)) else: nx2 = self.ref_w self.scale_box.append(self.scale_box[0]+(x2-x1)) if y2 + self.y_scale <= self.ref_h: ny2 = y2 + self.y_scale self.scale_box.append(self.scale_box[1]+(y2-y1)) else: ny2 = self.ref_h self.scale_box.append(self.scale_box[1]+(y2-y1)) return nx1, ny1, nx2, ny2 def num_pixel_func(self, ostu_query_image): """ True: 无缺陷 False:有缺陷 :return: :rtype: """ num_pixel_region_query = round((np.sum(ostu_query_image == 0) / (ostu_query_image.shape[0] * ostu_query_image.shape[1])), 2) if num_pixel_region_query >= self.threshld: return True return False def cut_image(self, image): return image[self.y1:self.y2, self.x1:self.x2] def otsu_binarize(self, image): gray_image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) gray_image = cv2.equalizeHist(gray_image) # ret1, mask1 = cv2.threshold(gray_image, 0, 255, cv2.THRESH_BINARY +cv2.THRESH_OTSU) ret2, mask = cv2.threshold(gray_image, 60, 255, cv2.THRESH_BINARY) # mask =mask1 & mask2 kernel = np.ones((3, 3), np.uint8) mask = cv2.dilate(mask, kernel, iterations=3) mask = cv2.erode(mask, kernel, iterations=3) return mask def read_yaml(yaml_file): with open(yaml_file, 'r') as file: data = yaml.safe_load(file) return data def calculate_det(struct, yaml_data): for weldment in struct.weldmentclasses: shapelist = weldment.shapelist xylist = weldment.xylist for i in range(len(shapelist)): for method in yaml_data.get(shapelist[i]): class_obj = globals()[method] instance = class_obj(struct.ref_image, struct.query_image, shapelist[i], xylist[i]) weldment.flaglist.append(instance.result) weldment.result = all(weldment.flaglist) weldment.add_method(instance) if __name__ == '__main__': ref_image_path = './data/yongsheng_image/ref_image/image165214-001.jpg' query_image_path = './data/yongsheng_image/test_image_query/image165214-011.jpg' json_path = './data/yongsheng_image/json/image165214-001.json' save_image_path = './data/yongsheng_image/test_regis_result' json_mask_path = './data/yongsheng_image/json/image165214-001_mask.json' # for filename in os.listdir(image_dir): struct = StructureClass(ref_image_path, query_image_path, json_path, save_image_path, json_mask_path) yaml_data = read_yaml('./test.yaml') calculate_det(struct, yaml_data.get('image165214-001')) print()