single_gpu_unit_test.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159
  1. import torch
  2. import numpy as np
  3. import apex
  4. if True:
  5. print("using setup tools")
  6. import syncbn
  7. else:
  8. print("using jit")
  9. from torch.utils.cpp_extension import load
  10. syncbn = load(name='syncbn', sources=['../../csrc/syncbn.cpp', '../../csrc/welford.cu'])
  11. def compare(desc, inp1, inp2, error):
  12. a = inp1.clone().detach().cpu().numpy()
  13. b = inp2.clone().detach().cpu().numpy()
  14. close = np.allclose(a,b, error, error)
  15. if not close:
  16. print(desc, close)
  17. z = a - b
  18. index = (np.abs(z) >= error + error * np.abs(b)).nonzero()
  19. print("dif : ", z[index])
  20. print("inp1 : ", a[index])
  21. print("inp2 : ", b[index])
  22. return close
  23. feature_size = 10
  24. space_size = 16
  25. batch_size = 5
  26. error = 1e-5
  27. np.random.seed(1)
  28. dtype = np.float32
  29. inp = (np.random.randn(batch_size, feature_size, space_size, space_size)).astype(dtype)
  30. grad = (np.random.randn(batch_size, feature_size, space_size, space_size)).astype(dtype)
  31. weight = (np.random.randn(feature_size)).astype(dtype)
  32. bias = (np.random.randn(feature_size)).astype(dtype)
  33. count = torch.cuda.IntTensor([batch_size*space_size**2])
  34. type_tensor = torch.cuda.FloatTensor
  35. ref_tensor = torch.cuda.DoubleTensor
  36. inp_t = type_tensor(inp)
  37. weight_t = type_tensor(weight)
  38. bias_t = type_tensor(bias)
  39. inp_r = ref_tensor(inp.transpose(1, 0, 2, 3).reshape(feature_size, -1))
  40. inp2_r = ref_tensor(inp)
  41. weight_r = ref_tensor(weight).view(-1, 1, 1)
  42. bias_r = ref_tensor(bias).view(-1, 1, 1)
  43. grad_output_t = type_tensor(grad)
  44. m = inp_r.mean(1)
  45. b_v = inp_r.var(1, unbiased=False)
  46. unb_v = inp_r.var(1, unbiased=True)
  47. eps = 1e-5
  48. #mean, var, var_biased = syncbn.welford_mean_var(inp_t)
  49. mean, var_biased = syncbn.welford_mean_var(inp_t)
  50. inv_std = 1.0 / torch.sqrt(var_biased + eps)
  51. bn = torch.nn.BatchNorm2d(feature_size).cuda()
  52. bn.momentum = 1.0
  53. bn.weight.data = weight_t.clone()
  54. bn.bias.data = bias_t.clone()
  55. inp_bn = inp_t.clone().requires_grad_()
  56. grad_bn = grad_output_t.clone().detach()
  57. out_bn = bn(inp_bn)
  58. out_bn.backward(grad_bn)
  59. sbn = apex.parallel.SyncBatchNorm(feature_size).cuda()
  60. sbn.momentum = 1.0
  61. sbn.weight.data = weight_t.clone()
  62. sbn.bias.data = bias_t.clone()
  63. inp_sbn = inp_t.clone().requires_grad_()
  64. grad_sbn = grad_output_t.clone().detach()
  65. out_sbn = sbn(inp_sbn)
  66. out_sbn.backward(grad_sbn)
  67. sbn_c_last = apex.parallel.SyncBatchNorm(feature_size, channel_last=True).cuda()
  68. sbn_c_last.momentum = 1.0
  69. sbn_c_last.weight.data = weight_t.clone()
  70. sbn_c_last.bias.data = bias_t.clone()
  71. inp_sbn_c_last = inp_t.clone().transpose(-1, 1).contiguous().requires_grad_()
  72. grad_sbn_c_last = grad_output_t.clone().transpose(-1, 1).contiguous().detach()
  73. out_sbn_c_last = sbn_c_last(inp_sbn_c_last)
  74. out_sbn_c_last.backward(grad_sbn_c_last)
  75. sbn_result = True
  76. sbn_result_c_last = True
  77. bn_result = True
  78. sbn_result = compare("comparing mean: ", mean, m, error) and sbn_result
  79. #sbn_result = compare("comparing variance: ", var, unb_v, error) and sbn_result
  80. sbn_result = compare("comparing biased variance: ", var_biased, b_v, error) and sbn_result
  81. out = syncbn.batchnorm_forward(inp_t, mean, inv_std, weight_t, bias_t)
  82. out_r = weight_r * (inp2_r - m.view(-1, 1, 1)) * torch.rsqrt(b_v.view(-1,1,1) + eps) + bias_r
  83. sbn_result = compare("comparing output: ", out, out_r, error) and sbn_result
  84. compare("comparing bn output: ", out_bn, out_r, error)
  85. grad_output_t = type_tensor(grad)
  86. grad_output_r = ref_tensor(grad.transpose(1, 0, 2, 3).reshape(feature_size, -1))
  87. grad_output2_r = ref_tensor(grad)
  88. grad_bias_r = grad_output_r.sum(1)
  89. grad_weight_r = ((inp2_r - m.view(-1, 1, 1)) * torch.rsqrt(b_v.view(-1,1,1) + eps) * grad_output2_r).transpose(1,0).contiguous().view(feature_size, -1).sum(1)
  90. sum_dy_r = grad_output_r.sum(1)
  91. mean_dy_r = grad_output_r.mean(1)
  92. sum_dy_xmu_r = ((inp2_r - m.view(-1, 1, 1)) * grad_output2_r).transpose(1,0).contiguous().view(feature_size, -1).sum(1)
  93. mean_dy_xmu_r = ((inp2_r - m.view(-1, 1, 1)) * grad_output2_r).transpose(1,0).contiguous().view(feature_size, -1).mean(1)
  94. grad_input_r = (grad_output2_r - mean_dy_r.view(-1, 1, 1) - (inp2_r - m.view(-1, 1, 1)) / (b_v.view(-1,1,1) + eps) * mean_dy_xmu_r.view(-1, 1, 1) ) * torch.rsqrt(b_v.view(-1,1,1) + eps) * weight_r.view(-1,1,1)
  95. sum_dy, sum_dy_xmu, grad_weight, grad_bias = syncbn.reduce_bn(grad_output_t, inp_t, mean, inv_std, weight_t)
  96. grad_input = syncbn.batchnorm_backward(grad_output_t, inp_t, mean, inv_std, weight_t, sum_dy, sum_dy_xmu, count)
  97. sbn_result = compare("comparing bias grad: ", grad_bias, grad_bias_r, error) and sbn_result
  98. sbn_result = compare("comparing weight grad: ", grad_weight, grad_weight_r, error) and sbn_result
  99. sbn_result = compare("comparing sum_dy grad: ", sum_dy, sum_dy_r, error) and sbn_result
  100. sbn_result = compare("comparing sum_dy_xmu grad: ", sum_dy_xmu, sum_dy_xmu_r, error) and sbn_result
  101. sbn_result = compare("comparing input grad: ", grad_input, grad_input_r, error) and sbn_result
  102. compare("comparing bn input grad: ", inp_bn.grad, grad_input_r, error)
  103. sbn_result = compare("comparing sbn input grad: ", inp_sbn.grad, grad_input_r, error) and sbn_result
  104. compare("comparing bn/sbn output: ", out_bn, out_sbn, error)
  105. sbn_result = compare("comparing running_mean: ", bn.running_mean.data, sbn.running_mean.data, error) and sbn_result
  106. sbn_result = compare("comparing running_variance: ", bn.running_var.data, sbn.running_var.data, error) and sbn_result
  107. compare("comparing grad_input: ", inp_bn.grad, inp_sbn.grad, error)
  108. compare("comparing grad_bias: ", bn.bias.grad, sbn.bias.grad, error)
  109. compare("comparing grad_bias bn to ref: ", bn.bias.grad, grad_bias_r, error)
  110. sbn_result = compare("comparing grad_bias sbn to ref: ", sbn.bias.grad, grad_bias_r, error) and sbn_result
  111. compare("comparing grad_weight: ", bn.weight.grad, sbn.weight.grad, error)
  112. compare("comparing grad_weight bn to ref: ", bn.weight.grad, grad_weight_r, error)
  113. sbn_result = compare("comparing grad_weight sbn to ref: ", sbn.weight.grad, grad_weight_r, error) and sbn_result
  114. compare("comparing channel last bn/sbn output: ", out_bn, out_sbn_c_last.transpose(-1, 1).contiguous(), error)
  115. sbn_result_c_last = compare("comparing channel last running_mean: ", bn.running_mean.data, sbn_c_last.running_mean.data, error) and sbn_result_c_last
  116. sbn_result_c_last = compare("comparing channel last running_variance: ", bn.running_var.data, sbn_c_last.running_var.data, error) and sbn_result_c_last
  117. compare("comparing channel last grad_input: ", inp_bn.grad, inp_sbn_c_last.grad.transpose(-1, 1).contiguous(), error)
  118. compare("comparing channel last grad_bias: ", bn.bias.grad, sbn_c_last.bias.grad, error)
  119. sbn_result_c_last = compare("comparing channel last grad_bias sbn to ref: ", sbn_c_last.bias.grad, grad_bias_r, error) and sbn_result_c_last
  120. compare("comparing channel last grad_weight: ", bn.weight.grad, sbn_c_last.weight.grad, error)
  121. sbn_result_c_last = compare("comparing channel last grad_weight sbn to ref: ", sbn_c_last.weight.grad, grad_weight_r, error) and sbn_result_c_last
  122. if sbn_result:
  123. print("====SBN single gpu passed tests")
  124. else:
  125. print("*SBN single gpu failed*")
  126. if sbn_result_c_last:
  127. print("====SBN channel last single gpu passed tests")
  128. else:
  129. print("*SBN channel last single gpu failed*")