main_track.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375
  1. # Modified by Peize Sun, Rufeng Zhang
  2. # ------------------------------------------------------------------------
  3. # Deformable DETR
  4. # Copyright (c) 2020 SenseTime. All Rights Reserved.
  5. # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
  6. # ------------------------------------------------------------------------
  7. # Modified from DETR (https://github.com/facebookresearch/detr)
  8. # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
  9. # ------------------------------------------------------------------------
  10. import argparse
  11. import datetime
  12. import json
  13. import random
  14. import time
  15. from pathlib import Path
  16. import numpy as np
  17. import torch
  18. from torch.utils.data import DataLoader
  19. import datasets
  20. import util.misc as utils
  21. import datasets.samplers as samplers
  22. from datasets import build_dataset, get_coco_api_from_dataset
  23. from engine_track import evaluate, train_one_epoch, evaluate_track
  24. from models import build_tracktrain_model, build_tracktest_model, build_model
  25. from models import Tracker
  26. from models import save_track
  27. from mot_online.byte_tracker import BYTETracker
  28. from collections import defaultdict
  29. def get_args_parser():
  30. parser = argparse.ArgumentParser('Deformable DETR Detector', add_help=False)
  31. parser.add_argument('--lr', default=2e-4, type=float)
  32. parser.add_argument('--lr_backbone_names', default=["backbone.0"], type=str, nargs='+')
  33. parser.add_argument('--lr_backbone', default=2e-5, type=float)
  34. parser.add_argument('--lr_linear_proj_names', default=['reference_points', 'sampling_offsets'], type=str, nargs='+')
  35. parser.add_argument('--lr_linear_proj_mult', default=0.1, type=float)
  36. parser.add_argument('--batch_size', default=1, type=int)
  37. parser.add_argument('--weight_decay', default=1e-4, type=float)
  38. parser.add_argument('--epochs', default=50, type=int)
  39. parser.add_argument('--lr_drop', default=40, type=int)
  40. parser.add_argument('--lr_drop_epochs', default=None, type=int, nargs='+')
  41. parser.add_argument('--clip_max_norm', default=0.1, type=float,
  42. help='gradient clipping max norm')
  43. parser.add_argument('--sgd', action='store_true')
  44. # Variants of Deformable DETR
  45. parser.add_argument('--with_box_refine', default=True, action='store_true')
  46. parser.add_argument('--two_stage', default=False, action='store_true')
  47. # Model parameters
  48. parser.add_argument('--frozen_weights', type=str, default=None,
  49. help="Path to the pretrained model. If set, only the mask head will be trained")
  50. # * Backbone
  51. parser.add_argument('--backbone', default='resnet50', type=str,
  52. help="Name of the convolutional backbone to use")
  53. parser.add_argument('--dilation', action='store_true',
  54. help="If true, we replace stride with dilation in the last convolutional block (DC5)")
  55. parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine', 'learned'),
  56. help="Type of positional embedding to use on top of the image features")
  57. parser.add_argument('--position_embedding_scale', default=2 * np.pi, type=float,
  58. help="position / size * scale")
  59. parser.add_argument('--num_feature_levels', default=4, type=int, help='number of feature levels')
  60. # * Transformer
  61. parser.add_argument('--enc_layers', default=6, type=int,
  62. help="Number of encoding layers in the transformer")
  63. parser.add_argument('--dec_layers', default=6, type=int,
  64. help="Number of decoding layers in the transformer")
  65. parser.add_argument('--dim_feedforward', default=1024, type=int,
  66. help="Intermediate size of the feedforward layers in the transformer blocks")
  67. parser.add_argument('--hidden_dim', default=256, type=int,
  68. help="Size of the embeddings (dimension of the transformer)")
  69. parser.add_argument('--dropout', default=0.1, type=float,
  70. help="Dropout applied in the transformer")
  71. parser.add_argument('--nheads', default=8, type=int,
  72. help="Number of attention heads inside the transformer's attentions")
  73. parser.add_argument('--num_queries', default=500, type=int,
  74. help="Number of query slots")
  75. parser.add_argument('--dec_n_points', default=4, type=int)
  76. parser.add_argument('--enc_n_points', default=4, type=int)
  77. # * Segmentation
  78. parser.add_argument('--masks', action='store_true',
  79. help="Train segmentation head if the flag is provided")
  80. # Loss
  81. parser.add_argument('--no_aux_loss', dest='aux_loss', action='store_false',
  82. help="Disables auxiliary decoding losses (loss at each layer)")
  83. # * Matcher
  84. parser.add_argument('--set_cost_class', default=2, type=float,
  85. help="Class coefficient in the matching cost")
  86. parser.add_argument('--set_cost_bbox', default=5, type=float,
  87. help="L1 box coefficient in the matching cost")
  88. parser.add_argument('--set_cost_giou', default=2, type=float,
  89. help="giou box coefficient in the matching cost")
  90. # * Loss coefficients
  91. parser.add_argument('--mask_loss_coef', default=1, type=float)
  92. parser.add_argument('--dice_loss_coef', default=1, type=float)
  93. parser.add_argument('--cls_loss_coef', default=2, type=float)
  94. parser.add_argument('--bbox_loss_coef', default=5, type=float)
  95. parser.add_argument('--giou_loss_coef', default=2, type=float)
  96. parser.add_argument('--focal_alpha', default=0.25, type=float)
  97. parser.add_argument('--id_loss_coef', default=1, type=float)
  98. # dataset parameters
  99. parser.add_argument('--dataset_file', default='coco')
  100. parser.add_argument('--coco_path', default='./data/coco', type=str)
  101. parser.add_argument('--coco_panoptic_path', type=str)
  102. parser.add_argument('--remove_difficult', action='store_true')
  103. parser.add_argument('--output_dir', default='',
  104. help='path where to save, empty for no saving')
  105. parser.add_argument('--device', default='cuda',
  106. help='device to use for training / testing')
  107. parser.add_argument('--seed', default=42, type=int)
  108. parser.add_argument('--resume', default='', help='resume from checkpoint')
  109. parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
  110. help='start epoch')
  111. parser.add_argument('--eval', action='store_true')
  112. parser.add_argument('--num_workers', default=2, type=int)
  113. parser.add_argument('--cache_mode', default=False, action='store_true', help='whether to cache images on memory')
  114. # PyTorch checkpointing for saving memory (torch.utils.checkpoint.checkpoint)
  115. parser.add_argument('--checkpoint_enc_ffn', default=False, action='store_true')
  116. parser.add_argument('--checkpoint_dec_ffn', default=False, action='store_true')
  117. # appended for track.
  118. parser.add_argument('--track_train_split', default='train', type=str)
  119. parser.add_argument('--track_eval_split', default='val', type=str)
  120. parser.add_argument('--track_thresh', default=0.4, type=float)
  121. parser.add_argument('--reid_shared', default=False, type=bool)
  122. parser.add_argument('--reid_dim', default=128, type=int)
  123. parser.add_argument('--num_ids', default=360, type=int)
  124. # detector for track.
  125. parser.add_argument('--det_val', default=False, action='store_true')
  126. return parser
  127. def main(args):
  128. utils.init_distributed_mode(args)
  129. print("git:\n {}\n".format(utils.get_sha()))
  130. if args.frozen_weights is not None:
  131. assert args.masks, "Frozen training is meant for segmentation only"
  132. print(args)
  133. device = torch.device(args.device)
  134. # fix the seed for reproducibility
  135. seed = args.seed + utils.get_rank()
  136. torch.manual_seed(seed)
  137. np.random.seed(seed)
  138. random.seed(seed)
  139. if args.det_val:
  140. assert args.eval, 'only support eval mode of detector for track'
  141. model, criterion, postprocessors = build_model(args)
  142. elif args.eval:
  143. model, criterion, postprocessors = build_tracktest_model(args)
  144. else:
  145. model, criterion, postprocessors = build_tracktrain_model(args)
  146. model.to(device)
  147. model_without_ddp = model
  148. n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
  149. print('number of params:', n_parameters)
  150. dataset_train = build_dataset(image_set=args.track_train_split, args=args)
  151. dataset_val = build_dataset(image_set=args.track_eval_split, args=args)
  152. if args.distributed:
  153. if args.cache_mode:
  154. sampler_train = samplers.NodeDistributedSampler(dataset_train)
  155. sampler_val = samplers.NodeDistributedSampler(dataset_val, shuffle=False)
  156. else:
  157. sampler_train = samplers.DistributedSampler(dataset_train)
  158. sampler_val = samplers.DistributedSampler(dataset_val, shuffle=False)
  159. else:
  160. sampler_train = torch.utils.data.RandomSampler(dataset_train)
  161. sampler_val = torch.utils.data.SequentialSampler(dataset_val)
  162. batch_sampler_train = torch.utils.data.BatchSampler(
  163. sampler_train, args.batch_size, drop_last=True)
  164. data_loader_train = DataLoader(dataset_train, batch_sampler=batch_sampler_train,
  165. collate_fn=utils.collate_fn, num_workers=args.num_workers,
  166. pin_memory=True)
  167. data_loader_val = DataLoader(dataset_val, args.batch_size, sampler=sampler_val,
  168. drop_last=False, collate_fn=utils.collate_fn, num_workers=args.num_workers,
  169. pin_memory=True)
  170. # lr_backbone_names = ["backbone.0", "backbone.neck", "input_proj", "transformer.encoder"]
  171. def match_name_keywords(n, name_keywords):
  172. out = False
  173. for b in name_keywords:
  174. if b in n:
  175. out = True
  176. break
  177. return out
  178. for n, p in model_without_ddp.named_parameters():
  179. print(n)
  180. param_dicts = [
  181. {
  182. "params":
  183. [p for n, p in model_without_ddp.named_parameters()
  184. if not match_name_keywords(n, args.lr_backbone_names) and not match_name_keywords(n, args.lr_linear_proj_names) and p.requires_grad],
  185. "lr": args.lr,
  186. },
  187. {
  188. "params": [p for n, p in model_without_ddp.named_parameters() if match_name_keywords(n, args.lr_backbone_names) and p.requires_grad],
  189. "lr": args.lr_backbone,
  190. },
  191. {
  192. "params": [p for n, p in model_without_ddp.named_parameters() if match_name_keywords(n, args.lr_linear_proj_names) and p.requires_grad],
  193. "lr": args.lr * args.lr_linear_proj_mult,
  194. }
  195. ]
  196. if args.sgd:
  197. optimizer = torch.optim.SGD(param_dicts, lr=args.lr, momentum=0.9,
  198. weight_decay=args.weight_decay)
  199. else:
  200. optimizer = torch.optim.AdamW(param_dicts, lr=args.lr,
  201. weight_decay=args.weight_decay)
  202. lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_drop)
  203. if args.distributed:
  204. model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu], find_unused_parameters=True)
  205. model_without_ddp = model.module
  206. if args.dataset_file == "coco_panoptic":
  207. # We also evaluate AP during panoptic training, on original coco DS
  208. coco_val = datasets.coco.build("val", args)
  209. base_ds = get_coco_api_from_dataset(coco_val)
  210. else:
  211. base_ds = get_coco_api_from_dataset(dataset_val)
  212. if args.frozen_weights is not None:
  213. checkpoint = torch.load(args.frozen_weights, map_location='cpu')
  214. model_without_ddp.detr.load_state_dict(checkpoint['model'])
  215. output_dir = Path(args.output_dir)
  216. if args.resume:
  217. if args.resume.startswith('https'):
  218. checkpoint = torch.hub.load_state_dict_from_url(
  219. args.resume, map_location='cpu', check_hash=True)
  220. else:
  221. checkpoint = torch.load(args.resume, map_location='cpu')
  222. missing_keys, unexpected_keys = model_without_ddp.load_state_dict(checkpoint['model'], strict=False)
  223. unexpected_keys = [k for k in unexpected_keys if not (k.endswith('total_params') or k.endswith('total_ops'))]
  224. if len(missing_keys) > 0:
  225. print('Missing Keys: {}'.format(missing_keys))
  226. if len(unexpected_keys) > 0:
  227. print('Unexpected Keys: {}'.format(unexpected_keys))
  228. if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
  229. import copy
  230. p_groups = copy.deepcopy(optimizer.param_groups)
  231. optimizer.load_state_dict(checkpoint['optimizer'])
  232. for pg, pg_old in zip(optimizer.param_groups, p_groups):
  233. pg['lr'] = pg_old['lr']
  234. pg['initial_lr'] = pg_old['initial_lr']
  235. print(optimizer.param_groups)
  236. lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
  237. # todo: this is a hack for doing experiment that resume from checkpoint and also modify lr scheduler (e.g., decrease lr in advance).
  238. args.override_resumed_lr_drop = True
  239. if args.override_resumed_lr_drop:
  240. print('Warning: (hack) args.override_resumed_lr_drop is set to True, so args.lr_drop would override lr_drop in resumed lr_scheduler.')
  241. lr_scheduler.step_size = args.lr_drop
  242. lr_scheduler.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups))
  243. lr_scheduler.step(lr_scheduler.last_epoch)
  244. args.start_epoch = checkpoint['epoch'] + 1
  245. # check the resumed model
  246. # if not args.eval:
  247. # test_stats, coco_evaluator, _ = evaluate(
  248. # model, criterion, postprocessors, data_loader_val, base_ds, device, args.output_dir
  249. # )
  250. if args.eval:
  251. assert args.batch_size == 1, print("Now only support 1.")
  252. # tracker = MOTXTracker(score_thresh=args.track_thresh)
  253. # test_stats, coco_evaluator, res_tracks = evaluate(model, criterion, postprocessors, data_loader_val,
  254. # base_ds, device, args.output_dir, tracker=tracker,
  255. # phase='eval', det_val=args.det_val)
  256. tracker = BYTETracker(args)
  257. test_stats, coco_evaluator, res_tracks = evaluate_track(args, model, criterion, postprocessors, data_loader_val,
  258. base_ds, device, args.output_dir, tracker=tracker,
  259. phase='eval', det_val=args.det_val)
  260. if args.output_dir:
  261. utils.save_on_master(coco_evaluator.coco_eval["bbox"].eval, output_dir / "eval.pth")
  262. if res_tracks is not None:
  263. print("Creating video index for {}.".format(args.dataset_file))
  264. video_to_images = defaultdict(list)
  265. video_names = defaultdict()
  266. for _, info in dataset_val.coco.imgs.items():
  267. video_to_images[info["video_id"]].append({"image_id": info["id"],
  268. "frame_id": info["frame_id"]})
  269. video_name = info["file_name"].split("/")[0]
  270. if video_name not in video_names:
  271. video_names[info["video_id"]] = video_name
  272. assert len(video_to_images) == len(video_names)
  273. # save mot results.
  274. save_track(res_tracks, args.output_dir, video_to_images, video_names, args.track_eval_split)
  275. return
  276. print("Start training")
  277. start_time = time.time()
  278. for epoch in range(args.start_epoch, args.epochs):
  279. if args.distributed:
  280. sampler_train.set_epoch(epoch)
  281. train_stats = train_one_epoch(
  282. model, criterion, data_loader_train, optimizer, device, epoch, args.clip_max_norm)
  283. lr_scheduler.step()
  284. if args.output_dir:
  285. checkpoint_paths = [output_dir / 'checkpoint.pth']
  286. # extra checkpoint before LR drop and every 5 epochs
  287. if (epoch + 1) % args.lr_drop == 0 or (epoch + 1) % 5 == 0:
  288. checkpoint_paths.append(output_dir / f'checkpoint{epoch:04}.pth')
  289. for checkpoint_path in checkpoint_paths:
  290. utils.save_on_master({
  291. 'model': model_without_ddp.state_dict(),
  292. 'optimizer': optimizer.state_dict(),
  293. 'lr_scheduler': lr_scheduler.state_dict(),
  294. 'epoch': epoch,
  295. 'args': args,
  296. }, checkpoint_path)
  297. if epoch % 10 == 0 or epoch > args.epochs - 5:
  298. test_stats, coco_evaluator, _ = evaluate(
  299. model, criterion, postprocessors, data_loader_val, base_ds, device, args.output_dir,
  300. )
  301. log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
  302. **{f'test_{k}': v for k, v in test_stats.items()},
  303. 'epoch': epoch,
  304. 'n_parameters': n_parameters}
  305. if args.output_dir and utils.is_main_process():
  306. with (output_dir / "log.txt").open("a") as f:
  307. f.write(json.dumps(log_stats) + "\n")
  308. # for evaluation logs
  309. if coco_evaluator is not None:
  310. (output_dir / 'eval').mkdir(exist_ok=True)
  311. if "bbox" in coco_evaluator.coco_eval:
  312. filenames = ['latest.pth']
  313. if epoch % 50 == 0:
  314. filenames.append(f'{epoch:03}.pth')
  315. for name in filenames:
  316. torch.save(coco_evaluator.coco_eval["bbox"].eval,
  317. output_dir / "eval" / name)
  318. total_time = time.time() - start_time
  319. total_time_str = str(datetime.timedelta(seconds=int(total_time)))
  320. print('Training time {}'.format(total_time_str))
  321. if __name__ == '__main__':
  322. parser = argparse.ArgumentParser('Deformable DETR training and evaluation script', parents=[get_args_parser()])
  323. args = parser.parse_args()
  324. if args.output_dir:
  325. Path(args.output_dir).mkdir(parents=True, exist_ok=True)
  326. main(args)