distributed_fused_lamb.py 56 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061
  1. import os
  2. import math
  3. import inspect
  4. import torch
  5. import importlib
  6. import amp_C
  7. from apex.multi_tensor_apply import multi_tensor_applier
  8. import torch.distributed.distributed_c10d as c10d
  9. # Fallback to private fields if using older PyTorch version
  10. try:
  11. import torch.distributed.distributed_c10d.get_process_group_ranks
  12. except ImportError:
  13. def get_process_group_ranks(group):
  14. return list(c10d._pg_group_ranks[group].keys())
  15. _make_nccl_premul_sum = getattr(torch.distributed, "_make_nccl_premul_sum", None)
  16. # Ref: https://github.com/pytorch/pytorch/pull/81272
  17. if _make_nccl_premul_sum is None:
  18. if hasattr(torch.distributed, "make_nccl_premul_sum"):
  19. _make_nccl_premul_sum = torch.distributed.make_nccl_premul_sum
  20. class DistributedFusedLAMB(torch.optim.Optimizer):
  21. """Implements LAMB algorithm.
  22. Currently GPU-only. Requires Apex to be installed via
  23. ``pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./``.
  24. This version of fused LAMB implements 2 fusions.
  25. * Fusion of the LAMB update's elementwise operations
  26. * A multi-tensor apply launch that batches the elementwise updates applied to all the model's parameters into one or a few kernel launches.
  27. :class:`apex.optimizers.FusedLAMB`'s usage is identical to any ordinary Pytorch optimizer::
  28. opt = apex.optimizers.FusedLAMB(model.parameters(), lr = ....)
  29. ...
  30. opt.step()
  31. :class:`apex.optimizers.FusedLAMB` may be used with or without Amp. If you wish to use :class:`FusedLAMB` with Amp,
  32. you may choose any ``opt_level``::
  33. opt = apex.optimizers.FusedLAMB(model.parameters(), lr = ....)
  34. model, opt = amp.initialize(model, opt, opt_level="O0" or "O1 or "O2")
  35. ...
  36. opt.step()
  37. In general, ``opt_level="O1"`` is recommended.
  38. LAMB was proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_.
  39. Arguments:
  40. params (iterable): iterable of parameters to optimize or dicts defining
  41. parameter groups.
  42. lr (float, optional): learning rate. (default: 1e-3)
  43. betas (Tuple[float, float], optional): coefficients used for computing
  44. running averages of gradient and its norm. (default: (0.9, 0.999))
  45. eps (float, optional): term added to the denominator to improve
  46. numerical stability. (default: 1e-8)
  47. weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
  48. amsgrad (boolean, optional): whether to use the AMSGrad variant of this
  49. algorithm from the paper `On the Convergence of Adam and Beyond`_
  50. NOT SUPPORTED now! (default: False)
  51. adam_w_mode (boolean, optional): Apply L2 regularization or weight decay
  52. True for decoupled weight decay(also known as AdamW) (default: True)
  53. grad_averaging (bool, optional): whether apply (1-beta2) to grad when
  54. calculating running averages of gradient. (default: True)
  55. set_grad_none (bool, optional): whether set grad to None when zero_grad()
  56. method is called. (default: True)
  57. max_grad_norm (float, optional): value used to clip global grad norm
  58. (default: 1.0)
  59. use_nvlamb (boolean, optional): Apply adaptive learning rate to 0.0
  60. weight decay parameter (default: False)
  61. step_supports_amp_scaling(boolean, optional): whether to use customized
  62. gradient unscaling logic (default: True)
  63. .. _Large Batch Optimization for Deep Learning - Training BERT in 76 minutes:
  64. https://arxiv.org/abs/1904.00962
  65. .. _On the Convergence of Adam and Beyond:
  66. https://openreview.net/forum?id=ryQu7f-RZ
  67. """
  68. class AtomicCounter(object):
  69. def __init__(self):
  70. self.value = 0
  71. self.order = []
  72. import threading
  73. self._lock = threading.Lock()
  74. def add(self, idx):
  75. with self._lock:
  76. self.value += 1
  77. self.order.append(idx)
  78. def __init__(self, params,
  79. lr=1e-3, bias_correction = True, grad_averaging=True,
  80. betas=(0.9, 0.999), eps=1e-8,
  81. weight_decay=0., max_grad_norm=0.,
  82. adam_w_mode=True, use_nvlamb=False,
  83. step_supports_amp_scaling=True, overlap_reductions=True,
  84. dwu_group_size=0, dwu_num_blocks=4, dwu_num_chunks=4,
  85. dwu_num_rs_pg=1, dwu_num_ar_pg=4, dwu_num_ag_pg=0, fused_norm=False,
  86. e5m2_allgather=False, verbose=False, clip_after_ar=True,
  87. full_ar=False, set_param_views_to_flat_buffer=False, skip_allgather=False,
  88. fuse_scale=False, param_order=None, nccl_allgather_channels=0):
  89. defaults = dict(lr=lr, bias_correction=bias_correction,
  90. betas=betas, eps=eps, weight_decay=weight_decay,
  91. grad_averaging=grad_averaging,
  92. max_grad_norm=max_grad_norm)
  93. super(DistributedFusedLAMB, self).__init__(params, defaults)
  94. global fused_adam_cuda, distributed_lamb_cuda
  95. fused_adam_cuda = importlib.import_module("fused_adam_cuda")
  96. distributed_lamb_cuda = importlib.import_module("distributed_lamb_cuda")
  97. self._overflow_buf = torch.cuda.IntTensor([0])
  98. self._has_overflow = False
  99. self.multi_tensor_lamb_compute_update_term = distributed_lamb_cuda.multi_tensor_lamb_compute_update_term
  100. self.multi_tensor_lamb_update_weights = distributed_lamb_cuda.multi_tensor_lamb_update_weights
  101. import amp_C
  102. self.multi_tensor_l2norm = amp_C.multi_tensor_l2norm
  103. self._grad_averaging = grad_averaging
  104. self._adam_w_mode = 1 if adam_w_mode else 0
  105. self._use_nvlamb = use_nvlamb
  106. self._step_supports_amp_scaling = step_supports_amp_scaling
  107. self._is_accumulation_step = False
  108. self._last_step = False
  109. self._overlap_reductions = overlap_reductions
  110. self._global_scale = None
  111. self._num_blocks = dwu_num_blocks
  112. self._num_chunks = dwu_num_chunks
  113. self._e5m2_allgather = e5m2_allgather
  114. self._verbose = verbose
  115. self._clip_after_ar = clip_after_ar
  116. self._full_ar = full_ar
  117. self._fuse_scale = fuse_scale
  118. self._L2_grad_norm = None
  119. self._set_flat_param_view = set_param_views_to_flat_buffer
  120. self._skip_ag = skip_allgather
  121. self._fused_norm = fused_norm if not clip_after_ar else False
  122. self._current_process_group = c10d._get_default_group()
  123. self._available_ranks = get_process_group_ranks(self._current_process_group)
  124. self._group_size = torch.cuda.device_count() if dwu_group_size <= 0 else dwu_group_size
  125. self._world_size = torch.distributed.get_world_size()
  126. self._num_groups = self._world_size // self._group_size
  127. self._rank_in_group = torch.distributed.get_rank() % self._group_size
  128. self._lr = torch.tensor(0.0, dtype=torch.float32, device='cuda')
  129. self._resume_from_checkpoint = False
  130. self._step = torch.cuda.IntTensor([0])
  131. # Master weight, moment, gradient buffers
  132. self._fp32_p, self._fp32_m, self._fp32_v, self._fp16_p, self._fp16_g = None, None, None, None, None
  133. # Check if collectives have no_copy option
  134. self._reduce_scatter_no_copy = (
  135. 'no_copy' in inspect.getfullargspec(torch.distributed.reduce_scatter).args
  136. )
  137. self._all_gather_no_copy = (
  138. 'no_copy' in inspect.getfullargspec(torch.distributed.all_gather).args
  139. )
  140. if "reduce_scatter_tensor" not in dir(torch.distributed):
  141. torch.distributed.reduce_scatter_tensor = torch.distributed._reduce_scatter_base
  142. if "all_gather_into_tensor" not in dir(torch.distributed):
  143. torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base
  144. self._num_rs_pg = dwu_num_rs_pg
  145. self._num_ar_pg = dwu_num_ar_pg
  146. self._num_ag_pg = dwu_num_ag_pg
  147. if self._full_ar: # full all reduce, only need AR and AG groups
  148. # l2_grad_norm may be reduced within a node to limit from memory reads
  149. for group_i in range(self._num_groups):
  150. ranks = [group_i*self._group_size+j for j in range(self._group_size)]
  151. l2_grad_norm_pg = torch.distributed.new_group(ranks=ranks)
  152. if torch.distributed.get_rank() in ranks:
  153. self._l2_grad_norm_pg = l2_grad_norm_pg
  154. self._ar_pg = []
  155. # consider all the ranks
  156. ranks = list(range(0, self._world_size))
  157. for i in range(self._num_ar_pg):
  158. if self._verbose:
  159. print(f"creating new AR group {i}: {ranks}")
  160. grp = torch.distributed.new_group(ranks=ranks)
  161. if grp != torch.distributed.GroupMember.NON_GROUP_MEMBER:
  162. if self._verbose:
  163. print(f"group {i}: init barrier (device: {torch.cuda.current_device()})")
  164. torch.distributed.barrier(group=grp, device_ids=[torch.cuda.current_device()])
  165. if self._verbose:
  166. print(f"created new AR group {i}: {ranks}")
  167. if torch.distributed.get_rank() in ranks:
  168. self._ar_pg.append(grp)
  169. self._ar_st = [torch.cuda.Stream() for _ in range(self._num_ar_pg)]
  170. if nccl_allgather_channels > 0:
  171. os.putenv('NCCL_MAX_NCHANNELS', str(nccl_allgather_channels))
  172. if self._num_ag_pg == 0:
  173. self._ag_pg = self._ar_pg
  174. self._ag_st = self._ar_st
  175. self._num_ag_pg = self._num_ar_pg
  176. else:
  177. self._ag_pg = []
  178. ranks = []
  179. stride = torch.cuda.device_count()
  180. for i in range(self._num_groups):
  181. rs = list(range(i*stride, (i+1)*stride))
  182. ranks.append(rs)
  183. for rs in ranks:
  184. for i in range(self._num_ag_pg):
  185. grp = torch.distributed.new_group(ranks=rs)
  186. if torch.distributed.get_rank() in rs:
  187. if self._verbose:
  188. print(f"creating AG group {i}: {rs}")
  189. self._ag_pg.append(grp)
  190. self._ag_st = [torch.cuda.Stream() for _ in range(self._num_ag_pg)]
  191. else: # reduce-scatter + all-reduce, need RS, AR, AG groups
  192. if self._num_groups > 1:
  193. self._ar_pg = []
  194. for dev_i in range(self._group_size):
  195. ranks = [dev_i+j*self._group_size for j in range(self._num_groups)]
  196. for i in range(self._num_ar_pg):
  197. if self._verbose:
  198. print(f"creating new AR group {i}: {ranks}")
  199. grp = torch.distributed.new_group(ranks=ranks)
  200. if grp != torch.distributed.GroupMember.NON_GROUP_MEMBER:
  201. if self._verbose:
  202. print(f"group {i}: init barrier (device: {torch.cuda.current_device()})")
  203. torch.distributed.barrier(group=grp, device_ids=[torch.cuda.current_device()])
  204. if self._verbose:
  205. print(f"created new AR group {i}: {ranks}")
  206. if torch.distributed.get_rank() in ranks:
  207. self._ar_pg.append(grp)
  208. self._ar_st = [torch.cuda.Stream() for _ in range(self._num_ar_pg)]
  209. rs_ranks = []
  210. for group_i in range(self._num_groups):
  211. rs_ranks.append([group_i*self._group_size+j for j in range(self._group_size)])
  212. self._rs_pg = []
  213. for group_i in range(self._num_groups):
  214. ranks = rs_ranks[group_i]
  215. for i in range(self._num_rs_pg):
  216. grp = torch.distributed.new_group(ranks=ranks)
  217. if torch.distributed.get_rank() in ranks:
  218. self._rs_pg.append(grp)
  219. if self._verbose:
  220. print(f"creating RS group : {ranks}")
  221. l2_grad_norm_pg = torch.distributed.new_group(ranks=ranks)
  222. if torch.distributed.get_rank() in ranks:
  223. self._l2_grad_norm_pg = l2_grad_norm_pg
  224. self._rs_st = [torch.cuda.Stream() for _ in range(self._num_rs_pg)]
  225. if self._num_ag_pg == 0:
  226. self._ag_pg = self._rs_pg
  227. self._ag_st = self._rs_st
  228. self._num_ag_pg = self._num_rs_pg
  229. else:
  230. self._ag_pg = []
  231. for group_i in range(self._num_groups):
  232. ranks = rs_ranks[group_i]
  233. for i in range(self._num_ag_pg):
  234. grp = torch.distributed.new_group(ranks=ranks)
  235. if torch.distributed.get_rank() in ranks:
  236. self._ag_pg.append(grp)
  237. if self._verbose:
  238. print(f"creating AG group : {ranks}")
  239. self._ag_st = [torch.cuda.Stream() for _ in range(self._num_ag_pg)]
  240. for ag_pg in self._ag_pg:
  241. torch.distributed.barrier(group=ag_pg)
  242. self._l2_grad_norm_st = torch.cuda.Stream()
  243. self._completion_st = torch.cuda.Stream()
  244. self._step.record_stream(self._completion_st)
  245. self._reductions_works = [None]*self._num_blocks
  246. self._allgather_works = [None]*self._num_blocks
  247. self._one = torch.cuda.IntTensor([1])
  248. self._first_step = True
  249. self._lazy_init_stage1_done, self._lazy_init_stage2_done = False, False
  250. self._param_order = self.AtomicCounter()
  251. p_offset = 0
  252. p_i = 0
  253. self._model_params = []
  254. self._grad_accs = []
  255. self._group_properties = []
  256. for group in self.param_groups:
  257. prev = None
  258. beta1, beta2 = group['betas']
  259. beta3 = 1.0 - beta1 if self._grad_averaging else 1.0
  260. bias_correction = 1 if group['bias_correction'] else 0
  261. eps = group['eps']
  262. weight_decay = group['weight_decay']
  263. for p in group['params']:
  264. if not p.requires_grad:
  265. continue
  266. self._model_params.append(p)
  267. self._group_properties.append((
  268. weight_decay,
  269. bias_correction,
  270. beta1,
  271. beta2,
  272. beta3,
  273. eps
  274. ))
  275. p_grads_size = p.numel()
  276. if self._set_flat_param_view:
  277. if param_order:
  278. # this is executed when param_order is specified by the user
  279. self._param_order.add(param_order[p])
  280. else:
  281. self._param_order.add(p_i)
  282. p_offset += p_grads_size
  283. # Only enforce 128b alignment (64 * fp16) for non-consecutive parameters
  284. # RNN is one example of consecutive parameters:
  285. # (weight_ih, weight_hh, bias_ih, bias_hh)
  286. if prev is not None and (prev.data_ptr() + prev.numel() * prev.element_size() != p.data_ptr()):
  287. p_offset = ((p_offset + 63) // 64) * 64
  288. prev = p
  289. p_i += 1
  290. if param_order:
  291. self._param_order.order = torch.argsort(torch.tensor(self._param_order.order)).tolist()
  292. self._grads_generated = [False]*len(self._model_params)
  293. self._grads_fp16, self._grads_fp32 = [], []
  294. if self._overlap_reductions:
  295. self._current_block = self._num_blocks
  296. self._net_total_param_size = p_offset
  297. self._total_param_size = p_offset
  298. dwu_min_page_size = 256 * self._num_blocks * self._num_chunks * self._group_size
  299. self._total_param_size = ((self._total_param_size + dwu_min_page_size - 1) // dwu_min_page_size) * dwu_min_page_size
  300. self._new_params = torch.zeros([self._total_param_size], dtype=torch.uint8 if self._e5m2_allgather else torch.float16, device='cuda')
  301. def _lazy_init_stage1(self):
  302. if self._lazy_init_stage1_done: return
  303. p_i = 0
  304. #self._model_params = []
  305. #self._grad_accs = []
  306. #self._group_properties = []
  307. for group in self.param_groups:
  308. for p in group['params']:
  309. torch.distributed.broadcast(p, 0)
  310. if not p.requires_grad:
  311. continue
  312. def wrapper(param, param_i):
  313. param_tmp = param.expand_as(param)
  314. grad_acc = param_tmp.grad_fn.next_functions[0][0]
  315. def allreduce_hook(*unused):
  316. if not self._set_flat_param_view:
  317. if self._first_step:
  318. # first time
  319. self._param_order.add(param_i)
  320. else:
  321. idx = self._param_order.order.index(param_i)
  322. self._do_overlapped_reduction(idx, param)
  323. else:
  324. if not self._first_step:
  325. idx = self._param_order.order.index(param_i)
  326. self._do_overlapped_reduction(idx, param)
  327. grad_acc.register_hook(allreduce_hook)
  328. self._grad_accs.append(grad_acc)
  329. wrapper(p, p_i)
  330. p_i += 1
  331. self._block_size = self._total_param_size // self._num_blocks
  332. self._chunk_size = self._block_size // self._num_chunks
  333. self._shard_size = self._chunk_size // self._group_size
  334. self._flat_grads = torch.zeros([self._total_param_size], dtype=torch.float16, device='cuda')
  335. self._mega_shard_size = self._num_blocks * self._num_chunks * self._shard_size
  336. # initialize master weights, moments buffers if not loaded from checkpoint
  337. if self._fp32_p is None:
  338. self._fp32_p = torch.zeros([self._mega_shard_size], dtype=torch.float32, device='cuda')
  339. self._fp32_m = torch.zeros([self._mega_shard_size], dtype=torch.float32, device='cuda')
  340. self._fp32_v = torch.zeros([self._mega_shard_size], dtype=torch.float32, device='cuda')
  341. self._fp32_u = torch.zeros([self._mega_shard_size], dtype=torch.float32, device='cuda')
  342. # FIXME: Rethink fp16 label since it's either uint8 or fp16
  343. self._fp16_p = torch.zeros([self._mega_shard_size], dtype=torch.uint8 if self._e5m2_allgather else torch.float16, device='cuda')
  344. self._fp16_g = torch.zeros([self._mega_shard_size], dtype=torch.float16, device='cuda')
  345. def _flat_split(p):
  346. def __blockify(p):
  347. return [p[block_id*self._block_size:(block_id+1)*self._block_size] for block_id in range(self._num_blocks)]
  348. def __chunkify(p):
  349. return [p[chunk_id*self._chunk_size:(chunk_id+1)*self._chunk_size] for chunk_id in range(self._num_chunks)]
  350. def __shardify(p):
  351. return [p[shard_id*self._shard_size:(shard_id+1)*self._shard_size] for shard_id in range(self._group_size)]
  352. list_of_blocks = __blockify(p)
  353. list_of_list_of_chunks = [__chunkify(block) for block in list_of_blocks]
  354. list_of_list_of_list_of_shards = [[__shardify(chunk) for chunk in chunks] for chunks in list_of_list_of_chunks]
  355. return list_of_blocks, list_of_list_of_chunks, list_of_list_of_list_of_shards
  356. # note(crcrpar): the function below doesn't seem to be used at all.
  357. # def _flat_split_no_shards(p):
  358. # def __blockify(p):
  359. # return [p[block_id*self._block_size:(block_id+1)*self._block_size] for block_id in range(self._num_blocks)]
  360. # def __chunkify(p):
  361. # return [p[chunk_id*self._chunk_size:(chunk_id+1)*self._chunk_size] for chunk_id in range(self._num_chunks)]
  362. # list_of_blocks = __blockify(self._flat_grads)
  363. # list_of_list_of_chunks = [__chunkify(block) for block in list_of_blocks]
  364. # return list_of_blocks, list_of_list_of_chunks
  365. def _full_packed_split(p):
  366. def __shardify(p):
  367. return [p[mega_shard*self._mega_shard_size:(mega_shard+1)*self._mega_shard_size] for mega_shard in range(self._group_size)]
  368. def __blockify(p):
  369. return [p[block_id*self._num_chunks*self._shard_size:(block_id+1)*self._num_chunks*self._shard_size] for block_id in range(self._num_blocks)]
  370. def __chunkify(p):
  371. return [p[chunk_id*self._shard_size:(chunk_id+1)*self._shard_size] for chunk_id in range(self._num_chunks)]
  372. list_of_mega_shards = __shardify(p)
  373. list_of_list_of_mega_blocks = [__blockify(mega_shard) for mega_shard in list_of_mega_shards]
  374. list_of_list_of_list_of_mega_chunks = [[__chunkify(mega_block) for mega_block in mega_blocks] for mega_blocks in list_of_list_of_mega_blocks]
  375. return list_of_mega_shards, list_of_list_of_mega_blocks, list_of_list_of_list_of_mega_chunks
  376. def _packed_split(p):
  377. def __packed_blockify(p):
  378. packed_block_size = self._num_chunks*self._shard_size
  379. return [p[block_id*packed_block_size:(block_id+1)*packed_block_size] for block_id in range(self._num_blocks)]
  380. def __packed_chunkify(p):
  381. # in the packed format, each chunk contains one shard, so packed_chunk_size == self._shard_size
  382. return [p[chunk_id*self._shard_size:(chunk_id+1)*self._shard_size] for chunk_id in range(self._num_chunks)]
  383. list_of_blocks = __packed_blockify(p)
  384. list_of_list_of_chunks = [__packed_chunkify(block) for block in list_of_blocks]
  385. return list_of_blocks, list_of_list_of_chunks
  386. def _split_assign(shards):
  387. packed_block_size = self._num_chunks*self._shard_size
  388. list_of_list_of_chunks=[]
  389. for block_id in range(self._num_blocks):
  390. list_of_chunks=[]
  391. for chunk_id in range(self._num_chunks):
  392. #self._fp16_g[block_id*packed_block_size+chunk_id*self._shard_size:block_id*packed_block_size+(chunk_id+1)*self._shard_size] = shards[block_id][chunk_id][self._rank_in_group]
  393. list_of_chunks.append( shards[block_id][chunk_id][self._rank_in_group])
  394. list_of_list_of_chunks.append(list_of_chunks)
  395. return list_of_list_of_chunks
  396. self._new_params_mega_shards, self._new_params_mega_blocks, self._new_params_mega_chunks = _full_packed_split(self._new_params)
  397. # this splitting scheme is needed when allgather needs to be split into multiple chunks in a contiguous way
  398. self._new_params2_blocks, self._new_params2_chunks, self._new_params2_shards = _flat_split(self._new_params)
  399. self._fp32_p_blocks, self._fp32_p_chunks = _packed_split(self._fp32_p)
  400. self._fp32_m_blocks, self._fp32_m_chunks = _packed_split(self._fp32_m)
  401. self._fp32_v_blocks, self._fp32_v_chunks = _packed_split(self._fp32_v)
  402. self._fp32_u_blocks, self._fp32_u_chunks = _packed_split(self._fp32_u)
  403. self._fp16_p_blocks, self._fp16_p_chunks = _packed_split(self._fp16_p)
  404. if self._full_ar:
  405. # for gradient all-reduce
  406. self._flat_grads_blocks, self._flat_grads_chunks, self._flat_grads_shards = _flat_split(self._flat_grads)
  407. # for weight update
  408. self._fp16_g_chunks = _split_assign(self._flat_grads_shards)
  409. else:
  410. self._flat_grads_blocks, self._flat_grads_chunks, self._flat_grads_shards = _flat_split(self._flat_grads)
  411. self._fp16_g_blocks, self._fp16_g_chunks = _packed_split(self._fp16_g)
  412. self._lazy_init_stage1_done = True
  413. def _lazy_init_stage2(self):
  414. if self._lazy_init_stage2_done: return
  415. if not self._set_flat_param_view:
  416. # reversing is needed for overlapping allreduce and backprop, but currently not supported for flat param view
  417. self._param_order.order.reverse()
  418. # re-order model_params, grad_accs, group_properties lists
  419. self._model_params = [self._model_params[i] for i in self._param_order.order]
  420. self._grad_accs = [self._grad_accs[i] for i in self._param_order.order]
  421. self._group_properties = [self._group_properties[i] for i in self._param_order.order]
  422. def _get_flat_view(param):
  423. if param.is_contiguous(memory_format=torch.channels_last):
  424. K, C, H, W = param.shape
  425. pv = param.as_strided(size=(K,H,W,C), stride=(H*W*C, W*C, C, 1))
  426. elif param.is_contiguous(memory_format=torch.channels_last_3d):
  427. K, C, D, H, W = param.shape
  428. pv = param.as_strided(size=(K,D,H,W,C), stride=(D*H*W*C, H*W*C, W*C, C, 1))
  429. else:
  430. pv = param
  431. return pv.view(-1)
  432. # re-collect grads info (size, offset) after ordering
  433. prev = None
  434. p_offset = 0
  435. self._grads_info = []
  436. self._individual_flat_grads = []
  437. for i, p in enumerate(self._model_params):
  438. p_grads_size = p.numel()
  439. self._grads_info.append({"param_grads_size":p_grads_size, "param_offset":p_offset})
  440. self._individual_flat_grads.append(self._flat_grads[p_offset:p_offset+p_grads_size].view_as(p))
  441. # for the first iteration
  442. self._do_overlapped_reduction(i, p)
  443. p_offset += p_grads_size
  444. # Only enforce 128b alignment (64 * fp16) for non-consecutive parameters
  445. # RNN is one example of consecutive parameters:
  446. # (weight_ih, weight_hh, bias_ih, bias_hh)
  447. if prev is not None and (prev.data_ptr() + prev.numel() * prev.element_size() != p.data_ptr()):
  448. p_offset = ((p_offset + 63) // 64) * 64
  449. prev = p
  450. self._low_param_i = [0]*self._num_blocks
  451. for block_id in range(self._num_blocks-1,-1,-1):
  452. p_i = len(self._grads_info)-1
  453. while p_i > 0 and self._grads_info[p_i]["param_offset"] > block_id*self._block_size:
  454. p_i -= 1
  455. self._low_param_i[block_id] = p_i
  456. #print("self._low_param_i", self._low_param_i)
  457. # This paragraph does two things:
  458. # 1) Copy model parameters into master buffer
  459. # 2) Create tensor lists for unpacking new parameter tensor after all-gather
  460. self._packed_flat_to_model_params_fp16 = []
  461. self._packed_flat_to_model_params_fp32 = []
  462. self._model_params_num = len(self._model_params)
  463. self._contrib_tensor_list = []
  464. self._contrib_min_param_i, self._contrib_max_param_i = -1, -1
  465. self._contrib_update_frag_for_norm = []
  466. self._contrib_model_param_for_norm_fp16 = []
  467. self._contrib_model_param_for_norm_fp32 = []
  468. self._contrib_model_param_for_norm_is_fp16 = []
  469. self._model_param_is_contrib = []
  470. self._contrib_group_properties = []
  471. for shard_id in range(self._group_size):
  472. for block_id in range(self._num_blocks):
  473. for chunk_id in range(self._num_chunks):
  474. flat_shard_start = (((block_id * self._num_chunks + chunk_id) * self._group_size) + shard_id) * self._shard_size
  475. flat_shard_end = flat_shard_start + self._shard_size
  476. for param_i, (p, grads_info, group_props) in enumerate(zip(self._model_params, self._grads_info, self._group_properties)):
  477. flat_grad_start = grads_info["param_offset"]
  478. flat_grad_end = flat_grad_start + grads_info["param_grads_size"]
  479. clipped_start = (lambda a,b: a if a > b else b)(flat_grad_start, flat_shard_start)
  480. clipped_end = (lambda a,b: a if a < b else b)(flat_grad_end, flat_shard_end)
  481. if clipped_start < clipped_end:
  482. grad_offset = clipped_start - flat_grad_start
  483. grad_length = clipped_end - clipped_start
  484. shard_offset = clipped_start - flat_shard_start
  485. pf = _get_flat_view(p)
  486. model_param_fragment = pf[grad_offset:grad_offset+grad_length]
  487. new_param_packed_fragment = self._new_params_mega_chunks[shard_id][block_id][chunk_id][shard_offset:shard_offset+grad_length]
  488. if model_param_fragment.dtype == torch.float16:
  489. self._packed_flat_to_model_params_fp16.append( (new_param_packed_fragment, model_param_fragment) )
  490. else:
  491. self._packed_flat_to_model_params_fp32.append( (new_param_packed_fragment, model_param_fragment) )
  492. if shard_id == self._rank_in_group:
  493. self._model_param_is_contrib.append(param_i)
  494. # copy model parameters into master buffer
  495. master_param_fragment = self._fp32_p_chunks[block_id][chunk_id][shard_offset:shard_offset+grad_length]
  496. opti_state_m_fragment = self._fp32_m_chunks[block_id][chunk_id][shard_offset:shard_offset+grad_length]
  497. opti_state_v_fragment = self._fp32_v_chunks[block_id][chunk_id][shard_offset:shard_offset+grad_length]
  498. opti_state_u_fragment = self._fp32_u_chunks[block_id][chunk_id][shard_offset:shard_offset+grad_length]
  499. opti_state_g_fragment = self._fp16_g_chunks[block_id][chunk_id][shard_offset:shard_offset+grad_length]
  500. opti_state_p_fragment = self._fp16_p_chunks[block_id][chunk_id][shard_offset:shard_offset+grad_length]
  501. #print("model_param_fragment.size()=%s, new_param_packed_fragment.size()=%s, master_param_fragment.size()=%s" % (str(model_param_fragment.size()), str(new_param_packed_fragment.size()), str(master_param_fragment.size())))
  502. if not self._resume_from_checkpoint:
  503. master_param_fragment.copy_(model_param_fragment)
  504. self._contrib_group_properties.append(group_props)
  505. self._contrib_tensor_list.append((master_param_fragment, opti_state_m_fragment, opti_state_v_fragment, opti_state_u_fragment, opti_state_g_fragment, opti_state_p_fragment)) # p, m, v, u, g, p_copy
  506. self._contrib_update_frag_for_norm.append(opti_state_u_fragment)
  507. if p.dtype == torch.float16:
  508. self._contrib_model_param_for_norm_fp16.append(p)
  509. else:
  510. self._contrib_model_param_for_norm_fp32.append(p)
  511. self._contrib_model_param_for_norm_is_fp16.append(True if p.dtype == torch.float16 else False)
  512. if self._contrib_min_param_i < 0: self._contrib_min_param_i = param_i
  513. self._contrib_max_param_i = param_i
  514. self._contrib_model_param_for_norm_num = len(self._contrib_model_param_for_norm_is_fp16)
  515. if len(self._contrib_model_param_for_norm_fp16) == 0: self._contrib_model_param_for_norm_fp16 = None
  516. if len(self._contrib_model_param_for_norm_fp32) == 0: self._contrib_model_param_for_norm_fp32 = None
  517. self._contrib_model_param_for_norm_is_fp32 = torch.tensor([not is_fp16 for is_fp16 in self._contrib_model_param_for_norm_is_fp16], dtype=torch.bool, device='cuda')
  518. self._contrib_model_param_for_norm_is_fp16 = torch.tensor([is_fp16 for is_fp16 in self._contrib_model_param_for_norm_is_fp16], dtype=torch.bool, device='cuda')
  519. self._offsets = torch.tensor(self._model_param_is_contrib, dtype=torch.int64, device='cuda')
  520. p, m, v, u, g, p_copy = list(zip(*self._contrib_tensor_list))
  521. self._contrib_compute_update_term_tensor_list = [g, p, m, v, u]
  522. self._contrib_update_weights_tensor_list = [u, p, p_copy]
  523. math_type = self._fp32_u.dtype
  524. decay, bias_correction, beta1, beta2, beta3, epsilon = list(zip(*self._contrib_group_properties))
  525. self._contrib_beta1 = torch.tensor(beta1, dtype=math_type, device='cuda')
  526. self._contrib_beta2 = torch.tensor(beta2, dtype=math_type, device='cuda')
  527. self._contrib_beta3 = torch.tensor(beta3, dtype=math_type, device='cuda')
  528. self._contrib_bias_correction = torch.tensor(bias_correction, dtype=torch.int, device='cuda')
  529. self._contrib_epsilon = torch.tensor(epsilon, dtype=math_type, device='cuda')
  530. self._contrib_weight_decay = torch.tensor(decay, dtype=math_type, device='cuda')
  531. self._packed_flat_to_model_params_fp16 = list(zip(*self._packed_flat_to_model_params_fp16)) if len(self._packed_flat_to_model_params_fp16) > 0 else None
  532. self._packed_flat_to_model_params_fp32 = list(zip(*self._packed_flat_to_model_params_fp32)) if len(self._packed_flat_to_model_params_fp32) > 0 else None
  533. self._lazy_init_stage2_done = True
  534. self.complete_reductions()
  535. self._first_step = False
  536. def set_is_accumulation_step(self, is_accumulation_step):
  537. self._is_accumulation_step = is_accumulation_step
  538. def set_last_step(self, last_step):
  539. self._last_step = last_step
  540. def _get_flush_block(self):
  541. flush_block = []
  542. if self._current_block > 0 and self._grads_generated[self._low_param_i[self._current_block-1]]:
  543. num_grads = len(self._grads_generated)
  544. contiguous_idx = num_grads
  545. while contiguous_idx > 0 and self._grads_generated[contiguous_idx-1]:
  546. contiguous_idx -= 1
  547. if contiguous_idx < num_grads and self._grads_info[contiguous_idx]["param_offset"] <= (self._current_block-1)*self._block_size:
  548. self._current_block -= 1
  549. start = self._current_block * self._block_size
  550. end = (self._current_block+1) * self._block_size
  551. flush_block = [start, end]
  552. return flush_block
  553. def _full_all_reduce_scale(self, block_id, scale):
  554. works = [None]*self._num_chunks
  555. if self._clip_after_ar:
  556. for chunk_id in range(self._num_chunks):
  557. glob_chunk_id = block_id * self._num_chunks + chunk_id
  558. ar_stream = self._ar_st[glob_chunk_id%self._num_ar_pg]
  559. ar_stream.wait_stream(torch.cuda.current_stream())
  560. with torch.cuda.stream(ar_stream):
  561. works[chunk_id] = torch.distributed.all_reduce(self._flat_grads_chunks[block_id][chunk_id],group=self._ar_pg[glob_chunk_id%self._num_ar_pg],async_op=True,op=_make_nccl_premul_sum(scale))
  562. else:
  563. glob_chunk_id = block_id
  564. ar_stream = self._ar_st[glob_chunk_id%self._num_ar_pg]
  565. ar_stream.wait_stream(torch.cuda.current_stream())
  566. with torch.cuda.stream(ar_stream):
  567. works0 = torch.distributed.all_reduce(self._flat_grads_blocks[block_id],group=self._ar_pg[glob_chunk_id%self._num_ar_pg],async_op=True,op=_make_nccl_premul_sum(scale))
  568. for i in range(self._num_chunks):
  569. works[i]=works0
  570. self._reductions_works[block_id] = works
  571. def _full_all_reduce(self, block_id):
  572. works = [None]*self._num_chunks
  573. for chunk_id in range(self._num_chunks):
  574. glob_chunk_id = block_id * self._num_chunks + chunk_id
  575. ar_stream = self._ar_st[glob_chunk_id%self._num_ar_pg]
  576. ar_stream.wait_stream(torch.cuda.current_stream())
  577. with torch.cuda.stream(ar_stream):
  578. works[chunk_id] = torch.distributed.all_reduce(self._flat_grads_chunks[block_id][chunk_id],group=self._ar_pg[glob_chunk_id%self._num_ar_pg],async_op=True)
  579. self._reductions_works[block_id] = works
  580. def _reduce_scatter_and_all_reduce_scale(self, block_id, scale):
  581. # Reduction within each node
  582. # Changes gradient format from [block * chunk * shard] to [shard * block * chunk]
  583. # The output format is the same as the fp32 master parameters
  584. works = [None]*self._num_chunks
  585. for chunk_id in range(self._num_chunks):
  586. glob_chunk_id = block_id * self._num_chunks + chunk_id
  587. rs_stream = self._rs_st[glob_chunk_id%self._num_rs_pg]
  588. rs_stream.wait_stream(torch.cuda.current_stream())
  589. rs_stream.wait_stream(self._l2_grad_norm_st)
  590. with torch.cuda.stream(rs_stream):
  591. if self._reduce_scatter_no_copy:
  592. works[chunk_id] = torch.distributed.reduce_scatter(
  593. output=self._fp16_g_chunks[block_id][chunk_id],
  594. input_list=self._flat_grads_shards[block_id][chunk_id],
  595. group=self._rs_pg[glob_chunk_id%self._num_rs_pg],
  596. async_op=True,
  597. no_copy=True,
  598. op=_make_nccl_premul_sum(scale),
  599. )
  600. else:
  601. works[chunk_id] = torch.distributed.reduce_scatter_tensor(
  602. output=self._fp16_g_chunks[block_id][chunk_id],
  603. input=self._flat_grads_chunks[block_id][chunk_id],
  604. group=self._rs_pg[glob_chunk_id%self._num_rs_pg],
  605. async_op=True,
  606. op=_make_nccl_premul_sum(scale),
  607. )
  608. # Reduction across nodes for each rank
  609. if self._num_groups > 1:
  610. for chunk_id in range(self._num_chunks):
  611. glob_chunk_id = block_id * self._num_chunks + chunk_id
  612. ar_stream = self._ar_st[glob_chunk_id%self._num_ar_pg]
  613. with torch.cuda.stream(ar_stream):
  614. works[chunk_id].wait()
  615. works[chunk_id] = torch.distributed.all_reduce(self._fp16_g_chunks[block_id][chunk_id],group=self._ar_pg[glob_chunk_id%self._num_ar_pg],async_op=True)
  616. self._reductions_works[block_id] = works
  617. def _reduce_scatter_and_all_reduce(self, block_id):
  618. # Reduction within each node
  619. # Changes gradient format from [block * chunk * shard] to [shard * block * chunk]
  620. # The output format is the same as the fp32 master parameters
  621. works = [None]*self._num_chunks
  622. for chunk_id in range(self._num_chunks):
  623. glob_chunk_id = block_id * self._num_chunks + chunk_id
  624. rs_stream = self._rs_st[glob_chunk_id%self._num_rs_pg]
  625. rs_stream.wait_stream(torch.cuda.current_stream())
  626. with torch.cuda.stream(rs_stream):
  627. if self._reduce_scatter_no_copy:
  628. works[chunk_id] = torch.distributed.reduce_scatter(
  629. output=self._fp16_g_chunks[block_id][chunk_id],
  630. input_list=self._flat_grads_shards[block_id][chunk_id],
  631. group=self._rs_pg[glob_chunk_id%self._num_rs_pg],
  632. async_op=True,
  633. no_copy=True,
  634. )
  635. else:
  636. works[chunk_id] = torch.distributed.reduce_scatter_tensor(
  637. output = self._fp16_g_chunks[block_id][chunk_id],
  638. input = self._flat_grads_chunks[block_id][chunk_id],
  639. group = self._rs_pg[glob_chunk_id%self._num_rs_pg],
  640. async_op = True,
  641. )
  642. # Reduction across nodes for each rank
  643. if self._num_groups > 1:
  644. for chunk_id in range(self._num_chunks):
  645. glob_chunk_id = block_id * self._num_chunks + chunk_id
  646. ar_stream = self._ar_st[glob_chunk_id%self._num_ar_pg]
  647. with torch.cuda.stream(ar_stream):
  648. works[chunk_id].wait()
  649. works[chunk_id] = torch.distributed.all_reduce(self._fp16_g_chunks[block_id][chunk_id],group=self._ar_pg[glob_chunk_id%self._num_ar_pg],async_op=True)
  650. self._reductions_works[block_id] = works
  651. def _pipeline_block_reductions(self, block_id):
  652. if self._clip_after_ar:
  653. self._flatten_grad_mt(1.0/self._world_size)
  654. if self._full_ar:
  655. self._full_all_reduce(block_id)
  656. else:
  657. self._reduce_scatter_and_all_reduce(block_id)
  658. # Compute L2 grad norm
  659. if block_id == 0:
  660. with torch.cuda.stream(self._l2_grad_norm_st):
  661. for block_id in range(self._num_blocks):
  662. for chunk_id in range(self._num_chunks):
  663. self._reductions_works[block_id][chunk_id].wait()
  664. # Since the packed format is contiguous after reductions, only one norm is needed
  665. l2_grad_norm_sq = torch.empty([1], device='cuda')
  666. if self._full_ar:
  667. # this flattening of lists is to keep multi_tensor_apply function happy, it wants depth=1 for l2 norm computation
  668. flat_list = [item for sublist in self._fp16_g_chunks for item in sublist]
  669. l2_grad_norm_sq = multi_tensor_applier(self.multi_tensor_l2norm, self._overflow_buf, [flat_list], False)[0]**2
  670. else:
  671. l2_grad_norm_sq = self._fp16_g.norm(dtype=torch.float32, p=2)**2
  672. torch.distributed.all_reduce(l2_grad_norm_sq, group=self._l2_grad_norm_pg)
  673. self._L2_grad_norm = l2_grad_norm_sq.sqrt()
  674. else:
  675. # Copy model grads to flat grads buffer
  676. self._flatten_grad_mt(1.0)
  677. # Compute L2 grad norm
  678. self._l2_grad_norm_st.wait_stream(torch.cuda.current_stream())
  679. with torch.cuda.stream(self._l2_grad_norm_st):
  680. if not self._fused_norm:
  681. self._L2_grad_norm = self._flat_grads.norm(dtype=torch.float16, p=2).float()
  682. torch.cuda.current_stream().wait_stream(self._l2_grad_norm_st)
  683. # Apply clipping & pre-reduction scaling on grads
  684. loss_scale = self.global_scale
  685. max_grad_norm = loss_scale*self.defaults['max_grad_norm']
  686. coeff = max_grad_norm /(1e-6+self.L2_grad_norm)
  687. coeff = (coeff>1) * self._one + (coeff<=1) * coeff
  688. tmp = torch.cat(((self._one), (coeff)))
  689. index = (coeff+1>coeff).int()
  690. scale = tmp.index_select(0, index).half()/self._world_size
  691. if not self._fuse_scale:
  692. self._flat_grads.mul_(scale)
  693. if self._full_ar:
  694. if self._fuse_scale:
  695. self._full_all_reduce_scale(block_id, scale)
  696. else:
  697. self._full_all_reduce(block_id)
  698. else:
  699. if self._fuse_scale:
  700. self._reduce_scatter_and_all_reduce_scale(block_id, scale)
  701. else:
  702. self._reduce_scatter_and_all_reduce(block_id)
  703. if block_id == 0:
  704. for block_id in range(self._num_blocks):
  705. for chunk_id in range(self._num_chunks):
  706. self._reductions_works[block_id][chunk_id].wait()
  707. def __compute_contrib_param_norm(self):
  708. if self._contrib_model_param_for_norm_fp16 is not None and self._contrib_model_param_for_norm_fp32 is not None:
  709. gnorm_fp16 = multi_tensor_applier(self.multi_tensor_l2norm, self._overflow_buf, [self._contrib_model_param_for_norm_fp16], True)[1]
  710. gnorm_fp32 = multi_tensor_applier(self.multi_tensor_l2norm, self._overflow_buf, [self._contrib_model_param_for_norm_fp32], True)[1]
  711. gnorm = torch.empty(size=[self._contrib_model_param_for_norm_num], dtype=torch.bool, device='cuda')
  712. gnorm.masked_scatter_(self._contrib_model_param_for_norm_is_fp16, gnorm_fp16)
  713. gnorm.masked_scatter_(self._contrib_model_param_for_norm_is_fp32, gnorm_fp32)
  714. elif self._contrib_model_param_for_norm_fp16 is not None:
  715. gnorm = multi_tensor_applier(self.multi_tensor_l2norm, self._overflow_buf, [self._contrib_model_param_for_norm_fp16], True)[1]
  716. elif self._contrib_model_param_for_norm_fp32 is not None:
  717. gnorm = multi_tensor_applier(self.multi_tensor_l2norm, self._overflow_buf, [self._contrib_model_param_for_norm_fp32], True)[1]
  718. return gnorm
  719. def __compute_contrib_update_norm(self):
  720. l2_norm = torch.zeros(size=[self._model_params_num], dtype=torch.float32, device='cuda')
  721. local_contrib_l2_norm = multi_tensor_applier(self.multi_tensor_l2norm, self._overflow_buf, [self._contrib_update_frag_for_norm], True)[1] ** 2
  722. l2_norm.scatter_(dim=0, index=self._offsets, src=local_contrib_l2_norm)
  723. torch.distributed.all_reduce(l2_norm, group=self._ag_pg[0])
  724. l2_norm = torch.sqrt(l2_norm)
  725. return l2_norm
  726. def _pipeline_step(self):
  727. global_scale = self.global_scale
  728. # if clip before ar, set max_grad_norm to 0
  729. max_grad_norm = self.defaults['max_grad_norm'] * self._clip_after_ar
  730. self._completion_st.wait_stream(self._l2_grad_norm_st)
  731. global_grad_norm = self.L2_grad_norm
  732. # check global_grad_norm and fill overflow_buf
  733. is_finite = (global_grad_norm + 1 > global_grad_norm).int()
  734. self._overflow_buf = self._one * (is_finite ^ self._one) # toggle between 0 and 1
  735. if not self._clip_after_ar:
  736. torch.distributed.all_reduce(is_finite,
  737. op=torch.distributed.ReduceOp.MIN,
  738. group=self._current_process_group)
  739. torch.distributed.all_reduce(self._overflow_buf,
  740. op=torch.distributed.ReduceOp.MAX,
  741. group=self._current_process_group)
  742. # increment step counter if no overflow
  743. self._step += is_finite
  744. self._completion_st.wait_stream(torch.cuda.current_stream())
  745. self._completion_st.wait_stream(self._l2_grad_norm_st)
  746. # Call step kernel once per step
  747. # Call all-gather once per step
  748. with torch.cuda.stream(self._completion_st):
  749. for block_id in range(self._num_blocks):
  750. for chunk_id in range(self._num_chunks):
  751. self._reductions_works[block_id][chunk_id].wait()
  752. param_norm = self.__compute_contrib_param_norm()
  753. multi_tensor_applier(self.multi_tensor_lamb_compute_update_term,
  754. self._overflow_buf,
  755. self._contrib_compute_update_term_tensor_list, # g, p, m, v, u
  756. self._contrib_beta1,
  757. self._contrib_beta2,
  758. self._contrib_beta3,
  759. self._contrib_bias_correction,
  760. self._step,
  761. self._contrib_epsilon,
  762. self._adam_w_mode,
  763. self._contrib_weight_decay,
  764. global_scale,
  765. global_grad_norm,
  766. max_grad_norm)
  767. upd_norm = self.__compute_contrib_update_norm()
  768. multi_tensor_applier(self.multi_tensor_lamb_update_weights,
  769. self._overflow_buf,
  770. self._contrib_update_weights_tensor_list, # u, p, p_copy
  771. param_norm,
  772. upd_norm,
  773. self._offsets,
  774. self._lr,
  775. self._contrib_weight_decay,
  776. global_grad_norm,
  777. self._use_nvlamb)
  778. if not self._skip_ag:
  779. # allgather chunking is currently not supported for clip after allreduce
  780. if not self._clip_after_ar:
  781. for block in range(self._num_blocks):
  782. for chunk in range(self._num_chunks):
  783. if self._all_gather_no_copy:
  784. torch.distributed.all_gather(
  785. tensor_list = self._new_params2_shards[block][chunk],
  786. tensor = self._fp16_p_chunks[block][chunk],
  787. group = self._ag_pg[0],
  788. no_copy = True,
  789. )
  790. else:
  791. torch.distributed.all_gather_into_tensor(
  792. output_tensor = self._new_params2_blocks[block],
  793. input_tensor = self._fp16_p_chunks[block][chunk],
  794. group = self._ag_pg[0],
  795. )
  796. else:
  797. if self._all_gather_no_copy:
  798. torch.distributed.all_gather(
  799. tensor_list = self._new_params_mega_shards,
  800. tensor = self._fp16_p,
  801. group = self._ag_pg[0],
  802. no_copy = True,
  803. )
  804. else:
  805. torch.distributed.all_gather_into_tensor(
  806. output_tensor = self._new_params,
  807. input_tensor = self._fp16_p,
  808. group = self._ag_pg[0],
  809. )
  810. def _flatten_grad_mt(self, scale):
  811. if len(self._grads_fp16) > 0:
  812. self._overflow_buf.zero_()
  813. if not self._fused_norm:
  814. multi_tensor_applier(
  815. amp_C.multi_tensor_scale,
  816. self._overflow_buf,
  817. list(zip(*self._grads_fp16)),
  818. scale)
  819. else:
  820. self._L2_grad_norm=multi_tensor_applier(
  821. amp_C.multi_tensor_l2norm_scale,
  822. self._overflow_buf,
  823. list(zip(*self._grads_fp16)),
  824. scale, False)[0].float()
  825. self._grads_fp16 = []
  826. if len(self._grads_fp32) > 0:
  827. self._overflow_buf.zero_()
  828. if not self._fused_norm:
  829. multi_tensor_applier(
  830. amp_C.multi_tensor_scale,
  831. self._overflow_buf,
  832. list(zip(*self._grads_fp32)),
  833. scale)
  834. else:
  835. self._L2_grad_norm=multi_tensor_applier(
  836. amp_C.multi_tensor_l2norm_scale,
  837. self._overflow_buf,
  838. list(zip(*self._grads_fp32)),
  839. scale, False)[0].float()
  840. self._grads_fp32 = []
  841. def _do_overlapped_reduction(self, param_i, param):
  842. if not self._is_accumulation_step:
  843. # handle overlapped reductions
  844. if param.dtype == torch.float16:
  845. self._grads_fp16.append( (param.grad, self._individual_flat_grads[param_i]) )
  846. else:
  847. self._grads_fp32.append( (param.grad, self._individual_flat_grads[param_i]) )
  848. self._grads_generated[param_i]=True
  849. if not self._first_step and not self._last_step:
  850. if self._overlap_reductions:
  851. flush_block = self._get_flush_block()
  852. while flush_block:
  853. block_id = flush_block[0] // self._block_size
  854. self._pipeline_block_reductions(block_id)
  855. flush_block = self._get_flush_block()
  856. def set_global_scale(self, global_scale):
  857. """Set global scale.
  858. """
  859. self._global_scale = global_scale
  860. @property
  861. def global_scale(self):
  862. return self._global_scale
  863. @property
  864. def L2_grad_norm(self):
  865. torch.cuda.current_stream().wait_stream(self._l2_grad_norm_st)
  866. return self._L2_grad_norm
  867. def complete_reductions(self):
  868. """Complete reductions if full pipeline is not selected or overlap is not allowed.
  869. """
  870. if self._last_step:
  871. # zero out gradients that have not been completed yet
  872. for param_i, grad_generated in enumerate(self._grads_generated):
  873. if not grad_generated:
  874. grad_info = self._grads_info[param_i]
  875. param_offset = grad_info["param_offset"]
  876. param_size = grad_info["param_grads_size"]
  877. self._flat_grads[param_offset:param_offset+param_size].zero_()
  878. self._grads_generated[param_i] = True
  879. if self._first_step or self._last_step or not self._overlap_reductions:
  880. # nothing done so far, run full pipeline after reductions
  881. for block_id in range(self._num_blocks-1,-1,-1):
  882. self._pipeline_block_reductions(block_id)
  883. torch.cuda.current_stream().wait_stream(self._l2_grad_norm_st)
  884. self._current_block = self._num_blocks
  885. self._grads_generated = [False]*len(self._grads_info)
  886. def step(self, closure=None, grad_scaler=None):
  887. loss = None
  888. if closure is not None:
  889. loss = closure()
  890. self._pipeline_step()
  891. if grad_scaler is not None:
  892. found_inf = self._overflow_buf.float()
  893. optimizer_state = grad_scaler._per_optimizer_states[id(self)]
  894. current_device = torch.device('cuda', torch.cuda.current_device())
  895. optimizer_state["found_inf_per_device"][current_device] = found_inf
  896. self._completion_st.wait_stream(torch.cuda.current_stream())
  897. if not self._set_flat_param_view:
  898. with torch.cuda.stream(self._completion_st):
  899. # Copy self._new_params to model params
  900. with torch.no_grad():
  901. if self._packed_flat_to_model_params_fp16 is not None:
  902. multi_tensor_applier(
  903. fused_adam_cuda.maybe_cast_mt,
  904. self._overflow_buf,
  905. self._packed_flat_to_model_params_fp16)
  906. if self._packed_flat_to_model_params_fp32 is not None:
  907. multi_tensor_applier(
  908. fused_adam_cuda.maybe_cast_mt,
  909. self._overflow_buf,
  910. self._packed_flat_to_model_params_fp32)
  911. torch.cuda.current_stream().wait_stream(self._completion_st)
  912. self._reductions_works = [None]*self._num_blocks
  913. self._allgather_works = [None]*self._num_blocks
  914. return loss
  915. def state_dict(self):
  916. """
  917. Returns a dict containing the current state of this :class:`DistributedFusedAdam` instance.
  918. Example::
  919. checkpoint = {}
  920. checkpoint['model'] = model.state_dict()
  921. checkpoint['optimizer'] = optimizer.state_dict()
  922. torch.save(checkpoint, "saved.pth")
  923. """
  924. # save step, master weights and first/second moments
  925. state_dict = {}
  926. state_dict['step'] = self._step
  927. state_dict['fp32_p'] = self._fp32_p
  928. state_dict['fp32_m'] = self._fp32_m
  929. state_dict['fp32_v'] = self._fp32_v
  930. return state_dict
  931. def load_state_dict(self, state_dict):
  932. """
  933. Loads a state_dict created by an earlier call to state_dict().
  934. If an DistributedFusedAdam instance was constructed from some ``init_optimizer``,
  935. whose parameters in turn came from ``model``, it is expected that the user
  936. will call ``model.load_state_dict()`` before
  937. ``optimizer.load_state_dict()`` is called.
  938. Example::
  939. model = torch.nn.Linear(D_in, D_out).cuda().half()
  940. optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
  941. optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)
  942. ...
  943. checkpoint = torch.load("saved.pth")
  944. model.load_state_dict(checkpoint['model'])
  945. optimizer.load_state_dict(checkpoint['optimizer'])
  946. """
  947. # restore step, master weights and first/second moments
  948. self._step = state_dict['step']
  949. self._fp32_p = state_dict['fp32_p'].to(device="cuda")
  950. self._fp32_m = state_dict['fp32_m'].to(device="cuda")
  951. self._fp32_v = state_dict['fp32_v'].to(device="cuda")
  952. self._resume_from_checkpoint = True