123456789101112131415161718 |
- import torch
- import apex
- model = apex.parallel.SyncBatchNorm(4).cuda()
- model.weight.data.uniform_()
- model.bias.data.uniform_()
- data = torch.rand((8,4)).cuda()
- model_ref = torch.nn.BatchNorm1d(4).cuda()
- model_ref.load_state_dict(model.state_dict())
- data_ref = data.clone()
- output = model(data)
- output_ref = model_ref(data_ref)
- assert(output.allclose(output_ref))
- assert(model.running_mean.allclose(model_ref.running_mean))
- assert(model.running_var.allclose(model_ref.running_var))
|