test_multiple_models_optimizers_losses.py 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762
  1. import unittest
  2. import functools as ft
  3. import itertools as it
  4. from apex import amp
  5. from apex.amp import _amp_state
  6. import torch
  7. from torch import nn
  8. import torch.nn.functional as F
  9. from torch.nn import Parameter
  10. from utils import common_init, HALF, FLOAT,\
  11. ALWAYS_HALF, ALWAYS_FLOAT, MATCH_INPUT
  12. class MyModel(torch.nn.Module):
  13. def __init__(self, unique):
  14. super(MyModel, self).__init__()
  15. self.weight0 = Parameter(unique +
  16. torch.arange(2, device='cuda', dtype=torch.float32))
  17. self.weight1 = Parameter(1. + unique + torch.arange(2, device='cuda', dtype=torch.float16))
  18. @staticmethod
  19. def ops(input, weight0, weight1):
  20. return ((input*(weight0.float()))*(weight1.float())).sum()
  21. def forward(self, input):
  22. return self.ops(input, self.weight0, self.weight1)
  23. # Abandon all hope, ye who enter here.
  24. # This is hands down the ugliest code I have ever written, but it succeeds in testing
  25. # multiple models/optimizers/losses fairly thoroughly. Many of the different test cases
  26. # require slightly divergent code in a way that seems near-impossible to genericize into a simple
  27. # cross product or nested loops.
  28. class TestMultipleModelsOptimizersLosses(unittest.TestCase):
  29. def setUp(self):
  30. self.x = torch.ones((2), device='cuda', dtype=torch.float32)
  31. common_init(self)
  32. def tearDown(self):
  33. pass
  34. def test_2models2losses1optimizer(self):
  35. model0 = MyModel(1)
  36. model1 = MyModel(2)
  37. optimizer = torch.optim.SGD([{'params' : model0.parameters(), 'lr' : 0.25},
  38. {'params' : model1.parameters(), 'lr' : 0.5}],
  39. momentum=0.125)
  40. reference_grads = []
  41. for i in range(2):
  42. optimizer.zero_grad()
  43. loss0 = model0(self.x)
  44. loss1 = model1(self.x)
  45. loss0.backward()
  46. loss1.backward()
  47. reference_grads.append([param.grad.data.clone() for param in model0.parameters()] +
  48. [param.grad.data.clone() for param in model1.parameters()])
  49. optimizer.step()
  50. final_params = [param.data.clone() for param in model0.parameters()] + \
  51. [param.data.clone() for param in model1.parameters()]
  52. for opt_level in ("O0", "O1", "O2", "O3"):
  53. for how_to_zero in ("none", "model", "optimizer"):
  54. for use_multiple_loss_scalers in (True, False):
  55. if opt_level == "O1" or opt_level == "O2":
  56. inject_inf_iters = (-1, 0, 1)
  57. else:
  58. inject_inf_iters = (-1,)
  59. for inject_inf in inject_inf_iters:
  60. if inject_inf >= 0:
  61. inject_inf_locs = ("fp16", "fp32")
  62. which_backwards = (0, 1)
  63. else:
  64. inject_inf_locs = ("fdsa",)
  65. which_backwards = (None,)
  66. for inject_inf_loc in inject_inf_locs:
  67. for which_backward in which_backwards:
  68. if use_multiple_loss_scalers:
  69. num_losses = 2
  70. loss_ids = [0, 1]
  71. else:
  72. num_losses = 1
  73. loss_ids = [0, 0]
  74. if inject_inf >= 0:
  75. iters = 3
  76. else:
  77. iters = 2
  78. model0 = MyModel(1)
  79. model1 = MyModel(2)
  80. models = [model0, model1]
  81. optimizer = torch.optim.SGD([{'params' : model0.parameters(), 'lr' : 0.25},
  82. {'params' : model1.parameters(), 'lr' : 0.5}],
  83. momentum=0.125)
  84. _amp_state.allow_incoming_model_not_fp32 = True
  85. [model0, model1], optimizer = amp.initialize(
  86. [model0, model1],
  87. optimizer,
  88. opt_level=opt_level,
  89. verbosity=0,
  90. cast_model_type=False,
  91. num_losses=num_losses)
  92. _amp_state.allow_incoming_model_not_fp32 = False
  93. _amp_state.loss_scalers[0]._loss_scale = 4.0
  94. if use_multiple_loss_scalers:
  95. _amp_state.loss_scalers[1]._loss_scale = 16.0
  96. unskipped = 0
  97. for i in range(iters):
  98. if how_to_zero == "none":
  99. for model in models:
  100. for param in model.parameters():
  101. param.grad = None
  102. elif how_to_zero == "model":
  103. for model in models:
  104. model.zero_grad()
  105. else:
  106. optimizer.zero_grad()
  107. loss0 = model0(self.x)
  108. loss1 = model1(self.x)
  109. with amp.scale_loss(loss0, optimizer, loss_id=loss_ids[0]) as scaled_loss:
  110. scaled_loss.backward()
  111. if i == inject_inf and which_backward == 0:
  112. if inject_inf_loc == "fp32":
  113. model0.weight0.grad[0] = float('inf')
  114. elif inject_inf_loc == "fp16":
  115. model0.weight1.grad[0] = float('inf')
  116. with amp.scale_loss(loss1, optimizer, loss_id=loss_ids[1]) as scaled_loss:
  117. scaled_loss.backward()
  118. if i == inject_inf and which_backward == 1:
  119. if inject_inf_loc == "fp32":
  120. model1.weight0.grad[0] = float('inf')
  121. elif inject_inf_loc == "fp16":
  122. model1.weight1.grad[0] = float('inf')
  123. if i != inject_inf:
  124. for param, reference_grad in zip(amp.master_params(optimizer),
  125. reference_grads[unskipped]):
  126. torch.testing.assert_close(param.grad.float(), reference_grad.float())
  127. unskipped += 1
  128. optimizer.step()
  129. model_params = [p for p in model0.parameters()] + [p for p in model1.parameters()]
  130. for model, master, reference in zip(
  131. model_params,
  132. amp.master_params(optimizer),
  133. final_params):
  134. torch.testing.assert_close(model, reference)
  135. torch.testing.assert_close(model, master.to(model.dtype))
  136. if opt_level == "O1":
  137. _amp_state.handle._deactivate()
  138. def test_3models2losses1optimizer(self):
  139. model0 = MyModel(1)
  140. model1 = MyModel(2)
  141. model2 = MyModel(3)
  142. optimizer = torch.optim.SGD([{'params' : model0.parameters(), 'lr' : 0.25},
  143. {'params' : model1.parameters(), 'lr' : 0.5},
  144. {'params' : model2.parameters(), 'lr' : 0.125}],
  145. momentum=0.125)
  146. reference_grads = []
  147. for i in range(2):
  148. optimizer.zero_grad()
  149. loss0 = model0(self.x) + model2(self.x)
  150. loss1 = model1(self.x) + model2(self.x)
  151. loss0.backward()
  152. loss1.backward()
  153. reference_grads.append([param.grad.data.clone() for param in model0.parameters()] +
  154. [param.grad.data.clone() for param in model1.parameters()] +
  155. [param.grad.data.clone() for param in model2.parameters()])
  156. optimizer.step()
  157. final_params = [param.data.clone() for param in model0.parameters()] + \
  158. [param.data.clone() for param in model1.parameters()] + \
  159. [param.data.clone() for param in model2.parameters()]
  160. for opt_level in ("O0", "O1", "O2", "O3"):
  161. for how_to_zero in ("none", "model", "optimizer"):
  162. for use_multiple_loss_scalers in (True, False):
  163. if opt_level == "O1" or opt_level == "O2":
  164. inject_inf_iters = (-1, 0, 1)
  165. else:
  166. inject_inf_iters = (-1,)
  167. for inject_inf in inject_inf_iters:
  168. if inject_inf >= 0:
  169. inject_inf_locs = ("fp16", "fp32")
  170. which_backwards = (0, 1)
  171. else:
  172. inject_inf_locs = ("fdsa",)
  173. which_backwards = (None,)
  174. for inject_inf_loc in inject_inf_locs:
  175. for which_backward in which_backwards:
  176. if use_multiple_loss_scalers:
  177. num_losses = 2
  178. loss_ids = [0, 1]
  179. else:
  180. num_losses = 1
  181. loss_ids = [0, 0]
  182. if inject_inf >= 0:
  183. iters = 3
  184. if which_backward == 0:
  185. which_models = (0, 2)
  186. elif which_backward == 1:
  187. which_models = (1, 2)
  188. else:
  189. iters = 2
  190. which_models = (None,)
  191. for which_model in which_models:
  192. model0 = MyModel(1)
  193. model1 = MyModel(2)
  194. model2 = MyModel(3)
  195. models = [model0, model1, model2]
  196. optimizer = torch.optim.SGD([{'params' : model0.parameters(), 'lr' : 0.25},
  197. {'params' : model1.parameters(), 'lr' : 0.5},
  198. {'params' : model2.parameters(), 'lr' : 0.125}],
  199. momentum=0.125)
  200. _amp_state.allow_incoming_model_not_fp32 = True
  201. [model0, model1, model2], optimizer = amp.initialize(
  202. [model0, model1, model2],
  203. optimizer,
  204. opt_level=opt_level,
  205. verbosity=0,
  206. cast_model_type=False,
  207. num_losses=num_losses)
  208. _amp_state.allow_incoming_model_not_fp32 = False
  209. _amp_state.loss_scalers[0]._loss_scale = 4.0
  210. if use_multiple_loss_scalers:
  211. _amp_state.loss_scalers[1]._loss_scale = 16.0
  212. unskipped = 0
  213. for i in range(iters):
  214. if how_to_zero == "none":
  215. for model in models:
  216. for param in model.parameters():
  217. param.grad = None
  218. elif how_to_zero == "model":
  219. for model in models:
  220. model.zero_grad()
  221. else:
  222. optimizer.zero_grad()
  223. # print("opt_level {} i {} inject_inf {} which_backward {} inject_inf_loc {} which_model {} use_multiple_loss_scalers {}".format(opt_level, i, inject_inf, which_backward, inject_inf_loc, which_model, use_multiple_loss_scalers))
  224. loss0 = model0(self.x) + model2(self.x)
  225. loss1 = model1(self.x) + model2(self.x)
  226. with amp.scale_loss(loss0, optimizer, loss_id=loss_ids[0]) as scaled_loss:
  227. scaled_loss.backward()
  228. if i == inject_inf and which_backward == 0:
  229. if which_model == 0:
  230. inj_model = model0
  231. elif which_model == 2:
  232. inj_model = model2
  233. else:
  234. raise RuntimeError(which_model + " invalid for loss 0")
  235. if inject_inf_loc == "fp32":
  236. inj_model.weight0.grad[0] = float('inf')
  237. elif inject_inf_loc == "fp16":
  238. inj_model.weight1.grad[0] = float('inf')
  239. with amp.scale_loss(loss1, optimizer, loss_id=loss_ids[1]) as scaled_loss:
  240. scaled_loss.backward()
  241. if i == inject_inf and which_backward == 1:
  242. if which_model == 1:
  243. inj_model = model1
  244. elif which_model == 2:
  245. inj_model = model2
  246. else:
  247. raise RuntimeError(which_model + " invalid for loss 1 ")
  248. if inject_inf_loc == "fp32":
  249. inj_model.weight0.grad[0] = float('inf')
  250. elif inject_inf_loc == "fp16":
  251. inj_model.weight1.grad[0] = float('inf')
  252. if i != inject_inf:
  253. for param, reference_grad in zip(amp.master_params(optimizer),
  254. reference_grads[unskipped]):
  255. torch.testing.assert_close(param.grad.float(), reference_grad.float())
  256. unskipped += 1
  257. optimizer.step()
  258. model_params = [p for p in model0.parameters()] + \
  259. [p for p in model1.parameters()] + \
  260. [p for p in model2.parameters()]
  261. for model, master, reference in zip(
  262. model_params,
  263. amp.master_params(optimizer),
  264. final_params):
  265. torch.testing.assert_close(model, reference)
  266. torch.testing.assert_close(model, master.to(model.dtype))
  267. if opt_level == "O1":
  268. _amp_state.handle._deactivate()
  269. def test_2models2losses2optimizers(self):
  270. model0 = MyModel(1)
  271. model1 = MyModel(2)
  272. optimizer0 = torch.optim.SGD([{'params' : model0.parameters(), 'lr' : 0.25}],
  273. momentum=0.125)
  274. optimizer1 = torch.optim.SGD([{'params' : model1.parameters(), 'lr' : 0.5}],
  275. momentum=0.25)
  276. # Don't do it like this: reference_grads = [[]]*5
  277. # because then it creates a list of 5 references to the same "[]" and appending
  278. # to any of them effectively makes you append to all of them, which multiplies
  279. # the resulting size of reference_grads by 5x and needless to say makes the test fail.
  280. reference_grads = [[], [], [], [], []]
  281. final_params = [None, None, None, None, None]
  282. for i in range(2):
  283. optimizer0.zero_grad()
  284. optimizer1.zero_grad()
  285. loss0 = model0(self.x)
  286. loss1 = model1(self.x)
  287. loss0.backward()
  288. loss1.backward()
  289. reference_grads[0].append([param.grad.data.clone() for param in model0.parameters()] +
  290. [param.grad.data.clone() for param in model1.parameters()])
  291. optimizer0.step()
  292. optimizer1.step()
  293. final_params[0] = [param.data.clone() for param in model0.parameters()] + \
  294. [param.data.clone() for param in model1.parameters()]
  295. def what_got_skipped(which_iter, which_backward):
  296. if which_iter == 0 and which_backward == 0:
  297. return 1
  298. if which_iter == 0 and which_backward == 1:
  299. return 2
  300. if which_iter == 1 and which_backward == 0:
  301. return 3
  302. if which_iter == 1 and which_backward == 1:
  303. return 4
  304. return 0
  305. for which_iter in (0,1):
  306. for which_backward in (0,1):
  307. model0 = MyModel(1)
  308. model1 = MyModel(2)
  309. optimizer0 = torch.optim.SGD([{'params' : model0.parameters(), 'lr' : 0.25}],
  310. momentum=0.125)
  311. optimizer1 = torch.optim.SGD([{'params' : model1.parameters(), 'lr' : 0.5}],
  312. momentum=0.25)
  313. for i in range(3):
  314. optimizer0.zero_grad()
  315. optimizer1.zero_grad()
  316. loss0 = model0(self.x)
  317. loss1 = model1(self.x)
  318. loss0.backward()
  319. loss1.backward()
  320. if i != which_iter:
  321. reference_grads[what_got_skipped(which_iter, which_backward)].append(
  322. [param.grad.data.clone() for param in model0.parameters()] +
  323. [param.grad.data.clone() for param in model1.parameters()])
  324. if i == which_iter:
  325. if which_backward == 0:
  326. optimizer1.step()
  327. else:
  328. optimizer0.step()
  329. else:
  330. optimizer0.step()
  331. optimizer1.step()
  332. final_params[what_got_skipped(which_iter, which_backward)] = \
  333. [param.data.clone() for param in model0.parameters()] + \
  334. [param.data.clone() for param in model1.parameters()]
  335. for opt_level in ("O0", "O1", "O2", "O3"):
  336. for how_to_zero in ("none", "model", "optimizer"):
  337. for use_multiple_loss_scalers in (True, False):
  338. if opt_level == "O1" or opt_level == "O2":
  339. inject_inf_iters = (-1, 0, 1)
  340. else:
  341. inject_inf_iters = (-1,)
  342. for inject_inf in inject_inf_iters:
  343. if inject_inf >= 0:
  344. inject_inf_locs = ("fp16", "fp32")
  345. which_backwards = (0, 1)
  346. else:
  347. inject_inf_locs = ("fdsa",)
  348. which_backwards = (None,)
  349. for inject_inf_loc in inject_inf_locs:
  350. for which_backward in which_backwards:
  351. if use_multiple_loss_scalers:
  352. num_losses = 2
  353. loss_ids = [0, 1]
  354. else:
  355. num_losses = 1
  356. loss_ids = [0, 0]
  357. if inject_inf >= 0:
  358. iters = 3
  359. else:
  360. iters = 2
  361. model0 = MyModel(1)
  362. model1 = MyModel(2)
  363. models = [model0, model1]
  364. optimizer0 = torch.optim.SGD([{'params' : model0.parameters(), 'lr' : 0.25}],
  365. momentum=0.125)
  366. optimizer1 = torch.optim.SGD([{'params' : model1.parameters(), 'lr' : 0.5}],
  367. momentum=0.25)
  368. _amp_state.allow_incoming_model_not_fp32 = True
  369. [model0, model1], [optimizer0, optimizer1] = amp.initialize(
  370. [model0, model1],
  371. [optimizer0, optimizer1],
  372. opt_level=opt_level,
  373. verbosity=0,
  374. cast_model_type=False,
  375. num_losses=num_losses)
  376. _amp_state.allow_incoming_model_not_fp32 = False
  377. _amp_state.loss_scalers[0]._loss_scale = 4.0
  378. if use_multiple_loss_scalers:
  379. _amp_state.loss_scalers[1]._loss_scale = 16.0
  380. unskipped = 0
  381. for i in range(iters):
  382. if how_to_zero == "none":
  383. for model in models:
  384. for param in model.parameters():
  385. param.grad = None
  386. elif how_to_zero == "model":
  387. for model in models:
  388. model.zero_grad()
  389. else:
  390. optimizer0.zero_grad()
  391. optimizer1.zero_grad()
  392. loss0 = model0(self.x)
  393. loss1 = model1(self.x)
  394. with amp.scale_loss(loss0, optimizer0, loss_id=loss_ids[0]) as scaled_loss:
  395. scaled_loss.backward()
  396. if i == inject_inf and which_backward == 0:
  397. if inject_inf_loc == "fp32":
  398. model0.weight0.grad[0] = float('inf')
  399. elif inject_inf_loc == "fp16":
  400. model0.weight1.grad[0] = float('inf')
  401. with amp.scale_loss(loss1, optimizer1, loss_id=loss_ids[1]) as scaled_loss:
  402. scaled_loss.backward()
  403. if i == inject_inf and which_backward == 1:
  404. if inject_inf_loc == "fp32":
  405. model1.weight0.grad[0] = float('inf')
  406. elif inject_inf_loc == "fp16":
  407. model1.weight1.grad[0] = float('inf')
  408. # print("opt_level {} i {} inject_inf {} which_backward {} inject_inf_loc {} use_multiple_loss_scalers {}".format(opt_level, i, inject_inf, which_backward, inject_inf_loc, use_multiple_loss_scalers))
  409. if i != inject_inf:
  410. master_params = list(amp.master_params(optimizer0)) + \
  411. list(amp.master_params(optimizer1))
  412. for param, reference_grad in zip(master_params,
  413. reference_grads[what_got_skipped(inject_inf, which_backward)][unskipped]):
  414. torch.testing.assert_close(param.grad.float(), reference_grad.float())
  415. unskipped += 1
  416. optimizer0.step()
  417. optimizer1.step()
  418. model_params = [p for p in model0.parameters()] + [p for p in model1.parameters()]
  419. master_params = [p for p in amp.master_params(optimizer0)] + \
  420. [p for p in amp.master_params(optimizer1)]
  421. for model, master, reference in zip(
  422. model_params,
  423. master_params,
  424. final_params[what_got_skipped(inject_inf, which_backward)]):
  425. torch.testing.assert_close(model, reference)
  426. torch.testing.assert_close(model, master.to(model.dtype))
  427. if opt_level == "O1":
  428. _amp_state.handle._deactivate()
  429. def test_3models2losses2optimizers(self):
  430. model0 = MyModel(1)
  431. model1 = MyModel(2)
  432. model2 = MyModel(3)
  433. optimizer0 = torch.optim.SGD([{'params' : model0.parameters(), 'lr' : 0.25},
  434. {'params' : model1.parameters(), 'lr' : 1.0}],
  435. momentum=0.5)
  436. optimizer1 = torch.optim.SGD([{'params' : model2.parameters(), 'lr' : 0.5}],
  437. momentum=0.25)
  438. # Again, can't do this: reference_grads = [[]]*9
  439. reference_grads = [[], [], [], [], [], [], [], [], []]
  440. final_params = [None, None, None, None, None, None, None, None, None]
  441. for i in range(2):
  442. optimizer0.zero_grad()
  443. optimizer1.zero_grad()
  444. loss0 = model0(self.x) + model1(self.x)
  445. loss1 = model2(self.x) + model1(self.x)
  446. loss0.backward()
  447. loss1.backward()
  448. reference_grads[0].append([param.grad.data.clone() for param in model0.parameters()] +
  449. [param.grad.data.clone() for param in model1.parameters()])
  450. optimizer0.step()
  451. optimizer1.step()
  452. final_params[0] = \
  453. [param.data.clone() for param in model0.parameters()] + \
  454. [param.data.clone() for param in model1.parameters()] + \
  455. [param.data.clone() for param in model2.parameters()]
  456. def what_got_skipped(which_iter, which_backward, which_model):
  457. if which_iter == 0:
  458. if which_backward == 0:
  459. if which_model == 0:
  460. return 1
  461. if which_model == 1:
  462. return 2
  463. if which_backward == 1:
  464. if which_model == 2:
  465. return 3
  466. if which_model == 1:
  467. return 4
  468. if which_iter == 1:
  469. if which_backward == 0:
  470. if which_model == 0:
  471. return 5
  472. if which_model == 1:
  473. return 6
  474. if which_backward == 1:
  475. if which_model == 2:
  476. return 7
  477. if which_model == 1:
  478. return 8
  479. return 0
  480. for which_iter in (0,1):
  481. for which_backward in (0,1):
  482. if which_backward == 0:
  483. which_models = (0,1)
  484. if which_backward == 1:
  485. which_models = (2,1)
  486. for which_model in which_models:
  487. model0 = MyModel(1)
  488. model1 = MyModel(2)
  489. model2 = MyModel(3)
  490. optimizer0 = torch.optim.SGD([{'params' : model0.parameters(), 'lr' : 0.25},
  491. {'params' : model1.parameters(), 'lr' : 1.0}],
  492. momentum=0.5)
  493. optimizer1 = torch.optim.SGD([{'params' : model2.parameters(), 'lr' : 0.5}],
  494. momentum=0.25)
  495. for i in range(3):
  496. optimizer0.zero_grad()
  497. optimizer1.zero_grad()
  498. loss0 = model0(self.x) + model1(self.x)
  499. loss1 = model2(self.x) + model1(self.x)
  500. loss0.backward()
  501. loss1.backward()
  502. if i != which_iter:
  503. reference_grads[what_got_skipped(which_iter,
  504. which_backward, which_model)].append(
  505. [param.grad.data.clone() for param in model0.parameters()] +
  506. [param.grad.data.clone() for param in model1.parameters()])
  507. if i == which_iter:
  508. if which_backward == 0:
  509. # if which_model == 0:
  510. optimizer1.step()
  511. # if which_model == 1:
  512. # optimizer1.step()
  513. if which_backward == 1:
  514. # if which_model == 2:
  515. # optimizer0.step()
  516. # if which_model == 1:
  517. continue
  518. else:
  519. optimizer0.step()
  520. optimizer1.step()
  521. final_params[what_got_skipped(which_iter, which_backward, which_model)] = \
  522. [param.data.clone() for param in model0.parameters()] + \
  523. [param.data.clone() for param in model1.parameters()] + \
  524. [param.data.clone() for param in model2.parameters()]
  525. for opt_level in ("O0", "O1", "O2", "O3"):
  526. for how_to_zero in ("none", "model", "optimizer"):
  527. for use_multiple_loss_scalers in (True, False):
  528. if opt_level == "O1" or opt_level == "O2":
  529. inject_inf_iters = (-1, 0, 1)
  530. else:
  531. inject_inf_iters = (-1,)
  532. for inject_inf in inject_inf_iters:
  533. if inject_inf >= 0:
  534. inject_inf_locs = ("fp16", "fp32")
  535. which_backwards = (0, 1)
  536. else:
  537. inject_inf_locs = ("fdsa",)
  538. which_backwards = (None,)
  539. for inject_inf_loc in inject_inf_locs:
  540. for which_backward in which_backwards:
  541. if use_multiple_loss_scalers:
  542. num_losses = 2
  543. loss_ids = [0, 1]
  544. else:
  545. num_losses = 1
  546. loss_ids = [0, 0]
  547. if inject_inf >= 0:
  548. iters = 3
  549. if which_backward == 0:
  550. which_models = (0, 1)
  551. elif which_backward == 1:
  552. which_models = (2, 1)
  553. else:
  554. iters = 2
  555. which_models = (None,)
  556. for which_model in which_models:
  557. model0 = MyModel(1)
  558. model1 = MyModel(2)
  559. model2 = MyModel(3)
  560. models = [model0, model1, model2]
  561. optimizer0 = torch.optim.SGD([{'params' : model0.parameters(), 'lr' : 0.25},
  562. {'params' : model1.parameters(), 'lr' : 1.0}],
  563. momentum=0.5)
  564. optimizer1 = torch.optim.SGD([{'params' : model2.parameters(), 'lr' : 0.5}],
  565. momentum=0.25)
  566. _amp_state.allow_incoming_model_not_fp32 = True
  567. [model0, model1, model2], [optimizer0, optimizer1] = amp.initialize(
  568. [model0, model1, model2],
  569. [optimizer0, optimizer1],
  570. opt_level=opt_level,
  571. verbosity=0,
  572. cast_model_type=False,
  573. num_losses=num_losses)
  574. _amp_state.allow_incoming_model_not_fp32 = False
  575. _amp_state.loss_scalers[0]._loss_scale = 4.0
  576. if use_multiple_loss_scalers:
  577. _amp_state.loss_scalers[1]._loss_scale = 16.0
  578. unskipped = 0
  579. for i in range(iters):
  580. if how_to_zero == "none":
  581. for model in models:
  582. for param in model.parameters():
  583. param.grad = None
  584. elif how_to_zero == "model":
  585. for model in models:
  586. model.zero_grad()
  587. else:
  588. optimizer0.zero_grad()
  589. optimizer1.zero_grad()
  590. loss0 = model0(self.x) + model1(self.x)
  591. loss1 = model2(self.x) + model1(self.x)
  592. with amp.scale_loss(loss0, optimizer0, loss_id=loss_ids[0]) as scaled_loss:
  593. scaled_loss.backward()
  594. if i == inject_inf and which_backward == 0:
  595. if which_model == 0:
  596. inj_model = model0
  597. elif which_model == 1:
  598. inj_model = model1
  599. else:
  600. raise RuntimeError(which_model + " invalid for loss 0")
  601. if inject_inf_loc == "fp32":
  602. inj_model.weight0.grad[0] = float('inf')
  603. elif inject_inf_loc == "fp16":
  604. inj_model.weight1.grad[0] = float('inf')
  605. with amp.scale_loss(loss1, [optimizer0, optimizer1], loss_id=loss_ids[1]) as scaled_loss:
  606. scaled_loss.backward()
  607. if i == inject_inf and which_backward == 1:
  608. if which_model == 2:
  609. inj_model = model2
  610. elif which_model == 1:
  611. inj_model = model1
  612. else:
  613. raise RuntimeError(which_model + " invalid for loss 1 ")
  614. if inject_inf_loc == "fp32":
  615. inj_model.weight0.grad[0] = float('inf')
  616. elif inject_inf_loc == "fp16":
  617. inj_model.weight1.grad[0] = float('inf')
  618. if i != inject_inf:
  619. master_params = list(amp.master_params(optimizer0)) + \
  620. list(amp.master_params(optimizer1))
  621. for param, reference_grad in zip(master_params,
  622. reference_grads[what_got_skipped(inject_inf,
  623. which_backward, which_model)][unskipped]):
  624. torch.testing.assert_close(param.grad.float(), reference_grad.float())
  625. unskipped += 1
  626. optimizer0.step()
  627. optimizer1.step()
  628. model_params = [p for p in model0.parameters()] + \
  629. [p for p in model1.parameters()] + \
  630. [p for p in model2.parameters()]
  631. master_params = [p for p in amp.master_params(optimizer0)] + \
  632. [p for p in amp.master_params(optimizer1)]
  633. # print("opt_level {} i {} inject_inf {} which_backward {} inject_inf_loc {} use_multiple_loss_scalers {} which_model {}".format(opt_level, i, inject_inf, which_backward, inject_inf_loc, use_multiple_loss_scalers, which_model))
  634. for model, master, reference in zip(
  635. model_params,
  636. master_params,
  637. final_params[what_got_skipped(inject_inf, which_backward, which_model)]):
  638. torch.testing.assert_close(model, reference)
  639. torch.testing.assert_close(model, master.to(model.dtype))
  640. if opt_level == "O1":
  641. _amp_state.handle._deactivate()
  642. if __name__ == '__main__':
  643. unittest.main()