test_batchnorm1d.py 484 B

123456789101112131415161718
  1. import torch
  2. import apex
  3. model = apex.parallel.SyncBatchNorm(4).cuda()
  4. model.weight.data.uniform_()
  5. model.bias.data.uniform_()
  6. data = torch.rand((8,4)).cuda()
  7. model_ref = torch.nn.BatchNorm1d(4).cuda()
  8. model_ref.load_state_dict(model.state_dict())
  9. data_ref = data.clone()
  10. output = model(data)
  11. output_ref = model_ref(data_ref)
  12. assert(output.allclose(output_ref))
  13. assert(model.running_mean.allclose(model_ref.running_mean))
  14. assert(model.running_var.allclose(model_ref.running_var))