joint.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292
  1. # ------------------------------------------------------------------------
  2. # Copyright (c) 2021 megvii-model. All Rights Reserved.
  3. # ------------------------------------------------------------------------
  4. # Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR)
  5. # Copyright (c) 2020 SenseTime. All Rights Reserved.
  6. # ------------------------------------------------------------------------
  7. # Modified from DETR (https://github.com/facebookresearch/detr)
  8. # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
  9. # ------------------------------------------------------------------------
  10. """
  11. MOT dataset which returns image_id for evaluation.
  12. """
  13. from pathlib import Path
  14. import cv2
  15. import numpy as np
  16. import torch
  17. import torch.utils.data
  18. import os.path as osp
  19. from PIL import Image, ImageDraw
  20. import copy
  21. import datasets.transforms as T
  22. from models.structures import Instances
  23. class DetMOTDetection:
  24. def __init__(self, args, data_txt_path: str, seqs_folder, dataset2transform):
  25. self.args = args
  26. self.dataset2transform = dataset2transform
  27. self.num_frames_per_batch = max(args.sampler_lengths)
  28. self.sample_mode = args.sample_mode
  29. self.sample_interval = args.sample_interval
  30. self.vis = args.vis
  31. self.video_dict = {}
  32. with open(data_txt_path, 'r') as file:
  33. self.img_files = file.readlines()
  34. self.img_files = [osp.join(seqs_folder, x.strip()) for x in self.img_files]
  35. self.img_files = list(filter(lambda x: len(x) > 0, self.img_files))
  36. self.label_files = [(x.replace('images', 'labels_with_ids').replace('.png', '.txt').replace('.jpg', '.txt'))
  37. for x in self.img_files]
  38. # The number of images per sample: 1 + (num_frames - 1) * interval.
  39. # The number of valid samples: num_images - num_image_per_sample + 1.
  40. self.item_num = len(self.img_files) - (self.num_frames_per_batch - 1) * self.sample_interval
  41. self._register_videos()
  42. # video sampler.
  43. self.sampler_steps: list = args.sampler_steps
  44. self.lengths: list = args.sampler_lengths
  45. print("sampler_steps={} lenghts={}".format(self.sampler_steps, self.lengths))
  46. if self.sampler_steps is not None and len(self.sampler_steps) > 0:
  47. # Enable sampling length adjustment.
  48. assert len(self.lengths) > 0
  49. assert len(self.lengths) == len(self.sampler_steps) + 1
  50. for i in range(len(self.sampler_steps) - 1):
  51. assert self.sampler_steps[i] < self.sampler_steps[i + 1]
  52. self.item_num = len(self.img_files) - (self.lengths[-1] - 1) * self.sample_interval
  53. self.period_idx = 0
  54. self.num_frames_per_batch = self.lengths[0]
  55. self.current_epoch = 0
  56. def _register_videos(self):
  57. for label_name in self.label_files:
  58. video_name = '/'.join(label_name.split('/')[:-1])
  59. if video_name not in self.video_dict:
  60. print("register {}-th video: {} ".format(len(self.video_dict) + 1, video_name))
  61. self.video_dict[video_name] = len(self.video_dict)
  62. # assert len(self.video_dict) <= 300
  63. def set_epoch(self, epoch):
  64. self.current_epoch = epoch
  65. if self.sampler_steps is None or len(self.sampler_steps) == 0:
  66. # fixed sampling length.
  67. return
  68. for i in range(len(self.sampler_steps)):
  69. if epoch >= self.sampler_steps[i]:
  70. self.period_idx = i + 1
  71. print("set epoch: epoch {} period_idx={}".format(epoch, self.period_idx))
  72. self.num_frames_per_batch = self.lengths[self.period_idx]
  73. def step_epoch(self):
  74. # one epoch finishes.
  75. print("Dataset: epoch {} finishes".format(self.current_epoch))
  76. self.set_epoch(self.current_epoch + 1)
  77. @staticmethod
  78. def _targets_to_instances(targets: dict, img_shape) -> Instances:
  79. gt_instances = Instances(tuple(img_shape))
  80. gt_instances.boxes = targets['boxes']
  81. gt_instances.labels = targets['labels']
  82. gt_instances.obj_ids = targets['obj_ids']
  83. gt_instances.area = targets['area']
  84. return gt_instances
  85. def _pre_single_frame(self, idx: int):
  86. img_path = self.img_files[idx]
  87. label_path = self.label_files[idx]
  88. if 'crowdhuman' in img_path:
  89. img_path = img_path.replace('.jpg', '.png')
  90. img = Image.open(img_path)
  91. targets = {}
  92. w, h = img._size
  93. assert w > 0 and h > 0, "invalid image {} with shape {} {}".format(img_path, w, h)
  94. if osp.isfile(label_path):
  95. labels0 = np.loadtxt(label_path, dtype=np.float32).reshape(-1, 6)
  96. # normalized cewh to pixel xyxy format
  97. labels = labels0.copy()
  98. labels[:, 2] = w * (labels0[:, 2] - labels0[:, 4] / 2)
  99. labels[:, 3] = h * (labels0[:, 3] - labels0[:, 5] / 2)
  100. labels[:, 4] = w * (labels0[:, 2] + labels0[:, 4] / 2)
  101. labels[:, 5] = h * (labels0[:, 3] + labels0[:, 5] / 2)
  102. else:
  103. raise ValueError('invalid label path: {}'.format(label_path))
  104. video_name = '/'.join(label_path.split('/')[:-1])
  105. obj_idx_offset = self.video_dict[video_name] * 1000000 # 1000000 unique ids is enough for a video.
  106. if 'crowdhuman' in img_path:
  107. targets['dataset'] = 'CrowdHuman'
  108. elif 'MOT17' in img_path:
  109. targets['dataset'] = 'MOT17'
  110. else:
  111. raise NotImplementedError()
  112. targets['boxes'] = []
  113. targets['area'] = []
  114. targets['iscrowd'] = []
  115. targets['labels'] = []
  116. targets['obj_ids'] = []
  117. targets['image_id'] = torch.as_tensor(idx)
  118. targets['size'] = torch.as_tensor([h, w])
  119. targets['orig_size'] = torch.as_tensor([h, w])
  120. for label in labels:
  121. targets['boxes'].append(label[2:6].tolist())
  122. targets['area'].append(label[4] * label[5])
  123. targets['iscrowd'].append(0)
  124. targets['labels'].append(0)
  125. obj_id = label[1] + obj_idx_offset if label[1] >= 0 else label[1]
  126. targets['obj_ids'].append(obj_id) # relative id
  127. targets['area'] = torch.as_tensor(targets['area'])
  128. targets['iscrowd'] = torch.as_tensor(targets['iscrowd'])
  129. targets['labels'] = torch.as_tensor(targets['labels'])
  130. targets['obj_ids'] = torch.as_tensor(targets['obj_ids'])
  131. targets['boxes'] = torch.as_tensor(targets['boxes'], dtype=torch.float32).reshape(-1, 4)
  132. # targets['boxes'][:, 0::2].clamp_(min=0, max=w)
  133. # targets['boxes'][:, 1::2].clamp_(min=0, max=h)
  134. return img, targets
  135. def _get_sample_range(self, start_idx):
  136. # take default sampling method for normal dataset.
  137. assert self.sample_mode in ['fixed_interval', 'random_interval'], 'invalid sample mode: {}'.format(self.sample_mode)
  138. if self.sample_mode == 'fixed_interval':
  139. sample_interval = self.sample_interval
  140. elif self.sample_mode == 'random_interval':
  141. sample_interval = np.random.randint(1, self.sample_interval + 1)
  142. default_range = start_idx, start_idx + (self.num_frames_per_batch - 1) * sample_interval + 1, sample_interval
  143. return default_range
  144. def pre_continuous_frames(self, start, end, interval=1):
  145. targets = []
  146. images = []
  147. for i in range(start, end, interval):
  148. img_i, targets_i = self._pre_single_frame(i)
  149. images.append(img_i)
  150. targets.append(targets_i)
  151. return images, targets
  152. def __getitem__(self, idx):
  153. sample_start, sample_end, sample_interval = self._get_sample_range(idx)
  154. images, targets = self.pre_continuous_frames(sample_start, sample_end, sample_interval)
  155. data = {}
  156. dataset_name = targets[0]['dataset']
  157. transform = self.dataset2transform[dataset_name]
  158. if transform is not None:
  159. images, targets = transform(images, targets)
  160. gt_instances = []
  161. for img_i, targets_i in zip(images, targets):
  162. gt_instances_i = self._targets_to_instances(targets_i, img_i.shape[1:3])
  163. gt_instances.append(gt_instances_i)
  164. data.update({
  165. 'imgs': images,
  166. 'gt_instances': gt_instances,
  167. })
  168. if self.args.vis:
  169. data['ori_img'] = [target_i['ori_img'] for target_i in targets]
  170. return data
  171. def __len__(self):
  172. return self.item_num
  173. class DetMOTDetectionValidation(DetMOTDetection):
  174. def __init__(self, args, seqs_folder, dataset2transform):
  175. args.data_txt_path = args.val_data_txt_path
  176. super().__init__(args, seqs_folder, dataset2transform)
  177. def make_transforms_for_mot17(image_set, args=None):
  178. normalize = T.MotCompose([
  179. T.MotToTensor(),
  180. T.MotNormalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
  181. ])
  182. scales = [608, 640, 672, 704, 736, 768, 800, 832, 864, 896, 928, 960, 992]
  183. if image_set == 'train':
  184. return T.MotCompose([
  185. T.MotRandomHorizontalFlip(),
  186. T.MotRandomSelect(
  187. T.MotRandomResize(scales, max_size=1536),
  188. T.MotCompose([
  189. T.MotRandomResize([400, 500, 600]),
  190. T.FixedMotRandomCrop(384, 600),
  191. T.MotRandomResize(scales, max_size=1536),
  192. ])
  193. ),
  194. normalize,
  195. ])
  196. if image_set == 'val':
  197. return T.MotCompose([
  198. T.MotRandomResize([800], max_size=1333),
  199. normalize,
  200. ])
  201. raise ValueError(f'unknown {image_set}')
  202. def make_transforms_for_crowdhuman(image_set, args=None):
  203. normalize = T.MotCompose([
  204. T.MotToTensor(),
  205. T.MotNormalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
  206. ])
  207. scales = [608, 640, 672, 704, 736, 768, 800, 832, 864, 896, 928, 960, 992]
  208. if image_set == 'train':
  209. return T.MotCompose([
  210. T.MotRandomHorizontalFlip(),
  211. T.FixedMotRandomShift(bs=1),
  212. T.MotRandomSelect(
  213. T.MotRandomResize(scales, max_size=1536),
  214. T.MotCompose([
  215. T.MotRandomResize([400, 500, 600]),
  216. T.FixedMotRandomCrop(384, 600),
  217. T.MotRandomResize(scales, max_size=1536),
  218. ])
  219. ),
  220. normalize,
  221. ])
  222. if image_set == 'val':
  223. return T.MotCompose([
  224. T.MotRandomResize([800], max_size=1333),
  225. normalize,
  226. ])
  227. raise ValueError(f'unknown {image_set}')
  228. def build_dataset2transform(args, image_set):
  229. mot17_train = make_transforms_for_mot17('train', args)
  230. mot17_test = make_transforms_for_mot17('val', args)
  231. crowdhuman_train = make_transforms_for_crowdhuman('train', args)
  232. dataset2transform_train = {'MOT17': mot17_train, 'CrowdHuman': crowdhuman_train}
  233. dataset2transform_val = {'MOT17': mot17_test, 'CrowdHuman': mot17_test}
  234. if image_set == 'train':
  235. return dataset2transform_train
  236. elif image_set == 'val':
  237. return dataset2transform_val
  238. else:
  239. raise NotImplementedError()
  240. def build(image_set, args):
  241. root = Path(args.mot_path)
  242. assert root.exists(), f'provided MOT path {root} does not exist'
  243. dataset2transform = build_dataset2transform(args, image_set)
  244. if image_set == 'train':
  245. data_txt_path = args.data_txt_path_train
  246. dataset = DetMOTDetection(args, data_txt_path=data_txt_path, seqs_folder=root, dataset2transform=dataset2transform)
  247. if image_set == 'val':
  248. data_txt_path = args.data_txt_path_val
  249. dataset = DetMOTDetection(args, data_txt_path=data_txt_path, seqs_folder=root, dataset2transform=dataset2transform)
  250. return dataset