superpoint_handler.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. import os
  2. from typing import Any, Dict, List, Optional, Tuple, Union
  3. import cv2
  4. import numpy as np
  5. import torch
  6. from loguru import logger
  7. from superpoint_superglue_deployment.core import assert_single_channel
  8. from superpoint_superglue_deployment.superpoint import SuperPoint
  9. __all__ = ["SuperPointHandler"]
  10. class SuperPointHandler:
  11. __CACHED_DIR = os.path.join(os.path.expanduser("~"), ".cache/torch/hub/checkpoints")
  12. __MODEL_WEIGHTS_FILE_NAME = "superpoint_v1.pth"
  13. __MODEL_WEIGHTS_URL = (
  14. "https://github.com/xmba15/superpoint_superglue_deployment/releases/download/model_weights/superpoint_v1.pth"
  15. )
  16. __DEFAULT_CONFIG: Dict[str, Any] = {
  17. "descriptor_dim": 256,
  18. "nms_radius": 4,
  19. "keypoint_threshold": 0.005,
  20. "max_keypoints": -1,
  21. "remove_borders": 4,
  22. "input_shape": (-1, -1),
  23. "use_gpu": True,
  24. }
  25. def __init__(
  26. self,
  27. config: Optional[Dict[str, Any]] = None, shape=160
  28. ):
  29. self.shape = shape
  30. self._config = self.__DEFAULT_CONFIG.copy()
  31. if config is not None:
  32. self._config.update(config)
  33. os.makedirs(self.__CACHED_DIR, exist_ok=True)
  34. if all([e > 0 for e in self._config["input_shape"]]):
  35. self._validate_input_shape(self._config["input_shape"])
  36. if self._config["use_gpu"] and not torch.cuda.is_available():
  37. logger.info("gpu environment is not available, falling back to cpu")
  38. self._config["use_gpu"] = False
  39. self._device = torch.device("cuda" if self._config["use_gpu"] else "cpu")
  40. self._superpoint_engine = SuperPoint(self._config)
  41. if not os.path.isfile(os.path.join(self.__CACHED_DIR, self.__MODEL_WEIGHTS_FILE_NAME)):
  42. torch.hub.load_state_dict_from_url(self.__MODEL_WEIGHTS_URL, map_location=lambda storage, loc: storage)
  43. self._superpoint_engine.load_state_dict(
  44. torch.load(os.path.join(self.__CACHED_DIR, self.__MODEL_WEIGHTS_FILE_NAME))
  45. )
  46. self._superpoint_engine = self._superpoint_engine.eval().to(self._device)
  47. logger.info(f"loaded superpoint weights {self.__MODEL_WEIGHTS_FILE_NAME}")
  48. def _validate_input_shape(self, image_shape: Tuple[int, int]):
  49. assert (
  50. max(image_shape) >= self.shape and max(image_shape) <= 4100
  51. ), f"input resolution {image_shape} is too small or too large"
  52. @property
  53. def device(self):
  54. return self._device
  55. def run(self, image: np.ndarray) -> Dict[str, Tuple[torch.Tensor]]:
  56. """
  57. Returns
  58. -------
  59. Dict[str, Tuple[torch.Tensor]]
  60. dict data in the following form:
  61. {
  62. "keypoints": List[torch.Tensor] # tensor has shape: num keypoints x 2
  63. "scores": Tuple[torch.Tensor] # tensor has shape: num keypoints
  64. "descriptors": List[torch.Tensor] # tensor has shape: 256 x num keypoints
  65. }
  66. """
  67. assert_single_channel(image)
  68. self._validate_input_shape(image.shape[:2])
  69. with torch.no_grad():
  70. pred = self._superpoint_engine({"image": self._to_tensor(image)})
  71. if all([e > 0 for e in self._config["input_shape"]]):
  72. pred["keypoints"][0] = torch.mul(
  73. pred["keypoints"][0],
  74. torch.from_numpy(np.divide(image.shape[:2][::-1], self._config["input_shape"][::-1])).to(self._device),
  75. )
  76. return pred
  77. def process_prediction(self, pred: Dict[str, torch.Tensor]) -> Tuple[List[cv2.KeyPoint], np.ndarray]:
  78. keypoints_arr = pred["keypoints"][0].cpu().numpy() # num keypoints x 2
  79. scores_arr = pred["scores"][0].cpu().numpy() # num keypoints
  80. descriptors_arr = pred["descriptors"][0].cpu().numpy() # 256 x num keypoints
  81. del pred
  82. num_keypoints = keypoints_arr.shape[0]
  83. if num_keypoints == 0:
  84. return [], np.array([])
  85. keypoints = []
  86. for idx in range(num_keypoints):
  87. keypoint = cv2.KeyPoint()
  88. keypoint.pt = keypoints_arr[idx]
  89. keypoint.response = scores_arr[idx]
  90. keypoints.append(keypoint)
  91. return keypoints, descriptors_arr.transpose(1, 0)
  92. def to_prediction(
  93. self,
  94. keypoints: List[cv2.KeyPoint],
  95. descriptors: np.ndarray,
  96. ) -> Dict[str, Union[Tuple[torch.Tensor], List[torch.Tensor]]]:
  97. pred: Dict[str, Union[Tuple[torch.Tensor], List[torch.Tensor]]] = dict()
  98. pred["keypoints"] = [
  99. torch.from_numpy(np.array([keypoint.pt for keypoint in keypoints])).float().to(self._device)
  100. ]
  101. pred["scores"] = (
  102. torch.from_numpy(np.array([keypoint.response for keypoint in keypoints])).float().to(self._device),
  103. )
  104. pred["descriptors"] = [torch.from_numpy(descriptors.transpose(1, 0)).float().to(self._device)]
  105. return pred
  106. def detect_and_compute(self, image: np.ndarray) -> Tuple[List[cv2.KeyPoint], np.ndarray]:
  107. pred = self.run(image)
  108. return self.process_prediction(pred)
  109. def detect(self, image) -> List[cv2.KeyPoint]:
  110. return self.detect_and_compute(image)[0]
  111. def _to_tensor(self, image: np.ndarray):
  112. if all([e > 0 for e in self._config["input_shape"]]):
  113. return (
  114. torch.from_numpy(cv2.resize(image, self._config["input_shape"][::-1]).astype(np.float32) / 255.0)
  115. .float()[None, None]
  116. .to(self._device)
  117. )
  118. else:
  119. return torch.from_numpy(image.astype(np.float32) / 255.0).float()[None, None].to(self._device)