123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526 |
- import argparse
- import os
- import shutil
- import time
- import torch
- import torch.nn as nn
- import torch.nn.parallel
- import torch.backends.cudnn as cudnn
- import torch.distributed as dist
- import torch.optim
- import torch.utils.data
- import torch.utils.data.distributed
- import torchvision.transforms as transforms
- import torchvision.datasets as datasets
- import torchvision.models as models
- import numpy as np
- try:
- from apex.parallel import DistributedDataParallel as DDP
- from apex.fp16_utils import *
- from apex import amp, optimizers
- from apex.multi_tensor_apply import multi_tensor_applier
- except ImportError:
- raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example.")
- model_names = sorted(name for name in models.__dict__
- if name.islower() and not name.startswith("__")
- and callable(models.__dict__[name]))
- parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')
- parser.add_argument('data', metavar='DIR',
- help='path to dataset')
- parser.add_argument('--arch', '-a', metavar='ARCH', default='resnet18',
- choices=model_names,
- help='model architecture: ' +
- ' | '.join(model_names) +
- ' (default: resnet18)')
- parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
- help='number of data loading workers (default: 4)')
- parser.add_argument('--epochs', default=90, type=int, metavar='N',
- help='number of total epochs to run')
- parser.add_argument('--start-epoch', default=0, type=int, metavar='N',
- help='manual epoch number (useful on restarts)')
- parser.add_argument('-b', '--batch-size', default=256, type=int,
- metavar='N', help='mini-batch size per process (default: 256)')
- parser.add_argument('--lr', '--learning-rate', default=0.1, type=float,
- metavar='LR', help='Initial learning rate. Will be scaled by <global batch size>/256: args.lr = args.lr*float(args.batch_size*args.world_size)/256. A warmup schedule will also be applied over the first 5 epochs.')
- parser.add_argument('--momentum', default=0.9, type=float, metavar='M',
- help='momentum')
- parser.add_argument('--weight-decay', '--wd', default=1e-4, type=float,
- metavar='W', help='weight decay (default: 1e-4)')
- parser.add_argument('--print-freq', '-p', default=10, type=int,
- metavar='N', help='print frequency (default: 10)')
- parser.add_argument('--resume', default='', type=str, metavar='PATH',
- help='path to latest checkpoint (default: none)')
- parser.add_argument('-e', '--evaluate', dest='evaluate', action='store_true',
- help='evaluate model on validation set')
- parser.add_argument('--pretrained', dest='pretrained', action='store_true',
- help='use pre-trained model')
- parser.add_argument('--prof', dest='prof', action='store_true',
- help='Only run 10 iterations for profiling.')
- parser.add_argument('--deterministic', action='store_true')
- parser.add_argument("--local_rank", default=0, type=int)
- parser.add_argument('--sync_bn', action='store_true',
- help='enabling apex sync BN.')
- parser.add_argument('--has-ext', action='store_true')
- parser.add_argument('--opt-level', type=str)
- parser.add_argument('--keep-batchnorm-fp32', type=str, default=None)
- parser.add_argument('--loss-scale', type=str, default=None)
- parser.add_argument('--fused-adam', action='store_true')
- parser.add_argument('--prints-to-process', type=int, default=10)
- cudnn.benchmark = True
- def fast_collate(batch):
- imgs = [img[0] for img in batch]
- targets = torch.tensor([target[1] for target in batch], dtype=torch.int64)
- w = imgs[0].size[0]
- h = imgs[0].size[1]
- tensor = torch.zeros( (len(imgs), 3, h, w), dtype=torch.uint8 )
- for i, img in enumerate(imgs):
- nump_array = np.asarray(img, dtype=np.uint8)
- if(nump_array.ndim < 3):
- nump_array = np.expand_dims(nump_array, axis=-1)
- nump_array = np.rollaxis(nump_array, 2)
- tensor[i] += torch.from_numpy(nump_array)
-
- return tensor, targets
- best_prec1 = 0
- args = parser.parse_args()
- # Let multi_tensor_applier be the canary in the coalmine
- # that verifies if the backend is what we think it is
- assert multi_tensor_applier.available == args.has_ext
- print("opt_level = {}".format(args.opt_level))
- print("keep_batchnorm_fp32 = {}".format(args.keep_batchnorm_fp32), type(args.keep_batchnorm_fp32))
- print("loss_scale = {}".format(args.loss_scale), type(args.loss_scale))
- print("\nCUDNN VERSION: {}\n".format(torch.backends.cudnn.version()))
- if args.deterministic:
- cudnn.benchmark = False
- cudnn.deterministic = True
- torch.manual_seed(args.local_rank)
- torch.set_printoptions(precision=10)
- def main():
- global best_prec1, args
- args.distributed = False
- if 'WORLD_SIZE' in os.environ:
- args.distributed = int(os.environ['WORLD_SIZE']) > 1
- args.gpu = 0
- args.world_size = 1
- if args.distributed:
- args.gpu = args.local_rank % torch.cuda.device_count()
- torch.cuda.set_device(args.gpu)
- torch.distributed.init_process_group(backend='nccl',
- init_method='env://')
- args.world_size = torch.distributed.get_world_size()
- assert torch.backends.cudnn.enabled, "Amp requires cudnn backend to be enabled."
- # create model
- if args.pretrained:
- print("=> using pre-trained model '{}'".format(args.arch))
- model = models.__dict__[args.arch](pretrained=True)
- else:
- print("=> creating model '{}'".format(args.arch))
- model = models.__dict__[args.arch]()
- if args.sync_bn:
- import apex
- print("using apex synced BN")
- model = apex.parallel.convert_syncbn_model(model)
- model = model.cuda()
- # Scale learning rate based on global batch size
- args.lr = args.lr*float(args.batch_size*args.world_size)/256.
- if args.fused_adam:
- optimizer = optimizers.FusedAdam(model.parameters())
- else:
- optimizer = torch.optim.SGD(model.parameters(), args.lr,
- momentum=args.momentum,
- weight_decay=args.weight_decay)
- model, optimizer = amp.initialize(
- model, optimizer,
- # enabled=False,
- opt_level=args.opt_level,
- keep_batchnorm_fp32=args.keep_batchnorm_fp32,
- loss_scale=args.loss_scale
- )
- if args.distributed:
- # By default, apex.parallel.DistributedDataParallel overlaps communication with
- # computation in the backward pass.
- # model = DDP(model)
- # delay_allreduce delays all communication to the end of the backward pass.
- model = DDP(model, delay_allreduce=True)
- # define loss function (criterion) and optimizer
- criterion = nn.CrossEntropyLoss().cuda()
- # Optionally resume from a checkpoint
- if args.resume:
- # Use a local scope to avoid dangling references
- def resume():
- if os.path.isfile(args.resume):
- print("=> loading checkpoint '{}'".format(args.resume))
- checkpoint = torch.load(args.resume, map_location = lambda storage, loc: storage.cuda(args.gpu))
- args.start_epoch = checkpoint['epoch']
- best_prec1 = checkpoint['best_prec1']
- model.load_state_dict(checkpoint['state_dict'])
- optimizer.load_state_dict(checkpoint['optimizer'])
- print("=> loaded checkpoint '{}' (epoch {})"
- .format(args.resume, checkpoint['epoch']))
- else:
- print("=> no checkpoint found at '{}'".format(args.resume))
- resume()
- # Data loading code
- traindir = os.path.join(args.data, 'train')
- valdir = os.path.join(args.data, 'val')
- if(args.arch == "inception_v3"):
- crop_size = 299
- val_size = 320 # I chose this value arbitrarily, we can adjust.
- else:
- crop_size = 224
- val_size = 256
- train_dataset = datasets.ImageFolder(
- traindir,
- transforms.Compose([
- transforms.RandomResizedCrop(crop_size),
- transforms.RandomHorizontalFlip(),
- # transforms.ToTensor(), Too slow
- # normalize,
- ]))
- val_dataset = datasets.ImageFolder(valdir, transforms.Compose([
- transforms.Resize(val_size),
- transforms.CenterCrop(crop_size),
- ]))
- train_sampler = None
- val_sampler = None
- if args.distributed:
- train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
- val_sampler = torch.utils.data.distributed.DistributedSampler(val_dataset)
- train_loader = torch.utils.data.DataLoader(
- train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
- num_workers=args.workers, pin_memory=True, sampler=train_sampler, collate_fn=fast_collate)
- val_loader = torch.utils.data.DataLoader(
- val_dataset,
- batch_size=args.batch_size, shuffle=False,
- num_workers=args.workers, pin_memory=True,
- sampler=val_sampler,
- collate_fn=fast_collate)
- if args.evaluate:
- validate(val_loader, model, criterion)
- return
- for epoch in range(args.start_epoch, args.epochs):
- if args.distributed:
- train_sampler.set_epoch(epoch)
- # train for one epoch
- train(train_loader, model, criterion, optimizer, epoch)
- if args.prof:
- break
- # evaluate on validation set
- prec1 = validate(val_loader, model, criterion)
- # remember best prec@1 and save checkpoint
- if args.local_rank == 0:
- is_best = prec1 > best_prec1
- best_prec1 = max(prec1, best_prec1)
- save_checkpoint({
- 'epoch': epoch + 1,
- 'arch': args.arch,
- 'state_dict': model.state_dict(),
- 'best_prec1': best_prec1,
- 'optimizer' : optimizer.state_dict(),
- }, is_best)
- class data_prefetcher():
- def __init__(self, loader):
- self.loader = iter(loader)
- self.stream = torch.cuda.Stream()
- self.mean = torch.tensor([0.485 * 255, 0.456 * 255, 0.406 * 255]).cuda().view(1,3,1,1)
- self.std = torch.tensor([0.229 * 255, 0.224 * 255, 0.225 * 255]).cuda().view(1,3,1,1)
- # With Amp, it isn't necessary to manually convert data to half.
- # if args.fp16:
- # self.mean = self.mean.half()
- # self.std = self.std.half()
- self.preload()
- def preload(self):
- try:
- self.next_input, self.next_target = next(self.loader)
- except StopIteration:
- self.next_input = None
- self.next_target = None
- return
- with torch.cuda.stream(self.stream):
- self.next_input = self.next_input.cuda(non_blocking=True)
- self.next_target = self.next_target.cuda(non_blocking=True)
- # With Amp, it isn't necessary to manually convert data to half.
- # if args.fp16:
- # self.next_input = self.next_input.half()
- # else:
- self.next_input = self.next_input.float()
- self.next_input = self.next_input.sub_(self.mean).div_(self.std)
-
- def next(self):
- torch.cuda.current_stream().wait_stream(self.stream)
- input = self.next_input
- target = self.next_target
- self.preload()
- return input, target
- def train(train_loader, model, criterion, optimizer, epoch):
- batch_time = AverageMeter()
- data_time = AverageMeter()
- losses = AverageMeter()
- top1 = AverageMeter()
- top5 = AverageMeter()
- # switch to train mode
- model.train()
- end = time.time()
- run_info_dict = {"Iteration" : [],
- "Loss" : [],
- "Speed" : []}
- prefetcher = data_prefetcher(train_loader)
- input, target = prefetcher.next()
- i = -1
- while input is not None:
- i += 1
- # No learning rate warmup for this test, to expose bitwise inaccuracies more quickly
- # adjust_learning_rate(optimizer, epoch, i, len(train_loader))
- if args.prof:
- if i > 10:
- break
- # measure data loading time
- data_time.update(time.time() - end)
- # compute output
- output = model(input)
- loss = criterion(output, target)
- # measure accuracy and record loss
- prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
- if args.distributed:
- reduced_loss = reduce_tensor(loss.data)
- prec1 = reduce_tensor(prec1)
- prec5 = reduce_tensor(prec5)
- else:
- reduced_loss = loss.data
- losses.update(to_python_float(reduced_loss), input.size(0))
- top1.update(to_python_float(prec1), input.size(0))
- top5.update(to_python_float(prec5), input.size(0))
- # compute gradient and do SGD step
- optimizer.zero_grad()
- with amp.scale_loss(loss, optimizer) as scaled_loss:
- scaled_loss.backward()
- # for param in model.parameters():
- # print(param.data.double().sum().item(), param.grad.data.double().sum().item())
- # torch.cuda.synchronize()
- torch.cuda.nvtx.range_push("step")
- optimizer.step()
- torch.cuda.nvtx.range_pop()
- torch.cuda.synchronize()
- # measure elapsed time
- batch_time.update(time.time() - end)
- end = time.time()
- # If you decide to refactor this test, like examples/imagenet, to sample the loss every
- # print_freq iterations, make sure to move this prefetching below the accuracy calculation.
- input, target = prefetcher.next()
- if i % args.print_freq == 0 and i > 1:
- if args.local_rank == 0:
- print('Epoch: [{0}][{1}/{2}]\t'
- 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
- 'Speed {3:.3f} ({4:.3f})\t'
- 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
- 'Loss {loss.val:.10f} ({loss.avg:.4f})\t'
- 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
- 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
- epoch, i, len(train_loader),
- args.world_size * args.batch_size / batch_time.val,
- args.world_size * args.batch_size / batch_time.avg,
- batch_time=batch_time,
- data_time=data_time, loss=losses, top1=top1, top5=top5))
- run_info_dict["Iteration"].append(i)
- run_info_dict["Loss"].append(losses.val)
- run_info_dict["Speed"].append(args.world_size * args.batch_size / batch_time.val)
- if len(run_info_dict["Loss"]) == args.prints_to_process:
- if args.local_rank == 0:
- torch.save(run_info_dict,
- str(args.has_ext) + "_" + str(args.opt_level) + "_" +
- str(args.loss_scale) + "_" + str(args.keep_batchnorm_fp32) + "_" +
- str(args.fused_adam))
- quit()
- def validate(val_loader, model, criterion):
- batch_time = AverageMeter()
- losses = AverageMeter()
- top1 = AverageMeter()
- top5 = AverageMeter()
- # switch to evaluate mode
- model.eval()
- end = time.time()
- prefetcher = data_prefetcher(val_loader)
- input, target = prefetcher.next()
- i = -1
- while input is not None:
- i += 1
- # compute output
- with torch.no_grad():
- output = model(input)
- loss = criterion(output, target)
- # measure accuracy and record loss
- prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
- if args.distributed:
- reduced_loss = reduce_tensor(loss.data)
- prec1 = reduce_tensor(prec1)
- prec5 = reduce_tensor(prec5)
- else:
- reduced_loss = loss.data
- losses.update(to_python_float(reduced_loss), input.size(0))
- top1.update(to_python_float(prec1), input.size(0))
- top5.update(to_python_float(prec5), input.size(0))
- # measure elapsed time
- batch_time.update(time.time() - end)
- end = time.time()
- if args.local_rank == 0 and i % args.print_freq == 0:
- print('Test: [{0}/{1}]\t'
- 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
- 'Speed {2:.3f} ({3:.3f})\t'
- 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
- 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
- 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format(
- i, len(val_loader),
- args.world_size * args.batch_size / batch_time.val,
- args.world_size * args.batch_size / batch_time.avg,
- batch_time=batch_time, loss=losses,
- top1=top1, top5=top5))
- input, target = prefetcher.next()
- print(' * Prec@1 {top1.avg:.3f} Prec@5 {top5.avg:.3f}'
- .format(top1=top1, top5=top5))
- return top1.avg
- def save_checkpoint(state, is_best, filename='checkpoint.pth.tar'):
- torch.save(state, filename)
- if is_best:
- shutil.copyfile(filename, 'model_best.pth.tar')
- class AverageMeter(object):
- """Computes and stores the average and current value"""
- def __init__(self):
- self.reset()
- def reset(self):
- self.val = 0
- self.avg = 0
- self.sum = 0
- self.count = 0
- def update(self, val, n=1):
- self.val = val
- self.sum += val * n
- self.count += n
- self.avg = self.sum / self.count
- def adjust_learning_rate(optimizer, epoch, step, len_epoch):
- """LR schedule that should yield 76% converged accuracy with batch size 256"""
- factor = epoch // 30
- if epoch >= 80:
- factor = factor + 1
- lr = args.lr*(0.1**factor)
- """Warmup"""
- if epoch < 5:
- lr = lr*float(1 + step + epoch*len_epoch)/(5.*len_epoch)
- # if(args.local_rank == 0):
- # print("epoch = {}, step = {}, lr = {}".format(epoch, step, lr))
- for param_group in optimizer.param_groups:
- param_group['lr'] = lr
- def accuracy(output, target, topk=(1,)):
- """Computes the precision@k for the specified values of k"""
- maxk = max(topk)
- batch_size = target.size(0)
- _, pred = output.topk(maxk, 1, True, True)
- pred = pred.t()
- correct = pred.eq(target.view(1, -1).expand_as(pred))
- res = []
- for k in topk:
- correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
- res.append(correct_k.mul_(100.0 / batch_size))
- return res
- def reduce_tensor(tensor):
- rt = tensor.clone()
- dist.all_reduce(rt, op=dist.reduce_op.SUM)
- rt /= args.world_size
- return rt
- if __name__ == '__main__':
- main()
|