import torch import torch.distributed as dist from torch.nn import Parameter from torch.nn import Module from apex.parallel import DistributedDataParallel as DDP import argparse import os parser = argparse.ArgumentParser(description='allreduce hook example') parser.add_argument("--local_rank", default=0, type=int) args = parser.parse_args() args.distributed = False if 'WORLD_SIZE' in os.environ: args.distributed = int(os.environ['WORLD_SIZE']) > 1 if args.distributed: args.gpu = args.local_rank % torch.cuda.device_count() torch.cuda.set_device(args.gpu) torch.distributed.init_process_group(backend='nccl', init_method='env://') args.world_size = torch.distributed.get_world_size() torch.set_printoptions(precision=10) torch.manual_seed(args.local_rank) class Model(Module): def __init__(self): super(Model, self).__init__() self.a = Parameter(torch.cuda.FloatTensor(4096*4096).fill_(1.0)) self.b = Parameter(torch.cuda.FloatTensor(4096*4096).fill_(2.0)) def forward(self, input): return (input*self.a)*self.b model = Model() # model = DDP(model, message_size=1, gradient_predivide_factor=8.0) # model = DDP(model, delay_allreduce=True) # model = DDP(model, message_size=1, allreduce_trigger_params=[model.b]) model = DDP(model, message_size=1, allreduce_trigger_params=[model.b], num_allreduce_streams=3) x = torch.cuda.FloatTensor(4096*4096) passed = True torch.cuda.cudart().cudaProfilerStart() for i in range(10): x.fill_(i + args.local_rank) # fill x with new values every iteration for sanity model.zero_grad() out = model(x) loss = out.sum() # torch.cuda.nvtx.range_push("backward") loss.backward() # torch.cuda.nvtx.range_pop() # torch.cuda.nvtx.range_push("synchronize() + info") # torch.cuda.synchronize() print("i = {}".format(i)) def info(name, param, val): expected = val*4096*4096*(2.*i+1)/2. actual = param.grad.data.sum().item() print(name+": grad.data_ptr() = {}, expected sum {}, got {}".format( param.grad.data_ptr(), expected, actual)) return (expected == actual) if not info("model.a", model.module.a, 2.): passed = False if not info("model.b", model.module.b, 1.): passed = False # torch.cuda.nvtx.range_pop() torch.cuda.cudart().cudaProfilerStop() print("passed = ", passed)