anchor_cluster.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import os
  18. import sys
  19. # add python path of PadleDetection to sys.path
  20. parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2)))
  21. if parent_path not in sys.path:
  22. sys.path.append(parent_path)
  23. import logging
  24. FORMAT = '%(asctime)s-%(levelname)s: %(message)s'
  25. logging.basicConfig(level=logging.INFO, format=FORMAT)
  26. logger = logging.getLogger(__name__)
  27. from scipy.cluster.vq import kmeans
  28. import numpy as np
  29. from tqdm import tqdm
  30. try:
  31. from ppdet.utils.cli import ArgsParser
  32. from ppdet.utils.check import check_gpu, check_version, check_config
  33. from ppdet.core.workspace import load_config, merge_config, create
  34. except ImportError as e:
  35. if sys.argv[0].find('static') >= 0:
  36. logger.error("Importing ppdet failed when running static model "
  37. "with error: {}\n"
  38. "please try:\n"
  39. "\t1. run static model under PaddleDetection/static "
  40. "directory\n"
  41. "\t2. run 'pip uninstall ppdet' to uninstall ppdet "
  42. "dynamic version firstly.".format(e))
  43. sys.exit(-1)
  44. else:
  45. raise e
  46. class BaseAnchorCluster(object):
  47. def __init__(self, n, cache_path, cache, verbose=True):
  48. """
  49. Base Anchor Cluster
  50. Args:
  51. n (int): number of clusters
  52. cache_path (str): cache directory path
  53. cache (bool): whether using cache
  54. verbose (bool): whether print results
  55. """
  56. super(BaseAnchorCluster, self).__init__()
  57. self.n = n
  58. self.cache_path = cache_path
  59. self.cache = cache
  60. self.verbose = verbose
  61. def print_result(self, centers):
  62. raise NotImplementedError('%s.print_result is not available' %
  63. self.__class__.__name__)
  64. def get_whs(self):
  65. whs_cache_path = os.path.join(self.cache_path, 'whs.npy')
  66. shapes_cache_path = os.path.join(self.cache_path, 'shapes.npy')
  67. if self.cache and os.path.exists(whs_cache_path) and os.path.exists(
  68. shapes_cache_path):
  69. self.whs = np.load(whs_cache_path)
  70. self.shapes = np.load(shapes_cache_path)
  71. return self.whs, self.shapes
  72. whs = np.zeros((0, 2))
  73. shapes = np.zeros((0, 2))
  74. roidbs = self.dataset.get_roidb()
  75. for rec in tqdm(roidbs):
  76. h, w = rec['h'], rec['w']
  77. bbox = rec['gt_bbox']
  78. wh = bbox[:, 2:4] - bbox[:, 0:2] + 1
  79. wh = wh / np.array([[w, h]])
  80. shape = np.ones_like(wh) * np.array([[w, h]])
  81. whs = np.vstack((whs, wh))
  82. shapes = np.vstack((shapes, shape))
  83. if self.cache:
  84. os.makedirs(self.cache_path, exist_ok=True)
  85. np.save(whs_cache_path, whs)
  86. np.save(shapes_cache_path, shapes)
  87. self.whs = whs
  88. self.shapes = shapes
  89. return self.whs, self.shapes
  90. def calc_anchors(self):
  91. raise NotImplementedError('%s.calc_anchors is not available' %
  92. self.__class__.__name__)
  93. def __call__(self):
  94. self.get_whs()
  95. centers = self.calc_anchors()
  96. if self.verbose:
  97. self.print_result(centers)
  98. return centers
  99. class YOLOv2AnchorCluster(BaseAnchorCluster):
  100. def __init__(self,
  101. n,
  102. dataset,
  103. size,
  104. cache_path,
  105. cache,
  106. iters=1000,
  107. verbose=True):
  108. super(YOLOv2AnchorCluster, self).__init__(
  109. n, cache_path, cache, verbose=verbose)
  110. """
  111. YOLOv2 Anchor Cluster
  112. The code is based on https://github.com/AlexeyAB/darknet/blob/master/scripts/gen_anchors.py
  113. Args:
  114. n (int): number of clusters
  115. dataset (DataSet): DataSet instance, VOC or COCO
  116. size (list): [w, h]
  117. cache_path (str): cache directory path
  118. cache (bool): whether using cache
  119. iters (int): kmeans algorithm iters
  120. verbose (bool): whether print results
  121. """
  122. self.dataset = dataset
  123. self.size = size
  124. self.iters = iters
  125. def print_result(self, centers):
  126. logger.info('%d anchor cluster result: [w, h]' % self.n)
  127. for w, h in centers:
  128. logger.info('[%d, %d]' % (round(w), round(h)))
  129. def metric(self, whs, centers):
  130. wh1 = whs[:, None]
  131. wh2 = centers[None]
  132. inter = np.minimum(wh1, wh2).prod(2)
  133. return inter / (wh1.prod(2) + wh2.prod(2) - inter)
  134. def kmeans_expectation(self, whs, centers, assignments):
  135. dist = self.metric(whs, centers)
  136. new_assignments = dist.argmax(1)
  137. converged = (new_assignments == assignments).all()
  138. return converged, new_assignments
  139. def kmeans_maximizations(self, whs, centers, assignments):
  140. new_centers = np.zeros_like(centers)
  141. for i in range(centers.shape[0]):
  142. mask = (assignments == i)
  143. if mask.sum():
  144. new_centers[i, :] = whs[mask].mean(0)
  145. return new_centers
  146. def calc_anchors(self):
  147. self.whs = self.whs * np.array([self.size])
  148. # random select k centers
  149. whs, n, iters = self.whs, self.n, self.iters
  150. logger.info('Running kmeans for %d anchors on %d points...' %
  151. (n, len(whs)))
  152. idx = np.random.choice(whs.shape[0], size=n, replace=False)
  153. centers = whs[idx]
  154. assignments = np.zeros(whs.shape[0:1]) * -1
  155. # kmeans
  156. if n == 1:
  157. return self.kmeans_maximizations(whs, centers, assignments)
  158. pbar = tqdm(range(iters), desc='Cluster anchors with k-means algorithm')
  159. for _ in pbar:
  160. # E step
  161. converged, assignments = self.kmeans_expectation(whs, centers,
  162. assignments)
  163. if converged:
  164. break
  165. # M step
  166. centers = self.kmeans_maximizations(whs, centers, assignments)
  167. ious = self.metric(whs, centers)
  168. pbar.desc = 'avg_iou: %.4f' % (ious.max(1).mean())
  169. centers = sorted(centers, key=lambda x: x[0] * x[1])
  170. return centers
  171. def main():
  172. parser = ArgsParser()
  173. parser.add_argument(
  174. '--n', '-n', default=9, type=int, help='num of clusters')
  175. parser.add_argument(
  176. '--iters',
  177. '-i',
  178. default=1000,
  179. type=int,
  180. help='num of iterations for kmeans')
  181. parser.add_argument(
  182. '--verbose', '-v', default=True, type=bool, help='whether print result')
  183. parser.add_argument(
  184. '--size',
  185. '-s',
  186. default=None,
  187. type=str,
  188. help='image size: w,h, using comma as delimiter')
  189. parser.add_argument(
  190. '--method',
  191. '-m',
  192. default='v2',
  193. type=str,
  194. help='cluster method, v2 is only supported now')
  195. parser.add_argument(
  196. '--cache_path', default='cache', type=str, help='cache path')
  197. parser.add_argument(
  198. '--cache', action='store_true', help='whether use cache')
  199. FLAGS = parser.parse_args()
  200. cfg = load_config(FLAGS.config)
  201. merge_config(FLAGS.opt)
  202. check_config(cfg)
  203. # check if set use_gpu=True in paddlepaddle cpu version
  204. check_gpu(cfg.use_gpu)
  205. # check if paddlepaddle version is satisfied
  206. check_version()
  207. # get dataset
  208. dataset = cfg['TrainReader']['dataset']
  209. if FLAGS.size:
  210. if ',' in FLAGS.size:
  211. size = list(map(int, FLAGS.size.split(',')))
  212. assert len(size) == 2, "the format of size is incorrect"
  213. else:
  214. size = int(FLAGS.size)
  215. size = [size, size]
  216. elif 'image_shape' in cfg['TestReader']['inputs_def']:
  217. size = cfg['TestReader']['inputs_def']['image_shape'][1:]
  218. else:
  219. raise ValueError('size is not specified')
  220. if FLAGS.method == 'v2':
  221. cluster = YOLOv2AnchorCluster(FLAGS.n, dataset, size, FLAGS.cache_path,
  222. FLAGS.cache, FLAGS.iters, FLAGS.verbose)
  223. else:
  224. raise ValueError('cluster method: %s is not supported' % FLAGS.method)
  225. anchors = cluster()
  226. if __name__ == "__main__":
  227. main()