import torch import apex model = apex.parallel.SyncBatchNorm(4).cuda() model.weight.data.uniform_() model.bias.data.uniform_() data = torch.rand((8,4)).cuda() model_ref = torch.nn.BatchNorm1d(4).cuda() model_ref.load_state_dict(model.state_dict()) data_ref = data.clone() output = model(data) output_ref = model_ref(data_ref) assert(output.allclose(output_ref)) assert(model.running_mean.allclose(model_ref.running_mean)) assert(model.running_var.allclose(model_ref.running_var))