simplex_downhill.h 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186
  1. /***********************************************************************
  2. * Software License Agreement (BSD License)
  3. *
  4. * Copyright 2008-2009 Marius Muja (mariusm@cs.ubc.ca). All rights reserved.
  5. * Copyright 2008-2009 David G. Lowe (lowe@cs.ubc.ca). All rights reserved.
  6. *
  7. * THE BSD LICENSE
  8. *
  9. * Redistribution and use in source and binary forms, with or without
  10. * modification, are permitted provided that the following conditions
  11. * are met:
  12. *
  13. * 1. Redistributions of source code must retain the above copyright
  14. * notice, this list of conditions and the following disclaimer.
  15. * 2. Redistributions in binary form must reproduce the above copyright
  16. * notice, this list of conditions and the following disclaimer in the
  17. * documentation and/or other materials provided with the distribution.
  18. *
  19. * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
  20. * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
  21. * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
  22. * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
  23. * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
  24. * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
  25. * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
  26. * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
  27. * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
  28. * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  29. *************************************************************************/
  30. #ifndef OPENCV_FLANN_SIMPLEX_DOWNHILL_H_
  31. #define OPENCV_FLANN_SIMPLEX_DOWNHILL_H_
  32. namespace cvflann
  33. {
  34. /**
  35. Adds val to array vals (and point to array points) and keeping the arrays sorted by vals.
  36. */
  37. template <typename T>
  38. void addValue(int pos, float val, float* vals, T* point, T* points, int n)
  39. {
  40. vals[pos] = val;
  41. for (int i=0; i<n; ++i) {
  42. points[pos*n+i] = point[i];
  43. }
  44. // bubble down
  45. int j=pos;
  46. while (j>0 && vals[j]<vals[j-1]) {
  47. swap(vals[j],vals[j-1]);
  48. for (int i=0; i<n; ++i) {
  49. swap(points[j*n+i],points[(j-1)*n+i]);
  50. }
  51. --j;
  52. }
  53. }
  54. /**
  55. Simplex downhill optimization function.
  56. Preconditions: points is a 2D mattrix of size (n+1) x n
  57. func is the cost function taking n an array of n params and returning float
  58. vals is the cost function in the n+1 simplex points, if NULL it will be computed
  59. Postcondition: returns optimum value and points[0..n] are the optimum parameters
  60. */
  61. template <typename T, typename F>
  62. float optimizeSimplexDownhill(T* points, int n, F func, float* vals = NULL )
  63. {
  64. const int MAX_ITERATIONS = 10;
  65. assert(n>0);
  66. T* p_o = new T[n];
  67. T* p_r = new T[n];
  68. T* p_e = new T[n];
  69. int alpha = 1;
  70. int iterations = 0;
  71. bool ownVals = false;
  72. if (vals == NULL) {
  73. ownVals = true;
  74. vals = new float[n+1];
  75. for (int i=0; i<n+1; ++i) {
  76. float val = func(points+i*n);
  77. addValue(i, val, vals, points+i*n, points, n);
  78. }
  79. }
  80. int nn = n*n;
  81. while (true) {
  82. if (iterations++ > MAX_ITERATIONS) break;
  83. // compute average of simplex points (except the highest point)
  84. for (int j=0; j<n; ++j) {
  85. p_o[j] = 0;
  86. for (int i=0; i<n; ++i) {
  87. p_o[i] += points[j*n+i];
  88. }
  89. }
  90. for (int i=0; i<n; ++i) {
  91. p_o[i] /= n;
  92. }
  93. bool converged = true;
  94. for (int i=0; i<n; ++i) {
  95. if (p_o[i] != points[nn+i]) {
  96. converged = false;
  97. }
  98. }
  99. if (converged) break;
  100. // trying a reflection
  101. for (int i=0; i<n; ++i) {
  102. p_r[i] = p_o[i] + alpha*(p_o[i]-points[nn+i]);
  103. }
  104. float val_r = func(p_r);
  105. if ((val_r>=vals[0])&&(val_r<vals[n])) {
  106. // reflection between second highest and lowest
  107. // add it to the simplex
  108. Logger::info("Choosing reflection\n");
  109. addValue(n, val_r,vals, p_r, points, n);
  110. continue;
  111. }
  112. if (val_r<vals[0]) {
  113. // value is smaller than smalest in simplex
  114. // expand some more to see if it drops further
  115. for (int i=0; i<n; ++i) {
  116. p_e[i] = 2*p_r[i]-p_o[i];
  117. }
  118. float val_e = func(p_e);
  119. if (val_e<val_r) {
  120. Logger::info("Choosing reflection and expansion\n");
  121. addValue(n, val_e,vals,p_e,points,n);
  122. }
  123. else {
  124. Logger::info("Choosing reflection\n");
  125. addValue(n, val_r,vals,p_r,points,n);
  126. }
  127. continue;
  128. }
  129. if (val_r>=vals[n]) {
  130. for (int i=0; i<n; ++i) {
  131. p_e[i] = (p_o[i]+points[nn+i])/2;
  132. }
  133. float val_e = func(p_e);
  134. if (val_e<vals[n]) {
  135. Logger::info("Choosing contraction\n");
  136. addValue(n,val_e,vals,p_e,points,n);
  137. continue;
  138. }
  139. }
  140. {
  141. Logger::info("Full contraction\n");
  142. for (int j=1; j<=n; ++j) {
  143. for (int i=0; i<n; ++i) {
  144. points[j*n+i] = (points[j*n+i]+points[i])/2;
  145. }
  146. float val = func(points+j*n);
  147. addValue(j,val,vals,points+j*n,points,n);
  148. }
  149. }
  150. }
  151. float bestVal = vals[0];
  152. delete[] p_r;
  153. delete[] p_o;
  154. delete[] p_e;
  155. if (ownVals) delete[] vals;
  156. return bestVal;
  157. }
  158. }
  159. #endif //OPENCV_FLANN_SIMPLEX_DOWNHILL_H_