main_amp.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274
  1. from __future__ import print_function
  2. import argparse
  3. import os
  4. import random
  5. import torch
  6. import torch.nn as nn
  7. import torch.nn.parallel
  8. import torch.backends.cudnn as cudnn
  9. import torch.optim as optim
  10. import torch.utils.data
  11. import torchvision.datasets as dset
  12. import torchvision.transforms as transforms
  13. import torchvision.utils as vutils
  14. try:
  15. from apex import amp
  16. except ImportError:
  17. raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run this example.")
  18. parser = argparse.ArgumentParser()
  19. parser.add_argument('--dataset', default='cifar10', help='cifar10 | lsun | mnist |imagenet | folder | lfw | fake')
  20. parser.add_argument('--dataroot', default='./', help='path to dataset')
  21. parser.add_argument('--workers', type=int, help='number of data loading workers', default=2)
  22. parser.add_argument('--batchSize', type=int, default=64, help='input batch size')
  23. parser.add_argument('--imageSize', type=int, default=64, help='the height / width of the input image to network')
  24. parser.add_argument('--nz', type=int, default=100, help='size of the latent z vector')
  25. parser.add_argument('--ngf', type=int, default=64)
  26. parser.add_argument('--ndf', type=int, default=64)
  27. parser.add_argument('--niter', type=int, default=25, help='number of epochs to train for')
  28. parser.add_argument('--lr', type=float, default=0.0002, help='learning rate, default=0.0002')
  29. parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5')
  30. parser.add_argument('--ngpu', type=int, default=1, help='number of GPUs to use')
  31. parser.add_argument('--netG', default='', help="path to netG (to continue training)")
  32. parser.add_argument('--netD', default='', help="path to netD (to continue training)")
  33. parser.add_argument('--outf', default='.', help='folder to output images and model checkpoints')
  34. parser.add_argument('--manualSeed', type=int, help='manual seed')
  35. parser.add_argument('--classes', default='bedroom', help='comma separated list of classes for the lsun data set')
  36. parser.add_argument('--opt_level', default='O1', help='amp opt_level, default="O1"')
  37. opt = parser.parse_args()
  38. print(opt)
  39. try:
  40. os.makedirs(opt.outf)
  41. except OSError:
  42. pass
  43. if opt.manualSeed is None:
  44. opt.manualSeed = 2809
  45. print("Random Seed: ", opt.manualSeed)
  46. random.seed(opt.manualSeed)
  47. torch.manual_seed(opt.manualSeed)
  48. cudnn.benchmark = True
  49. if opt.dataset in ['imagenet', 'folder', 'lfw']:
  50. # folder dataset
  51. dataset = dset.ImageFolder(root=opt.dataroot,
  52. transform=transforms.Compose([
  53. transforms.Resize(opt.imageSize),
  54. transforms.CenterCrop(opt.imageSize),
  55. transforms.ToTensor(),
  56. transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
  57. ]))
  58. nc=3
  59. elif opt.dataset == 'lsun':
  60. classes = [ c + '_train' for c in opt.classes.split(',')]
  61. dataset = dset.LSUN(root=opt.dataroot, classes=classes,
  62. transform=transforms.Compose([
  63. transforms.Resize(opt.imageSize),
  64. transforms.CenterCrop(opt.imageSize),
  65. transforms.ToTensor(),
  66. transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
  67. ]))
  68. nc=3
  69. elif opt.dataset == 'cifar10':
  70. dataset = dset.CIFAR10(root=opt.dataroot, download=True,
  71. transform=transforms.Compose([
  72. transforms.Resize(opt.imageSize),
  73. transforms.ToTensor(),
  74. transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
  75. ]))
  76. nc=3
  77. elif opt.dataset == 'mnist':
  78. dataset = dset.MNIST(root=opt.dataroot, download=True,
  79. transform=transforms.Compose([
  80. transforms.Resize(opt.imageSize),
  81. transforms.ToTensor(),
  82. transforms.Normalize((0.5,), (0.5,)),
  83. ]))
  84. nc=1
  85. elif opt.dataset == 'fake':
  86. dataset = dset.FakeData(image_size=(3, opt.imageSize, opt.imageSize),
  87. transform=transforms.ToTensor())
  88. nc=3
  89. assert dataset
  90. dataloader = torch.utils.data.DataLoader(dataset, batch_size=opt.batchSize,
  91. shuffle=True, num_workers=int(opt.workers))
  92. device = torch.device("cuda:0")
  93. ngpu = int(opt.ngpu)
  94. nz = int(opt.nz)
  95. ngf = int(opt.ngf)
  96. ndf = int(opt.ndf)
  97. # custom weights initialization called on netG and netD
  98. def weights_init(m):
  99. classname = m.__class__.__name__
  100. if classname.find('Conv') != -1:
  101. m.weight.data.normal_(0.0, 0.02)
  102. elif classname.find('BatchNorm') != -1:
  103. m.weight.data.normal_(1.0, 0.02)
  104. m.bias.data.fill_(0)
  105. class Generator(nn.Module):
  106. def __init__(self, ngpu):
  107. super(Generator, self).__init__()
  108. self.ngpu = ngpu
  109. self.main = nn.Sequential(
  110. # input is Z, going into a convolution
  111. nn.ConvTranspose2d( nz, ngf * 8, 4, 1, 0, bias=False),
  112. nn.BatchNorm2d(ngf * 8),
  113. nn.ReLU(True),
  114. # state size. (ngf*8) x 4 x 4
  115. nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
  116. nn.BatchNorm2d(ngf * 4),
  117. nn.ReLU(True),
  118. # state size. (ngf*4) x 8 x 8
  119. nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
  120. nn.BatchNorm2d(ngf * 2),
  121. nn.ReLU(True),
  122. # state size. (ngf*2) x 16 x 16
  123. nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
  124. nn.BatchNorm2d(ngf),
  125. nn.ReLU(True),
  126. # state size. (ngf) x 32 x 32
  127. nn.ConvTranspose2d( ngf, nc, 4, 2, 1, bias=False),
  128. nn.Tanh()
  129. # state size. (nc) x 64 x 64
  130. )
  131. def forward(self, input):
  132. if input.is_cuda and self.ngpu > 1:
  133. output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
  134. else:
  135. output = self.main(input)
  136. return output
  137. netG = Generator(ngpu).to(device)
  138. netG.apply(weights_init)
  139. if opt.netG != '':
  140. netG.load_state_dict(torch.load(opt.netG))
  141. print(netG)
  142. class Discriminator(nn.Module):
  143. def __init__(self, ngpu):
  144. super(Discriminator, self).__init__()
  145. self.ngpu = ngpu
  146. self.main = nn.Sequential(
  147. # input is (nc) x 64 x 64
  148. nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
  149. nn.LeakyReLU(0.2, inplace=True),
  150. # state size. (ndf) x 32 x 32
  151. nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
  152. nn.BatchNorm2d(ndf * 2),
  153. nn.LeakyReLU(0.2, inplace=True),
  154. # state size. (ndf*2) x 16 x 16
  155. nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
  156. nn.BatchNorm2d(ndf * 4),
  157. nn.LeakyReLU(0.2, inplace=True),
  158. # state size. (ndf*4) x 8 x 8
  159. nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
  160. nn.BatchNorm2d(ndf * 8),
  161. nn.LeakyReLU(0.2, inplace=True),
  162. # state size. (ndf*8) x 4 x 4
  163. nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
  164. )
  165. def forward(self, input):
  166. if input.is_cuda and self.ngpu > 1:
  167. output = nn.parallel.data_parallel(self.main, input, range(self.ngpu))
  168. else:
  169. output = self.main(input)
  170. return output.view(-1, 1).squeeze(1)
  171. netD = Discriminator(ngpu).to(device)
  172. netD.apply(weights_init)
  173. if opt.netD != '':
  174. netD.load_state_dict(torch.load(opt.netD))
  175. print(netD)
  176. criterion = nn.BCEWithLogitsLoss()
  177. fixed_noise = torch.randn(opt.batchSize, nz, 1, 1, device=device)
  178. real_label = 1
  179. fake_label = 0
  180. # setup optimizer
  181. optimizerD = optim.Adam(netD.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
  182. optimizerG = optim.Adam(netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
  183. [netD, netG], [optimizerD, optimizerG] = amp.initialize(
  184. [netD, netG], [optimizerD, optimizerG], opt_level=opt.opt_level, num_losses=3)
  185. for epoch in range(opt.niter):
  186. for i, data in enumerate(dataloader, 0):
  187. ############################
  188. # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
  189. ###########################
  190. # train with real
  191. netD.zero_grad()
  192. real_cpu = data[0].to(device)
  193. batch_size = real_cpu.size(0)
  194. label = torch.full((batch_size,), real_label, device=device)
  195. output = netD(real_cpu)
  196. errD_real = criterion(output, label)
  197. with amp.scale_loss(errD_real, optimizerD, loss_id=0) as errD_real_scaled:
  198. errD_real_scaled.backward()
  199. D_x = output.mean().item()
  200. # train with fake
  201. noise = torch.randn(batch_size, nz, 1, 1, device=device)
  202. fake = netG(noise)
  203. label.fill_(fake_label)
  204. output = netD(fake.detach())
  205. errD_fake = criterion(output, label)
  206. with amp.scale_loss(errD_fake, optimizerD, loss_id=1) as errD_fake_scaled:
  207. errD_fake_scaled.backward()
  208. D_G_z1 = output.mean().item()
  209. errD = errD_real + errD_fake
  210. optimizerD.step()
  211. ############################
  212. # (2) Update G network: maximize log(D(G(z)))
  213. ###########################
  214. netG.zero_grad()
  215. label.fill_(real_label) # fake labels are real for generator cost
  216. output = netD(fake)
  217. errG = criterion(output, label)
  218. with amp.scale_loss(errG, optimizerG, loss_id=2) as errG_scaled:
  219. errG_scaled.backward()
  220. D_G_z2 = output.mean().item()
  221. optimizerG.step()
  222. print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f D(x): %.4f D(G(z)): %.4f / %.4f'
  223. % (epoch, opt.niter, i, len(dataloader),
  224. errD.item(), errG.item(), D_x, D_G_z1, D_G_z2))
  225. if i % 100 == 0:
  226. vutils.save_image(real_cpu,
  227. '%s/real_samples.png' % opt.outf,
  228. normalize=True)
  229. fake = netG(fixed_noise)
  230. vutils.save_image(fake.detach(),
  231. '%s/amp_fake_samples_epoch_%03d.png' % (opt.outf, epoch),
  232. normalize=True)
  233. # do checkpointing
  234. torch.save(netG.state_dict(), '%s/netG_epoch_%d.pth' % (opt.outf, epoch))
  235. torch.save(netD.state_dict(), '%s/netD_epoch_%d.pth' % (opt.outf, epoch))