from __future__ import print_function import argparse import os import random import torch import torch.nn as nn import torch.nn.parallel import torch.backends.cudnn as cudnn import torch.optim as optim import torch.utils.data import torchvision.datasets as dset import torchvision.transforms as transforms import torchvision.utils as vutils try: from apex import amp except ImportError: raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example.") parser = argparse.ArgumentParser() parser.add_argument('--dataset', default='cifar10', help='cifar10 | lsun | mnist |imagenet | folder | lfw | fake') parser.add_argument('--dataroot', default='./', help='path to dataset') parser.add_argument('--workers', type=int, help='number of data loading workers', default=2) parser.add_argument('--batchSize', type=int, default=64, help='input batch size') parser.add_argument('--imageSize', type=int, default=64, help='the height / width of the input image to network') parser.add_argument('--nz', type=int, default=100, help='size of the latent z vector') parser.add_argument('--ngf', type=int, default=64) parser.add_argument('--ndf', type=int, default=64) parser.add_argument('--niter', type=int, default=25, help='number of epochs to train for') parser.add_argument('--lr', type=float, default=0.0002, help='learning rate, default=0.0002') parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5') parser.add_argument('--ngpu', type=int, default=1, help='number of GPUs to use') parser.add_argument('--netG', default='', help="path to netG (to continue training)") parser.add_argument('--netD', default='', help="path to netD (to continue training)") parser.add_argument('--outf', default='.', help='folder to output images and model checkpoints') parser.add_argument('--manualSeed', type=int, help='manual seed') parser.add_argument('--classes', default='bedroom', help='comma separated list of classes for the lsun data set') parser.add_argument('--opt_level', default='O1', help='amp opt_level, default="O1"') opt = parser.parse_args() print(opt) try: os.makedirs(opt.outf) except OSError: pass if opt.manualSeed is None: opt.manualSeed = 2809 print("Random Seed: ", opt.manualSeed) random.seed(opt.manualSeed) torch.manual_seed(opt.manualSeed) cudnn.benchmark = True if opt.dataset in ['imagenet', 'folder', 'lfw']: # folder dataset dataset = dset.ImageFolder(root=opt.dataroot, transform=transforms.Compose([ transforms.Resize(opt.imageSize), transforms.CenterCrop(opt.imageSize), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ])) nc=3 elif opt.dataset == 'lsun': classes = [ c + '_train' for c in opt.classes.split(',')] dataset = dset.LSUN(root=opt.dataroot, classes=classes, transform=transforms.Compose([ transforms.Resize(opt.imageSize), transforms.CenterCrop(opt.imageSize), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ])) nc=3 elif opt.dataset == 'cifar10': dataset = dset.CIFAR10(root=opt.dataroot, download=True, transform=transforms.Compose([ transforms.Resize(opt.imageSize), transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ])) nc=3 elif opt.dataset == 'mnist': dataset = dset.MNIST(root=opt.dataroot, download=True, transform=transforms.Compose([ transforms.Resize(opt.imageSize), transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)), ])) nc=1 elif opt.dataset == 'fake': dataset = dset.FakeData(image_size=(3, opt.imageSize, opt.imageSize), transform=transforms.ToTensor()) nc=3 assert dataset dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batchSize, shuffle=True, num_workers=int(opt.workers)) device = torch.device("cuda:0") ngpu = int(opt.ngpu) nz = int(opt.nz) ngf = int(opt.ngf) ndf = int(opt.ndf) # custom weights initialization called on netG and netD def weights_init(m): classname = m.__class__.__name__ if classname.find('Conv') != -1: m.weight.data.normal_(0.0, 0.02) elif classname.find('BatchNorm') != -1: m.weight.data.normal_(1.0, 0.02) m.bias.data.fill_(0) class Generator(nn.Module): def __init__(self, ngpu): super(Generator, self).__init__() self.ngpu = ngpu self.main = nn.Sequential( # input is Z, going into a convolution nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=False), nn.BatchNorm2d(ngf * 8), nn.ReLU(True), # state size. (ngf*8) x 4 x 4 nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False), nn.BatchNorm2d(ngf * 4), nn.ReLU(True), # state size. (ngf*4) x 8 x 8 nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False), nn.BatchNorm2d(ngf * 2), nn.ReLU(True), # state size. (ngf*2) x 16 x 16 nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False), nn.BatchNorm2d(ngf), nn.ReLU(True), # state size. (ngf) x 32 x 32 nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False), nn.Tanh() # state size. (nc) x 64 x 64 ) def forward(self, input): if input.is_cuda and self.ngpu > 1: output = nn.parallel.data_parallel(self.main, input, range(self.ngpu)) else: output = self.main(input) return output netG = Generator(ngpu).to(device) netG.apply(weights_init) if opt.netG != '': netG.load_state_dict(torch.load(opt.netG)) print(netG) class Discriminator(nn.Module): def __init__(self, ngpu): super(Discriminator, self).__init__() self.ngpu = ngpu self.main = nn.Sequential( # input is (nc) x 64 x 64 nn.Conv2d(nc, ndf, 4, 2, 1, bias=False), nn.LeakyReLU(0.2, inplace=True), # state size. (ndf) x 32 x 32 nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False), nn.BatchNorm2d(ndf * 2), nn.LeakyReLU(0.2, inplace=True), # state size. (ndf*2) x 16 x 16 nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False), nn.BatchNorm2d(ndf * 4), nn.LeakyReLU(0.2, inplace=True), # state size. (ndf*4) x 8 x 8 nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False), nn.BatchNorm2d(ndf * 8), nn.LeakyReLU(0.2, inplace=True), # state size. (ndf*8) x 4 x 4 nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False), ) def forward(self, input): if input.is_cuda and self.ngpu > 1: output = nn.parallel.data_parallel(self.main, input, range(self.ngpu)) else: output = self.main(input) return output.view(-1, 1).squeeze(1) netD = Discriminator(ngpu).to(device) netD.apply(weights_init) if opt.netD != '': netD.load_state_dict(torch.load(opt.netD)) print(netD) criterion = nn.BCEWithLogitsLoss() fixed_noise = torch.randn(opt.batchSize, nz, 1, 1, device=device) real_label = 1 fake_label = 0 # setup optimizer optimizerD = optim.Adam(netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) optimizerG = optim.Adam(netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) [netD, netG], [optimizerD, optimizerG] = amp.initialize( [netD, netG], [optimizerD, optimizerG], opt_level=opt.opt_level, num_losses=3) for epoch in range(opt.niter): for i, data in enumerate(dataloader, 0): ############################ # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z))) ########################### # train with real netD.zero_grad() real_cpu = data[0].to(device) batch_size = real_cpu.size(0) label = torch.full((batch_size,), real_label, device=device) output = netD(real_cpu) errD_real = criterion(output, label) with amp.scale_loss(errD_real, optimizerD, loss_id=0) as errD_real_scaled: errD_real_scaled.backward() D_x = output.mean().item() # train with fake noise = torch.randn(batch_size, nz, 1, 1, device=device) fake = netG(noise) label.fill_(fake_label) output = netD(fake.detach()) errD_fake = criterion(output, label) with amp.scale_loss(errD_fake, optimizerD, loss_id=1) as errD_fake_scaled: errD_fake_scaled.backward() D_G_z1 = output.mean().item() errD = errD_real + errD_fake optimizerD.step() ############################ # (2) Update G network: maximize log(D(G(z))) ########################### netG.zero_grad() label.fill_(real_label) # fake labels are real for generator cost output = netD(fake) errG = criterion(output, label) with amp.scale_loss(errG, optimizerG, loss_id=2) as errG_scaled: errG_scaled.backward() D_G_z2 = output.mean().item() optimizerG.step() print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f' % (epoch, opt.niter, i, len(dataloader), errD.item(), errG.item(), D_x, D_G_z1, D_G_z2)) if i % 100 == 0: vutils.save_image(real_cpu, '%s/real_samples.png' % opt.outf, normalize=True) fake = netG(fixed_noise) vutils.save_image(fake.detach(), '%s/amp_fake_samples_epoch_%03d.png' % (opt.outf, epoch), normalize=True) # do checkpointing torch.save(netG.state_dict(), '%s/netG_epoch_%d.pth' % (opt.outf, epoch)) torch.save(netD.state_dict(), '%s/netD_epoch_%d.pth' % (opt.outf, epoch))