__main__.py 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. import cv2
  2. import numpy as np
  3. from loguru import logger
  4. from superpoint_superglue_deployment import Matcher
  5. def get_args():
  6. import argparse
  7. parser = argparse.ArgumentParser("test matching two images")
  8. parser.add_argument("--query_path", "-q", type=str, required=True, help="path to query image")
  9. parser.add_argument("--ref_path", "-r", type=str, required=True, help="path to reference image")
  10. parser.add_argument("--use_gpu", action="store_true")
  11. return parser.parse_args()
  12. def main():
  13. args = get_args()
  14. query_image = cv2.imread(args.query_path)
  15. ref_image = cv2.imread(args.ref_path)
  16. query_gray = cv2.imread(args.query_path, 0)
  17. ref_gray = cv2.imread(args.ref_path, 0)
  18. superglue_matcher = Matcher(
  19. {
  20. "superpoint": {
  21. "input_shape": (-1, -1),
  22. "keypoint_threshold": 0.005,
  23. },
  24. "superglue": {
  25. "match_threshold": 0.2,
  26. },
  27. "use_gpu": args.use_gpu,
  28. }
  29. )
  30. query_kpts, ref_kpts, _, _, matches = superglue_matcher.match(query_gray, ref_gray)
  31. logger.info(f"number of matches by superpoint+superglue: {len(matches)}")
  32. _, mask = cv2.findHomography(
  33. np.array([query_kpts[m.queryIdx].pt for m in matches], dtype=np.float64).reshape(-1, 1, 2),
  34. np.array([ref_kpts[m.trainIdx].pt for m in matches], dtype=np.float64).reshape(-1, 1, 2),
  35. method=cv2.USAC_MAGSAC,
  36. ransacReprojThreshold=5.0,
  37. maxIters=10000,
  38. confidence=0.95,
  39. )
  40. logger.info(f"number of inliers: {mask.sum()}")
  41. matches = np.array(matches)[np.all(mask > 0, axis=1)]
  42. matches = sorted(matches, key=lambda match: match.distance)
  43. matched_image = cv2.drawMatches(
  44. query_image,
  45. query_kpts,
  46. ref_image,
  47. ref_kpts,
  48. matches[:100],
  49. None,
  50. flags=2,
  51. )
  52. cv2.imwrite("matched_image.jpg", matched_image)
  53. cv2.imshow("matched_image", matched_image)
  54. cv2.waitKey(0)
  55. cv2.destroyAllWindows()
  56. if __name__ == "__main__":
  57. main()