ddp_race_condition_test.py 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. import torch
  2. import torch.distributed as dist
  3. from torch.nn import Parameter
  4. from torch.nn import Module
  5. from apex.parallel import DistributedDataParallel as DDP
  6. import argparse
  7. import os
  8. parser = argparse.ArgumentParser(description='allreduce hook example')
  9. parser.add_argument("--local_rank", default=0, type=int)
  10. args = parser.parse_args()
  11. args.distributed = False
  12. if 'WORLD_SIZE' in os.environ:
  13. args.distributed = int(os.environ['WORLD_SIZE']) > 1
  14. if args.distributed:
  15. args.gpu = args.local_rank % torch.cuda.device_count()
  16. torch.cuda.set_device(args.gpu)
  17. torch.distributed.init_process_group(backend='nccl',
  18. init_method='env://')
  19. args.world_size = torch.distributed.get_world_size()
  20. torch.set_printoptions(precision=10)
  21. torch.manual_seed(args.local_rank)
  22. class Model(Module):
  23. def __init__(self):
  24. super(Model, self).__init__()
  25. self.a = Parameter(torch.cuda.FloatTensor(4096*4096).fill_(1.0))
  26. self.b = Parameter(torch.cuda.FloatTensor(4096*4096).fill_(2.0))
  27. def forward(self, input):
  28. return (input*self.a)*self.b
  29. model = Model()
  30. # model = DDP(model, message_size=1, gradient_predivide_factor=8.0)
  31. # model = DDP(model, delay_allreduce=True)
  32. # model = DDP(model, message_size=1, allreduce_trigger_params=[model.b])
  33. model = DDP(model, message_size=1, allreduce_trigger_params=[model.b], num_allreduce_streams=3)
  34. x = torch.cuda.FloatTensor(4096*4096)
  35. passed = True
  36. torch.cuda.cudart().cudaProfilerStart()
  37. for i in range(10):
  38. x.fill_(i + args.local_rank) # fill x with new values every iteration for sanity
  39. model.zero_grad()
  40. out = model(x)
  41. loss = out.sum()
  42. # torch.cuda.nvtx.range_push("backward")
  43. loss.backward()
  44. # torch.cuda.nvtx.range_pop()
  45. # torch.cuda.nvtx.range_push("synchronize() + info")
  46. # torch.cuda.synchronize()
  47. print("i = {}".format(i))
  48. def info(name, param, val):
  49. expected = val*4096*4096*(2.*i+1)/2.
  50. actual = param.grad.data.sum().item()
  51. print(name+": grad.data_ptr() = {}, expected sum {}, got {}".format(
  52. param.grad.data_ptr(), expected, actual))
  53. return (expected == actual)
  54. if not info("model.a", model.module.a, 2.): passed = False
  55. if not info("model.b", model.module.b, 1.): passed = False
  56. # torch.cuda.nvtx.range_pop()
  57. torch.cuda.cudart().cudaProfilerStop()
  58. print("passed = ", passed)