main_simmim_ft.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342
  1. # --------------------------------------------------------
  2. # SimMIM
  3. # Copyright (c) 2021 Microsoft
  4. # Licensed under The MIT License [see LICENSE for details]
  5. # Written by Ze Liu
  6. # Modified by Zhenda Xie
  7. # --------------------------------------------------------
  8. import os
  9. import time
  10. import argparse
  11. import datetime
  12. import numpy as np
  13. import torch
  14. import torch.backends.cudnn as cudnn
  15. import torch.distributed as dist
  16. import torch.cuda.amp as amp
  17. from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
  18. from timm.utils import accuracy, AverageMeter
  19. from config import get_config
  20. from models import build_model
  21. from data import build_loader
  22. from lr_scheduler import build_scheduler
  23. from optimizer import build_optimizer
  24. from logger import create_logger
  25. from utils_simmim import load_checkpoint, load_pretrained, save_checkpoint, get_grad_norm, auto_resume_helper, \
  26. reduce_tensor
  27. # pytorch major version (1.x or 2.x)
  28. PYTORCH_MAJOR_VERSION = int(torch.__version__.split('.')[0])
  29. def parse_option():
  30. parser = argparse.ArgumentParser('SimMIM fine-tuning script', add_help=False)
  31. parser.add_argument('--cfg', type=str, required=True, metavar="FILE", help='path to config file', )
  32. parser.add_argument(
  33. "--opts",
  34. help="Modify config options by adding 'KEY VALUE' pairs. ",
  35. default=None,
  36. nargs='+',
  37. )
  38. # easy config modification
  39. parser.add_argument('--batch-size', type=int, help="batch size for single GPU")
  40. parser.add_argument('--data-path', type=str, help='path to dataset')
  41. parser.add_argument('--pretrained', type=str, help='path to pre-trained model')
  42. parser.add_argument('--resume', help='resume from checkpoint')
  43. parser.add_argument('--accumulation-steps', type=int, help="gradient accumulation steps")
  44. parser.add_argument('--use-checkpoint', action='store_true',
  45. help="whether to use gradient checkpointing to save memory")
  46. parser.add_argument('--enable-amp', action='store_true')
  47. parser.add_argument('--disable-amp', action='store_false', dest='enable_amp')
  48. parser.set_defaults(enable_amp=True)
  49. parser.add_argument('--output', default='output', type=str, metavar='PATH',
  50. help='root of output folder, the full path is <output>/<model_name>/<tag> (default: output)')
  51. parser.add_argument('--tag', help='tag of experiment')
  52. parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
  53. parser.add_argument('--throughput', action='store_true', help='Test throughput only')
  54. # distributed training
  55. # for pytorch >= 2.0, use `os.environ['LOCAL_RANK']` instead
  56. # (see https://pytorch.org/docs/stable/distributed.html#launch-utility)
  57. if PYTORCH_MAJOR_VERSION == 1:
  58. parser.add_argument("--local_rank", type=int, required=True, help='local rank for DistributedDataParallel')
  59. args = parser.parse_args()
  60. config = get_config(args)
  61. return args, config
  62. def main(config):
  63. dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn = build_loader(config, simmim=True,
  64. is_pretrain=False)
  65. logger.info(f"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}")
  66. model = build_model(config, is_pretrain=False)
  67. model.cuda()
  68. logger.info(str(model))
  69. optimizer = build_optimizer(config, model, simmim=True, is_pretrain=False)
  70. model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[config.LOCAL_RANK], broadcast_buffers=False)
  71. model_without_ddp = model.module
  72. n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
  73. logger.info(f"number of params: {n_parameters}")
  74. if hasattr(model_without_ddp, 'flops'):
  75. flops = model_without_ddp.flops()
  76. logger.info(f"number of GFLOPs: {flops / 1e9}")
  77. lr_scheduler = build_scheduler(config, optimizer, len(data_loader_train))
  78. scaler = amp.GradScaler()
  79. if config.AUG.MIXUP > 0.:
  80. # smoothing is handled with mixup label transform
  81. criterion = SoftTargetCrossEntropy()
  82. elif config.MODEL.LABEL_SMOOTHING > 0.:
  83. criterion = LabelSmoothingCrossEntropy(smoothing=config.MODEL.LABEL_SMOOTHING)
  84. else:
  85. criterion = torch.nn.CrossEntropyLoss()
  86. max_accuracy = 0.0
  87. if config.TRAIN.AUTO_RESUME:
  88. resume_file = auto_resume_helper(config.OUTPUT, logger)
  89. if resume_file:
  90. if config.MODEL.RESUME:
  91. logger.warning(f"auto-resume changing resume file from {config.MODEL.RESUME} to {resume_file}")
  92. config.defrost()
  93. config.MODEL.RESUME = resume_file
  94. config.freeze()
  95. logger.info(f'auto resuming from {resume_file}')
  96. else:
  97. logger.info(f'no checkpoint found in {config.OUTPUT}, ignoring auto resume')
  98. if config.MODEL.RESUME:
  99. max_accuracy = load_checkpoint(config, model_without_ddp, optimizer, lr_scheduler, scaler, logger)
  100. acc1, acc5, loss = validate(config, data_loader_val, model)
  101. logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%")
  102. if config.EVAL_MODE:
  103. return
  104. if config.MODEL.PRETRAINED and (not config.MODEL.RESUME):
  105. load_pretrained(config, model_without_ddp, logger)
  106. acc1, acc5, loss = validate(config, data_loader_val, model)
  107. logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%")
  108. if config.THROUGHPUT_MODE:
  109. throughput(data_loader_val, model, logger)
  110. return
  111. logger.info("Start training")
  112. start_time = time.time()
  113. for epoch in range(config.TRAIN.START_EPOCH, config.TRAIN.EPOCHS):
  114. data_loader_train.sampler.set_epoch(epoch)
  115. train_one_epoch(config, model, criterion, data_loader_train, optimizer, epoch, mixup_fn, lr_scheduler, scaler)
  116. if dist.get_rank() == 0 and (epoch % config.SAVE_FREQ == 0 or epoch == (config.TRAIN.EPOCHS - 1)):
  117. save_checkpoint(config, epoch, model_without_ddp, max_accuracy, optimizer, lr_scheduler, scaler, logger)
  118. acc1, acc5, loss = validate(config, data_loader_val, model)
  119. logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%")
  120. max_accuracy = max(max_accuracy, acc1)
  121. logger.info(f'Max accuracy: {max_accuracy:.2f}%')
  122. total_time = time.time() - start_time
  123. total_time_str = str(datetime.timedelta(seconds=int(total_time)))
  124. logger.info('Training time {}'.format(total_time_str))
  125. def train_one_epoch(config, model, criterion, data_loader, optimizer, epoch, mixup_fn, lr_scheduler, scaler):
  126. model.train()
  127. optimizer.zero_grad()
  128. logger.info(f'Current learning rate for different parameter groups: {[it["lr"] for it in optimizer.param_groups]}')
  129. num_steps = len(data_loader)
  130. batch_time = AverageMeter()
  131. loss_meter = AverageMeter()
  132. norm_meter = AverageMeter()
  133. loss_scale_meter = AverageMeter()
  134. start = time.time()
  135. end = time.time()
  136. for idx, (samples, targets) in enumerate(data_loader):
  137. samples = samples.cuda(non_blocking=True)
  138. targets = targets.cuda(non_blocking=True)
  139. if mixup_fn is not None:
  140. samples, targets = mixup_fn(samples, targets)
  141. outputs = model(samples)
  142. if config.TRAIN.ACCUMULATION_STEPS > 1:
  143. loss = criterion(outputs, targets)
  144. loss = loss / config.TRAIN.ACCUMULATION_STEPS
  145. scaler.scale(loss).backward()
  146. if config.TRAIN.CLIP_GRAD:
  147. scaler.unscale_(optimizer)
  148. grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.TRAIN.CLIP_GRAD)
  149. else:
  150. grad_norm = get_grad_norm(model.parameters())
  151. if (idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0:
  152. scaler.step(optimizer)
  153. optimizer.zero_grad()
  154. scaler.update()
  155. lr_scheduler.step_update(epoch * num_steps + idx)
  156. else:
  157. loss = criterion(outputs, targets)
  158. optimizer.zero_grad()
  159. scaler.scale(loss).backward()
  160. if config.TRAIN.CLIP_GRAD:
  161. scaler.unscale_(optimizer)
  162. grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.TRAIN.CLIP_GRAD)
  163. else:
  164. grad_norm = get_grad_norm(model.parameters())
  165. scaler.step(optimizer)
  166. scaler.update()
  167. lr_scheduler.step_update(epoch * num_steps + idx)
  168. torch.cuda.synchronize()
  169. loss_meter.update(loss.item(), targets.size(0))
  170. norm_meter.update(grad_norm)
  171. loss_scale_meter.update(scaler.get_scale())
  172. batch_time.update(time.time() - end)
  173. end = time.time()
  174. if idx % config.PRINT_FREQ == 0:
  175. lr = optimizer.param_groups[-1]['lr']
  176. memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
  177. etas = batch_time.avg * (num_steps - idx)
  178. logger.info(
  179. f'Train: [{epoch}/{config.TRAIN.EPOCHS}][{idx}/{num_steps}]\t'
  180. f'eta {datetime.timedelta(seconds=int(etas))} lr {lr:.6f}\t'
  181. f'time {batch_time.val:.4f} ({batch_time.avg:.4f})\t'
  182. f'loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t'
  183. f'grad_norm {norm_meter.val:.4f} ({norm_meter.avg:.4f})\t'
  184. f'loss_scale {loss_scale_meter.val:.4f} ({loss_scale_meter.avg:.4f})\t'
  185. f'mem {memory_used:.0f}MB')
  186. epoch_time = time.time() - start
  187. logger.info(f"EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}")
  188. @torch.no_grad()
  189. def validate(config, data_loader, model):
  190. criterion = torch.nn.CrossEntropyLoss()
  191. model.eval()
  192. batch_time = AverageMeter()
  193. loss_meter = AverageMeter()
  194. acc1_meter = AverageMeter()
  195. acc5_meter = AverageMeter()
  196. end = time.time()
  197. for idx, (images, target) in enumerate(data_loader):
  198. images = images.cuda(non_blocking=True)
  199. target = target.cuda(non_blocking=True)
  200. # compute output
  201. output = model(images)
  202. # measure accuracy and record loss
  203. loss = criterion(output, target)
  204. acc1, acc5 = accuracy(output, target, topk=(1, 5))
  205. acc1 = reduce_tensor(acc1)
  206. acc5 = reduce_tensor(acc5)
  207. loss = reduce_tensor(loss)
  208. loss_meter.update(loss.item(), target.size(0))
  209. acc1_meter.update(acc1.item(), target.size(0))
  210. acc5_meter.update(acc5.item(), target.size(0))
  211. # measure elapsed time
  212. batch_time.update(time.time() - end)
  213. end = time.time()
  214. if idx % config.PRINT_FREQ == 0:
  215. memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
  216. logger.info(
  217. f'Test: [{idx}/{len(data_loader)}]\t'
  218. f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
  219. f'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t'
  220. f'Acc@1 {acc1_meter.val:.3f} ({acc1_meter.avg:.3f})\t'
  221. f'Acc@5 {acc5_meter.val:.3f} ({acc5_meter.avg:.3f})\t'
  222. f'Mem {memory_used:.0f}MB')
  223. logger.info(f' * Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}')
  224. return acc1_meter.avg, acc5_meter.avg, loss_meter.avg
  225. @torch.no_grad()
  226. def throughput(data_loader, model, logger):
  227. model.eval()
  228. for idx, (images, _) in enumerate(data_loader):
  229. images = images.cuda(non_blocking=True)
  230. batch_size = images.shape[0]
  231. for i in range(50):
  232. model(images)
  233. torch.cuda.synchronize()
  234. logger.info(f"throughput averaged with 30 times")
  235. tic1 = time.time()
  236. for i in range(30):
  237. model(images)
  238. torch.cuda.synchronize()
  239. tic2 = time.time()
  240. logger.info(f"batch_size {batch_size} throughput {30 * batch_size / (tic2 - tic1)}")
  241. return
  242. if __name__ == '__main__':
  243. _, config = parse_option()
  244. if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
  245. rank = int(os.environ["RANK"])
  246. world_size = int(os.environ['WORLD_SIZE'])
  247. print(f"RANK and WORLD_SIZE in environ: {rank}/{world_size}")
  248. else:
  249. rank = -1
  250. world_size = -1
  251. torch.cuda.set_device(config.LOCAL_RANK)
  252. torch.distributed.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank)
  253. torch.distributed.barrier()
  254. seed = config.SEED + dist.get_rank()
  255. torch.manual_seed(seed)
  256. np.random.seed(seed)
  257. cudnn.benchmark = True
  258. # linear scale the learning rate according to total batch size, may not be optimal
  259. linear_scaled_lr = config.TRAIN.BASE_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0
  260. linear_scaled_warmup_lr = config.TRAIN.WARMUP_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0
  261. linear_scaled_min_lr = config.TRAIN.MIN_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0
  262. # gradient accumulation also need to scale the learning rate
  263. if config.TRAIN.ACCUMULATION_STEPS > 1:
  264. linear_scaled_lr = linear_scaled_lr * config.TRAIN.ACCUMULATION_STEPS
  265. linear_scaled_warmup_lr = linear_scaled_warmup_lr * config.TRAIN.ACCUMULATION_STEPS
  266. linear_scaled_min_lr = linear_scaled_min_lr * config.TRAIN.ACCUMULATION_STEPS
  267. config.defrost()
  268. config.TRAIN.BASE_LR = linear_scaled_lr
  269. config.TRAIN.WARMUP_LR = linear_scaled_warmup_lr
  270. config.TRAIN.MIN_LR = linear_scaled_min_lr
  271. config.freeze()
  272. os.makedirs(config.OUTPUT, exist_ok=True)
  273. logger = create_logger(output_dir=config.OUTPUT, dist_rank=dist.get_rank(), name=f"{config.MODEL.NAME}")
  274. if dist.get_rank() == 0:
  275. path = os.path.join(config.OUTPUT, "config.json")
  276. with open(path, "w") as f:
  277. f.write(config.dump())
  278. logger.info(f"Full config saved to {path}")
  279. # print config
  280. logger.info(config.dump())
  281. main(config)