superglue_handler.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. import os
  2. from typing import Any, Dict, List, Optional, Tuple
  3. import cv2
  4. import numpy as np
  5. import torch
  6. from loguru import logger
  7. from superpoint_superglue_deployment.superglue import SuperGlue
  8. __all__ = ["SuperGlueHandler"]
  9. class SuperGlueHandler:
  10. __CACHED_DIR = os.path.join(os.path.expanduser("~"), ".cache/torch/hub/checkpoints")
  11. __MODEL_WEIGHTS_DICT: Dict[str, Any] = {
  12. "indoor": {
  13. "name": "superglue_indoor.pth",
  14. "url": "https://github.com/xmba15/superpoint_superglue_deployment/releases/download/model_weights/superglue_indoor.pth", # noqa: E501
  15. },
  16. "outdoor": {
  17. "name": "superglue_outdoor.pth",
  18. "url": "https://github.com/xmba15/superpoint_superglue_deployment/releases/download/model_weights/superglue_outdoor.pth", # noqa: E501
  19. },
  20. }
  21. __MODEL_WEIGHTS_OUTDOOR_FILE_NAME = "superglue_outdoor.pth"
  22. __DEFAULT_CONFIG: Dict[str, Any] = {
  23. "descriptor_dim": 256,
  24. "weights": "outdoor",
  25. "keypoint_encoder": [32, 64, 128, 256],
  26. "GNN_layers": ["self", "cross"] * 9,
  27. "sinkhorn_iterations": 100,
  28. "match_threshold": 0.2,
  29. "use_gpu": False,
  30. }
  31. def __init__(
  32. self,
  33. config: Optional[Dict[str, Any]] = None,
  34. ):
  35. self._config = self.__DEFAULT_CONFIG.copy()
  36. if config is not None:
  37. self._config.update(config)
  38. assert self._config["weights"] in self.__MODEL_WEIGHTS_DICT
  39. os.makedirs(self.__CACHED_DIR, exist_ok=True)
  40. if self._config["use_gpu"] and not torch.cuda.is_available():
  41. logger.info("gpu environment is not available, falling back to cpu")
  42. self._config["use_gpu"] = False
  43. self._device = torch.device("cuda" if self._config["use_gpu"] else "cpu")
  44. self._superglue_engine = SuperGlue(self._config)
  45. if not os.path.isfile(
  46. os.path.join(self.__CACHED_DIR, self.__MODEL_WEIGHTS_DICT[self._config["weights"]]["name"])
  47. ):
  48. torch.hub.load_state_dict_from_url(
  49. self.__MODEL_WEIGHTS_DICT[self._config["weights"]]["url"], map_location=lambda storage, loc: storage
  50. )
  51. self._superglue_engine.load_state_dict(
  52. torch.load(os.path.join(self.__CACHED_DIR, self.__MODEL_WEIGHTS_DICT[self._config["weights"]]["name"]))
  53. )
  54. self._superglue_engine = self._superglue_engine.eval().to(self._device)
  55. logger.info(f"loaded superglue weights {self.__MODEL_WEIGHTS_DICT[self._config['weights']]['name']}")
  56. @property
  57. def device(self):
  58. return self._device
  59. def run(
  60. self,
  61. query_pred: Dict[str, torch.Tensor],
  62. ref_pred: Dict[str, torch.Tensor],
  63. query_shape: Tuple[int, int],
  64. ref_shape: Tuple[int, int],
  65. ) -> Dict[str, torch.Tensor]:
  66. """
  67. Parameters
  68. ----------
  69. query_pred
  70. dict data in the following form
  71. {
  72. "keypoints": List[torch.Tensor] # tensor has shape: num keypoints x 2
  73. "descriptors": List[torch.Tensor] # tensor has shape: 256 x num keypoints
  74. }
  75. ref_pred
  76. dict data in the same form as query_pred's
  77. """
  78. data_dict: Dict[str, Any] = dict()
  79. data_dict = {**data_dict, **{k + "0": v for k, v in query_pred.items()}}
  80. data_dict = {**data_dict, **{k + "1": v for k, v in ref_pred.items()}}
  81. for k in data_dict:
  82. if isinstance(data_dict[k], (list, tuple)):
  83. data_dict[k] = torch.stack(data_dict[k])
  84. del query_pred, ref_pred
  85. for k in data_dict:
  86. if isinstance(data_dict[k], torch.Tensor) and data_dict[k].device.type != self._device.type:
  87. data_dict[k] = data_dict[k].to(self._device)
  88. data_dict["image0_shape"] = query_shape
  89. data_dict["image1_shape"] = ref_shape
  90. with torch.no_grad():
  91. return self._superglue_engine(data_dict)
  92. def match(
  93. self,
  94. query_pred: Dict[str, torch.Tensor],
  95. ref_pred: Dict[str, torch.Tensor],
  96. query_shape: Tuple[int, int],
  97. ref_shape: Tuple[int, int],
  98. ) -> List[cv2.DMatch]:
  99. num_query_kpts = query_pred["keypoints"][0].size()[0]
  100. num_ref_kpts = ref_pred["keypoints"][0].size()[0]
  101. pred = self.run(
  102. query_pred,
  103. ref_pred,
  104. query_shape,
  105. ref_shape,
  106. )
  107. matches0_to_1 = pred["matches0"].cpu().numpy().squeeze(0)
  108. query_matching_scores = pred["matching_scores0"].cpu().numpy().squeeze(0)
  109. valid = matches0_to_1 > -1
  110. del pred
  111. matched_query_indices = np.arange(num_query_kpts)[valid]
  112. matched_ref_indices = np.arange(num_ref_kpts)[matches0_to_1[valid]]
  113. matches = [
  114. cv2.DMatch(
  115. _distance=1 - query_matching_scores[matched_query_idx],
  116. _queryIdx=matched_query_idx,
  117. _trainIdx=matched_ref_idx,
  118. )
  119. for matched_query_idx, matched_ref_idx in zip(matched_query_indices, matched_ref_indices)
  120. ]
  121. return matches