keypoint_coco.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. this code is base on https://github.com/open-mmlab/mmpose
  16. """
  17. import os
  18. import cv2
  19. import numpy as np
  20. import json
  21. import copy
  22. import pycocotools
  23. from pycocotools.coco import COCO
  24. from .dataset import DetDataset
  25. from ppdet.core.workspace import register, serializable
  26. @serializable
  27. class KeypointBottomUpBaseDataset(DetDataset):
  28. """Base class for bottom-up datasets.
  29. All datasets should subclass it.
  30. All subclasses should overwrite:
  31. Methods:`_get_imganno`
  32. Args:
  33. dataset_dir (str): Root path to the dataset.
  34. anno_path (str): Relative path to the annotation file.
  35. image_dir (str): Path to a directory where images are held.
  36. Default: None.
  37. num_joints (int): keypoint numbers
  38. transform (composed(operators)): A sequence of data transforms.
  39. shard (list): [rank, worldsize], the distributed env params
  40. test_mode (bool): Store True when building test or
  41. validation dataset. Default: False.
  42. """
  43. def __init__(self,
  44. dataset_dir,
  45. image_dir,
  46. anno_path,
  47. num_joints,
  48. transform=[],
  49. shard=[0, 1],
  50. test_mode=False):
  51. super().__init__(dataset_dir, image_dir, anno_path)
  52. self.image_info = {}
  53. self.ann_info = {}
  54. self.img_prefix = os.path.join(dataset_dir, image_dir)
  55. self.transform = transform
  56. self.test_mode = test_mode
  57. self.ann_info['num_joints'] = num_joints
  58. self.img_ids = []
  59. def parse_dataset(self):
  60. pass
  61. def __len__(self):
  62. """Get dataset length."""
  63. return len(self.img_ids)
  64. def _get_imganno(self, idx):
  65. """Get anno for a single image."""
  66. raise NotImplementedError
  67. def __getitem__(self, idx):
  68. """Prepare image for training given the index."""
  69. records = copy.deepcopy(self._get_imganno(idx))
  70. records['image'] = cv2.imread(records['image_file'])
  71. records['image'] = cv2.cvtColor(records['image'], cv2.COLOR_BGR2RGB)
  72. records['mask'] = (records['mask'] + 0).astype('uint8')
  73. records = self.transform(records)
  74. return records
  75. def parse_dataset(self):
  76. return
  77. @register
  78. @serializable
  79. class KeypointBottomUpCocoDataset(KeypointBottomUpBaseDataset):
  80. """COCO dataset for bottom-up pose estimation.
  81. The dataset loads raw features and apply specified transforms
  82. to return a dict containing the image tensors and other information.
  83. COCO keypoint indexes::
  84. 0: 'nose',
  85. 1: 'left_eye',
  86. 2: 'right_eye',
  87. 3: 'left_ear',
  88. 4: 'right_ear',
  89. 5: 'left_shoulder',
  90. 6: 'right_shoulder',
  91. 7: 'left_elbow',
  92. 8: 'right_elbow',
  93. 9: 'left_wrist',
  94. 10: 'right_wrist',
  95. 11: 'left_hip',
  96. 12: 'right_hip',
  97. 13: 'left_knee',
  98. 14: 'right_knee',
  99. 15: 'left_ankle',
  100. 16: 'right_ankle'
  101. Args:
  102. dataset_dir (str): Root path to the dataset.
  103. anno_path (str): Relative path to the annotation file.
  104. image_dir (str): Path to a directory where images are held.
  105. Default: None.
  106. num_joints (int): keypoint numbers
  107. transform (composed(operators)): A sequence of data transforms.
  108. shard (list): [rank, worldsize], the distributed env params
  109. test_mode (bool): Store True when building test or
  110. validation dataset. Default: False.
  111. """
  112. def __init__(self,
  113. dataset_dir,
  114. image_dir,
  115. anno_path,
  116. num_joints,
  117. transform=[],
  118. shard=[0, 1],
  119. test_mode=False):
  120. super().__init__(dataset_dir, image_dir, anno_path, num_joints,
  121. transform, shard, test_mode)
  122. self.ann_file = os.path.join(dataset_dir, anno_path)
  123. self.shard = shard
  124. self.test_mode = test_mode
  125. def parse_dataset(self):
  126. self.coco = COCO(self.ann_file)
  127. self.img_ids = self.coco.getImgIds()
  128. if not self.test_mode:
  129. self.img_ids = [
  130. img_id for img_id in self.img_ids
  131. if len(self.coco.getAnnIds(
  132. imgIds=img_id, iscrowd=None)) > 0
  133. ]
  134. blocknum = int(len(self.img_ids) / self.shard[1])
  135. self.img_ids = self.img_ids[(blocknum * self.shard[0]):(blocknum * (
  136. self.shard[0] + 1))]
  137. self.num_images = len(self.img_ids)
  138. self.id2name, self.name2id = self._get_mapping_id_name(self.coco.imgs)
  139. self.dataset_name = 'coco'
  140. cat_ids = self.coco.getCatIds()
  141. self.catid2clsid = dict({catid: i for i, catid in enumerate(cat_ids)})
  142. print('=> num_images: {}'.format(self.num_images))
  143. @staticmethod
  144. def _get_mapping_id_name(imgs):
  145. """
  146. Args:
  147. imgs (dict): dict of image info.
  148. Returns:
  149. tuple: Image name & id mapping dicts.
  150. - id2name (dict): Mapping image id to name.
  151. - name2id (dict): Mapping image name to id.
  152. """
  153. id2name = {}
  154. name2id = {}
  155. for image_id, image in imgs.items():
  156. file_name = image['file_name']
  157. id2name[image_id] = file_name
  158. name2id[file_name] = image_id
  159. return id2name, name2id
  160. def _get_imganno(self, idx):
  161. """Get anno for a single image.
  162. Args:
  163. idx (int): image idx
  164. Returns:
  165. dict: info for model training
  166. """
  167. coco = self.coco
  168. img_id = self.img_ids[idx]
  169. ann_ids = coco.getAnnIds(imgIds=img_id)
  170. anno = coco.loadAnns(ann_ids)
  171. mask = self._get_mask(anno, idx)
  172. anno = [
  173. obj for obj in anno
  174. if obj['iscrowd'] == 0 or obj['num_keypoints'] > 0
  175. ]
  176. joints, orgsize = self._get_joints(anno, idx)
  177. db_rec = {}
  178. db_rec['im_id'] = img_id
  179. db_rec['image_file'] = os.path.join(self.img_prefix,
  180. self.id2name[img_id])
  181. db_rec['mask'] = mask
  182. db_rec['joints'] = joints
  183. db_rec['im_shape'] = orgsize
  184. return db_rec
  185. def _get_joints(self, anno, idx):
  186. """Get joints for all people in an image."""
  187. num_people = len(anno)
  188. joints = np.zeros(
  189. (num_people, self.ann_info['num_joints'], 3), dtype=np.float32)
  190. for i, obj in enumerate(anno):
  191. joints[i, :self.ann_info['num_joints'], :3] = \
  192. np.array(obj['keypoints']).reshape([-1, 3])
  193. img_info = self.coco.loadImgs(self.img_ids[idx])[0]
  194. joints[..., 0] /= img_info['width']
  195. joints[..., 1] /= img_info['height']
  196. orgsize = np.array([img_info['height'], img_info['width']])
  197. return joints, orgsize
  198. def _get_mask(self, anno, idx):
  199. """Get ignore masks to mask out losses."""
  200. coco = self.coco
  201. img_info = coco.loadImgs(self.img_ids[idx])[0]
  202. m = np.zeros((img_info['height'], img_info['width']), dtype=np.float32)
  203. for obj in anno:
  204. if 'segmentation' in obj:
  205. if obj['iscrowd']:
  206. rle = pycocotools.mask.frPyObjects(obj['segmentation'],
  207. img_info['height'],
  208. img_info['width'])
  209. m += pycocotools.mask.decode(rle)
  210. elif obj['num_keypoints'] == 0:
  211. rles = pycocotools.mask.frPyObjects(obj['segmentation'],
  212. img_info['height'],
  213. img_info['width'])
  214. for rle in rles:
  215. m += pycocotools.mask.decode(rle)
  216. return m < 0.5
  217. @register
  218. @serializable
  219. class KeypointBottomUpCrowdPoseDataset(KeypointBottomUpCocoDataset):
  220. """CrowdPose dataset for bottom-up pose estimation.
  221. The dataset loads raw features and apply specified transforms
  222. to return a dict containing the image tensors and other information.
  223. CrowdPose keypoint indexes::
  224. 0: 'left_shoulder',
  225. 1: 'right_shoulder',
  226. 2: 'left_elbow',
  227. 3: 'right_elbow',
  228. 4: 'left_wrist',
  229. 5: 'right_wrist',
  230. 6: 'left_hip',
  231. 7: 'right_hip',
  232. 8: 'left_knee',
  233. 9: 'right_knee',
  234. 10: 'left_ankle',
  235. 11: 'right_ankle',
  236. 12: 'top_head',
  237. 13: 'neck'
  238. Args:
  239. dataset_dir (str): Root path to the dataset.
  240. anno_path (str): Relative path to the annotation file.
  241. image_dir (str): Path to a directory where images are held.
  242. Default: None.
  243. num_joints (int): keypoint numbers
  244. transform (composed(operators)): A sequence of data transforms.
  245. shard (list): [rank, worldsize], the distributed env params
  246. test_mode (bool): Store True when building test or
  247. validation dataset. Default: False.
  248. """
  249. def __init__(self,
  250. dataset_dir,
  251. image_dir,
  252. anno_path,
  253. num_joints,
  254. transform=[],
  255. shard=[0, 1],
  256. test_mode=False):
  257. super().__init__(dataset_dir, image_dir, anno_path, num_joints,
  258. transform, shard, test_mode)
  259. self.ann_file = os.path.join(dataset_dir, anno_path)
  260. self.shard = shard
  261. self.test_mode = test_mode
  262. def parse_dataset(self):
  263. self.coco = COCO(self.ann_file)
  264. self.img_ids = self.coco.getImgIds()
  265. if not self.test_mode:
  266. self.img_ids = [
  267. img_id for img_id in self.img_ids
  268. if len(self.coco.getAnnIds(
  269. imgIds=img_id, iscrowd=None)) > 0
  270. ]
  271. blocknum = int(len(self.img_ids) / self.shard[1])
  272. self.img_ids = self.img_ids[(blocknum * self.shard[0]):(blocknum * (
  273. self.shard[0] + 1))]
  274. self.num_images = len(self.img_ids)
  275. self.id2name, self.name2id = self._get_mapping_id_name(self.coco.imgs)
  276. self.dataset_name = 'crowdpose'
  277. print('=> num_images: {}'.format(self.num_images))
  278. @serializable
  279. class KeypointTopDownBaseDataset(DetDataset):
  280. """Base class for top_down datasets.
  281. All datasets should subclass it.
  282. All subclasses should overwrite:
  283. Methods:`_get_db`
  284. Args:
  285. dataset_dir (str): Root path to the dataset.
  286. image_dir (str): Path to a directory where images are held.
  287. anno_path (str): Relative path to the annotation file.
  288. num_joints (int): keypoint numbers
  289. transform (composed(operators)): A sequence of data transforms.
  290. """
  291. def __init__(self,
  292. dataset_dir,
  293. image_dir,
  294. anno_path,
  295. num_joints,
  296. transform=[]):
  297. super().__init__(dataset_dir, image_dir, anno_path)
  298. self.image_info = {}
  299. self.ann_info = {}
  300. self.img_prefix = os.path.join(dataset_dir, image_dir)
  301. self.transform = transform
  302. self.ann_info['num_joints'] = num_joints
  303. self.db = []
  304. def __len__(self):
  305. """Get dataset length."""
  306. return len(self.db)
  307. def _get_db(self):
  308. """Get a sample"""
  309. raise NotImplementedError
  310. def __getitem__(self, idx):
  311. """Prepare sample for training given the index."""
  312. records = copy.deepcopy(self.db[idx])
  313. records['image'] = cv2.imread(records['image_file'], cv2.IMREAD_COLOR |
  314. cv2.IMREAD_IGNORE_ORIENTATION)
  315. records['image'] = cv2.cvtColor(records['image'], cv2.COLOR_BGR2RGB)
  316. records['score'] = records['score'] if 'score' in records else 1
  317. records = self.transform(records)
  318. # print('records', records)
  319. return records
  320. @register
  321. @serializable
  322. class KeypointTopDownCocoDataset(KeypointTopDownBaseDataset):
  323. """COCO dataset for top-down pose estimation.
  324. The dataset loads raw features and apply specified transforms
  325. to return a dict containing the image tensors and other information.
  326. COCO keypoint indexes:
  327. 0: 'nose',
  328. 1: 'left_eye',
  329. 2: 'right_eye',
  330. 3: 'left_ear',
  331. 4: 'right_ear',
  332. 5: 'left_shoulder',
  333. 6: 'right_shoulder',
  334. 7: 'left_elbow',
  335. 8: 'right_elbow',
  336. 9: 'left_wrist',
  337. 10: 'right_wrist',
  338. 11: 'left_hip',
  339. 12: 'right_hip',
  340. 13: 'left_knee',
  341. 14: 'right_knee',
  342. 15: 'left_ankle',
  343. 16: 'right_ankle'
  344. Args:
  345. dataset_dir (str): Root path to the dataset.
  346. image_dir (str): Path to a directory where images are held.
  347. anno_path (str): Relative path to the annotation file.
  348. num_joints (int): Keypoint numbers
  349. trainsize (list):[w, h] Image target size
  350. transform (composed(operators)): A sequence of data transforms.
  351. bbox_file (str): Path to a detection bbox file
  352. Default: None.
  353. use_gt_bbox (bool): Whether to use ground truth bbox
  354. Default: True.
  355. pixel_std (int): The pixel std of the scale
  356. Default: 200.
  357. image_thre (float): The threshold to filter the detection box
  358. Default: 0.0.
  359. """
  360. def __init__(self,
  361. dataset_dir,
  362. image_dir,
  363. anno_path,
  364. num_joints,
  365. trainsize,
  366. transform=[],
  367. bbox_file=None,
  368. use_gt_bbox=True,
  369. pixel_std=200,
  370. image_thre=0.0):
  371. super().__init__(dataset_dir, image_dir, anno_path, num_joints,
  372. transform)
  373. self.bbox_file = bbox_file
  374. self.use_gt_bbox = use_gt_bbox
  375. self.trainsize = trainsize
  376. self.pixel_std = pixel_std
  377. self.image_thre = image_thre
  378. self.dataset_name = 'coco'
  379. def parse_dataset(self):
  380. if self.use_gt_bbox:
  381. self.db = self._load_coco_keypoint_annotations()
  382. else:
  383. self.db = self._load_coco_person_detection_results()
  384. def _load_coco_keypoint_annotations(self):
  385. coco = COCO(self.get_anno())
  386. img_ids = coco.getImgIds()
  387. gt_db = []
  388. for index in img_ids:
  389. im_ann = coco.loadImgs(index)[0]
  390. width = im_ann['width']
  391. height = im_ann['height']
  392. file_name = im_ann['file_name']
  393. im_id = int(im_ann["id"])
  394. annIds = coco.getAnnIds(imgIds=index, iscrowd=False)
  395. objs = coco.loadAnns(annIds)
  396. valid_objs = []
  397. for obj in objs:
  398. x, y, w, h = obj['bbox']
  399. x1 = np.max((0, x))
  400. y1 = np.max((0, y))
  401. x2 = np.min((width - 1, x1 + np.max((0, w - 1))))
  402. y2 = np.min((height - 1, y1 + np.max((0, h - 1))))
  403. if obj['area'] > 0 and x2 >= x1 and y2 >= y1:
  404. obj['clean_bbox'] = [x1, y1, x2 - x1, y2 - y1]
  405. valid_objs.append(obj)
  406. objs = valid_objs
  407. rec = []
  408. for obj in objs:
  409. if max(obj['keypoints']) == 0:
  410. continue
  411. joints = np.zeros(
  412. (self.ann_info['num_joints'], 3), dtype=np.float)
  413. joints_vis = np.zeros(
  414. (self.ann_info['num_joints'], 3), dtype=np.float)
  415. for ipt in range(self.ann_info['num_joints']):
  416. joints[ipt, 0] = obj['keypoints'][ipt * 3 + 0]
  417. joints[ipt, 1] = obj['keypoints'][ipt * 3 + 1]
  418. joints[ipt, 2] = 0
  419. t_vis = obj['keypoints'][ipt * 3 + 2]
  420. if t_vis > 1:
  421. t_vis = 1
  422. joints_vis[ipt, 0] = t_vis
  423. joints_vis[ipt, 1] = t_vis
  424. joints_vis[ipt, 2] = 0
  425. center, scale = self._box2cs(obj['clean_bbox'][:4])
  426. rec.append({
  427. 'image_file': os.path.join(self.img_prefix, file_name),
  428. 'center': center,
  429. 'scale': scale,
  430. 'joints': joints,
  431. 'joints_vis': joints_vis,
  432. 'im_id': im_id,
  433. })
  434. gt_db.extend(rec)
  435. return gt_db
  436. def _box2cs(self, box):
  437. x, y, w, h = box[:4]
  438. center = np.zeros((2), dtype=np.float32)
  439. center[0] = x + w * 0.5
  440. center[1] = y + h * 0.5
  441. aspect_ratio = self.trainsize[0] * 1.0 / self.trainsize[1]
  442. if w > aspect_ratio * h:
  443. h = w * 1.0 / aspect_ratio
  444. elif w < aspect_ratio * h:
  445. w = h * aspect_ratio
  446. scale = np.array(
  447. [w * 1.0 / self.pixel_std, h * 1.0 / self.pixel_std],
  448. dtype=np.float32)
  449. if center[0] != -1:
  450. scale = scale * 1.25
  451. return center, scale
  452. def _load_coco_person_detection_results(self):
  453. all_boxes = None
  454. bbox_file_path = os.path.join(self.dataset_dir, self.bbox_file)
  455. with open(bbox_file_path, 'r') as f:
  456. all_boxes = json.load(f)
  457. if not all_boxes:
  458. print('=> Load %s fail!' % bbox_file_path)
  459. return None
  460. kpt_db = []
  461. for n_img in range(0, len(all_boxes)):
  462. det_res = all_boxes[n_img]
  463. if det_res['category_id'] != 1:
  464. continue
  465. file_name = det_res[
  466. 'filename'] if 'filename' in det_res else '%012d.jpg' % det_res[
  467. 'image_id']
  468. img_name = os.path.join(self.img_prefix, file_name)
  469. box = det_res['bbox']
  470. score = det_res['score']
  471. im_id = int(det_res['image_id'])
  472. if score < self.image_thre:
  473. continue
  474. center, scale = self._box2cs(box)
  475. joints = np.zeros((self.ann_info['num_joints'], 3), dtype=np.float)
  476. joints_vis = np.ones(
  477. (self.ann_info['num_joints'], 3), dtype=np.float)
  478. kpt_db.append({
  479. 'image_file': img_name,
  480. 'im_id': im_id,
  481. 'center': center,
  482. 'scale': scale,
  483. 'score': score,
  484. 'joints': joints,
  485. 'joints_vis': joints_vis,
  486. })
  487. return kpt_db
  488. @register
  489. @serializable
  490. class KeypointTopDownMPIIDataset(KeypointTopDownBaseDataset):
  491. """MPII dataset for topdown pose estimation.
  492. The dataset loads raw features and apply specified transforms
  493. to return a dict containing the image tensors and other information.
  494. MPII keypoint indexes::
  495. 0: 'right_ankle',
  496. 1: 'right_knee',
  497. 2: 'right_hip',
  498. 3: 'left_hip',
  499. 4: 'left_knee',
  500. 5: 'left_ankle',
  501. 6: 'pelvis',
  502. 7: 'thorax',
  503. 8: 'upper_neck',
  504. 9: 'head_top',
  505. 10: 'right_wrist',
  506. 11: 'right_elbow',
  507. 12: 'right_shoulder',
  508. 13: 'left_shoulder',
  509. 14: 'left_elbow',
  510. 15: 'left_wrist',
  511. Args:
  512. dataset_dir (str): Root path to the dataset.
  513. image_dir (str): Path to a directory where images are held.
  514. anno_path (str): Relative path to the annotation file.
  515. num_joints (int): Keypoint numbers
  516. trainsize (list):[w, h] Image target size
  517. transform (composed(operators)): A sequence of data transforms.
  518. """
  519. def __init__(self,
  520. dataset_dir,
  521. image_dir,
  522. anno_path,
  523. num_joints,
  524. transform=[]):
  525. super().__init__(dataset_dir, image_dir, anno_path, num_joints,
  526. transform)
  527. self.dataset_name = 'mpii'
  528. def parse_dataset(self):
  529. with open(self.get_anno()) as anno_file:
  530. anno = json.load(anno_file)
  531. gt_db = []
  532. for a in anno:
  533. image_name = a['image']
  534. im_id = a['image_id'] if 'image_id' in a else int(
  535. os.path.splitext(image_name)[0])
  536. c = np.array(a['center'], dtype=np.float)
  537. s = np.array([a['scale'], a['scale']], dtype=np.float)
  538. # Adjust center/scale slightly to avoid cropping limbs
  539. if c[0] != -1:
  540. c[1] = c[1] + 15 * s[1]
  541. s = s * 1.25
  542. c = c - 1
  543. joints = np.zeros((self.ann_info['num_joints'], 3), dtype=np.float)
  544. joints_vis = np.zeros(
  545. (self.ann_info['num_joints'], 3), dtype=np.float)
  546. if 'joints' in a:
  547. joints_ = np.array(a['joints'])
  548. joints_[:, 0:2] = joints_[:, 0:2] - 1
  549. joints_vis_ = np.array(a['joints_vis'])
  550. assert len(joints_) == self.ann_info[
  551. 'num_joints'], 'joint num diff: {} vs {}'.format(
  552. len(joints_), self.ann_info['num_joints'])
  553. joints[:, 0:2] = joints_[:, 0:2]
  554. joints_vis[:, 0] = joints_vis_[:]
  555. joints_vis[:, 1] = joints_vis_[:]
  556. gt_db.append({
  557. 'image_file': os.path.join(self.img_prefix, image_name),
  558. 'im_id': im_id,
  559. 'center': c,
  560. 'scale': s,
  561. 'joints': joints,
  562. 'joints_vis': joints_vis
  563. })
  564. print("number length: {}".format(len(gt_db)))
  565. self.db = gt_db