test_add_param_group.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148
  1. import unittest
  2. import functools as ft
  3. import itertools as it
  4. from apex import amp
  5. from apex.amp import _amp_state
  6. import torch
  7. from torch import nn
  8. import torch.nn.functional as F
  9. from torch.nn import Parameter
  10. from utils import common_init, HALF, FLOAT,\
  11. ALWAYS_HALF, ALWAYS_FLOAT, MATCH_INPUT
  12. class MyModel(torch.nn.Module):
  13. def __init__(self, unique):
  14. super(MyModel, self).__init__()
  15. self.weight0 = Parameter(unique +
  16. torch.arange(2, device='cuda', dtype=torch.float32))
  17. self.weight1 = Parameter(1. + unique + torch.arange(2, device='cuda', dtype=torch.float16))
  18. @staticmethod
  19. def ops(input, weight0, weight1):
  20. return ((input*(weight0.float()))*(weight1.float())).sum()
  21. def forward(self, input):
  22. return self.ops(input, self.weight0, self.weight1)
  23. # Abandon all hope, ye who enter here.
  24. class TestAddParamGroup(unittest.TestCase):
  25. def setUp(self):
  26. self.x = torch.ones((2), device='cuda', dtype=torch.float32)
  27. common_init(self)
  28. def tearDown(self):
  29. pass
  30. def zero_grad(self, models, optimizer, how_to_zero):
  31. if how_to_zero == "none":
  32. for model in models:
  33. for param in model.parameters():
  34. param.grad = None
  35. elif how_to_zero == "model":
  36. for model in models:
  37. model.zero_grad()
  38. elif how_to_zero == "optimizer":
  39. optimizer.zero_grad()
  40. def test_add_param_group(self):
  41. for opt_level in ("O0", "O1", "O2", "O3"):
  42. for zero_before_add in (True, False):
  43. for try_accumulation in (True, False):
  44. model0 = MyModel(1)
  45. model1 = MyModel(2)
  46. optimizer = torch.optim.SGD([{'params' : model0.parameters(), 'lr' : 0.25}],
  47. momentum=0.125)
  48. optimizer.zero_grad()
  49. loss = model0(self.x)
  50. loss.backward()
  51. optimizer.step()
  52. if zero_before_add:
  53. optimizer.zero_grad()
  54. optimizer.add_param_group({'params' : model1.parameters(), 'lr' : 0.5})
  55. if not zero_before_add:
  56. optimizer.zero_grad()
  57. loss = model0(self.x) + model1(self.x)
  58. loss.backward(retain_graph=try_accumulation)
  59. if try_accumulation:
  60. loss.backward()
  61. optimizer.step()
  62. # Once more to make sure the new params pick up momemtums properly
  63. optimizer.zero_grad()
  64. loss = model0(self.x) + model1(self.x)
  65. loss.backward(retain_graph=try_accumulation)
  66. if try_accumulation:
  67. loss.backward()
  68. optimizer.step()
  69. reference_params = [param.data.clone() for param in model0.parameters()] + \
  70. [param.data.clone() for param in model1.parameters()]
  71. for how_to_zero in "none", "model", "optimizer":
  72. model0 = MyModel(1)
  73. model1 = MyModel(2)
  74. optimizer = torch.optim.SGD([{'params' : model0.parameters(), 'lr' : 0.25}],
  75. momentum=0.125)
  76. _amp_state.allow_incoming_model_not_fp32 = True
  77. [model0, model1], optimizer = amp.initialize([model0, model1],
  78. optimizer,
  79. opt_level=opt_level,
  80. verbosity=0,
  81. cast_model_type=False)
  82. _amp_state.allow_incoming_model_not_fp32 = False
  83. _amp_state.loss_scalers[0]._loss_scale = 4.0
  84. self.zero_grad([model0, model1], optimizer, how_to_zero)
  85. loss = model0(self.x)
  86. with amp.scale_loss(loss, optimizer) as scaled_loss:
  87. scaled_loss.backward()
  88. optimizer.step()
  89. if zero_before_add:
  90. self.zero_grad([model0, model1], optimizer, how_to_zero)
  91. optimizer.add_param_group({'params' : model1.parameters(), 'lr' : 0.5})
  92. if not zero_before_add:
  93. self.zero_grad([model0, model1], optimizer, how_to_zero)
  94. loss = model0(self.x) + model1(self.x)
  95. with amp.scale_loss(loss, optimizer) as scaled_loss:
  96. scaled_loss.backward(retain_graph=try_accumulation)
  97. if try_accumulation:
  98. with amp.scale_loss(loss, optimizer) as scaled_loss:
  99. scaled_loss.backward()
  100. optimizer.step()
  101. # Once more to make sure the new params pick up momentums properly
  102. self.zero_grad([model0, model1], optimizer, how_to_zero)
  103. loss = model0(self.x) + model1(self.x)
  104. with amp.scale_loss(loss, optimizer) as scaled_loss:
  105. scaled_loss.backward(retain_graph=try_accumulation)
  106. if try_accumulation:
  107. with amp.scale_loss(loss, optimizer) as scaled_loss:
  108. scaled_loss.backward()
  109. optimizer.step()
  110. final_params = [param.data.clone() for param in model0.parameters()] + \
  111. [param.data.clone() for param in model1.parameters()]
  112. for reference, final in zip(reference_params, final_params):
  113. torch.testing.assert_close(reference.to(final.dtype), final,
  114. msg="opt_level = {}, how_to_zero = {}, zero_before_add = {}".format(
  115. opt_level, how_to_zero, zero_before_add))
  116. if __name__ == '__main__':
  117. unittest.main()