image_similarity_count.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. #!/usr/bin/env python
  2. # -*- coding: utf-8 -*-
  3. # @Time : 2024/5/24 0024 上午 10:09
  4. # @Author : liudan
  5. # @File : image_similarity_count.py
  6. # @Software: pycharm
  7. import cv2
  8. import os
  9. from PIL import Image, ImageDraw
  10. import numpy as np
  11. from skimage.metrics import structural_similarity as compare_ssim
  12. import json
  13. import random
  14. import yaml
  15. import demo_env
  16. from demo_env import registration_demo
  17. from sklearn.metrics.pairwise import cosine_similarity
  18. def compare_boxes_similarity(image1_path, image2_path, json_file_path, similarity_threshold=0.4):
  19. try:
  20. if not os.path.exists(image1_path):
  21. raise FileNotFoundError(f"Image file {image1_path} not found")
  22. image1 = cv2.imread(image1_path)
  23. # image1 = np.array(Image.open(image1_path)) # 原图尺寸,未resize
  24. # image1 = image1[:, :, ::-1]
  25. image2 = wrap_image
  26. # draw1 = ImageDraw.Draw(image1)
  27. # 存储相似度结果和是否相同的判断
  28. similarity_results = []
  29. same_content_boxes = []
  30. with open(json_file_path, 'r') as f:
  31. data = json.load(f)
  32. for shape in data['shapes']:
  33. if 'points' in shape:
  34. shape['points'] = [[int(round(x)), int(round(y))] for x, y in shape['points']]
  35. x1, y1 = shape['points'][0]
  36. x2, y2 = shape['points'][1]
  37. # 从两幅图像中截取对应区域
  38. # region1 = image1.crop((x1, y1, x2, y2))
  39. # # region1 = image1.crop((x1, y1, x2, y2)).convert('L')
  40. # draw1.rectangle([x1, y1, x2, y2], outline='red', width=2)
  41. # image1.save(os.path.join(params['save_dir'], f'save_annotated1_{i}.jpg'))
  42. # region1.save(os.path.join(params['save_dir'], f'111111{i}.jpg'))
  43. #
  44. region1 = image1[y1:y2, x1:x2]
  45. filename = f'json_image1_{shape["label"]}_{i}.jpg'
  46. cv2.imwrite(os.path.join(params['save_dir'], filename),region1)
  47. cv2.rectangle(image1, (x1, y1), (x2, y2), (0, 255, 0), 2)
  48. cv2.imwrite(os.path.join(params['save_dir'], f'save_annotated1_{i}.jpg'), image2)
  49. # region2 = image2.crop((left-80, top, right-80, bottom))
  50. region2 = image2[y1:y2, x1:x2]
  51. # region2 = cv2.cvtColor(region2, cv2.COLOR_BGR2GRAY)
  52. # region2= region2.transpose(Image.FLIP_TOP_BOTTOM) #旋转180°针对pillowImage对象
  53. # region2 = cv2.rotate(region2, cv2.ROTATE_180)
  54. filename = f'json_image2_{shape["label"]}_{i}.jpg'
  55. cv2.imwrite(os.path.join(params['save_dir'], filename), region2)
  56. cv2.rectangle(image2, (x1, y1), (x2, y2),(0,255,0), 2)
  57. cv2.imwrite(os.path.join(params['save_dir'], f'save_annotated2_{i}.jpg'), image2)
  58. # 将PIL图像转换为numpy数组,以便进行计算
  59. arr1 = np.array(region1)
  60. arr2 = region2 # region2一直是numpy数组,所以上述image1和image2处理方式不同
  61. # 确保两个数组的形状是相同的
  62. assert arr1.shape == arr2.shape, "Images do not have the same size for the given box"
  63. # 使用SSIM计算相似度(范围在-1到1之间,1表示完全相似)
  64. # ssim = compare_ssim(arr1, arr2, multichannel=False) # 这是旧版,可以计算灰度图相似度,对于计算彩色图像即使设置multichannel=True也错
  65. # ssim = compare_ssim(arr1, arr2, channel_axis=2)
  66. ssim = batch_ssim(arr1, arr2)
  67. similarity_results.append(ssim)
  68. if ssim > similarity_threshold:
  69. same_content_boxes.append(shape)
  70. cv2.rectangle(image2, (x1, y1), (x2, y2),(0,255,0), 2)
  71. text = "score: " + str(round(ssim, 3))
  72. text_pos = (x1, y1 - 5)
  73. # 参数:图像, 文本, 文本位置, 字体类型, 字体大小, 字体颜色, 字体粗细
  74. cv2.putText(image2, text, text_pos, cv2.FONT_HERSHEY_SIMPLEX, 1.5, (0, 255, 0), 2)
  75. cv2.imwrite(os.path.join(params['visualization_dir'],f'{wrap_images_name[:-8]}_{i}.jpg'), image2)
  76. else:
  77. cv2.rectangle(image2, (x1, y1), (x2, y2), (0, 0, 255), 2)
  78. text = "score: " + str(round(ssim, 3))
  79. text_pos = (x1, y1 - 5)
  80. # 参数:图像, 文本, 文本位置, 字体类型, 字体大小, 字体颜色, 字体粗细
  81. cv2.putText(image2, text, text_pos, cv2.FONT_HERSHEY_SIMPLEX, 1.5, (0, 0, 255), 2)
  82. cv2.imwrite(os.path.join(params['visualization_dir'], f'{wrap_images_name[:-8]}_{i}.jpg'), image2)
  83. return similarity_results, same_content_boxes
  84. except FileNotFoundError as e:
  85. print(f"An error occurred: {e}")
  86. except Exception as e:
  87. print(f"An unexpected error occurred: {e}")
  88. return None, None
  89. def read_params_from_yml(yml_file_path):
  90. with open(yml_file_path, 'r') as file:
  91. params = yaml.safe_load(file)
  92. return params
  93. def batch_ssim(im1, im2):
  94. imgsize = im1.shape[1] * im1.shape[2]
  95. avg1 = im1.mean((1, 2), keepdims=1)
  96. avg2 = im2.mean((1, 2), keepdims=1)
  97. std1 = im1.std((1, 2), ddof=1)
  98. std2 = im2.std((1, 2), ddof=1)
  99. cov = ((im1 - avg1) * (im2 - avg2)).mean((1, 2)) * imgsize / (imgsize - 1)
  100. avg1 = np.squeeze(avg1)
  101. avg2 = np.squeeze(avg2)
  102. k1 = 0.01
  103. k2 = 0.03
  104. c1 = (k1 * 255) ** 2
  105. c2 = (k2 * 255) ** 2
  106. c3 = c2 / 2
  107. # return np.mean((cov + c3) / (std1 * std2 + c3))
  108. return np.mean(
  109. (2 * avg1 * avg2 + c1) * 2 * (cov + c3) / (avg1 ** 2 + avg2 ** 2 + c1) / (std1 ** 2 + std2 ** 2 + c2))
  110. if __name__ == "__main__":
  111. yml_file_path = 'params.yml'
  112. params = read_params_from_yml(yml_file_path)
  113. wrap_images_all = registration_demo(params['image_dir'],params['demo_image_path'], params['json_ref_path'], params['ref_image_path'])
  114. for i, item in enumerate(wrap_images_all):
  115. wrap_image,wrap_images_name = item
  116. similarity_results, same_content_boxes = compare_boxes_similarity(params['path_to_image1'], wrap_image, params['json_file_path'],
  117. params['similarity_threshold'])
  118. # 打印所有坐标框的相似度结果
  119. print(f"{wrap_images_name}\n")
  120. for idx, score in enumerate(similarity_results, 1):
  121. print(f"Similarity Score for Box {idx}: {score}")
  122. # 打印被认为是相同内容的坐标框
  123. print("Boxes with the same content:")
  124. for shape in same_content_boxes:
  125. print(shape['label'] + ' object is same as template')