test_adam.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254
  1. import copy
  2. import math
  3. import random
  4. import unittest
  5. import torch
  6. import torch.nn.functional as F
  7. from torch import nn
  8. from torch.testing._internal.common_device_type import largeTensorTest
  9. try:
  10. import apex
  11. except ImportError as e:
  12. HAS_APEX = False
  13. else:
  14. HAS_APEX = True
  15. class Model(torch.nn.Module):
  16. def __init__(self):
  17. super(Model, self).__init__()
  18. self.conv1 = nn.Conv2d(1, 6, 5)
  19. self.relu1 = nn.ReLU()
  20. self.pool1 = nn.MaxPool2d(2)
  21. self.conv2 = nn.Conv2d(6, 16, 5)
  22. self.relu2 = nn.ReLU()
  23. self.pool2 = nn.MaxPool2d(2)
  24. self.fc1 = nn.Linear(256, 120)
  25. self.relu3 = nn.ReLU()
  26. self.fc2 = nn.Linear(120, 84)
  27. self.relu4 = nn.ReLU()
  28. self.fc3 = nn.Linear(84, 10)
  29. self.relu5 = nn.ReLU()
  30. def forward(self, x):
  31. y = self.conv1(x)
  32. y = self.relu1(y)
  33. y = self.pool1(y)
  34. y = self.conv2(y)
  35. y = self.relu2(y)
  36. y = self.pool2(y)
  37. y = y.reshape(y.shape[0], -1)
  38. y = self.fc1(y)
  39. y = self.relu3(y)
  40. y = self.fc2(y)
  41. y = self.relu4(y)
  42. y = self.fc3(y)
  43. y = self.relu5(y)
  44. return y
  45. @unittest.skipIf(not HAS_APEX, "`apex` is not found.")
  46. class AdamTest(unittest.TestCase):
  47. def setUp(self, seed=0):
  48. super().setUp()
  49. torch.manual_seed(seed)
  50. self.model = Model().cuda()
  51. self.model_ = Model().cuda()
  52. self.model_.load_state_dict(copy.deepcopy(self.model.state_dict()))
  53. self.lr = 0.00001
  54. params = [p for p in self.model.parameters() if p.requires_grad]
  55. self.optimizer = torch.optim.Adam(params, lr=self.lr)
  56. def testGradScaler(self):
  57. params_ = [p for p in self.model_.parameters() if p.requires_grad]
  58. optimizer_ = apex.optimizers.FusedAdam(params_, lr=self.lr, capturable=False)
  59. scaler = torch.cuda.amp.GradScaler(enabled=True)
  60. scaler_ = torch.cuda.amp.GradScaler(enabled=True)
  61. for i in range(100):
  62. x = torch.rand([32, 1, 28, 28]).cuda().to(memory_format=torch.channels_last)
  63. x_ = x.clone()
  64. gt = torch.rand([32, 10]).cuda()
  65. gt_ = gt.clone()
  66. # Reference
  67. with torch.cuda.amp.autocast(enabled=True):
  68. y = self.model(x)
  69. loss = ((gt - y) ** 2).mean()
  70. scaler.scale(loss).backward()
  71. scaler.step(self.optimizer)
  72. scaler.update()
  73. # DUT
  74. with torch.cuda.amp.autocast(enabled=True):
  75. y = self.model_(x)
  76. loss_ = ((gt_ - y) ** 2).mean()
  77. scaler_.scale(loss_).backward()
  78. scaler_.step(optimizer_)
  79. scaler_.update()
  80. for module in zip(self.model.modules(), self.model_.modules()):
  81. m = module[0]
  82. m_ = module[1]
  83. if isinstance(m, nn.Conv2d) or isinstance(m_, nn.Linear):
  84. torch.testing.assert_close(m.weight, m_.weight, atol=1e-3, rtol=1e-3, equal_nan=True)
  85. torch.testing.assert_close(m.weight.grad, m_.weight.grad, atol=1e-3, rtol=1e-3, equal_nan=True)
  86. # Init for next iteration
  87. self.optimizer.zero_grad()
  88. optimizer_.zero_grad()
  89. self.model_.load_state_dict(copy.deepcopy(self.model.state_dict()))
  90. def testGradScalerCapturable(self):
  91. params_ = [p for p in self.model_.parameters() if p.requires_grad]
  92. optimizer_ = apex.optimizers.FusedAdam(params_, lr=self.lr, capturable=True)
  93. scaler = torch.cuda.amp.GradScaler(enabled=True)
  94. scaler_ = torch.cuda.amp.GradScaler(enabled=True)
  95. for i in range(100):
  96. x = torch.rand([32, 1, 28, 28]).cuda().to(memory_format=torch.channels_last)
  97. x_ = x.clone()
  98. gt = torch.rand([32, 10]).cuda()
  99. gt_ = gt.clone()
  100. # Reference
  101. with torch.cuda.amp.autocast(enabled=True):
  102. y = self.model(x)
  103. loss = ((gt - y) ** 2).mean()
  104. scaler.scale(loss).backward()
  105. scaler.step(self.optimizer)
  106. scaler.update()
  107. # DUT
  108. with torch.cuda.amp.autocast(enabled=True):
  109. y = self.model_(x)
  110. loss_ = ((gt_ - y) ** 2).mean()
  111. scaler_.scale(loss_).backward()
  112. scaler_.step(optimizer_)
  113. scaler_.update()
  114. for module in zip(self.model.modules(), self.model_.modules()):
  115. m = module[0]
  116. m_ = module[1]
  117. if isinstance(m, nn.Conv2d) or isinstance(m_, nn.Linear):
  118. torch.testing.assert_close(m.weight, m_.weight, atol=1e-3, rtol=1e-3, equal_nan=True)
  119. torch.testing.assert_close(m.weight.grad, m_.weight.grad, atol=1e-3, rtol=1e-3, equal_nan=True)
  120. # Init for next iteration
  121. self.optimizer.zero_grad()
  122. optimizer_.zero_grad()
  123. self.model_.load_state_dict(copy.deepcopy(self.model.state_dict()))
  124. def testGradScalerCapturableMaster(self):
  125. # Cast conv layers to FP16
  126. for m in self.model_.modules():
  127. if m.__class__ in [torch.nn.Conv2d]:
  128. m.half()
  129. params_ = [p for p in self.model_.parameters() if p.requires_grad]
  130. optimizer_ = apex.optimizers.FusedAdam(params_, lr=self.lr, capturable=True, master_weights=True)
  131. scaler = torch.cuda.amp.GradScaler(enabled=True)
  132. scaler_ = torch.cuda.amp.GradScaler(enabled=True)
  133. for i in range(100):
  134. x = torch.rand([32, 1, 28, 28]).cuda().to(memory_format=torch.channels_last)
  135. x_ = x.clone()
  136. gt = torch.rand([32, 10]).cuda()
  137. gt_ = gt.clone()
  138. # Reference
  139. with torch.cuda.amp.autocast(enabled=True):
  140. y = self.model(x)
  141. loss = ((gt - y) ** 2).mean()
  142. scaler.scale(loss).backward()
  143. scaler.step(self.optimizer)
  144. scaler.update()
  145. # DUT
  146. with torch.cuda.amp.autocast(enabled=True):
  147. y = self.model_(x)
  148. loss_ = ((gt_ - y) ** 2).mean()
  149. scaler_.scale(loss_).backward()
  150. scaler_.step(optimizer_)
  151. scaler_.update()
  152. for module in zip(self.model.modules(), self.model_.modules()):
  153. m = module[0]
  154. m_ = module[1]
  155. if isinstance(m, nn.Conv2d) or isinstance(m_, nn.Linear):
  156. torch.testing.assert_close(m.weight, m_.weight.float(), atol=1e-3, rtol=1e-3, equal_nan=True)
  157. torch.testing.assert_close(m.weight.grad, m_.weight.grad.float(), atol=1e-3, rtol=1e-3, equal_nan=True)
  158. # Init for next iteration
  159. self.optimizer.zero_grad()
  160. optimizer_.zero_grad()
  161. self.model_.load_state_dict(copy.deepcopy(self.model.state_dict()))
  162. def testNative(self):
  163. params_ = [p for p in self.model_.parameters() if p.requires_grad]
  164. optimizer_ = apex.optimizers.FusedAdam(params_, lr=self.lr, capturable=False)
  165. for i in range(100):
  166. x = torch.rand([32, 1, 28, 28]).cuda().to(memory_format=torch.channels_last)
  167. x_ = x.clone()
  168. gt = torch.rand([32, 10]).cuda()
  169. gt_ = gt.clone()
  170. # Reference
  171. y = self.model(x)
  172. loss = ((gt - y) ** 2).mean()
  173. loss.backward()
  174. self.optimizer.step()
  175. # DUT
  176. y = self.model_(x)
  177. loss_ = ((gt_ - y) ** 2).mean()
  178. loss_.backward()
  179. optimizer_.step()
  180. for module in zip(self.model.modules(), self.model_.modules()):
  181. m = module[0]
  182. m_ = module[1]
  183. if isinstance(m, nn.Conv2d) or isinstance(m_, nn.Linear):
  184. torch.testing.assert_close(m.weight, m_.weight, atol=1e-3, rtol=1e-3, equal_nan=True)
  185. torch.testing.assert_close(m.weight.grad, m_.weight.grad, atol=1e-3, rtol=1e-3, equal_nan=True)
  186. # Init for next iteration
  187. self.optimizer.zero_grad()
  188. optimizer_.zero_grad()
  189. self.model_.load_state_dict(copy.deepcopy(self.model.state_dict()))
  190. @largeTensorTest('60GB', 'cuda')
  191. def testLargeTensor(self):
  192. t = torch.zeros(2359332864, dtype=torch.half, device='cuda')
  193. t2 = torch.zeros(2359332864, dtype=torch.half, device='cuda')
  194. grad = torch.randn_like(t)
  195. t.grad = grad
  196. t2.grad = grad
  197. params = [t]
  198. params2 = [t2]
  199. optimizer = apex.optimizers.FusedAdam(params, lr=self.lr)
  200. optimizer.step()
  201. optimizer2 = torch.optim.Adam(params2, lr=self.lr)
  202. torch.testing.assert_close(t, t2)
  203. torch.cuda.synchronize()
  204. if __name__ == '__main__':
  205. unittest.main()