import unittest import torch from torch import nn from torch.nn import Parameter from apex import amp from apex.parallel.LARC import LARC from utils import common_init class MyModel(torch.nn.Module): def __init__(self, unique): super(MyModel, self).__init__() self.weight0 = Parameter( unique + torch.arange(2, device="cuda", dtype=torch.float32) ) def forward(self, input): return (input * self.weight0).sum() class TestLARC(unittest.TestCase): def setUp(self): self.x = torch.ones((2), device="cuda", dtype=torch.float32) common_init(self) def tearDown(self): pass def test_larc_mixed_precision(self): for opt_level in ["O0", "O1", "O2", "O3"]: model = MyModel(1) optimizer = LARC( torch.optim.SGD( [{"params": model.parameters(), "lr": 0.25}], momentum=0.125 ) ) model, optimizer = amp.initialize( model, optimizer, opt_level=opt_level, verbosity=0 ) optimizer.zero_grad() loss = model(self.x) with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() optimizer.step() if __name__ == "__main__": unittest.main()