main_simmim_pt.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234
  1. # --------------------------------------------------------
  2. # SimMIM
  3. # Copyright (c) 2021 Microsoft
  4. # Licensed under The MIT License [see LICENSE for details]
  5. # Written by Ze Liu
  6. # Modified by Zhenda Xie
  7. # --------------------------------------------------------
  8. import os
  9. import time
  10. import argparse
  11. import datetime
  12. import numpy as np
  13. import torch
  14. import torch.backends.cudnn as cudnn
  15. import torch.distributed as dist
  16. import torch.cuda.amp as amp
  17. from timm.utils import AverageMeter
  18. from config import get_config
  19. from models import build_model
  20. from data import build_loader
  21. from lr_scheduler import build_scheduler
  22. from optimizer import build_optimizer
  23. from logger import create_logger
  24. from utils_simmim import load_checkpoint, save_checkpoint, get_grad_norm, auto_resume_helper
  25. # pytorch major version (1.x or 2.x)
  26. PYTORCH_MAJOR_VERSION = int(torch.__version__.split('.')[0])
  27. def parse_option():
  28. parser = argparse.ArgumentParser('SimMIM pre-training script', add_help=False)
  29. parser.add_argument('--cfg', type=str, required=True, metavar="FILE", help='path to config file', )
  30. parser.add_argument(
  31. "--opts",
  32. help="Modify config options by adding 'KEY VALUE' pairs. ",
  33. default=None,
  34. nargs='+',
  35. )
  36. # easy config modification
  37. parser.add_argument('--batch-size', type=int, help="batch size for single GPU")
  38. parser.add_argument('--data-path', type=str, help='path to dataset')
  39. parser.add_argument('--resume', help='resume from checkpoint')
  40. parser.add_argument('--accumulation-steps', type=int, help="gradient accumulation steps")
  41. parser.add_argument('--use-checkpoint', action='store_true',
  42. help="whether to use gradient checkpointing to save memory")
  43. parser.add_argument('--enable-amp', action='store_true')
  44. parser.add_argument('--disable-amp', action='store_false', dest='enable_amp')
  45. parser.set_defaults(enable_amp=True)
  46. parser.add_argument('--output', default='output', type=str, metavar='PATH',
  47. help='root of output folder, the full path is <output>/<model_name>/<tag> (default: output)')
  48. parser.add_argument('--tag', help='tag of experiment')
  49. # distributed training
  50. # for pytorch >= 2.0, use `os.environ['LOCAL_RANK']` instead
  51. # (see https://pytorch.org/docs/stable/distributed.html#launch-utility)
  52. if PYTORCH_MAJOR_VERSION == 1:
  53. parser.add_argument("--local_rank", type=int, required=True, help='local rank for DistributedDataParallel')
  54. args = parser.parse_args()
  55. config = get_config(args)
  56. return args, config
  57. def main(config):
  58. data_loader_train = build_loader(config, simmim=True, is_pretrain=True)
  59. logger.info(f"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}")
  60. model = build_model(config, is_pretrain=True)
  61. model.cuda()
  62. logger.info(str(model))
  63. optimizer = build_optimizer(config, model, simmim=True, is_pretrain=True)
  64. model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[config.LOCAL_RANK], broadcast_buffers=False)
  65. model_without_ddp = model.module
  66. n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
  67. logger.info(f"number of params: {n_parameters}")
  68. if hasattr(model_without_ddp, 'flops'):
  69. flops = model_without_ddp.flops()
  70. logger.info(f"number of GFLOPs: {flops / 1e9}")
  71. lr_scheduler = build_scheduler(config, optimizer, len(data_loader_train))
  72. scaler = amp.GradScaler()
  73. if config.TRAIN.AUTO_RESUME:
  74. resume_file = auto_resume_helper(config.OUTPUT, logger)
  75. if resume_file:
  76. if config.MODEL.RESUME:
  77. logger.warning(f"auto-resume changing resume file from {config.MODEL.RESUME} to {resume_file}")
  78. config.defrost()
  79. config.MODEL.RESUME = resume_file
  80. config.freeze()
  81. logger.info(f'auto resuming from {resume_file}')
  82. else:
  83. logger.info(f'no checkpoint found in {config.OUTPUT}, ignoring auto resume')
  84. if config.MODEL.RESUME:
  85. load_checkpoint(config, model_without_ddp, optimizer, lr_scheduler, scaler, logger)
  86. logger.info("Start training")
  87. start_time = time.time()
  88. for epoch in range(config.TRAIN.START_EPOCH, config.TRAIN.EPOCHS):
  89. data_loader_train.sampler.set_epoch(epoch)
  90. train_one_epoch(config, model, data_loader_train, optimizer, epoch, lr_scheduler, scaler)
  91. if dist.get_rank() == 0 and (epoch % config.SAVE_FREQ == 0 or epoch == (config.TRAIN.EPOCHS - 1)):
  92. save_checkpoint(config, epoch, model_without_ddp, 0., optimizer, lr_scheduler, scaler, logger)
  93. total_time = time.time() - start_time
  94. total_time_str = str(datetime.timedelta(seconds=int(total_time)))
  95. logger.info('Training time {}'.format(total_time_str))
  96. def train_one_epoch(config, model, data_loader, optimizer, epoch, lr_scheduler, scaler):
  97. model.train()
  98. optimizer.zero_grad()
  99. num_steps = len(data_loader)
  100. batch_time = AverageMeter()
  101. loss_meter = AverageMeter()
  102. norm_meter = AverageMeter()
  103. loss_scale_meter = AverageMeter()
  104. start = time.time()
  105. end = time.time()
  106. for idx, (img, mask, _) in enumerate(data_loader):
  107. img = img.cuda(non_blocking=True)
  108. mask = mask.cuda(non_blocking=True)
  109. with amp.autocast(enabled=config.ENABLE_AMP):
  110. loss = model(img, mask)
  111. if config.TRAIN.ACCUMULATION_STEPS > 1:
  112. loss = loss / config.TRAIN.ACCUMULATION_STEPS
  113. scaler.scale(loss).backward()
  114. if config.TRAIN.CLIP_GRAD:
  115. scaler.unscale_(optimizer)
  116. grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.TRAIN.CLIP_GRAD)
  117. else:
  118. grad_norm = get_grad_norm(model.parameters())
  119. if (idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0:
  120. scaler.step(optimizer)
  121. optimizer.zero_grad()
  122. scaler.update()
  123. lr_scheduler.step_update(epoch * num_steps + idx)
  124. else:
  125. optimizer.zero_grad()
  126. scaler.scale(loss).backward()
  127. if config.TRAIN.CLIP_GRAD:
  128. scaler.unscale_(optimizer)
  129. grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.TRAIN.CLIP_GRAD)
  130. else:
  131. grad_norm = get_grad_norm(model.parameters())
  132. scaler.step(optimizer)
  133. scaler.update()
  134. lr_scheduler.step_update(epoch * num_steps + idx)
  135. torch.cuda.synchronize()
  136. loss_meter.update(loss.item(), img.size(0))
  137. norm_meter.update(grad_norm)
  138. loss_scale_meter.update(scaler.get_scale())
  139. batch_time.update(time.time() - end)
  140. end = time.time()
  141. if idx % config.PRINT_FREQ == 0:
  142. lr = optimizer.param_groups[0]['lr']
  143. memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
  144. etas = batch_time.avg * (num_steps - idx)
  145. logger.info(
  146. f'Train: [{epoch}/{config.TRAIN.EPOCHS}][{idx}/{num_steps}]\t'
  147. f'eta {datetime.timedelta(seconds=int(etas))} lr {lr:.6f}\t'
  148. f'time {batch_time.val:.4f} ({batch_time.avg:.4f})\t'
  149. f'loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t'
  150. f'grad_norm {norm_meter.val:.4f} ({norm_meter.avg:.4f})\t'
  151. f'loss_scale {loss_scale_meter.val:.4f} ({loss_scale_meter.avg:.4f})\t'
  152. f'mem {memory_used:.0f}MB')
  153. epoch_time = time.time() - start
  154. logger.info(f"EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}")
  155. if __name__ == '__main__':
  156. _, config = parse_option()
  157. if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
  158. rank = int(os.environ["RANK"])
  159. world_size = int(os.environ['WORLD_SIZE'])
  160. print(f"RANK and WORLD_SIZE in environ: {rank}/{world_size}")
  161. else:
  162. rank = -1
  163. world_size = -1
  164. torch.cuda.set_device(config.LOCAL_RANK)
  165. torch.distributed.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank)
  166. torch.distributed.barrier()
  167. seed = config.SEED + dist.get_rank()
  168. torch.manual_seed(seed)
  169. np.random.seed(seed)
  170. cudnn.benchmark = True
  171. # linear scale the learning rate according to total batch size, may not be optimal
  172. linear_scaled_lr = config.TRAIN.BASE_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0
  173. linear_scaled_warmup_lr = config.TRAIN.WARMUP_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0
  174. linear_scaled_min_lr = config.TRAIN.MIN_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0
  175. # gradient accumulation also need to scale the learning rate
  176. if config.TRAIN.ACCUMULATION_STEPS > 1:
  177. linear_scaled_lr = linear_scaled_lr * config.TRAIN.ACCUMULATION_STEPS
  178. linear_scaled_warmup_lr = linear_scaled_warmup_lr * config.TRAIN.ACCUMULATION_STEPS
  179. linear_scaled_min_lr = linear_scaled_min_lr * config.TRAIN.ACCUMULATION_STEPS
  180. config.defrost()
  181. config.TRAIN.BASE_LR = linear_scaled_lr
  182. config.TRAIN.WARMUP_LR = linear_scaled_warmup_lr
  183. config.TRAIN.MIN_LR = linear_scaled_min_lr
  184. config.freeze()
  185. os.makedirs(config.OUTPUT, exist_ok=True)
  186. logger = create_logger(output_dir=config.OUTPUT, dist_rank=dist.get_rank(), name=f"{config.MODEL.NAME}")
  187. if dist.get_rank() == 0:
  188. path = os.path.join(config.OUTPUT, "config.json")
  189. with open(path, "w") as f:
  190. f.write(config.dump())
  191. logger.info(f"Full config saved to {path}")
  192. # print config
  193. logger.info(config.dump())
  194. main(config)