#!/usr/bin/env python # -*- coding: utf-8 -*- # @Time : 2024/6/18 0018 下午 3:49 # @Author : liudan # @File : generate_test_images.py # @Software: pycharm import os import cv2 import json import numpy as np from loguru import logger from superpoint_superglue_deployment import Matcher class GenerateTestImage: def __init__(self, template_img_dir, template_mask_json_dir, template_label_json_dir, save_dir): # self.template_img_paths = ['./data/ref/image165214-001.jpg','./data/ref/image165214-019.jpg', './data/ref/image165214-037.jpg'] # self.template_mask_json_paths = ['./data/json/image165214-001_mask.json','./data/json/image165214-019_mask.json', './data/json/image165214-037_mask.json'] # self.template_label_json_paths = ['./data/json/image165214-001.json','./data/json/image165214-019.json', './data/json/image165214-037.json'] # self.save_dir="./data/generated_mask_test_images" # if not os.path.exists(self.save_dir): # os.mkdir(self.save_dir) self.template_img_dir = template_img_dir self.template_mask_json_dir = template_mask_json_dir self.template_label_json_dir = template_label_json_dir self.save_dir = save_dir if not os.path.exists(self.save_dir): os.mkdir(self.save_dir) self.template_img_paths = sorted([os.path.join(self.template_img_dir, f) for f in os.listdir(self.template_img_dir) if f.endswith('.jpg') or f.endswith('.JPG') ]) self.template_mask_json_paths = sorted([os.path.join(self.template_mask_json_dir, f) for f in os.listdir(self.template_mask_json_dir) if f.endswith('.json')]) self.template_label_json_paths = sorted([os.path.join(self.template_label_json_dir, f) for f in os.listdir(self.template_label_json_dir) if f.endswith('.json')]) def generate_test_image(self, query_image_dir): for file_name in os.listdir(query_image_dir): query_image_path = os.path.join(query_image_dir, file_name) best_matched_index = self.find_best_matched_template_image(query_image_path) best_matched_template_img_path = self.template_img_paths[best_matched_index] best_matched_template_mask_json_path = self.template_mask_json_paths[best_matched_index] best_matched_template_label_json_path = self.template_label_json_paths[best_matched_index] query_image = cv2.imread(query_image_path) ref_image = cv2.imread(best_matched_template_img_path) height, width = ref_image.shape[:2] scale_factor = 0.37 query_image_resize = cv2.resize(query_image, dsize=None, fx=scale_factor, fy=scale_factor) ref_image_resize = cv2.resize(ref_image, dsize=None, fx=scale_factor, fy=scale_factor) query_gray = cv2.cvtColor(query_image_resize, cv2.COLOR_BGR2GRAY) ref_gray = cv2.cvtColor(ref_image_resize, cv2.COLOR_BGR2GRAY) if os.path.exists(best_matched_template_mask_json_path): ref_gray = self.process_image_with_mask(ref_image, scale_factor, best_matched_template_mask_json_path, ref_gray) superglue_matcher = Matcher( { "superpoint": { "input_shape": (-1, -1), "keypoint_threshold": 0.003, }, "superglue": { "match_threshold": 0.5, }, "use_gpu": True, } ) query_kpts, ref_kpts, _, _, matches = superglue_matcher.match(query_gray, ref_gray) try: if len(matches) < 4: raise ValueError("Not enough matches to compute homography.") M, mask = cv2.findHomography( np.float64([query_kpts[m.queryIdx].pt for m in matches]).reshape(-1, 1, 2), np.float64([ref_kpts[m.trainIdx].pt for m in matches]).reshape(-1, 1, 2), method=cv2.USAC_MAGSAC, ransacReprojThreshold=5.0, maxIters=10000, confidence=0.95, ) if M is None: raise ValueError("Unable to compute homography.") wrap_image = cv2.warpPerspective(query_image_resize, M, (ref_image_resize.shape[1], ref_image_resize.shape[0])) wrap_image = cv2.resize(wrap_image, (ref_image.shape[1], ref_image.shape[0])) with open(best_matched_template_label_json_path) as json_file: data = json.load(json_file) for shape in data['shapes']: if 'points' in shape: shape['points'] = [[int(round(x)), int(round(y))] for x, y in shape['points']] x1, y1 = shape['points'][0] x2, y2 = shape['points'][1] region = wrap_image[y1:y2, x1:x2] filename = f'{file_name[:-4]}_{shape["label"]}.jpg' save_subdir =os.path.join(self.save_dir, "test_img", shape["label"]) if not os.path.exists(save_subdir): os.makedirs(save_subdir) black_pixels = np.sum(cv2.cvtColor(region, cv2.COLOR_BGR2GRAY) == 0) if black_pixels == 0: cv2.imwrite(os.path.join(save_subdir, filename), region) except ValueError as e: print(e) except Exception as e: print(f"An error occurred: {e}") def process_image_with_mask(self, ref_image, scale_factor, json_ref_path, ref_gray): with open(json_ref_path, 'r') as f: data = json.load(f) shapes = data['shapes'] shape = shapes[0] if shape['shape_type'] == 'polygon': coords = [(int(round(x * scale_factor)), int(round(y * scale_factor))) for x, y in shape['points']] else: coords = [] mask = np.zeros(ref_gray.shape, dtype=np.uint8) * 255 pts = np.array(coords, np.int32) cv2.fillPoly(mask, [pts], 1) ref_gray_masked = cv2.bitwise_and(ref_gray, ref_gray, mask=mask) return ref_gray_masked def find_best_matched_template_image(self, query_img_path): number_match_pts_list = [] for index, item in enumerate(self.template_img_paths): ref_image_path = item query_image = cv2.imread(query_img_path) ref_image = cv2.imread(ref_image_path) height, width = ref_image.shape[:2] scale_factor = 0.37 query_image_resize = cv2.resize(query_image, dsize=None, fx=scale_factor, fy=scale_factor) ref_image_resize = cv2.resize(ref_image, dsize=None, fx=scale_factor, fy=scale_factor) query_gray = cv2.cvtColor(query_image_resize, cv2.COLOR_BGR2GRAY) ref_gray = cv2.cvtColor(ref_image_resize, cv2.COLOR_BGR2GRAY) if os.path.exists( self.template_mask_json_paths[index]): ref_gray = self.process_image_with_mask(ref_image, scale_factor,self.template_mask_json_paths[index], ref_gray) superglue_matcher = Matcher( { "superpoint": { "input_shape": (-1, -1), "keypoint_threshold": 0.003, }, "superglue": { "match_threshold": 0.5, }, "use_gpu": True, } ) query_kpts, ref_kpts, _, _, matches = superglue_matcher.match(query_gray, ref_gray) number_match_pts_list.append(len(matches)) best_match_index = number_match_pts_list.index(max(number_match_pts_list)) return best_match_index if __name__ =="__main__": template_img_dir = './data/ref/' template_mask_json_dir = './data/mask_json/' template_label_json_dir = './data/json/' save_dir = "./data/D_result" query_image_dir = "./data/query" # match_path = './data/vague_match' gti =GenerateTestImage(template_img_dir, template_mask_json_dir, template_label_json_dir, save_dir) gti.generate_test_image(query_image_dir)