123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459 |
- #include <torch/extension.h>
- #include <vector>
- #include <cassert>
- #include "compat.h"
- namespace {
- void compute_n1_n2(
- at::Tensor input,
- #ifdef VERSION_GE_1_1
- at::IntArrayRef normalized_shape,
- #else
- at::IntList normalized_shape,
- #endif
- int& n1,
- int& n2)
- {
- int idiff = input.ndimension() - normalized_shape.size();
- n2 = 1;
- for (int i = 0; i < (int)normalized_shape.size(); ++i) {
- assert( input.sizes()[i+idiff] == normalized_shape[i] );
- n2 *= normalized_shape[i];
- }
- n1 = 1;
- for (int i = 0; i < idiff; ++i) {
- n1 *= input.sizes()[i];
- }
- }
- void check_args(
- #ifdef VERSION_GE_1_1
- at::IntArrayRef normalized_shape,
- #else
- at::IntList normalized_shape,
- #endif
- at::Tensor gamma,
- at::Tensor beta
- )
- {
- TORCH_CHECK(!gamma.defined() || gamma.sizes().equals(normalized_shape));
- TORCH_CHECK(!beta.defined() || beta.sizes().equals(normalized_shape));
- }
- void check_args(
- #ifdef VERSION_GE_1_1
- at::IntArrayRef normalized_shape,
- #else
- at::IntList normalized_shape,
- #endif
- at::Tensor gamma
- )
- {
- TORCH_CHECK(!gamma.defined() || gamma.sizes().equals(normalized_shape));
- }
- void check_args(
- at::Tensor input,
- #ifdef VERSION_GE_1_1
- at::IntArrayRef normalized_shape,
- #else
- at::IntList normalized_shape,
- #endif
- int& n1,
- int& n2
- )
- {
- int64_t normalized_ndim = normalized_shape.size();
- if (normalized_ndim < 1) {
- std::stringstream ss;
- ss << "Expected normalized_shape to be at least 1-dimensional, i.e., "
- << "containing at least one element, but got normalized_shape="
- << normalized_shape;
- throw std::runtime_error(ss.str());
- }
- auto input_shape = input.sizes();
- auto input_ndim = input.dim();
- if (input_ndim < normalized_ndim ||
- !input_shape.slice(input_ndim - normalized_ndim).equals(normalized_shape)) {
- std::stringstream ss;
- ss << "Given normalized_shape=" << normalized_shape
- << ", expected input with shape [*";
- for (auto size : normalized_shape) {
- ss << ", " << size;
- }
- ss << "], but got input of size" << input_shape;
- throw std::runtime_error(ss.str());
- }
- compute_n1_n2(input,normalized_shape,n1,n2);
- }
- void check_args(
- at::Tensor input,
- #ifdef VERSION_GE_1_1
- at::IntArrayRef normalized_shape,
- #else
- at::IntList normalized_shape,
- #endif
- at::Tensor gamma,
- at::Tensor beta,
- int& n1,
- int& n2
- )
- {
- check_args(input,normalized_shape,n1,n2);
- check_args(normalized_shape,gamma,beta);
- }
- void check_args(
- at::Tensor input,
- #ifdef VERSION_GE_1_1
- at::IntArrayRef normalized_shape,
- #else
- at::IntList normalized_shape,
- #endif
- at::Tensor gamma,
- int& n1,
- int& n2
- )
- {
- check_args(input,normalized_shape,n1,n2);
- check_args(normalized_shape,gamma);
- }
- }
- void cuda_layer_norm(
- at::Tensor* output,
- at::Tensor* mean,
- at::Tensor* invvar,
- at::Tensor* input,
- int n1,
- int n2,
- #ifdef VERSION_GE_1_1
- at::IntArrayRef normalized_shape,
- #else
- at::IntList normalized_shape,
- #endif
- at::Tensor* gamma,
- at::Tensor* beta,
- double epsilon);
- #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
- #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
- #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
- std::vector<at::Tensor> layer_norm(
- at::Tensor input,
- #ifdef VERSION_GE_1_1
- at::IntArrayRef normalized_shape,
- #else
- at::IntList normalized_shape,
- #endif
- double epsilon) {
- CHECK_INPUT(input);
- int n1,n2;
- check_args(input,normalized_shape,n1,n2);
- at::Tensor output = at::empty_like(input);
- at::Tensor mean = at::empty({n1}, input.options().dtype(input.scalar_type()==at::ScalarType::Half || input.scalar_type()==at::ScalarType::BFloat16 ? at::ScalarType::Float : input.scalar_type()));
- at::Tensor invvar = at::empty_like(mean);
- cuda_layer_norm(&output,&mean,&invvar,&input,n1,n2,
- normalized_shape,NULL,NULL,epsilon);
- return {output, mean, invvar};
- }
- std::vector<at::Tensor> layer_norm_affine(
- at::Tensor input,
- #ifdef VERSION_GE_1_1
- at::IntArrayRef normalized_shape,
- #else
- at::IntList normalized_shape,
- #endif
- at::Tensor gamma,
- at::Tensor beta,
- double epsilon) {
- CHECK_INPUT(input);
- CHECK_INPUT(gamma);
- CHECK_INPUT(beta);
- int n1,n2;
- check_args(input,normalized_shape,gamma,beta,n1,n2);
- at::Tensor output = at::empty_like(input);
- const auto stats_dtype = (input.scalar_type() == at::ScalarType::Half || input.scalar_type() == at::ScalarType::BFloat16) ? at::ScalarType::Float : input.scalar_type();
- at::Tensor mean = at::empty({n1}, input.options().dtype(stats_dtype));
- at::Tensor invvar = at::empty_like(mean);
- cuda_layer_norm(&output,&mean,&invvar,&input,n1,n2,
- normalized_shape,&gamma,&beta,epsilon);
- return {output, mean, invvar};
- }
- std::vector<at::Tensor> layer_norm_affine_mixed_dtypes(
- at::Tensor input,
- #ifdef VERSION_GE_1_1
- at::IntArrayRef normalized_shape,
- #else
- at::IntList normalized_shape,
- #endif
- at::Tensor gamma,
- at::Tensor beta,
- double epsilon) {
- CHECK_INPUT(input);
- int n1, n2;
- check_args(input, normalized_shape, n1, n2);
- at::Tensor output = at::empty_like(input, gamma.options().dtype(gamma.scalar_type()));
- at::Tensor mean = at::empty({n1}, input.options().dtype(input.scalar_type() == at::ScalarType::Half || input.scalar_type() == at::ScalarType::BFloat16 ? at::ScalarType::Float : input.scalar_type()));
- at::Tensor invvar = at::empty_like(mean);
- cuda_layer_norm(&output, &mean, &invvar, &input, n1, n2,
- normalized_shape, &gamma, &beta, epsilon);
- return {output, mean, invvar};
- }
- void cuda_layer_norm_gradient(
- at::Tensor* dout,
- at::Tensor* mean,
- at::Tensor* invvar,
- at::Tensor* input_or_output,
- int n1,
- int n2,
- #ifdef VERSION_GE_1_1
- at::IntArrayRef normalized_shape,
- #else
- at::IntList normalized_shape,
- #endif
- at::Tensor* gamma,
- at::Tensor* beta,
- double epsilon,
- at::Tensor* grad_input,
- at::Tensor* grad_gamma,
- at::Tensor* grad_beta,
- bool memory_efficient
- );
- at::Tensor layer_norm_gradient(
- at::Tensor dout,
- c10::optional<at::Tensor> mean_,
- at::Tensor invvar,
- at::Tensor input_or_output,
- #ifdef VERSION_GE_1_1
- at::IntArrayRef normalized_shape,
- #else
- at::IntList normalized_shape,
- #endif
- double epsilon,
- bool memory_efficient) {
- CHECK_INPUT(dout);
- CHECK_INPUT(invvar);
- CHECK_INPUT(input_or_output);
- int n1,n2;
- check_args(input_or_output,normalized_shape,n1,n2);
- at::Tensor grad_input = at::empty_like(input_or_output);
- if (mean_.has_value()) {
- cuda_layer_norm_gradient(&dout,&mean_.value(),&invvar,&input_or_output,n1,n2,
- normalized_shape,NULL,NULL,epsilon,
- &grad_input,NULL,NULL,memory_efficient);
- } else {
- cuda_layer_norm_gradient(&dout,NULL,&invvar,&input_or_output,n1,n2,
- normalized_shape,NULL,NULL,epsilon,
- &grad_input,NULL,NULL,memory_efficient);
- }
- return grad_input;
- }
- std::vector<at::Tensor> layer_norm_gradient_affine(
- at::Tensor dout,
- c10::optional<at::Tensor> mean_,
- at::Tensor invvar,
- at::Tensor input_or_output,
- #ifdef VERSION_GE_1_1
- at::IntArrayRef normalized_shape,
- #else
- at::IntList normalized_shape,
- #endif
- at::Tensor gamma,
- at::Tensor beta,
- double epsilon,
- bool memory_efficient) {
- CHECK_INPUT(dout);
- CHECK_INPUT(invvar);
- CHECK_INPUT(input_or_output);
- CHECK_INPUT(gamma);
- CHECK_INPUT(beta);
- int n1,n2;
- check_args(input_or_output,normalized_shape,gamma,beta,n1,n2);
- at::Tensor grad_input = at::empty_like(input_or_output);
- at::Tensor grad_gamma = at::empty_like(gamma);
- at::Tensor grad_beta = at::empty_like(beta);
- // at::Tensor *mean = mean_.has_value() ? &mean_.value() : NULL;
- if (mean_.has_value()) {
- cuda_layer_norm_gradient(&dout,&mean_.value(),&invvar,&input_or_output,n1,n2,
- normalized_shape,&gamma,&beta,epsilon,
- &grad_input,&grad_gamma,&grad_beta,memory_efficient);
- } else {
- cuda_layer_norm_gradient(&dout,NULL,&invvar,&input_or_output,n1,n2,
- normalized_shape,&gamma,&beta,epsilon,
- &grad_input,&grad_gamma,&grad_beta,memory_efficient);
- }
- return {grad_input, grad_gamma, grad_beta};
- }
- void cuda_rms_norm(
- at::Tensor* output,
- at::Tensor* invvar,
- at::Tensor* input,
- int n1,
- int n2,
- #ifdef VERSION_GE_1_1
- at::IntArrayRef normalized_shape,
- #else
- at::IntList normalized_shape,
- #endif
- at::Tensor* gamma,
- double epsilon);
- #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
- #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
- #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
- std::vector<at::Tensor> rms_norm(
- at::Tensor input,
- #ifdef VERSION_GE_1_1
- at::IntArrayRef normalized_shape,
- #else
- at::IntList normalized_shape,
- #endif
- double epsilon) {
- CHECK_INPUT(input);
- int n1,n2;
- check_args(input,normalized_shape,n1,n2);
- at::Tensor output = at::empty_like(input);
- at::Tensor invvar = at::empty({n1}, input.options().dtype(input.scalar_type()==at::ScalarType::Half || input.scalar_type()==at::ScalarType::BFloat16 ? at::ScalarType::Float : input.scalar_type()));
- cuda_rms_norm(&output,&invvar,&input,n1,n2,
- normalized_shape,NULL,epsilon);
- return {output, invvar};
- }
- std::vector<at::Tensor> rms_norm_affine(
- at::Tensor input,
- #ifdef VERSION_GE_1_1
- at::IntArrayRef normalized_shape,
- #else
- at::IntList normalized_shape,
- #endif
- at::Tensor gamma,
- double epsilon) {
- CHECK_INPUT(input);
- CHECK_INPUT(gamma);
- int n1,n2;
- check_args(input,normalized_shape,gamma,n1,n2);
- at::Tensor output = at::empty_like(input);
- const auto stats_dtype = (input.scalar_type() == at::ScalarType::Half || input.scalar_type() == at::ScalarType::BFloat16) ? at::ScalarType::Float : input.scalar_type();
- at::Tensor invvar = at::empty({n1}, input.options().dtype(stats_dtype));
- cuda_rms_norm(&output,&invvar,&input,n1,n2,
- normalized_shape,&gamma,epsilon);
- return {output, invvar};
- }
- std::vector<at::Tensor> rms_norm_affine_mixed_dtypes(
- at::Tensor input,
- #ifdef VERSION_GE_1_1
- at::IntArrayRef normalized_shape,
- #else
- at::IntList normalized_shape,
- #endif
- at::Tensor gamma,
- double epsilon) {
- CHECK_INPUT(input);
- int n1, n2;
- check_args(input, normalized_shape, n1, n2);
- at::Tensor output = at::empty_like(input, gamma.options().dtype(gamma.scalar_type()));
- at::Tensor invvar = at::empty({n1}, input.options().dtype(input.scalar_type() == at::ScalarType::Half || input.scalar_type() == at::ScalarType::BFloat16 ? at::ScalarType::Float : input.scalar_type()));
- cuda_rms_norm(&output,&invvar, &input, n1, n2,
- normalized_shape, &gamma,epsilon);
- return {output,invvar};
- }
- void cuda_rms_norm_gradient(
- at::Tensor* dout,
- at::Tensor* invvar,
- at::Tensor* input_or_output,
- int n1,
- int n2,
- #ifdef VERSION_GE_1_1
- at::IntArrayRef normalized_shape,
- #else
- at::IntList normalized_shape,
- #endif
- at::Tensor* gamma,
- double epsilon,
- at::Tensor* grad_input,
- at::Tensor* grad_gamma,
- bool memory_efficient);
- at::Tensor rms_norm_gradient(
- at::Tensor dout,
- at::Tensor invvar,
- at::Tensor input_or_output,
- #ifdef VERSION_GE_1_1
- at::IntArrayRef normalized_shape,
- #else
- at::IntList normalized_shape,
- #endif
- double epsilon,
- bool memory_efficient) {
- CHECK_INPUT(dout);
- CHECK_INPUT(invvar);
- CHECK_INPUT(input_or_output);
- int n1,n2;
- check_args(input_or_output,normalized_shape,n1,n2);
- at::Tensor grad_input = at::empty_like(input_or_output);
- cuda_rms_norm_gradient(&dout,&invvar,&input_or_output,n1,n2,
- normalized_shape,NULL,epsilon,
- &grad_input,NULL,memory_efficient);
- return grad_input;
- }
- std::vector<at::Tensor> rms_norm_gradient_affine(
- at::Tensor dout,
- at::Tensor invvar,
- at::Tensor input_or_output,
- #ifdef VERSION_GE_1_1
- at::IntArrayRef normalized_shape,
- #else
- at::IntList normalized_shape,
- #endif
- at::Tensor gamma,
- double epsilon,
- bool memory_efficient) {
- CHECK_INPUT(dout);
- CHECK_INPUT(invvar);
- CHECK_INPUT(input_or_output);
- CHECK_INPUT(gamma);
- int n1,n2;
- check_args(input_or_output,normalized_shape,gamma,n1,n2);
- at::Tensor grad_input = at::empty_like(input_or_output);
- at::Tensor grad_gamma = at::empty_like(gamma);
- cuda_rms_norm_gradient(&dout,&invvar,&input_or_output,n1,n2,
- normalized_shape,&gamma,epsilon,
- &grad_input,&grad_gamma,memory_efficient);
- return {grad_input, grad_gamma};
- }
- PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
- m.def("forward_affine", &layer_norm_affine, "LayerNorm forward (CUDA)");
- m.def("forward", &layer_norm, "LayerNorm forward (CUDA)");
- m.def("backward_affine", &layer_norm_gradient_affine, "LayerNorm backward (CUDA)");
- m.def("backward", &layer_norm_gradient, "LayerNorm backward (CUDA)");
- m.def("forward_affine_mixed_dtypes", &layer_norm_affine_mixed_dtypes, "LayerNorm forward with mixed dtypes (CUDA) compatible with Megatron's implementation");
- m.def("rms_forward_affine", &rms_norm_affine, "RMSNorm forward (CUDA)");
- m.def("rms_forward", &rms_norm, "RMSNorm forward (CUDA)");
- m.def("rms_backward_affine", &rms_norm_gradient_affine, "RMSNorm backward (CUDA)");
- m.def("rms_backward", &rms_norm_gradient, "RMSNorm backward (CUDA)");
- m.def("rms_forward_affine_mixed_dtypes", &rms_norm_affine_mixed_dtypes, "RMSNorm forward with mixed dtypes (CUDA) compatible with Megatron's implementation");
- }
|