12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061 |
- import os
- import math
- import inspect
- import torch
- import importlib
- import amp_C
- from apex.multi_tensor_apply import multi_tensor_applier
- import torch.distributed.distributed_c10d as c10d
- # Fallback to private fields if using older PyTorch version
- try:
- import torch.distributed.distributed_c10d.get_process_group_ranks
- except ImportError:
- def get_process_group_ranks(group):
- return list(c10d._pg_group_ranks[group].keys())
- _make_nccl_premul_sum = getattr(torch.distributed, "_make_nccl_premul_sum", None)
- # Ref: https://github.com/pytorch/pytorch/pull/81272
- if _make_nccl_premul_sum is None:
- if hasattr(torch.distributed, "make_nccl_premul_sum"):
- _make_nccl_premul_sum = torch.distributed.make_nccl_premul_sum
- class DistributedFusedLAMB(torch.optim.Optimizer):
- """Implements LAMB algorithm.
- Currently GPU-only. Requires Apex to be installed via
- ``pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./``.
- This version of fused LAMB implements 2 fusions.
- * Fusion of the LAMB update's elementwise operations
- * A multi-tensor apply launch that batches the elementwise updates applied to all the model's parameters into one or a few kernel launches.
- :class:`apex.optimizers.FusedLAMB`'s usage is identical to any ordinary Pytorch optimizer::
- opt = apex.optimizers.FusedLAMB(model.parameters(), lr = ....)
- ...
- opt.step()
- :class:`apex.optimizers.FusedLAMB` may be used with or without Amp. If you wish to use :class:`FusedLAMB` with Amp,
- you may choose any ``opt_level``::
- opt = apex.optimizers.FusedLAMB(model.parameters(), lr = ....)
- model, opt = amp.initialize(model, opt, opt_level="O0" or "O1 or "O2")
- ...
- opt.step()
- In general, ``opt_level="O1"`` is recommended.
- LAMB was proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_.
- Arguments:
- params (iterable): iterable of parameters to optimize or dicts defining
- parameter groups.
- lr (float, optional): learning rate. (default: 1e-3)
- betas (Tuple[float, float], optional): coefficients used for computing
- running averages of gradient and its norm. (default: (0.9, 0.999))
- eps (float, optional): term added to the denominator to improve
- numerical stability. (default: 1e-8)
- weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
- amsgrad (boolean, optional): whether to use the AMSGrad variant of this
- algorithm from the paper `On the Convergence of Adam and Beyond`_
- NOT SUPPORTED now! (default: False)
- adam_w_mode (boolean, optional): Apply L2 regularization or weight decay
- True for decoupled weight decay(also known as AdamW) (default: True)
- grad_averaging (bool, optional): whether apply (1-beta2) to grad when
- calculating running averages of gradient. (default: True)
- set_grad_none (bool, optional): whether set grad to None when zero_grad()
- method is called. (default: True)
- max_grad_norm (float, optional): value used to clip global grad norm
- (default: 1.0)
- use_nvlamb (boolean, optional): Apply adaptive learning rate to 0.0
- weight decay parameter (default: False)
- step_supports_amp_scaling(boolean, optional): whether to use customized
- gradient unscaling logic (default: True)
- .. _Large Batch Optimization for Deep Learning - Training BERT in 76 minutes:
- https://arxiv.org/abs/1904.00962
- .. _On the Convergence of Adam and Beyond:
- https://openreview.net/forum?id=ryQu7f-RZ
- """
- class AtomicCounter(object):
- def __init__(self):
- self.value = 0
- self.order = []
- import threading
- self._lock = threading.Lock()
- def add(self, idx):
- with self._lock:
- self.value += 1
- self.order.append(idx)
- def __init__(self, params,
- lr=1e-3, bias_correction = True, grad_averaging=True,
- betas=(0.9, 0.999), eps=1e-8,
- weight_decay=0., max_grad_norm=0.,
- adam_w_mode=True, use_nvlamb=False,
- step_supports_amp_scaling=True, overlap_reductions=True,
- dwu_group_size=0, dwu_num_blocks=4, dwu_num_chunks=4,
- dwu_num_rs_pg=1, dwu_num_ar_pg=4, dwu_num_ag_pg=0, fused_norm=False,
- e5m2_allgather=False, verbose=False, clip_after_ar=True,
- full_ar=False, set_param_views_to_flat_buffer=False, skip_allgather=False,
- fuse_scale=False, param_order=None, nccl_allgather_channels=0):
- defaults = dict(lr=lr, bias_correction=bias_correction,
- betas=betas, eps=eps, weight_decay=weight_decay,
- grad_averaging=grad_averaging,
- max_grad_norm=max_grad_norm)
- super(DistributedFusedLAMB, self).__init__(params, defaults)
- global fused_adam_cuda, distributed_lamb_cuda
- fused_adam_cuda = importlib.import_module("fused_adam_cuda")
- distributed_lamb_cuda = importlib.import_module("distributed_lamb_cuda")
- self._overflow_buf = torch.cuda.IntTensor([0])
- self._has_overflow = False
- self.multi_tensor_lamb_compute_update_term = distributed_lamb_cuda.multi_tensor_lamb_compute_update_term
- self.multi_tensor_lamb_update_weights = distributed_lamb_cuda.multi_tensor_lamb_update_weights
- import amp_C
- self.multi_tensor_l2norm = amp_C.multi_tensor_l2norm
- self._grad_averaging = grad_averaging
- self._adam_w_mode = 1 if adam_w_mode else 0
- self._use_nvlamb = use_nvlamb
- self._step_supports_amp_scaling = step_supports_amp_scaling
- self._is_accumulation_step = False
- self._last_step = False
- self._overlap_reductions = overlap_reductions
- self._global_scale = None
- self._num_blocks = dwu_num_blocks
- self._num_chunks = dwu_num_chunks
- self._e5m2_allgather = e5m2_allgather
- self._verbose = verbose
- self._clip_after_ar = clip_after_ar
- self._full_ar = full_ar
- self._fuse_scale = fuse_scale
- self._L2_grad_norm = None
- self._set_flat_param_view = set_param_views_to_flat_buffer
- self._skip_ag = skip_allgather
- self._fused_norm = fused_norm if not clip_after_ar else False
- self._current_process_group = c10d._get_default_group()
- self._available_ranks = get_process_group_ranks(self._current_process_group)
- self._group_size = torch.cuda.device_count() if dwu_group_size <= 0 else dwu_group_size
- self._world_size = torch.distributed.get_world_size()
- self._num_groups = self._world_size // self._group_size
- self._rank_in_group = torch.distributed.get_rank() % self._group_size
- self._lr = torch.tensor(0.0, dtype=torch.float32, device='cuda')
- self._resume_from_checkpoint = False
- self._step = torch.cuda.IntTensor([0])
- # Master weight, moment, gradient buffers
- self._fp32_p, self._fp32_m, self._fp32_v, self._fp16_p, self._fp16_g = None, None, None, None, None
- # Check if collectives have no_copy option
- self._reduce_scatter_no_copy = (
- 'no_copy' in inspect.getfullargspec(torch.distributed.reduce_scatter).args
- )
- self._all_gather_no_copy = (
- 'no_copy' in inspect.getfullargspec(torch.distributed.all_gather).args
- )
- if "reduce_scatter_tensor" not in dir(torch.distributed):
- torch.distributed.reduce_scatter_tensor = torch.distributed._reduce_scatter_base
- if "all_gather_into_tensor" not in dir(torch.distributed):
- torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base
- self._num_rs_pg = dwu_num_rs_pg
- self._num_ar_pg = dwu_num_ar_pg
- self._num_ag_pg = dwu_num_ag_pg
- if self._full_ar: # full all reduce, only need AR and AG groups
- # l2_grad_norm may be reduced within a node to limit from memory reads
- for group_i in range(self._num_groups):
- ranks = [group_i*self._group_size+j for j in range(self._group_size)]
- l2_grad_norm_pg = torch.distributed.new_group(ranks=ranks)
- if torch.distributed.get_rank() in ranks:
- self._l2_grad_norm_pg = l2_grad_norm_pg
- self._ar_pg = []
- # consider all the ranks
- ranks = list(range(0, self._world_size))
- for i in range(self._num_ar_pg):
- if self._verbose:
- print(f"creating new AR group {i}: {ranks}")
- grp = torch.distributed.new_group(ranks=ranks)
- if grp != torch.distributed.GroupMember.NON_GROUP_MEMBER:
- if self._verbose:
- print(f"group {i}: init barrier (device: {torch.cuda.current_device()})")
- torch.distributed.barrier(group=grp, device_ids=[torch.cuda.current_device()])
- if self._verbose:
- print(f"created new AR group {i}: {ranks}")
- if torch.distributed.get_rank() in ranks:
- self._ar_pg.append(grp)
- self._ar_st = [torch.cuda.Stream() for _ in range(self._num_ar_pg)]
- if nccl_allgather_channels > 0:
- os.putenv('NCCL_MAX_NCHANNELS', str(nccl_allgather_channels))
- if self._num_ag_pg == 0:
- self._ag_pg = self._ar_pg
- self._ag_st = self._ar_st
- self._num_ag_pg = self._num_ar_pg
- else:
- self._ag_pg = []
- ranks = []
- stride = torch.cuda.device_count()
- for i in range(self._num_groups):
- rs = list(range(i*stride, (i+1)*stride))
- ranks.append(rs)
- for rs in ranks:
- for i in range(self._num_ag_pg):
- grp = torch.distributed.new_group(ranks=rs)
- if torch.distributed.get_rank() in rs:
- if self._verbose:
- print(f"creating AG group {i}: {rs}")
- self._ag_pg.append(grp)
- self._ag_st = [torch.cuda.Stream() for _ in range(self._num_ag_pg)]
- else: # reduce-scatter + all-reduce, need RS, AR, AG groups
- if self._num_groups > 1:
- self._ar_pg = []
- for dev_i in range(self._group_size):
- ranks = [dev_i+j*self._group_size for j in range(self._num_groups)]
- for i in range(self._num_ar_pg):
- if self._verbose:
- print(f"creating new AR group {i}: {ranks}")
- grp = torch.distributed.new_group(ranks=ranks)
- if grp != torch.distributed.GroupMember.NON_GROUP_MEMBER:
- if self._verbose:
- print(f"group {i}: init barrier (device: {torch.cuda.current_device()})")
- torch.distributed.barrier(group=grp, device_ids=[torch.cuda.current_device()])
- if self._verbose:
- print(f"created new AR group {i}: {ranks}")
- if torch.distributed.get_rank() in ranks:
- self._ar_pg.append(grp)
- self._ar_st = [torch.cuda.Stream() for _ in range(self._num_ar_pg)]
- rs_ranks = []
- for group_i in range(self._num_groups):
- rs_ranks.append([group_i*self._group_size+j for j in range(self._group_size)])
- self._rs_pg = []
- for group_i in range(self._num_groups):
- ranks = rs_ranks[group_i]
- for i in range(self._num_rs_pg):
- grp = torch.distributed.new_group(ranks=ranks)
- if torch.distributed.get_rank() in ranks:
- self._rs_pg.append(grp)
- if self._verbose:
- print(f"creating RS group : {ranks}")
- l2_grad_norm_pg = torch.distributed.new_group(ranks=ranks)
- if torch.distributed.get_rank() in ranks:
- self._l2_grad_norm_pg = l2_grad_norm_pg
- self._rs_st = [torch.cuda.Stream() for _ in range(self._num_rs_pg)]
- if self._num_ag_pg == 0:
- self._ag_pg = self._rs_pg
- self._ag_st = self._rs_st
- self._num_ag_pg = self._num_rs_pg
- else:
- self._ag_pg = []
- for group_i in range(self._num_groups):
- ranks = rs_ranks[group_i]
- for i in range(self._num_ag_pg):
- grp = torch.distributed.new_group(ranks=ranks)
- if torch.distributed.get_rank() in ranks:
- self._ag_pg.append(grp)
- if self._verbose:
- print(f"creating AG group : {ranks}")
- self._ag_st = [torch.cuda.Stream() for _ in range(self._num_ag_pg)]
- for ag_pg in self._ag_pg:
- torch.distributed.barrier(group=ag_pg)
- self._l2_grad_norm_st = torch.cuda.Stream()
- self._completion_st = torch.cuda.Stream()
- self._step.record_stream(self._completion_st)
- self._reductions_works = [None]*self._num_blocks
- self._allgather_works = [None]*self._num_blocks
- self._one = torch.cuda.IntTensor([1])
- self._first_step = True
- self._lazy_init_stage1_done, self._lazy_init_stage2_done = False, False
- self._param_order = self.AtomicCounter()
- p_offset = 0
- p_i = 0
- self._model_params = []
- self._grad_accs = []
- self._group_properties = []
- for group in self.param_groups:
- prev = None
- beta1, beta2 = group['betas']
- beta3 = 1.0 - beta1 if self._grad_averaging else 1.0
- bias_correction = 1 if group['bias_correction'] else 0
- eps = group['eps']
- weight_decay = group['weight_decay']
- for p in group['params']:
- if not p.requires_grad:
- continue
- self._model_params.append(p)
- self._group_properties.append((
- weight_decay,
- bias_correction,
- beta1,
- beta2,
- beta3,
- eps
- ))
- p_grads_size = p.numel()
- if self._set_flat_param_view:
- if param_order:
- # this is executed when param_order is specified by the user
- self._param_order.add(param_order[p])
- else:
- self._param_order.add(p_i)
- p_offset += p_grads_size
- # Only enforce 128b alignment (64 * fp16) for non-consecutive parameters
- # RNN is one example of consecutive parameters:
- # (weight_ih, weight_hh, bias_ih, bias_hh)
- if prev is not None and (prev.data_ptr() + prev.numel() * prev.element_size() != p.data_ptr()):
- p_offset = ((p_offset + 63) // 64) * 64
- prev = p
- p_i += 1
- if param_order:
- self._param_order.order = torch.argsort(torch.tensor(self._param_order.order)).tolist()
- self._grads_generated = [False]*len(self._model_params)
- self._grads_fp16, self._grads_fp32 = [], []
- if self._overlap_reductions:
- self._current_block = self._num_blocks
- self._net_total_param_size = p_offset
- self._total_param_size = p_offset
- dwu_min_page_size = 256 * self._num_blocks * self._num_chunks * self._group_size
- self._total_param_size = ((self._total_param_size + dwu_min_page_size - 1) // dwu_min_page_size) * dwu_min_page_size
- self._new_params = torch.zeros([self._total_param_size], dtype=torch.uint8 if self._e5m2_allgather else torch.float16, device='cuda')
- def _lazy_init_stage1(self):
- if self._lazy_init_stage1_done: return
- p_i = 0
- #self._model_params = []
- #self._grad_accs = []
- #self._group_properties = []
- for group in self.param_groups:
- for p in group['params']:
- torch.distributed.broadcast(p, 0)
- if not p.requires_grad:
- continue
- def wrapper(param, param_i):
- param_tmp = param.expand_as(param)
- grad_acc = param_tmp.grad_fn.next_functions[0][0]
- def allreduce_hook(*unused):
- if not self._set_flat_param_view:
- if self._first_step:
- # first time
- self._param_order.add(param_i)
- else:
- idx = self._param_order.order.index(param_i)
- self._do_overlapped_reduction(idx, param)
- else:
- if not self._first_step:
- idx = self._param_order.order.index(param_i)
- self._do_overlapped_reduction(idx, param)
- grad_acc.register_hook(allreduce_hook)
- self._grad_accs.append(grad_acc)
- wrapper(p, p_i)
- p_i += 1
- self._block_size = self._total_param_size // self._num_blocks
- self._chunk_size = self._block_size // self._num_chunks
- self._shard_size = self._chunk_size // self._group_size
- self._flat_grads = torch.zeros([self._total_param_size], dtype=torch.float16, device='cuda')
- self._mega_shard_size = self._num_blocks * self._num_chunks * self._shard_size
- # initialize master weights, moments buffers if not loaded from checkpoint
- if self._fp32_p is None:
- self._fp32_p = torch.zeros([self._mega_shard_size], dtype=torch.float32, device='cuda')
- self._fp32_m = torch.zeros([self._mega_shard_size], dtype=torch.float32, device='cuda')
- self._fp32_v = torch.zeros([self._mega_shard_size], dtype=torch.float32, device='cuda')
- self._fp32_u = torch.zeros([self._mega_shard_size], dtype=torch.float32, device='cuda')
- # FIXME: Rethink fp16 label since it's either uint8 or fp16
- self._fp16_p = torch.zeros([self._mega_shard_size], dtype=torch.uint8 if self._e5m2_allgather else torch.float16, device='cuda')
- self._fp16_g = torch.zeros([self._mega_shard_size], dtype=torch.float16, device='cuda')
- def _flat_split(p):
- def __blockify(p):
- return [p[block_id*self._block_size:(block_id+1)*self._block_size] for block_id in range(self._num_blocks)]
- def __chunkify(p):
- return [p[chunk_id*self._chunk_size:(chunk_id+1)*self._chunk_size] for chunk_id in range(self._num_chunks)]
- def __shardify(p):
- return [p[shard_id*self._shard_size:(shard_id+1)*self._shard_size] for shard_id in range(self._group_size)]
- list_of_blocks = __blockify(p)
- list_of_list_of_chunks = [__chunkify(block) for block in list_of_blocks]
- list_of_list_of_list_of_shards = [[__shardify(chunk) for chunk in chunks] for chunks in list_of_list_of_chunks]
- return list_of_blocks, list_of_list_of_chunks, list_of_list_of_list_of_shards
- # note(crcrpar): the function below doesn't seem to be used at all.
- # def _flat_split_no_shards(p):
- # def __blockify(p):
- # return [p[block_id*self._block_size:(block_id+1)*self._block_size] for block_id in range(self._num_blocks)]
- # def __chunkify(p):
- # return [p[chunk_id*self._chunk_size:(chunk_id+1)*self._chunk_size] for chunk_id in range(self._num_chunks)]
- # list_of_blocks = __blockify(self._flat_grads)
- # list_of_list_of_chunks = [__chunkify(block) for block in list_of_blocks]
- # return list_of_blocks, list_of_list_of_chunks
- def _full_packed_split(p):
- def __shardify(p):
- return [p[mega_shard*self._mega_shard_size:(mega_shard+1)*self._mega_shard_size] for mega_shard in range(self._group_size)]
- def __blockify(p):
- return [p[block_id*self._num_chunks*self._shard_size:(block_id+1)*self._num_chunks*self._shard_size] for block_id in range(self._num_blocks)]
- def __chunkify(p):
- return [p[chunk_id*self._shard_size:(chunk_id+1)*self._shard_size] for chunk_id in range(self._num_chunks)]
- list_of_mega_shards = __shardify(p)
- list_of_list_of_mega_blocks = [__blockify(mega_shard) for mega_shard in list_of_mega_shards]
- list_of_list_of_list_of_mega_chunks = [[__chunkify(mega_block) for mega_block in mega_blocks] for mega_blocks in list_of_list_of_mega_blocks]
- return list_of_mega_shards, list_of_list_of_mega_blocks, list_of_list_of_list_of_mega_chunks
- def _packed_split(p):
- def __packed_blockify(p):
- packed_block_size = self._num_chunks*self._shard_size
- return [p[block_id*packed_block_size:(block_id+1)*packed_block_size] for block_id in range(self._num_blocks)]
- def __packed_chunkify(p):
- # in the packed format, each chunk contains one shard, so packed_chunk_size == self._shard_size
- return [p[chunk_id*self._shard_size:(chunk_id+1)*self._shard_size] for chunk_id in range(self._num_chunks)]
- list_of_blocks = __packed_blockify(p)
- list_of_list_of_chunks = [__packed_chunkify(block) for block in list_of_blocks]
- return list_of_blocks, list_of_list_of_chunks
- def _split_assign(shards):
- packed_block_size = self._num_chunks*self._shard_size
- list_of_list_of_chunks=[]
- for block_id in range(self._num_blocks):
- list_of_chunks=[]
- for chunk_id in range(self._num_chunks):
- #self._fp16_g[block_id*packed_block_size+chunk_id*self._shard_size:block_id*packed_block_size+(chunk_id+1)*self._shard_size] = shards[block_id][chunk_id][self._rank_in_group]
- list_of_chunks.append( shards[block_id][chunk_id][self._rank_in_group])
- list_of_list_of_chunks.append(list_of_chunks)
- return list_of_list_of_chunks
- self._new_params_mega_shards, self._new_params_mega_blocks, self._new_params_mega_chunks = _full_packed_split(self._new_params)
- # this splitting scheme is needed when allgather needs to be split into multiple chunks in a contiguous way
- self._new_params2_blocks, self._new_params2_chunks, self._new_params2_shards = _flat_split(self._new_params)
- self._fp32_p_blocks, self._fp32_p_chunks = _packed_split(self._fp32_p)
- self._fp32_m_blocks, self._fp32_m_chunks = _packed_split(self._fp32_m)
- self._fp32_v_blocks, self._fp32_v_chunks = _packed_split(self._fp32_v)
- self._fp32_u_blocks, self._fp32_u_chunks = _packed_split(self._fp32_u)
- self._fp16_p_blocks, self._fp16_p_chunks = _packed_split(self._fp16_p)
- if self._full_ar:
- # for gradient all-reduce
- self._flat_grads_blocks, self._flat_grads_chunks, self._flat_grads_shards = _flat_split(self._flat_grads)
- # for weight update
- self._fp16_g_chunks = _split_assign(self._flat_grads_shards)
- else:
- self._flat_grads_blocks, self._flat_grads_chunks, self._flat_grads_shards = _flat_split(self._flat_grads)
- self._fp16_g_blocks, self._fp16_g_chunks = _packed_split(self._fp16_g)
- self._lazy_init_stage1_done = True
- def _lazy_init_stage2(self):
- if self._lazy_init_stage2_done: return
- if not self._set_flat_param_view:
- # reversing is needed for overlapping allreduce and backprop, but currently not supported for flat param view
- self._param_order.order.reverse()
- # re-order model_params, grad_accs, group_properties lists
- self._model_params = [self._model_params[i] for i in self._param_order.order]
- self._grad_accs = [self._grad_accs[i] for i in self._param_order.order]
- self._group_properties = [self._group_properties[i] for i in self._param_order.order]
- def _get_flat_view(param):
- if param.is_contiguous(memory_format=torch.channels_last):
- K, C, H, W = param.shape
- pv = param.as_strided(size=(K,H,W,C), stride=(H*W*C, W*C, C, 1))
- elif param.is_contiguous(memory_format=torch.channels_last_3d):
- K, C, D, H, W = param.shape
- pv = param.as_strided(size=(K,D,H,W,C), stride=(D*H*W*C, H*W*C, W*C, C, 1))
- else:
- pv = param
- return pv.view(-1)
- # re-collect grads info (size, offset) after ordering
- prev = None
- p_offset = 0
- self._grads_info = []
- self._individual_flat_grads = []
- for i, p in enumerate(self._model_params):
- p_grads_size = p.numel()
- self._grads_info.append({"param_grads_size":p_grads_size, "param_offset":p_offset})
- self._individual_flat_grads.append(self._flat_grads[p_offset:p_offset+p_grads_size].view_as(p))
- # for the first iteration
- self._do_overlapped_reduction(i, p)
- p_offset += p_grads_size
- # Only enforce 128b alignment (64 * fp16) for non-consecutive parameters
- # RNN is one example of consecutive parameters:
- # (weight_ih, weight_hh, bias_ih, bias_hh)
- if prev is not None and (prev.data_ptr() + prev.numel() * prev.element_size() != p.data_ptr()):
- p_offset = ((p_offset + 63) // 64) * 64
- prev = p
- self._low_param_i = [0]*self._num_blocks
- for block_id in range(self._num_blocks-1,-1,-1):
- p_i = len(self._grads_info)-1
- while p_i > 0 and self._grads_info[p_i]["param_offset"] > block_id*self._block_size:
- p_i -= 1
- self._low_param_i[block_id] = p_i
- #print("self._low_param_i", self._low_param_i)
- # This paragraph does two things:
- # 1) Copy model parameters into master buffer
- # 2) Create tensor lists for unpacking new parameter tensor after all-gather
- self._packed_flat_to_model_params_fp16 = []
- self._packed_flat_to_model_params_fp32 = []
- self._model_params_num = len(self._model_params)
- self._contrib_tensor_list = []
- self._contrib_min_param_i, self._contrib_max_param_i = -1, -1
- self._contrib_update_frag_for_norm = []
- self._contrib_model_param_for_norm_fp16 = []
- self._contrib_model_param_for_norm_fp32 = []
- self._contrib_model_param_for_norm_is_fp16 = []
- self._model_param_is_contrib = []
- self._contrib_group_properties = []
- for shard_id in range(self._group_size):
- for block_id in range(self._num_blocks):
- for chunk_id in range(self._num_chunks):
- flat_shard_start = (((block_id * self._num_chunks + chunk_id) * self._group_size) + shard_id) * self._shard_size
- flat_shard_end = flat_shard_start + self._shard_size
- for param_i, (p, grads_info, group_props) in enumerate(zip(self._model_params, self._grads_info, self._group_properties)):
- flat_grad_start = grads_info["param_offset"]
- flat_grad_end = flat_grad_start + grads_info["param_grads_size"]
- clipped_start = (lambda a,b: a if a > b else b)(flat_grad_start, flat_shard_start)
- clipped_end = (lambda a,b: a if a < b else b)(flat_grad_end, flat_shard_end)
- if clipped_start < clipped_end:
- grad_offset = clipped_start - flat_grad_start
- grad_length = clipped_end - clipped_start
- shard_offset = clipped_start - flat_shard_start
- pf = _get_flat_view(p)
- model_param_fragment = pf[grad_offset:grad_offset+grad_length]
- new_param_packed_fragment = self._new_params_mega_chunks[shard_id][block_id][chunk_id][shard_offset:shard_offset+grad_length]
- if model_param_fragment.dtype == torch.float16:
- self._packed_flat_to_model_params_fp16.append( (new_param_packed_fragment, model_param_fragment) )
- else:
- self._packed_flat_to_model_params_fp32.append( (new_param_packed_fragment, model_param_fragment) )
- if shard_id == self._rank_in_group:
- self._model_param_is_contrib.append(param_i)
- # copy model parameters into master buffer
- master_param_fragment = self._fp32_p_chunks[block_id][chunk_id][shard_offset:shard_offset+grad_length]
- opti_state_m_fragment = self._fp32_m_chunks[block_id][chunk_id][shard_offset:shard_offset+grad_length]
- opti_state_v_fragment = self._fp32_v_chunks[block_id][chunk_id][shard_offset:shard_offset+grad_length]
- opti_state_u_fragment = self._fp32_u_chunks[block_id][chunk_id][shard_offset:shard_offset+grad_length]
- opti_state_g_fragment = self._fp16_g_chunks[block_id][chunk_id][shard_offset:shard_offset+grad_length]
- opti_state_p_fragment = self._fp16_p_chunks[block_id][chunk_id][shard_offset:shard_offset+grad_length]
- #print("model_param_fragment.size()=%s, new_param_packed_fragment.size()=%s, master_param_fragment.size()=%s" % (str(model_param_fragment.size()), str(new_param_packed_fragment.size()), str(master_param_fragment.size())))
- if not self._resume_from_checkpoint:
- master_param_fragment.copy_(model_param_fragment)
- self._contrib_group_properties.append(group_props)
- self._contrib_tensor_list.append((master_param_fragment, opti_state_m_fragment, opti_state_v_fragment, opti_state_u_fragment, opti_state_g_fragment, opti_state_p_fragment)) # p, m, v, u, g, p_copy
- self._contrib_update_frag_for_norm.append(opti_state_u_fragment)
- if p.dtype == torch.float16:
- self._contrib_model_param_for_norm_fp16.append(p)
- else:
- self._contrib_model_param_for_norm_fp32.append(p)
- self._contrib_model_param_for_norm_is_fp16.append(True if p.dtype == torch.float16 else False)
- if self._contrib_min_param_i < 0: self._contrib_min_param_i = param_i
- self._contrib_max_param_i = param_i
- self._contrib_model_param_for_norm_num = len(self._contrib_model_param_for_norm_is_fp16)
- if len(self._contrib_model_param_for_norm_fp16) == 0: self._contrib_model_param_for_norm_fp16 = None
- if len(self._contrib_model_param_for_norm_fp32) == 0: self._contrib_model_param_for_norm_fp32 = None
- self._contrib_model_param_for_norm_is_fp32 = torch.tensor([not is_fp16 for is_fp16 in self._contrib_model_param_for_norm_is_fp16], dtype=torch.bool, device='cuda')
- self._contrib_model_param_for_norm_is_fp16 = torch.tensor([is_fp16 for is_fp16 in self._contrib_model_param_for_norm_is_fp16], dtype=torch.bool, device='cuda')
- self._offsets = torch.tensor(self._model_param_is_contrib, dtype=torch.int64, device='cuda')
- p, m, v, u, g, p_copy = list(zip(*self._contrib_tensor_list))
- self._contrib_compute_update_term_tensor_list = [g, p, m, v, u]
- self._contrib_update_weights_tensor_list = [u, p, p_copy]
- math_type = self._fp32_u.dtype
- decay, bias_correction, beta1, beta2, beta3, epsilon = list(zip(*self._contrib_group_properties))
- self._contrib_beta1 = torch.tensor(beta1, dtype=math_type, device='cuda')
- self._contrib_beta2 = torch.tensor(beta2, dtype=math_type, device='cuda')
- self._contrib_beta3 = torch.tensor(beta3, dtype=math_type, device='cuda')
- self._contrib_bias_correction = torch.tensor(bias_correction, dtype=torch.int, device='cuda')
- self._contrib_epsilon = torch.tensor(epsilon, dtype=math_type, device='cuda')
- self._contrib_weight_decay = torch.tensor(decay, dtype=math_type, device='cuda')
- self._packed_flat_to_model_params_fp16 = list(zip(*self._packed_flat_to_model_params_fp16)) if len(self._packed_flat_to_model_params_fp16) > 0 else None
- self._packed_flat_to_model_params_fp32 = list(zip(*self._packed_flat_to_model_params_fp32)) if len(self._packed_flat_to_model_params_fp32) > 0 else None
- self._lazy_init_stage2_done = True
- self.complete_reductions()
- self._first_step = False
- def set_is_accumulation_step(self, is_accumulation_step):
- self._is_accumulation_step = is_accumulation_step
- def set_last_step(self, last_step):
- self._last_step = last_step
- def _get_flush_block(self):
- flush_block = []
- if self._current_block > 0 and self._grads_generated[self._low_param_i[self._current_block-1]]:
- num_grads = len(self._grads_generated)
- contiguous_idx = num_grads
- while contiguous_idx > 0 and self._grads_generated[contiguous_idx-1]:
- contiguous_idx -= 1
- if contiguous_idx < num_grads and self._grads_info[contiguous_idx]["param_offset"] <= (self._current_block-1)*self._block_size:
- self._current_block -= 1
- start = self._current_block * self._block_size
- end = (self._current_block+1) * self._block_size
- flush_block = [start, end]
- return flush_block
- def _full_all_reduce_scale(self, block_id, scale):
- works = [None]*self._num_chunks
- if self._clip_after_ar:
- for chunk_id in range(self._num_chunks):
- glob_chunk_id = block_id * self._num_chunks + chunk_id
- ar_stream = self._ar_st[glob_chunk_id%self._num_ar_pg]
- ar_stream.wait_stream(torch.cuda.current_stream())
- with torch.cuda.stream(ar_stream):
- works[chunk_id] = torch.distributed.all_reduce(self._flat_grads_chunks[block_id][chunk_id],group=self._ar_pg[glob_chunk_id%self._num_ar_pg],async_op=True,op=_make_nccl_premul_sum(scale))
- else:
- glob_chunk_id = block_id
- ar_stream = self._ar_st[glob_chunk_id%self._num_ar_pg]
- ar_stream.wait_stream(torch.cuda.current_stream())
- with torch.cuda.stream(ar_stream):
- works0 = torch.distributed.all_reduce(self._flat_grads_blocks[block_id],group=self._ar_pg[glob_chunk_id%self._num_ar_pg],async_op=True,op=_make_nccl_premul_sum(scale))
- for i in range(self._num_chunks):
- works[i]=works0
- self._reductions_works[block_id] = works
- def _full_all_reduce(self, block_id):
- works = [None]*self._num_chunks
- for chunk_id in range(self._num_chunks):
- glob_chunk_id = block_id * self._num_chunks + chunk_id
- ar_stream = self._ar_st[glob_chunk_id%self._num_ar_pg]
- ar_stream.wait_stream(torch.cuda.current_stream())
- with torch.cuda.stream(ar_stream):
- works[chunk_id] = torch.distributed.all_reduce(self._flat_grads_chunks[block_id][chunk_id],group=self._ar_pg[glob_chunk_id%self._num_ar_pg],async_op=True)
- self._reductions_works[block_id] = works
- def _reduce_scatter_and_all_reduce_scale(self, block_id, scale):
- # Reduction within each node
- # Changes gradient format from [block * chunk * shard] to [shard * block * chunk]
- # The output format is the same as the fp32 master parameters
- works = [None]*self._num_chunks
- for chunk_id in range(self._num_chunks):
- glob_chunk_id = block_id * self._num_chunks + chunk_id
- rs_stream = self._rs_st[glob_chunk_id%self._num_rs_pg]
- rs_stream.wait_stream(torch.cuda.current_stream())
- rs_stream.wait_stream(self._l2_grad_norm_st)
- with torch.cuda.stream(rs_stream):
- if self._reduce_scatter_no_copy:
- works[chunk_id] = torch.distributed.reduce_scatter(
- output=self._fp16_g_chunks[block_id][chunk_id],
- input_list=self._flat_grads_shards[block_id][chunk_id],
- group=self._rs_pg[glob_chunk_id%self._num_rs_pg],
- async_op=True,
- no_copy=True,
- op=_make_nccl_premul_sum(scale),
- )
- else:
- works[chunk_id] = torch.distributed.reduce_scatter_tensor(
- output=self._fp16_g_chunks[block_id][chunk_id],
- input=self._flat_grads_chunks[block_id][chunk_id],
- group=self._rs_pg[glob_chunk_id%self._num_rs_pg],
- async_op=True,
- op=_make_nccl_premul_sum(scale),
- )
- # Reduction across nodes for each rank
- if self._num_groups > 1:
- for chunk_id in range(self._num_chunks):
- glob_chunk_id = block_id * self._num_chunks + chunk_id
- ar_stream = self._ar_st[glob_chunk_id%self._num_ar_pg]
- with torch.cuda.stream(ar_stream):
- works[chunk_id].wait()
- works[chunk_id] = torch.distributed.all_reduce(self._fp16_g_chunks[block_id][chunk_id],group=self._ar_pg[glob_chunk_id%self._num_ar_pg],async_op=True)
- self._reductions_works[block_id] = works
- def _reduce_scatter_and_all_reduce(self, block_id):
- # Reduction within each node
- # Changes gradient format from [block * chunk * shard] to [shard * block * chunk]
- # The output format is the same as the fp32 master parameters
- works = [None]*self._num_chunks
- for chunk_id in range(self._num_chunks):
- glob_chunk_id = block_id * self._num_chunks + chunk_id
- rs_stream = self._rs_st[glob_chunk_id%self._num_rs_pg]
- rs_stream.wait_stream(torch.cuda.current_stream())
- with torch.cuda.stream(rs_stream):
- if self._reduce_scatter_no_copy:
- works[chunk_id] = torch.distributed.reduce_scatter(
- output=self._fp16_g_chunks[block_id][chunk_id],
- input_list=self._flat_grads_shards[block_id][chunk_id],
- group=self._rs_pg[glob_chunk_id%self._num_rs_pg],
- async_op=True,
- no_copy=True,
- )
- else:
- works[chunk_id] = torch.distributed.reduce_scatter_tensor(
- output = self._fp16_g_chunks[block_id][chunk_id],
- input = self._flat_grads_chunks[block_id][chunk_id],
- group = self._rs_pg[glob_chunk_id%self._num_rs_pg],
- async_op = True,
- )
- # Reduction across nodes for each rank
- if self._num_groups > 1:
- for chunk_id in range(self._num_chunks):
- glob_chunk_id = block_id * self._num_chunks + chunk_id
- ar_stream = self._ar_st[glob_chunk_id%self._num_ar_pg]
- with torch.cuda.stream(ar_stream):
- works[chunk_id].wait()
- works[chunk_id] = torch.distributed.all_reduce(self._fp16_g_chunks[block_id][chunk_id],group=self._ar_pg[glob_chunk_id%self._num_ar_pg],async_op=True)
- self._reductions_works[block_id] = works
- def _pipeline_block_reductions(self, block_id):
- if self._clip_after_ar:
- self._flatten_grad_mt(1.0/self._world_size)
- if self._full_ar:
- self._full_all_reduce(block_id)
- else:
- self._reduce_scatter_and_all_reduce(block_id)
- # Compute L2 grad norm
- if block_id == 0:
- with torch.cuda.stream(self._l2_grad_norm_st):
- for block_id in range(self._num_blocks):
- for chunk_id in range(self._num_chunks):
- self._reductions_works[block_id][chunk_id].wait()
- # Since the packed format is contiguous after reductions, only one norm is needed
- l2_grad_norm_sq = torch.empty([1], device='cuda')
- if self._full_ar:
- # this flattening of lists is to keep multi_tensor_apply function happy, it wants depth=1 for l2 norm computation
- flat_list = [item for sublist in self._fp16_g_chunks for item in sublist]
- l2_grad_norm_sq = multi_tensor_applier(self.multi_tensor_l2norm, self._overflow_buf, [flat_list], False)[0]**2
- else:
- l2_grad_norm_sq = self._fp16_g.norm(dtype=torch.float32, p=2)**2
- torch.distributed.all_reduce(l2_grad_norm_sq, group=self._l2_grad_norm_pg)
- self._L2_grad_norm = l2_grad_norm_sq.sqrt()
- else:
- # Copy model grads to flat grads buffer
- self._flatten_grad_mt(1.0)
- # Compute L2 grad norm
- self._l2_grad_norm_st.wait_stream(torch.cuda.current_stream())
- with torch.cuda.stream(self._l2_grad_norm_st):
- if not self._fused_norm:
- self._L2_grad_norm = self._flat_grads.norm(dtype=torch.float16, p=2).float()
- torch.cuda.current_stream().wait_stream(self._l2_grad_norm_st)
- # Apply clipping & pre-reduction scaling on grads
- loss_scale = self.global_scale
- max_grad_norm = loss_scale*self.defaults['max_grad_norm']
- coeff = max_grad_norm /(1e-6+self.L2_grad_norm)
- coeff = (coeff>1) * self._one + (coeff<=1) * coeff
- tmp = torch.cat(((self._one), (coeff)))
- index = (coeff+1>coeff).int()
- scale = tmp.index_select(0, index).half()/self._world_size
- if not self._fuse_scale:
- self._flat_grads.mul_(scale)
- if self._full_ar:
- if self._fuse_scale:
- self._full_all_reduce_scale(block_id, scale)
- else:
- self._full_all_reduce(block_id)
- else:
- if self._fuse_scale:
- self._reduce_scatter_and_all_reduce_scale(block_id, scale)
- else:
- self._reduce_scatter_and_all_reduce(block_id)
- if block_id == 0:
- for block_id in range(self._num_blocks):
- for chunk_id in range(self._num_chunks):
- self._reductions_works[block_id][chunk_id].wait()
- def __compute_contrib_param_norm(self):
- if self._contrib_model_param_for_norm_fp16 is not None and self._contrib_model_param_for_norm_fp32 is not None:
- gnorm_fp16 = multi_tensor_applier(self.multi_tensor_l2norm, self._overflow_buf, [self._contrib_model_param_for_norm_fp16], True)[1]
- gnorm_fp32 = multi_tensor_applier(self.multi_tensor_l2norm, self._overflow_buf, [self._contrib_model_param_for_norm_fp32], True)[1]
- gnorm = torch.empty(size=[self._contrib_model_param_for_norm_num], dtype=torch.bool, device='cuda')
- gnorm.masked_scatter_(self._contrib_model_param_for_norm_is_fp16, gnorm_fp16)
- gnorm.masked_scatter_(self._contrib_model_param_for_norm_is_fp32, gnorm_fp32)
- elif self._contrib_model_param_for_norm_fp16 is not None:
- gnorm = multi_tensor_applier(self.multi_tensor_l2norm, self._overflow_buf, [self._contrib_model_param_for_norm_fp16], True)[1]
- elif self._contrib_model_param_for_norm_fp32 is not None:
- gnorm = multi_tensor_applier(self.multi_tensor_l2norm, self._overflow_buf, [self._contrib_model_param_for_norm_fp32], True)[1]
- return gnorm
- def __compute_contrib_update_norm(self):
- l2_norm = torch.zeros(size=[self._model_params_num], dtype=torch.float32, device='cuda')
- local_contrib_l2_norm = multi_tensor_applier(self.multi_tensor_l2norm, self._overflow_buf, [self._contrib_update_frag_for_norm], True)[1] ** 2
- l2_norm.scatter_(dim=0, index=self._offsets, src=local_contrib_l2_norm)
- torch.distributed.all_reduce(l2_norm, group=self._ag_pg[0])
- l2_norm = torch.sqrt(l2_norm)
- return l2_norm
- def _pipeline_step(self):
- global_scale = self.global_scale
- # if clip before ar, set max_grad_norm to 0
- max_grad_norm = self.defaults['max_grad_norm'] * self._clip_after_ar
- self._completion_st.wait_stream(self._l2_grad_norm_st)
- global_grad_norm = self.L2_grad_norm
- # check global_grad_norm and fill overflow_buf
- is_finite = (global_grad_norm + 1 > global_grad_norm).int()
- self._overflow_buf = self._one * (is_finite ^ self._one) # toggle between 0 and 1
- if not self._clip_after_ar:
- torch.distributed.all_reduce(is_finite,
- op=torch.distributed.ReduceOp.MIN,
- group=self._current_process_group)
- torch.distributed.all_reduce(self._overflow_buf,
- op=torch.distributed.ReduceOp.MAX,
- group=self._current_process_group)
- # increment step counter if no overflow
- self._step += is_finite
- self._completion_st.wait_stream(torch.cuda.current_stream())
- self._completion_st.wait_stream(self._l2_grad_norm_st)
- # Call step kernel once per step
- # Call all-gather once per step
- with torch.cuda.stream(self._completion_st):
- for block_id in range(self._num_blocks):
- for chunk_id in range(self._num_chunks):
- self._reductions_works[block_id][chunk_id].wait()
- param_norm = self.__compute_contrib_param_norm()
- multi_tensor_applier(self.multi_tensor_lamb_compute_update_term,
- self._overflow_buf,
- self._contrib_compute_update_term_tensor_list, # g, p, m, v, u
- self._contrib_beta1,
- self._contrib_beta2,
- self._contrib_beta3,
- self._contrib_bias_correction,
- self._step,
- self._contrib_epsilon,
- self._adam_w_mode,
- self._contrib_weight_decay,
- global_scale,
- global_grad_norm,
- max_grad_norm)
- upd_norm = self.__compute_contrib_update_norm()
- multi_tensor_applier(self.multi_tensor_lamb_update_weights,
- self._overflow_buf,
- self._contrib_update_weights_tensor_list, # u, p, p_copy
- param_norm,
- upd_norm,
- self._offsets,
- self._lr,
- self._contrib_weight_decay,
- global_grad_norm,
- self._use_nvlamb)
- if not self._skip_ag:
- # allgather chunking is currently not supported for clip after allreduce
- if not self._clip_after_ar:
- for block in range(self._num_blocks):
- for chunk in range(self._num_chunks):
- if self._all_gather_no_copy:
- torch.distributed.all_gather(
- tensor_list = self._new_params2_shards[block][chunk],
- tensor = self._fp16_p_chunks[block][chunk],
- group = self._ag_pg[0],
- no_copy = True,
- )
- else:
- torch.distributed.all_gather_into_tensor(
- output_tensor = self._new_params2_blocks[block],
- input_tensor = self._fp16_p_chunks[block][chunk],
- group = self._ag_pg[0],
- )
- else:
- if self._all_gather_no_copy:
- torch.distributed.all_gather(
- tensor_list = self._new_params_mega_shards,
- tensor = self._fp16_p,
- group = self._ag_pg[0],
- no_copy = True,
- )
- else:
- torch.distributed.all_gather_into_tensor(
- output_tensor = self._new_params,
- input_tensor = self._fp16_p,
- group = self._ag_pg[0],
- )
- def _flatten_grad_mt(self, scale):
- if len(self._grads_fp16) > 0:
- self._overflow_buf.zero_()
- if not self._fused_norm:
- multi_tensor_applier(
- amp_C.multi_tensor_scale,
- self._overflow_buf,
- list(zip(*self._grads_fp16)),
- scale)
- else:
- self._L2_grad_norm=multi_tensor_applier(
- amp_C.multi_tensor_l2norm_scale,
- self._overflow_buf,
- list(zip(*self._grads_fp16)),
- scale, False)[0].float()
- self._grads_fp16 = []
- if len(self._grads_fp32) > 0:
- self._overflow_buf.zero_()
- if not self._fused_norm:
- multi_tensor_applier(
- amp_C.multi_tensor_scale,
- self._overflow_buf,
- list(zip(*self._grads_fp32)),
- scale)
- else:
- self._L2_grad_norm=multi_tensor_applier(
- amp_C.multi_tensor_l2norm_scale,
- self._overflow_buf,
- list(zip(*self._grads_fp32)),
- scale, False)[0].float()
- self._grads_fp32 = []
- def _do_overlapped_reduction(self, param_i, param):
- if not self._is_accumulation_step:
- # handle overlapped reductions
- if param.dtype == torch.float16:
- self._grads_fp16.append( (param.grad, self._individual_flat_grads[param_i]) )
- else:
- self._grads_fp32.append( (param.grad, self._individual_flat_grads[param_i]) )
- self._grads_generated[param_i]=True
- if not self._first_step and not self._last_step:
- if self._overlap_reductions:
- flush_block = self._get_flush_block()
- while flush_block:
- block_id = flush_block[0] // self._block_size
- self._pipeline_block_reductions(block_id)
- flush_block = self._get_flush_block()
- def set_global_scale(self, global_scale):
- """Set global scale.
- """
- self._global_scale = global_scale
- @property
- def global_scale(self):
- return self._global_scale
- @property
- def L2_grad_norm(self):
- torch.cuda.current_stream().wait_stream(self._l2_grad_norm_st)
- return self._L2_grad_norm
- def complete_reductions(self):
- """Complete reductions if full pipeline is not selected or overlap is not allowed.
- """
- if self._last_step:
- # zero out gradients that have not been completed yet
- for param_i, grad_generated in enumerate(self._grads_generated):
- if not grad_generated:
- grad_info = self._grads_info[param_i]
- param_offset = grad_info["param_offset"]
- param_size = grad_info["param_grads_size"]
- self._flat_grads[param_offset:param_offset+param_size].zero_()
- self._grads_generated[param_i] = True
- if self._first_step or self._last_step or not self._overlap_reductions:
- # nothing done so far, run full pipeline after reductions
- for block_id in range(self._num_blocks-1,-1,-1):
- self._pipeline_block_reductions(block_id)
- torch.cuda.current_stream().wait_stream(self._l2_grad_norm_st)
- self._current_block = self._num_blocks
- self._grads_generated = [False]*len(self._grads_info)
- def step(self, closure=None, grad_scaler=None):
- loss = None
- if closure is not None:
- loss = closure()
- self._pipeline_step()
- if grad_scaler is not None:
- found_inf = self._overflow_buf.float()
- optimizer_state = grad_scaler._per_optimizer_states[id(self)]
- current_device = torch.device('cuda', torch.cuda.current_device())
- optimizer_state["found_inf_per_device"][current_device] = found_inf
- self._completion_st.wait_stream(torch.cuda.current_stream())
- if not self._set_flat_param_view:
- with torch.cuda.stream(self._completion_st):
- # Copy self._new_params to model params
- with torch.no_grad():
- if self._packed_flat_to_model_params_fp16 is not None:
- multi_tensor_applier(
- fused_adam_cuda.maybe_cast_mt,
- self._overflow_buf,
- self._packed_flat_to_model_params_fp16)
- if self._packed_flat_to_model_params_fp32 is not None:
- multi_tensor_applier(
- fused_adam_cuda.maybe_cast_mt,
- self._overflow_buf,
- self._packed_flat_to_model_params_fp32)
- torch.cuda.current_stream().wait_stream(self._completion_st)
- self._reductions_works = [None]*self._num_blocks
- self._allgather_works = [None]*self._num_blocks
- return loss
- def state_dict(self):
- """
- Returns a dict containing the current state of this :class:`DistributedFusedAdam` instance.
- Example::
- checkpoint = {}
- checkpoint['model'] = model.state_dict()
- checkpoint['optimizer'] = optimizer.state_dict()
- torch.save(checkpoint, "saved.pth")
- """
- # save step, master weights and first/second moments
- state_dict = {}
- state_dict['step'] = self._step
- state_dict['fp32_p'] = self._fp32_p
- state_dict['fp32_m'] = self._fp32_m
- state_dict['fp32_v'] = self._fp32_v
- return state_dict
- def load_state_dict(self, state_dict):
- """
- Loads a state_dict created by an earlier call to state_dict().
- If an DistributedFusedAdam instance was constructed from some ``init_optimizer``,
- whose parameters in turn came from ``model``, it is expected that the user
- will call ``model.load_state_dict()`` before
- ``optimizer.load_state_dict()`` is called.
- Example::
- model = torch.nn.Linear(D_in, D_out).cuda().half()
- optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
- optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)
- ...
- checkpoint = torch.load("saved.pth")
- model.load_state_dict(checkpoint['model'])
- optimizer.load_state_dict(checkpoint['optimizer'])
- """
- # restore step, master weights and first/second moments
- self._step = state_dict['step']
- self._fp32_p = state_dict['fp32_p'].to(device="cuda")
- self._fp32_m = state_dict['fp32_m'].to(device="cuda")
- self._fp32_v = state_dict['fp32_v'].to(device="cuda")
- self._resume_from_checkpoint = True
|