amp_master_params.py 2.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970
  1. import torch
  2. import argparse
  3. import os
  4. from apex import amp
  5. # FOR DISTRIBUTED: (can also use torch.nn.parallel.DistributedDataParallel instead)
  6. from apex.parallel import DistributedDataParallel
  7. parser = argparse.ArgumentParser()
  8. # FOR DISTRIBUTED: Parse for the local_rank argument, which will be supplied
  9. # automatically by torch.distributed.launch.
  10. parser.add_argument("--local_rank", default=0, type=int)
  11. args = parser.parse_args()
  12. # FOR DISTRIBUTED: If we are running under torch.distributed.launch,
  13. # the 'WORLD_SIZE' environment variable will also be set automatically.
  14. args.distributed = False
  15. if 'WORLD_SIZE' in os.environ:
  16. args.distributed = int(os.environ['WORLD_SIZE']) > 1
  17. if args.distributed:
  18. # FOR DISTRIBUTED: Set the device according to local_rank.
  19. torch.cuda.set_device(args.local_rank)
  20. # FOR DISTRIBUTED: Initialize the backend. torch.distributed.launch will provide
  21. # environment variables, and requires that you use init_method=`env://`.
  22. torch.distributed.init_process_group(backend='nccl',
  23. init_method='env://')
  24. torch.manual_seed(torch.distributed.get_rank())
  25. torch.backends.cudnn.benchmark = True
  26. N, D_in, D_out = 64, 1024, 16
  27. # Each process receives its own batch of "fake input data" and "fake target data."
  28. # The "training loop" in each process just uses this fake batch over and over.
  29. # https://github.com/NVIDIA/apex/tree/master/examples/imagenet provides a more realistic
  30. # example of distributed data sampling for both training and validation.
  31. x = torch.randn(N, D_in, device='cuda')
  32. y = torch.randn(N, D_out, device='cuda')
  33. model = torch.nn.Linear(D_in, D_out).cuda()
  34. optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
  35. model, optimizer = amp.initialize(model, optimizer, opt_level="O2")
  36. if args.distributed:
  37. # FOR DISTRIBUTED: After amp.initialize, wrap the model with
  38. # apex.parallel.DistributedDataParallel.
  39. model = DistributedDataParallel(model)
  40. # torch.nn.parallel.DistributedDataParallel is also fine, with some added args:
  41. # model = torch.nn.parallel.DistributedDataParallel(model,
  42. # device_ids=[args.local_rank],
  43. # output_device=args.local_rank)
  44. loss_fn = torch.nn.MSELoss()
  45. for t in range(500):
  46. optimizer.zero_grad()
  47. y_pred = model(x)
  48. loss = loss_fn(y_pred, y)
  49. with amp.scale_loss(loss, optimizer) as scaled_loss:
  50. scaled_loss.backward()
  51. optimizer.step()
  52. if args.local_rank == 0:
  53. print("final loss = ", loss)
  54. torch.save(list(model.parameters()), "rank{}model.pth".format(torch.distributed.get_rank()))
  55. torch.save(list(amp.master_params(optimizer)), "rank{}master.pth".format(torch.distributed.get_rank()))