#!/usr/bin/env python # -*- coding: utf-8 -*- # @Time : 2024/6/26 0026 下午 3:21 # @Author : liudan # @File : image_registration.py # @Software: pycharm import os import cv2 import json import numpy as np from loguru import logger from superpoint_superglue_deployment import Matcher class GenerateTestImage: def __init__(self, template_img_dir, template_mask_json_dir, save_dir): self.template_img_dir = template_img_dir self.template_mask_json_dir = template_mask_json_dir self.save_dir = save_dir if not os.path.exists(self.save_dir): os.mkdir(self.save_dir) self.template_img_paths = sorted([os.path.join(self.template_img_dir, f) for f in os.listdir(self.template_img_dir) if f.endswith('.jpg') or f.endswith('.JPG') ]) self.template_mask_json_paths = sorted([os.path.join(self.template_mask_json_dir, f) for f in os.listdir(self.template_mask_json_dir) if f.endswith('.json')]) def generate_test_image(self, query_image_dir): for file_name in os.listdir(query_image_dir): query_image_path = os.path.join(query_image_dir, file_name) best_matched_index = self.find_best_matched_template_image(query_image_path) best_matched_template_img_path = self.template_img_paths[best_matched_index] best_matched_template_mask_json_path = self.template_mask_json_paths[best_matched_index] query_image = cv2.imread(query_image_path) ref_image = cv2.imread(best_matched_template_img_path) height, width = ref_image.shape[:2] scale_factor = 0.37 query_image_resize = cv2.resize(query_image, dsize=None, fx=scale_factor, fy=scale_factor) ref_image_resize = cv2.resize(ref_image, dsize=None, fx=scale_factor, fy=scale_factor) query_gray = cv2.cvtColor(query_image_resize, cv2.COLOR_BGR2GRAY) ref_gray = cv2.cvtColor(ref_image_resize, cv2.COLOR_BGR2GRAY) if os.path.exists(best_matched_template_mask_json_path): ref_gray = self.process_image_with_mask(ref_image, scale_factor, best_matched_template_mask_json_path, ref_gray) superglue_matcher = Matcher( { "superpoint": { "input_shape": (-1, -1), "keypoint_threshold": 0.003, }, "superglue": { "match_threshold": 0.5, }, "use_gpu": True, } ) query_kpts, ref_kpts, _, _, matches = superglue_matcher.match(query_gray, ref_gray) try: if len(matches) < 4: raise ValueError("Not enough matches to compute homography.") M, mask = cv2.findHomography( np.float64([query_kpts[m.queryIdx].pt for m in matches]).reshape(-1, 1, 2), np.float64([ref_kpts[m.trainIdx].pt for m in matches]).reshape(-1, 1, 2), method=cv2.USAC_MAGSAC, ransacReprojThreshold=5.0, maxIters=10000, confidence=0.95, ) if M is None: raise ValueError("Unable to compute homography.") logger.info(f"number of inliers: {mask.sum()}") matches = np.array(matches)[np.all(mask > 0, axis=1)] matches = sorted(matches, key=lambda match: match.distance) matched_image = cv2.drawMatches(query_image_resize,query_kpts,ref_image_resize,ref_kpts,matches[:50],None,flags=2) wrap_image = cv2.warpPerspective(query_image_resize, M,(ref_image_resize.shape[1], ref_image_resize.shape[0])) wrap_image = cv2.resize(wrap_image, (ref_image.shape[1], ref_image.shape[0])) sub_image = cv2.subtract(ref_image, wrap_image) cv2.imwrite(os.path.join(self.save_dir, f'match_{file_name[:-4]}.jpg'), matched_image) cv2.imwrite(os.path.join(self.save_dir, f'wrap_{file_name[:-4]}.jpg'), wrap_image) cv2.imwrite(os.path.join(self.save_dir, f'sub_{file_name[:-4]}.jpg'), sub_image) except ValueError as e: print(e) except Exception as e: print(f"An error occurred: {e}") def process_image_with_mask(self, ref_image, scale_factor, json_ref_path, ref_gray): with open(json_ref_path, 'r') as f: data = json.load(f) shapes = data['shapes'] shape = shapes[0] if shape['shape_type'] == 'polygon': coords = [(int(round(x * scale_factor)), int(round(y * scale_factor))) for x, y in shape['points']] else: coords = [] mask = np.zeros(ref_gray.shape, dtype=np.uint8) * 255 pts = np.array(coords, np.int32) cv2.fillPoly(mask, [pts], 1) ref_gray_masked = cv2.bitwise_and(ref_gray, ref_gray, mask=mask) return ref_gray_masked def find_best_matched_template_image(self, query_img_path): number_match_pts_list = [] for index, item in enumerate(self.template_img_paths): ref_image_path = item query_image = cv2.imread(query_img_path) ref_image = cv2.imread(ref_image_path) height, width = ref_image.shape[:2] scale_factor = 0.37 query_image_resize = cv2.resize(query_image, dsize=None, fx=scale_factor, fy=scale_factor) ref_image_resize = cv2.resize(ref_image, dsize=None, fx=scale_factor, fy=scale_factor) query_gray = cv2.cvtColor(query_image_resize, cv2.COLOR_BGR2GRAY) ref_gray = cv2.cvtColor(ref_image_resize, cv2.COLOR_BGR2GRAY) if os.path.exists( self.template_mask_json_paths[index]): ref_gray = self.process_image_with_mask(ref_image, scale_factor,self.template_mask_json_paths[index], ref_gray) superglue_matcher = Matcher( { "superpoint": { "input_shape": (-1, -1), "keypoint_threshold": 0.003, }, "superglue": { "match_threshold": 0.5, }, "use_gpu": True, } ) query_kpts, ref_kpts, _, _, matches = superglue_matcher.match(query_gray, ref_gray) number_match_pts_list.append(len(matches)) best_match_index = number_match_pts_list.index(max(number_match_pts_list)) return best_match_index if __name__ =="__main__": template_img_dir = './data/best_registration/template_image/' template_mask_json_dir = './data/best_registration/mask_json/' save_dir = "./data/best_registration/result" query_image_dir = "./data/best_registration/query_image" gti =GenerateTestImage(template_img_dir, template_mask_json_dir, save_dir) gti.generate_test_image(query_image_dir)