12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273 |
- from typing import Any, Dict, List, Optional, Tuple
- import cv2
- import numpy as np
- from superpoint_superglue_deployment.superglue_handler import SuperGlueHandler
- from superpoint_superglue_deployment.superpoint_handler import SuperPointHandler
- __all__ = ["Matcher"]
- class Matcher:
- __DEFAULT_CONFIG: Dict[str, Any] = {
- "superpoint": {
- "descriptor_dim": 256,
- "nms_radius": 4,
- "keypoint_threshold": 0.005,
- "max_keypoints": -1,
- "remove_borders": 4,
- "input_shape": (-1, -1),
- },
- "superglue": {
- "descriptor_dim": 256,
- "weights": "outdoor",
- "keypoint_encoder": [32, 64, 128, 256],
- "GNN_layers": ["self", "cross"] * 9,
- "sinkhorn_iterations": 100,
- "match_threshold": 0.2,
- },
- "use_gpu": True,
- }
- def __init__(
- self,
- config: Optional[Dict[str, Any]] = None,
- ):
- self._config = self.__DEFAULT_CONFIG.copy()
- if config is not None:
- self._config.update(config)
- self._config["superpoint"].update({"use_gpu": self._config["use_gpu"]})
- self._config["superglue"].update({"use_gpu": self._config["use_gpu"]})
- self._superpoint_handler = SuperPointHandler(self._config["superpoint"])
- self._superglue_handler = SuperGlueHandler(self._config["superglue"])
- def match(
- self,
- query_image: np.ndarray,
- ref_image: np.ndarray,
- ) -> Tuple[List[cv2.KeyPoint], List[cv2.KeyPoint], np.ndarray, np.ndarray, List[cv2.DMatch]]:
- """
- Parameters
- ----------
- query_image:
- Single channel 8bit image
- ref_image:
- Single channel 8bit image
- """
- query_pred = self._superpoint_handler.run(query_image)
- ref_pred = self._superpoint_handler.run(ref_image)
- query_kpts, query_descs = self._superpoint_handler.process_prediction(query_pred)
- ref_kpts, ref_descs = self._superpoint_handler.process_prediction(ref_pred)
- return (
- query_kpts,
- ref_kpts,
- query_descs,
- ref_descs,
- self._superglue_handler.match(
- query_pred,
- ref_pred,
- query_image.shape[:2],
- ref_image.shape[:2],
- ),
- )
|