1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798 |
- /* coding=utf-8
- * Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
- #include <ATen/ATen.h>
- #include <cuda.h>
- #include <cuda_runtime.h>
- #include <cuda_fp16.h>
- #include <cuda_profiler_api.h>
- #include <ATen/cuda/CUDAContext.h>
- #include <torch/extension.h>
- #include "scaled_upper_triang_masked_softmax.h"
- #include "type_shim.h"
- namespace multihead_attn {
- namespace fused_softmax {
- namespace scaled_upper_triang_masked_softmax {
- torch::Tensor fwd_cuda(
- torch::Tensor const& input,
- float scale_factor)
- {
- // input is a 3d tensor with dimensions [attn_batches, seq_len, seq_len]
- const int attn_batches = input.size(0);
- const int seq_len = input.size(1);
- TORCH_INTERNAL_ASSERT(seq_len <= 16384);
- // Output
- auto act_options = input.options().requires_grad(false);
- torch::Tensor softmax_results =
- torch::empty({attn_batches, seq_len, seq_len}, act_options);
- // Softmax Intermediate Result Ptr
- void* input_ptr = static_cast<void*>(input.data_ptr());
- void* softmax_results_ptr = static_cast<void*>(softmax_results.data_ptr());
- DISPATCH_HALF_AND_BFLOAT(
- input.scalar_type(),
- "dispatch_scaled_upper_triang_masked_softmax_forward",
- dispatch_scaled_upper_triang_masked_softmax_forward<scalar_t, scalar_t, float>(
- reinterpret_cast<scalar_t*>(softmax_results_ptr),
- reinterpret_cast<const scalar_t*>(input_ptr),
- scale_factor,
- seq_len,
- seq_len,
- attn_batches);
- );
- return softmax_results;
- }
-
- torch::Tensor bwd_cuda(
- torch::Tensor const& output_grads_,
- torch::Tensor const& softmax_results_,
- float scale_factor) {
-
- auto output_grads = output_grads_.contiguous();
- auto softmax_results = softmax_results_.contiguous();
- //output grads is a 3d tensor with dimensions [attn_batches, seq_len, seq_len]
- const int attn_batches = output_grads.size(0);
- const int seq_len = output_grads.size(1);
- TORCH_INTERNAL_ASSERT(output_grads.size(1) == output_grads.size(2));
- void* output_grads_ptr = static_cast<void*>(output_grads.data_ptr());
- //Softmax Grad
- DISPATCH_HALF_AND_BFLOAT(
- output_grads_.scalar_type(),
- "dispatch_scaled_upper_triang_masked_softmax_backward",
- dispatch_scaled_upper_triang_masked_softmax_backward<scalar_t, scalar_t, float>(
- reinterpret_cast<scalar_t*>(output_grads_ptr),
- reinterpret_cast<scalar_t*>(output_grads_ptr),
- reinterpret_cast<scalar_t const*>(softmax_results.data_ptr()),
- scale_factor,
- seq_len,
- seq_len,
- attn_batches);
- );
-
- //backward pass is completely in-place
- return output_grads;
- }
- }
- }
- }
|