result_set.h 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543
  1. /***********************************************************************
  2. * Software License Agreement (BSD License)
  3. *
  4. * Copyright 2008-2009 Marius Muja (mariusm@cs.ubc.ca). All rights reserved.
  5. * Copyright 2008-2009 David G. Lowe (lowe@cs.ubc.ca). All rights reserved.
  6. *
  7. * THE BSD LICENSE
  8. *
  9. * Redistribution and use in source and binary forms, with or without
  10. * modification, are permitted provided that the following conditions
  11. * are met:
  12. *
  13. * 1. Redistributions of source code must retain the above copyright
  14. * notice, this list of conditions and the following disclaimer.
  15. * 2. Redistributions in binary form must reproduce the above copyright
  16. * notice, this list of conditions and the following disclaimer in the
  17. * documentation and/or other materials provided with the distribution.
  18. *
  19. * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
  20. * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
  21. * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
  22. * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
  23. * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
  24. * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
  25. * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
  26. * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
  27. * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
  28. * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  29. *************************************************************************/
  30. #ifndef OPENCV_FLANN_RESULTSET_H
  31. #define OPENCV_FLANN_RESULTSET_H
  32. #include <algorithm>
  33. #include <cstring>
  34. #include <iostream>
  35. #include <limits>
  36. #include <set>
  37. #include <vector>
  38. namespace cvflann
  39. {
  40. /* This record represents a branch point when finding neighbors in
  41. the tree. It contains a record of the minimum distance to the query
  42. point, as well as the node at which the search resumes.
  43. */
  44. template <typename T, typename DistanceType>
  45. struct BranchStruct
  46. {
  47. T node; /* Tree node at which search resumes */
  48. DistanceType mindist; /* Minimum distance to query for all nodes below. */
  49. BranchStruct() {}
  50. BranchStruct(const T& aNode, DistanceType dist) : node(aNode), mindist(dist) {}
  51. bool operator<(const BranchStruct<T, DistanceType>& rhs) const
  52. {
  53. return mindist<rhs.mindist;
  54. }
  55. };
  56. template <typename DistanceType>
  57. class ResultSet
  58. {
  59. public:
  60. virtual ~ResultSet() {}
  61. virtual bool full() const = 0;
  62. virtual void addPoint(DistanceType dist, int index) = 0;
  63. virtual DistanceType worstDist() const = 0;
  64. };
  65. /**
  66. * KNNSimpleResultSet does not ensure that the element it holds are unique.
  67. * Is used in those cases where the nearest neighbour algorithm used does not
  68. * attempt to insert the same element multiple times.
  69. */
  70. template <typename DistanceType>
  71. class KNNSimpleResultSet : public ResultSet<DistanceType>
  72. {
  73. int* indices;
  74. DistanceType* dists;
  75. int capacity;
  76. int count;
  77. DistanceType worst_distance_;
  78. public:
  79. KNNSimpleResultSet(int capacity_) : capacity(capacity_), count(0)
  80. {
  81. }
  82. void init(int* indices_, DistanceType* dists_)
  83. {
  84. indices = indices_;
  85. dists = dists_;
  86. count = 0;
  87. worst_distance_ = (std::numeric_limits<DistanceType>::max)();
  88. dists[capacity-1] = worst_distance_;
  89. }
  90. size_t size() const
  91. {
  92. return count;
  93. }
  94. bool full() const CV_OVERRIDE
  95. {
  96. return count == capacity;
  97. }
  98. void addPoint(DistanceType dist, int index) CV_OVERRIDE
  99. {
  100. if (dist >= worst_distance_) return;
  101. int i;
  102. for (i=count; i>0; --i) {
  103. #ifdef FLANN_FIRST_MATCH
  104. if ( (dists[i-1]>dist) || ((dist==dists[i-1])&&(indices[i-1]>index)) )
  105. #else
  106. if (dists[i-1]>dist)
  107. #endif
  108. {
  109. if (i<capacity) {
  110. dists[i] = dists[i-1];
  111. indices[i] = indices[i-1];
  112. }
  113. }
  114. else break;
  115. }
  116. if (count < capacity) ++count;
  117. dists[i] = dist;
  118. indices[i] = index;
  119. worst_distance_ = dists[capacity-1];
  120. }
  121. DistanceType worstDist() const CV_OVERRIDE
  122. {
  123. return worst_distance_;
  124. }
  125. };
  126. /**
  127. * K-Nearest neighbour result set. Ensures that the elements inserted are unique
  128. */
  129. template <typename DistanceType>
  130. class KNNResultSet : public ResultSet<DistanceType>
  131. {
  132. int* indices;
  133. DistanceType* dists;
  134. int capacity;
  135. int count;
  136. DistanceType worst_distance_;
  137. public:
  138. KNNResultSet(int capacity_) : capacity(capacity_), count(0)
  139. {
  140. }
  141. void init(int* indices_, DistanceType* dists_)
  142. {
  143. indices = indices_;
  144. dists = dists_;
  145. count = 0;
  146. worst_distance_ = (std::numeric_limits<DistanceType>::max)();
  147. dists[capacity-1] = worst_distance_;
  148. }
  149. size_t size() const
  150. {
  151. return count;
  152. }
  153. bool full() const CV_OVERRIDE
  154. {
  155. return count == capacity;
  156. }
  157. void addPoint(DistanceType dist, int index) CV_OVERRIDE
  158. {
  159. if (dist >= worst_distance_) return;
  160. int i;
  161. for (i = count; i > 0; --i) {
  162. #ifdef FLANN_FIRST_MATCH
  163. if ( (dists[i-1]<=dist) && ((dist!=dists[i-1])||(indices[i-1]<=index)) )
  164. #else
  165. if (dists[i-1]<=dist)
  166. #endif
  167. {
  168. // Check for duplicate indices
  169. int j = i - 1;
  170. while ((j >= 0) && (dists[j] == dist)) {
  171. if (indices[j] == index) {
  172. return;
  173. }
  174. --j;
  175. }
  176. break;
  177. }
  178. }
  179. if (count < capacity) ++count;
  180. for (int j = count-1; j > i; --j) {
  181. dists[j] = dists[j-1];
  182. indices[j] = indices[j-1];
  183. }
  184. dists[i] = dist;
  185. indices[i] = index;
  186. worst_distance_ = dists[capacity-1];
  187. }
  188. DistanceType worstDist() const CV_OVERRIDE
  189. {
  190. return worst_distance_;
  191. }
  192. };
  193. /**
  194. * A result-set class used when performing a radius based search.
  195. */
  196. template <typename DistanceType>
  197. class RadiusResultSet : public ResultSet<DistanceType>
  198. {
  199. DistanceType radius;
  200. int* indices;
  201. DistanceType* dists;
  202. size_t capacity;
  203. size_t count;
  204. public:
  205. RadiusResultSet(DistanceType radius_, int* indices_, DistanceType* dists_, int capacity_) :
  206. radius(radius_), indices(indices_), dists(dists_), capacity(capacity_)
  207. {
  208. init();
  209. }
  210. ~RadiusResultSet()
  211. {
  212. }
  213. void init()
  214. {
  215. count = 0;
  216. }
  217. size_t size() const
  218. {
  219. return count;
  220. }
  221. bool full() const
  222. {
  223. return true;
  224. }
  225. void addPoint(DistanceType dist, int index)
  226. {
  227. if (dist<radius) {
  228. if ((capacity>0)&&(count < capacity)) {
  229. dists[count] = dist;
  230. indices[count] = index;
  231. }
  232. count++;
  233. }
  234. }
  235. DistanceType worstDist() const
  236. {
  237. return radius;
  238. }
  239. };
  240. ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
  241. /** Class that holds the k NN neighbors
  242. * Faster than KNNResultSet as it uses a binary heap and does not maintain two arrays
  243. */
  244. template<typename DistanceType>
  245. class UniqueResultSet : public ResultSet<DistanceType>
  246. {
  247. public:
  248. struct DistIndex
  249. {
  250. DistIndex(DistanceType dist, unsigned int index) :
  251. dist_(dist), index_(index)
  252. {
  253. }
  254. bool operator<(const DistIndex dist_index) const
  255. {
  256. return (dist_ < dist_index.dist_) || ((dist_ == dist_index.dist_) && index_ < dist_index.index_);
  257. }
  258. DistanceType dist_;
  259. unsigned int index_;
  260. };
  261. /** Default cosntructor */
  262. UniqueResultSet() :
  263. is_full_(false), worst_distance_(std::numeric_limits<DistanceType>::max())
  264. {
  265. }
  266. /** Check the status of the set
  267. * @return true if we have k NN
  268. */
  269. inline bool full() const CV_OVERRIDE
  270. {
  271. return is_full_;
  272. }
  273. /** Remove all elements in the set
  274. */
  275. virtual void clear() = 0;
  276. /** Copy the set to two C arrays
  277. * @param indices pointer to a C array of indices
  278. * @param dist pointer to a C array of distances
  279. * @param n_neighbors the number of neighbors to copy
  280. */
  281. virtual void copy(int* indices, DistanceType* dist, int n_neighbors = -1) const
  282. {
  283. if (n_neighbors < 0) {
  284. for (typename std::set<DistIndex>::const_iterator dist_index = dist_indices_.begin(), dist_index_end =
  285. dist_indices_.end(); dist_index != dist_index_end; ++dist_index, ++indices, ++dist) {
  286. *indices = dist_index->index_;
  287. *dist = dist_index->dist_;
  288. }
  289. }
  290. else {
  291. int i = 0;
  292. for (typename std::set<DistIndex>::const_iterator dist_index = dist_indices_.begin(), dist_index_end =
  293. dist_indices_.end(); (dist_index != dist_index_end) && (i < n_neighbors); ++dist_index, ++indices, ++dist, ++i) {
  294. *indices = dist_index->index_;
  295. *dist = dist_index->dist_;
  296. }
  297. }
  298. }
  299. /** Copy the set to two C arrays but sort it according to the distance first
  300. * @param indices pointer to a C array of indices
  301. * @param dist pointer to a C array of distances
  302. * @param n_neighbors the number of neighbors to copy
  303. */
  304. virtual void sortAndCopy(int* indices, DistanceType* dist, int n_neighbors = -1) const
  305. {
  306. copy(indices, dist, n_neighbors);
  307. }
  308. /** The number of neighbors in the set
  309. * @return
  310. */
  311. size_t size() const
  312. {
  313. return dist_indices_.size();
  314. }
  315. /** The distance of the furthest neighbor
  316. * If we don't have enough neighbors, it returns the max possible value
  317. * @return
  318. */
  319. inline DistanceType worstDist() const CV_OVERRIDE
  320. {
  321. return worst_distance_;
  322. }
  323. protected:
  324. /** Flag to say if the set is full */
  325. bool is_full_;
  326. /** The worst distance found so far */
  327. DistanceType worst_distance_;
  328. /** The best candidates so far */
  329. std::set<DistIndex> dist_indices_;
  330. };
  331. ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
  332. /** Class that holds the k NN neighbors
  333. * Faster than KNNResultSet as it uses a binary heap and does not maintain two arrays
  334. */
  335. template<typename DistanceType>
  336. class KNNUniqueResultSet : public UniqueResultSet<DistanceType>
  337. {
  338. public:
  339. /** Constructor
  340. * @param capacity the number of neighbors to store at max
  341. */
  342. KNNUniqueResultSet(unsigned int capacity) : capacity_(capacity)
  343. {
  344. this->is_full_ = false;
  345. this->clear();
  346. }
  347. /** Add a possible candidate to the best neighbors
  348. * @param dist distance for that neighbor
  349. * @param index index of that neighbor
  350. */
  351. inline void addPoint(DistanceType dist, int index) CV_OVERRIDE
  352. {
  353. // Don't do anything if we are worse than the worst
  354. if (dist >= worst_distance_) return;
  355. dist_indices_.insert(DistIndex(dist, index));
  356. if (is_full_) {
  357. if (dist_indices_.size() > capacity_) {
  358. dist_indices_.erase(*dist_indices_.rbegin());
  359. worst_distance_ = dist_indices_.rbegin()->dist_;
  360. }
  361. }
  362. else if (dist_indices_.size() == capacity_) {
  363. is_full_ = true;
  364. worst_distance_ = dist_indices_.rbegin()->dist_;
  365. }
  366. }
  367. /** Remove all elements in the set
  368. */
  369. void clear() CV_OVERRIDE
  370. {
  371. dist_indices_.clear();
  372. worst_distance_ = std::numeric_limits<DistanceType>::max();
  373. is_full_ = false;
  374. }
  375. protected:
  376. typedef typename UniqueResultSet<DistanceType>::DistIndex DistIndex;
  377. using UniqueResultSet<DistanceType>::is_full_;
  378. using UniqueResultSet<DistanceType>::worst_distance_;
  379. using UniqueResultSet<DistanceType>::dist_indices_;
  380. /** The number of neighbors to keep */
  381. unsigned int capacity_;
  382. };
  383. ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
  384. /** Class that holds the radius nearest neighbors
  385. * It is more accurate than RadiusResult as it is not limited in the number of neighbors
  386. */
  387. template<typename DistanceType>
  388. class RadiusUniqueResultSet : public UniqueResultSet<DistanceType>
  389. {
  390. public:
  391. /** Constructor
  392. * @param radius the maximum distance of a neighbor
  393. */
  394. RadiusUniqueResultSet(DistanceType radius) :
  395. radius_(radius)
  396. {
  397. is_full_ = true;
  398. }
  399. /** Add a possible candidate to the best neighbors
  400. * @param dist distance for that neighbor
  401. * @param index index of that neighbor
  402. */
  403. void addPoint(DistanceType dist, int index) CV_OVERRIDE
  404. {
  405. if (dist <= radius_) dist_indices_.insert(DistIndex(dist, index));
  406. }
  407. /** Remove all elements in the set
  408. */
  409. inline void clear() CV_OVERRIDE
  410. {
  411. dist_indices_.clear();
  412. }
  413. /** Check the status of the set
  414. * @return alwys false
  415. */
  416. inline bool full() const CV_OVERRIDE
  417. {
  418. return true;
  419. }
  420. /** The distance of the furthest neighbor
  421. * If we don't have enough neighbors, it returns the max possible value
  422. * @return
  423. */
  424. inline DistanceType worstDist() const CV_OVERRIDE
  425. {
  426. return radius_;
  427. }
  428. private:
  429. typedef typename UniqueResultSet<DistanceType>::DistIndex DistIndex;
  430. using UniqueResultSet<DistanceType>::dist_indices_;
  431. using UniqueResultSet<DistanceType>::is_full_;
  432. /** The furthest distance a neighbor can be */
  433. DistanceType radius_;
  434. };
  435. ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
  436. /** Class that holds the k NN neighbors within a radius distance
  437. */
  438. template<typename DistanceType>
  439. class KNNRadiusUniqueResultSet : public KNNUniqueResultSet<DistanceType>
  440. {
  441. public:
  442. /** Constructor
  443. * @param capacity the number of neighbors to store at max
  444. * @param radius the maximum distance of a neighbor
  445. */
  446. KNNRadiusUniqueResultSet(unsigned int capacity, DistanceType radius)
  447. {
  448. this->capacity_ = capacity;
  449. this->radius_ = radius;
  450. this->dist_indices_.reserve(capacity_);
  451. this->clear();
  452. }
  453. /** Remove all elements in the set
  454. */
  455. void clear()
  456. {
  457. dist_indices_.clear();
  458. worst_distance_ = radius_;
  459. is_full_ = false;
  460. }
  461. private:
  462. using KNNUniqueResultSet<DistanceType>::dist_indices_;
  463. using KNNUniqueResultSet<DistanceType>::is_full_;
  464. using KNNUniqueResultSet<DistanceType>::worst_distance_;
  465. /** The maximum number of neighbors to consider */
  466. unsigned int capacity_;
  467. /** The maximum distance of a neighbor */
  468. DistanceType radius_;
  469. };
  470. }
  471. #endif //OPENCV_FLANN_RESULTSET_H