two_gpu_unit_test.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. import torch
  2. import numpy as np
  3. import apex
  4. import syncbn
  5. import os
  6. import argparse
  7. import torch.optim as optim
  8. def compare(desc, inp1, inp2, error):
  9. a = inp1.clone().detach().cpu().numpy()
  10. b = inp2.clone().detach().cpu().numpy()
  11. close = np.allclose(a,b, error, error)
  12. if not close:
  13. print(desc, close)
  14. z = a - b
  15. index = (np.abs(z) >= error + error * np.abs(b)).nonzero()
  16. print("dif : ", z[index])
  17. print("inp1 : ", a[index])
  18. print("inp2 : ", b[index])
  19. return close
  20. feature_size = 10
  21. space_size = 40
  22. batch_size = 32
  23. from apex.parallel import DistributedDataParallel as DDP
  24. parser = argparse.ArgumentParser()
  25. parser.add_argument("--local_rank", default=0, type=int)
  26. parser.add_argument("--fp16", action='store_true', default=False)
  27. parser.add_argument("--fp64", action='store_true', default=False)
  28. args = parser.parse_args()
  29. args.world_size = int(os.environ['WORLD_SIZE'])
  30. torch.cuda.set_device(args.local_rank)
  31. torch.distributed.init_process_group(backend='nccl', init_method='env://')
  32. start = args.local_rank * batch_size//args.world_size
  33. finish = (args.local_rank + 1) * batch_size//args.world_size
  34. error = 1e-5
  35. dtype = np.float32
  36. if args.fp16:
  37. error = 1e-3
  38. dtype = np.float16
  39. elif args.fp64:
  40. error = 1e-8
  41. dtype = np.float64
  42. np.random.seed(18)
  43. inp = np.random.randn(batch_size, feature_size, space_size, space_size).astype(dtype)
  44. grad = np.random.randn(batch_size, feature_size, space_size, space_size).astype(dtype)
  45. weight = np.random.randn(feature_size).astype(dtype)
  46. bias = np.random.randn(feature_size).astype(dtype)
  47. type_tensor = torch.cuda.FloatTensor
  48. if args.fp16:
  49. type_tensor = torch.cuda.HalfTensor
  50. if args.fp64:
  51. type_tensor = torch.cuda.DoubleTensor
  52. ref_tensor = torch.cuda.DoubleTensor
  53. inp_t = type_tensor(inp)
  54. weight_t = type_tensor(weight)
  55. bias_t = type_tensor(bias)
  56. inp_r = ref_tensor(inp.transpose(1, 0, 2, 3).reshape(feature_size, -1))
  57. inp2_r = ref_tensor(inp)
  58. weight_r = ref_tensor(weight).view(-1, 1, 1)
  59. bias_r = ref_tensor(bias).view(-1, 1, 1)
  60. grad_output_t = type_tensor(grad)
  61. m = inp_r.mean(1)
  62. b_v = inp_r.var(1, unbiased=False)
  63. unb_v = inp_r.var(1, unbiased=True)
  64. eps = 1e-5
  65. mean, var_biased = syncbn.welford_mean_var(inp_t)
  66. inv_std = 1.0 / torch.sqrt(var_biased + eps)
  67. bn = torch.nn.BatchNorm2d(feature_size).cuda()
  68. bn.momentum = 1.0
  69. bn.weight.data = weight_t.clone()
  70. bn.bias.data = bias_t.clone()
  71. if args.fp16:
  72. bn.half()
  73. if args.fp64:
  74. bn.double()
  75. inp_bn = inp_t.clone().requires_grad_()
  76. grad_bn = grad_output_t.clone().detach()
  77. out_bn = bn(inp_bn)
  78. out_bn.backward(grad_bn)
  79. # compensating the averaging over processes done by DDP
  80. # in order to produce mathematically equivalent result
  81. # https://github.com/NVIDIA/apex/issues/134#issuecomment-458307368
  82. for param in bn.parameters():
  83. param.grad = param.grad / args.world_size
  84. bn_opt = optim.SGD(bn.parameters(), lr=1.0)
  85. sbn = apex.parallel.SyncBatchNorm(feature_size).cuda()
  86. sbn.momentum = 1.0
  87. sbn.weight.data = weight_t.clone()
  88. sbn.bias.data = bias_t.clone()
  89. if args.fp16:
  90. sbn.half()
  91. if args.fp64:
  92. sbn.double()
  93. sbn = DDP(sbn)
  94. sbn_opt = optim.SGD(sbn.parameters(), lr=1.0)
  95. inp_sbn = inp_t.clone().requires_grad_()
  96. grad_sbn = grad_output_t.clone().detach()
  97. out_sbn = sbn(inp_sbn[start:finish])
  98. out_sbn.backward(grad_sbn[start:finish])
  99. count = [ space_size**2 * ( (i+1) * batch_size // args.world_size - i * batch_size // args.world_size ) for i in range(0, args.world_size)]
  100. count = torch.cuda.IntTensor(count)
  101. print("--- count : " , count)
  102. sbn_result = True
  103. bn_result = True
  104. if args.local_rank == 0:
  105. sbn_result = compare("comparing mean: ", mean, m, error) and sbn_result
  106. sbn_result = compare("comparing biased variance: ", var_biased, b_v, error) and sbn_result
  107. out = syncbn.batchnorm_forward(inp_t, mean, inv_std, weight_t, bias_t)
  108. out_r = weight_r * (inp2_r - m.view(-1, 1, 1)) * torch.rsqrt(b_v.view(-1,1,1) + eps) + bias_r
  109. if args.local_rank == 0:
  110. sbn_result = compare("comparing output: ", out, out_r, error) and sbn_result
  111. compare("comparing bn output: ", out_bn, out_r, error)
  112. grad_output_t = type_tensor(grad)
  113. grad_output_r = ref_tensor(grad.transpose(1, 0, 2, 3).reshape(feature_size, -1))
  114. grad_output2_r = ref_tensor(grad)
  115. grad_bias_r = grad_output_r.sum(1)
  116. grad_weight_r = ((inp2_r - m.view(-1, 1, 1)) * torch.rsqrt(b_v.view(-1,1,1) + eps) * grad_output2_r).transpose(1,0).contiguous().view(feature_size, -1).sum(1)
  117. sum_dy_r = grad_output_r.sum(1)
  118. mean_dy_r = grad_output_r.mean(1)
  119. mean_dy_xmu_r = ((inp2_r - m.view(-1, 1, 1)) * grad_output2_r).transpose(1,0).contiguous().view(feature_size, -1).mean(1)
  120. sum_dy_xmu_r = ((inp2_r - m.view(-1, 1, 1)) * grad_output2_r).transpose(1,0).contiguous().view(feature_size, -1).sum(1)
  121. grad_input_r = (grad_output2_r - mean_dy_r.view(-1, 1, 1) - (inp2_r - m.view(-1, 1, 1)) / (b_v.view(-1,1,1) + eps) * mean_dy_xmu_r.view(-1, 1, 1) ) * torch.rsqrt(b_v.view(-1,1,1) + eps) * weight_r.view(-1,1,1)
  122. sum_dy, sum_dy_xmu, grad_weight, grad_bias = syncbn.reduce_bn(grad_output_t, inp_t, mean, inv_std, weight_t)
  123. grad_input = syncbn.batchnorm_backward(grad_output_t, inp_t, mean, inv_std, weight_t, sum_dy, sum_dy_xmu, count)
  124. if args.local_rank == 0:
  125. sbn_result = compare("comparing bias grad: ", grad_bias, grad_bias_r, error) and sbn_result
  126. sbn_result = compare("comparing weight grad: ", grad_weight, grad_weight_r, error) and sbn_result
  127. sbn_result = compare("comparing sum_dy grad: ", sum_dy, sum_dy_r, error) and sbn_result
  128. sbn_result = compare("comparing sum_dy_xmu grad: ", sum_dy_xmu, sum_dy_xmu_r, error) and sbn_result
  129. sbn_result = compare("comparing input grad: ", grad_input, grad_input_r, error) and sbn_result
  130. compare("comparing bn input grad: ", inp_bn.grad, grad_input_r, error)
  131. if args.local_rank == 0:
  132. sbn_result = compare("comparing running_mean: ", bn.running_mean.data, sbn.module.running_mean.data, error) and sbn_result
  133. sbn_result = compare("comparing running_variance: ", bn.running_var.data, sbn.module.running_var.data, error) and sbn_result
  134. # execute by both
  135. compare("comparing layers output: ", out_bn[start:finish], out_sbn, error) and sbn_result
  136. compare("comparing layers grad_input: ", inp_bn.grad[start:finish], inp_sbn.grad[start:finish], error) and sbn_result
  137. bn_opt.step()
  138. sbn_opt.step()
  139. if args.local_rank == 0:
  140. compare("comparing bn vs sbn bias: ", bn.bias, sbn.module.bias, error)
  141. compare("comparing bn vs sbn weight: ", bn.weight, sbn.module.weight, error)
  142. if sbn_result:
  143. print("====SBN two gpu passed tests")
  144. else:
  145. print("*SBN two gpu failed*")