main_amp.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543
  1. import argparse
  2. import os
  3. import shutil
  4. import time
  5. import torch
  6. import torch.nn as nn
  7. import torch.nn.parallel
  8. import torch.backends.cudnn as cudnn
  9. import torch.distributed as dist
  10. import torch.optim
  11. import torch.utils.data
  12. import torch.utils.data.distributed
  13. import torchvision.transforms as transforms
  14. import torchvision.datasets as datasets
  15. import torchvision.models as models
  16. import numpy as np
  17. try:
  18. from apex.parallel import DistributedDataParallel as DDP
  19. from apex.fp16_utils import *
  20. from apex import amp, optimizers
  21. from apex.multi_tensor_apply import multi_tensor_applier
  22. except ImportError:
  23. raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example.")
  24. def fast_collate(batch, memory_format):
  25. imgs = [img[0] for img in batch]
  26. targets = torch.tensor([target[1] for target in batch], dtype=torch.int64)
  27. w = imgs[0].size[0]
  28. h = imgs[0].size[1]
  29. tensor = torch.zeros( (len(imgs), 3, h, w), dtype=torch.uint8).contiguous(memory_format=memory_format)
  30. for i, img in enumerate(imgs):
  31. nump_array = np.asarray(img, dtype=np.uint8)
  32. if(nump_array.ndim < 3):
  33. nump_array = np.expand_dims(nump_array, axis=-1)
  34. nump_array = np.rollaxis(nump_array, 2)
  35. tensor[i] += torch.from_numpy(nump_array)
  36. return tensor, targets
  37. def parse():
  38. model_names = sorted(name for name in models.__dict__
  39. if name.islower() and not name.startswith("__")
  40. and callable(models.__dict__[name]))
  41. parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
  42. parser.add_argument('data', metavar='DIR',
  43. help='path to dataset')
  44. parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18',
  45. choices=model_names,
  46. help='model architecture: ' +
  47. ' | '.join(model_names) +
  48. ' (default: resnet18)')
  49. parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
  50. help='number of data loading workers (default: 4)')
  51. parser.add_argument('--epochs', default=90, type=int, metavar='N',
  52. help='number of total epochs to run')
  53. parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
  54. help='manual epoch number (useful on restarts)')
  55. parser.add_argument('-b', '--batch-size', default=256, type=int,
  56. metavar='N', help='mini-batch size per process (default: 256)')
  57. parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
  58. metavar='LR', help='Initial learning rate. Will be scaled by <global batch size>/256: args.lr = args.lr*float(args.batch_size*args.world_size)/256. A warmup schedule will also be applied over the first 5 epochs.')
  59. parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
  60. help='momentum')
  61. parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
  62. metavar='W', help='weight decay (default: 1e-4)')
  63. parser.add_argument('--print-freq', '-p', default=10, type=int,
  64. metavar='N', help='print frequency (default: 10)')
  65. parser.add_argument('--resume', default='', type=str, metavar='PATH',
  66. help='path to latest checkpoint (default: none)')
  67. parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
  68. help='evaluate model on validation set')
  69. parser.add_argument('--pretrained', dest='pretrained', action='store_true',
  70. help='use pre-trained model')
  71. parser.add_argument('--prof', default=-1, type=int,
  72. help='Only run 10 iterations for profiling.')
  73. parser.add_argument('--deterministic', action='store_true')
  74. parser.add_argument("--local_rank", default=os.getenv('LOCAL_RANK', 0), type=int)
  75. parser.add_argument('--sync_bn', action='store_true',
  76. help='enabling apex sync BN.')
  77. parser.add_argument('--opt-level', type=str)
  78. parser.add_argument('--keep-batchnorm-fp32', type=str, default=None)
  79. parser.add_argument('--loss-scale', type=str, default=None)
  80. parser.add_argument('--channels-last', type=bool, default=False)
  81. args = parser.parse_args()
  82. return args
  83. def main():
  84. global best_prec1, args
  85. args = parse()
  86. print("opt_level = {}".format(args.opt_level))
  87. print("keep_batchnorm_fp32 = {}".format(args.keep_batchnorm_fp32), type(args.keep_batchnorm_fp32))
  88. print("loss_scale = {}".format(args.loss_scale), type(args.loss_scale))
  89. print("\nCUDNN VERSION: {}\n".format(torch.backends.cudnn.version()))
  90. cudnn.benchmark = True
  91. best_prec1 = 0
  92. if args.deterministic:
  93. cudnn.benchmark = False
  94. cudnn.deterministic = True
  95. torch.manual_seed(args.local_rank)
  96. torch.set_printoptions(precision=10)
  97. args.distributed = False
  98. if 'WORLD_SIZE' in os.environ:
  99. args.distributed = int(os.environ['WORLD_SIZE']) > 1
  100. args.gpu = 0
  101. args.world_size = 1
  102. if args.distributed:
  103. args.gpu = args.local_rank
  104. torch.cuda.set_device(args.gpu)
  105. torch.distributed.init_process_group(backend='nccl',
  106. init_method='env://')
  107. args.world_size = torch.distributed.get_world_size()
  108. assert torch.backends.cudnn.enabled, "Amp requires cudnn backend to be enabled."
  109. if args.channels_last:
  110. memory_format = torch.channels_last
  111. else:
  112. memory_format = torch.contiguous_format
  113. # create model
  114. if args.pretrained:
  115. print("=> using pre-trained model '{}'".format(args.arch))
  116. model = models.__dict__[args.arch](pretrained=True)
  117. else:
  118. print("=> creating model '{}'".format(args.arch))
  119. model = models.__dict__[args.arch]()
  120. if args.sync_bn:
  121. import apex
  122. print("using apex synced BN")
  123. model = apex.parallel.convert_syncbn_model(model)
  124. model = model.cuda().to(memory_format=memory_format)
  125. # Scale learning rate based on global batch size
  126. args.lr = args.lr*float(args.batch_size*args.world_size)/256.
  127. optimizer = torch.optim.SGD(model.parameters(), args.lr,
  128. momentum=args.momentum,
  129. weight_decay=args.weight_decay)
  130. # Initialize Amp. Amp accepts either values or strings for the optional override arguments,
  131. # for convenient interoperation with argparse.
  132. model, optimizer = amp.initialize(model, optimizer,
  133. opt_level=args.opt_level,
  134. keep_batchnorm_fp32=args.keep_batchnorm_fp32,
  135. loss_scale=args.loss_scale
  136. )
  137. # For distributed training, wrap the model with apex.parallel.DistributedDataParallel.
  138. # This must be done AFTER the call to amp.initialize. If model = DDP(model) is called
  139. # before model, ... = amp.initialize(model, ...), the call to amp.initialize may alter
  140. # the types of model's parameters in a way that disrupts or destroys DDP's allreduce hooks.
  141. if args.distributed:
  142. # By default, apex.parallel.DistributedDataParallel overlaps communication with
  143. # computation in the backward pass.
  144. # model = DDP(model)
  145. # delay_allreduce delays all communication to the end of the backward pass.
  146. model = DDP(model, delay_allreduce=True)
  147. # define loss function (criterion) and optimizer
  148. criterion = nn.CrossEntropyLoss().cuda()
  149. # Optionally resume from a checkpoint
  150. if args.resume:
  151. # Use a local scope to avoid dangling references
  152. def resume():
  153. if os.path.isfile(args.resume):
  154. print("=> loading checkpoint '{}'".format(args.resume))
  155. checkpoint = torch.load(args.resume, map_location = lambda storage, loc: storage.cuda(args.gpu))
  156. args.start_epoch = checkpoint['epoch']
  157. global best_prec1
  158. best_prec1 = checkpoint['best_prec1']
  159. model.load_state_dict(checkpoint['state_dict'])
  160. optimizer.load_state_dict(checkpoint['optimizer'])
  161. print("=> loaded checkpoint '{}' (epoch {})"
  162. .format(args.resume, checkpoint['epoch']))
  163. else:
  164. print("=> no checkpoint found at '{}'".format(args.resume))
  165. resume()
  166. # Data loading code
  167. traindir = os.path.join(args.data, 'train')
  168. valdir = os.path.join(args.data, 'val')
  169. if(args.arch == "inception_v3"):
  170. raise RuntimeError("Currently, inception_v3 is not supported by this example.")
  171. # crop_size = 299
  172. # val_size = 320 # I chose this value arbitrarily, we can adjust.
  173. else:
  174. crop_size = 224
  175. val_size = 256
  176. train_dataset = datasets.ImageFolder(
  177. traindir,
  178. transforms.Compose([
  179. transforms.RandomResizedCrop(crop_size),
  180. transforms.RandomHorizontalFlip(),
  181. # transforms.ToTensor(), Too slow
  182. # normalize,
  183. ]))
  184. val_dataset = datasets.ImageFolder(valdir, transforms.Compose([
  185. transforms.Resize(val_size),
  186. transforms.CenterCrop(crop_size),
  187. ]))
  188. train_sampler = None
  189. val_sampler = None
  190. if args.distributed:
  191. train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
  192. val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset)
  193. collate_fn = lambda b: fast_collate(b, memory_format)
  194. train_loader = torch.utils.data.DataLoader(
  195. train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
  196. num_workers=args.workers, pin_memory=True, sampler=train_sampler, collate_fn=collate_fn)
  197. val_loader = torch.utils.data.DataLoader(
  198. val_dataset,
  199. batch_size=args.batch_size, shuffle=False,
  200. num_workers=args.workers, pin_memory=True,
  201. sampler=val_sampler,
  202. collate_fn=collate_fn)
  203. if args.evaluate:
  204. validate(val_loader, model, criterion)
  205. return
  206. for epoch in range(args.start_epoch, args.epochs):
  207. if args.distributed:
  208. train_sampler.set_epoch(epoch)
  209. # train for one epoch
  210. train(train_loader, model, criterion, optimizer, epoch)
  211. # evaluate on validation set
  212. prec1 = validate(val_loader, model, criterion)
  213. # remember best prec@1 and save checkpoint
  214. if args.local_rank == 0:
  215. is_best = prec1 > best_prec1
  216. best_prec1 = max(prec1, best_prec1)
  217. save_checkpoint({
  218. 'epoch': epoch + 1,
  219. 'arch': args.arch,
  220. 'state_dict': model.state_dict(),
  221. 'best_prec1': best_prec1,
  222. 'optimizer' : optimizer.state_dict(),
  223. }, is_best)
  224. class data_prefetcher():
  225. def __init__(self, loader):
  226. self.loader = iter(loader)
  227. self.stream = torch.cuda.Stream()
  228. self.mean = torch.tensor([0.485 * 255, 0.456 * 255, 0.406 * 255]).cuda().view(1,3,1,1)
  229. self.std = torch.tensor([0.229 * 255, 0.224 * 255, 0.225 * 255]).cuda().view(1,3,1,1)
  230. # With Amp, it isn't necessary to manually convert data to half.
  231. # if args.fp16:
  232. # self.mean = self.mean.half()
  233. # self.std = self.std.half()
  234. self.preload()
  235. def preload(self):
  236. try:
  237. self.next_input, self.next_target = next(self.loader)
  238. except StopIteration:
  239. self.next_input = None
  240. self.next_target = None
  241. return
  242. # if record_stream() doesn't work, another option is to make sure device inputs are created
  243. # on the main stream.
  244. # self.next_input_gpu = torch.empty_like(self.next_input, device='cuda')
  245. # self.next_target_gpu = torch.empty_like(self.next_target, device='cuda')
  246. # Need to make sure the memory allocated for next_* is not still in use by the main stream
  247. # at the time we start copying to next_*:
  248. # self.stream.wait_stream(torch.cuda.current_stream())
  249. with torch.cuda.stream(self.stream):
  250. self.next_input = self.next_input.cuda(non_blocking=True)
  251. self.next_target = self.next_target.cuda(non_blocking=True)
  252. # more code for the alternative if record_stream() doesn't work:
  253. # copy_ will record the use of the pinned source tensor in this side stream.
  254. # self.next_input_gpu.copy_(self.next_input, non_blocking=True)
  255. # self.next_target_gpu.copy_(self.next_target, non_blocking=True)
  256. # self.next_input = self.next_input_gpu
  257. # self.next_target = self.next_target_gpu
  258. # With Amp, it isn't necessary to manually convert data to half.
  259. # if args.fp16:
  260. # self.next_input = self.next_input.half()
  261. # else:
  262. self.next_input = self.next_input.float()
  263. self.next_input = self.next_input.sub_(self.mean).div_(self.std)
  264. def next(self):
  265. torch.cuda.current_stream().wait_stream(self.stream)
  266. input = self.next_input
  267. target = self.next_target
  268. if input is not None:
  269. input.record_stream(torch.cuda.current_stream())
  270. if target is not None:
  271. target.record_stream(torch.cuda.current_stream())
  272. self.preload()
  273. return input, target
  274. def train(train_loader, model, criterion, optimizer, epoch):
  275. batch_time = AverageMeter()
  276. losses = AverageMeter()
  277. top1 = AverageMeter()
  278. top5 = AverageMeter()
  279. # switch to train mode
  280. model.train()
  281. end = time.time()
  282. prefetcher = data_prefetcher(train_loader)
  283. input, target = prefetcher.next()
  284. i = 0
  285. while input is not None:
  286. i += 1
  287. if args.prof >= 0 and i == args.prof:
  288. print("Profiling begun at iteration {}".format(i))
  289. torch.cuda.cudart().cudaProfilerStart()
  290. if args.prof >= 0: torch.cuda.nvtx.range_push("Body of iteration {}".format(i))
  291. adjust_learning_rate(optimizer, epoch, i, len(train_loader))
  292. # compute output
  293. if args.prof >= 0: torch.cuda.nvtx.range_push("forward")
  294. output = model(input)
  295. if args.prof >= 0: torch.cuda.nvtx.range_pop()
  296. loss = criterion(output, target)
  297. # compute gradient and do SGD step
  298. optimizer.zero_grad()
  299. if args.prof >= 0: torch.cuda.nvtx.range_push("backward")
  300. with amp.scale_loss(loss, optimizer) as scaled_loss:
  301. scaled_loss.backward()
  302. if args.prof >= 0: torch.cuda.nvtx.range_pop()
  303. # for param in model.parameters():
  304. # print(param.data.double().sum().item(), param.grad.data.double().sum().item())
  305. if args.prof >= 0: torch.cuda.nvtx.range_push("optimizer.step()")
  306. optimizer.step()
  307. if args.prof >= 0: torch.cuda.nvtx.range_pop()
  308. if i%args.print_freq == 0:
  309. # Every print_freq iterations, check the loss, accuracy, and speed.
  310. # For best performance, it doesn't make sense to print these metrics every
  311. # iteration, since they incur an allreduce and some host<->device syncs.
  312. # Measure accuracy
  313. prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
  314. # Average loss and accuracy across processes for logging
  315. if args.distributed:
  316. reduced_loss = reduce_tensor(loss.data)
  317. prec1 = reduce_tensor(prec1)
  318. prec5 = reduce_tensor(prec5)
  319. else:
  320. reduced_loss = loss.data
  321. # to_python_float incurs a host<->device sync
  322. losses.update(to_python_float(reduced_loss), input.size(0))
  323. top1.update(to_python_float(prec1), input.size(0))
  324. top5.update(to_python_float(prec5), input.size(0))
  325. torch.cuda.synchronize()
  326. batch_time.update((time.time() - end)/args.print_freq)
  327. end = time.time()
  328. if args.local_rank == 0:
  329. print('Epoch: [{0}][{1}/{2}]\t'
  330. 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
  331. 'Speed {3:.3f} ({4:.3f})\t'
  332. 'Loss {loss.val:.10f} ({loss.avg:.4f})\t'
  333. 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
  334. 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
  335. epoch, i, len(train_loader),
  336. args.world_size*args.batch_size/batch_time.val,
  337. args.world_size*args.batch_size/batch_time.avg,
  338. batch_time=batch_time,
  339. loss=losses, top1=top1, top5=top5))
  340. if args.prof >= 0: torch.cuda.nvtx.range_push("prefetcher.next()")
  341. input, target = prefetcher.next()
  342. if args.prof >= 0: torch.cuda.nvtx.range_pop()
  343. # Pop range "Body of iteration {}".format(i)
  344. if args.prof >= 0: torch.cuda.nvtx.range_pop()
  345. if args.prof >= 0 and i == args.prof + 10:
  346. print("Profiling ended at iteration {}".format(i))
  347. torch.cuda.cudart().cudaProfilerStop()
  348. quit()
  349. def validate(val_loader, model, criterion):
  350. batch_time = AverageMeter()
  351. losses = AverageMeter()
  352. top1 = AverageMeter()
  353. top5 = AverageMeter()
  354. # switch to evaluate mode
  355. model.eval()
  356. end = time.time()
  357. prefetcher = data_prefetcher(val_loader)
  358. input, target = prefetcher.next()
  359. i = 0
  360. while input is not None:
  361. i += 1
  362. # compute output
  363. with torch.no_grad():
  364. output = model(input)
  365. loss = criterion(output, target)
  366. # measure accuracy and record loss
  367. prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
  368. if args.distributed:
  369. reduced_loss = reduce_tensor(loss.data)
  370. prec1 = reduce_tensor(prec1)
  371. prec5 = reduce_tensor(prec5)
  372. else:
  373. reduced_loss = loss.data
  374. losses.update(to_python_float(reduced_loss), input.size(0))
  375. top1.update(to_python_float(prec1), input.size(0))
  376. top5.update(to_python_float(prec5), input.size(0))
  377. # measure elapsed time
  378. batch_time.update(time.time() - end)
  379. end = time.time()
  380. # TODO: Change timings to mirror train().
  381. if args.local_rank == 0 and i % args.print_freq == 0:
  382. print('Test: [{0}/{1}]\t'
  383. 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
  384. 'Speed {2:.3f} ({3:.3f})\t'
  385. 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
  386. 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
  387. 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
  388. i, len(val_loader),
  389. args.world_size * args.batch_size / batch_time.val,
  390. args.world_size * args.batch_size / batch_time.avg,
  391. batch_time=batch_time, loss=losses,
  392. top1=top1, top5=top5))
  393. input, target = prefetcher.next()
  394. print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}'
  395. .format(top1=top1, top5=top5))
  396. return top1.avg
  397. def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
  398. torch.save(state, filename)
  399. if is_best:
  400. shutil.copyfile(filename, 'model_best.pth.tar')
  401. class AverageMeter(object):
  402. """Computes and stores the average and current value"""
  403. def __init__(self):
  404. self.reset()
  405. def reset(self):
  406. self.val = 0
  407. self.avg = 0
  408. self.sum = 0
  409. self.count = 0
  410. def update(self, val, n=1):
  411. self.val = val
  412. self.sum += val * n
  413. self.count += n
  414. self.avg = self.sum / self.count
  415. def adjust_learning_rate(optimizer, epoch, step, len_epoch):
  416. """LR schedule that should yield 76% converged accuracy with batch size 256"""
  417. factor = epoch // 30
  418. if epoch >= 80:
  419. factor = factor + 1
  420. lr = args.lr*(0.1**factor)
  421. """Warmup"""
  422. if epoch < 5:
  423. lr = lr*float(1 + step + epoch*len_epoch)/(5.*len_epoch)
  424. # if(args.local_rank == 0):
  425. # print("epoch = {}, step = {}, lr = {}".format(epoch, step, lr))
  426. for param_group in optimizer.param_groups:
  427. param_group['lr'] = lr
  428. def accuracy(output, target, topk=(1,)):
  429. """Computes the precision@k for the specified values of k"""
  430. maxk = max(topk)
  431. batch_size = target.size(0)
  432. _, pred = output.topk(maxk, 1, True, True)
  433. pred = pred.t()
  434. correct = pred.eq(target.view(1, -1).expand_as(pred))
  435. res = []
  436. for k in topk:
  437. correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
  438. res.append(correct_k.mul_(100.0 / batch_size))
  439. return res
  440. def reduce_tensor(tensor):
  441. rt = tensor.clone()
  442. dist.all_reduce(rt, op=dist.reduce_op.SUM)
  443. rt /= args.world_size
  444. return rt
  445. if __name__ == '__main__':
  446. main()