test_checkpointing.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267
  1. import unittest
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. import torch.optim as optim
  6. from apex import amp
  7. from utils import common_init, FLOAT
  8. class MyModel(torch.nn.Module):
  9. def __init__(self):
  10. super(MyModel, self).__init__()
  11. self.conv1 = nn.Conv2d(3, 6, 3, 1, 1)
  12. self.bn1 = nn.BatchNorm2d(6)
  13. self.param = nn.Parameter(torch.randn(1))
  14. def forward(self, x):
  15. x = x * self.param
  16. x = F.relu(self.conv1(x))
  17. x = self.bn1(x)
  18. return x
  19. class TestCheckpointing(unittest.TestCase):
  20. def setUp(self):
  21. self.initial_lr = 1e-3
  22. self.test_opt_levels = ("O0", "O1", "O2", "O3")
  23. def seed(self):
  24. torch.manual_seed(2809)
  25. torch.backends.cudnn.benchmark = False
  26. torch.backends.cudnn.deterministic = True
  27. def check_state_dict_fp32(self, state_dict):
  28. for key in state_dict:
  29. if 'num_batches_tracked' in key:
  30. continue
  31. param = state_dict[key]
  32. self.assertEqual(param.type(), FLOAT,
  33. 'Parameter in state_dict not FLOAT')
  34. def train_step(self, model, optimizer, data, loss_ids):
  35. optimizer.zero_grad()
  36. output = model(data)
  37. # Call backward for num_losses-1
  38. for idx in loss_ids:
  39. loss = output.mean()
  40. with amp.scale_loss(loss, optimizer, loss_id=idx) as scaled_loss:
  41. scaled_loss.backward(retain_graph=True)
  42. optimizer.step()
  43. return output
  44. def compare_models(self, modelA, modelB, test_setup=''):
  45. state_dictA = modelA.state_dict()
  46. state_dictB = modelB.state_dict()
  47. self.assertEqual(len(state_dictA), len(state_dictB),
  48. 'state_dicts have different lengths' + test_setup)
  49. for key in state_dictA:
  50. paramA = state_dictA[key]
  51. paramB = state_dictB[key]
  52. self.assertTrue((paramA==paramB).all(),
  53. msg='Parameters in state_dices not equal.' +
  54. 'key: {}\nparam: {}\nrestored: {}\ndiff: {} for {}'.format(
  55. key, paramA, paramB, paramA - paramB, test_setup))
  56. def test_restoring(self):
  57. nb_epochs = 10
  58. nb_epochs_restore = nb_epochs // 2
  59. for opt_level in self.test_opt_levels:
  60. for res_opt_level in self.test_opt_levels:
  61. for amp_before_load in [True, False]:
  62. for num_losses in range(1, 3):
  63. test_setup = ('#' * 75 + '\n' + \
  64. f'opt_level {opt_level}\n' + \
  65. f'restore_opt_level {res_opt_level}\n' + \
  66. f'amp_before_load {amp_before_load}\n' + \
  67. f'num_losses {num_losses}\n')
  68. self.seed()
  69. # Create reference model
  70. model = MyModel().to('cuda')
  71. optimizer = optim.SGD(model.parameters(),
  72. lr=self.initial_lr)
  73. # Initialize with num_losses*2 for the original model and the restored one
  74. model, optimizer = amp.initialize(
  75. model, optimizer, opt_level=opt_level,
  76. num_losses=num_losses*2, verbosity=0)
  77. # Compare training behavior for same restore option
  78. # We cannot really generalize it, since a saved model in O0
  79. # would introduce a skipped step in O1, which will raise an error
  80. if opt_level == res_opt_level:
  81. # train for nb_epochs and restore after nb_epochs_restore
  82. for epoch in range(nb_epochs):
  83. x = torch.randn(16, 3, 24, 24, device='cuda')
  84. output = self.train_step(
  85. model, optimizer, x, range(num_losses))
  86. # Initialize model one step before comparing.
  87. # Otherwise the batchnorm layers will be updated
  88. # additionally in restore_model
  89. if epoch == (nb_epochs_restore - 1):
  90. # Load model and optimizer
  91. checkpoint = {
  92. 'model': model.state_dict(),
  93. 'optimizer': optimizer.state_dict(),
  94. 'amp': amp.state_dict()
  95. }
  96. # Check state_dict for FP32 tensors
  97. self.check_state_dict_fp32(checkpoint['model'])
  98. # Restore model
  99. restore_model = MyModel().to('cuda')
  100. restore_optimizer = optim.SGD(
  101. restore_model.parameters(),
  102. lr=self.initial_lr)
  103. if amp_before_load:
  104. restore_model, restore_optimizer = amp.initialize(
  105. restore_model,
  106. restore_optimizer,
  107. opt_level=res_opt_level,
  108. num_losses=num_losses*2,
  109. verbosity=0)
  110. restore_model.load_state_dict(checkpoint['model'])
  111. restore_optimizer.load_state_dict(checkpoint['optimizer'])
  112. # FIXME: We cannot test the amp.state_dict in the same script
  113. # amp.load_state_dict(checkpoint['amp'])
  114. if not amp_before_load:
  115. restore_model, restore_optimizer = amp.initialize(
  116. restore_model,
  117. restore_optimizer,
  118. opt_level=res_opt_level,
  119. num_losses=num_losses*2,
  120. verbosity=0)
  121. elif epoch >= nb_epochs_restore:
  122. restore_output = self.train_step(
  123. restore_model,
  124. restore_optimizer,
  125. x,
  126. range(num_losses, num_losses*2))
  127. self.assertTrue(
  128. torch.allclose(output.float(), restore_output.float()),
  129. 'Output of reference and restored models differ for ' + test_setup)
  130. self.compare_models(model, restore_model, test_setup)
  131. # if opt_level != res_opt_level
  132. else:
  133. # skip tests for different opt_levels
  134. continue
  135. def test_loss_scale_decrease(self):
  136. num_losses = 3
  137. nb_decrease_loss_scales = [0, 1, 2]
  138. for opt_level in self.test_opt_levels:
  139. #print('#' * 75 + f'\n opt_level {opt_level}\n')
  140. # Create new tmp copy for this run
  141. nb_decrease_loss_scales_tmp = list(nb_decrease_loss_scales)
  142. model = MyModel().to('cuda')
  143. optimizer = optim.SGD(model.parameters(),
  144. lr=self.initial_lr)
  145. model, optimizer = amp.initialize(
  146. model, optimizer, opt_level=opt_level, num_losses=num_losses,
  147. verbosity=0)
  148. if amp._amp_state.opt_properties.loss_scale != 'dynamic':
  149. #print('Static loss scale set. Skipping opt_level.')
  150. continue
  151. # force to skip some updates to decrease the loss_scale
  152. initial_loss_scales = []
  153. for idx in range(num_losses):
  154. initial_loss_scales.append(
  155. amp._amp_state.loss_scalers[idx].loss_scale())
  156. for _ in range(len(nb_decrease_loss_scales)):
  157. x = torch.randn(16, 3, 24, 24, device='cuda')
  158. for idx in range(num_losses):
  159. while nb_decrease_loss_scales_tmp[idx] > 0:
  160. optimizer.zero_grad()
  161. output = model(x * 2**17)
  162. loss = output.mean()
  163. with amp.scale_loss(loss, optimizer, loss_id=idx) as scaled_loss:
  164. scaled_loss.backward(retain_graph=True)
  165. optimizer.step()
  166. nb_decrease_loss_scales_tmp[idx] -= 1
  167. # Check loss scales afterwards
  168. updated_loss_scales = []
  169. for idx in range(num_losses):
  170. updated_loss_scales.append(
  171. amp._amp_state.loss_scalers[idx].loss_scale())
  172. for factor, update_ls, init_ls in zip(nb_decrease_loss_scales,
  173. updated_loss_scales,
  174. initial_loss_scales):
  175. self.assertEqual(update_ls, init_ls / 2**factor)
  176. # Check state dict
  177. amp_state_dict = amp.state_dict()
  178. for scaler_idx, factor, init_ls in zip(amp_state_dict,
  179. nb_decrease_loss_scales,
  180. initial_loss_scales):
  181. scaler = amp_state_dict[scaler_idx]
  182. self.assertEqual(scaler['loss_scale'], init_ls / 2**factor)
  183. unskipped_target = 0
  184. self.assertEqual(scaler['unskipped'], unskipped_target)
  185. def test_state_dict(self):
  186. for opt_level in self.test_opt_levels:
  187. # Skip O3
  188. if opt_level == 'O3':
  189. continue
  190. model = MyModel().to('cuda')
  191. optimizer = optim.Adam(model.parameters(), lr=1e-3)
  192. model, optimizer = amp.initialize(
  193. model, optimizer, opt_level=opt_level, verbosity=0)
  194. # Export state_dict and check for Half
  195. state_dict = model.state_dict()
  196. for key in state_dict:
  197. self.assertFalse('Half' in state_dict[key].type())
  198. # Check, if model is still trainable
  199. # Create dummy data
  200. data = torch.randn(10, 3, 4, 4, device='cuda')
  201. target = torch.randn(10, 6, 4, 4, device='cuda')
  202. # Get initnial loss
  203. optimizer.zero_grad()
  204. output = model(data)
  205. loss = F.mse_loss(output, target)
  206. with amp.scale_loss(loss, optimizer) as scaled_loss:
  207. scaled_loss.backward()
  208. optimizer.step()
  209. last_loss = loss.item()
  210. # train for some epochs
  211. for epoch in range(10):
  212. optimizer.zero_grad()
  213. output = model(data)
  214. loss = F.mse_loss(output, target)
  215. with amp.scale_loss(loss, optimizer) as scaled_loss:
  216. scaled_loss.backward()
  217. optimizer.step()
  218. self.assertTrue(loss.item() < last_loss)
  219. last_loss = loss.item()
  220. if __name__=='__main__':
  221. unittest.main()