main_amp.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526
  1. import argparse
  2. import os
  3. import shutil
  4. import time
  5. import torch
  6. import torch.nn as nn
  7. import torch.nn.parallel
  8. import torch.backends.cudnn as cudnn
  9. import torch.distributed as dist
  10. import torch.optim
  11. import torch.utils.data
  12. import torch.utils.data.distributed
  13. import torchvision.transforms as transforms
  14. import torchvision.datasets as datasets
  15. import torchvision.models as models
  16. import numpy as np
  17. try:
  18. from apex.parallel import DistributedDataParallel as DDP
  19. from apex.fp16_utils import *
  20. from apex import amp, optimizers
  21. from apex.multi_tensor_apply import multi_tensor_applier
  22. except ImportError:
  23. raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example.")
  24. model_names = sorted(name for name in models.__dict__
  25. if name.islower() and not name.startswith("__")
  26. and callable(models.__dict__[name]))
  27. parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
  28. parser.add_argument('data', metavar='DIR',
  29. help='path to dataset')
  30. parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18',
  31. choices=model_names,
  32. help='model architecture: ' +
  33. ' | '.join(model_names) +
  34. ' (default: resnet18)')
  35. parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
  36. help='number of data loading workers (default: 4)')
  37. parser.add_argument('--epochs', default=90, type=int, metavar='N',
  38. help='number of total epochs to run')
  39. parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
  40. help='manual epoch number (useful on restarts)')
  41. parser.add_argument('-b', '--batch-size', default=256, type=int,
  42. metavar='N', help='mini-batch size per process (default: 256)')
  43. parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
  44. metavar='LR', help='Initial learning rate. Will be scaled by <global batch size>/256: args.lr = args.lr*float(args.batch_size*args.world_size)/256. A warmup schedule will also be applied over the first 5 epochs.')
  45. parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
  46. help='momentum')
  47. parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
  48. metavar='W', help='weight decay (default: 1e-4)')
  49. parser.add_argument('--print-freq', '-p', default=10, type=int,
  50. metavar='N', help='print frequency (default: 10)')
  51. parser.add_argument('--resume', default='', type=str, metavar='PATH',
  52. help='path to latest checkpoint (default: none)')
  53. parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
  54. help='evaluate model on validation set')
  55. parser.add_argument('--pretrained', dest='pretrained', action='store_true',
  56. help='use pre-trained model')
  57. parser.add_argument('--prof', dest='prof', action='store_true',
  58. help='Only run 10 iterations for profiling.')
  59. parser.add_argument('--deterministic', action='store_true')
  60. parser.add_argument("--local_rank", default=0, type=int)
  61. parser.add_argument('--sync_bn', action='store_true',
  62. help='enabling apex sync BN.')
  63. parser.add_argument('--has-ext', action='store_true')
  64. parser.add_argument('--opt-level', type=str)
  65. parser.add_argument('--keep-batchnorm-fp32', type=str, default=None)
  66. parser.add_argument('--loss-scale', type=str, default=None)
  67. parser.add_argument('--fused-adam', action='store_true')
  68. parser.add_argument('--prints-to-process', type=int, default=10)
  69. cudnn.benchmark = True
  70. def fast_collate(batch):
  71. imgs = [img[0] for img in batch]
  72. targets = torch.tensor([target[1] for target in batch], dtype=torch.int64)
  73. w = imgs[0].size[0]
  74. h = imgs[0].size[1]
  75. tensor = torch.zeros( (len(imgs), 3, h, w), dtype=torch.uint8 )
  76. for i, img in enumerate(imgs):
  77. nump_array = np.asarray(img, dtype=np.uint8)
  78. if(nump_array.ndim < 3):
  79. nump_array = np.expand_dims(nump_array, axis=-1)
  80. nump_array = np.rollaxis(nump_array, 2)
  81. tensor[i] += torch.from_numpy(nump_array)
  82. return tensor, targets
  83. best_prec1 = 0
  84. args = parser.parse_args()
  85. # Let multi_tensor_applier be the canary in the coalmine
  86. # that verifies if the backend is what we think it is
  87. assert multi_tensor_applier.available == args.has_ext
  88. print("opt_level = {}".format(args.opt_level))
  89. print("keep_batchnorm_fp32 = {}".format(args.keep_batchnorm_fp32), type(args.keep_batchnorm_fp32))
  90. print("loss_scale = {}".format(args.loss_scale), type(args.loss_scale))
  91. print("\nCUDNN VERSION: {}\n".format(torch.backends.cudnn.version()))
  92. if args.deterministic:
  93. cudnn.benchmark = False
  94. cudnn.deterministic = True
  95. torch.manual_seed(args.local_rank)
  96. torch.set_printoptions(precision=10)
  97. def main():
  98. global best_prec1, args
  99. args.distributed = False
  100. if 'WORLD_SIZE' in os.environ:
  101. args.distributed = int(os.environ['WORLD_SIZE']) > 1
  102. args.gpu = 0
  103. args.world_size = 1
  104. if args.distributed:
  105. args.gpu = args.local_rank % torch.cuda.device_count()
  106. torch.cuda.set_device(args.gpu)
  107. torch.distributed.init_process_group(backend='nccl',
  108. init_method='env://')
  109. args.world_size = torch.distributed.get_world_size()
  110. assert torch.backends.cudnn.enabled, "Amp requires cudnn backend to be enabled."
  111. # create model
  112. if args.pretrained:
  113. print("=> using pre-trained model '{}'".format(args.arch))
  114. model = models.__dict__[args.arch](pretrained=True)
  115. else:
  116. print("=> creating model '{}'".format(args.arch))
  117. model = models.__dict__[args.arch]()
  118. if args.sync_bn:
  119. import apex
  120. print("using apex synced BN")
  121. model = apex.parallel.convert_syncbn_model(model)
  122. model = model.cuda()
  123. # Scale learning rate based on global batch size
  124. args.lr = args.lr*float(args.batch_size*args.world_size)/256.
  125. if args.fused_adam:
  126. optimizer = optimizers.FusedAdam(model.parameters())
  127. else:
  128. optimizer = torch.optim.SGD(model.parameters(), args.lr,
  129. momentum=args.momentum,
  130. weight_decay=args.weight_decay)
  131. model, optimizer = amp.initialize(
  132. model, optimizer,
  133. # enabled=False,
  134. opt_level=args.opt_level,
  135. keep_batchnorm_fp32=args.keep_batchnorm_fp32,
  136. loss_scale=args.loss_scale
  137. )
  138. if args.distributed:
  139. # By default, apex.parallel.DistributedDataParallel overlaps communication with
  140. # computation in the backward pass.
  141. # model = DDP(model)
  142. # delay_allreduce delays all communication to the end of the backward pass.
  143. model = DDP(model, delay_allreduce=True)
  144. # define loss function (criterion) and optimizer
  145. criterion = nn.CrossEntropyLoss().cuda()
  146. # Optionally resume from a checkpoint
  147. if args.resume:
  148. # Use a local scope to avoid dangling references
  149. def resume():
  150. if os.path.isfile(args.resume):
  151. print("=> loading checkpoint '{}'".format(args.resume))
  152. checkpoint = torch.load(args.resume, map_location = lambda storage, loc: storage.cuda(args.gpu))
  153. args.start_epoch = checkpoint['epoch']
  154. best_prec1 = checkpoint['best_prec1']
  155. model.load_state_dict(checkpoint['state_dict'])
  156. optimizer.load_state_dict(checkpoint['optimizer'])
  157. print("=> loaded checkpoint '{}' (epoch {})"
  158. .format(args.resume, checkpoint['epoch']))
  159. else:
  160. print("=> no checkpoint found at '{}'".format(args.resume))
  161. resume()
  162. # Data loading code
  163. traindir = os.path.join(args.data, 'train')
  164. valdir = os.path.join(args.data, 'val')
  165. if(args.arch == "inception_v3"):
  166. crop_size = 299
  167. val_size = 320 # I chose this value arbitrarily, we can adjust.
  168. else:
  169. crop_size = 224
  170. val_size = 256
  171. train_dataset = datasets.ImageFolder(
  172. traindir,
  173. transforms.Compose([
  174. transforms.RandomResizedCrop(crop_size),
  175. transforms.RandomHorizontalFlip(),
  176. # transforms.ToTensor(), Too slow
  177. # normalize,
  178. ]))
  179. val_dataset = datasets.ImageFolder(valdir, transforms.Compose([
  180. transforms.Resize(val_size),
  181. transforms.CenterCrop(crop_size),
  182. ]))
  183. train_sampler = None
  184. val_sampler = None
  185. if args.distributed:
  186. train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
  187. val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset)
  188. train_loader = torch.utils.data.DataLoader(
  189. train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
  190. num_workers=args.workers, pin_memory=True, sampler=train_sampler, collate_fn=fast_collate)
  191. val_loader = torch.utils.data.DataLoader(
  192. val_dataset,
  193. batch_size=args.batch_size, shuffle=False,
  194. num_workers=args.workers, pin_memory=True,
  195. sampler=val_sampler,
  196. collate_fn=fast_collate)
  197. if args.evaluate:
  198. validate(val_loader, model, criterion)
  199. return
  200. for epoch in range(args.start_epoch, args.epochs):
  201. if args.distributed:
  202. train_sampler.set_epoch(epoch)
  203. # train for one epoch
  204. train(train_loader, model, criterion, optimizer, epoch)
  205. if args.prof:
  206. break
  207. # evaluate on validation set
  208. prec1 = validate(val_loader, model, criterion)
  209. # remember best prec@1 and save checkpoint
  210. if args.local_rank == 0:
  211. is_best = prec1 > best_prec1
  212. best_prec1 = max(prec1, best_prec1)
  213. save_checkpoint({
  214. 'epoch': epoch + 1,
  215. 'arch': args.arch,
  216. 'state_dict': model.state_dict(),
  217. 'best_prec1': best_prec1,
  218. 'optimizer' : optimizer.state_dict(),
  219. }, is_best)
  220. class data_prefetcher():
  221. def __init__(self, loader):
  222. self.loader = iter(loader)
  223. self.stream = torch.cuda.Stream()
  224. self.mean = torch.tensor([0.485 * 255, 0.456 * 255, 0.406 * 255]).cuda().view(1,3,1,1)
  225. self.std = torch.tensor([0.229 * 255, 0.224 * 255, 0.225 * 255]).cuda().view(1,3,1,1)
  226. # With Amp, it isn't necessary to manually convert data to half.
  227. # if args.fp16:
  228. # self.mean = self.mean.half()
  229. # self.std = self.std.half()
  230. self.preload()
  231. def preload(self):
  232. try:
  233. self.next_input, self.next_target = next(self.loader)
  234. except StopIteration:
  235. self.next_input = None
  236. self.next_target = None
  237. return
  238. with torch.cuda.stream(self.stream):
  239. self.next_input = self.next_input.cuda(non_blocking=True)
  240. self.next_target = self.next_target.cuda(non_blocking=True)
  241. # With Amp, it isn't necessary to manually convert data to half.
  242. # if args.fp16:
  243. # self.next_input = self.next_input.half()
  244. # else:
  245. self.next_input = self.next_input.float()
  246. self.next_input = self.next_input.sub_(self.mean).div_(self.std)
  247. def next(self):
  248. torch.cuda.current_stream().wait_stream(self.stream)
  249. input = self.next_input
  250. target = self.next_target
  251. self.preload()
  252. return input, target
  253. def train(train_loader, model, criterion, optimizer, epoch):
  254. batch_time = AverageMeter()
  255. data_time = AverageMeter()
  256. losses = AverageMeter()
  257. top1 = AverageMeter()
  258. top5 = AverageMeter()
  259. # switch to train mode
  260. model.train()
  261. end = time.time()
  262. run_info_dict = {"Iteration" : [],
  263. "Loss" : [],
  264. "Speed" : []}
  265. prefetcher = data_prefetcher(train_loader)
  266. input, target = prefetcher.next()
  267. i = -1
  268. while input is not None:
  269. i += 1
  270. # No learning rate warmup for this test, to expose bitwise inaccuracies more quickly
  271. # adjust_learning_rate(optimizer, epoch, i, len(train_loader))
  272. if args.prof:
  273. if i > 10:
  274. break
  275. # measure data loading time
  276. data_time.update(time.time() - end)
  277. # compute output
  278. output = model(input)
  279. loss = criterion(output, target)
  280. # measure accuracy and record loss
  281. prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
  282. if args.distributed:
  283. reduced_loss = reduce_tensor(loss.data)
  284. prec1 = reduce_tensor(prec1)
  285. prec5 = reduce_tensor(prec5)
  286. else:
  287. reduced_loss = loss.data
  288. losses.update(to_python_float(reduced_loss), input.size(0))
  289. top1.update(to_python_float(prec1), input.size(0))
  290. top5.update(to_python_float(prec5), input.size(0))
  291. # compute gradient and do SGD step
  292. optimizer.zero_grad()
  293. with amp.scale_loss(loss, optimizer) as scaled_loss:
  294. scaled_loss.backward()
  295. # for param in model.parameters():
  296. # print(param.data.double().sum().item(), param.grad.data.double().sum().item())
  297. # torch.cuda.synchronize()
  298. torch.cuda.nvtx.range_push("step")
  299. optimizer.step()
  300. torch.cuda.nvtx.range_pop()
  301. torch.cuda.synchronize()
  302. # measure elapsed time
  303. batch_time.update(time.time() - end)
  304. end = time.time()
  305. # If you decide to refactor this test, like examples/imagenet, to sample the loss every
  306. # print_freq iterations, make sure to move this prefetching below the accuracy calculation.
  307. input, target = prefetcher.next()
  308. if i % args.print_freq == 0 and i > 1:
  309. if args.local_rank == 0:
  310. print('Epoch: [{0}][{1}/{2}]\t'
  311. 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
  312. 'Speed {3:.3f} ({4:.3f})\t'
  313. 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
  314. 'Loss {loss.val:.10f} ({loss.avg:.4f})\t'
  315. 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
  316. 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
  317. epoch, i, len(train_loader),
  318. args.world_size * args.batch_size / batch_time.val,
  319. args.world_size * args.batch_size / batch_time.avg,
  320. batch_time=batch_time,
  321. data_time=data_time, loss=losses, top1=top1, top5=top5))
  322. run_info_dict["Iteration"].append(i)
  323. run_info_dict["Loss"].append(losses.val)
  324. run_info_dict["Speed"].append(args.world_size * args.batch_size / batch_time.val)
  325. if len(run_info_dict["Loss"]) == args.prints_to_process:
  326. if args.local_rank == 0:
  327. torch.save(run_info_dict,
  328. str(args.has_ext) + "_" + str(args.opt_level) + "_" +
  329. str(args.loss_scale) + "_" + str(args.keep_batchnorm_fp32) + "_" +
  330. str(args.fused_adam))
  331. quit()
  332. def validate(val_loader, model, criterion):
  333. batch_time = AverageMeter()
  334. losses = AverageMeter()
  335. top1 = AverageMeter()
  336. top5 = AverageMeter()
  337. # switch to evaluate mode
  338. model.eval()
  339. end = time.time()
  340. prefetcher = data_prefetcher(val_loader)
  341. input, target = prefetcher.next()
  342. i = -1
  343. while input is not None:
  344. i += 1
  345. # compute output
  346. with torch.no_grad():
  347. output = model(input)
  348. loss = criterion(output, target)
  349. # measure accuracy and record loss
  350. prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
  351. if args.distributed:
  352. reduced_loss = reduce_tensor(loss.data)
  353. prec1 = reduce_tensor(prec1)
  354. prec5 = reduce_tensor(prec5)
  355. else:
  356. reduced_loss = loss.data
  357. losses.update(to_python_float(reduced_loss), input.size(0))
  358. top1.update(to_python_float(prec1), input.size(0))
  359. top5.update(to_python_float(prec5), input.size(0))
  360. # measure elapsed time
  361. batch_time.update(time.time() - end)
  362. end = time.time()
  363. if args.local_rank == 0 and i % args.print_freq == 0:
  364. print('Test: [{0}/{1}]\t'
  365. 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
  366. 'Speed {2:.3f} ({3:.3f})\t'
  367. 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
  368. 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
  369. 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
  370. i, len(val_loader),
  371. args.world_size * args.batch_size / batch_time.val,
  372. args.world_size * args.batch_size / batch_time.avg,
  373. batch_time=batch_time, loss=losses,
  374. top1=top1, top5=top5))
  375. input, target = prefetcher.next()
  376. print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}'
  377. .format(top1=top1, top5=top5))
  378. return top1.avg
  379. def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
  380. torch.save(state, filename)
  381. if is_best:
  382. shutil.copyfile(filename, 'model_best.pth.tar')
  383. class AverageMeter(object):
  384. """Computes and stores the average and current value"""
  385. def __init__(self):
  386. self.reset()
  387. def reset(self):
  388. self.val = 0
  389. self.avg = 0
  390. self.sum = 0
  391. self.count = 0
  392. def update(self, val, n=1):
  393. self.val = val
  394. self.sum += val * n
  395. self.count += n
  396. self.avg = self.sum / self.count
  397. def adjust_learning_rate(optimizer, epoch, step, len_epoch):
  398. """LR schedule that should yield 76% converged accuracy with batch size 256"""
  399. factor = epoch // 30
  400. if epoch >= 80:
  401. factor = factor + 1
  402. lr = args.lr*(0.1**factor)
  403. """Warmup"""
  404. if epoch < 5:
  405. lr = lr*float(1 + step + epoch*len_epoch)/(5.*len_epoch)
  406. # if(args.local_rank == 0):
  407. # print("epoch = {}, step = {}, lr = {}".format(epoch, step, lr))
  408. for param_group in optimizer.param_groups:
  409. param_group['lr'] = lr
  410. def accuracy(output, target, topk=(1,)):
  411. """Computes the precision@k for the specified values of k"""
  412. maxk = max(topk)
  413. batch_size = target.size(0)
  414. _, pred = output.topk(maxk, 1, True, True)
  415. pred = pred.t()
  416. correct = pred.eq(target.view(1, -1).expand_as(pred))
  417. res = []
  418. for k in topk:
  419. correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
  420. res.append(correct_k.mul_(100.0 / batch_size))
  421. return res
  422. def reduce_tensor(tensor):
  423. rt = tensor.clone()
  424. dist.all_reduce(rt, op=dist.reduce_op.SUM)
  425. rt /= args.world_size
  426. return rt
  427. if __name__ == '__main__':
  428. main()