engine_track.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277
  1. # Modified by Peize Sun, Rufeng Zhang
  2. # ------------------------------------------------------------------------
  3. # Deformable DETR
  4. # Copyright (c) 2020 SenseTime. All Rights Reserved.
  5. # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
  6. # ------------------------------------------------------------------------
  7. # Modified from DETR (https://github.com/facebookresearch/detr)
  8. # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
  9. # ------------------------------------------------------------------------
  10. """
  11. Train and eval functions used in main.py
  12. """
  13. import math
  14. import os
  15. import sys
  16. from typing import Iterable
  17. import torch
  18. import util.misc as utils
  19. from datasets.coco_eval import CocoEvaluator
  20. from datasets.panoptic_eval import PanopticEvaluator
  21. from datasets.data_prefetcher import data_prefetcher
  22. from mot_online.byte_tracker import BYTETracker
  23. def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
  24. data_loader: Iterable, optimizer: torch.optim.Optimizer,
  25. device: torch.device, epoch: int, max_norm: float = 0):
  26. model.train()
  27. criterion.train()
  28. metric_logger = utils.MetricLogger(delimiter=" ")
  29. metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
  30. metric_logger.add_meter('class_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}'))
  31. metric_logger.add_meter('grad_norm', utils.SmoothedValue(window_size=1, fmt='{value:.2f}'))
  32. header = 'Epoch: [{}]'.format(epoch)
  33. print_freq = 10
  34. prefetcher = data_prefetcher(data_loader, device, prefetch=True)
  35. samples, targets = prefetcher.next()
  36. # for samples, targets in metric_logger.log_every(data_loader, print_freq, header):
  37. for _ in metric_logger.log_every(range(len(data_loader)), print_freq, header):
  38. outputs, pre_outputs, pre_targets = model([samples, targets])
  39. loss_dict = criterion(outputs, targets, pre_outputs, pre_targets)
  40. weight_dict = criterion.weight_dict
  41. losses = sum(loss_dict[k] * weight_dict[k] for k in loss_dict.keys() if k in weight_dict)
  42. # reduce losses over all GPUs for logging purposes
  43. loss_dict_reduced = utils.reduce_dict(loss_dict)
  44. loss_dict_reduced_unscaled = {f'{k}_unscaled': v
  45. for k, v in loss_dict_reduced.items()}
  46. loss_dict_reduced_scaled = {k: v * weight_dict[k]
  47. for k, v in loss_dict_reduced.items() if k in weight_dict}
  48. losses_reduced_scaled = sum(loss_dict_reduced_scaled.values())
  49. loss_value = losses_reduced_scaled.item()
  50. if not math.isfinite(loss_value):
  51. print("Loss is {}, stopping training".format(loss_value))
  52. print(loss_dict_reduced)
  53. sys.exit(1)
  54. optimizer.zero_grad()
  55. losses.backward()
  56. if max_norm > 0:
  57. grad_total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
  58. else:
  59. grad_total_norm = utils.get_total_grad_norm(model.parameters(), max_norm)
  60. optimizer.step()
  61. metric_logger.update(loss=loss_value, **loss_dict_reduced_scaled, **loss_dict_reduced_unscaled)
  62. metric_logger.update(class_error=loss_dict_reduced['class_error'])
  63. metric_logger.update(lr=optimizer.param_groups[0]["lr"])
  64. metric_logger.update(grad_norm=grad_total_norm)
  65. samples, targets = prefetcher.next()
  66. # gather the stats from all processes
  67. metric_logger.synchronize_between_processes()
  68. print("Averaged stats:", metric_logger)
  69. return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
  70. @torch.no_grad()
  71. def evaluate(model, criterion, postprocessors, data_loader, base_ds, device, output_dir, tracker=None,
  72. phase='train', det_val=False):
  73. model.eval()
  74. criterion.eval()
  75. metric_logger = utils.MetricLogger(delimiter=" ")
  76. metric_logger.add_meter('class_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}'))
  77. header = 'Test:'
  78. iou_types = tuple(k for k in ('segm', 'bbox') if k in postprocessors.keys())
  79. coco_evaluator = CocoEvaluator(base_ds, iou_types)
  80. # coco_evaluator.coco_eval[iou_types[0]].params.iouThrs = [0, 0.1, 0.5, 0.75]
  81. panoptic_evaluator = None
  82. if 'panoptic' in postprocessors.keys():
  83. panoptic_evaluator = PanopticEvaluator(
  84. data_loader.dataset.ann_file,
  85. data_loader.dataset.ann_folder,
  86. output_dir=os.path.join(output_dir, "panoptic_eval"),
  87. )
  88. res_tracks = dict()
  89. pre_embed = None
  90. for samples, targets in metric_logger.log_every(data_loader, 10, header):
  91. # pre process for track.
  92. if tracker is not None:
  93. if phase != 'train':
  94. assert samples.tensors.shape[0] == 1, "Now only support inference of batchsize 1."
  95. frame_id = targets[0].get("frame_id", None)
  96. assert frame_id is not None
  97. frame_id = frame_id.item()
  98. if frame_id == 1:
  99. tracker.reset_all()
  100. pre_embed = None
  101. samples = samples.to(device)
  102. targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
  103. if det_val:
  104. outputs = model(samples)
  105. else:
  106. outputs, pre_embed = model(samples, pre_embed)
  107. loss_dict = criterion(outputs, targets)
  108. weight_dict = criterion.weight_dict
  109. # reduce losses over all GPUs for logging purposes
  110. loss_dict_reduced = utils.reduce_dict(loss_dict)
  111. loss_dict_reduced_scaled = {k: v * weight_dict[k]
  112. for k, v in loss_dict_reduced.items() if k in weight_dict}
  113. loss_dict_reduced_unscaled = {f'{k}_unscaled': v
  114. for k, v in loss_dict_reduced.items()}
  115. metric_logger.update(loss=sum(loss_dict_reduced_scaled.values()),
  116. **loss_dict_reduced_scaled,
  117. **loss_dict_reduced_unscaled)
  118. metric_logger.update(class_error=loss_dict_reduced['class_error'])
  119. orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0)
  120. results = postprocessors['bbox'](outputs, orig_target_sizes)
  121. if 'segm' in postprocessors.keys():
  122. target_sizes = torch.stack([t["size"] for t in targets], dim=0)
  123. results = postprocessors['segm'](results, outputs, orig_target_sizes, target_sizes)
  124. res = {target['image_id'].item(): output for target, output in zip(targets, results)}
  125. # post process for track.
  126. if tracker is not None:
  127. if frame_id == 1:
  128. res_track = tracker.init_track(results[0])
  129. else:
  130. res_track = tracker.step(results[0])
  131. res_tracks[targets[0]['image_id'].item()] = res_track
  132. if coco_evaluator is not None:
  133. coco_evaluator.update(res)
  134. if panoptic_evaluator is not None:
  135. res_pano = postprocessors["panoptic"](outputs, target_sizes, orig_target_sizes)
  136. for i, target in enumerate(targets):
  137. image_id = target["image_id"].item()
  138. file_name = f"{image_id:012d}.png"
  139. res_pano[i]["image_id"] = image_id
  140. res_pano[i]["file_name"] = file_name
  141. panoptic_evaluator.update(res_pano)
  142. # gather the stats from all processes
  143. metric_logger.synchronize_between_processes()
  144. print("Averaged stats:", metric_logger)
  145. if coco_evaluator is not None:
  146. coco_evaluator.synchronize_between_processes()
  147. if panoptic_evaluator is not None:
  148. panoptic_evaluator.synchronize_between_processes()
  149. # accumulate predictions from all images
  150. if coco_evaluator is not None:
  151. coco_evaluator.accumulate()
  152. coco_evaluator.summarize()
  153. panoptic_res = None
  154. if panoptic_evaluator is not None:
  155. panoptic_res = panoptic_evaluator.summarize()
  156. stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()}
  157. if coco_evaluator is not None:
  158. if 'bbox' in postprocessors.keys():
  159. stats['coco_eval_bbox'] = coco_evaluator.coco_eval['bbox'].stats.tolist()
  160. if 'segm' in postprocessors.keys():
  161. stats['coco_eval_masks'] = coco_evaluator.coco_eval['segm'].stats.tolist()
  162. if panoptic_res is not None:
  163. stats['PQ_all'] = panoptic_res["All"]
  164. stats['PQ_th'] = panoptic_res["Things"]
  165. stats['PQ_st'] = panoptic_res["Stuff"]
  166. return stats, coco_evaluator, res_tracks
  167. @torch.no_grad()
  168. def evaluate_track(args, model, criterion, postprocessors, data_loader, base_ds, device, output_dir, tracker=None,
  169. phase='train', det_val=False):
  170. model.eval()
  171. criterion.eval()
  172. metric_logger = utils.MetricLogger(delimiter=" ")
  173. metric_logger.add_meter('class_error', utils.SmoothedValue(window_size=1, fmt='{value:.2f}'))
  174. header = 'Test:'
  175. iou_types = tuple(k for k in ('segm', 'bbox') if k in postprocessors.keys())
  176. coco_evaluator = CocoEvaluator(base_ds, iou_types)
  177. # coco_evaluator.coco_eval[iou_types[0]].params.iouThrs = [0, 0.1, 0.5, 0.75]
  178. res_tracks = dict()
  179. pre_embed = None
  180. for samples, targets in metric_logger.log_every(data_loader, 50, header):
  181. # pre process for track.
  182. if tracker is not None:
  183. frame_id = targets[0].get("frame_id", None)
  184. assert frame_id is not None
  185. frame_id = frame_id.item()
  186. if frame_id == 1:
  187. tracker = BYTETracker(args)
  188. pre_embed = None
  189. samples = samples.to(device)
  190. targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
  191. if det_val:
  192. outputs = model(samples)
  193. else:
  194. outputs, pre_embed = model(samples, pre_embed)
  195. loss_dict = criterion(outputs, targets)
  196. weight_dict = criterion.weight_dict
  197. # reduce losses over all GPUs for logging purposes
  198. loss_dict_reduced = utils.reduce_dict(loss_dict)
  199. loss_dict_reduced_scaled = {k: v * weight_dict[k]
  200. for k, v in loss_dict_reduced.items() if k in weight_dict}
  201. loss_dict_reduced_unscaled = {f'{k}_unscaled': v
  202. for k, v in loss_dict_reduced.items()}
  203. metric_logger.update(loss=sum(loss_dict_reduced_scaled.values()),
  204. **loss_dict_reduced_scaled,
  205. **loss_dict_reduced_unscaled)
  206. metric_logger.update(class_error=loss_dict_reduced['class_error'])
  207. orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0)
  208. results = postprocessors['bbox'](outputs, orig_target_sizes)
  209. if 'segm' in postprocessors.keys():
  210. target_sizes = torch.stack([t["size"] for t in targets], dim=0)
  211. results = postprocessors['segm'](results, outputs, orig_target_sizes, target_sizes)
  212. res = {target['image_id'].item(): output for target, output in zip(targets, results)}
  213. # post process for track.
  214. if tracker is not None:
  215. res_track = tracker.update(results[0])
  216. res_tracks[targets[0]['image_id'].item()] = res_track
  217. if coco_evaluator is not None:
  218. coco_evaluator.update(res)
  219. # gather the stats from all processes
  220. metric_logger.synchronize_between_processes()
  221. print("Averaged stats:", metric_logger)
  222. if coco_evaluator is not None:
  223. coco_evaluator.synchronize_between_processes()
  224. # accumulate predictions from all images
  225. if coco_evaluator is not None:
  226. coco_evaluator.accumulate()
  227. coco_evaluator.summarize()
  228. stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()}
  229. if coco_evaluator is not None:
  230. if 'bbox' in postprocessors.keys():
  231. stats['coco_eval_bbox'] = coco_evaluator.coco_eval['bbox'].stats.tolist()
  232. if 'segm' in postprocessors.keys():
  233. stats['coco_eval_masks'] = coco_evaluator.coco_eval['segm'].stats.tolist()
  234. return stats, coco_evaluator, res_tracks