123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274 |
- from __future__ import print_function
- import argparse
- import os
- import random
- import torch
- import torch.nn as nn
- import torch.nn.parallel
- import torch.backends.cudnn as cudnn
- import torch.optim as optim
- import torch.utils.data
- import torchvision.datasets as dset
- import torchvision.transforms as transforms
- import torchvision.utils as vutils
- try:
- from apex import amp
- except ImportError:
- raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example.")
- parser = argparse.ArgumentParser()
- parser.add_argument('--dataset', default='cifar10', help='cifar10 | lsun | mnist |imagenet | folder | lfw | fake')
- parser.add_argument('--dataroot', default='./', help='path to dataset')
- parser.add_argument('--workers', type=int, help='number of data loading workers', default=2)
- parser.add_argument('--batchSize', type=int, default=64, help='input batch size')
- parser.add_argument('--imageSize', type=int, default=64, help='the height / width of the input image to network')
- parser.add_argument('--nz', type=int, default=100, help='size of the latent z vector')
- parser.add_argument('--ngf', type=int, default=64)
- parser.add_argument('--ndf', type=int, default=64)
- parser.add_argument('--niter', type=int, default=25, help='number of epochs to train for')
- parser.add_argument('--lr', type=float, default=0.0002, help='learning rate, default=0.0002')
- parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5')
- parser.add_argument('--ngpu', type=int, default=1, help='number of GPUs to use')
- parser.add_argument('--netG', default='', help="path to netG (to continue training)")
- parser.add_argument('--netD', default='', help="path to netD (to continue training)")
- parser.add_argument('--outf', default='.', help='folder to output images and model checkpoints')
- parser.add_argument('--manualSeed', type=int, help='manual seed')
- parser.add_argument('--classes', default='bedroom', help='comma separated list of classes for the lsun data set')
- parser.add_argument('--opt_level', default='O1', help='amp opt_level, default="O1"')
- opt = parser.parse_args()
- print(opt)
- try:
- os.makedirs(opt.outf)
- except OSError:
- pass
- if opt.manualSeed is None:
- opt.manualSeed = 2809
- print("Random Seed: ", opt.manualSeed)
- random.seed(opt.manualSeed)
- torch.manual_seed(opt.manualSeed)
- cudnn.benchmark = True
- if opt.dataset in ['imagenet', 'folder', 'lfw']:
- # folder dataset
- dataset = dset.ImageFolder(root=opt.dataroot,
- transform=transforms.Compose([
- transforms.Resize(opt.imageSize),
- transforms.CenterCrop(opt.imageSize),
- transforms.ToTensor(),
- transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
- ]))
- nc=3
- elif opt.dataset == 'lsun':
- classes = [ c + '_train' for c in opt.classes.split(',')]
- dataset = dset.LSUN(root=opt.dataroot, classes=classes,
- transform=transforms.Compose([
- transforms.Resize(opt.imageSize),
- transforms.CenterCrop(opt.imageSize),
- transforms.ToTensor(),
- transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
- ]))
- nc=3
- elif opt.dataset == 'cifar10':
- dataset = dset.CIFAR10(root=opt.dataroot, download=True,
- transform=transforms.Compose([
- transforms.Resize(opt.imageSize),
- transforms.ToTensor(),
- transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
- ]))
- nc=3
- elif opt.dataset == 'mnist':
- dataset = dset.MNIST(root=opt.dataroot, download=True,
- transform=transforms.Compose([
- transforms.Resize(opt.imageSize),
- transforms.ToTensor(),
- transforms.Normalize((0.5,), (0.5,)),
- ]))
- nc=1
- elif opt.dataset == 'fake':
- dataset = dset.FakeData(image_size=(3, opt.imageSize, opt.imageSize),
- transform=transforms.ToTensor())
- nc=3
- assert dataset
- dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batchSize,
- shuffle=True, num_workers=int(opt.workers))
- device = torch.device("cuda:0")
- ngpu = int(opt.ngpu)
- nz = int(opt.nz)
- ngf = int(opt.ngf)
- ndf = int(opt.ndf)
- # custom weights initialization called on netG and netD
- def weights_init(m):
- classname = m.__class__.__name__
- if classname.find('Conv') != -1:
- m.weight.data.normal_(0.0, 0.02)
- elif classname.find('BatchNorm') != -1:
- m.weight.data.normal_(1.0, 0.02)
- m.bias.data.fill_(0)
- class Generator(nn.Module):
- def __init__(self, ngpu):
- super(Generator, self).__init__()
- self.ngpu = ngpu
- self.main = nn.Sequential(
- # input is Z, going into a convolution
- nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=False),
- nn.BatchNorm2d(ngf * 8),
- nn.ReLU(True),
- # state size. (ngf*8) x 4 x 4
- nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
- nn.BatchNorm2d(ngf * 4),
- nn.ReLU(True),
- # state size. (ngf*4) x 8 x 8
- nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
- nn.BatchNorm2d(ngf * 2),
- nn.ReLU(True),
- # state size. (ngf*2) x 16 x 16
- nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
- nn.BatchNorm2d(ngf),
- nn.ReLU(True),
- # state size. (ngf) x 32 x 32
- nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),
- nn.Tanh()
- # state size. (nc) x 64 x 64
- )
- def forward(self, input):
- if input.is_cuda and self.ngpu > 1:
- output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
- else:
- output = self.main(input)
- return output
- netG = Generator(ngpu).to(device)
- netG.apply(weights_init)
- if opt.netG != '':
- netG.load_state_dict(torch.load(opt.netG))
- print(netG)
- class Discriminator(nn.Module):
- def __init__(self, ngpu):
- super(Discriminator, self).__init__()
- self.ngpu = ngpu
- self.main = nn.Sequential(
- # input is (nc) x 64 x 64
- nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
- nn.LeakyReLU(0.2, inplace=True),
- # state size. (ndf) x 32 x 32
- nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
- nn.BatchNorm2d(ndf * 2),
- nn.LeakyReLU(0.2, inplace=True),
- # state size. (ndf*2) x 16 x 16
- nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
- nn.BatchNorm2d(ndf * 4),
- nn.LeakyReLU(0.2, inplace=True),
- # state size. (ndf*4) x 8 x 8
- nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
- nn.BatchNorm2d(ndf * 8),
- nn.LeakyReLU(0.2, inplace=True),
- # state size. (ndf*8) x 4 x 4
- nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
- )
- def forward(self, input):
- if input.is_cuda and self.ngpu > 1:
- output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
- else:
- output = self.main(input)
- return output.view(-1, 1).squeeze(1)
- netD = Discriminator(ngpu).to(device)
- netD.apply(weights_init)
- if opt.netD != '':
- netD.load_state_dict(torch.load(opt.netD))
- print(netD)
- criterion = nn.BCEWithLogitsLoss()
- fixed_noise = torch.randn(opt.batchSize, nz, 1, 1, device=device)
- real_label = 1
- fake_label = 0
- # setup optimizer
- optimizerD = optim.Adam(netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
- optimizerG = optim.Adam(netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
- [netD, netG], [optimizerD, optimizerG] = amp.initialize(
- [netD, netG], [optimizerD, optimizerG], opt_level=opt.opt_level, num_losses=3)
- for epoch in range(opt.niter):
- for i, data in enumerate(dataloader, 0):
- ############################
- # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
- ###########################
- # train with real
- netD.zero_grad()
- real_cpu = data[0].to(device)
- batch_size = real_cpu.size(0)
- label = torch.full((batch_size,), real_label, device=device)
- output = netD(real_cpu)
- errD_real = criterion(output, label)
- with amp.scale_loss(errD_real, optimizerD, loss_id=0) as errD_real_scaled:
- errD_real_scaled.backward()
- D_x = output.mean().item()
- # train with fake
- noise = torch.randn(batch_size, nz, 1, 1, device=device)
- fake = netG(noise)
- label.fill_(fake_label)
- output = netD(fake.detach())
- errD_fake = criterion(output, label)
- with amp.scale_loss(errD_fake, optimizerD, loss_id=1) as errD_fake_scaled:
- errD_fake_scaled.backward()
- D_G_z1 = output.mean().item()
- errD = errD_real + errD_fake
- optimizerD.step()
- ############################
- # (2) Update G network: maximize log(D(G(z)))
- ###########################
- netG.zero_grad()
- label.fill_(real_label) # fake labels are real for generator cost
- output = netD(fake)
- errG = criterion(output, label)
- with amp.scale_loss(errG, optimizerG, loss_id=2) as errG_scaled:
- errG_scaled.backward()
- D_G_z2 = output.mean().item()
- optimizerG.step()
- print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f'
- % (epoch, opt.niter, i, len(dataloader),
- errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
- if i % 100 == 0:
- vutils.save_image(real_cpu,
- '%s/real_samples.png' % opt.outf,
- normalize=True)
- fake = netG(fixed_noise)
- vutils.save_image(fake.detach(),
- '%s/amp_fake_samples_epoch_%03d.png' % (opt.outf, epoch),
- normalize=True)
- # do checkpointing
- torch.save(netG.state_dict(), '%s/netG_epoch_%d.pth' % (opt.outf, epoch))
- torch.save(netD.state_dict(), '%s/netD_epoch_%d.pth' % (opt.outf, epoch))
|