scaled_masked_softmax.cpp 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. /* coding=utf-8
  2. * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <cuda_fp16.h>
  17. #include <torch/extension.h>
  18. #include <vector>
  19. namespace multihead_attn {
  20. namespace fused_softmax {
  21. namespace scaled_masked_softmax {
  22. torch::Tensor fwd_cuda(
  23. torch::Tensor const& input,
  24. torch::Tensor const& mask,
  25. float scale_factor);
  26. torch::Tensor bwd_cuda(
  27. torch::Tensor const& output_grads,
  28. torch::Tensor const& softmax_results,
  29. float scale_factor);
  30. int get_batch_per_block_cuda(
  31. int query_seq_len,
  32. int key_seq_len,
  33. int batches,
  34. int attn_heads);
  35. torch::Tensor fwd(
  36. torch::Tensor & input,
  37. torch::Tensor & mask,
  38. float scale_factor) {
  39. TORCH_CHECK(input.dim() == 4, "expected 4D tensor");
  40. TORCH_CHECK((input.scalar_type() == at::ScalarType::Half) ||
  41. (input.scalar_type() == at::ScalarType::BFloat16),
  42. "Only fp16 and bf16 are supported");
  43. TORCH_CHECK(mask.dim() == 4, "expected 4D tensor");
  44. if (!input.is_contiguous())
  45. input = input.contiguous();
  46. if (!mask.is_contiguous())
  47. mask = mask.contiguous();
  48. return fwd_cuda(input, mask, scale_factor);
  49. }
  50. torch::Tensor bwd(
  51. torch::Tensor & output_grads,
  52. torch::Tensor & softmax_results,
  53. float scale_factor) {
  54. TORCH_CHECK(output_grads.dim() == 4, "expected 3D tensor");
  55. TORCH_CHECK(softmax_results.dim() == 4, "expected 3D tensor");
  56. TORCH_CHECK((output_grads.scalar_type() == at::ScalarType::Half) ||
  57. (output_grads.scalar_type() == at::ScalarType::BFloat16),
  58. "Only fp16 and bf16 are supported");
  59. TORCH_CHECK((softmax_results.scalar_type() == at::ScalarType::Half) ||
  60. (softmax_results.scalar_type() == at::ScalarType::BFloat16),
  61. "Only fp16 and bf16 are supported");
  62. if (!output_grads.is_contiguous())
  63. output_grads = output_grads.contiguous();
  64. if (!softmax_results.is_contiguous())
  65. softmax_results = softmax_results.contiguous();
  66. return bwd_cuda(output_grads, softmax_results, scale_factor);
  67. }
  68. int get_batch_per_block(
  69. int query_seq_len,
  70. int key_seq_len,
  71. int batches,
  72. int attn_heads) {
  73. return get_batch_per_block_cuda(query_seq_len, key_seq_len, batches, attn_heads);
  74. }
  75. } // end namespace scaled_masked_softmax
  76. } // end namespace fused_softmax
  77. } // end namespace multihead_attn
  78. PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  79. m.def("forward",
  80. &multihead_attn::fused_softmax::scaled_masked_softmax::fwd,
  81. "Self Multihead Attention scaled, time masked softmax -- Forward.");
  82. m.def("backward",
  83. &multihead_attn::fused_softmax::scaled_masked_softmax::bwd,
  84. "Self Multihead Attention scaled, time masked softmax -- Backward.");
  85. m.def("get_batch_per_block",
  86. &multihead_attn::fused_softmax::scaled_masked_softmax::get_batch_per_block,
  87. "Return Batch per block size."
  88. );
  89. }