munkres.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. This code is borrow from https://github.com/xingyizhou/CenterTrack/blob/master/src/tools/eval_kitti_track/munkres.py
  16. """
  17. import sys
  18. __all__ = ['Munkres', 'make_cost_matrix']
  19. class Munkres:
  20. """
  21. Calculate the Munkres solution to the classical assignment problem.
  22. See the module documentation for usage.
  23. """
  24. def __init__(self):
  25. """Create a new instance"""
  26. self.C = None
  27. self.row_covered = []
  28. self.col_covered = []
  29. self.n = 0
  30. self.Z0_r = 0
  31. self.Z0_c = 0
  32. self.marked = None
  33. self.path = None
  34. def make_cost_matrix(profit_matrix, inversion_function):
  35. """
  36. **DEPRECATED**
  37. Please use the module function ``make_cost_matrix()``.
  38. """
  39. import munkres
  40. return munkres.make_cost_matrix(profit_matrix, inversion_function)
  41. make_cost_matrix = staticmethod(make_cost_matrix)
  42. def pad_matrix(self, matrix, pad_value=0):
  43. """
  44. Pad a possibly non-square matrix to make it square.
  45. :Parameters:
  46. matrix : list of lists
  47. matrix to pad
  48. pad_value : int
  49. value to use to pad the matrix
  50. :rtype: list of lists
  51. :return: a new, possibly padded, matrix
  52. """
  53. max_columns = 0
  54. total_rows = len(matrix)
  55. for row in matrix:
  56. max_columns = max(max_columns, len(row))
  57. total_rows = max(max_columns, total_rows)
  58. new_matrix = []
  59. for row in matrix:
  60. row_len = len(row)
  61. new_row = row[:]
  62. if total_rows > row_len:
  63. # Row too short. Pad it.
  64. new_row += [0] * (total_rows - row_len)
  65. new_matrix += [new_row]
  66. while len(new_matrix) < total_rows:
  67. new_matrix += [[0] * total_rows]
  68. return new_matrix
  69. def compute(self, cost_matrix):
  70. """
  71. Compute the indexes for the lowest-cost pairings between rows and
  72. columns in the database. Returns a list of (row, column) tuples
  73. that can be used to traverse the matrix.
  74. :Parameters:
  75. cost_matrix : list of lists
  76. The cost matrix. If this cost matrix is not square, it
  77. will be padded with zeros, via a call to ``pad_matrix()``.
  78. (This method does *not* modify the caller's matrix. It
  79. operates on a copy of the matrix.)
  80. **WARNING**: This code handles square and rectangular
  81. matrices. It does *not* handle irregular matrices.
  82. :rtype: list
  83. :return: A list of ``(row, column)`` tuples that describe the lowest
  84. cost path through the matrix
  85. """
  86. self.C = self.pad_matrix(cost_matrix)
  87. self.n = len(self.C)
  88. self.original_length = len(cost_matrix)
  89. self.original_width = len(cost_matrix[0])
  90. self.row_covered = [False for i in range(self.n)]
  91. self.col_covered = [False for i in range(self.n)]
  92. self.Z0_r = 0
  93. self.Z0_c = 0
  94. self.path = self.__make_matrix(self.n * 2, 0)
  95. self.marked = self.__make_matrix(self.n, 0)
  96. done = False
  97. step = 1
  98. steps = {
  99. 1: self.__step1,
  100. 2: self.__step2,
  101. 3: self.__step3,
  102. 4: self.__step4,
  103. 5: self.__step5,
  104. 6: self.__step6
  105. }
  106. while not done:
  107. try:
  108. func = steps[step]
  109. step = func()
  110. except KeyError:
  111. done = True
  112. # Look for the starred columns
  113. results = []
  114. for i in range(self.original_length):
  115. for j in range(self.original_width):
  116. if self.marked[i][j] == 1:
  117. results += [(i, j)]
  118. return results
  119. def __copy_matrix(self, matrix):
  120. """Return an exact copy of the supplied matrix"""
  121. return copy.deepcopy(matrix)
  122. def __make_matrix(self, n, val):
  123. """Create an *n*x*n* matrix, populating it with the specific value."""
  124. matrix = []
  125. for i in range(n):
  126. matrix += [[val for j in range(n)]]
  127. return matrix
  128. def __step1(self):
  129. """
  130. For each row of the matrix, find the smallest element and
  131. subtract it from every element in its row. Go to Step 2.
  132. """
  133. C = self.C
  134. n = self.n
  135. for i in range(n):
  136. minval = min(self.C[i])
  137. # Find the minimum value for this row and subtract that minimum
  138. # from every element in the row.
  139. for j in range(n):
  140. self.C[i][j] -= minval
  141. return 2
  142. def __step2(self):
  143. """
  144. Find a zero (Z) in the resulting matrix. If there is no starred
  145. zero in its row or column, star Z. Repeat for each element in the
  146. matrix. Go to Step 3.
  147. """
  148. n = self.n
  149. for i in range(n):
  150. for j in range(n):
  151. if (self.C[i][j] == 0) and \
  152. (not self.col_covered[j]) and \
  153. (not self.row_covered[i]):
  154. self.marked[i][j] = 1
  155. self.col_covered[j] = True
  156. self.row_covered[i] = True
  157. self.__clear_covers()
  158. return 3
  159. def __step3(self):
  160. """
  161. Cover each column containing a starred zero. If K columns are
  162. covered, the starred zeros describe a complete set of unique
  163. assignments. In this case, Go to DONE, otherwise, Go to Step 4.
  164. """
  165. n = self.n
  166. count = 0
  167. for i in range(n):
  168. for j in range(n):
  169. if self.marked[i][j] == 1:
  170. self.col_covered[j] = True
  171. count += 1
  172. if count >= n:
  173. step = 7 # done
  174. else:
  175. step = 4
  176. return step
  177. def __step4(self):
  178. """
  179. Find a noncovered zero and prime it. If there is no starred zero
  180. in the row containing this primed zero, Go to Step 5. Otherwise,
  181. cover this row and uncover the column containing the starred
  182. zero. Continue in this manner until there are no uncovered zeros
  183. left. Save the smallest uncovered value and Go to Step 6.
  184. """
  185. step = 0
  186. done = False
  187. row = -1
  188. col = -1
  189. star_col = -1
  190. while not done:
  191. (row, col) = self.__find_a_zero()
  192. if row < 0:
  193. done = True
  194. step = 6
  195. else:
  196. self.marked[row][col] = 2
  197. star_col = self.__find_star_in_row(row)
  198. if star_col >= 0:
  199. col = star_col
  200. self.row_covered[row] = True
  201. self.col_covered[col] = False
  202. else:
  203. done = True
  204. self.Z0_r = row
  205. self.Z0_c = col
  206. step = 5
  207. return step
  208. def __step5(self):
  209. """
  210. Construct a series of alternating primed and starred zeros as
  211. follows. Let Z0 represent the uncovered primed zero found in Step 4.
  212. Let Z1 denote the starred zero in the column of Z0 (if any).
  213. Let Z2 denote the primed zero in the row of Z1 (there will always
  214. be one). Continue until the series terminates at a primed zero
  215. that has no starred zero in its column. Unstar each starred zero
  216. of the series, star each primed zero of the series, erase all
  217. primes and uncover every line in the matrix. Return to Step 3
  218. """
  219. count = 0
  220. path = self.path
  221. path[count][0] = self.Z0_r
  222. path[count][1] = self.Z0_c
  223. done = False
  224. while not done:
  225. row = self.__find_star_in_col(path[count][1])
  226. if row >= 0:
  227. count += 1
  228. path[count][0] = row
  229. path[count][1] = path[count - 1][1]
  230. else:
  231. done = True
  232. if not done:
  233. col = self.__find_prime_in_row(path[count][0])
  234. count += 1
  235. path[count][0] = path[count - 1][0]
  236. path[count][1] = col
  237. self.__convert_path(path, count)
  238. self.__clear_covers()
  239. self.__erase_primes()
  240. return 3
  241. def __step6(self):
  242. """
  243. Add the value found in Step 4 to every element of each covered
  244. row, and subtract it from every element of each uncovered column.
  245. Return to Step 4 without altering any stars, primes, or covered
  246. lines.
  247. """
  248. minval = self.__find_smallest()
  249. for i in range(self.n):
  250. for j in range(self.n):
  251. if self.row_covered[i]:
  252. self.C[i][j] += minval
  253. if not self.col_covered[j]:
  254. self.C[i][j] -= minval
  255. return 4
  256. def __find_smallest(self):
  257. """Find the smallest uncovered value in the matrix."""
  258. minval = 2e9 # sys.maxint
  259. for i in range(self.n):
  260. for j in range(self.n):
  261. if (not self.row_covered[i]) and (not self.col_covered[j]):
  262. if minval > self.C[i][j]:
  263. minval = self.C[i][j]
  264. return minval
  265. def __find_a_zero(self):
  266. """Find the first uncovered element with value 0"""
  267. row = -1
  268. col = -1
  269. i = 0
  270. n = self.n
  271. done = False
  272. while not done:
  273. j = 0
  274. while True:
  275. if (self.C[i][j] == 0) and \
  276. (not self.row_covered[i]) and \
  277. (not self.col_covered[j]):
  278. row = i
  279. col = j
  280. done = True
  281. j += 1
  282. if j >= n:
  283. break
  284. i += 1
  285. if i >= n:
  286. done = True
  287. return (row, col)
  288. def __find_star_in_row(self, row):
  289. """
  290. Find the first starred element in the specified row. Returns
  291. the column index, or -1 if no starred element was found.
  292. """
  293. col = -1
  294. for j in range(self.n):
  295. if self.marked[row][j] == 1:
  296. col = j
  297. break
  298. return col
  299. def __find_star_in_col(self, col):
  300. """
  301. Find the first starred element in the specified row. Returns
  302. the row index, or -1 if no starred element was found.
  303. """
  304. row = -1
  305. for i in range(self.n):
  306. if self.marked[i][col] == 1:
  307. row = i
  308. break
  309. return row
  310. def __find_prime_in_row(self, row):
  311. """
  312. Find the first prime element in the specified row. Returns
  313. the column index, or -1 if no starred element was found.
  314. """
  315. col = -1
  316. for j in range(self.n):
  317. if self.marked[row][j] == 2:
  318. col = j
  319. break
  320. return col
  321. def __convert_path(self, path, count):
  322. for i in range(count + 1):
  323. if self.marked[path[i][0]][path[i][1]] == 1:
  324. self.marked[path[i][0]][path[i][1]] = 0
  325. else:
  326. self.marked[path[i][0]][path[i][1]] = 1
  327. def __clear_covers(self):
  328. """Clear all covered matrix cells"""
  329. for i in range(self.n):
  330. self.row_covered[i] = False
  331. self.col_covered[i] = False
  332. def __erase_primes(self):
  333. """Erase all prime markings"""
  334. for i in range(self.n):
  335. for j in range(self.n):
  336. if self.marked[i][j] == 2:
  337. self.marked[i][j] = 0
  338. def make_cost_matrix(profit_matrix, inversion_function):
  339. """
  340. Create a cost matrix from a profit matrix by calling
  341. 'inversion_function' to invert each value. The inversion
  342. function must take one numeric argument (of any type) and return
  343. another numeric argument which is presumed to be the cost inverse
  344. of the original profit.
  345. This is a static method. Call it like this:
  346. .. python::
  347. cost_matrix = Munkres.make_cost_matrix(matrix, inversion_func)
  348. For example:
  349. .. python::
  350. cost_matrix = Munkres.make_cost_matrix(matrix, lambda x : sys.maxint - x)
  351. :Parameters:
  352. profit_matrix : list of lists
  353. The matrix to convert from a profit to a cost matrix
  354. inversion_function : function
  355. The function to use to invert each entry in the profit matrix
  356. :rtype: list of lists
  357. :return: The converted matrix
  358. """
  359. cost_matrix = []
  360. for row in profit_matrix:
  361. cost_matrix.append([inversion_function(value) for value in row])
  362. return cost_matrix