sniper_params_stats.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import sys
  15. import json
  16. import logging
  17. import numpy as np
  18. from ppdet.utils.logger import setup_logger
  19. logger = setup_logger('sniper_params_stats')
  20. def get_default_params(architecture):
  21. """get_default_params"""
  22. if architecture == "FasterRCNN":
  23. anchor_range = np.array([64., 512.]) # for frcnn-fpn
  24. # anchor_range = np.array([16., 373.]) # for yolov3
  25. # anchor_range = np.array([32., 373.]) # for yolov3
  26. default_crop_size = 1536 # mod 32 for frcnn-fpn
  27. default_max_bbox_size = 352
  28. elif architecture == "YOLOv3":
  29. anchor_range = np.array([32., 373.]) # for yolov3
  30. default_crop_size = 800 # mod 32 for yolov3
  31. default_max_bbox_size = 352
  32. else:
  33. raise NotImplementedError
  34. return anchor_range, default_crop_size, default_max_bbox_size
  35. def get_box_ratios(anno_file):
  36. """
  37. get_size_ratios
  38. :param anno_file: coco anno flile
  39. :return: size_ratio: (box_long_size / pic_long_size)
  40. """
  41. coco_dict = json.load(open(anno_file))
  42. image_list = coco_dict['images']
  43. anno_list = coco_dict['annotations']
  44. image_id2hw = {}
  45. for im_dict in image_list:
  46. im_id = im_dict['id']
  47. h, w = im_dict['height'], im_dict['width']
  48. image_id2hw[im_id] = (h, w)
  49. box_ratios = []
  50. for a_dict in anno_list:
  51. im_id = a_dict['image_id']
  52. im_h, im_w = image_id2hw[im_id]
  53. bbox = a_dict['bbox']
  54. x1, y1, w, h = bbox
  55. pic_long = max(im_h, im_w)
  56. box_long = max(w, h)
  57. box_ratios.append(box_long / pic_long)
  58. return np.array(box_ratios)
  59. def get_target_size_and_valid_box_ratios(anchor_range, box_ratio_p2, box_ratio_p98):
  60. """get_scale_and_ratios"""
  61. anchor_better_low, anchor_better_high = anchor_range # (60., 512.)
  62. anchor_center = np.sqrt(anchor_better_high * anchor_better_low)
  63. anchor_log_range = np.log10(anchor_better_high) - np.log10(anchor_better_low)
  64. box_ratio_log_range = np.log10(box_ratio_p98) - np.log10(box_ratio_p2)
  65. logger.info("anchor_log_range:{}, box_ratio_log_range:{}".format(anchor_log_range, box_ratio_log_range))
  66. box_cut_num = int(np.ceil(box_ratio_log_range / anchor_log_range))
  67. box_ratio_log_window = box_ratio_log_range / box_cut_num
  68. logger.info("box_cut_num:{}, box_ratio_log_window:{}".format(box_cut_num, box_ratio_log_window))
  69. image_target_sizes = []
  70. valid_ratios = []
  71. for i in range(box_cut_num):
  72. # # method1: align center
  73. # box_ratio_log_center = np.log10(p2) + 0.5 * box_ratio_log_window + i * box_ratio_log_window
  74. # box_ratio_center = np.power(10, box_ratio_log_center)
  75. # scale = anchor_center / box_ratio_center
  76. # method2: align left low
  77. box_ratio_low = np.power(10, np.log10(box_ratio_p2) + i * box_ratio_log_window)
  78. image_target_size = anchor_better_low / box_ratio_low
  79. image_target_sizes.append(int(image_target_size))
  80. valid_ratio = anchor_range / image_target_size
  81. valid_ratios.append(valid_ratio.tolist())
  82. logger.info("Box cut {}".format(i))
  83. logger.info("box_ratio_low: {}".format(box_ratio_low))
  84. logger.info("image_target_size: {}".format(image_target_size))
  85. logger.info("valid_ratio: {}".format(valid_ratio))
  86. return image_target_sizes, valid_ratios
  87. def get_valid_ranges(valid_ratios):
  88. """
  89. get_valid_box_ratios_range
  90. :param valid_ratios:
  91. :return:
  92. """
  93. valid_ranges = []
  94. if len(valid_ratios) == 1:
  95. valid_ranges.append([-1, -1])
  96. else:
  97. for i, vratio in enumerate(valid_ratios):
  98. if i == 0:
  99. valid_ranges.append([-1, vratio[1]])
  100. elif i == len(valid_ratios) - 1:
  101. valid_ranges.append([vratio[0], -1])
  102. else:
  103. valid_ranges.append(vratio)
  104. return valid_ranges
  105. def get_percentile(a_array, low_percent, high_percent):
  106. """
  107. get_percentile
  108. :param low_percent:
  109. :param high_percent:
  110. :return:
  111. """
  112. array_p0 = min(a_array)
  113. array_p100 = max(a_array)
  114. array_plow = np.percentile(a_array, low_percent)
  115. array_phigh = np.percentile(a_array, high_percent)
  116. logger.info(
  117. "array_percentile(0): {},array_percentile low({}): {}, "
  118. "array_percentile high({}): {}, array_percentile 100: {}".format(
  119. array_p0, low_percent, array_plow, high_percent, array_phigh, array_p100))
  120. return array_plow, array_phigh
  121. def sniper_anno_stats(architecture, anno_file):
  122. """
  123. sniper_anno_stats
  124. :param anno_file:
  125. :return:
  126. """
  127. anchor_range, default_crop_size, default_max_bbox_size = get_default_params(architecture)
  128. box_ratios = get_box_ratios(anno_file)
  129. box_ratio_p8, box_ratio_p92 = get_percentile(box_ratios, 8, 92)
  130. image_target_sizes, valid_box_ratios = get_target_size_and_valid_box_ratios(anchor_range, box_ratio_p8, box_ratio_p92)
  131. valid_ranges = get_valid_ranges(valid_box_ratios)
  132. crop_size = min(default_crop_size, min([item for item in image_target_sizes]))
  133. crop_size = int(np.ceil(crop_size / 32.) * 32.)
  134. crop_stride = max(min(default_max_bbox_size, crop_size), crop_size - default_max_bbox_size)
  135. logger.info("Result".center(100, '-'))
  136. logger.info("image_target_sizes: {}".format(image_target_sizes))
  137. logger.info("valid_box_ratio_ranges: {}".format(valid_ranges))
  138. logger.info("chip_target_size: {}, chip_target_stride: {}".format(crop_size, crop_stride))
  139. return {
  140. "image_target_sizes": image_target_sizes,
  141. "valid_box_ratio_ranges": valid_ranges,
  142. "chip_target_size": crop_size,
  143. "chip_target_stride": crop_stride
  144. }
  145. if __name__=="__main__":
  146. architecture, anno_file = sys.argv[1], sys.argv[2]
  147. sniper_anno_stats(architecture, anno_file)