test_larc.py 1.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253
  1. import unittest
  2. import torch
  3. from torch import nn
  4. from torch.nn import Parameter
  5. from apex import amp
  6. from apex.parallel.LARC import LARC
  7. from utils import common_init
  8. class MyModel(torch.nn.Module):
  9. def __init__(self, unique):
  10. super(MyModel, self).__init__()
  11. self.weight0 = Parameter(
  12. unique + torch.arange(2, device="cuda", dtype=torch.float32)
  13. )
  14. def forward(self, input):
  15. return (input * self.weight0).sum()
  16. class TestLARC(unittest.TestCase):
  17. def setUp(self):
  18. self.x = torch.ones((2), device="cuda", dtype=torch.float32)
  19. common_init(self)
  20. def tearDown(self):
  21. pass
  22. def test_larc_mixed_precision(self):
  23. for opt_level in ["O0", "O1", "O2", "O3"]:
  24. model = MyModel(1)
  25. optimizer = LARC(
  26. torch.optim.SGD(
  27. [{"params": model.parameters(), "lr": 0.25}], momentum=0.125
  28. )
  29. )
  30. model, optimizer = amp.initialize(
  31. model, optimizer, opt_level=opt_level, verbosity=0
  32. )
  33. optimizer.zero_grad()
  34. loss = model(self.x)
  35. with amp.scale_loss(loss, optimizer) as scaled_loss:
  36. scaled_loss.backward()
  37. optimizer.step()
  38. if __name__ == "__main__":
  39. unittest.main()