3
0

fused_weight_gradient_dense.cpp 540 B

123456789101112131415161718192021
  1. #include <torch/extension.h>
  2. #include <cstdio>
  3. #include <vector>
  4. void wgrad_gemm_accum_fp32_cuda_stub(
  5. at::Tensor &input_2d,
  6. at::Tensor &d_output_2d,
  7. at::Tensor &d_weight
  8. );
  9. void wgrad_gemm_accum_fp16_cuda_stub(
  10. at::Tensor &input_2d,
  11. at::Tensor &d_output_2d,
  12. at::Tensor &d_weight
  13. );
  14. PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  15. m.def("wgrad_gemm_accum_fp32", &wgrad_gemm_accum_fp32_cuda_stub, "wgrad gemm accum in fp32");
  16. m.def("wgrad_gemm_accum_fp16", &wgrad_gemm_accum_fp16_cuda_stub, "wgrad gemm accum in fp16");
  17. }