123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444 |
- """
- # File : test.py
- # Time :2024-06-11 10:15
- # Author :FEANGYANG
- # version :python 3.7
- # Contact :1071082183@qq.com
- # Description:
- """
- import json
- import itertools
- import random
- import cv2
- from PIL import Image, ImageDraw
- import os
- from superpoint_superglue_deployment import Matcher
- import numpy as np
- from loguru import logger
- import cv2 as cv
- import yaml
- class StructureClass:
- def __init__(self, ref_image_path, query_image_path, json_path, save_image_path, json_mask_path, scale_factor=0.37):
- self.weldmentclasses = []
- self.ref_image = []
- self.query_image = []
- self.boxes_xy_label = {}
- self.scale_factor = scale_factor
- self.superglue_matcher = Matcher({
- "superpoint": {
- "input_shape": (-1, -1),
- "keypoint_threshold": 0.003,
- },
- "superglue": {
- "match_threshold": 0.5,
- },
- "use_gpu": True,
- })
- self.read_ref_image(ref_image_path)
- self.read_query_image(query_image_path)
- self.read_json(json_path)
- _, wrap_image, _ = self.registration_demo(save_image_path, json_mask_path)
- self.replace_query_image(wrap_image)
- self.group_weldment()
- # 定义一个函数来获取每个元素的首字母
- def get_first_letter(self, item):
- it = '-'.join(item.split('-')[:2])
- return it
- def read_ref_image(self, path):
- self.ref_image = self.process_image_data(path)
- def read_query_image(self, path):
- self.query_image = self.process_image_data(path)
- def replace_query_image(self, wrap_image):
- self.query_image = wrap_image
- def read_json(self, json_path):
- """
- 读取焊接件的标注信息
- :param json_path:
- :type json_path:
- :return:
- :rtype:
- """
- with open(json_path, 'r') as f:
- data = json.load(f)
- for shape in data['shapes']:
- if 'points' in shape:
- shape['points'] = [[int(round(x)), int(round(y))] for x, y in shape['points']]
- x1, y1 = shape['points'][0]
- x2, y2 = shape['points'][1]
- label = shape['label']
- self.boxes_xy_label[label] = [x1, y1, x2, y2]
- def process_image_data(self, data):
- """
- 读取图像,如果是文件则用cv读取,如果已经读取过则直接返回
- :param data:
- :type data:
- :return:
- :rtype:
- """
- extensions = ['.PNG', '.png', '.jpg', '.jpeg', '.JPG', '.JPEG']
- if any(data.endswith(ext) for ext in extensions):
- if not os.path.exists(data):
- raise FileNotFoundError(f"Image file {data} not found")
- image = cv2.imread(data)
- return image
- else:
- if isinstance(data, np.ndarray):
- return data
- else:
- raise FileNotFoundError(f"Image file {data} not found")
- def mask2image(self, json_mask_path, ref_gray):
- """
- 如果image_mask json文件存在,则将ref_image按照mask的形状扣出来
- :param json_mask_path:
- :type json_mask_path:
- :param ref_gray:
- :type ref_gray:
- :return:
- :rtype:
- """
- with open(json_mask_path, 'r') as f:
- data = json.load(f)
- shape = data['shapes'][0]
- if shape['shape_type'] == 'polygon':
- coords = [(int(round(x * self.scale_factor)), int(round(y * self.scale_factor))) for x, y in
- shape['points']]
- mask = np.zeros(ref_gray.shape, np.uint8)
- cv2.fillPoly(mask, [np.array(coords, np.int32)], 255)
- ref_gray = cv2.bitwise_and(ref_gray, mask)
- return ref_gray
- def registration_demo(self, save_image_path, json_mask_path):
- ref_image_resize = cv2.resize(self.ref_image, None, fx=self.scale_factor, fy=self.scale_factor)
- query_image_resize = cv2.resize(self.query_image, None, fx=self.scale_factor, fy=self.scale_factor)
- ref_gray = cv2.cvtColor(ref_image_resize, cv2.COLOR_BGR2GRAY)
- query_gray = cv2.cvtColor(query_image_resize, cv2.COLOR_BGR2GRAY)
- if os.path.exists(json_mask_path):
- ref_gray = self.mask2image(json_mask_path, ref_gray)
- query_kpts, ref_kpts, _, _, matches = self.superglue_matcher.match(query_gray, ref_gray)
- src_pts = np.float32([query_kpts[m.queryIdx].pt for m in matches]).reshape(-1, 1, 2)
- dst_pts = np.float32([ref_kpts[m.trainIdx].pt for m in matches]).reshape(-1, 1, 2)
- M, mask = cv2.findHomography(src_pts, dst_pts, cv2.USAC_MAGSAC, 5.0, maxIters=10000, confidence=0.95)
- logger.info(f"Number of inliers: {mask.sum()}")
- matches = [m for m, m_mask in zip(matches, mask) if m_mask]
- matches.sort(key=lambda m: m.distance)
- matched_image = cv2.drawMatches(query_image_resize, query_kpts, ref_image_resize, ref_kpts, matches[:50], None, flags=2)
- wrap_image = cv2.warpPerspective(query_image_resize, M, (ref_image_resize.shape[1], ref_image_resize.shape[0]))
- wrap_image = cv2.resize(wrap_image, (self.ref_image.shape[1], self.ref_image.shape[0]))
- sub_image = cv2.subtract(self.ref_image, wrap_image)
- cv2.imwrite(os.path.join(save_image_path, "match.jpg"), matched_image)
- cv2.imwrite(os.path.join(save_image_path, "wrap.jpg"), wrap_image)
- cv2.imwrite(os.path.join(save_image_path, "result.jpg"), sub_image)
- return matched_image, wrap_image, sub_image
- def group_weldment(self):
- grouped_data = {}
- for key, group in itertools.groupby(sorted(self.boxes_xy_label), self.get_first_letter):
- grouped_data[key] = list(group)
- # 创建子类实例并添加到大类中
- for key, group in grouped_data.items():
- subclass = WeldmentClass(key)
- for g in group:
- subclass.addshapelist(g, self.boxes_xy_label.get(g))
- self.add_weldmentclass(subclass)
- def add_weldmentclass(self, weldmentclass):
- self.weldmentclasses.append(weldmentclass)
- # 焊接件类
- class WeldmentClass(StructureClass):
- def __init__(self, name):
- self.shapelist = []
- self.xylist = []
- self.methodclasses = []
- self.name = name
- self.flaglist = []
- self.result = None
- def addshapelist(self, shape, box_xy):
- self.shapelist.append(shape)
- self.xylist.append(box_xy)
- def add_method(self, methodclass):
- self.methodclasses.append(methodclass)
- class SSIMDet:
- def __init__(self, ref_image, query_image, label, box_xy): # x1, y1, x2, y2
- self.name = 'SSIM'
- self.label = label
- self.x1, self.y1, self.x2, self.y2 = box_xy
- self.cut_ref_image = self.cut_image(ref_image)
- self.cut_query_image = self.cut_image(query_image)
- self.result = self.ssim_func(self.cut_ref_image, self.cut_query_image)
- def cut_image(self, image):
- return image[self.y1:self.y2, self.x1:self.x2]
- def ssim_func(self, im1, im2):
- imgsize = im1.shape[1] * im1.shape[2]
- avg1 = im1.mean((1, 2), keepdims=1)
- avg2 = im2.mean((1, 2), keepdims=1)
- std1 = im1.std((1, 2), ddof=1)
- std2 = im2.std((1, 2), ddof=1)
- cov = ((im1 - avg1) * (im2 - avg2)).mean((1, 2)) * imgsize / (imgsize - 1)
- avg1 = np.squeeze(avg1)
- avg2 = np.squeeze(avg2)
- k1 = 0.01
- k2 = 0.03
- c1 = (k1 * 255) ** 2
- c2 = (k2 * 255) ** 2
- c3 = c2 / 2
- # return np.mean((cov + c3) / (std1 * std2 + c3))
- return np.mean(
- (2 * avg1 * avg2 + c1) * 2 * (cov + c3) / (avg1 ** 2 + avg2 ** 2 + c1) / (std1 ** 2 + std2 ** 2 + c2))
- class VarianceDet:
- def __init__(self, ref_image, query_image, label, box_xy):
- self.name = 'VarianceDet'
- self.label = label
- self.x1, self.y1, self.x2, self.y2 = box_xy
- self.cut_ref_image = self.cut_image(ref_image)
- self.cut_query_image = self.cut_image(query_image)
- self.proportion = self.black_pixels_proportion(self.cut_query_image)
- if self.proportion > 0.05:
- self.result = 1
- else:
- self.result = self.variance_det_func(self.cut_ref_image, self.cut_query_image)
- def cut_image(self, image):
- return image[self.y1:self.y2, self.x1:self.x2]
- def black_pixels_proportion(self, cut_query_image):
- black_pixels = np.sum(cv2.cvtColor(cut_query_image, cv2.COLOR_BGR2GRAY) == 0)
- other_pixels = np.sum(cv2.cvtColor(cut_query_image, cv2.COLOR_BGR2GRAY) != 0)
- proportion = black_pixels / (other_pixels + black_pixels)
- return proportion
- # 计算两张图片的方差
- def calculate_variance(self, image):
- gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
- variance = np.var(gray_image)
- mean = np.mean(image)
- variance = variance / mean
- return variance
- def variance_det_func(self, ref_image, query_image):
- variance1 = self.calculate_variance(ref_image)
- variance2 = self.calculate_variance(query_image)
- variance_diff = abs(variance1 - variance2)
- max_variance = max(variance1, variance2)
- normalized_diff = variance_diff / max_variance if max_variance != 0 else 0
- if normalized_diff < 0.9:
- return True
- else:
- return False
- class RorateDet:
- def __init__(self, ref_image, query_image, label, box_xy):
- self.name = 'RorateDet'
- self.label = label
- self.x1, self.y1, self.x2, self.y2 = box_xy
- self.cut_ref_image = self.cut_image(ref_image)
- self.cut_query_image = self.cut_image(query_image)
- self.query_image_rotate = cv2.rotate(self.cut_query_image, cv2.ROTATE_180)
- self.result = self.rorate_det_func()
- def cut_image(self, image):
- return image[self.y1:self.y2, self.x1:self.x2]
- # 配准操作,计算模板图与查询图配准点的距离差异(翻转后的距离大于未翻转的)
- def match_image(self, ref_img, query_img):
- superglue_matcher = Matcher(
- {
- "superpoint": {
- "input_shape": (-1, -1),
- "keypoint_threshold": 0.003,
- },
- "superglue": {
- "match_threshold": 0.5,
- },
- "use_gpu": True,
- },shape=20
- )
- ref_img = cv2.cvtColor(ref_img, cv2.COLOR_BGR2GRAY)
- query_img = cv2.cvtColor(query_img, cv2.COLOR_BGR2GRAY)
- query_kpts, ref_kpts, _, _, matches = superglue_matcher.match(ref_img, query_img)
- matched_query_kpts = [query_kpts[m.queryIdx].pt for m in matches]
- matched_ref_kpts = [ref_kpts[m.trainIdx].pt for m in matches]
- diff_query_ref = np.array(matched_ref_kpts) - np.array(matched_query_kpts)
- print(diff_query_ref)
- if len(matches) != 0:
- diff_query_ref = np.linalg.norm(diff_query_ref, axis=1, ord=2)
- diff_query_ref = np.sqrt(diff_query_ref)
- diff_query_ref = np.mean(diff_query_ref)
- else:
- diff_query_ref = np.inf
- return diff_query_ref # 返回差异值,用于后续做比较
- def rorate_det_func(self):
- """
- True: 没有缺陷
- False:有缺陷
- :return:
- """
- diff1 = self.match_image(self.cut_ref_image, self.cut_query_image) # 计算模板图与查询图的配准点差异
- diff2 = self.match_image(self.cut_ref_image, self.query_image_rotate) # 计算模板图与翻转180度图的配准点差异
- if diff1 < diff2:
- return True
- return False
- class NumPixel:
- def __init__(self, ref_image, query_image, label, box_xy, threshld=0.15, x_scale = 120, y_scale = 80):
- self.name = 'NumPixel'
- self.label = label
- self.scale_box = []
- self.x_scale, self.y_scale, self.threshld = x_scale, y_scale, threshld
- self.ref_h, self.ref_w, _ = ref_image.shape
- self.x1, self.y1, self.x2, self.y2 = self.big_box(box_xy)
- # self.cut_ref_image = self.cut_image(ref_image)
- self.cut_query_image = self.cut_image(query_image)
- self.ostu_query_image = self.otsu_binarize(self.cut_query_image)
- self.ostu_query_image = self.ostu_query_image[self.scale_box[1]:self.scale_box[3], self.scale_box[0]:self.scale_box[2]]
- self.result = self.num_pixel_func(self.ostu_query_image)
- def big_box(self, box_xy):
- x1, y1, x2, y2 = box_xy
- nx1, ny1, nx2, ny2 = 0,0,0,0
- if x1 >= self.x_scale:
- nx1 = x1-self.x_scale
- self.scale_box.append(self.x_scale)
- else:
- nx1 = 0
- self.scale_box.append(x1)
- if y1 >= self.y_scale:
- ny1 = y1-self.y_scale
- self.scale_box.append(self.y_scale)
- else:
- ny1 = 0
- self.scale_box.append(y1)
- if x2 + self.x_scale <= self.ref_w:
- nx2 = x2 + self.x_scale
- self.scale_box.append(self.scale_box[0]+(x2-x1))
- else:
- nx2 = self.ref_w
- self.scale_box.append(self.scale_box[0]+(x2-x1))
- if y2 + self.y_scale <= self.ref_h:
- ny2 = y2 + self.y_scale
- self.scale_box.append(self.scale_box[1]+(y2-y1))
- else:
- ny2 = self.ref_h
- self.scale_box.append(self.scale_box[1]+(y2-y1))
- return nx1, ny1, nx2, ny2
- def num_pixel_func(self, ostu_query_image):
- """
- True: 无缺陷
- False:有缺陷
- :return:
- :rtype:
- """
- num_pixel_region_query = round((np.sum(ostu_query_image == 0) / (ostu_query_image.shape[0] * ostu_query_image.shape[1])), 2)
- if num_pixel_region_query >= self.threshld:
- return True
- return False
- def cut_image(self, image):
- return image[self.y1:self.y2, self.x1:self.x2]
- def otsu_binarize(self, image):
- gray_image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
- gray_image = cv2.equalizeHist(gray_image)
- # ret1, mask1 = cv2.threshold(gray_image, 0, 255, cv2.THRESH_BINARY +cv2.THRESH_OTSU)
- ret2, mask = cv2.threshold(gray_image, 60, 255, cv2.THRESH_BINARY)
- # mask =mask1 & mask2
- kernel = np.ones((3, 3), np.uint8)
- mask = cv2.dilate(mask, kernel, iterations=3)
- mask = cv2.erode(mask, kernel, iterations=3)
- return mask
- def read_yaml(yaml_file):
- with open(yaml_file, 'r') as file:
- data = yaml.safe_load(file)
- return data
- def calculate_det(struct, yaml_data):
- for weldment in struct.weldmentclasses:
- shapelist = weldment.shapelist
- xylist = weldment.xylist
- for i in range(len(shapelist)):
- for method in yaml_data.get(shapelist[i]):
- class_obj = globals()[method]
- instance = class_obj(struct.ref_image, struct.query_image, shapelist[i], xylist[i])
- weldment.flaglist.append(instance.result)
- weldment.result = all(weldment.flaglist)
- weldment.add_method(instance)
- if __name__ == '__main__':
- ref_image_path = './data/yongsheng_image/ref_image/image165214-001.jpg'
- query_image_path = './data/yongsheng_image/test_image_query/image165214-011.jpg'
- json_path = './data/yongsheng_image/json/image165214-001.json'
- save_image_path = './data/yongsheng_image/test_regis_result'
- json_mask_path = './data/yongsheng_image/json/image165214-001_mask.json'
- # for filename in os.listdir(image_dir):
- struct = StructureClass(ref_image_path, query_image_path, json_path, save_image_path, json_mask_path)
- yaml_data = read_yaml('./test.yaml')
- calculate_det(struct, yaml_data.get('image165214-001'))
- print()
|