demo_env.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # @Time : 2024/6/6 0006 上午 9:21
  4. # @Author : liudan
  5. # @File : demo_env.py
  6. # @Software: pycharm
  7. import cv2 as cv
  8. import cv2
  9. import numpy as np
  10. from loguru import logger
  11. import os
  12. import json
  13. from superpoint_superglue_deployment import Matcher
  14. from datetime import datetime
  15. import random
  16. import yaml
  17. # timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
  18. #import image_similarity_count
  19. def process_image_with_mask(ref_image, scale_factor, json_ref_path, ref_gray):
  20. with open(json_ref_path, 'r') as f:
  21. data = json.load(f)
  22. shapes = data['shapes']
  23. shape = shapes[0]
  24. if shape['shape_type'] == 'polygon':
  25. coords = [(int(round(x * scale_factor)), int(round(y * scale_factor))) for x, y in shape['points']]
  26. else:
  27. coords = []
  28. mask = np.zeros(ref_gray.shape, dtype=np.uint8) * 255
  29. pts = np.array(coords, np.int32)
  30. cv2.fillPoly(mask, [pts], 1)
  31. ref_gray_masked = cv2.bitwise_and(ref_gray, ref_gray, mask=mask)
  32. # cv2.imwrite('1111111.jpg',ref_gray_masked)
  33. return ref_gray_masked
  34. def registration_demo(image_dir,demo_image_path, json_ref_path, ref_image_path):
  35. wraped_image_list =[]
  36. for filename in os.listdir(image_dir):
  37. # 检查文件是否是图片(这里假设是.jpg格式)
  38. if filename.endswith('.jpg') or filename.endswith('.JPG'):
  39. # 构建图片的完整路径
  40. image1_path = os.path.join(image_dir, filename)
  41. query_image = cv2.imread(image1_path)
  42. ref_image = cv2.imread(ref_image_path)
  43. height, width = ref_image.shape[:2]
  44. scale_factor = 0.37
  45. query_image_resize = cv2.resize(query_image, dsize=None, fx=scale_factor, fy=scale_factor)
  46. ref_image_resize = cv2.resize(ref_image, dsize=None, fx=scale_factor, fy=scale_factor)
  47. query_gray = cv2.cvtColor(query_image_resize, cv2.COLOR_BGR2GRAY)
  48. ref_gray = cv2.cvtColor(ref_image_resize, cv2.COLOR_BGR2GRAY)
  49. query_gray = cv2.cvtColor(query_image, cv2.COLOR_BGR2GRAY)
  50. ref_gray = cv2.cvtColor(ref_image, cv2.COLOR_BGR2GRAY)
  51. if os.path.exists(json_ref_path):
  52. ref_gray= process_image_with_mask(ref_image, scale_factor, json_ref_path,ref_gray)
  53. superglue_matcher = Matcher(
  54. {
  55. "superpoint": {
  56. "input_shape": (-1, -1),
  57. "keypoint_threshold": 0.003,
  58. },
  59. "superglue": {
  60. "match_threshold": 0.5,
  61. },
  62. "use_gpu": True,
  63. }
  64. )
  65. query_kpts, ref_kpts, _, _, matches = superglue_matcher.match(query_gray, ref_gray)
  66. # print(f"{filename}\n")
  67. # matched_query_kpts =[query_kpts[m.queryIdx].pt for m in matches]
  68. # matched_ref_kpts =[ref_kpts[m.trainIdx].pt for m in matches]
  69. # diff_query_ref=np.array(matched_ref_kpts) -np.array(matched_query_kpts)
  70. # # 计算每一列的模
  71. # diff_query_ref = np.linalg.norm(diff_query_ref, axis=0, ord=1)
  72. # # 计算所有行的平均值
  73. # diff_query_ref = np.mean(diff_query_ref)
  74. # # for index in range(len(matched_query_kpts)):
  75. # # diff_query_ref.append(matched_query_kpts[index]-matched_ref_kpts[index])
  76. #
  77. # print(matched_query_kpts)
  78. # print(matched_ref_kpts)
  79. # print(diff_query_ref)
  80. M, mask = cv2.findHomography(
  81. np.float64([query_kpts[m.queryIdx].pt for m in matches]).reshape(-1, 1, 2),
  82. np.float64([ref_kpts[m.trainIdx].pt for m in matches]).reshape(-1, 1, 2),
  83. method=cv2.USAC_MAGSAC,
  84. ransacReprojThreshold=5.0,
  85. maxIters=10000,
  86. confidence=0.95,
  87. )
  88. logger.info(f"number of inliers: {mask.sum()}")
  89. matches = np.array(matches)[np.all(mask > 0, axis=1)]
  90. matches = sorted(matches, key=lambda match: match.distance)
  91. matched_image = cv2.drawMatches(
  92. query_image_resize,
  93. query_kpts,
  94. ref_image_resize,
  95. ref_kpts,
  96. matches[:50],
  97. None,
  98. flags=2,
  99. )
  100. match_file_name = f"match_image_{filename}.jpg"
  101. cv2.imwrite(os.path.join(demo_image_path , match_file_name), matched_image)
  102. wrap_image = cv.warpPerspective(query_image_resize, M, (ref_image_resize.shape[1], ref_image_resize.shape[0]))
  103. # wrap_image = cv.warpPerspective(query_image, M,(ref_image.shape[1], ref_image.shape[0]))
  104. wrap_image = cv2.resize(wrap_image,(ref_image.shape[1], ref_image.shape[0]))
  105. wrap_filename = f"wrap_image_{filename}.jpg"
  106. cv2.imwrite(os.path.join(demo_image_path, wrap_filename), wrap_image)
  107. wraped_image_list.append((wrap_image,wrap_filename))
  108. result_image = cv2.subtract(ref_image, wrap_image)
  109. result_file_name = f"result_image_{filename}.jpg"
  110. cv2.imwrite(os.path.join(demo_image_path, result_file_name), result_image)
  111. return wraped_image_list
  112. def read_params_from_yml(yml_file_path):
  113. with open(yml_file_path, 'r') as file:
  114. params = yaml.safe_load(file)
  115. return params
  116. if __name__ == "__main__":
  117. yml_file_path = 'params.yml'
  118. params = read_params_from_yml(yml_file_path)
  119. registration_demo(params['image_dir'],params['demo_image_path'], params['json_ref_path'], params['ref_image_path'])