Parcourir la source

hotfix: fix matching bug

xmba15 il y a 1 an
Parent
commit
2d9643d596

+ 6 - 0
README.md

@@ -91,6 +91,12 @@ if __name__ == "__main__":
 - [Notebook with detailed sample code for SuperPoint](notebooks/demo_superpoint.ipynb)
 - [Notebook with detailed sample code for SuperGlue](notebooks/demo_superglue.ipynb)
 
+- Command line to test matching two images after installing the library
+
+```bash
+match_two_images --query_path [path/to/query/image] --ref_path [path/to/reference/image] --use_gpu
+```
+
 ## 🎛 Development environment
 
 ---

BIN
docs/images/matched_image.jpg


Fichier diff supprimé car celui-ci est trop grand
+ 4 - 4
notebooks/demo_superglue.ipynb


+ 5 - 0
setup.py

@@ -35,6 +35,11 @@ def main():
         ],
         packages=find_packages(exclude=["tests"]),
         install_requires=_INSTALL_REQUIRES,
+        entry_points={
+            "console_scripts": [
+                "match_two_images=superpoint_superglue_deployment.__main__:main",
+            ]
+        },
     )
 
 

+ 70 - 0
superpoint_superglue_deployment/__main__.py

@@ -0,0 +1,70 @@
+import cv2
+import numpy as np
+from loguru import logger
+
+from superpoint_superglue_deployment import Matcher
+
+
+def get_args():
+    import argparse
+
+    parser = argparse.ArgumentParser("test matching two images")
+    parser.add_argument("--query_path", "-q", type=str, required=True, help="path to query image")
+    parser.add_argument("--ref_path", "-r", type=str, required=True, help="path to reference image")
+    parser.add_argument("--use_gpu", action="store_true")
+
+    return parser.parse_args()
+
+
+def main():
+    args = get_args()
+
+    query_image = cv2.imread(args.query_path)
+    ref_image = cv2.imread(args.ref_path)
+
+    query_gray = cv2.imread(args.query_path, 0)
+    ref_gray = cv2.imread(args.ref_path, 0)
+
+    superglue_matcher = Matcher(
+        {
+            "superpoint": {
+                "input_shape": (-1, -1),
+                "keypoint_threshold": 0.005,
+            },
+            "superglue": {
+                "match_threshold": 0.2,
+            },
+            "use_gpu": args.use_gpu,
+        }
+    )
+    query_kpts, ref_kpts, _, _, matches = superglue_matcher.match(query_gray, ref_gray)
+    logger.info(f"number of matches by superpoint+superglue: {len(matches)}")
+    _, mask = cv2.findHomography(
+        np.array([query_kpts[m.queryIdx].pt for m in matches], dtype=np.float64).reshape(-1, 1, 2),
+        np.array([ref_kpts[m.trainIdx].pt for m in matches], dtype=np.float64).reshape(-1, 1, 2),
+        method=cv2.USAC_MAGSAC,
+        ransacReprojThreshold=5.0,
+        maxIters=10000,
+        confidence=0.95,
+    )
+    logger.info(f"number of inliers: {mask.sum()}")
+    matches = np.array(matches)[np.all(mask > 0, axis=1)]
+
+    matches = sorted(matches, key=lambda match: match.distance)
+    matched_image = cv2.drawMatches(
+        query_image,
+        query_kpts,
+        ref_image,
+        ref_kpts,
+        matches[:100],
+        None,
+        flags=2,
+    )
+    cv2.imwrite("matched_image.jpg", matched_image)
+    cv2.imshow("matched_image", matched_image)
+    cv2.waitKey(0)
+    cv2.destroyAllWindows()
+
+
+if __name__ == "__main__":
+    main()

+ 8 - 4
superpoint_superglue_deployment/superglue_handler.py

@@ -114,19 +114,23 @@ class SuperGlueHandler:
         query_shape: Tuple[int, int],
         ref_shape: Tuple[int, int],
     ) -> List[cv2.DMatch]:
+        num_query_kpts = query_pred["keypoints"][0].size()[0]
+        num_ref_kpts = ref_pred["keypoints"][0].size()[0]
         pred = self.run(
             query_pred,
             ref_pred,
             query_shape,
             ref_shape,
         )
-        query_indices = pred["matches0"].cpu().numpy().squeeze(0)
-        ref_indices = pred["matches1"].cpu().numpy().squeeze(0)
+        matches0_to_1 = pred["matches0"].cpu().numpy().squeeze(0)
         query_matching_scores = pred["matching_scores0"].cpu().numpy().squeeze(0)
+        valid = matches0_to_1 > -1
 
         del pred
-        matched_query_indices = np.where(query_indices > -1)[0]
-        matched_ref_indices = np.where(ref_indices > -1)[0]
+
+        matched_query_indices = np.arange(num_query_kpts)[valid]
+        matched_ref_indices = np.arange(num_ref_kpts)[matches0_to_1[valid]]
+
         matches = [
             cv2.DMatch(
                 _distance=1 - query_matching_scores[matched_query_idx],

+ 1 - 1
superpoint_superglue_deployment/version.py

@@ -1 +1 @@
-__version__ = "v0.0.2"
+__version__ = "v0.0.3"

Certains fichiers n'ont pas été affichés car il y a eu trop de fichiers modifiés dans ce diff