|
@@ -1,5 +1,5 @@
|
|
|
import os
|
|
|
-from typing import Any, Dict, List, Optional, Tuple
|
|
|
+from typing import Any, Dict, List, Optional, Tuple, Union
|
|
|
|
|
|
import cv2
|
|
|
import numpy as np
|
|
@@ -107,6 +107,21 @@ class SuperPointHandler:
|
|
|
keypoints.append(keypoint)
|
|
|
return keypoints, descriptors_arr.transpose(1, 0)
|
|
|
|
|
|
+ def to_prediction(
|
|
|
+ self,
|
|
|
+ keypoints: List[cv2.KeyPoint],
|
|
|
+ descriptors: np.ndarray,
|
|
|
+ ) -> Dict[str, Union[Tuple[torch.Tensor], List[torch.Tensor]]]:
|
|
|
+ pred: Dict[str, Union[Tuple[torch.Tensor], List[torch.Tensor]]] = dict()
|
|
|
+ pred["keypoints"] = [
|
|
|
+ torch.from_numpy(np.array([keypoint.pt for keypoint in keypoints])).float().to(self._device)
|
|
|
+ ]
|
|
|
+ pred["scores"] = (
|
|
|
+ torch.from_numpy(np.array([keypoint.response for keypoint in keypoints])).float().to(self._device),
|
|
|
+ )
|
|
|
+ pred["descriptors"] = [torch.from_numpy(descriptors.transpose(1, 0)).float().to(self._device)]
|
|
|
+ return pred
|
|
|
+
|
|
|
def detect_and_compute(self, image: np.ndarray) -> Tuple[List[cv2.KeyPoint], np.ndarray]:
|
|
|
pred = self.run(image)
|
|
|
return self.process_prediction(pred)
|