kdtree_single_index.h 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635
  1. /***********************************************************************
  2. * Software License Agreement (BSD License)
  3. *
  4. * Copyright 2008-2009 Marius Muja (mariusm@cs.ubc.ca). All rights reserved.
  5. * Copyright 2008-2009 David G. Lowe (lowe@cs.ubc.ca). All rights reserved.
  6. *
  7. * THE BSD LICENSE
  8. *
  9. * Redistribution and use in source and binary forms, with or without
  10. * modification, are permitted provided that the following conditions
  11. * are met:
  12. *
  13. * 1. Redistributions of source code must retain the above copyright
  14. * notice, this list of conditions and the following disclaimer.
  15. * 2. Redistributions in binary form must reproduce the above copyright
  16. * notice, this list of conditions and the following disclaimer in the
  17. * documentation and/or other materials provided with the distribution.
  18. *
  19. * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
  20. * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
  21. * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
  22. * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
  23. * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
  24. * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
  25. * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
  26. * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
  27. * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
  28. * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  29. *************************************************************************/
  30. #ifndef OPENCV_FLANN_KDTREE_SINGLE_INDEX_H_
  31. #define OPENCV_FLANN_KDTREE_SINGLE_INDEX_H_
  32. #include <algorithm>
  33. #include <map>
  34. #include <cassert>
  35. #include <cstring>
  36. #include "general.h"
  37. #include "nn_index.h"
  38. #include "matrix.h"
  39. #include "result_set.h"
  40. #include "heap.h"
  41. #include "allocator.h"
  42. #include "random.h"
  43. #include "saving.h"
  44. namespace cvflann
  45. {
  46. struct KDTreeSingleIndexParams : public IndexParams
  47. {
  48. KDTreeSingleIndexParams(int leaf_max_size = 10, bool reorder = true, int dim = -1)
  49. {
  50. (*this)["algorithm"] = FLANN_INDEX_KDTREE_SINGLE;
  51. (*this)["leaf_max_size"] = leaf_max_size;
  52. (*this)["reorder"] = reorder;
  53. (*this)["dim"] = dim;
  54. }
  55. };
  56. /**
  57. * Randomized kd-tree index
  58. *
  59. * Contains the k-d trees and other information for indexing a set of points
  60. * for nearest-neighbor matching.
  61. */
  62. template <typename Distance>
  63. class KDTreeSingleIndex : public NNIndex<Distance>
  64. {
  65. public:
  66. typedef typename Distance::ElementType ElementType;
  67. typedef typename Distance::ResultType DistanceType;
  68. /**
  69. * KDTree constructor
  70. *
  71. * Params:
  72. * inputData = dataset with the input features
  73. * params = parameters passed to the kdtree algorithm
  74. */
  75. KDTreeSingleIndex(const Matrix<ElementType>& inputData, const IndexParams& params = KDTreeSingleIndexParams(),
  76. Distance d = Distance() ) :
  77. dataset_(inputData), index_params_(params), distance_(d)
  78. {
  79. size_ = dataset_.rows;
  80. dim_ = dataset_.cols;
  81. root_node_ = 0;
  82. int dim_param = get_param(params,"dim",-1);
  83. if (dim_param>0) dim_ = dim_param;
  84. leaf_max_size_ = get_param(params,"leaf_max_size",10);
  85. reorder_ = get_param(params,"reorder",true);
  86. // Create a permutable array of indices to the input vectors.
  87. vind_.resize(size_);
  88. for (size_t i = 0; i < size_; i++) {
  89. vind_[i] = (int)i;
  90. }
  91. }
  92. KDTreeSingleIndex(const KDTreeSingleIndex&);
  93. KDTreeSingleIndex& operator=(const KDTreeSingleIndex&);
  94. /**
  95. * Standard destructor
  96. */
  97. ~KDTreeSingleIndex()
  98. {
  99. if (reorder_) delete[] data_.data;
  100. }
  101. /**
  102. * Builds the index
  103. */
  104. void buildIndex() CV_OVERRIDE
  105. {
  106. computeBoundingBox(root_bbox_);
  107. root_node_ = divideTree(0, (int)size_, root_bbox_ ); // construct the tree
  108. if (reorder_) {
  109. delete[] data_.data;
  110. data_ = cvflann::Matrix<ElementType>(new ElementType[size_*dim_], size_, dim_);
  111. for (size_t i=0; i<size_; ++i) {
  112. for (size_t j=0; j<dim_; ++j) {
  113. data_[i][j] = dataset_[vind_[i]][j];
  114. }
  115. }
  116. }
  117. else {
  118. data_ = dataset_;
  119. }
  120. }
  121. flann_algorithm_t getType() const CV_OVERRIDE
  122. {
  123. return FLANN_INDEX_KDTREE_SINGLE;
  124. }
  125. void saveIndex(FILE* stream) CV_OVERRIDE
  126. {
  127. save_value(stream, size_);
  128. save_value(stream, dim_);
  129. save_value(stream, root_bbox_);
  130. save_value(stream, reorder_);
  131. save_value(stream, leaf_max_size_);
  132. save_value(stream, vind_);
  133. if (reorder_) {
  134. save_value(stream, data_);
  135. }
  136. save_tree(stream, root_node_);
  137. }
  138. void loadIndex(FILE* stream) CV_OVERRIDE
  139. {
  140. load_value(stream, size_);
  141. load_value(stream, dim_);
  142. load_value(stream, root_bbox_);
  143. load_value(stream, reorder_);
  144. load_value(stream, leaf_max_size_);
  145. load_value(stream, vind_);
  146. if (reorder_) {
  147. load_value(stream, data_);
  148. }
  149. else {
  150. data_ = dataset_;
  151. }
  152. load_tree(stream, root_node_);
  153. index_params_["algorithm"] = getType();
  154. index_params_["leaf_max_size"] = leaf_max_size_;
  155. index_params_["reorder"] = reorder_;
  156. }
  157. /**
  158. * Returns size of index.
  159. */
  160. size_t size() const CV_OVERRIDE
  161. {
  162. return size_;
  163. }
  164. /**
  165. * Returns the length of an index feature.
  166. */
  167. size_t veclen() const CV_OVERRIDE
  168. {
  169. return dim_;
  170. }
  171. /**
  172. * Computes the inde memory usage
  173. * Returns: memory used by the index
  174. */
  175. int usedMemory() const CV_OVERRIDE
  176. {
  177. return (int)(pool_.usedMemory+pool_.wastedMemory+dataset_.rows*sizeof(int)); // pool memory and vind array memory
  178. }
  179. /**
  180. * \brief Perform k-nearest neighbor search
  181. * \param[in] queries The query points for which to find the nearest neighbors
  182. * \param[out] indices The indices of the nearest neighbors found
  183. * \param[out] dists Distances to the nearest neighbors found
  184. * \param[in] knn Number of nearest neighbors to return
  185. * \param[in] params Search parameters
  186. */
  187. void knnSearch(const Matrix<ElementType>& queries, Matrix<int>& indices, Matrix<DistanceType>& dists, int knn, const SearchParams& params) CV_OVERRIDE
  188. {
  189. assert(queries.cols == veclen());
  190. assert(indices.rows >= queries.rows);
  191. assert(dists.rows >= queries.rows);
  192. assert(int(indices.cols) >= knn);
  193. assert(int(dists.cols) >= knn);
  194. KNNSimpleResultSet<DistanceType> resultSet(knn);
  195. for (size_t i = 0; i < queries.rows; i++) {
  196. resultSet.init(indices[i], dists[i]);
  197. findNeighbors(resultSet, queries[i], params);
  198. }
  199. }
  200. IndexParams getParameters() const CV_OVERRIDE
  201. {
  202. return index_params_;
  203. }
  204. /**
  205. * Find set of nearest neighbors to vec. Their indices are stored inside
  206. * the result object.
  207. *
  208. * Params:
  209. * result = the result object in which the indices of the nearest-neighbors are stored
  210. * vec = the vector for which to search the nearest neighbors
  211. * maxCheck = the maximum number of restarts (in a best-bin-first manner)
  212. */
  213. void findNeighbors(ResultSet<DistanceType>& result, const ElementType* vec, const SearchParams& searchParams) CV_OVERRIDE
  214. {
  215. float epsError = 1+get_param(searchParams,"eps",0.0f);
  216. std::vector<DistanceType> dists(dim_,0);
  217. DistanceType distsq = computeInitialDistances(vec, dists);
  218. searchLevel(result, vec, root_node_, distsq, dists, epsError);
  219. }
  220. private:
  221. /*--------------------- Internal Data Structures --------------------------*/
  222. struct Node
  223. {
  224. /**
  225. * Indices of points in leaf node
  226. */
  227. int left, right;
  228. /**
  229. * Dimension used for subdivision.
  230. */
  231. int divfeat;
  232. /**
  233. * The values used for subdivision.
  234. */
  235. DistanceType divlow, divhigh;
  236. /**
  237. * The child nodes.
  238. */
  239. Node* child1, * child2;
  240. };
  241. typedef Node* NodePtr;
  242. struct Interval
  243. {
  244. DistanceType low, high;
  245. };
  246. typedef std::vector<Interval> BoundingBox;
  247. typedef BranchStruct<NodePtr, DistanceType> BranchSt;
  248. typedef BranchSt* Branch;
  249. void save_tree(FILE* stream, NodePtr tree)
  250. {
  251. save_value(stream, *tree);
  252. if (tree->child1!=NULL) {
  253. save_tree(stream, tree->child1);
  254. }
  255. if (tree->child2!=NULL) {
  256. save_tree(stream, tree->child2);
  257. }
  258. }
  259. void load_tree(FILE* stream, NodePtr& tree)
  260. {
  261. tree = pool_.allocate<Node>();
  262. load_value(stream, *tree);
  263. if (tree->child1!=NULL) {
  264. load_tree(stream, tree->child1);
  265. }
  266. if (tree->child2!=NULL) {
  267. load_tree(stream, tree->child2);
  268. }
  269. }
  270. void computeBoundingBox(BoundingBox& bbox)
  271. {
  272. bbox.resize(dim_);
  273. for (size_t i=0; i<dim_; ++i) {
  274. bbox[i].low = (DistanceType)dataset_[0][i];
  275. bbox[i].high = (DistanceType)dataset_[0][i];
  276. }
  277. for (size_t k=1; k<dataset_.rows; ++k) {
  278. for (size_t i=0; i<dim_; ++i) {
  279. if (dataset_[k][i]<bbox[i].low) bbox[i].low = (DistanceType)dataset_[k][i];
  280. if (dataset_[k][i]>bbox[i].high) bbox[i].high = (DistanceType)dataset_[k][i];
  281. }
  282. }
  283. }
  284. /**
  285. * Create a tree node that subdivides the list of vecs from vind[first]
  286. * to vind[last]. The routine is called recursively on each sublist.
  287. * Place a pointer to this new tree node in the location pTree.
  288. *
  289. * Params: pTree = the new node to create
  290. * first = index of the first vector
  291. * last = index of the last vector
  292. */
  293. NodePtr divideTree(int left, int right, BoundingBox& bbox)
  294. {
  295. NodePtr node = pool_.allocate<Node>(); // allocate memory
  296. /* If too few exemplars remain, then make this a leaf node. */
  297. if ( (right-left) <= leaf_max_size_) {
  298. node->child1 = node->child2 = NULL; /* Mark as leaf node. */
  299. node->left = left;
  300. node->right = right;
  301. // compute bounding-box of leaf points
  302. for (size_t i=0; i<dim_; ++i) {
  303. bbox[i].low = (DistanceType)dataset_[vind_[left]][i];
  304. bbox[i].high = (DistanceType)dataset_[vind_[left]][i];
  305. }
  306. for (int k=left+1; k<right; ++k) {
  307. for (size_t i=0; i<dim_; ++i) {
  308. if (bbox[i].low>dataset_[vind_[k]][i]) bbox[i].low=(DistanceType)dataset_[vind_[k]][i];
  309. if (bbox[i].high<dataset_[vind_[k]][i]) bbox[i].high=(DistanceType)dataset_[vind_[k]][i];
  310. }
  311. }
  312. }
  313. else {
  314. int idx;
  315. int cutfeat;
  316. DistanceType cutval;
  317. middleSplit_(&vind_[0]+left, right-left, idx, cutfeat, cutval, bbox);
  318. node->divfeat = cutfeat;
  319. BoundingBox left_bbox(bbox);
  320. left_bbox[cutfeat].high = cutval;
  321. node->child1 = divideTree(left, left+idx, left_bbox);
  322. BoundingBox right_bbox(bbox);
  323. right_bbox[cutfeat].low = cutval;
  324. node->child2 = divideTree(left+idx, right, right_bbox);
  325. node->divlow = left_bbox[cutfeat].high;
  326. node->divhigh = right_bbox[cutfeat].low;
  327. for (size_t i=0; i<dim_; ++i) {
  328. bbox[i].low = std::min(left_bbox[i].low, right_bbox[i].low);
  329. bbox[i].high = std::max(left_bbox[i].high, right_bbox[i].high);
  330. }
  331. }
  332. return node;
  333. }
  334. void computeMinMax(int* ind, int count, int dim, ElementType& min_elem, ElementType& max_elem)
  335. {
  336. min_elem = dataset_[ind[0]][dim];
  337. max_elem = dataset_[ind[0]][dim];
  338. for (int i=1; i<count; ++i) {
  339. ElementType val = dataset_[ind[i]][dim];
  340. if (val<min_elem) min_elem = val;
  341. if (val>max_elem) max_elem = val;
  342. }
  343. }
  344. void middleSplit(int* ind, int count, int& index, int& cutfeat, DistanceType& cutval, const BoundingBox& bbox)
  345. {
  346. // find the largest span from the approximate bounding box
  347. ElementType max_span = bbox[0].high-bbox[0].low;
  348. cutfeat = 0;
  349. cutval = (bbox[0].high+bbox[0].low)/2;
  350. for (size_t i=1; i<dim_; ++i) {
  351. ElementType span = bbox[i].high-bbox[i].low;
  352. if (span>max_span) {
  353. max_span = span;
  354. cutfeat = i;
  355. cutval = (bbox[i].high+bbox[i].low)/2;
  356. }
  357. }
  358. // compute exact span on the found dimension
  359. ElementType min_elem, max_elem;
  360. computeMinMax(ind, count, cutfeat, min_elem, max_elem);
  361. cutval = (min_elem+max_elem)/2;
  362. max_span = max_elem - min_elem;
  363. // check if a dimension of a largest span exists
  364. size_t k = cutfeat;
  365. for (size_t i=0; i<dim_; ++i) {
  366. if (i==k) continue;
  367. ElementType span = bbox[i].high-bbox[i].low;
  368. if (span>max_span) {
  369. computeMinMax(ind, count, i, min_elem, max_elem);
  370. span = max_elem - min_elem;
  371. if (span>max_span) {
  372. max_span = span;
  373. cutfeat = i;
  374. cutval = (min_elem+max_elem)/2;
  375. }
  376. }
  377. }
  378. int lim1, lim2;
  379. planeSplit(ind, count, cutfeat, cutval, lim1, lim2);
  380. if (lim1>count/2) index = lim1;
  381. else if (lim2<count/2) index = lim2;
  382. else index = count/2;
  383. }
  384. void middleSplit_(int* ind, int count, int& index, int& cutfeat, DistanceType& cutval, const BoundingBox& bbox)
  385. {
  386. const float EPS=0.00001f;
  387. DistanceType max_span = bbox[0].high-bbox[0].low;
  388. for (size_t i=1; i<dim_; ++i) {
  389. DistanceType span = bbox[i].high-bbox[i].low;
  390. if (span>max_span) {
  391. max_span = span;
  392. }
  393. }
  394. DistanceType max_spread = -1;
  395. cutfeat = 0;
  396. for (size_t i=0; i<dim_; ++i) {
  397. DistanceType span = bbox[i].high-bbox[i].low;
  398. if (span>(DistanceType)((1-EPS)*max_span)) {
  399. ElementType min_elem, max_elem;
  400. computeMinMax(ind, count, cutfeat, min_elem, max_elem);
  401. DistanceType spread = (DistanceType)(max_elem-min_elem);
  402. if (spread>max_spread) {
  403. cutfeat = (int)i;
  404. max_spread = spread;
  405. }
  406. }
  407. }
  408. // split in the middle
  409. DistanceType split_val = (bbox[cutfeat].low+bbox[cutfeat].high)/2;
  410. ElementType min_elem, max_elem;
  411. computeMinMax(ind, count, cutfeat, min_elem, max_elem);
  412. if (split_val<min_elem) cutval = (DistanceType)min_elem;
  413. else if (split_val>max_elem) cutval = (DistanceType)max_elem;
  414. else cutval = split_val;
  415. int lim1, lim2;
  416. planeSplit(ind, count, cutfeat, cutval, lim1, lim2);
  417. if (lim1>count/2) index = lim1;
  418. else if (lim2<count/2) index = lim2;
  419. else index = count/2;
  420. }
  421. /**
  422. * Subdivide the list of points by a plane perpendicular on axe corresponding
  423. * to the 'cutfeat' dimension at 'cutval' position.
  424. *
  425. * On return:
  426. * dataset[ind[0..lim1-1]][cutfeat]<cutval
  427. * dataset[ind[lim1..lim2-1]][cutfeat]==cutval
  428. * dataset[ind[lim2..count]][cutfeat]>cutval
  429. */
  430. void planeSplit(int* ind, int count, int cutfeat, DistanceType cutval, int& lim1, int& lim2)
  431. {
  432. /* Move vector indices for left subtree to front of list. */
  433. int left = 0;
  434. int right = count-1;
  435. for (;; ) {
  436. while (left<=right && dataset_[ind[left]][cutfeat]<cutval) ++left;
  437. while (left<=right && dataset_[ind[right]][cutfeat]>=cutval) --right;
  438. if (left>right) break;
  439. std::swap(ind[left], ind[right]); ++left; --right;
  440. }
  441. /* If either list is empty, it means that all remaining features
  442. * are identical. Split in the middle to maintain a balanced tree.
  443. */
  444. lim1 = left;
  445. right = count-1;
  446. for (;; ) {
  447. while (left<=right && dataset_[ind[left]][cutfeat]<=cutval) ++left;
  448. while (left<=right && dataset_[ind[right]][cutfeat]>cutval) --right;
  449. if (left>right) break;
  450. std::swap(ind[left], ind[right]); ++left; --right;
  451. }
  452. lim2 = left;
  453. }
  454. DistanceType computeInitialDistances(const ElementType* vec, std::vector<DistanceType>& dists)
  455. {
  456. DistanceType distsq = 0.0;
  457. for (size_t i = 0; i < dim_; ++i) {
  458. if (vec[i] < root_bbox_[i].low) {
  459. dists[i] = distance_.accum_dist(vec[i], root_bbox_[i].low, (int)i);
  460. distsq += dists[i];
  461. }
  462. if (vec[i] > root_bbox_[i].high) {
  463. dists[i] = distance_.accum_dist(vec[i], root_bbox_[i].high, (int)i);
  464. distsq += dists[i];
  465. }
  466. }
  467. return distsq;
  468. }
  469. /**
  470. * Performs an exact search in the tree starting from a node.
  471. */
  472. void searchLevel(ResultSet<DistanceType>& result_set, const ElementType* vec, const NodePtr node, DistanceType mindistsq,
  473. std::vector<DistanceType>& dists, const float epsError)
  474. {
  475. /* If this is a leaf node, then do check and return. */
  476. if ((node->child1 == NULL)&&(node->child2 == NULL)) {
  477. DistanceType worst_dist = result_set.worstDist();
  478. for (int i=node->left; i<node->right; ++i) {
  479. int index = reorder_ ? i : vind_[i];
  480. DistanceType dist = distance_(vec, data_[index], dim_, worst_dist);
  481. if (dist<worst_dist) {
  482. result_set.addPoint(dist,vind_[i]);
  483. }
  484. }
  485. return;
  486. }
  487. /* Which child branch should be taken first? */
  488. int idx = node->divfeat;
  489. ElementType val = vec[idx];
  490. DistanceType diff1 = val - node->divlow;
  491. DistanceType diff2 = val - node->divhigh;
  492. NodePtr bestChild;
  493. NodePtr otherChild;
  494. DistanceType cut_dist;
  495. if ((diff1+diff2)<0) {
  496. bestChild = node->child1;
  497. otherChild = node->child2;
  498. cut_dist = distance_.accum_dist(val, node->divhigh, idx);
  499. }
  500. else {
  501. bestChild = node->child2;
  502. otherChild = node->child1;
  503. cut_dist = distance_.accum_dist( val, node->divlow, idx);
  504. }
  505. /* Call recursively to search next level down. */
  506. searchLevel(result_set, vec, bestChild, mindistsq, dists, epsError);
  507. DistanceType dst = dists[idx];
  508. mindistsq = mindistsq + cut_dist - dst;
  509. dists[idx] = cut_dist;
  510. if (mindistsq*epsError<=result_set.worstDist()) {
  511. searchLevel(result_set, vec, otherChild, mindistsq, dists, epsError);
  512. }
  513. dists[idx] = dst;
  514. }
  515. private:
  516. /**
  517. * The dataset used by this index
  518. */
  519. const Matrix<ElementType> dataset_;
  520. IndexParams index_params_;
  521. int leaf_max_size_;
  522. bool reorder_;
  523. /**
  524. * Array of indices to vectors in the dataset.
  525. */
  526. std::vector<int> vind_;
  527. Matrix<ElementType> data_;
  528. size_t size_;
  529. size_t dim_;
  530. /**
  531. * Array of k-d trees used to find neighbours.
  532. */
  533. NodePtr root_node_;
  534. BoundingBox root_bbox_;
  535. /**
  536. * Pooled memory allocator.
  537. *
  538. * Using a pooled memory allocator is more efficient
  539. * than allocating memory directly when there is a large
  540. * number small of memory allocations.
  541. */
  542. PooledAllocator pool_;
  543. Distance distance_;
  544. }; // class KDTree
  545. }
  546. #endif //OPENCV_FLANN_KDTREE_SINGLE_INDEX_H_