trainer.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329
  1. #!/usr/bin/env python3
  2. # -*- coding:utf-8 -*-
  3. # Copyright (c) Megvii, Inc. and its affiliates.
  4. from loguru import logger
  5. import torch
  6. from torch.nn.parallel import DistributedDataParallel as DDP
  7. from torch.utils.tensorboard import SummaryWriter
  8. from yolox.data import DataPrefetcher
  9. from yolox.utils import (
  10. MeterBuffer,
  11. ModelEMA,
  12. all_reduce_norm,
  13. get_model_info,
  14. get_rank,
  15. get_world_size,
  16. gpu_mem_usage,
  17. load_ckpt,
  18. occupy_mem,
  19. save_checkpoint,
  20. setup_logger,
  21. synchronize
  22. )
  23. import datetime
  24. import os
  25. import time
  26. class Trainer:
  27. def __init__(self, exp, args):
  28. # init function only defines some basic attr, other attrs like model, optimizer are built in
  29. # before_train methods.
  30. self.exp = exp
  31. self.args = args
  32. # training related attr
  33. self.max_epoch = exp.max_epoch
  34. self.amp_training = args.fp16
  35. self.scaler = torch.cuda.amp.GradScaler(enabled=args.fp16)
  36. self.is_distributed = get_world_size() > 1
  37. self.rank = get_rank()
  38. self.local_rank = args.local_rank
  39. self.device = "cuda:{}".format(self.local_rank)
  40. self.use_model_ema = exp.ema
  41. # data/dataloader related attr
  42. self.data_type = torch.float16 if args.fp16 else torch.float32
  43. self.input_size = exp.input_size
  44. self.best_ap = 0
  45. # metric record
  46. self.meter = MeterBuffer(window_size=exp.print_interval)
  47. self.file_name = os.path.join(exp.output_dir, args.experiment_name)
  48. if self.rank == 0:
  49. os.makedirs(self.file_name, exist_ok=True)
  50. setup_logger(
  51. self.file_name,
  52. distributed_rank=self.rank,
  53. filename="train_log.txt",
  54. mode="a",
  55. )
  56. def train(self):
  57. self.before_train()
  58. try:
  59. self.train_in_epoch()
  60. except Exception:
  61. raise
  62. finally:
  63. self.after_train()
  64. def train_in_epoch(self):
  65. for self.epoch in range(self.start_epoch, self.max_epoch):
  66. self.before_epoch()
  67. self.train_in_iter()
  68. self.after_epoch()
  69. def train_in_iter(self):
  70. for self.iter in range(self.max_iter):
  71. self.before_iter()
  72. self.train_one_iter()
  73. self.after_iter()
  74. def train_one_iter(self):
  75. iter_start_time = time.time()
  76. inps, targets = self.prefetcher.next()
  77. track_ids = targets[:, :, 5]
  78. targets = targets[:, :, :5]
  79. inps = inps.to(self.data_type)
  80. targets = targets.to(self.data_type)
  81. targets.requires_grad = False
  82. data_end_time = time.time()
  83. with torch.cuda.amp.autocast(enabled=self.amp_training):
  84. outputs = self.model(inps, targets)
  85. loss = outputs["total_loss"]
  86. self.optimizer.zero_grad()
  87. self.scaler.scale(loss).backward()
  88. self.scaler.step(self.optimizer)
  89. self.scaler.update()
  90. if self.use_model_ema:
  91. self.ema_model.update(self.model)
  92. lr = self.lr_scheduler.update_lr(self.progress_in_iter + 1)
  93. for param_group in self.optimizer.param_groups:
  94. param_group["lr"] = lr
  95. iter_end_time = time.time()
  96. self.meter.update(
  97. iter_time=iter_end_time - iter_start_time,
  98. data_time=data_end_time - iter_start_time,
  99. lr=lr,
  100. **outputs,
  101. )
  102. def before_train(self):
  103. logger.info("args: {}".format(self.args))
  104. logger.info("exp value:\n{}".format(self.exp))
  105. # model related init
  106. torch.cuda.set_device(self.local_rank)
  107. model = self.exp.get_model()
  108. logger.info(
  109. "Model Summary: {}".format(get_model_info(model, self.exp.test_size))
  110. )
  111. model.to(self.device)
  112. # solver related init
  113. self.optimizer = self.exp.get_optimizer(self.args.batch_size)
  114. # value of epoch will be set in `resume_train`
  115. model = self.resume_train(model)
  116. # data related init
  117. self.no_aug = self.start_epoch >= self.max_epoch - self.exp.no_aug_epochs
  118. self.train_loader = self.exp.get_data_loader(
  119. batch_size=self.args.batch_size,
  120. is_distributed=self.is_distributed,
  121. no_aug=self.no_aug,
  122. )
  123. logger.info("init prefetcher, this might take one minute or less...")
  124. self.prefetcher = DataPrefetcher(self.train_loader)
  125. # max_iter means iters per epoch
  126. self.max_iter = len(self.train_loader)
  127. self.lr_scheduler = self.exp.get_lr_scheduler(
  128. self.exp.basic_lr_per_img * self.args.batch_size, self.max_iter
  129. )
  130. if self.args.occupy:
  131. occupy_mem(self.local_rank)
  132. if self.is_distributed:
  133. model = DDP(model, device_ids=[self.local_rank], broadcast_buffers=False)
  134. if self.use_model_ema:
  135. self.ema_model = ModelEMA(model, 0.9998)
  136. self.ema_model.updates = self.max_iter * self.start_epoch
  137. self.model = model
  138. self.model.train()
  139. self.evaluator = self.exp.get_evaluator(
  140. batch_size=self.args.batch_size, is_distributed=self.is_distributed
  141. )
  142. # Tensorboard logger
  143. if self.rank == 0:
  144. self.tblogger = SummaryWriter(self.file_name)
  145. logger.info("Training start...")
  146. #logger.info("\n{}".format(model))
  147. def after_train(self):
  148. logger.info(
  149. "Training of experiment is done and the best AP is {:.2f}".format(
  150. self.best_ap * 100
  151. )
  152. )
  153. def before_epoch(self):
  154. logger.info("---> start train epoch{}".format(self.epoch + 1))
  155. if self.epoch + 1 == self.max_epoch - self.exp.no_aug_epochs or self.no_aug:
  156. logger.info("--->No mosaic aug now!")
  157. self.train_loader.close_mosaic()
  158. logger.info("--->Add additional L1 loss now!")
  159. if self.is_distributed:
  160. self.model.module.head.use_l1 = True
  161. else:
  162. self.model.head.use_l1 = True
  163. self.exp.eval_interval = 1
  164. if not self.no_aug:
  165. self.save_ckpt(ckpt_name="last_mosaic_epoch")
  166. def after_epoch(self):
  167. if self.use_model_ema:
  168. self.ema_model.update_attr(self.model)
  169. self.save_ckpt(ckpt_name="latest")
  170. if (self.epoch + 1) % self.exp.eval_interval == 0:
  171. all_reduce_norm(self.model)
  172. self.evaluate_and_save_model()
  173. def before_iter(self):
  174. pass
  175. def after_iter(self):
  176. """
  177. `after_iter` contains two parts of logic:
  178. * log information
  179. * reset setting of resize
  180. """
  181. # log needed information
  182. if (self.iter + 1) % self.exp.print_interval == 0:
  183. # TODO check ETA logic
  184. left_iters = self.max_iter * self.max_epoch - (self.progress_in_iter + 1)
  185. eta_seconds = self.meter["iter_time"].global_avg * left_iters
  186. eta_str = "ETA: {}".format(datetime.timedelta(seconds=int(eta_seconds)))
  187. progress_str = "epoch: {}/{}, iter: {}/{}".format(
  188. self.epoch + 1, self.max_epoch, self.iter + 1, self.max_iter
  189. )
  190. loss_meter = self.meter.get_filtered_meter("loss")
  191. loss_str = ", ".join(
  192. ["{}: {:.3f}".format(k, v.latest) for k, v in loss_meter.items()]
  193. )
  194. time_meter = self.meter.get_filtered_meter("time")
  195. time_str = ", ".join(
  196. ["{}: {:.3f}s".format(k, v.avg) for k, v in time_meter.items()]
  197. )
  198. logger.info(
  199. "{}, mem: {:.0f}Mb, {}, {}, lr: {:.3e}".format(
  200. progress_str,
  201. gpu_mem_usage(),
  202. time_str,
  203. loss_str,
  204. self.meter["lr"].latest,
  205. )
  206. + (", size: {:d}, {}".format(self.input_size[0], eta_str))
  207. )
  208. self.meter.clear_meters()
  209. # random resizing
  210. if self.exp.random_size is not None and (self.progress_in_iter + 1) % 10 == 0:
  211. self.input_size = self.exp.random_resize(
  212. self.train_loader, self.epoch, self.rank, self.is_distributed
  213. )
  214. @property
  215. def progress_in_iter(self):
  216. return self.epoch * self.max_iter + self.iter
  217. def resume_train(self, model):
  218. if self.args.resume:
  219. logger.info("resume training")
  220. if self.args.ckpt is None:
  221. ckpt_file = os.path.join(self.file_name, "latest" + "_ckpt.pth.tar")
  222. else:
  223. ckpt_file = self.args.ckpt
  224. ckpt = torch.load(ckpt_file, map_location=self.device)
  225. # resume the model/optimizer state dict
  226. model.load_state_dict(ckpt["model"])
  227. self.optimizer.load_state_dict(ckpt["optimizer"])
  228. start_epoch = (
  229. self.args.start_epoch - 1
  230. if self.args.start_epoch is not None
  231. else ckpt["start_epoch"]
  232. )
  233. self.start_epoch = start_epoch
  234. logger.info(
  235. "loaded checkpoint '{}' (epoch {})".format(
  236. self.args.resume, self.start_epoch
  237. )
  238. ) # noqa
  239. else:
  240. if self.args.ckpt is not None:
  241. logger.info("loading checkpoint for fine tuning")
  242. ckpt_file = self.args.ckpt
  243. ckpt = torch.load(ckpt_file, map_location=self.device)["model"]
  244. model = load_ckpt(model, ckpt)
  245. self.start_epoch = 0
  246. return model
  247. def evaluate_and_save_model(self):
  248. evalmodel = self.ema_model.ema if self.use_model_ema else self.model
  249. ap50_95, ap50, summary = self.exp.eval(
  250. evalmodel, self.evaluator, self.is_distributed
  251. )
  252. self.model.train()
  253. if self.rank == 0:
  254. self.tblogger.add_scalar("val/COCOAP50", ap50, self.epoch + 1)
  255. self.tblogger.add_scalar("val/COCOAP50_95", ap50_95, self.epoch + 1)
  256. logger.info("\n" + summary)
  257. synchronize()
  258. #self.best_ap = max(self.best_ap, ap50_95)
  259. self.save_ckpt("last_epoch", ap50 > self.best_ap)
  260. self.best_ap = max(self.best_ap, ap50)
  261. def save_ckpt(self, ckpt_name, update_best_ckpt=False):
  262. if self.rank == 0:
  263. save_model = self.ema_model.ema if self.use_model_ema else self.model
  264. logger.info("Save weights to {}".format(self.file_name))
  265. ckpt_state = {
  266. "start_epoch": self.epoch + 1,
  267. "model": save_model.state_dict(),
  268. "optimizer": self.optimizer.state_dict(),
  269. }
  270. save_checkpoint(
  271. ckpt_state,
  272. update_best_ckpt,
  273. self.file_name,
  274. ckpt_name,
  275. )