sniper_coco.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import cv2
  16. import json
  17. import copy
  18. import numpy as np
  19. try:
  20. from collections.abc import Sequence
  21. except Exception:
  22. from collections import Sequence
  23. from ppdet.core.workspace import register, serializable
  24. from ppdet.data.crop_utils.annotation_cropper import AnnoCropper
  25. from .coco import COCODataSet
  26. from .dataset import _make_dataset, _is_valid_file
  27. from ppdet.utils.logger import setup_logger
  28. logger = setup_logger('sniper_coco_dataset')
  29. @register
  30. @serializable
  31. class SniperCOCODataSet(COCODataSet):
  32. """SniperCOCODataSet"""
  33. def __init__(self,
  34. dataset_dir=None,
  35. image_dir=None,
  36. anno_path=None,
  37. proposals_file=None,
  38. data_fields=['image'],
  39. sample_num=-1,
  40. load_crowd=False,
  41. allow_empty=True,
  42. empty_ratio=1.,
  43. is_trainset=True,
  44. image_target_sizes=[2000, 1000],
  45. valid_box_ratio_ranges=[[-1, 0.1],[0.08, -1]],
  46. chip_target_size=500,
  47. chip_target_stride=200,
  48. use_neg_chip=False,
  49. max_neg_num_per_im=8,
  50. max_per_img=-1,
  51. nms_thresh=0.5):
  52. super(SniperCOCODataSet, self).__init__(
  53. dataset_dir=dataset_dir,
  54. image_dir=image_dir,
  55. anno_path=anno_path,
  56. data_fields=data_fields,
  57. sample_num=sample_num,
  58. load_crowd=load_crowd,
  59. allow_empty=allow_empty,
  60. empty_ratio=empty_ratio
  61. )
  62. self.proposals_file = proposals_file
  63. self.proposals = None
  64. self.anno_cropper = None
  65. self.is_trainset = is_trainset
  66. self.image_target_sizes = image_target_sizes
  67. self.valid_box_ratio_ranges = valid_box_ratio_ranges
  68. self.chip_target_size = chip_target_size
  69. self.chip_target_stride = chip_target_stride
  70. self.use_neg_chip = use_neg_chip
  71. self.max_neg_num_per_im = max_neg_num_per_im
  72. self.max_per_img = max_per_img
  73. self.nms_thresh = nms_thresh
  74. def parse_dataset(self):
  75. if not hasattr(self, "roidbs"):
  76. super(SniperCOCODataSet, self).parse_dataset()
  77. if self.is_trainset:
  78. self._parse_proposals()
  79. self._merge_anno_proposals()
  80. self.ori_roidbs = copy.deepcopy(self.roidbs)
  81. self.init_anno_cropper()
  82. self.roidbs = self.generate_chips_roidbs(self.roidbs, self.is_trainset)
  83. def set_proposals_file(self, file_path):
  84. self.proposals_file = file_path
  85. def init_anno_cropper(self):
  86. logger.info("Init AnnoCropper...")
  87. self.anno_cropper = AnnoCropper(
  88. image_target_sizes=self.image_target_sizes,
  89. valid_box_ratio_ranges=self.valid_box_ratio_ranges,
  90. chip_target_size=self.chip_target_size,
  91. chip_target_stride=self.chip_target_stride,
  92. use_neg_chip=self.use_neg_chip,
  93. max_neg_num_per_im=self.max_neg_num_per_im,
  94. max_per_img=self.max_per_img,
  95. nms_thresh=self.nms_thresh
  96. )
  97. def generate_chips_roidbs(self, roidbs, is_trainset):
  98. if is_trainset:
  99. roidbs = self.anno_cropper.crop_anno_records(roidbs)
  100. else:
  101. roidbs = self.anno_cropper.crop_infer_anno_records(roidbs)
  102. return roidbs
  103. def _parse_proposals(self):
  104. if self.proposals_file:
  105. self.proposals = {}
  106. logger.info("Parse proposals file:{}".format(self.proposals_file))
  107. with open(self.proposals_file, 'r') as f:
  108. proposals = json.load(f)
  109. for prop in proposals:
  110. image_id = prop["image_id"]
  111. if image_id not in self.proposals:
  112. self.proposals[image_id] = []
  113. x, y, w, h = prop["bbox"]
  114. self.proposals[image_id].append([x, y, x + w, y + h])
  115. def _merge_anno_proposals(self):
  116. assert self.roidbs
  117. if self.proposals and len(self.proposals.keys()) > 0:
  118. logger.info("merge proposals to annos")
  119. for id, record in enumerate(self.roidbs):
  120. image_id = int(record["im_id"])
  121. if image_id not in self.proposals.keys():
  122. logger.info("image id :{} no proposals".format(image_id))
  123. record["proposals"] = np.array(self.proposals.get(image_id, []), dtype=np.float32)
  124. self.roidbs[id] = record
  125. def get_ori_roidbs(self):
  126. if not hasattr(self, "ori_roidbs"):
  127. return None
  128. return self.ori_roidbs
  129. def get_roidbs(self):
  130. if not hasattr(self, "roidbs"):
  131. self.parse_dataset()
  132. return self.roidbs
  133. def set_roidbs(self, roidbs):
  134. self.roidbs = roidbs
  135. def check_or_download_dataset(self):
  136. return
  137. def _parse(self):
  138. image_dir = self.image_dir
  139. if not isinstance(image_dir, Sequence):
  140. image_dir = [image_dir]
  141. images = []
  142. for im_dir in image_dir:
  143. if os.path.isdir(im_dir):
  144. im_dir = os.path.join(self.dataset_dir, im_dir)
  145. images.extend(_make_dataset(im_dir))
  146. elif os.path.isfile(im_dir) and _is_valid_file(im_dir):
  147. images.append(im_dir)
  148. return images
  149. def _load_images(self):
  150. images = self._parse()
  151. ct = 0
  152. records = []
  153. for image in images:
  154. assert image != '' and os.path.isfile(image), \
  155. "Image {} not found".format(image)
  156. if self.sample_num > 0 and ct >= self.sample_num:
  157. break
  158. im = cv2.imread(image)
  159. h, w, c = im.shape
  160. rec = {'im_id': np.array([ct]), 'im_file': image, "h": h, "w": w}
  161. self._imid2path[ct] = image
  162. ct += 1
  163. records.append(rec)
  164. assert len(records) > 0, "No image file found"
  165. return records
  166. def get_imid2path(self):
  167. return self._imid2path
  168. def set_images(self, images):
  169. self._imid2path = {}
  170. self.image_dir = images
  171. self.roidbs = self._load_images()