distributed_data_parallel.py 2.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. import torch
  2. import argparse
  3. import os
  4. from apex import amp
  5. # FOR DISTRIBUTED: (can also use torch.nn.parallel.DistributedDataParallel instead)
  6. from apex.parallel import DistributedDataParallel
  7. parser = argparse.ArgumentParser()
  8. # FOR DISTRIBUTED: Parse for the local_rank argument, which will be supplied
  9. # automatically by torch.distributed.launch.
  10. parser.add_argument("--local_rank", default=0, type=int)
  11. args = parser.parse_args()
  12. # FOR DISTRIBUTED: If we are running under torch.distributed.launch,
  13. # the 'WORLD_SIZE' environment variable will also be set automatically.
  14. args.distributed = False
  15. if 'WORLD_SIZE' in os.environ:
  16. args.distributed = int(os.environ['WORLD_SIZE']) > 1
  17. if args.distributed:
  18. # FOR DISTRIBUTED: Set the device according to local_rank.
  19. torch.cuda.set_device(args.local_rank)
  20. # FOR DISTRIBUTED: Initialize the backend. torch.distributed.launch will provide
  21. # environment variables, and requires that you use init_method=`env://`.
  22. torch.distributed.init_process_group(backend='nccl',
  23. init_method='env://')
  24. torch.backends.cudnn.benchmark = True
  25. N, D_in, D_out = 64, 1024, 16
  26. # Each process receives its own batch of "fake input data" and "fake target data."
  27. # The "training loop" in each process just uses this fake batch over and over.
  28. # https://github.com/NVIDIA/apex/tree/master/examples/imagenet provides a more realistic
  29. # example of distributed data sampling for both training and validation.
  30. x = torch.randn(N, D_in, device='cuda')
  31. y = torch.randn(N, D_out, device='cuda')
  32. model = torch.nn.Linear(D_in, D_out).cuda()
  33. optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
  34. model, optimizer = amp.initialize(model, optimizer, opt_level="O1")
  35. if args.distributed:
  36. # FOR DISTRIBUTED: After amp.initialize, wrap the model with
  37. # apex.parallel.DistributedDataParallel.
  38. model = DistributedDataParallel(model)
  39. # torch.nn.parallel.DistributedDataParallel is also fine, with some added args:
  40. # model = torch.nn.parallel.DistributedDataParallel(model,
  41. # device_ids=[args.local_rank],
  42. # output_device=args.local_rank)
  43. loss_fn = torch.nn.MSELoss()
  44. for t in range(500):
  45. optimizer.zero_grad()
  46. y_pred = model(x)
  47. loss = loss_fn(y_pred, y)
  48. with amp.scale_loss(loss, optimizer) as scaled_loss:
  49. scaled_loss.backward()
  50. optimizer.step()
  51. if args.local_rank == 0:
  52. print("final loss = ", loss)