1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465 |
- import torch
- import argparse
- import os
- from apex import amp
- # FOR DISTRIBUTED: (can also use torch.nn.parallel.DistributedDataParallel instead)
- from apex.parallel import DistributedDataParallel
- parser = argparse.ArgumentParser()
- # FOR DISTRIBUTED: Parse for the local_rank argument, which will be supplied
- # automatically by torch.distributed.launch.
- parser.add_argument("--local_rank", default=0, type=int)
- args = parser.parse_args()
- # FOR DISTRIBUTED: If we are running under torch.distributed.launch,
- # the 'WORLD_SIZE' environment variable will also be set automatically.
- args.distributed = False
- if 'WORLD_SIZE' in os.environ:
- args.distributed = int(os.environ['WORLD_SIZE']) > 1
- if args.distributed:
- # FOR DISTRIBUTED: Set the device according to local_rank.
- torch.cuda.set_device(args.local_rank)
- # FOR DISTRIBUTED: Initialize the backend. torch.distributed.launch will provide
- # environment variables, and requires that you use init_method=`env://`.
- torch.distributed.init_process_group(backend='nccl',
- init_method='env://')
- torch.backends.cudnn.benchmark = True
- N, D_in, D_out = 64, 1024, 16
- # Each process receives its own batch of "fake input data" and "fake target data."
- # The "training loop" in each process just uses this fake batch over and over.
- # https://github.com/NVIDIA/apex/tree/master/examples/imagenet provides a more realistic
- # example of distributed data sampling for both training and validation.
- x = torch.randn(N, D_in, device='cuda')
- y = torch.randn(N, D_out, device='cuda')
- model = torch.nn.Linear(D_in, D_out).cuda()
- optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
- model, optimizer = amp.initialize(model, optimizer, opt_level="O1")
- if args.distributed:
- # FOR DISTRIBUTED: After amp.initialize, wrap the model with
- # apex.parallel.DistributedDataParallel.
- model = DistributedDataParallel(model)
- # torch.nn.parallel.DistributedDataParallel is also fine, with some added args:
- # model = torch.nn.parallel.DistributedDataParallel(model,
- # device_ids=[args.local_rank],
- # output_device=args.local_rank)
- loss_fn = torch.nn.MSELoss()
- for t in range(500):
- optimizer.zero_grad()
- y_pred = model(x)
- loss = loss_fn(y_pred, y)
- with amp.scale_loss(loss, optimizer) as scaled_loss:
- scaled_loss.backward()
- optimizer.step()
- if args.local_rank == 0:
- print("final loss = ", loss)
|