main.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354
  1. # --------------------------------------------------------
  2. # Swin Transformer
  3. # Copyright (c) 2021 Microsoft
  4. # Licensed under The MIT License [see LICENSE for details]
  5. # Written by Ze Liu
  6. # --------------------------------------------------------
  7. import os
  8. import time
  9. import json
  10. import random
  11. import argparse
  12. import datetime
  13. import numpy as np
  14. import torch
  15. import torch.backends.cudnn as cudnn
  16. import torch.distributed as dist
  17. from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
  18. from timm.utils import accuracy, AverageMeter
  19. from config import get_config
  20. from models import build_model
  21. from data import build_loader
  22. from lr_scheduler import build_scheduler
  23. from optimizer import build_optimizer
  24. from logger import create_logger
  25. from utils import load_checkpoint, load_pretrained, save_checkpoint, NativeScalerWithGradNormCount, auto_resume_helper, \
  26. reduce_tensor
  27. # pytorch major version (1.x or 2.x)
  28. PYTORCH_MAJOR_VERSION = int(torch.__version__.split('.')[0])
  29. def parse_option():
  30. parser = argparse.ArgumentParser('Swin Transformer training and evaluation script', add_help=False)
  31. parser.add_argument('--cfg', type=str, required=True, metavar="FILE", help='path to config file', )
  32. parser.add_argument(
  33. "--opts",
  34. help="Modify config options by adding 'KEY VALUE' pairs. ",
  35. default=None,
  36. nargs='+',
  37. )
  38. # easy config modification
  39. parser.add_argument('--batch-size', type=int, help="batch size for single GPU")
  40. parser.add_argument('--data-path', type=str, help='path to dataset')
  41. parser.add_argument('--zip', action='store_true', help='use zipped dataset instead of folder dataset')
  42. parser.add_argument('--cache-mode', type=str, default='part', choices=['no', 'full', 'part'],
  43. help='no: no cache, '
  44. 'full: cache all data, '
  45. 'part: sharding the dataset into nonoverlapping pieces and only cache one piece')
  46. parser.add_argument('--pretrained',
  47. help='pretrained weight from checkpoint, could be imagenet22k pretrained weight')
  48. parser.add_argument('--resume', help='resume from checkpoint')
  49. parser.add_argument('--accumulation-steps', type=int, help="gradient accumulation steps")
  50. parser.add_argument('--use-checkpoint', action='store_true',
  51. help="whether to use gradient checkpointing to save memory")
  52. parser.add_argument('--disable_amp', action='store_true', help='Disable pytorch amp')
  53. parser.add_argument('--amp-opt-level', type=str, choices=['O0', 'O1', 'O2'],
  54. help='mixed precision opt level, if O0, no amp is used (deprecated!)')
  55. parser.add_argument('--output', default='output', type=str, metavar='PATH',
  56. help='root of output folder, the full path is <output>/<model_name>/<tag> (default: output)')
  57. parser.add_argument('--tag', help='tag of experiment')
  58. parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
  59. parser.add_argument('--throughput', action='store_true', help='Test throughput only')
  60. # distributed training
  61. # for pytorch >= 2.0, use `os.environ['LOCAL_RANK']` instead
  62. # (see https://pytorch.org/docs/stable/distributed.html#launch-utility)
  63. if PYTORCH_MAJOR_VERSION == 1:
  64. parser.add_argument("--local_rank", type=int, required=True, help='local rank for DistributedDataParallel')
  65. # for acceleration
  66. parser.add_argument('--fused_window_process', action='store_true',
  67. help='Fused window shift & window partition, similar for reversed part.')
  68. parser.add_argument('--fused_layernorm', action='store_true', help='Use fused layernorm.')
  69. ## overwrite optimizer in config (*.yaml) if specified, e.g., fused_adam/fused_lamb
  70. parser.add_argument('--optim', type=str,
  71. help='overwrite optimizer if provided, can be adamw/sgd/fused_adam/fused_lamb.')
  72. args, unparsed = parser.parse_known_args()
  73. config = get_config(args)
  74. return args, config
  75. def main(config):
  76. dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn = build_loader(config)
  77. logger.info(f"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}")
  78. model = build_model(config)
  79. logger.info(str(model))
  80. n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
  81. logger.info(f"number of params: {n_parameters}")
  82. if hasattr(model, 'flops'):
  83. flops = model.flops()
  84. logger.info(f"number of GFLOPs: {flops / 1e9}")
  85. model.cuda()
  86. model_without_ddp = model
  87. optimizer = build_optimizer(config, model)
  88. model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[config.LOCAL_RANK], broadcast_buffers=False)
  89. loss_scaler = NativeScalerWithGradNormCount()
  90. if config.TRAIN.ACCUMULATION_STEPS > 1:
  91. lr_scheduler = build_scheduler(config, optimizer, len(data_loader_train) // config.TRAIN.ACCUMULATION_STEPS)
  92. else:
  93. lr_scheduler = build_scheduler(config, optimizer, len(data_loader_train))
  94. if config.AUG.MIXUP > 0.:
  95. # smoothing is handled with mixup label transform
  96. criterion = SoftTargetCrossEntropy()
  97. elif config.MODEL.LABEL_SMOOTHING > 0.:
  98. criterion = LabelSmoothingCrossEntropy(smoothing=config.MODEL.LABEL_SMOOTHING)
  99. else:
  100. criterion = torch.nn.CrossEntropyLoss()
  101. max_accuracy = 0.0
  102. if config.TRAIN.AUTO_RESUME:
  103. resume_file = auto_resume_helper(config.OUTPUT)
  104. if resume_file:
  105. if config.MODEL.RESUME:
  106. logger.warning(f"auto-resume changing resume file from {config.MODEL.RESUME} to {resume_file}")
  107. config.defrost()
  108. config.MODEL.RESUME = resume_file
  109. config.freeze()
  110. logger.info(f'auto resuming from {resume_file}')
  111. else:
  112. logger.info(f'no checkpoint found in {config.OUTPUT}, ignoring auto resume')
  113. if config.MODEL.RESUME:
  114. max_accuracy = load_checkpoint(config, model_without_ddp, optimizer, lr_scheduler, loss_scaler, logger)
  115. acc1, acc5, loss = validate(config, data_loader_val, model)
  116. logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%")
  117. if config.EVAL_MODE:
  118. return
  119. if config.MODEL.PRETRAINED and (not config.MODEL.RESUME):
  120. load_pretrained(config, model_without_ddp, logger)
  121. acc1, acc5, loss = validate(config, data_loader_val, model)
  122. logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%")
  123. if config.THROUGHPUT_MODE:
  124. throughput(data_loader_val, model, logger)
  125. return
  126. logger.info("Start training")
  127. start_time = time.time()
  128. for epoch in range(config.TRAIN.START_EPOCH, config.TRAIN.EPOCHS):
  129. data_loader_train.sampler.set_epoch(epoch)
  130. train_one_epoch(config, model, criterion, data_loader_train, optimizer, epoch, mixup_fn, lr_scheduler,
  131. loss_scaler)
  132. if dist.get_rank() == 0 and (epoch % config.SAVE_FREQ == 0 or epoch == (config.TRAIN.EPOCHS - 1)):
  133. save_checkpoint(config, epoch, model_without_ddp, max_accuracy, optimizer, lr_scheduler, loss_scaler,
  134. logger)
  135. acc1, acc5, loss = validate(config, data_loader_val, model)
  136. logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%")
  137. max_accuracy = max(max_accuracy, acc1)
  138. logger.info(f'Max accuracy: {max_accuracy:.2f}%')
  139. total_time = time.time() - start_time
  140. total_time_str = str(datetime.timedelta(seconds=int(total_time)))
  141. logger.info('Training time {}'.format(total_time_str))
  142. def train_one_epoch(config, model, criterion, data_loader, optimizer, epoch, mixup_fn, lr_scheduler, loss_scaler):
  143. model.train()
  144. optimizer.zero_grad()
  145. num_steps = len(data_loader)
  146. batch_time = AverageMeter()
  147. loss_meter = AverageMeter()
  148. norm_meter = AverageMeter()
  149. scaler_meter = AverageMeter()
  150. start = time.time()
  151. end = time.time()
  152. for idx, (samples, targets) in enumerate(data_loader):
  153. samples = samples.cuda(non_blocking=True)
  154. targets = targets.cuda(non_blocking=True)
  155. if mixup_fn is not None:
  156. samples, targets = mixup_fn(samples, targets)
  157. with torch.cuda.amp.autocast(enabled=config.AMP_ENABLE):
  158. outputs = model(samples)
  159. loss = criterion(outputs, targets)
  160. loss = loss / config.TRAIN.ACCUMULATION_STEPS
  161. # this attribute is added by timm on one optimizer (adahessian)
  162. is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
  163. grad_norm = loss_scaler(loss, optimizer, clip_grad=config.TRAIN.CLIP_GRAD,
  164. parameters=model.parameters(), create_graph=is_second_order,
  165. update_grad=(idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0)
  166. if (idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0:
  167. optimizer.zero_grad()
  168. lr_scheduler.step_update((epoch * num_steps + idx) // config.TRAIN.ACCUMULATION_STEPS)
  169. loss_scale_value = loss_scaler.state_dict()["scale"]
  170. torch.cuda.synchronize()
  171. loss_meter.update(loss.item(), targets.size(0))
  172. if grad_norm is not None: # loss_scaler return None if not update
  173. norm_meter.update(grad_norm)
  174. scaler_meter.update(loss_scale_value)
  175. batch_time.update(time.time() - end)
  176. end = time.time()
  177. if idx % config.PRINT_FREQ == 0:
  178. lr = optimizer.param_groups[0]['lr']
  179. wd = optimizer.param_groups[0]['weight_decay']
  180. memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
  181. etas = batch_time.avg * (num_steps - idx)
  182. logger.info(
  183. f'Train: [{epoch}/{config.TRAIN.EPOCHS}][{idx}/{num_steps}]\t'
  184. f'eta {datetime.timedelta(seconds=int(etas))} lr {lr:.6f}\t wd {wd:.4f}\t'
  185. f'time {batch_time.val:.4f} ({batch_time.avg:.4f})\t'
  186. f'loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t'
  187. f'grad_norm {norm_meter.val:.4f} ({norm_meter.avg:.4f})\t'
  188. f'loss_scale {scaler_meter.val:.4f} ({scaler_meter.avg:.4f})\t'
  189. f'mem {memory_used:.0f}MB')
  190. epoch_time = time.time() - start
  191. logger.info(f"EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}")
  192. @torch.no_grad()
  193. def validate(config, data_loader, model):
  194. criterion = torch.nn.CrossEntropyLoss()
  195. model.eval()
  196. batch_time = AverageMeter()
  197. loss_meter = AverageMeter()
  198. acc1_meter = AverageMeter()
  199. acc5_meter = AverageMeter()
  200. end = time.time()
  201. for idx, (images, target) in enumerate(data_loader):
  202. images = images.cuda(non_blocking=True)
  203. target = target.cuda(non_blocking=True)
  204. # compute output
  205. with torch.cuda.amp.autocast(enabled=config.AMP_ENABLE):
  206. output = model(images)
  207. # measure accuracy and record loss
  208. loss = criterion(output, target)
  209. acc1, acc5 = accuracy(output, target, topk=(1, 2))
  210. acc1 = reduce_tensor(acc1)
  211. acc5 = reduce_tensor(acc5)
  212. loss = reduce_tensor(loss)
  213. loss_meter.update(loss.item(), target.size(0))
  214. acc1_meter.update(acc1.item(), target.size(0))
  215. acc5_meter.update(acc5.item(), target.size(0))
  216. # measure elapsed time
  217. batch_time.update(time.time() - end)
  218. end = time.time()
  219. if idx % config.PRINT_FREQ == 0:
  220. memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
  221. logger.info(
  222. f'Test: [{idx}/{len(data_loader)}]\t'
  223. f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
  224. f'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t'
  225. f'Acc@1 {acc1_meter.val:.3f} ({acc1_meter.avg:.3f})\t'
  226. f'Acc@5 {acc5_meter.val:.3f} ({acc5_meter.avg:.3f})\t'
  227. f'Mem {memory_used:.0f}MB')
  228. logger.info(f' * Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}')
  229. return acc1_meter.avg, acc5_meter.avg, loss_meter.avg
  230. @torch.no_grad()
  231. def throughput(data_loader, model, logger):
  232. model.eval()
  233. for idx, (images, _) in enumerate(data_loader):
  234. images = images.cuda(non_blocking=True)
  235. batch_size = images.shape[0]
  236. for i in range(50):
  237. model(images)
  238. torch.cuda.synchronize()
  239. logger.info(f"throughput averaged with 30 times")
  240. tic1 = time.time()
  241. for i in range(30):
  242. model(images)
  243. torch.cuda.synchronize()
  244. tic2 = time.time()
  245. logger.info(f"batch_size {batch_size} throughput {30 * batch_size / (tic2 - tic1)}")
  246. return
  247. if __name__ == '__main__':
  248. args, config = parse_option()
  249. if config.AMP_OPT_LEVEL:
  250. print("[warning] Apex amp has been deprecated, please use pytorch amp instead!")
  251. if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
  252. rank = int(os.environ["RANK"])
  253. world_size = int(os.environ['WORLD_SIZE'])
  254. print(f"RANK and WORLD_SIZE in environ: {rank}/{world_size}")
  255. else:
  256. rank = -1
  257. world_size = -1
  258. torch.cuda.set_device(config.LOCAL_RANK)
  259. torch.distributed.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank)
  260. torch.distributed.barrier()
  261. seed = config.SEED + dist.get_rank()
  262. torch.manual_seed(seed)
  263. torch.cuda.manual_seed(seed)
  264. np.random.seed(seed)
  265. random.seed(seed)
  266. cudnn.benchmark = True
  267. # linear scale the learning rate according to total batch size, may not be optimal
  268. linear_scaled_lr = config.TRAIN.BASE_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0
  269. linear_scaled_warmup_lr = config.TRAIN.WARMUP_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0
  270. linear_scaled_min_lr = config.TRAIN.MIN_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0
  271. # gradient accumulation also need to scale the learning rate
  272. if config.TRAIN.ACCUMULATION_STEPS > 1:
  273. linear_scaled_lr = linear_scaled_lr * config.TRAIN.ACCUMULATION_STEPS
  274. linear_scaled_warmup_lr = linear_scaled_warmup_lr * config.TRAIN.ACCUMULATION_STEPS
  275. linear_scaled_min_lr = linear_scaled_min_lr * config.TRAIN.ACCUMULATION_STEPS
  276. config.defrost()
  277. config.TRAIN.BASE_LR = linear_scaled_lr
  278. config.TRAIN.WARMUP_LR = linear_scaled_warmup_lr
  279. config.TRAIN.MIN_LR = linear_scaled_min_lr
  280. config.freeze()
  281. os.makedirs(config.OUTPUT, exist_ok=True)
  282. logger = create_logger(output_dir=config.OUTPUT, dist_rank=dist.get_rank(), name=f"{config.MODEL.NAME}")
  283. if dist.get_rank() == 0:
  284. path = os.path.join(config.OUTPUT, "config.json")
  285. with open(path, "w") as f:
  286. f.write(config.dump())
  287. logger.info(f"Full config saved to {path}")
  288. # print config
  289. logger.info(config.dump())
  290. logger.info(json.dumps(vars(args)))
  291. main(config)