test_cache.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137
  1. import unittest
  2. import functools as ft
  3. import itertools as it
  4. from apex import amp
  5. from apex.amp import _amp_state
  6. import torch
  7. from torch import nn
  8. import torch.nn.functional as F
  9. from utils import common_init, HALF, FLOAT,\
  10. ALWAYS_HALF, ALWAYS_FLOAT, MATCH_INPUT
  11. def get_reference_grad(i, w, ops):
  12. # Creating new tensors ensures, among other things, that the new tensors are not in the cache.
  13. # In fact, they are guaranteed not to use the cache because they are not torch.nn.Parameters.
  14. fp32_i = i.detach().clone().float()
  15. fp32_w = w.detach().clone().float().requires_grad_()
  16. loss = ops(fp32_i, fp32_w)
  17. loss.backward()
  18. return fp32_w.grad
  19. class WhitelistModule(torch.nn.Module):
  20. def __init__(self, dtype):
  21. super(WhitelistModule, self).__init__()
  22. self.weight = torch.nn.Parameter(torch.arange(8*8, device='cuda', dtype=dtype).view(8,8))
  23. @staticmethod
  24. def ops(input, weight):
  25. return (input.mm(weight)).mm(weight).sum()
  26. def forward(self, input):
  27. return self.ops(input, self.weight)
  28. class BlacklistModule(torch.nn.Module):
  29. def __init__(self, dtype):
  30. super(BlacklistModule, self).__init__()
  31. self.weight = torch.nn.Parameter(torch.arange(2*8, device='cuda', dtype=dtype).view(2,8))
  32. @staticmethod
  33. def ops(input, weight):
  34. return (input + torch.pow(weight, 2) + torch.pow(weight, 2)).sum()
  35. def forward(self, input):
  36. return self.ops(input, self.weight)
  37. class PromoteModule(torch.nn.Module):
  38. def __init__(self, dtype):
  39. super(PromoteModule, self).__init__()
  40. self.weight = torch.nn.Parameter(torch.arange(2*8, device='cuda', dtype=dtype).view(2,8))
  41. @staticmethod
  42. def ops(input, weight):
  43. return ((input*weight)*weight).sum()
  44. def forward(self, input):
  45. return self.ops(input, self.weight)
  46. class TestCache(unittest.TestCase):
  47. def setUp(self):
  48. self.x = torch.ones((2, 8), device='cuda', dtype=torch.float32)
  49. common_init(self)
  50. def tearDown(self):
  51. pass
  52. def train_eval_train_test(self, module, t):
  53. model = module(t).cuda()
  54. optimizer = torch.optim.SGD(model.parameters(), lr=1.0)
  55. _amp_state.allow_incoming_model_not_fp32 = True
  56. model, optimizer = amp.initialize(model, optimizer, opt_level="O1", verbosity=0)
  57. _amp_state.allow_incoming_model_not_fp32 = False
  58. def training_step():
  59. for param in model.parameters():
  60. param.grad = None
  61. loss = model(self.x).sum()
  62. _amp_state.loss_scalers[0]._loss_scale = 4.0
  63. with amp.scale_loss(loss, optimizer) as scaled_loss:
  64. scaled_loss.backward()
  65. self.assertEqual(len([p.grad for p in model.parameters() if p.grad is not None]), 1)
  66. self.assertEqual(model.weight.grad.type(), model.weight.type())
  67. reference_grad = get_reference_grad(self.x, model.weight, model.ops)
  68. # Currently there's no difference in the allclose calls, so no need for branching,
  69. # but I'm keeping this in case we want different tolerances for fp16 and fp32 checks.
  70. if model.weight.grad.type() == "torch.cuda.HalfTensor":
  71. torch.testing.assert_close(model.weight.grad.float(), reference_grad)
  72. elif model.weight.grad.type() == "torch.cuda.FloatTensor":
  73. torch.testing.assert_close(model.weight.grad.float(), reference_grad)
  74. else:
  75. raise RuntimeError("model.weight.grad.type = {}".format(model.weight.grad.type()))
  76. model.weight.data -= 1.
  77. # Simulates first epoch
  78. training_step()
  79. # Simulates eval
  80. with torch.no_grad():
  81. loss = model(self.x).sum()
  82. # Simulates resuming training after eval
  83. training_step()
  84. _amp_state.handle._deactivate()
  85. # I could easily have these as a set of for loops in a single test,
  86. # instead of going for granularity.
  87. def test_whitelist_module_fp16_weight(self):
  88. self.train_eval_train_test(WhitelistModule, torch.float16)
  89. def test_whitelist_module_fp32_weight(self):
  90. self.train_eval_train_test(WhitelistModule, torch.float32)
  91. def test_blacklist_module_fp16_weight(self):
  92. self.train_eval_train_test(BlacklistModule, torch.float16)
  93. def test_blacklist_module_fp32_weight(self):
  94. self.train_eval_train_test(BlacklistModule, torch.float32)
  95. def test_promote_module_fp16_weight(self):
  96. self.train_eval_train_test(PromoteModule, torch.float16)
  97. def test_promote_module_fp32_weight(self):
  98. self.train_eval_train_test(PromoteModule, torch.float32)
  99. if __name__ == '__main__':
  100. unittest.main()