123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148 |
- import unittest
- import functools as ft
- import itertools as it
- from apex import amp
- from apex.amp import _amp_state
- import torch
- from torch import nn
- import torch.nn.functional as F
- from torch.nn import Parameter
- from utils import common_init, HALF, FLOAT,\
- ALWAYS_HALF, ALWAYS_FLOAT, MATCH_INPUT
- class MyModel(torch.nn.Module):
- def __init__(self, unique):
- super(MyModel, self).__init__()
- self.weight0 = Parameter(unique +
- torch.arange(2, device='cuda', dtype=torch.float32))
- self.weight1 = Parameter(1. + unique + torch.arange(2, device='cuda', dtype=torch.float16))
- @staticmethod
- def ops(input, weight0, weight1):
- return ((input*(weight0.float()))*(weight1.float())).sum()
- def forward(self, input):
- return self.ops(input, self.weight0, self.weight1)
- # Abandon all hope, ye who enter here.
- class TestAddParamGroup(unittest.TestCase):
- def setUp(self):
- self.x = torch.ones((2), device='cuda', dtype=torch.float32)
- common_init(self)
- def tearDown(self):
- pass
- def zero_grad(self, models, optimizer, how_to_zero):
- if how_to_zero == "none":
- for model in models:
- for param in model.parameters():
- param.grad = None
- elif how_to_zero == "model":
- for model in models:
- model.zero_grad()
- elif how_to_zero == "optimizer":
- optimizer.zero_grad()
- def test_add_param_group(self):
- for opt_level in ("O0", "O1", "O2", "O3"):
- for zero_before_add in (True, False):
- for try_accumulation in (True, False):
- model0 = MyModel(1)
- model1 = MyModel(2)
- optimizer = torch.optim.SGD([{'params' : model0.parameters(), 'lr' : 0.25}],
- momentum=0.125)
- optimizer.zero_grad()
- loss = model0(self.x)
- loss.backward()
- optimizer.step()
- if zero_before_add:
- optimizer.zero_grad()
- optimizer.add_param_group({'params' : model1.parameters(), 'lr' : 0.5})
- if not zero_before_add:
- optimizer.zero_grad()
- loss = model0(self.x) + model1(self.x)
- loss.backward(retain_graph=try_accumulation)
- if try_accumulation:
- loss.backward()
- optimizer.step()
- # Once more to make sure the new params pick up momemtums properly
- optimizer.zero_grad()
- loss = model0(self.x) + model1(self.x)
- loss.backward(retain_graph=try_accumulation)
- if try_accumulation:
- loss.backward()
- optimizer.step()
- reference_params = [param.data.clone() for param in model0.parameters()] + \
- [param.data.clone() for param in model1.parameters()]
- for how_to_zero in "none", "model", "optimizer":
- model0 = MyModel(1)
- model1 = MyModel(2)
- optimizer = torch.optim.SGD([{'params' : model0.parameters(), 'lr' : 0.25}],
- momentum=0.125)
- _amp_state.allow_incoming_model_not_fp32 = True
- [model0, model1], optimizer = amp.initialize([model0, model1],
- optimizer,
- opt_level=opt_level,
- verbosity=0,
- cast_model_type=False)
- _amp_state.allow_incoming_model_not_fp32 = False
- _amp_state.loss_scalers[0]._loss_scale = 4.0
- self.zero_grad([model0, model1], optimizer, how_to_zero)
- loss = model0(self.x)
- with amp.scale_loss(loss, optimizer) as scaled_loss:
- scaled_loss.backward()
- optimizer.step()
- if zero_before_add:
- self.zero_grad([model0, model1], optimizer, how_to_zero)
- optimizer.add_param_group({'params' : model1.parameters(), 'lr' : 0.5})
- if not zero_before_add:
- self.zero_grad([model0, model1], optimizer, how_to_zero)
- loss = model0(self.x) + model1(self.x)
- with amp.scale_loss(loss, optimizer) as scaled_loss:
- scaled_loss.backward(retain_graph=try_accumulation)
- if try_accumulation:
- with amp.scale_loss(loss, optimizer) as scaled_loss:
- scaled_loss.backward()
- optimizer.step()
- # Once more to make sure the new params pick up momentums properly
- self.zero_grad([model0, model1], optimizer, how_to_zero)
- loss = model0(self.x) + model1(self.x)
- with amp.scale_loss(loss, optimizer) as scaled_loss:
- scaled_loss.backward(retain_graph=try_accumulation)
- if try_accumulation:
- with amp.scale_loss(loss, optimizer) as scaled_loss:
- scaled_loss.backward()
- optimizer.step()
- final_params = [param.data.clone() for param in model0.parameters()] + \
- [param.data.clone() for param in model1.parameters()]
- for reference, final in zip(reference_params, final_params):
- torch.testing.assert_close(reference.to(final.dtype), final,
- msg="opt_level = {}, how_to_zero = {}, zero_before_add = {}".format(
- opt_level, how_to_zero, zero_before_add))
- if __name__ == '__main__':
- unittest.main()
|