matcher.py 2.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273
  1. from typing import Any, Dict, List, Optional, Tuple
  2. import cv2
  3. import numpy as np
  4. from superpoint_superglue_deployment.superglue_handler import SuperGlueHandler
  5. from superpoint_superglue_deployment.superpoint_handler import SuperPointHandler
  6. __all__ = ["Matcher"]
  7. class Matcher:
  8. __DEFAULT_CONFIG: Dict[str, Any] = {
  9. "superpoint": {
  10. "descriptor_dim": 256,
  11. "nms_radius": 4,
  12. "keypoint_threshold": 0.005,
  13. "max_keypoints": -1,
  14. "remove_borders": 4,
  15. "input_shape": (-1, -1),
  16. },
  17. "superglue": {
  18. "descriptor_dim": 256,
  19. "weights": "outdoor",
  20. "keypoint_encoder": [32, 64, 128, 256],
  21. "GNN_layers": ["self", "cross"] * 9,
  22. "sinkhorn_iterations": 100,
  23. "match_threshold": 0.2,
  24. },
  25. "use_gpu": True,
  26. }
  27. def __init__(
  28. self,
  29. config: Optional[Dict[str, Any]] = None,
  30. ):
  31. self._config = self.__DEFAULT_CONFIG.copy()
  32. if config is not None:
  33. self._config.update(config)
  34. self._config["superpoint"].update({"use_gpu": self._config["use_gpu"]})
  35. self._config["superglue"].update({"use_gpu": self._config["use_gpu"]})
  36. self._superpoint_handler = SuperPointHandler(self._config["superpoint"])
  37. self._superglue_handler = SuperGlueHandler(self._config["superglue"])
  38. def match(
  39. self,
  40. query_image: np.ndarray,
  41. ref_image: np.ndarray,
  42. ) -> Tuple[List[cv2.KeyPoint], List[cv2.KeyPoint], np.ndarray, np.ndarray, List[cv2.DMatch]]:
  43. """
  44. Parameters
  45. ----------
  46. query_image:
  47. Single channel 8bit image
  48. ref_image:
  49. Single channel 8bit image
  50. """
  51. query_pred = self._superpoint_handler.run(query_image)
  52. ref_pred = self._superpoint_handler.run(ref_image)
  53. query_kpts, query_descs = self._superpoint_handler.process_prediction(query_pred)
  54. ref_kpts, ref_descs = self._superpoint_handler.process_prediction(ref_pred)
  55. return (
  56. query_kpts,
  57. ref_kpts,
  58. query_descs,
  59. ref_descs,
  60. self._superglue_handler.match(
  61. query_pred,
  62. ref_pred,
  63. query_image.shape[:2],
  64. ref_image.shape[:2],
  65. ),
  66. )