test_groups.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. import torch
  2. import numpy as np
  3. import apex
  4. import syncbn
  5. import os
  6. import argparse
  7. import torch.optim as optim
  8. def compare(desc, inp1, inp2, error):
  9. a = inp1.clone().detach().cpu().numpy()
  10. b = inp2.clone().detach().cpu().numpy()
  11. close = np.allclose(a,b, error, error)
  12. if not close:
  13. print(desc, close)
  14. z = a - b
  15. index = (np.abs(z) >= error + error * np.abs(b)).nonzero()
  16. print("dif : ", z[index])
  17. print("inp1 : ", a[index])
  18. print("inp2 : ", b[index])
  19. return close
  20. feature_size = 10
  21. space_size = 40
  22. batch_size = 32
  23. from apex.parallel import DistributedDataParallel as DDP
  24. parser = argparse.ArgumentParser()
  25. parser.add_argument("--local_rank", default=0, type=int)
  26. parser.add_argument("--fp16", action='store_true', default=False)
  27. parser.add_argument("--fp64", action='store_true', default=False)
  28. parser.add_argument("--group_size", default=0, type=int)
  29. args = parser.parse_args()
  30. try:
  31. args.world_size = int(os.environ['WORLD_SIZE'])
  32. except:
  33. print("This is a multi-gpu test. To run it please use 'python -m torch.distributed.launch --nproc_per_node=<num gpus> test_groups.py <more options>'")
  34. exit(1)
  35. torch.cuda.set_device(args.local_rank)
  36. torch.distributed.init_process_group(backend='nccl', init_method='env://')
  37. start = (args.local_rank%args.group_size) * batch_size//args.group_size
  38. finish = (args.local_rank%args.group_size + 1) * batch_size//args.group_size
  39. error = 1e-5
  40. dtype = np.float32
  41. if args.fp16:
  42. error = 1e-3
  43. dtype = np.float16
  44. elif args.fp64:
  45. error = 1e-8
  46. dtype = np.float64
  47. np.random.seed(18 + args.local_rank//args.group_size)
  48. inp = np.random.randn(batch_size, feature_size, space_size, space_size).astype(dtype)
  49. grad = np.random.randn(batch_size, feature_size, space_size, space_size).astype(dtype)
  50. weight = np.random.randn(feature_size).astype(dtype)
  51. bias = np.random.randn(feature_size).astype(dtype)
  52. type_tensor = torch.cuda.FloatTensor
  53. if args.fp16:
  54. type_tensor = torch.cuda.HalfTensor
  55. if args.fp64:
  56. type_tensor = torch.cuda.DoubleTensor
  57. ref_tensor = torch.cuda.DoubleTensor
  58. inp_t = type_tensor(inp)
  59. weight_t = type_tensor(weight)
  60. bias_t = type_tensor(bias)
  61. inp_r = ref_tensor(inp.transpose(1, 0, 2, 3).reshape(feature_size, -1))
  62. inp2_r = ref_tensor(inp)
  63. weight_r = ref_tensor(weight).view(-1, 1, 1)
  64. bias_r = ref_tensor(bias).view(-1, 1, 1)
  65. grad_output_t = type_tensor(grad)
  66. m = inp_r.mean(1)
  67. b_v = inp_r.var(1, unbiased=False)
  68. unb_v = inp_r.var(1, unbiased=True)
  69. eps = 1e-5
  70. mean, var_biased = syncbn.welford_mean_var(inp_t)
  71. inv_std = 1.0 / torch.sqrt(var_biased + eps)
  72. bn = torch.nn.BatchNorm2d(feature_size).cuda()
  73. bn.momentum = 1.0
  74. bn.weight.data = weight_t.clone()
  75. bn.bias.data = bias_t.clone()
  76. if args.fp16:
  77. bn.half()
  78. if args.fp64:
  79. bn.double()
  80. bn = DDP(bn)
  81. inp_bn = inp_t.clone().requires_grad_()
  82. grad_bn = grad_output_t.clone().detach()
  83. out_bn = bn(inp_bn)
  84. out_bn.backward(grad_bn)
  85. # compensating the averaging over processes done by DDP
  86. # in order to produce mathematically equivalent result
  87. # https://github.com/NVIDIA/apex/issues/134#issuecomment-458307368
  88. for param in bn.parameters():
  89. param.grad = param.grad / args.group_size
  90. bn_opt = optim.SGD(bn.parameters(), lr=1.0)
  91. sbn = apex.parallel.SyncBatchNorm(feature_size, process_group=apex.parallel.create_syncbn_process_group(args.group_size)).cuda()
  92. sbn.momentum = 1.0
  93. sbn.weight.data = weight_t.clone()
  94. sbn.bias.data = bias_t.clone()
  95. if args.fp16:
  96. sbn.half()
  97. if args.fp64:
  98. sbn.double()
  99. sbn = DDP(sbn)
  100. sbn_opt = optim.SGD(sbn.parameters(), lr=1.0)
  101. inp_sbn = inp_t.clone().requires_grad_()
  102. grad_sbn = grad_output_t.clone().detach()
  103. out_sbn = sbn(inp_sbn[start:finish])
  104. out_sbn.backward(grad_sbn[start:finish])
  105. sbn_result = True
  106. bn_result = True
  107. if args.local_rank == 0:
  108. sbn_result = compare("comparing mean: ", mean, m, error) and sbn_result
  109. sbn_result = compare("comparing biased variance: ", var_biased, b_v, error) and sbn_result
  110. out = syncbn.batchnorm_forward(inp_t, mean, inv_std, weight_t, bias_t)
  111. out_r = weight_r * (inp2_r - m.view(-1, 1, 1)) * torch.rsqrt(b_v.view(-1,1,1) + eps) + bias_r
  112. if args.local_rank == 0:
  113. sbn_result = compare("comparing output: ", out, out_r, error) and sbn_result
  114. compare("comparing bn output: ", out_bn, out_r, error)
  115. grad_output_t = type_tensor(grad)
  116. grad_output_r = ref_tensor(grad.transpose(1, 0, 2, 3).reshape(feature_size, -1))
  117. grad_output2_r = ref_tensor(grad)
  118. grad_bias_r = grad_output_r.sum(1)
  119. grad_weight_r = ((inp2_r - m.view(-1, 1, 1)) * torch.rsqrt(b_v.view(-1,1,1) + eps) * grad_output2_r).transpose(1,0).contiguous().view(feature_size, -1).sum(1)
  120. mean_dy_r = grad_output_r.mean(1)
  121. mean_dy_xmu_r = ((inp2_r - m.view(-1, 1, 1)) * grad_output2_r).transpose(1,0).contiguous().view(feature_size, -1).mean(1)
  122. grad_input_r = (grad_output2_r - mean_dy_r.view(-1, 1, 1) - (inp2_r - m.view(-1, 1, 1)) / (b_v.view(-1,1,1) + eps) * mean_dy_xmu_r.view(-1, 1, 1) ) * torch.rsqrt(b_v.view(-1,1,1) + eps) * weight_r.view(-1,1,1)
  123. mean_dy, mean_dy_xmu, grad_weight, grad_bias = syncbn.reduce_bn(grad_output_t, inp_t, mean, inv_std, weight_t)
  124. grad_input = syncbn.batchnorm_backward(grad_output_t, inp_t, mean, inv_std, weight_t, mean_dy, mean_dy_xmu)
  125. if args.local_rank == 0:
  126. sbn_result = compare("comparing bias grad: ", grad_bias, grad_bias_r, error) and sbn_result
  127. sbn_result = compare("comparing weight grad: ", grad_weight, grad_weight_r, error) and sbn_result
  128. sbn_result = compare("comparing mean_dy grad: ", mean_dy, mean_dy_r, error) and sbn_result
  129. sbn_result = compare("comparing mean_dy_xmu grad: ", mean_dy_xmu, mean_dy_xmu_r, error) and sbn_result
  130. sbn_result = compare("comparing input grad: ", grad_input, grad_input_r, error) and sbn_result
  131. compare("comparing bn input grad: ", inp_bn.grad, grad_input_r, error)
  132. if args.local_rank == 0:
  133. sbn_result = compare("comparing running_mean: ", bn.module.running_mean.data, sbn.module.running_mean.data, error) and sbn_result
  134. sbn_result = compare("comparing running_variance: ", bn.module.running_var.data, sbn.module.running_var.data, error) and sbn_result
  135. # execute by both
  136. compare("comparing layers output: ", out_bn[start:finish], out_sbn, error) and sbn_result
  137. compare("comparing layers grad_input: ", inp_bn.grad[start:finish], inp_sbn.grad[start:finish], error) and sbn_result
  138. bn_opt.step()
  139. sbn_opt.step()
  140. if args.local_rank == 0:
  141. compare("comparing bn vs sbn bias: ", bn.module.bias, sbn.module.bias, error)
  142. compare("comparing bn vs sbn weight: ", bn.module.weight, sbn.module.weight, error)
  143. if sbn_result:
  144. print("====SBN group test passed")
  145. else:
  146. print("*SBN group test failed*")