import unittest import functools as ft import itertools as it from apex import amp from apex.amp import _amp_state import torch from torch import nn import torch.nn.functional as F from torch.nn import Parameter from utils import common_init, HALF, FLOAT,\ ALWAYS_HALF, ALWAYS_FLOAT, MATCH_INPUT class MyModel(torch.nn.Module): def __init__(self, unique): super(MyModel, self).__init__() self.weight0 = Parameter(unique + torch.arange(2, device='cuda', dtype=torch.float32)) self.weight1 = Parameter(1. + unique + torch.arange(2, device='cuda', dtype=torch.float16)) @staticmethod def ops(input, weight0, weight1): return ((input*(weight0.float()))*(weight1.float())).sum() def forward(self, input): return self.ops(input, self.weight0, self.weight1) # Abandon all hope, ye who enter here. # This is hands down the ugliest code I have ever written, but it succeeds in testing # multiple models/optimizers/losses fairly thoroughly. Many of the different test cases # require slightly divergent code in a way that seems near-impossible to genericize into a simple # cross product or nested loops. class TestMultipleModelsOptimizersLosses(unittest.TestCase): def setUp(self): self.x = torch.ones((2), device='cuda', dtype=torch.float32) common_init(self) def tearDown(self): pass def test_2models2losses1optimizer(self): model0 = MyModel(1) model1 = MyModel(2) optimizer = torch.optim.SGD([{'params' : model0.parameters(), 'lr' : 0.25}, {'params' : model1.parameters(), 'lr' : 0.5}], momentum=0.125) reference_grads = [] for i in range(2): optimizer.zero_grad() loss0 = model0(self.x) loss1 = model1(self.x) loss0.backward() loss1.backward() reference_grads.append([param.grad.data.clone() for param in model0.parameters()] + [param.grad.data.clone() for param in model1.parameters()]) optimizer.step() final_params = [param.data.clone() for param in model0.parameters()] + \ [param.data.clone() for param in model1.parameters()] for opt_level in ("O0", "O1", "O2", "O3"): for how_to_zero in ("none", "model", "optimizer"): for use_multiple_loss_scalers in (True, False): if opt_level == "O1" or opt_level == "O2": inject_inf_iters = (-1, 0, 1) else: inject_inf_iters = (-1,) for inject_inf in inject_inf_iters: if inject_inf >= 0: inject_inf_locs = ("fp16", "fp32") which_backwards = (0, 1) else: inject_inf_locs = ("fdsa",) which_backwards = (None,) for inject_inf_loc in inject_inf_locs: for which_backward in which_backwards: if use_multiple_loss_scalers: num_losses = 2 loss_ids = [0, 1] else: num_losses = 1 loss_ids = [0, 0] if inject_inf >= 0: iters = 3 else: iters = 2 model0 = MyModel(1) model1 = MyModel(2) models = [model0, model1] optimizer = torch.optim.SGD([{'params' : model0.parameters(), 'lr' : 0.25}, {'params' : model1.parameters(), 'lr' : 0.5}], momentum=0.125) _amp_state.allow_incoming_model_not_fp32 = True [model0, model1], optimizer = amp.initialize( [model0, model1], optimizer, opt_level=opt_level, verbosity=0, cast_model_type=False, num_losses=num_losses) _amp_state.allow_incoming_model_not_fp32 = False _amp_state.loss_scalers[0]._loss_scale = 4.0 if use_multiple_loss_scalers: _amp_state.loss_scalers[1]._loss_scale = 16.0 unskipped = 0 for i in range(iters): if how_to_zero == "none": for model in models: for param in model.parameters(): param.grad = None elif how_to_zero == "model": for model in models: model.zero_grad() else: optimizer.zero_grad() loss0 = model0(self.x) loss1 = model1(self.x) with amp.scale_loss(loss0, optimizer, loss_id=loss_ids[0]) as scaled_loss: scaled_loss.backward() if i == inject_inf and which_backward == 0: if inject_inf_loc == "fp32": model0.weight0.grad[0] = float('inf') elif inject_inf_loc == "fp16": model0.weight1.grad[0] = float('inf') with amp.scale_loss(loss1, optimizer, loss_id=loss_ids[1]) as scaled_loss: scaled_loss.backward() if i == inject_inf and which_backward == 1: if inject_inf_loc == "fp32": model1.weight0.grad[0] = float('inf') elif inject_inf_loc == "fp16": model1.weight1.grad[0] = float('inf') if i != inject_inf: for param, reference_grad in zip(amp.master_params(optimizer), reference_grads[unskipped]): torch.testing.assert_close(param.grad.float(), reference_grad.float()) unskipped += 1 optimizer.step() model_params = [p for p in model0.parameters()] + [p for p in model1.parameters()] for model, master, reference in zip( model_params, amp.master_params(optimizer), final_params): torch.testing.assert_close(model, reference) torch.testing.assert_close(model, master.to(model.dtype)) if opt_level == "O1": _amp_state.handle._deactivate() def test_3models2losses1optimizer(self): model0 = MyModel(1) model1 = MyModel(2) model2 = MyModel(3) optimizer = torch.optim.SGD([{'params' : model0.parameters(), 'lr' : 0.25}, {'params' : model1.parameters(), 'lr' : 0.5}, {'params' : model2.parameters(), 'lr' : 0.125}], momentum=0.125) reference_grads = [] for i in range(2): optimizer.zero_grad() loss0 = model0(self.x) + model2(self.x) loss1 = model1(self.x) + model2(self.x) loss0.backward() loss1.backward() reference_grads.append([param.grad.data.clone() for param in model0.parameters()] + [param.grad.data.clone() for param in model1.parameters()] + [param.grad.data.clone() for param in model2.parameters()]) optimizer.step() final_params = [param.data.clone() for param in model0.parameters()] + \ [param.data.clone() for param in model1.parameters()] + \ [param.data.clone() for param in model2.parameters()] for opt_level in ("O0", "O1", "O2", "O3"): for how_to_zero in ("none", "model", "optimizer"): for use_multiple_loss_scalers in (True, False): if opt_level == "O1" or opt_level == "O2": inject_inf_iters = (-1, 0, 1) else: inject_inf_iters = (-1,) for inject_inf in inject_inf_iters: if inject_inf >= 0: inject_inf_locs = ("fp16", "fp32") which_backwards = (0, 1) else: inject_inf_locs = ("fdsa",) which_backwards = (None,) for inject_inf_loc in inject_inf_locs: for which_backward in which_backwards: if use_multiple_loss_scalers: num_losses = 2 loss_ids = [0, 1] else: num_losses = 1 loss_ids = [0, 0] if inject_inf >= 0: iters = 3 if which_backward == 0: which_models = (0, 2) elif which_backward == 1: which_models = (1, 2) else: iters = 2 which_models = (None,) for which_model in which_models: model0 = MyModel(1) model1 = MyModel(2) model2 = MyModel(3) models = [model0, model1, model2] optimizer = torch.optim.SGD([{'params' : model0.parameters(), 'lr' : 0.25}, {'params' : model1.parameters(), 'lr' : 0.5}, {'params' : model2.parameters(), 'lr' : 0.125}], momentum=0.125) _amp_state.allow_incoming_model_not_fp32 = True [model0, model1, model2], optimizer = amp.initialize( [model0, model1, model2], optimizer, opt_level=opt_level, verbosity=0, cast_model_type=False, num_losses=num_losses) _amp_state.allow_incoming_model_not_fp32 = False _amp_state.loss_scalers[0]._loss_scale = 4.0 if use_multiple_loss_scalers: _amp_state.loss_scalers[1]._loss_scale = 16.0 unskipped = 0 for i in range(iters): if how_to_zero == "none": for model in models: for param in model.parameters(): param.grad = None elif how_to_zero == "model": for model in models: model.zero_grad() else: optimizer.zero_grad() # print("opt_level {} i {} inject_inf {} which_backward {} inject_inf_loc {} which_model {} use_multiple_loss_scalers {}".format(opt_level, i, inject_inf, which_backward, inject_inf_loc, which_model, use_multiple_loss_scalers)) loss0 = model0(self.x) + model2(self.x) loss1 = model1(self.x) + model2(self.x) with amp.scale_loss(loss0, optimizer, loss_id=loss_ids[0]) as scaled_loss: scaled_loss.backward() if i == inject_inf and which_backward == 0: if which_model == 0: inj_model = model0 elif which_model == 2: inj_model = model2 else: raise RuntimeError(which_model + " invalid for loss 0") if inject_inf_loc == "fp32": inj_model.weight0.grad[0] = float('inf') elif inject_inf_loc == "fp16": inj_model.weight1.grad[0] = float('inf') with amp.scale_loss(loss1, optimizer, loss_id=loss_ids[1]) as scaled_loss: scaled_loss.backward() if i == inject_inf and which_backward == 1: if which_model == 1: inj_model = model1 elif which_model == 2: inj_model = model2 else: raise RuntimeError(which_model + " invalid for loss 1 ") if inject_inf_loc == "fp32": inj_model.weight0.grad[0] = float('inf') elif inject_inf_loc == "fp16": inj_model.weight1.grad[0] = float('inf') if i != inject_inf: for param, reference_grad in zip(amp.master_params(optimizer), reference_grads[unskipped]): torch.testing.assert_close(param.grad.float(), reference_grad.float()) unskipped += 1 optimizer.step() model_params = [p for p in model0.parameters()] + \ [p for p in model1.parameters()] + \ [p for p in model2.parameters()] for model, master, reference in zip( model_params, amp.master_params(optimizer), final_params): torch.testing.assert_close(model, reference) torch.testing.assert_close(model, master.to(model.dtype)) if opt_level == "O1": _amp_state.handle._deactivate() def test_2models2losses2optimizers(self): model0 = MyModel(1) model1 = MyModel(2) optimizer0 = torch.optim.SGD([{'params' : model0.parameters(), 'lr' : 0.25}], momentum=0.125) optimizer1 = torch.optim.SGD([{'params' : model1.parameters(), 'lr' : 0.5}], momentum=0.25) # Don't do it like this: reference_grads = [[]]*5 # because then it creates a list of 5 references to the same "[]" and appending # to any of them effectively makes you append to all of them, which multiplies # the resulting size of reference_grads by 5x and needless to say makes the test fail. reference_grads = [[], [], [], [], []] final_params = [None, None, None, None, None] for i in range(2): optimizer0.zero_grad() optimizer1.zero_grad() loss0 = model0(self.x) loss1 = model1(self.x) loss0.backward() loss1.backward() reference_grads[0].append([param.grad.data.clone() for param in model0.parameters()] + [param.grad.data.clone() for param in model1.parameters()]) optimizer0.step() optimizer1.step() final_params[0] = [param.data.clone() for param in model0.parameters()] + \ [param.data.clone() for param in model1.parameters()] def what_got_skipped(which_iter, which_backward): if which_iter == 0 and which_backward == 0: return 1 if which_iter == 0 and which_backward == 1: return 2 if which_iter == 1 and which_backward == 0: return 3 if which_iter == 1 and which_backward == 1: return 4 return 0 for which_iter in (0,1): for which_backward in (0,1): model0 = MyModel(1) model1 = MyModel(2) optimizer0 = torch.optim.SGD([{'params' : model0.parameters(), 'lr' : 0.25}], momentum=0.125) optimizer1 = torch.optim.SGD([{'params' : model1.parameters(), 'lr' : 0.5}], momentum=0.25) for i in range(3): optimizer0.zero_grad() optimizer1.zero_grad() loss0 = model0(self.x) loss1 = model1(self.x) loss0.backward() loss1.backward() if i != which_iter: reference_grads[what_got_skipped(which_iter, which_backward)].append( [param.grad.data.clone() for param in model0.parameters()] + [param.grad.data.clone() for param in model1.parameters()]) if i == which_iter: if which_backward == 0: optimizer1.step() else: optimizer0.step() else: optimizer0.step() optimizer1.step() final_params[what_got_skipped(which_iter, which_backward)] = \ [param.data.clone() for param in model0.parameters()] + \ [param.data.clone() for param in model1.parameters()] for opt_level in ("O0", "O1", "O2", "O3"): for how_to_zero in ("none", "model", "optimizer"): for use_multiple_loss_scalers in (True, False): if opt_level == "O1" or opt_level == "O2": inject_inf_iters = (-1, 0, 1) else: inject_inf_iters = (-1,) for inject_inf in inject_inf_iters: if inject_inf >= 0: inject_inf_locs = ("fp16", "fp32") which_backwards = (0, 1) else: inject_inf_locs = ("fdsa",) which_backwards = (None,) for inject_inf_loc in inject_inf_locs: for which_backward in which_backwards: if use_multiple_loss_scalers: num_losses = 2 loss_ids = [0, 1] else: num_losses = 1 loss_ids = [0, 0] if inject_inf >= 0: iters = 3 else: iters = 2 model0 = MyModel(1) model1 = MyModel(2) models = [model0, model1] optimizer0 = torch.optim.SGD([{'params' : model0.parameters(), 'lr' : 0.25}], momentum=0.125) optimizer1 = torch.optim.SGD([{'params' : model1.parameters(), 'lr' : 0.5}], momentum=0.25) _amp_state.allow_incoming_model_not_fp32 = True [model0, model1], [optimizer0, optimizer1] = amp.initialize( [model0, model1], [optimizer0, optimizer1], opt_level=opt_level, verbosity=0, cast_model_type=False, num_losses=num_losses) _amp_state.allow_incoming_model_not_fp32 = False _amp_state.loss_scalers[0]._loss_scale = 4.0 if use_multiple_loss_scalers: _amp_state.loss_scalers[1]._loss_scale = 16.0 unskipped = 0 for i in range(iters): if how_to_zero == "none": for model in models: for param in model.parameters(): param.grad = None elif how_to_zero == "model": for model in models: model.zero_grad() else: optimizer0.zero_grad() optimizer1.zero_grad() loss0 = model0(self.x) loss1 = model1(self.x) with amp.scale_loss(loss0, optimizer0, loss_id=loss_ids[0]) as scaled_loss: scaled_loss.backward() if i == inject_inf and which_backward == 0: if inject_inf_loc == "fp32": model0.weight0.grad[0] = float('inf') elif inject_inf_loc == "fp16": model0.weight1.grad[0] = float('inf') with amp.scale_loss(loss1, optimizer1, loss_id=loss_ids[1]) as scaled_loss: scaled_loss.backward() if i == inject_inf and which_backward == 1: if inject_inf_loc == "fp32": model1.weight0.grad[0] = float('inf') elif inject_inf_loc == "fp16": model1.weight1.grad[0] = float('inf') # print("opt_level {} i {} inject_inf {} which_backward {} inject_inf_loc {} use_multiple_loss_scalers {}".format(opt_level, i, inject_inf, which_backward, inject_inf_loc, use_multiple_loss_scalers)) if i != inject_inf: master_params = list(amp.master_params(optimizer0)) + \ list(amp.master_params(optimizer1)) for param, reference_grad in zip(master_params, reference_grads[what_got_skipped(inject_inf, which_backward)][unskipped]): torch.testing.assert_close(param.grad.float(), reference_grad.float()) unskipped += 1 optimizer0.step() optimizer1.step() model_params = [p for p in model0.parameters()] + [p for p in model1.parameters()] master_params = [p for p in amp.master_params(optimizer0)] + \ [p for p in amp.master_params(optimizer1)] for model, master, reference in zip( model_params, master_params, final_params[what_got_skipped(inject_inf, which_backward)]): torch.testing.assert_close(model, reference) torch.testing.assert_close(model, master.to(model.dtype)) if opt_level == "O1": _amp_state.handle._deactivate() def test_3models2losses2optimizers(self): model0 = MyModel(1) model1 = MyModel(2) model2 = MyModel(3) optimizer0 = torch.optim.SGD([{'params' : model0.parameters(), 'lr' : 0.25}, {'params' : model1.parameters(), 'lr' : 1.0}], momentum=0.5) optimizer1 = torch.optim.SGD([{'params' : model2.parameters(), 'lr' : 0.5}], momentum=0.25) # Again, can't do this: reference_grads = [[]]*9 reference_grads = [[], [], [], [], [], [], [], [], []] final_params = [None, None, None, None, None, None, None, None, None] for i in range(2): optimizer0.zero_grad() optimizer1.zero_grad() loss0 = model0(self.x) + model1(self.x) loss1 = model2(self.x) + model1(self.x) loss0.backward() loss1.backward() reference_grads[0].append([param.grad.data.clone() for param in model0.parameters()] + [param.grad.data.clone() for param in model1.parameters()]) optimizer0.step() optimizer1.step() final_params[0] = \ [param.data.clone() for param in model0.parameters()] + \ [param.data.clone() for param in model1.parameters()] + \ [param.data.clone() for param in model2.parameters()] def what_got_skipped(which_iter, which_backward, which_model): if which_iter == 0: if which_backward == 0: if which_model == 0: return 1 if which_model == 1: return 2 if which_backward == 1: if which_model == 2: return 3 if which_model == 1: return 4 if which_iter == 1: if which_backward == 0: if which_model == 0: return 5 if which_model == 1: return 6 if which_backward == 1: if which_model == 2: return 7 if which_model == 1: return 8 return 0 for which_iter in (0,1): for which_backward in (0,1): if which_backward == 0: which_models = (0,1) if which_backward == 1: which_models = (2,1) for which_model in which_models: model0 = MyModel(1) model1 = MyModel(2) model2 = MyModel(3) optimizer0 = torch.optim.SGD([{'params' : model0.parameters(), 'lr' : 0.25}, {'params' : model1.parameters(), 'lr' : 1.0}], momentum=0.5) optimizer1 = torch.optim.SGD([{'params' : model2.parameters(), 'lr' : 0.5}], momentum=0.25) for i in range(3): optimizer0.zero_grad() optimizer1.zero_grad() loss0 = model0(self.x) + model1(self.x) loss1 = model2(self.x) + model1(self.x) loss0.backward() loss1.backward() if i != which_iter: reference_grads[what_got_skipped(which_iter, which_backward, which_model)].append( [param.grad.data.clone() for param in model0.parameters()] + [param.grad.data.clone() for param in model1.parameters()]) if i == which_iter: if which_backward == 0: # if which_model == 0: optimizer1.step() # if which_model == 1: # optimizer1.step() if which_backward == 1: # if which_model == 2: # optimizer0.step() # if which_model == 1: continue else: optimizer0.step() optimizer1.step() final_params[what_got_skipped(which_iter, which_backward, which_model)] = \ [param.data.clone() for param in model0.parameters()] + \ [param.data.clone() for param in model1.parameters()] + \ [param.data.clone() for param in model2.parameters()] for opt_level in ("O0", "O1", "O2", "O3"): for how_to_zero in ("none", "model", "optimizer"): for use_multiple_loss_scalers in (True, False): if opt_level == "O1" or opt_level == "O2": inject_inf_iters = (-1, 0, 1) else: inject_inf_iters = (-1,) for inject_inf in inject_inf_iters: if inject_inf >= 0: inject_inf_locs = ("fp16", "fp32") which_backwards = (0, 1) else: inject_inf_locs = ("fdsa",) which_backwards = (None,) for inject_inf_loc in inject_inf_locs: for which_backward in which_backwards: if use_multiple_loss_scalers: num_losses = 2 loss_ids = [0, 1] else: num_losses = 1 loss_ids = [0, 0] if inject_inf >= 0: iters = 3 if which_backward == 0: which_models = (0, 1) elif which_backward == 1: which_models = (2, 1) else: iters = 2 which_models = (None,) for which_model in which_models: model0 = MyModel(1) model1 = MyModel(2) model2 = MyModel(3) models = [model0, model1, model2] optimizer0 = torch.optim.SGD([{'params' : model0.parameters(), 'lr' : 0.25}, {'params' : model1.parameters(), 'lr' : 1.0}], momentum=0.5) optimizer1 = torch.optim.SGD([{'params' : model2.parameters(), 'lr' : 0.5}], momentum=0.25) _amp_state.allow_incoming_model_not_fp32 = True [model0, model1, model2], [optimizer0, optimizer1] = amp.initialize( [model0, model1, model2], [optimizer0, optimizer1], opt_level=opt_level, verbosity=0, cast_model_type=False, num_losses=num_losses) _amp_state.allow_incoming_model_not_fp32 = False _amp_state.loss_scalers[0]._loss_scale = 4.0 if use_multiple_loss_scalers: _amp_state.loss_scalers[1]._loss_scale = 16.0 unskipped = 0 for i in range(iters): if how_to_zero == "none": for model in models: for param in model.parameters(): param.grad = None elif how_to_zero == "model": for model in models: model.zero_grad() else: optimizer0.zero_grad() optimizer1.zero_grad() loss0 = model0(self.x) + model1(self.x) loss1 = model2(self.x) + model1(self.x) with amp.scale_loss(loss0, optimizer0, loss_id=loss_ids[0]) as scaled_loss: scaled_loss.backward() if i == inject_inf and which_backward == 0: if which_model == 0: inj_model = model0 elif which_model == 1: inj_model = model1 else: raise RuntimeError(which_model + " invalid for loss 0") if inject_inf_loc == "fp32": inj_model.weight0.grad[0] = float('inf') elif inject_inf_loc == "fp16": inj_model.weight1.grad[0] = float('inf') with amp.scale_loss(loss1, [optimizer0, optimizer1], loss_id=loss_ids[1]) as scaled_loss: scaled_loss.backward() if i == inject_inf and which_backward == 1: if which_model == 2: inj_model = model2 elif which_model == 1: inj_model = model1 else: raise RuntimeError(which_model + " invalid for loss 1 ") if inject_inf_loc == "fp32": inj_model.weight0.grad[0] = float('inf') elif inject_inf_loc == "fp16": inj_model.weight1.grad[0] = float('inf') if i != inject_inf: master_params = list(amp.master_params(optimizer0)) + \ list(amp.master_params(optimizer1)) for param, reference_grad in zip(master_params, reference_grads[what_got_skipped(inject_inf, which_backward, which_model)][unskipped]): torch.testing.assert_close(param.grad.float(), reference_grad.float()) unskipped += 1 optimizer0.step() optimizer1.step() model_params = [p for p in model0.parameters()] + \ [p for p in model1.parameters()] + \ [p for p in model2.parameters()] master_params = [p for p in amp.master_params(optimizer0)] + \ [p for p in amp.master_params(optimizer1)] # print("opt_level {} i {} inject_inf {} which_backward {} inject_inf_loc {} use_multiple_loss_scalers {} which_model {}".format(opt_level, i, inject_inf, which_backward, inject_inf_loc, use_multiple_loss_scalers, which_model)) for model, master, reference in zip( model_params, master_params, final_params[what_got_skipped(inject_inf, which_backward, which_model)]): torch.testing.assert_close(model, reference) torch.testing.assert_close(model, master.to(model.dtype)) if opt_level == "O1": _amp_state.handle._deactivate() if __name__ == '__main__': unittest.main()