test_fused_sgd.py 38 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794
  1. import unittest
  2. import functools as ft
  3. import itertools as it
  4. from apex import amp
  5. from apex.amp import _amp_state
  6. import torch
  7. from torch import nn
  8. import torch.nn.functional as F
  9. from torch.nn import Parameter
  10. from utils import common_init, HALF, FLOAT,\
  11. ALWAYS_HALF, ALWAYS_FLOAT, MATCH_INPUT
  12. try:
  13. import amp_C
  14. disabled = False
  15. from apex.optimizers import FusedSGD as FusedSGD
  16. except ImportError as err:
  17. print("amp_C fused kernels unavailable, disabling TestMultiTensorApply. ImportError was ", err)
  18. disabled = True
  19. class MyModel(torch.nn.Module):
  20. def __init__(self, unique):
  21. super(MyModel, self).__init__()
  22. self.weight0 = Parameter(unique +
  23. torch.arange(2, device='cuda', dtype=torch.float32))
  24. self.weight1 = Parameter(1. + unique + torch.arange(2, device='cuda', dtype=torch.float16))
  25. @staticmethod
  26. def ops(input, weight0, weight1):
  27. return ((input*(weight0.float()))*(weight1.float())).sum()
  28. def forward(self, input):
  29. return self.ops(input, self.weight0, self.weight1)
  30. # Abandon all hope, ye who enter here.
  31. # This is hands down the ugliest code I have ever written, but it succeeds in testing
  32. # multiple models/optimizers/losses fairly thoroughly. Many of the different test cases
  33. # require slightly divergent code in a way that seems near-impossible to genericize into a simple
  34. # cross product or nested loops.
  35. class TestMultipleModelsOptimizersLosses(unittest.TestCase):
  36. def setUp(self):
  37. self.x = torch.ones((2), device='cuda', dtype=torch.float32)
  38. common_init(self)
  39. def tearDown(self):
  40. pass
  41. @unittest.skipIf(disabled, "amp_C is unavailable")
  42. def test_2models2losses1optimizer(self):
  43. model0 = MyModel(1)
  44. model1 = MyModel(2)
  45. optimizer = torch.optim.SGD([{'params' : model0.parameters(), 'lr' : 0.25},
  46. {'params' : model1.parameters(), 'lr' : 0.5}],
  47. momentum=0.125)
  48. reference_grads = []
  49. for i in range(2):
  50. optimizer.zero_grad()
  51. loss0 = model0(self.x)
  52. loss1 = model1(self.x)
  53. loss0.backward()
  54. loss1.backward()
  55. reference_grads.append([param.grad.data.clone() for param in model0.parameters()] +
  56. [param.grad.data.clone() for param in model1.parameters()])
  57. optimizer.step()
  58. final_params = [param.data.clone() for param in model0.parameters()] + \
  59. [param.data.clone() for param in model1.parameters()]
  60. for materialize_master_grads in (False, True):
  61. for opt_level in ("O0", "O1", "O2", "O3"):
  62. for how_to_zero in ("none", "model", "optimizer"):
  63. for use_multiple_loss_scalers in (False, True):
  64. if opt_level == "O1" or opt_level == "O2":
  65. inject_inf_iters = (-1, 0, 1)
  66. else:
  67. inject_inf_iters = (-1,)
  68. for inject_inf in inject_inf_iters:
  69. if inject_inf >= 0:
  70. inject_inf_locs = ("fp16", "fp32")
  71. which_backwards = (0, 1)
  72. else:
  73. inject_inf_locs = ("fdsa",)
  74. which_backwards = (None,)
  75. for inject_inf_loc in inject_inf_locs:
  76. for which_backward in which_backwards:
  77. if use_multiple_loss_scalers:
  78. num_losses = 2
  79. loss_ids = [0, 1]
  80. else:
  81. num_losses = 1
  82. loss_ids = [0, 0]
  83. if inject_inf >= 0:
  84. iters = 3
  85. else:
  86. iters = 2
  87. model0 = MyModel(1)
  88. model1 = MyModel(2)
  89. models = [model0, model1]
  90. optimizer = FusedSGD([{'params' : model0.parameters(), 'lr' : 0.25},
  91. {'params' : model1.parameters(), 'lr' : 0.5}],
  92. momentum=0.125,
  93. materialize_master_grads=materialize_master_grads)
  94. _amp_state.allow_incoming_model_not_fp32 = True
  95. [model0, model1], optimizer = amp.initialize(
  96. [model0, model1],
  97. optimizer,
  98. opt_level=opt_level,
  99. verbosity=0,
  100. cast_model_type=False,
  101. num_losses=num_losses)
  102. _amp_state.allow_incoming_model_not_fp32 = False
  103. _amp_state.loss_scalers[0]._loss_scale = 4.0
  104. if use_multiple_loss_scalers:
  105. _amp_state.loss_scalers[1]._loss_scale = 16.0
  106. unskipped = 0
  107. for i in range(iters):
  108. if how_to_zero == "none":
  109. for model in models:
  110. for param in model.parameters():
  111. param.grad = None
  112. elif how_to_zero == "model":
  113. for model in models:
  114. model.zero_grad()
  115. else:
  116. optimizer.zero_grad()
  117. loss0 = model0(self.x)
  118. loss1 = model1(self.x)
  119. with amp.scale_loss(loss0, optimizer, loss_id=loss_ids[0]) as scaled_loss:
  120. scaled_loss.backward()
  121. if i == inject_inf and which_backward == 0:
  122. if inject_inf_loc == "fp32":
  123. model0.weight0.grad[0] = float('inf')
  124. elif inject_inf_loc == "fp16":
  125. model0.weight1.grad[0] = float('inf')
  126. with amp.scale_loss(loss1, optimizer, loss_id=loss_ids[1]) as scaled_loss:
  127. scaled_loss.backward()
  128. if i == inject_inf and which_backward == 1:
  129. if inject_inf_loc == "fp32":
  130. model1.weight0.grad[0] = float('inf')
  131. elif inject_inf_loc == "fp16":
  132. model1.weight1.grad[0] = float('inf')
  133. if i != inject_inf:
  134. master_params = amp.master_params(optimizer)
  135. for param, reference_grad in zip(master_params, reference_grads[unskipped]):
  136. if opt_level == "O2" and not materialize_master_grads:
  137. continue
  138. else:
  139. torch.testing.assert_close(param.grad.float(), reference_grad.float(),
  140. msg="opt_level {} i {} inject_inf {} which_backward {} inject_inf_loc {} use_multiple_loss_scalers {}".format(opt_level, i, inject_inf, which_backward, inject_inf_loc, use_multiple_loss_scalers))
  141. unskipped += 1
  142. optimizer.step()
  143. model_params = [p for p in model0.parameters()] + [p for p in model1.parameters()]
  144. for model, master, reference in zip(
  145. model_params,
  146. amp.master_params(optimizer),
  147. final_params):
  148. torch.testing.assert_close(model, reference)
  149. torch.testing.assert_close(model, master.to(model.dtype))
  150. if opt_level == "O1":
  151. _amp_state.handle._deactivate()
  152. @unittest.skipIf(disabled, "amp_C is unavailable")
  153. def test_3models2losses1optimizer(self):
  154. model0 = MyModel(1)
  155. model1 = MyModel(2)
  156. model2 = MyModel(3)
  157. optimizer = torch.optim.SGD([{'params' : model0.parameters(), 'lr' : 0.25},
  158. {'params' : model1.parameters(), 'lr' : 0.5},
  159. {'params' : model2.parameters(), 'lr' : 0.125}],
  160. momentum=0.125)
  161. reference_grads = []
  162. for i in range(2):
  163. optimizer.zero_grad()
  164. loss0 = model0(self.x) + model2(self.x)
  165. loss1 = model1(self.x) + model2(self.x)
  166. loss0.backward()
  167. loss1.backward()
  168. reference_grads.append([param.grad.data.clone() for param in model0.parameters()] +
  169. [param.grad.data.clone() for param in model1.parameters()] +
  170. [param.grad.data.clone() for param in model2.parameters()])
  171. optimizer.step()
  172. final_params = [param.data.clone() for param in model0.parameters()] + \
  173. [param.data.clone() for param in model1.parameters()] + \
  174. [param.data.clone() for param in model2.parameters()]
  175. for materialize_master_grads in (False, True):
  176. for opt_level in ("O0", "O1", "O2", "O3"):
  177. for how_to_zero in ("none", "model", "optimizer"):
  178. for use_multiple_loss_scalers in (False, True):
  179. if opt_level == "O1" or opt_level == "O2":
  180. inject_inf_iters = (-1, 0, 1)
  181. else:
  182. inject_inf_iters = (-1,)
  183. for inject_inf in inject_inf_iters:
  184. if inject_inf >= 0:
  185. inject_inf_locs = ("fp16", "fp32")
  186. which_backwards = (0, 1)
  187. else:
  188. inject_inf_locs = ("fdsa",)
  189. which_backwards = (None,)
  190. for inject_inf_loc in inject_inf_locs:
  191. for which_backward in which_backwards:
  192. if use_multiple_loss_scalers:
  193. num_losses = 2
  194. loss_ids = [0, 1]
  195. else:
  196. num_losses = 1
  197. loss_ids = [0, 0]
  198. if inject_inf >= 0:
  199. iters = 3
  200. if which_backward == 0:
  201. which_models = (0, 2)
  202. elif which_backward == 1:
  203. which_models = (1, 2)
  204. else:
  205. iters = 2
  206. which_models = (None,)
  207. for which_model in which_models:
  208. model0 = MyModel(1)
  209. model1 = MyModel(2)
  210. model2 = MyModel(3)
  211. models = [model0, model1, model2]
  212. optimizer = FusedSGD([{'params' : model0.parameters(), 'lr' : 0.25},
  213. {'params' : model1.parameters(), 'lr' : 0.5},
  214. {'params' : model2.parameters(), 'lr' : 0.125}],
  215. momentum=0.125,
  216. materialize_master_grads=materialize_master_grads)
  217. _amp_state.allow_incoming_model_not_fp32 = True
  218. [model0, model1, model2], optimizer = amp.initialize(
  219. [model0, model1, model2],
  220. optimizer,
  221. opt_level=opt_level,
  222. verbosity=0,
  223. cast_model_type=False,
  224. num_losses=num_losses)
  225. _amp_state.allow_incoming_model_not_fp32 = False
  226. _amp_state.loss_scalers[0]._loss_scale = 4.0
  227. if use_multiple_loss_scalers:
  228. _amp_state.loss_scalers[1]._loss_scale = 16.0
  229. unskipped = 0
  230. for i in range(iters):
  231. if how_to_zero == "none":
  232. for model in models:
  233. for param in model.parameters():
  234. param.grad = None
  235. elif how_to_zero == "model":
  236. for model in models:
  237. model.zero_grad()
  238. else:
  239. optimizer.zero_grad()
  240. loss0 = model0(self.x) + model2(self.x)
  241. loss1 = model1(self.x) + model2(self.x)
  242. with amp.scale_loss(loss0, optimizer, loss_id=loss_ids[0]) as scaled_loss:
  243. scaled_loss.backward()
  244. if i == inject_inf and which_backward == 0:
  245. if which_model == 0:
  246. inj_model = model0
  247. elif which_model == 2:
  248. inj_model = model2
  249. else:
  250. raise RuntimeError(which_model + " invalid for loss 0")
  251. if inject_inf_loc == "fp32":
  252. inj_model.weight0.grad[0] = float('inf')
  253. elif inject_inf_loc == "fp16":
  254. inj_model.weight1.grad[0] = float('inf')
  255. with amp.scale_loss(loss1, optimizer, loss_id=loss_ids[1]) as scaled_loss:
  256. scaled_loss.backward()
  257. if i == inject_inf and which_backward == 1:
  258. if which_model == 1:
  259. inj_model = model1
  260. elif which_model == 2:
  261. inj_model = model2
  262. else:
  263. raise RuntimeError(which_model + " invalid for loss 1 ")
  264. if inject_inf_loc == "fp32":
  265. inj_model.weight0.grad[0] = float('inf')
  266. elif inject_inf_loc == "fp16":
  267. inj_model.weight1.grad[0] = float('inf')
  268. if i != inject_inf:
  269. master_params = amp.master_params(optimizer)
  270. for param, reference_grad in zip(master_params, reference_grads[unskipped]):
  271. if opt_level == "O2" and not materialize_master_grads:
  272. continue
  273. else:
  274. torch.testing.assert_close(param.grad.float(), reference_grad.float(),
  275. msg="opt_level {} i {} inject_inf {} which_backward {} inject_inf_loc {} which_model {} use_multiple_loss_scalers {}".format(opt_level, i, inject_inf, which_backward, inject_inf_loc, which_model, use_multiple_loss_scalers))
  276. unskipped += 1
  277. optimizer.step()
  278. model_params = [p for p in model0.parameters()] + \
  279. [p for p in model1.parameters()] + \
  280. [p for p in model2.parameters()]
  281. for model, master, reference in zip(
  282. model_params,
  283. amp.master_params(optimizer),
  284. final_params):
  285. torch.testing.assert_close(model, reference)
  286. torch.testing.assert_close(model, master.to(model.dtype))
  287. if opt_level == "O1":
  288. _amp_state.handle._deactivate()
  289. @unittest.skipIf(disabled, "amp_C is unavailable")
  290. def test_2models2losses2optimizers(self):
  291. model0 = MyModel(1)
  292. model1 = MyModel(2)
  293. optimizer0 = torch.optim.SGD([{'params' : model0.parameters(), 'lr' : 0.25}],
  294. momentum=0.125)
  295. optimizer1 = torch.optim.SGD([{'params' : model1.parameters(), 'lr' : 0.5}],
  296. momentum=0.25)
  297. # Don't do it like this: reference_grads = [[]]*5
  298. # because then it creates a list of 5 references to the same "[]" and appending
  299. # to any of them effectively makes you append to all of them, which multiplies
  300. # the resulting size of reference_grads by 5x and needless to say makes the test fail.
  301. reference_grads = [[], [], [], [], []]
  302. final_params = [None, None, None, None, None]
  303. for i in range(2):
  304. optimizer0.zero_grad()
  305. optimizer1.zero_grad()
  306. loss0 = model0(self.x)
  307. loss1 = model1(self.x)
  308. loss0.backward()
  309. loss1.backward()
  310. reference_grads[0].append([param.grad.data.clone() for param in model0.parameters()] +
  311. [param.grad.data.clone() for param in model1.parameters()])
  312. optimizer0.step()
  313. optimizer1.step()
  314. final_params[0] = [param.data.clone() for param in model0.parameters()] + \
  315. [param.data.clone() for param in model1.parameters()]
  316. def what_got_skipped(which_iter, which_backward):
  317. if which_iter == 0 and which_backward == 0:
  318. return 1
  319. if which_iter == 0 and which_backward == 1:
  320. return 2
  321. if which_iter == 1 and which_backward == 0:
  322. return 3
  323. if which_iter == 1 and which_backward == 1:
  324. return 4
  325. return 0
  326. for which_iter in (0,1):
  327. for which_backward in (0,1):
  328. model0 = MyModel(1)
  329. model1 = MyModel(2)
  330. optimizer0 = torch.optim.SGD([{'params' : model0.parameters(), 'lr' : 0.25}],
  331. momentum=0.125)
  332. optimizer1 = torch.optim.SGD([{'params' : model1.parameters(), 'lr' : 0.5}],
  333. momentum=0.25)
  334. for i in range(3):
  335. optimizer0.zero_grad()
  336. optimizer1.zero_grad()
  337. loss0 = model0(self.x)
  338. loss1 = model1(self.x)
  339. loss0.backward()
  340. loss1.backward()
  341. if i != which_iter:
  342. reference_grads[what_got_skipped(which_iter, which_backward)].append(
  343. [param.grad.data.clone() for param in model0.parameters()] +
  344. [param.grad.data.clone() for param in model1.parameters()])
  345. if i == which_iter:
  346. if which_backward == 0:
  347. optimizer1.step()
  348. else:
  349. optimizer0.step()
  350. else:
  351. optimizer0.step()
  352. optimizer1.step()
  353. final_params[what_got_skipped(which_iter, which_backward)] = \
  354. [param.data.clone() for param in model0.parameters()] + \
  355. [param.data.clone() for param in model1.parameters()]
  356. for materialize_master_grads in (False, True):
  357. for opt_level in ("O0", "O1", "O2", "O3"):
  358. for how_to_zero in ("none", "model", "optimizer"):
  359. for use_multiple_loss_scalers in (False, True):
  360. if opt_level == "O1" or opt_level == "O2":
  361. inject_inf_iters = (-1, 0, 1)
  362. else:
  363. inject_inf_iters = (-1,)
  364. for inject_inf in inject_inf_iters:
  365. if inject_inf >= 0:
  366. inject_inf_locs = ("fp16", "fp32")
  367. which_backwards = (0, 1)
  368. else:
  369. inject_inf_locs = ("fdsa",)
  370. which_backwards = (None,)
  371. for inject_inf_loc in inject_inf_locs:
  372. for which_backward in which_backwards:
  373. if use_multiple_loss_scalers:
  374. num_losses = 2
  375. loss_ids = [0, 1]
  376. else:
  377. num_losses = 1
  378. loss_ids = [0, 0]
  379. if inject_inf >= 0:
  380. iters = 3
  381. else:
  382. iters = 2
  383. model0 = MyModel(1)
  384. model1 = MyModel(2)
  385. models = [model0, model1]
  386. optimizer0 = FusedSGD([{'params' : model0.parameters(), 'lr' : 0.25}],
  387. momentum=0.125, materialize_master_grads=materialize_master_grads)
  388. optimizer1 = FusedSGD([{'params' : model1.parameters(), 'lr' : 0.5}],
  389. momentum=0.25, materialize_master_grads=materialize_master_grads)
  390. _amp_state.allow_incoming_model_not_fp32 = True
  391. [model0, model1], [optimizer0, optimizer1] = amp.initialize(
  392. [model0, model1],
  393. [optimizer0, optimizer1],
  394. opt_level=opt_level,
  395. verbosity=0,
  396. cast_model_type=False,
  397. num_losses=num_losses)
  398. _amp_state.allow_incoming_model_not_fp32 = False
  399. _amp_state.loss_scalers[0]._loss_scale = 4.0
  400. if use_multiple_loss_scalers:
  401. _amp_state.loss_scalers[1]._loss_scale = 16.0
  402. unskipped = 0
  403. for i in range(iters):
  404. if how_to_zero == "none":
  405. for model in models:
  406. for param in model.parameters():
  407. param.grad = None
  408. elif how_to_zero == "model":
  409. for model in models:
  410. model.zero_grad()
  411. else:
  412. optimizer0.zero_grad()
  413. optimizer1.zero_grad()
  414. loss0 = model0(self.x)
  415. loss1 = model1(self.x)
  416. with amp.scale_loss(loss0, optimizer0, loss_id=loss_ids[0]) as scaled_loss:
  417. scaled_loss.backward()
  418. if i == inject_inf and which_backward == 0:
  419. if inject_inf_loc == "fp32":
  420. model0.weight0.grad[0] = float('inf')
  421. elif inject_inf_loc == "fp16":
  422. model0.weight1.grad[0] = float('inf')
  423. with amp.scale_loss(loss1, optimizer1, loss_id=loss_ids[1]) as scaled_loss:
  424. scaled_loss.backward()
  425. if i == inject_inf and which_backward == 1:
  426. if inject_inf_loc == "fp32":
  427. model1.weight0.grad[0] = float('inf')
  428. elif inject_inf_loc == "fp16":
  429. model1.weight1.grad[0] = float('inf')
  430. # print("opt_level {} i {} inject_inf {} which_backward {} inject_inf_loc {} use_multiple_loss_scalers {}".format(opt_level, i, inject_inf, which_backward, inject_inf_loc, use_multiple_loss_scalers))
  431. if i != inject_inf:
  432. master_params = list(amp.master_params(optimizer0)) + \
  433. list(amp.master_params(optimizer1))
  434. for param, reference_grad in zip(master_params,
  435. reference_grads[what_got_skipped(inject_inf, which_backward)][unskipped]):
  436. if opt_level == "O2" and not materialize_master_grads:
  437. continue
  438. else:
  439. torch.testing.assert_close(param.grad.float(), reference_grad.float())
  440. unskipped += 1
  441. optimizer0.step()
  442. optimizer1.step()
  443. model_params = [p for p in model0.parameters()] + [p for p in model1.parameters()]
  444. master_params = [p for p in amp.master_params(optimizer0)] + \
  445. [p for p in amp.master_params(optimizer1)]
  446. for model, master, reference in zip(
  447. model_params,
  448. master_params,
  449. final_params[what_got_skipped(inject_inf, which_backward)]):
  450. torch.testing.assert_close(model, reference)
  451. torch.testing.assert_close(model, master.to(model.dtype))
  452. if opt_level == "O1":
  453. _amp_state.handle._deactivate()
  454. @unittest.skipIf(disabled, "amp_C is unavailable")
  455. def test_3models2losses2optimizers(self):
  456. model0 = MyModel(1)
  457. model1 = MyModel(2)
  458. model2 = MyModel(3)
  459. optimizer0 = torch.optim.SGD([{'params' : model0.parameters(), 'lr' : 0.25},
  460. {'params' : model1.parameters(), 'lr' : 1.0}],
  461. momentum=0.5)
  462. optimizer1 = torch.optim.SGD([{'params' : model2.parameters(), 'lr' : 0.5}],
  463. momentum=0.25)
  464. # Again, can't do this: reference_grads = [[]]*9
  465. reference_grads = [[], [], [], [], [], [], [], [], []]
  466. final_params = [None, None, None, None, None, None, None, None, None]
  467. for i in range(2):
  468. optimizer0.zero_grad()
  469. optimizer1.zero_grad()
  470. loss0 = model0(self.x) + model1(self.x)
  471. loss1 = model2(self.x) + model1(self.x)
  472. loss0.backward()
  473. loss1.backward()
  474. reference_grads[0].append([param.grad.data.clone() for param in model0.parameters()] +
  475. [param.grad.data.clone() for param in model1.parameters()])
  476. optimizer0.step()
  477. optimizer1.step()
  478. final_params[0] = \
  479. [param.data.clone() for param in model0.parameters()] + \
  480. [param.data.clone() for param in model1.parameters()] + \
  481. [param.data.clone() for param in model2.parameters()]
  482. def what_got_skipped(which_iter, which_backward, which_model):
  483. if which_iter == 0:
  484. if which_backward == 0:
  485. if which_model == 0:
  486. return 1
  487. if which_model == 1:
  488. return 2
  489. if which_backward == 1:
  490. if which_model == 2:
  491. return 3
  492. if which_model == 1:
  493. return 4
  494. if which_iter == 1:
  495. if which_backward == 0:
  496. if which_model == 0:
  497. return 5
  498. if which_model == 1:
  499. return 6
  500. if which_backward == 1:
  501. if which_model == 2:
  502. return 7
  503. if which_model == 1:
  504. return 8
  505. return 0
  506. for which_iter in (0,1):
  507. for which_backward in (0,1):
  508. if which_backward == 0:
  509. which_models = (0,1)
  510. if which_backward == 1:
  511. which_models = (2,1)
  512. for which_model in which_models:
  513. model0 = MyModel(1)
  514. model1 = MyModel(2)
  515. model2 = MyModel(3)
  516. optimizer0 = torch.optim.SGD([{'params' : model0.parameters(), 'lr' : 0.25},
  517. {'params' : model1.parameters(), 'lr' : 1.0}],
  518. momentum=0.5)
  519. optimizer1 = torch.optim.SGD([{'params' : model2.parameters(), 'lr' : 0.5}],
  520. momentum=0.25)
  521. for i in range(3):
  522. optimizer0.zero_grad()
  523. optimizer1.zero_grad()
  524. loss0 = model0(self.x) + model1(self.x)
  525. loss1 = model2(self.x) + model1(self.x)
  526. loss0.backward()
  527. loss1.backward()
  528. if i != which_iter:
  529. reference_grads[what_got_skipped(which_iter,
  530. which_backward, which_model)].append(
  531. [param.grad.data.clone() for param in model0.parameters()] +
  532. [param.grad.data.clone() for param in model1.parameters()])
  533. if i == which_iter:
  534. if which_backward == 0:
  535. # if which_model == 0:
  536. optimizer1.step()
  537. # if which_model == 1:
  538. # optimizer1.step()
  539. if which_backward == 1:
  540. # if which_model == 2:
  541. # optimizer0.step()
  542. # if which_model == 1:
  543. continue
  544. else:
  545. optimizer0.step()
  546. optimizer1.step()
  547. final_params[what_got_skipped(which_iter, which_backward, which_model)] = \
  548. [param.data.clone() for param in model0.parameters()] + \
  549. [param.data.clone() for param in model1.parameters()] + \
  550. [param.data.clone() for param in model2.parameters()]
  551. for materialize_master_grads in (False, True):
  552. for opt_level in ("O0", "O1", "O2", "O3"):
  553. for how_to_zero in ("none", "model", "optimizer"):
  554. for use_multiple_loss_scalers in (False, True):
  555. if opt_level == "O1" or opt_level == "O2":
  556. inject_inf_iters = (-1, 0, 1)
  557. else:
  558. inject_inf_iters = (-1,)
  559. for inject_inf in inject_inf_iters:
  560. if inject_inf >= 0:
  561. inject_inf_locs = ("fp16", "fp32")
  562. which_backwards = (0, 1)
  563. else:
  564. inject_inf_locs = ("fdsa",)
  565. which_backwards = (None,)
  566. for inject_inf_loc in inject_inf_locs:
  567. for which_backward in which_backwards:
  568. if use_multiple_loss_scalers:
  569. num_losses = 2
  570. loss_ids = [0, 1]
  571. else:
  572. num_losses = 1
  573. loss_ids = [0, 0]
  574. if inject_inf >= 0:
  575. iters = 3
  576. if which_backward == 0:
  577. which_models = (0, 1)
  578. elif which_backward == 1:
  579. which_models = (2, 1)
  580. else:
  581. iters = 2
  582. which_models = (None,)
  583. for which_model in which_models:
  584. model0 = MyModel(1)
  585. model1 = MyModel(2)
  586. model2 = MyModel(3)
  587. models = [model0, model1, model2]
  588. optimizer0 = FusedSGD([{'params' : model0.parameters(), 'lr' : 0.25},
  589. {'params' : model1.parameters(), 'lr' : 1.0}],
  590. momentum=0.5, materialize_master_grads=materialize_master_grads)
  591. optimizer1 = FusedSGD([{'params' : model2.parameters(), 'lr' : 0.5}],
  592. momentum=0.25, materialize_master_grads=materialize_master_grads)
  593. _amp_state.allow_incoming_model_not_fp32 = True
  594. [model0, model1, model2], [optimizer0, optimizer1] = amp.initialize(
  595. [model0, model1, model2],
  596. [optimizer0, optimizer1],
  597. opt_level=opt_level,
  598. verbosity=0,
  599. cast_model_type=False,
  600. num_losses=num_losses)
  601. _amp_state.allow_incoming_model_not_fp32 = False
  602. _amp_state.loss_scalers[0]._loss_scale = 4.0
  603. if use_multiple_loss_scalers:
  604. _amp_state.loss_scalers[1]._loss_scale = 16.0
  605. unskipped = 0
  606. for i in range(iters):
  607. if how_to_zero == "none":
  608. for model in models:
  609. for param in model.parameters():
  610. param.grad = None
  611. elif how_to_zero == "model":
  612. for model in models:
  613. model.zero_grad()
  614. else:
  615. optimizer0.zero_grad()
  616. optimizer1.zero_grad()
  617. loss0 = model0(self.x) + model1(self.x)
  618. loss1 = model2(self.x) + model1(self.x)
  619. with amp.scale_loss(loss0, optimizer0, loss_id=loss_ids[0]) as scaled_loss:
  620. scaled_loss.backward()
  621. if i == inject_inf and which_backward == 0:
  622. if which_model == 0:
  623. inj_model = model0
  624. elif which_model == 1:
  625. inj_model = model1
  626. else:
  627. raise RuntimeError(which_model + " invalid for loss 0")
  628. if inject_inf_loc == "fp32":
  629. inj_model.weight0.grad[0] = float('inf')
  630. elif inject_inf_loc == "fp16":
  631. inj_model.weight1.grad[0] = float('inf')
  632. with amp.scale_loss(loss1, [optimizer0, optimizer1], loss_id=loss_ids[1]) as scaled_loss:
  633. scaled_loss.backward()
  634. if i == inject_inf and which_backward == 1:
  635. if which_model == 2:
  636. inj_model = model2
  637. elif which_model == 1:
  638. inj_model = model1
  639. else:
  640. raise RuntimeError(which_model + " invalid for loss 1 ")
  641. if inject_inf_loc == "fp32":
  642. inj_model.weight0.grad[0] = float('inf')
  643. elif inject_inf_loc == "fp16":
  644. inj_model.weight1.grad[0] = float('inf')
  645. if i != inject_inf:
  646. master_params = list(amp.master_params(optimizer0)) + \
  647. list(amp.master_params(optimizer1))
  648. for param, reference_grad in zip(master_params,
  649. reference_grads[what_got_skipped(inject_inf,
  650. which_backward, which_model)][unskipped]):
  651. if opt_level == "O2" and not materialize_master_grads:
  652. continue
  653. else:
  654. torch.testing.assert_close(param.grad.float(), reference_grad.float())
  655. unskipped += 1
  656. optimizer0.step()
  657. optimizer1.step()
  658. model_params = [p for p in model0.parameters()] + \
  659. [p for p in model1.parameters()] + \
  660. [p for p in model2.parameters()]
  661. master_params = [p for p in amp.master_params(optimizer0)] + \
  662. [p for p in amp.master_params(optimizer1)]
  663. # print("opt_level {} i {} inject_inf {} which_backward {} inject_inf_loc {} use_multiple_loss_scalers {} which_model {}".format(opt_level, i, inject_inf, which_backward, inject_inf_loc, use_multiple_loss_scalers, which_model))
  664. for model, master, reference in zip(
  665. model_params,
  666. master_params,
  667. final_params[what_got_skipped(inject_inf, which_backward, which_model)]):
  668. torch.testing.assert_close(model, reference)
  669. torch.testing.assert_close(model, master.to(model.dtype))
  670. if opt_level == "O1":
  671. _amp_state.handle._deactivate()
  672. if __name__ == '__main__':
  673. unittest.main()