main_test.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379
  1. # --------------------------------------------------------
  2. # Swin Transformer
  3. # Copyright (c) 2021 Microsoft
  4. # Licensed under The MIT License [see LICENSE for details]
  5. # Written by Ze Liu
  6. # --------------------------------------------------------
  7. import os
  8. import time
  9. import json
  10. import random
  11. import argparse
  12. import datetime
  13. import numpy as np
  14. import torch
  15. import torch.backends.cudnn as cudnn
  16. import torch.distributed as dist
  17. from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
  18. from timm.utils import accuracy, AverageMeter
  19. from config import get_config
  20. from models import build_model
  21. from data import build_loader
  22. from lr_scheduler import build_scheduler
  23. from optimizer import build_optimizer
  24. from logger import create_logger
  25. from utils import load_checkpoint, load_pretrained, save_checkpoint, NativeScalerWithGradNormCount, auto_resume_helper, \
  26. reduce_tensor
  27. # pytorch major version (1.x or 2.x)
  28. PYTORCH_MAJOR_VERSION = int(torch.__version__.split('.')[0])
  29. lists = []
  30. def parse_option():
  31. parser = argparse.ArgumentParser('Swin Transformer training and evaluation script', add_help=False)
  32. parser.add_argument('--cfg', type=str, required=True, metavar="FILE", help='path to config file', )
  33. parser.add_argument(
  34. "--opts",
  35. help="Modify config options by adding 'KEY VALUE' pairs. ",
  36. default=None,
  37. nargs='+',
  38. )
  39. # easy config modification
  40. parser.add_argument('--batch-size', type=int, help="batch size for single GPU")
  41. parser.add_argument('--data-path', type=str, help='path to dataset')
  42. parser.add_argument('--zip', action='store_true', help='use zipped dataset instead of folder dataset')
  43. parser.add_argument('--cache-mode', type=str, default='part', choices=['no', 'full', 'part'],
  44. help='no: no cache, '
  45. 'full: cache all data, '
  46. 'part: sharding the dataset into nonoverlapping pieces and only cache one piece')
  47. parser.add_argument('--pretrained',
  48. help='pretrained weight from checkpoint, could be imagenet22k pretrained weight')
  49. parser.add_argument('--resume', help='resume from checkpoint')
  50. parser.add_argument('--accumulation-steps', type=int, help="gradient accumulation steps")
  51. parser.add_argument('--use-checkpoint', action='store_true',
  52. help="whether to use gradient checkpointing to save memory")
  53. parser.add_argument('--disable_amp', action='store_true', help='Disable pytorch amp')
  54. parser.add_argument('--amp-opt-level', type=str, choices=['O0', 'O1', 'O2'],
  55. help='mixed precision opt level, if O0, no amp is used (deprecated!)')
  56. parser.add_argument('--output', default='output', type=str, metavar='PATH',
  57. help='root of output folder, the full path is <output>/<model_name>/<tag> (default: output)')
  58. parser.add_argument('--tag', help='tag of experiment')
  59. parser.add_argument('--eval', action='store_true', help='Perform evaluation only')
  60. parser.add_argument('--throughput', action='store_true', help='Test throughput only')
  61. # distributed training
  62. # for pytorch >= 2.0, use `os.environ['LOCAL_RANK']` instead
  63. # (see https://pytorch.org/docs/stable/distributed.html#launch-utility)
  64. if PYTORCH_MAJOR_VERSION == 1:
  65. parser.add_argument("--local_rank", type=int, required=True, help='local rank for DistributedDataParallel')
  66. # for acceleration
  67. parser.add_argument('--fused_window_process', action='store_true',
  68. help='Fused window shift & window partition, similar for reversed part.')
  69. parser.add_argument('--fused_layernorm', action='store_true', help='Use fused layernorm.')
  70. ## overwrite optimizer in config (*.yaml) if specified, e.g., fused_adam/fused_lamb
  71. parser.add_argument('--optim', type=str,
  72. help='overwrite optimizer if provided, can be adamw/sgd/fused_adam/fused_lamb.')
  73. args, unparsed = parser.parse_known_args()
  74. config = get_config(args)
  75. return args, config
  76. def main(config):
  77. dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn = build_loader(config)
  78. f = open("v.txt", "w")
  79. f.write('\n'.join([os.path.split(val[0])[-1] for val in dataset_val.imgs]))
  80. f.close()
  81. logger.info(f"Creating model:{config.MODEL.TYPE}/{config.MODEL.NAME}")
  82. model = build_model(config)
  83. logger.info(str(model))
  84. n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
  85. logger.info(f"number of params: {n_parameters}")
  86. if hasattr(model, 'flops'):
  87. flops = model.flops()
  88. logger.info(f"number of GFLOPs: {flops / 1e9}")
  89. model.cuda()
  90. model_without_ddp = model
  91. optimizer = build_optimizer(config, model)
  92. model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[config.LOCAL_RANK], broadcast_buffers=False)
  93. loss_scaler = NativeScalerWithGradNormCount()
  94. if config.TRAIN.ACCUMULATION_STEPS > 1:
  95. lr_scheduler = build_scheduler(config, optimizer, len(data_loader_train) // config.TRAIN.ACCUMULATION_STEPS)
  96. else:
  97. lr_scheduler = build_scheduler(config, optimizer, len(data_loader_train))
  98. if config.AUG.MIXUP > 0.:
  99. # smoothing is handled with mixup label transform
  100. criterion = SoftTargetCrossEntropy()
  101. elif config.MODEL.LABEL_SMOOTHING > 0.:
  102. criterion = LabelSmoothingCrossEntropy(smoothing=config.MODEL.LABEL_SMOOTHING)
  103. else:
  104. criterion = torch.nn.CrossEntropyLoss()
  105. max_accuracy = 0.0
  106. if config.TRAIN.AUTO_RESUME:
  107. resume_file = auto_resume_helper(config.OUTPUT)
  108. if resume_file:
  109. if config.MODEL.RESUME:
  110. logger.warning(f"auto-resume changing resume file from {config.MODEL.RESUME} to {resume_file}")
  111. config.defrost()
  112. config.MODEL.RESUME = resume_file
  113. config.freeze()
  114. logger.info(f'auto resuming from {resume_file}')
  115. else:
  116. logger.info(f'no checkpoint found in {config.OUTPUT}, ignoring auto resume')
  117. if config.MODEL.RESUME:
  118. max_accuracy = load_checkpoint(config, model_without_ddp, optimizer, lr_scheduler, loss_scaler, logger)
  119. acc1, acc5, loss = validate(config, data_loader_val, model)
  120. logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%")
  121. if config.EVAL_MODE:
  122. return
  123. if config.MODEL.PRETRAINED and (not config.MODEL.RESUME):
  124. load_pretrained(config, model_without_ddp, logger)
  125. acc1, acc5, loss = validate(config, data_loader_val, model)
  126. logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%")
  127. if config.THROUGHPUT_MODE:
  128. throughput(data_loader_val, model, logger)
  129. return
  130. logger.info("Start training")
  131. start_time = time.time()
  132. for epoch in range(config.TRAIN.START_EPOCH, config.TRAIN.EPOCHS):
  133. data_loader_train.sampler.set_epoch(epoch)
  134. train_one_epoch(config, model, criterion, data_loader_train, optimizer, epoch, mixup_fn, lr_scheduler,
  135. loss_scaler)
  136. if dist.get_rank() == 0 and (epoch % config.SAVE_FREQ == 0 or epoch == (config.TRAIN.EPOCHS - 1)):
  137. save_checkpoint(config, epoch, model_without_ddp, max_accuracy, optimizer, lr_scheduler, loss_scaler,
  138. logger)
  139. acc1, acc5, loss = validate(config, data_loader_val, model)
  140. logger.info(f"Accuracy of the network on the {len(dataset_val)} test images: {acc1:.1f}%")
  141. max_accuracy = max(max_accuracy, acc1)
  142. logger.info(f'Max accuracy: {max_accuracy:.2f}%')
  143. total_time = time.time() - start_time
  144. total_time_str = str(datetime.timedelta(seconds=int(total_time)))
  145. logger.info('Training time {}'.format(total_time_str))
  146. def train_one_epoch(config, model, criterion, data_loader, optimizer, epoch, mixup_fn, lr_scheduler, loss_scaler):
  147. model.train()
  148. optimizer.zero_grad()
  149. num_steps = len(data_loader)
  150. batch_time = AverageMeter()
  151. loss_meter = AverageMeter()
  152. norm_meter = AverageMeter()
  153. scaler_meter = AverageMeter()
  154. start = time.time()
  155. end = time.time()
  156. for idx, (samples, targets) in enumerate(data_loader):
  157. samples = samples.cuda(non_blocking=True)
  158. targets = targets.cuda(non_blocking=True)
  159. if mixup_fn is not None:
  160. samples, targets = mixup_fn(samples, targets)
  161. with torch.cuda.amp.autocast(enabled=config.AMP_ENABLE):
  162. outputs = model(samples)
  163. loss = criterion(outputs, targets)
  164. loss = loss / config.TRAIN.ACCUMULATION_STEPS
  165. # this attribute is added by timm on one optimizer (adahessian)
  166. is_second_order = hasattr(optimizer, 'is_second_order') and optimizer.is_second_order
  167. grad_norm = loss_scaler(loss, optimizer, clip_grad=config.TRAIN.CLIP_GRAD,
  168. parameters=model.parameters(), create_graph=is_second_order,
  169. update_grad=(idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0)
  170. if (idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0:
  171. optimizer.zero_grad()
  172. lr_scheduler.step_update((epoch * num_steps + idx) // config.TRAIN.ACCUMULATION_STEPS)
  173. loss_scale_value = loss_scaler.state_dict()["scale"]
  174. torch.cuda.synchronize()
  175. loss_meter.update(loss.item(), targets.size(0))
  176. if grad_norm is not None: # loss_scaler return None if not update
  177. norm_meter.update(grad_norm)
  178. scaler_meter.update(loss_scale_value)
  179. batch_time.update(time.time() - end)
  180. end = time.time()
  181. if idx % config.PRINT_FREQ == 0:
  182. lr = optimizer.param_groups[0]['lr']
  183. wd = optimizer.param_groups[0]['weight_decay']
  184. memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
  185. etas = batch_time.avg * (num_steps - idx)
  186. logger.info(
  187. f'Train: [{epoch}/{config.TRAIN.EPOCHS}][{idx}/{num_steps}]\t'
  188. f'eta {datetime.timedelta(seconds=int(etas))} lr {lr:.6f}\t wd {wd:.4f}\t'
  189. f'time {batch_time.val:.4f} ({batch_time.avg:.4f})\t'
  190. f'loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t'
  191. f'grad_norm {norm_meter.val:.4f} ({norm_meter.avg:.4f})\t'
  192. f'loss_scale {scaler_meter.val:.4f} ({scaler_meter.avg:.4f})\t'
  193. f'mem {memory_used:.0f}MB')
  194. epoch_time = time.time() - start
  195. logger.info(f"EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}")
  196. @torch.no_grad()
  197. def validate(config, data_loader, model):
  198. criterion = torch.nn.CrossEntropyLoss()
  199. model.eval()
  200. batch_time = AverageMeter()
  201. loss_meter = AverageMeter()
  202. acc1_meter = AverageMeter()
  203. acc5_meter = AverageMeter()
  204. end = time.time()
  205. for idx, (images, target) in enumerate(data_loader):
  206. images = images.cuda(non_blocking=True)
  207. target = target.cuda(non_blocking=True)
  208. # compute output
  209. with torch.cuda.amp.autocast(enabled=config.AMP_ENABLE):
  210. output = model(images)
  211. # measure accuracy and record loss
  212. # print(idx)
  213. output1 = output
  214. # print(idx, torch.softmax(output1, 1, torch.float).argmax(dim=1).item())
  215. lists.append('{} {}'.format(idx, torch.softmax(output1, 1, torch.float).argmax(dim=1).item()))
  216. loss = criterion(output, target)
  217. acc1, acc5 = accuracy(output, target, topk=(1, 2))
  218. acc1 = reduce_tensor(acc1)
  219. acc5 = reduce_tensor(acc5)
  220. loss = reduce_tensor(loss)
  221. loss_meter.update(loss.item(), target.size(0))
  222. acc1_meter.update(acc1.item(), target.size(0))
  223. acc5_meter.update(acc5.item(), target.size(0))
  224. # measure elapsed time
  225. batch_time.update(time.time() - end)
  226. end = time.time()
  227. # if idx % config.PRINT_FREQ == 0:
  228. if idx % 40655 == 0:
  229. str = '\n'
  230. f = open("k.txt", "w")
  231. f.write(str.join(lists))
  232. f.close()
  233. if idx % 1000 == 0:
  234. memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0)
  235. logger.info(
  236. f'Test: [{idx}/{len(data_loader)}]\t'
  237. f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
  238. f'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t'
  239. f'Acc@1 {acc1_meter.val:.3f} ({acc1_meter.avg:.3f})\t'
  240. f'Acc@5 {acc5_meter.val:.3f} ({acc5_meter.avg:.3f})\t'
  241. f'Mem {memory_used:.0f}MB')
  242. logger.info(f' * Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}')
  243. return acc1_meter.avg, acc5_meter.avg, loss_meter.avg
  244. @torch.no_grad()
  245. def throughput(data_loader, model, logger):
  246. model.eval()
  247. for idx, (images, _) in enumerate(data_loader):
  248. images = images.cuda(non_blocking=True)
  249. batch_size = images.shape[0]
  250. for i in range(50):
  251. model(images)
  252. torch.cuda.synchronize()
  253. logger.info(f"throughput averaged with 30 times")
  254. tic1 = time.time()
  255. for i in range(30):
  256. model(images)
  257. torch.cuda.synchronize()
  258. tic2 = time.time()
  259. logger.info(f"batch_size {batch_size} throughput {30 * batch_size / (tic2 - tic1)}")
  260. return
  261. if __name__ == '__main__':
  262. args, config = parse_option()
  263. if config.AMP_OPT_LEVEL:
  264. print("[warning] Apex amp has been deprecated, please use pytorch amp instead!")
  265. if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
  266. rank = int(os.environ["RANK"])
  267. world_size = int(os.environ['WORLD_SIZE'])
  268. print(f"RANK and WORLD_SIZE in environ: {rank}/{world_size}")
  269. else:
  270. rank = -1
  271. world_size = -1
  272. torch.cuda.set_device(config.LOCAL_RANK)
  273. torch.distributed.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank)
  274. torch.distributed.barrier()
  275. seed = config.SEED + dist.get_rank()
  276. torch.manual_seed(seed)
  277. torch.cuda.manual_seed(seed)
  278. np.random.seed(seed)
  279. random.seed(seed)
  280. cudnn.benchmark = True
  281. # linear scale the learning rate according to total batch size, may not be optimal
  282. linear_scaled_lr = config.TRAIN.BASE_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0
  283. linear_scaled_warmup_lr = config.TRAIN.WARMUP_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0
  284. linear_scaled_min_lr = config.TRAIN.MIN_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 512.0
  285. # gradient accumulation also need to scale the learning rate
  286. if config.TRAIN.ACCUMULATION_STEPS > 1:
  287. linear_scaled_lr = linear_scaled_lr * config.TRAIN.ACCUMULATION_STEPS
  288. linear_scaled_warmup_lr = linear_scaled_warmup_lr * config.TRAIN.ACCUMULATION_STEPS
  289. linear_scaled_min_lr = linear_scaled_min_lr * config.TRAIN.ACCUMULATION_STEPS
  290. config.defrost()
  291. config.TRAIN.BASE_LR = linear_scaled_lr
  292. config.TRAIN.WARMUP_LR = linear_scaled_warmup_lr
  293. config.TRAIN.MIN_LR = linear_scaled_min_lr
  294. config.freeze()
  295. os.makedirs(config.OUTPUT, exist_ok=True)
  296. logger = create_logger(output_dir=config.OUTPUT, dist_rank=dist.get_rank(), name=f"{config.MODEL.NAME}")
  297. if dist.get_rank() == 0:
  298. path = os.path.join(config.OUTPUT, "config.json")
  299. with open(path, "w") as f:
  300. f.write(config.dump())
  301. logger.info(f"Full config saved to {path}")
  302. # print config
  303. logger.info(config.dump())
  304. logger.info(json.dumps(vars(args)))
  305. main(config)