fused_dense.cpp 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
  1. #include <torch/extension.h>
  2. #include <torch/torch.h>
  3. #include <vector>
  4. #include <stdio.h>
  5. template <typename T>
  6. int linear_bias_forward_cuda(at::Tensor input, T *weight, at::Tensor bias, int in_features, int batch_size, int out_features, at::Tensor output, void *lt_workspace);
  7. template <typename T>
  8. int linear_bias_backward_cuda(T *input, T *weight, T *d_output, int in_features, int batch_size, int out_features, T *d_weight, T *d_bias, T *d_input, void *lt_workspace);
  9. template <typename T>
  10. int linear_gelu_linear_forward_cuda(T *input, T *weight1, T *bias1, T *weight2, T *bias2, int in_features, int hidden_features, int batch_size, int out_features, T *output1, T *output2, T *gelu_in, void *lt_workspace) ;
  11. template <typename T>
  12. int linear_gelu_linear_backward_cuda(T *input, T *gelu_in, T *output1, T *weight1, T *weight2, T *d_output1, T *d_output2, int in_features, int batch_size, int hidden_features, int out_features, T *d_weight1, T *d_weight2, T *d_bias1, T *d_bias2, T *d_input, void *lt_workspace);
  13. at::Tensor linear_bias_forward(at::Tensor input, at::Tensor weight, at::Tensor bias) {
  14. auto batch_size = input.size(0);
  15. auto in_features = input.size(1);
  16. int out_features = weight.size(0);
  17. //auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data());
  18. // create output/workspace tensor
  19. auto out = at::empty({batch_size, out_features}, input.type());
  20. //auto reserved_space = at::empty({reserved_size}, inputs[0].type());
  21. // allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
  22. auto lt_workspace = at::empty({1 << 22}, input.type());
  23. AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, input.scalar_type(), "linear_bias_forward", [&] {
  24. scalar_t* w_ptr = weight.data_ptr<scalar_t>();
  25. scalar_t* b_ptr = bias.data_ptr<scalar_t>();
  26. auto result = linear_bias_forward_cuda<scalar_t>(
  27. input,
  28. w_ptr,
  29. bias,
  30. in_features,
  31. batch_size,
  32. out_features,
  33. out,
  34. //out.data_ptr<scalar_t>(),
  35. // reserved_space.data_ptr<scalar_t>(),
  36. (void*) (lt_workspace.data_ptr<scalar_t>()));
  37. });
  38. return {out};
  39. }
  40. std::vector<at::Tensor> linear_bias_backward(at::Tensor input, at::Tensor weight, at::Tensor d_output) {
  41. auto batch_size = input.size(0);
  42. auto in_features = input.size(1);
  43. int out_features = weight.size(0);
  44. //auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data());
  45. // create output/workspace tensor
  46. auto d_weight = at::empty({out_features, in_features}, input.type());
  47. #if defined(CUBLAS_VERSION) && CUBLAS_VERSION < 11600
  48. auto d_bias = d_output.view({-1, out_features}).sum(0, false);
  49. #else
  50. auto d_bias = at::empty({out_features}, input.type());
  51. #endif
  52. auto d_input = at::empty({batch_size, in_features}, input.type());
  53. //auto reserved_space = at::empty({reserved_size}, inputs[0].type());
  54. // allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
  55. auto lt_workspace = at::empty({1 << 22}, input.type());
  56. AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, input.scalar_type(), "linear_bias_backward", [&] {
  57. scalar_t* w_ptr = weight.data_ptr<scalar_t>();
  58. scalar_t* d_b_ptr = d_bias.data_ptr<scalar_t>();
  59. auto result = linear_bias_backward_cuda<scalar_t>(
  60. input.data_ptr<scalar_t>(),
  61. w_ptr,
  62. d_output.data_ptr<scalar_t>(),
  63. in_features,
  64. batch_size,
  65. out_features,
  66. d_weight.data_ptr<scalar_t>(),
  67. d_bias.data_ptr<scalar_t>(),
  68. d_input.data_ptr<scalar_t>(),
  69. // reserved_space.data_ptr<scalar_t>(),
  70. (void*) (lt_workspace.data_ptr<scalar_t>()));
  71. });
  72. return {d_input, d_weight, d_bias};
  73. }
  74. std::vector<at::Tensor> linear_gelu_linear_forward(at::Tensor input, at::Tensor weight1, at::Tensor bias1, at::Tensor weight2, at::Tensor bias2) {
  75. auto batch_size = input.size(0);
  76. auto in_features = input.size(1);
  77. int hidden_features = weight1.size(0);
  78. int out_features = weight2.size(0);
  79. //auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data());
  80. // create output/workspace tensor
  81. auto output1 = at::empty({batch_size, hidden_features}, input.type());
  82. auto gelu_in = at::empty({batch_size, hidden_features}, input.type());
  83. auto output2 = at::empty({batch_size, out_features}, input.type());
  84. //auto reserved_space = at::empty({reserved_size}, inputs[0].type());
  85. // allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
  86. auto lt_workspace = at::empty({1 << 22}, input.type());
  87. AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, input.scalar_type(), "linear_gelu_linear_forward", [&] {
  88. scalar_t* w1_ptr = weight1.data_ptr<scalar_t>();
  89. scalar_t* b1_ptr = bias1.data_ptr<scalar_t>();
  90. scalar_t* w2_ptr = weight2.data_ptr<scalar_t>();
  91. scalar_t* b2_ptr = bias2.data_ptr<scalar_t>();
  92. auto result = linear_gelu_linear_forward_cuda<scalar_t>(
  93. input.data_ptr<scalar_t>(),
  94. w1_ptr,
  95. b1_ptr,
  96. w2_ptr,
  97. b2_ptr,
  98. in_features,
  99. hidden_features,
  100. batch_size,
  101. out_features,
  102. output1.data_ptr<scalar_t>(),
  103. output2.data_ptr<scalar_t>(),
  104. gelu_in.data_ptr<scalar_t>(),
  105. // reserved_space.data_ptr<scalar_t>(),
  106. (void*) (lt_workspace.data_ptr<scalar_t>()));
  107. });
  108. return {output1, output2, gelu_in};
  109. }
  110. std::vector<at::Tensor> linear_gelu_linear_backward(at::Tensor input, at::Tensor gelu_in, at::Tensor output1, at::Tensor weight1, at::Tensor weight2, at::Tensor d_output2) {
  111. auto batch_size = input.size(0);
  112. auto in_features = input.size(1);
  113. int hidden_features = weight1.size(0);
  114. int out_features = weight2.size(0);
  115. //auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data());
  116. // create output/workspace tensor
  117. auto d_weight1 = at::empty({hidden_features, in_features}, input.type());
  118. auto d_weight2 = at::empty({out_features, hidden_features}, input.type());
  119. auto d_bias1 = at::empty({hidden_features}, input.type());
  120. auto d_bias2 = at::empty({out_features}, input.type());
  121. auto d_input = at::empty({batch_size, in_features}, input.type());
  122. auto d_output1 = at::empty({batch_size, hidden_features}, input.type());
  123. //auto reserved_space = at::empty({reserved_size}, inputs[0].type());
  124. // allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
  125. auto lt_workspace = at::empty({1 << 22}, input.type());
  126. AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, input.scalar_type(), "linear_bias_backward", [&] {
  127. //scalar_t* w_ptr = weight.data_ptr<scalar_t>();
  128. //scalar_t* d_b_ptr = d_bias.data_ptr<scalar_t>();
  129. auto result = linear_gelu_linear_backward_cuda<scalar_t>(
  130. input.data_ptr<scalar_t>(),
  131. gelu_in.data_ptr<scalar_t>(),
  132. output1.data_ptr<scalar_t>(),
  133. weight1.data_ptr<scalar_t>(),
  134. weight2.data_ptr<scalar_t>(),
  135. d_output1.data_ptr<scalar_t>(),
  136. d_output2.data_ptr<scalar_t>(),
  137. in_features,
  138. batch_size,
  139. hidden_features,
  140. out_features,
  141. d_weight1.data_ptr<scalar_t>(),
  142. d_weight2.data_ptr<scalar_t>(),
  143. d_bias1.data_ptr<scalar_t>(),
  144. d_bias2.data_ptr<scalar_t>(),
  145. d_input.data_ptr<scalar_t>(),
  146. // reserved_space.data_ptr<scalar_t>(),
  147. (void*) (lt_workspace.data_ptr<scalar_t>()));
  148. });
  149. return {d_input, d_weight1, d_bias1, d_weight2, d_bias2};
  150. }
  151. PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  152. m.def("linear_bias_forward", &linear_bias_forward, "linear bias forward");
  153. m.def("linear_bias_backward", &linear_bias_backward, "linear bias backward");
  154. m.def("linear_gelu_linear_forward", &linear_gelu_linear_forward, "linear gelu linear forward");
  155. m.def("linear_gelu_linear_backward", &linear_gelu_linear_backward, "linear gelu linear backward");
  156. }