scaled_upper_triang_masked_softmax_cuda.cu 3.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. /* coding=utf-8
  2. * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <ATen/ATen.h>
  17. #include <cuda.h>
  18. #include <cuda_runtime.h>
  19. #include <cuda_fp16.h>
  20. #include <cuda_profiler_api.h>
  21. #include <ATen/cuda/CUDAContext.h>
  22. #include <torch/extension.h>
  23. #include "scaled_upper_triang_masked_softmax.h"
  24. #include "type_shim.h"
  25. namespace multihead_attn {
  26. namespace fused_softmax {
  27. namespace scaled_upper_triang_masked_softmax {
  28. torch::Tensor fwd_cuda(
  29. torch::Tensor const& input,
  30. float scale_factor)
  31. {
  32. // input is a 3d tensor with dimensions [attn_batches, seq_len, seq_len]
  33. const int attn_batches = input.size(0);
  34. const int seq_len = input.size(1);
  35. TORCH_INTERNAL_ASSERT(seq_len <= 16384);
  36. // Output
  37. auto act_options = input.options().requires_grad(false);
  38. torch::Tensor softmax_results =
  39. torch::empty({attn_batches, seq_len, seq_len}, act_options);
  40. // Softmax Intermediate Result Ptr
  41. void* input_ptr = static_cast<void*>(input.data_ptr());
  42. void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());
  43. DISPATCH_HALF_AND_BFLOAT(
  44. input.scalar_type(),
  45. "dispatch_scaled_upper_triang_masked_softmax_forward",
  46. dispatch_scaled_upper_triang_masked_softmax_forward<scalar_t, scalar_t, float>(
  47. reinterpret_cast<scalar_t*>(softmax_results_ptr),
  48. reinterpret_cast<const scalar_t*>(input_ptr),
  49. scale_factor,
  50. seq_len,
  51. seq_len,
  52. attn_batches);
  53. );
  54. return softmax_results;
  55. }
  56. torch::Tensor bwd_cuda(
  57. torch::Tensor const& output_grads_,
  58. torch::Tensor const& softmax_results_,
  59. float scale_factor) {
  60. auto output_grads = output_grads_.contiguous();
  61. auto softmax_results = softmax_results_.contiguous();
  62. //output grads is a 3d tensor with dimensions [attn_batches, seq_len, seq_len]
  63. const int attn_batches = output_grads.size(0);
  64. const int seq_len = output_grads.size(1);
  65. TORCH_INTERNAL_ASSERT(output_grads.size(1) == output_grads.size(2));
  66. void* output_grads_ptr = static_cast<void*>(output_grads.data_ptr());
  67. //Softmax Grad
  68. DISPATCH_HALF_AND_BFLOAT(
  69. output_grads_.scalar_type(),
  70. "dispatch_scaled_upper_triang_masked_softmax_backward",
  71. dispatch_scaled_upper_triang_masked_softmax_backward<scalar_t, scalar_t, float>(
  72. reinterpret_cast<scalar_t*>(output_grads_ptr),
  73. reinterpret_cast<scalar_t*>(output_grads_ptr),
  74. reinterpret_cast<scalar_t const*>(softmax_results.data_ptr()),
  75. scale_factor,
  76. seq_len,
  77. seq_len,
  78. attn_batches);
  79. );
  80. //backward pass is completely in-place
  81. return output_grads;
  82. }
  83. }
  84. }
  85. }