main_moe.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373
  1. # --------------------------------------------------------
  2. # Swin Transformer
  3. # Copyright (c) 2021 Microsoft
  4. # Licensed under The MIT License [see LICENSE for details]
  5. # Written by Ze Liu
  6. # --------------------------------------------------------
  7. from tutel import system
  8. import os
  9. import time
  10. import json
  11. import random
  12. import argparse
  13. import datetime
  14. import numpy as np
  15. from functools import partial
  16. import torch
  17. import torch.backends.cudnn as cudnn
  18. import torch.distributed as dist
  19. from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
  20. from timm.utils import accuracy, AverageMeter
  21. from config import get_config
  22. from models import build_model
  23. from data import build_loader
  24. from lr_scheduler import build_scheduler
  25. from optimizer import build_optimizer
  26. from logger import create_logger
  27. from utils import NativeScalerWithGradNormCount, reduce_tensor
  28. from utils_moe import load_checkpoint, load_pretrained, save_checkpoint, auto_resume_helper, hook_scale_grad
  29. assert torch.__version__ >= '1.8.0', "DDP-based MoE requires Pytorch >= 1.8.0"
  30. # pytorch major version (1.x or 2.x)
  31. PYTORCH_MAJOR_VERSION = int(torch.__version__.split('.')[0])
  32. def parse_option():
  33. parser = argparse.ArgumentParser('Swin Transformer training and evaluation script', add_help=False)
  34. parser.add_argument('--cfg', type=str, required=True, metavar="FILE", help='path to config file', )
  35. parser.add_argument(
  36. "--opts",
  37. help="Modify config options by adding 'KEY VALUE' pairs. ",
  38. default=None,
  39. nargs='+',
  40. )
  41. # easy config modification
  42. parser.add_argument('--batch-size', type=int, help="batch size for single GPU")
  43. parser.add_argument('--data-path', type=str, help='path to dataset')
  44. parser.add_argument('--zip', action='store_true', help='use zipped dataset instead of folder dataset')
  45. parser.add_argument('--cache-mode', type=str, default='part', choices=['no', 'full', 'part'],
  46. help='no: no cache, '
  47. 'full: cache all data, '
  48. 'part: sharding the dataset into nonoverlapping pieces and only cache one piece')
  49. parser.add_argument('--pretrained',
  50. help='pretrained weight from checkpoint, could be imagenet22k pretrained weight')
  51. parser.add_argument('--resume', help='resume from checkpoint')
  52. parser.add_argument('--accumulation-steps', type=int, help="gradient accumulation steps")
  53. parser.add_argument('--use-checkpoint', action='store_true',
  54. help="whether to use gradient checkpointing to save memory")
  55. parser.add_argument('--disable_amp', action='store_true', help='Disable pytorch amp')
  56. parser.add_argument('--amp-opt-level', type=str, choices=['O0', 'O1', 'O2'],
  57. help='mixed precision opt level, if O0, no amp is used (deprecated!)')
  58. parser.add_argument('--output', default='output', type=str, metavar='PATH',
  59. help='root of output folder, the full path is <output>/<model_name>/<tag> (default: output)')
  60. parser.add_argument('--tag', help='tag of experiment')
  61. parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
  62. parser.add_argument('--throughput', action='store_true', help='Test throughput only')
  63. # distributed training
  64. # for pytorch >= 2.0, use `os.environ['LOCAL_RANK']` instead
  65. # (see https://pytorch.org/docs/stable/distributed.html#launch-utility)
  66. if PYTORCH_MAJOR_VERSION == 1:
  67. parser.add_argument("--local_rank", type=int, required=True, help='local rank for DistributedDataParallel')
  68. args, unparsed = parser.parse_known_args()
  69. config = get_config(args)
  70. return args, config
  71. def main(config):
  72. dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn = build_loader(config)
  73. logger.info(f"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}")
  74. model = build_model(config)
  75. logger.info(str(model))
  76. # For Tutel MoE
  77. for name, param in model.named_parameters():
  78. if param.requires_grad == True and hasattr(param, 'skip_allreduce') and param.skip_allreduce is True:
  79. model.add_param_to_skip_allreduce(name)
  80. param.register_hook(partial(hook_scale_grad, dist.get_world_size()))
  81. logger.info(f"[rank{dist.get_rank()}] [{name}] skip all_reduce and div {dist.get_world_size()} for grad")
  82. n_parameters_single = sum(p.numel() * model.sharded_count if hasattr(p, 'skip_allreduce')
  83. else p.numel() for p in model.parameters() if p.requires_grad)
  84. logger.info(f"number of params single: {n_parameters_single}")
  85. n_parameters_whole = sum(p.numel() * model.sharded_count * model.global_experts if hasattr(p, 'skip_allreduce')
  86. else p.numel() for p in model.parameters() if p.requires_grad)
  87. logger.info(f"number of params whole: {n_parameters_whole}")
  88. if hasattr(model, 'flops'):
  89. flops = model.flops()
  90. logger.info(f"number of GFLOPs: {flops / 1e9}")
  91. model.cuda(config.LOCAL_RANK)
  92. model_without_ddp = model
  93. optimizer = build_optimizer(config, model)
  94. model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[config.LOCAL_RANK], broadcast_buffers=False)
  95. loss_scaler = NativeScalerWithGradNormCount()
  96. if config.TRAIN.ACCUMULATION_STEPS > 1:
  97. lr_scheduler = build_scheduler(config, optimizer, len(data_loader_train) // config.TRAIN.ACCUMULATION_STEPS)
  98. else:
  99. lr_scheduler = build_scheduler(config, optimizer, len(data_loader_train))
  100. if config.AUG.MIXUP > 0.:
  101. # smoothing is handled with mixup label transform
  102. criterion = SoftTargetCrossEntropy()
  103. elif config.MODEL.LABEL_SMOOTHING > 0.:
  104. criterion = LabelSmoothingCrossEntropy(smoothing=config.MODEL.LABEL_SMOOTHING)
  105. else:
  106. criterion = torch.nn.CrossEntropyLoss()
  107. max_accuracy = 0.0
  108. if config.TRAIN.AUTO_RESUME:
  109. resume_file = auto_resume_helper(config.OUTPUT, config.TRAIN.MOE.SAVE_MASTER)
  110. if resume_file:
  111. if config.MODEL.RESUME:
  112. logger.warning(f"auto-resume changing resume file from {config.MODEL.RESUME} to {resume_file}")
  113. config.defrost()
  114. config.MODEL.RESUME = resume_file
  115. config.freeze()
  116. logger.info(f'auto resuming from {resume_file}')
  117. else:
  118. logger.info(f'no checkpoint found in {config.OUTPUT}, ignoring auto resume')
  119. if config.MODEL.RESUME:
  120. max_accuracy = load_checkpoint(config, model_without_ddp, optimizer, lr_scheduler, loss_scaler, logger)
  121. acc1, acc5, loss = validate(config, data_loader_val, model)
  122. logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%")
  123. if config.EVAL_MODE:
  124. return
  125. if config.MODEL.PRETRAINED and (not config.MODEL.RESUME):
  126. load_pretrained(config, model_without_ddp, logger)
  127. acc1, acc5, loss = validate(config, data_loader_val, model)
  128. logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%")
  129. if config.EVAL_MODE:
  130. return
  131. if config.THROUGHPUT_MODE:
  132. throughput(data_loader_val, model, logger)
  133. return
  134. logger.info("Start training")
  135. start_time = time.time()
  136. for epoch in range(config.TRAIN.START_EPOCH, config.TRAIN.EPOCHS):
  137. data_loader_train.sampler.set_epoch(epoch)
  138. train_one_epoch(config, model, criterion, data_loader_train, optimizer, epoch, mixup_fn, lr_scheduler,
  139. loss_scaler)
  140. if (epoch % config.SAVE_FREQ == 0 or epoch == (config.TRAIN.EPOCHS - 1)):
  141. save_checkpoint(config, epoch, model_without_ddp, max_accuracy, optimizer, lr_scheduler, loss_scaler,
  142. logger)
  143. acc1, acc5, loss = validate(config, data_loader_val, model)
  144. logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%")
  145. max_accuracy = max(max_accuracy, acc1)
  146. logger.info(f'Max accuracy: {max_accuracy:.2f}%')
  147. save_checkpoint(config, 'final', model_without_ddp, max_accuracy, optimizer, lr_scheduler, loss_scaler,
  148. logger, zero_redundancy=True)
  149. total_time = time.time() - start_time
  150. total_time_str = str(datetime.timedelta(seconds=int(total_time)))
  151. logger.info('Training time {}'.format(total_time_str))
  152. def train_one_epoch(config, model, criterion, data_loader, optimizer, epoch, mixup_fn, lr_scheduler, loss_scaler):
  153. model.train()
  154. optimizer.zero_grad()
  155. num_steps = len(data_loader)
  156. batch_time = AverageMeter()
  157. loss_meter = AverageMeter()
  158. loss_aux_meter = AverageMeter()
  159. loss_cls_meter = AverageMeter()
  160. norm_meter = AverageMeter()
  161. scaler_meter = AverageMeter()
  162. start = time.time()
  163. end = time.time()
  164. for idx, (samples, targets) in enumerate(data_loader):
  165. samples = samples.cuda(non_blocking=True)
  166. targets = targets.cuda(non_blocking=True)
  167. if mixup_fn is not None:
  168. samples, targets = mixup_fn(samples, targets)
  169. with torch.cuda.amp.autocast(enabled=config.AMP_ENABLE):
  170. outputs, l_aux = model(samples)
  171. l_cls = criterion(outputs, targets)
  172. loss = l_cls + l_aux
  173. loss = loss / config.TRAIN.ACCUMULATION_STEPS
  174. # this attribute is added by timm on one optimizer (adahessian)
  175. is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
  176. grad_norm = loss_scaler(loss, optimizer, clip_grad=config.TRAIN.CLIP_GRAD,
  177. parameters=model.parameters(), create_graph=is_second_order,
  178. update_grad=(idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0)
  179. if (idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0:
  180. optimizer.zero_grad()
  181. lr_scheduler.step_update((epoch * num_steps + idx) // config.TRAIN.ACCUMULATION_STEPS)
  182. loss_scale_value = loss_scaler.state_dict()["scale"]
  183. torch.cuda.synchronize()
  184. loss_meter.update(loss.item(), targets.size(0))
  185. loss_cls_meter.update(l_cls.item(), targets.size(0))
  186. loss_aux_meter.update(l_aux if isinstance(l_aux, float) else l_aux.item(), targets.size(0))
  187. if grad_norm is not None: # loss_scaler return None if not update
  188. norm_meter.update(grad_norm)
  189. scaler_meter.update(loss_scale_value)
  190. batch_time.update(time.time() - end)
  191. end = time.time()
  192. if idx % config.PRINT_FREQ == 0:
  193. lr = optimizer.param_groups[0]['lr']
  194. wd = optimizer.param_groups[0]['weight_decay']
  195. memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
  196. etas = batch_time.avg * (num_steps - idx)
  197. logger.info(
  198. f'Train: [{epoch}/{config.TRAIN.EPOCHS}][{idx}/{num_steps}]\t'
  199. f'eta {datetime.timedelta(seconds=int(etas))} lr {lr:.6f}\t wd {wd:.4f}\t'
  200. f'time {batch_time.val:.4f} ({batch_time.avg:.4f})\t'
  201. f'loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t'
  202. f'loss-cls {loss_cls_meter.val:.4f} ({loss_cls_meter.avg:.4f})\t'
  203. f'loss-aux {loss_aux_meter.val:.4f} ({loss_aux_meter.avg:.4f})\t'
  204. f'grad_norm {norm_meter.val:.4f} ({norm_meter.avg:.4f})\t'
  205. f'loss_scale {scaler_meter.val:.4f} ({scaler_meter.avg:.4f})\t'
  206. f'mem {memory_used:.0f}MB')
  207. epoch_time = time.time() - start
  208. logger.info(f"EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}")
  209. @torch.no_grad()
  210. def validate(config, data_loader, model):
  211. criterion = torch.nn.CrossEntropyLoss()
  212. model.eval()
  213. batch_time = AverageMeter()
  214. loss_cls_meter = AverageMeter()
  215. loss_aux_meter = AverageMeter()
  216. acc1_meter = AverageMeter()
  217. acc5_meter = AverageMeter()
  218. end = time.time()
  219. for idx, (images, target) in enumerate(data_loader):
  220. images = images.cuda(non_blocking=True)
  221. target = target.cuda(non_blocking=True)
  222. # compute output
  223. with torch.cuda.amp.autocast(enabled=config.AMP_ENABLE):
  224. output, l_aux = model(images)
  225. # measure accuracy and record loss
  226. l_cls = criterion(output, target)
  227. acc1, acc5 = accuracy(output, target, topk=(1, 5))
  228. acc1 = reduce_tensor(acc1)
  229. acc5 = reduce_tensor(acc5)
  230. loss_cls_meter.update(l_cls.item(), target.size(0))
  231. loss_aux_meter.update(l_aux if isinstance(l_aux, float) else l_aux.item(), target.size(0))
  232. acc1_meter.update(acc1.item(), target.size(0))
  233. acc5_meter.update(acc5.item(), target.size(0))
  234. # measure elapsed time
  235. batch_time.update(time.time() - end)
  236. end = time.time()
  237. if idx % config.PRINT_FREQ == 0:
  238. memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
  239. logger.info(
  240. f'Test: [{idx}/{len(data_loader)}]\t'
  241. f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
  242. f'Loss-Cls {loss_cls_meter.val:.4f} ({loss_cls_meter.avg:.4f})\t'
  243. f'Loss-Aux {loss_aux_meter.val:.4f} ({loss_aux_meter.avg:.4f})\t'
  244. f'Acc@1 {acc1_meter.val:.3f} ({acc1_meter.avg:.3f})\t'
  245. f'Acc@5 {acc5_meter.val:.3f} ({acc5_meter.avg:.3f})\t'
  246. f'Mem {memory_used:.0f}MB')
  247. logger.info(f' * Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}')
  248. return acc1_meter.avg, acc5_meter.avg, loss_cls_meter.avg
  249. @torch.no_grad()
  250. def throughput(data_loader, model, logger):
  251. model.eval()
  252. for idx, (images, _) in enumerate(data_loader):
  253. images = images.cuda(non_blocking=True)
  254. batch_size = images.shape[0]
  255. for i in range(50):
  256. model(images)
  257. torch.cuda.synchronize()
  258. logger.info(f"throughput averaged with 30 times")
  259. tic1 = time.time()
  260. for i in range(30):
  261. model(images)
  262. torch.cuda.synchronize()
  263. tic2 = time.time()
  264. logger.info(f"batch_size {batch_size} throughput {30 * batch_size / (tic2 - tic1)}")
  265. return
  266. if __name__ == '__main__':
  267. args, config = parse_option()
  268. if config.AMP_OPT_LEVEL:
  269. print("[warning] Apex amp has been deprecated, please use pytorch amp instead!")
  270. if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
  271. rank = int(os.environ["RANK"])
  272. world_size = int(os.environ['WORLD_SIZE'])
  273. print(f"RANK and WORLD_SIZE in environ: {rank}/{world_size}")
  274. else:
  275. rank = -1
  276. world_size = -1
  277. torch.cuda.set_device(config.LOCAL_RANK)
  278. torch.distributed.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank)
  279. torch.distributed.barrier()
  280. seed = config.SEED + dist.get_rank()
  281. torch.manual_seed(seed)
  282. torch.cuda.manual_seed(seed)
  283. np.random.seed(seed)
  284. random.seed(seed)
  285. cudnn.benchmark = True
  286. # linear scale the learning rate according to total batch size, may not be optimal
  287. linear_scaled_lr = config.TRAIN.BASE_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0
  288. linear_scaled_warmup_lr = config.TRAIN.WARMUP_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0
  289. linear_scaled_min_lr = config.TRAIN.MIN_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0
  290. # gradient accumulation also need to scale the learning rate
  291. if config.TRAIN.ACCUMULATION_STEPS > 1:
  292. linear_scaled_lr = linear_scaled_lr * config.TRAIN.ACCUMULATION_STEPS
  293. linear_scaled_warmup_lr = linear_scaled_warmup_lr * config.TRAIN.ACCUMULATION_STEPS
  294. linear_scaled_min_lr = linear_scaled_min_lr * config.TRAIN.ACCUMULATION_STEPS
  295. config.defrost()
  296. config.TRAIN.BASE_LR = linear_scaled_lr
  297. config.TRAIN.WARMUP_LR = linear_scaled_warmup_lr
  298. config.TRAIN.MIN_LR = linear_scaled_min_lr
  299. config.freeze()
  300. os.makedirs(config.OUTPUT, exist_ok=True)
  301. logger = create_logger(output_dir=config.OUTPUT, dist_rank=dist.get_rank(), name=f"{config.MODEL.NAME}")
  302. if dist.get_rank() == 0:
  303. path = os.path.join(config.OUTPUT, "config.json")
  304. with open(path, "w") as f:
  305. f.write(config.dump())
  306. logger.info(f"Full config saved to {path}")
  307. # print config
  308. logger.info(config.dump())
  309. logger.info(json.dumps(vars(args)))
  310. main(config)