Przeglądaj źródła

get torch Tensor prediction interface from opencv keypoints, numpy descriptors

xmba15 1 rok temu
rodzic
commit
8ab0990dd6

+ 16 - 1
superpoint_superglue_deployment/superpoint_handler.py

@@ -1,5 +1,5 @@
 import os
-from typing import Any, Dict, List, Optional, Tuple
+from typing import Any, Dict, List, Optional, Tuple, Union
 
 import cv2
 import numpy as np
@@ -107,6 +107,21 @@ class SuperPointHandler:
             keypoints.append(keypoint)
         return keypoints, descriptors_arr.transpose(1, 0)
 
+    def to_prediction(
+        self,
+        keypoints: List[cv2.KeyPoint],
+        descriptors: np.ndarray,
+    ) -> Dict[str, Union[Tuple[torch.Tensor], List[torch.Tensor]]]:
+        pred: Dict[str, Union[Tuple[torch.Tensor], List[torch.Tensor]]] = dict()
+        pred["keypoints"] = [
+            torch.from_numpy(np.array([keypoint.pt for keypoint in keypoints])).float().to(self._device)
+        ]
+        pred["scores"] = (
+            torch.from_numpy(np.array([keypoint.response for keypoint in keypoints])).float().to(self._device),
+        )
+        pred["descriptors"] = [torch.from_numpy(descriptors.transpose(1, 0)).float().to(self._device)]
+        return pred
+
     def detect_and_compute(self, image: np.ndarray) -> Tuple[List[cv2.KeyPoint], np.ndarray]:
         pred = self.run(image)
         return self.process_prediction(pred)