syncbn.cpp 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. #include <torch/extension.h>
  2. #include <ATen/ATen.h>
  3. #include <vector>
  4. // returns {mean,biased_var}
  5. // implemented using welford
  6. std::vector<at::Tensor> welford_mean_var_CUDA(const at::Tensor input);
  7. // reduces array of mean/var across processes
  8. // returns global {mean,inv_std,biased_var}
  9. // implemented using welford
  10. std::vector<at::Tensor> welford_parallel_CUDA(const at::Tensor mean_feature_nodes,
  11. const at::Tensor var_biased_feature_nodes,
  12. const at::Tensor numel,
  13. const float eps);
  14. // elementwise BN operation, returns output
  15. // input/weight/shift should have identical data type;
  16. // mean/inv_std have promoted data type (dtype==fp16?fp32:dtype)
  17. at::Tensor batchnorm_forward_CUDA(const at::Tensor input,
  18. const at::Tensor mean,
  19. const at::Tensor inv_std,
  20. const at::optional<at::Tensor> weight,
  21. const at::optional<at::Tensor> shift);
  22. // backward BN operation, returns {sum_dy, sum_dy_xmu, grad_weight, grad_bias}
  23. // grad_output/input should have identical data type;
  24. // mean/inv_std have promoted data type (dtype==fp16?fp32:dtype)
  25. // implemented using kahan summation
  26. std::vector<at::Tensor> reduce_bn_CUDA(const at::Tensor grad_output,
  27. const at::Tensor input,
  28. const at::Tensor mean,
  29. const at::Tensor inv_std,
  30. const at::optional<at::Tensor> weight);
  31. // elementwise backward BN operation, returns grad_input
  32. // grad_output/input/weight precision could be fp16/fp32;
  33. // mean/inv_std/sum_dy/sum_dy_xmu precision is fp32
  34. at::Tensor batchnorm_backward_CUDA(const at::Tensor grad_output,
  35. const at::Tensor input,
  36. const at::Tensor mean,
  37. const at::Tensor inv_std,
  38. const at::optional<at::Tensor> weight,
  39. const at::Tensor sum_dy,
  40. const at::Tensor sum_dy_xmu,
  41. const at::Tensor count);
  42. // returns {mean, biased_var}
  43. // implemented using welford
  44. // expect data to be in n+c format (channel last) and applies CUDNN_BATCHNORM_SPATIAL
  45. std::vector<at::Tensor> welford_mean_var_c_last_CUDA(const at::Tensor input);
  46. // elementwise BN operation, returns output
  47. // input/weight/shift should have identical data type;
  48. // mean/inv_std have promoted data type (dtype==fp16?fp32:dtype)
  49. // expect data to be in n+c format (channel last) and applies CUDNN_BATCHNORM_SPATIAL
  50. at::Tensor batchnorm_forward_c_last_CUDA(const at::Tensor input,
  51. const at::optional<at::Tensor> z,
  52. const at::Tensor mean,
  53. const at::Tensor inv_std,
  54. const at::optional<at::Tensor> weight,
  55. const at::optional<at::Tensor> shift,
  56. const bool fuse_relu);
  57. // backward BN operation, returns {sum_dy, sum_dy_xmu, grad_weight, grad_bias}
  58. // grad_output/input should have identical data type;
  59. // mean/inv_std have promoted data type (dtype==fp16?fp32:dtype)
  60. // expect data to be in n+c format (channel last) and applies CUDNN_BATCHNORM_SPATIAL
  61. std::vector<at::Tensor> reduce_bn_c_last_CUDA(const at::Tensor grad_output,
  62. const at::Tensor input,
  63. const at::Tensor mean,
  64. const at::Tensor inv_std,
  65. const at::optional<at::Tensor> weight);
  66. // elementwise backward BN operation, returns grad_input
  67. // grad_output/input/weight precision could be fp16/fp32;
  68. // mean/inv_std/sum_dy/sum_dy_xmu precision is fp32
  69. // expect data to be in n+c format (channel last) and applies CUDNN_BATCHNORM_SPATIAL
  70. at::Tensor batchnorm_backward_c_last_CUDA(const at::Tensor grad_output,
  71. const at::Tensor input,
  72. const at::Tensor mean,
  73. const at::Tensor inv_std,
  74. const at::optional<at::Tensor> weight,
  75. const at::Tensor sum_dy,
  76. const at::Tensor sum_dy_xmu,
  77. const at::Tensor count);
  78. at::Tensor relu_backward_c_last_CUDA(const at::Tensor grad_output,
  79. const at::Tensor input,
  80. const at::optional<at::Tensor> z,
  81. const at::Tensor mean,
  82. const at::Tensor inv_std,
  83. const at::optional<at::Tensor> weight,
  84. const at::optional<at::Tensor> shift);
  85. PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  86. m.def("welford_mean_var", &welford_mean_var_CUDA, "welford mean variance");
  87. m.def("welford_parallel", &welford_parallel_CUDA, "welford parallel reduce mean variance");
  88. m.def("batchnorm_forward", &batchnorm_forward_CUDA, "batchnorm forward");
  89. m.def("reduce_bn", &reduce_bn_CUDA, "batchnorm backward reduce grad sum and bias/weight grad");
  90. m.def("batchnorm_backward", &batchnorm_backward_CUDA, "batchnorm backward dgrad");
  91. m.def("welford_mean_var_c_last", &welford_mean_var_c_last_CUDA, "welford mean variance nhwc");
  92. m.def("batchnorm_forward_c_last", &batchnorm_forward_c_last_CUDA, "batchnorm forward nhwc");
  93. m.def("reduce_bn_c_last", &reduce_bn_c_last_CUDA, "batchnorm backwards reduce grad sum and bias/weight grad nhwc");
  94. m.def("batchnorm_backward_c_last", &batchnorm_backward_c_last_CUDA, "batchnorm backward dgrad nhwc");
  95. m.def("relu_bw_c_last", &relu_backward_c_last_CUDA, "relu_bw_c_last");
  96. }