superpoint.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196
  1. # %BANNER_BEGIN%
  2. # ---------------------------------------------------------------------
  3. # %COPYRIGHT_BEGIN%
  4. #
  5. # Magic Leap, Inc. ("COMPANY") CONFIDENTIAL
  6. #
  7. # Unpublished Copyright (c) 2020
  8. # Magic Leap, Inc., All Rights Reserved.
  9. #
  10. # NOTICE: All information contained herein is, and remains the property
  11. # of COMPANY. The intellectual and technical concepts contained herein
  12. # are proprietary to COMPANY and may be covered by U.S. and Foreign
  13. # Patents, patents in process, and are protected by trade secret or
  14. # copyright law. Dissemination of this information or reproduction of
  15. # this material is strictly forbidden unless prior written permission is
  16. # obtained from COMPANY. Access to the source code contained herein is
  17. # hereby forbidden to anyone except current COMPANY employees, managers
  18. # or contractors who have executed Confidentiality and Non-disclosure
  19. # agreements explicitly covering such access.
  20. #
  21. # The copyright notice above does not evidence any actual or intended
  22. # publication or disclosure of this source code, which includes
  23. # information that is confidential and/or proprietary, and is a trade
  24. # secret, of COMPANY. ANY REPRODUCTION, MODIFICATION, DISTRIBUTION,
  25. # PUBLIC PERFORMANCE, OR PUBLIC DISPLAY OF OR THROUGH USE OF THIS
  26. # SOURCE CODE WITHOUT THE EXPRESS WRITTEN CONSENT OF COMPANY IS
  27. # STRICTLY PROHIBITED, AND IN VIOLATION OF APPLICABLE LAWS AND
  28. # INTERNATIONAL TREATIES. THE RECEIPT OR POSSESSION OF THIS SOURCE
  29. # CODE AND/OR RELATED INFORMATION DOES NOT CONVEY OR IMPLY ANY RIGHTS
  30. # TO REPRODUCE, DISCLOSE OR DISTRIBUTE ITS CONTENTS, OR TO MANUFACTURE,
  31. # USE, OR SELL ANYTHING THAT IT MAY DESCRIBE, IN WHOLE OR IN PART.
  32. #
  33. # %COPYRIGHT_END%
  34. # ----------------------------------------------------------------------
  35. # %AUTHORS_BEGIN%
  36. #
  37. # Originating Authors: Paul-Edouard Sarlin
  38. #
  39. # %AUTHORS_END%
  40. # --------------------------------------------------------------------*/
  41. # %BANNER_END%
  42. import torch
  43. from torch import nn
  44. def simple_nms(scores, nms_radius: int):
  45. """ Fast Non-maximum suppression to remove nearby points """
  46. assert(nms_radius >= 0)
  47. def max_pool(x):
  48. return torch.nn.functional.max_pool2d(
  49. x, kernel_size=nms_radius*2+1, stride=1, padding=nms_radius)
  50. zeros = torch.zeros_like(scores)
  51. max_mask = scores == max_pool(scores)
  52. for _ in range(2):
  53. supp_mask = max_pool(max_mask.float()) > 0
  54. supp_scores = torch.where(supp_mask, zeros, scores)
  55. new_max_mask = supp_scores == max_pool(supp_scores)
  56. max_mask = max_mask | (new_max_mask & (~supp_mask))
  57. return torch.where(max_mask, scores, zeros)
  58. def remove_borders(keypoints, scores, border: int, height: int, width: int):
  59. """ Removes keypoints too close to the border """
  60. mask_h = (keypoints[:, 0] >= border) & (keypoints[:, 0] < (height - border))
  61. mask_w = (keypoints[:, 1] >= border) & (keypoints[:, 1] < (width - border))
  62. mask = mask_h & mask_w
  63. return keypoints[mask], scores[mask]
  64. def top_k_keypoints(keypoints, scores, k: int):
  65. if k >= len(keypoints):
  66. return keypoints, scores
  67. scores, indices = torch.topk(scores, k, dim=0)
  68. return keypoints[indices], scores
  69. def sample_descriptors(keypoints, descriptors, s: int = 8):
  70. """ Interpolate descriptors at keypoint locations """
  71. b, c, h, w = descriptors.shape
  72. keypoints = keypoints - s / 2 + 0.5
  73. keypoints /= torch.tensor([(w*s - s/2 - 0.5), (h*s - s/2 - 0.5)],
  74. ).to(keypoints)[None]
  75. keypoints = keypoints*2 - 1 # normalize to (-1, 1)
  76. args = {'align_corners': True} if torch.__version__ >= '1.3' else {}
  77. descriptors = torch.nn.functional.grid_sample(
  78. descriptors, keypoints.view(b, 1, -1, 2), mode='bilinear', **args)
  79. descriptors = torch.nn.functional.normalize(
  80. descriptors.reshape(b, c, -1), p=2, dim=1)
  81. return descriptors
  82. class SuperPoint(nn.Module):
  83. """SuperPoint Convolutional Detector and Descriptor
  84. SuperPoint: Self-Supervised Interest Point Detection and
  85. Description. Daniel DeTone, Tomasz Malisiewicz, and Andrew
  86. Rabinovich. In CVPRW, 2019. https://arxiv.org/abs/1712.07629
  87. """
  88. default_config = {
  89. 'descriptor_dim': 256,
  90. 'nms_radius': 4,
  91. 'keypoint_threshold': 0.005,
  92. 'max_keypoints': -1,
  93. 'remove_borders': 4,
  94. }
  95. def __init__(self, config):
  96. super().__init__()
  97. self.config = {**self.default_config, **config}
  98. self.relu = nn.ReLU(inplace=True)
  99. self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
  100. c1, c2, c3, c4, c5 = 64, 64, 128, 128, 256
  101. self.conv1a = nn.Conv2d(1, c1, kernel_size=3, stride=1, padding=1)
  102. self.conv1b = nn.Conv2d(c1, c1, kernel_size=3, stride=1, padding=1)
  103. self.conv2a = nn.Conv2d(c1, c2, kernel_size=3, stride=1, padding=1)
  104. self.conv2b = nn.Conv2d(c2, c2, kernel_size=3, stride=1, padding=1)
  105. self.conv3a = nn.Conv2d(c2, c3, kernel_size=3, stride=1, padding=1)
  106. self.conv3b = nn.Conv2d(c3, c3, kernel_size=3, stride=1, padding=1)
  107. self.conv4a = nn.Conv2d(c3, c4, kernel_size=3, stride=1, padding=1)
  108. self.conv4b = nn.Conv2d(c4, c4, kernel_size=3, stride=1, padding=1)
  109. self.convPa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1)
  110. self.convPb = nn.Conv2d(c5, 65, kernel_size=1, stride=1, padding=0)
  111. self.convDa = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1)
  112. self.convDb = nn.Conv2d(
  113. c5, self.config['descriptor_dim'],
  114. kernel_size=1, stride=1, padding=0)
  115. mk = self.config['max_keypoints']
  116. if mk == 0 or mk < -1:
  117. raise ValueError('\"max_keypoints\" must be positive or \"-1\"')
  118. def forward(self, data):
  119. """ Compute keypoints, scores, descriptors for image """
  120. # Shared Encoder
  121. x = self.relu(self.conv1a(data['image']))
  122. x = self.relu(self.conv1b(x))
  123. x = self.pool(x)
  124. x = self.relu(self.conv2a(x))
  125. x = self.relu(self.conv2b(x))
  126. x = self.pool(x)
  127. x = self.relu(self.conv3a(x))
  128. x = self.relu(self.conv3b(x))
  129. x = self.pool(x)
  130. x = self.relu(self.conv4a(x))
  131. x = self.relu(self.conv4b(x))
  132. # Compute the dense keypoint scores
  133. cPa = self.relu(self.convPa(x))
  134. scores = self.convPb(cPa)
  135. scores = torch.nn.functional.softmax(scores, 1)[:, :-1]
  136. b, _, h, w = scores.shape
  137. scores = scores.permute(0, 2, 3, 1).reshape(b, h, w, 8, 8)
  138. scores = scores.permute(0, 1, 3, 2, 4).reshape(b, h*8, w*8)
  139. scores = simple_nms(scores, self.config['nms_radius'])
  140. # Extract keypoints
  141. keypoints = [
  142. torch.nonzero(s > self.config['keypoint_threshold'])
  143. for s in scores]
  144. scores = [s[tuple(k.t())] for s, k in zip(scores, keypoints)]
  145. # Discard keypoints near the image borders
  146. keypoints, scores = list(zip(*[
  147. remove_borders(k, s, self.config['remove_borders'], h*8, w*8)
  148. for k, s in zip(keypoints, scores)]))
  149. # Keep the k keypoints with highest score
  150. if self.config['max_keypoints'] >= 0:
  151. keypoints, scores = list(zip(*[
  152. top_k_keypoints(k, s, self.config['max_keypoints'])
  153. for k, s in zip(keypoints, scores)]))
  154. # Convert (h, w) to (x, y)
  155. keypoints = [torch.flip(k, [1]).float() for k in keypoints]
  156. # Compute the dense descriptors
  157. cDa = self.relu(self.convDa(x))
  158. descriptors = self.convDb(cDa)
  159. descriptors = torch.nn.functional.normalize(descriptors, p=2, dim=1)
  160. # Extract descriptors
  161. descriptors = [sample_descriptors(k[None], d[None], 8)[0]
  162. for k, d in zip(keypoints, descriptors)]
  163. return {
  164. 'keypoints': keypoints,
  165. 'scores': scores,
  166. 'descriptors': descriptors,
  167. }