mlp.cpp 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. #include <torch/extension.h>
  2. #include <torch/torch.h>
  3. #include <vector>
  4. #include <stdio.h>
  5. size_t get_mlp_reserved_space(int64_t batch_size, int num_layers, const int* output_features);
  6. template <typename T>
  7. size_t get_mlp_bp_workspace_in_bytes(int batch_size, int num_layers, const int* output_features);
  8. template <typename T>
  9. int mlp_fp(
  10. T* X,
  11. int input_features,
  12. int batch_size,
  13. T** WPtr,
  14. int num_layers,
  15. int* output_features,
  16. T** BPtr,
  17. T* Y,
  18. T* reserved_space,
  19. int use_bias,
  20. int activation,
  21. void* lt_workspace);
  22. template <typename T>
  23. int mlp_bp(
  24. T* X,
  25. T* Y,
  26. int input_features,
  27. int batch_size,
  28. T** WPtr,
  29. int num_layers,
  30. int* output_features,
  31. T* dY,
  32. T* reserved_space,
  33. T* work_space,
  34. T* dX,
  35. T** dwPtr,
  36. T** dbPtr,
  37. bool requires_grad,
  38. int use_bias,
  39. int activation);
  40. std::vector<at::Tensor> mlp_forward(int use_bias, int activation, std::vector<at::Tensor> inputs) {
  41. auto num_layers = inputs.size() - 1;
  42. if (use_bias) {
  43. // inputs contains (input, weights, biases)
  44. num_layers /= 2;
  45. }
  46. auto batch_size = inputs[0].size(0);
  47. auto input_features = inputs[0].size(1);
  48. std::vector<int> output_features;
  49. for (int i = 0; i < num_layers; i++) {
  50. output_features.push_back(inputs[i + 1].size(0));
  51. }
  52. auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data());
  53. // create output/workspace tensor
  54. auto out = at::empty({batch_size, output_features.back()}, inputs[0].type());
  55. auto reserved_space = at::empty({static_cast<long>(reserved_size)}, inputs[0].type());
  56. // allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
  57. auto lt_workspace = at::empty({1 << 22}, inputs[0].type());
  58. AT_DISPATCH_FLOATING_TYPES_AND_HALF(inputs[0].type(), "mlp_forward", [&] {
  59. std::vector<scalar_t*> w_ptr;
  60. std::vector<scalar_t*> b_ptr;
  61. for (int i = 0; i < num_layers; i++) {
  62. w_ptr.push_back(inputs[i + 1].data_ptr<scalar_t>());
  63. if (use_bias) {
  64. b_ptr.push_back(inputs[i + 1 + num_layers].data_ptr<scalar_t>());
  65. }
  66. }
  67. auto result = mlp_fp<scalar_t>(
  68. inputs[0].data_ptr<scalar_t>(),
  69. input_features,
  70. batch_size,
  71. w_ptr.data(),
  72. num_layers,
  73. output_features.data(),
  74. b_ptr.data(),
  75. out.data_ptr<scalar_t>(),
  76. reserved_space.data_ptr<scalar_t>(),
  77. use_bias,
  78. activation,
  79. (void*) (lt_workspace.data_ptr<scalar_t>()));
  80. });
  81. return {out, reserved_space};
  82. }
  83. std::vector<at::Tensor> mlp_backward(
  84. int use_bias,
  85. int activation,
  86. at::Tensor grad_o,
  87. std::vector<at::Tensor> fprop_outputs,
  88. std::vector<at::Tensor> inputs) {
  89. auto num_layers = inputs.size() - 1;
  90. if (use_bias) {
  91. // inputs contains (input, weights, biases)
  92. num_layers /= 2;
  93. }
  94. auto batch_size = inputs[0].size(0);
  95. auto input_features = inputs[0].size(1);
  96. bool requires_grad = inputs[0].requires_grad();
  97. std::vector<int> output_features;
  98. for (int i = 0; i < num_layers; i++) {
  99. output_features.push_back(inputs[i + 1].size(0));
  100. }
  101. // create outputs, length of inputs
  102. std::vector<at::Tensor> outputs;
  103. for (int i = 0; i < inputs.size(); i++) {
  104. outputs.push_back(at::empty(inputs[i].sizes(), inputs[i].type())); // clone for testing now
  105. }
  106. AT_DISPATCH_FLOATING_TYPES_AND_HALF(inputs[0].type(), "mlp_backward", [&] {
  107. std::vector<scalar_t*> w_ptr;
  108. for (int i = 0; i < num_layers; i++) {
  109. w_ptr.push_back(inputs[i + 1].data_ptr<scalar_t>());
  110. }
  111. std::vector<scalar_t*> outputs_ptr;
  112. for (int i = 0; i < inputs.size(); i++) {
  113. outputs_ptr.push_back(outputs[i].data_ptr<scalar_t>());
  114. }
  115. auto work_size =
  116. get_mlp_bp_workspace_in_bytes<scalar_t>(batch_size, num_layers, output_features.data());
  117. // auto work_space = at::empty({work_size*4}, at::kByte);
  118. auto work_space = at::empty({static_cast<long>(work_size / sizeof(scalar_t))}, inputs[0].type());
  119. auto result = mlp_bp<scalar_t>(
  120. inputs[0].data_ptr<scalar_t>(),
  121. fprop_outputs[0].data_ptr<scalar_t>(),
  122. input_features,
  123. batch_size,
  124. w_ptr.data(),
  125. num_layers,
  126. output_features.data(),
  127. grad_o.contiguous().data_ptr<scalar_t>(),
  128. fprop_outputs[1].data_ptr<scalar_t>(),
  129. work_space.data_ptr<scalar_t>(),
  130. outputs_ptr[0],
  131. outputs_ptr.data() + 1,
  132. outputs_ptr.data() + 1 + num_layers,
  133. requires_grad,
  134. use_bias,
  135. activation);
  136. });
  137. return outputs;
  138. }
  139. PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  140. m.def("forward", &mlp_forward, "MLP forward");
  141. m.def("backward", &mlp_backward, "MLP backward");
  142. }