superpoint_handler.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. import os
  2. from typing import Any, Dict, List, Optional, Tuple
  3. import cv2
  4. import numpy as np
  5. import torch
  6. from loguru import logger
  7. from superpoint_superglue_deployment.core import assert_single_channel
  8. from superpoint_superglue_deployment.superpoint import SuperPoint
  9. __all__ = ["SuperPointHandler"]
  10. class SuperPointHandler:
  11. __CACHED_DIR = os.path.join(os.path.expanduser("~"), ".cache/torch/hub/checkpoints")
  12. __MODEL_WEIGHTS_FILE_NAME = "superpoint_v1.pth"
  13. __MODEL_WEIGHTS_URL = (
  14. "https://github.com/xmba15/superpoint_superglue_deployment/releases/download/model_weights/superpoint_v1.pth"
  15. )
  16. __DEFAULT_CONFIG: Dict[str, Any] = {
  17. "descriptor_dim": 256,
  18. "nms_radius": 4,
  19. "keypoint_threshold": 0.005,
  20. "max_keypoints": -1,
  21. "remove_borders": 4,
  22. "input_shape": (-1, -1),
  23. "use_gpu": False,
  24. }
  25. def __init__(
  26. self,
  27. config: Optional[Dict[str, Any]] = None,
  28. ):
  29. self._config = self.__DEFAULT_CONFIG.copy()
  30. if config is not None:
  31. self._config.update(config)
  32. os.makedirs(self.__CACHED_DIR, exist_ok=True)
  33. if all([e > 0 for e in self._config["input_shape"]]):
  34. self._validate_input_shape(self._config["input_shape"])
  35. if self._config["use_gpu"] and not torch.cuda.is_available():
  36. logger.info("gpu environment is not available, falling back to cpu")
  37. self._config["use_gpu"] = False
  38. self._device = torch.device("cuda" if self._config["use_gpu"] else "cpu")
  39. self._superpoint_engine = SuperPoint(self._config)
  40. if not os.path.isfile(os.path.join(self.__CACHED_DIR, self.__MODEL_WEIGHTS_FILE_NAME)):
  41. torch.hub.load_state_dict_from_url(self.__MODEL_WEIGHTS_URL, map_location=lambda storage, loc: storage)
  42. self._superpoint_engine.load_state_dict(
  43. torch.load(os.path.join(self.__CACHED_DIR, self.__MODEL_WEIGHTS_FILE_NAME))
  44. )
  45. self._superpoint_engine = self._superpoint_engine.eval().to(self._device)
  46. logger.info(f"loaded superpoint weights {self.__MODEL_WEIGHTS_FILE_NAME}")
  47. def _validate_input_shape(self, image_shape: Tuple[int, int]):
  48. assert (
  49. max(image_shape) >= 160 and max(image_shape) <= 2000
  50. ), f"input resolution {image_shape} is too small or too large"
  51. @property
  52. def device(self):
  53. return self._device
  54. def run(self, image: np.ndarray) -> Dict[str, Tuple[torch.Tensor]]:
  55. """
  56. Returns
  57. -------
  58. Dict[str, Tuple[torch.Tensor]]
  59. dict data in the following form:
  60. {
  61. "keypoints": List[torch.Tensor] # tensor has shape: num keypoints x 2
  62. "scores": Tuple[torch.Tensor] # tensor has shape: num keypoints
  63. "descriptors": List[torch.Tensor] # tensor has shape: 256 x num keypoints
  64. }
  65. """
  66. assert_single_channel(image)
  67. self._validate_input_shape(image.shape[:2])
  68. with torch.no_grad():
  69. pred = self._superpoint_engine({"image": self._to_tensor(image)})
  70. if all([e > 0 for e in self._config["input_shape"]]):
  71. pred["keypoints"][0] = torch.mul(
  72. pred["keypoints"][0],
  73. torch.from_numpy(np.divide(image.shape[:2][::-1], self._config["input_shape"][::-1])).to(self._device),
  74. )
  75. return pred
  76. def process_prediction(self, pred: Dict[str, torch.Tensor]) -> Tuple[List[cv2.KeyPoint], np.ndarray]:
  77. keypoints_arr = pred["keypoints"][0].cpu().numpy() # num keypoints x 2
  78. scores_arr = pred["scores"][0].cpu().numpy() # num keypoints
  79. descriptors_arr = pred["descriptors"][0].cpu().numpy() # 256 x num keypoints
  80. del pred
  81. num_keypoints = keypoints_arr.shape[0]
  82. if num_keypoints == 0:
  83. return [], np.array([])
  84. keypoints = []
  85. for idx in range(num_keypoints):
  86. keypoint = cv2.KeyPoint()
  87. keypoint.pt = keypoints_arr[idx]
  88. keypoint.response = scores_arr[idx]
  89. keypoints.append(keypoint)
  90. return keypoints, descriptors_arr.transpose(1, 0)
  91. def detect_and_compute(self, image: np.ndarray) -> Tuple[List[cv2.KeyPoint], np.ndarray]:
  92. pred = self.run(image)
  93. return self.process_prediction(pred)
  94. def detect(self, image) -> List[cv2.KeyPoint]:
  95. return self.detect_and_compute(image)[0]
  96. def _to_tensor(self, image: np.ndarray):
  97. if all([e > 0 for e in self._config["input_shape"]]):
  98. return (
  99. torch.from_numpy(cv2.resize(image, self._config["input_shape"][::-1]).astype(np.float32) / 255.0)
  100. .float()[None, None]
  101. .to(self._device)
  102. )
  103. else:
  104. return torch.from_numpy(image.astype(np.float32) / 255.0).float()[None, None].to(self._device)