123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513 |
- #include <ATen/ATen.h>
- #include <ATen/AccumulateType.h>
- #include <ATen/cuda/CUDAContext.h>
- #include <ATen/cuda/Exceptions.h>
- // Another possibility:
- // #include <torch/all.h>
- #include <assert.h>
- #include "type_shim.h"
- #include "multi_tensor_apply.cuh"
- #define BLOCK_SIZE 512
- #define ILP 4
- typedef enum{
- ADAM_MODE_0 =0, // L2 regularization mode
- ADAM_MODE_1 =1 // Decoupled weight decay mode(AdamW)
- } adamMode_t;
- using MATH_T = float;
- template<typename T, typename FULL_T, typename index_t>
- struct AdamFunctor
- {
- __device__ __forceinline__ void operator()(
- index_t chunk_size,
- volatile int* noop_gmem,
- TensorListMetadata<4>& tl,
- const float beta1,
- const float beta2,
- const float beta1_correction,
- const float beta2_correction,
- const float epsilon,
- const float lr,
- adamMode_t mode,
- const float decay)
- {
- // I'd like this kernel to propagate infs/nans.
- // if(*noop_gmem == 1)
- // return;
- index_t tensor_loc = tl.block_to_tensor[blockIdx.x];
- // potentially use to pass in list of scalar
- // int tensor_num = tl.start_tensor_this_launch + tensor_loc;
- index_t chunk_idx = tl.block_to_chunk[blockIdx.x];
- index_t n = tl.sizes[tensor_loc];
- T* g = (T*)tl.addresses[0][tensor_loc];
- g += chunk_idx*chunk_size;
- T* p = (T*)tl.addresses[1][tensor_loc];
- p += chunk_idx*chunk_size;
- FULL_T* m = (FULL_T*)tl.addresses[2][tensor_loc];
- m += chunk_idx*chunk_size;
- FULL_T* v = (FULL_T*)tl.addresses[3][tensor_loc];
- v += chunk_idx*chunk_size;
- n -= chunk_idx*chunk_size;
- // see note in multi_tensor_scale_kernel.cu
- for(index_t i_start = 0;
- i_start < n && i_start < chunk_size;
- i_start += blockDim.x*ILP)
- {
- MATH_T r_g[ILP];
- MATH_T r_p[ILP];
- MATH_T r_m[ILP];
- MATH_T r_v[ILP];
- #pragma unroll
- for(int ii = 0; ii < ILP; ii++)
- {
- int i = i_start + threadIdx.x + ii*blockDim.x;
- if(i < n && i < chunk_size)
- {
- r_g[ii] = g[i];
- r_p[ii] = p[i];
- r_m[ii] = m[i];
- r_v[ii] = v[i];
- } else {
- r_g[ii] = MATH_T(0);
- r_p[ii] = MATH_T(0);
- r_m[ii] = MATH_T(0);
- r_v[ii] = MATH_T(0);
- }
- }
- #pragma unroll
- for(int ii = 0; ii < ILP; ii++)
- {
- if(mode == ADAM_MODE_0) { // L2
- r_g[ii] = r_g[ii] + (decay * r_p[ii]);
- r_m[ii] = beta1 * r_m[ii] + (1-beta1) * r_g[ii];
- r_v[ii] = beta2 * r_v[ii] + (1-beta2) * r_g[ii] * r_g[ii];
- MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
- MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
- MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
- MATH_T update = next_m_unbiased / denom;
- r_p[ii] = r_p[ii] - (lr * update);
- }
- else { // weight decay
- r_m[ii] = beta1 * r_m[ii] + (1-beta1) * r_g[ii];
- r_v[ii] = beta2 * r_v[ii] + (1-beta2) * r_g[ii] * r_g[ii];
- MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
- MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
- MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
- MATH_T update = (next_m_unbiased / denom) + (decay * r_p[ii]);
- r_p[ii] = r_p[ii] - (lr * update);
- }
- }
- #pragma unroll
- for(int ii = 0; ii < ILP; ii++)
- {
- int i = i_start + threadIdx.x + ii*blockDim.x;
- if(i < n && i < chunk_size)
- {
- p[i] = r_p[ii];
- m[i] = r_m[ii];
- v[i] = r_v[ii];
- }
- }
- }
- }
- };
- template<typename T, typename FULL_T>
- struct AdamCapturableFunctor
- {
- __device__ __forceinline__ void operator()(
- int chunk_size,
- volatile int* noop_gmem,
- TensorListMetadata<4>& tl,
- const float beta1,
- const float beta2,
- const int* step,
- const int bias_correction,
- const float epsilon,
- const float* lr,
- adamMode_t mode,
- const float decay,
- const float* inv_scale)
- {
- if(*noop_gmem == 1)
- return;
- float beta1_correction = 1.0f, beta2_correction = 1.0f;
- if (bias_correction == 1) {
- beta1_correction = 1 - pow(beta1, *step);
- beta2_correction = 1 - pow(beta2, *step);
- }
- int tensor_loc = tl.block_to_tensor[blockIdx.x];
- // potentially use to pass in list of scalar
- // int tensor_num = tl.start_tensor_this_launch + tensor_loc;
- int chunk_idx = tl.block_to_chunk[blockIdx.x];
- int n = tl.sizes[tensor_loc];
- T* g = (T*)tl.addresses[0][tensor_loc];
- g += chunk_idx*chunk_size;
- T* p = (T*)tl.addresses[1][tensor_loc];
- p += chunk_idx*chunk_size;
- FULL_T* m = (FULL_T*)tl.addresses[2][tensor_loc];
- m += chunk_idx*chunk_size;
- FULL_T* v = (FULL_T*)tl.addresses[3][tensor_loc];
- v += chunk_idx*chunk_size;
- n -= chunk_idx*chunk_size;
- // see note in multi_tensor_scale_kernel.cu
- for(int i_start = 0;
- i_start < n && i_start < chunk_size;
- i_start += blockDim.x*ILP)
- {
- MATH_T r_g[ILP];
- MATH_T r_p[ILP];
- MATH_T r_m[ILP];
- MATH_T r_v[ILP];
- #pragma unroll
- for(int ii = 0; ii < ILP; ii++)
- {
- int i = i_start + threadIdx.x + ii*blockDim.x;
- if(i < n && i < chunk_size)
- {
- r_g[ii] = static_cast<MATH_T>(g[i]) * (*inv_scale);
- g[i] = static_cast<T>(r_g[ii]);
- r_p[ii] = static_cast<MATH_T>(p[i]);
- r_m[ii] = static_cast<MATH_T>(m[i]);
- r_v[ii] = static_cast<MATH_T>(v[i]);
- } else {
- r_g[ii] = MATH_T(0);
- r_p[ii] = MATH_T(0);
- r_m[ii] = MATH_T(0);
- r_v[ii] = MATH_T(0);
- }
- }
- #pragma unroll
- for(int ii = 0; ii < ILP; ii++)
- {
- if(mode == ADAM_MODE_0) { // L2
- r_g[ii] = r_g[ii] + (decay * r_p[ii]);
- r_m[ii] = beta1 * r_m[ii] + (1-beta1) * r_g[ii];
- r_v[ii] = beta2 * r_v[ii] + (1-beta2) * r_g[ii] * r_g[ii];
- MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
- MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
- MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
- MATH_T update = next_m_unbiased / denom;
- r_p[ii] = r_p[ii] - (*lr * update);
- }
- else { // weight decay
- r_m[ii] = beta1 * r_m[ii] + (1-beta1) * r_g[ii];
- r_v[ii] = beta2 * r_v[ii] + (1-beta2) * r_g[ii] * r_g[ii];
- MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
- MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
- MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
- MATH_T update = (next_m_unbiased / denom) + (decay * r_p[ii]);
- r_p[ii] = r_p[ii] - (*lr * update);
- }
- }
- #pragma unroll
- for(int ii = 0; ii < ILP; ii++)
- {
- int i = i_start + threadIdx.x + ii*blockDim.x;
- if(i < n && i < chunk_size)
- {
- p[i] = static_cast<T>(r_p[ii]);
- m[i] = static_cast<T>(r_m[ii]);
- v[i] = static_cast<T>(r_v[ii]);
- }
- }
- }
- }
- };
- template<typename T, typename FULL_T>
- struct AdamCapturableMasterFunctor
- {
- __device__ __forceinline__ void operator()(
- int chunk_size,
- volatile int* noop_gmem,
- TensorListMetadata<5>& tl,
- const float beta1,
- const float beta2,
- const int* step,
- const int bias_correction,
- const float epsilon,
- const float* lr,
- adamMode_t mode,
- const float decay,
- const float* inv_scale)
- {
- if(*noop_gmem == 1)
- return;
- float beta1_correction = 1.0f, beta2_correction = 1.0f;
- if (bias_correction == 1) {
- beta1_correction = 1 - pow(beta1, *step);
- beta2_correction = 1 - pow(beta2, *step);
- }
- int tensor_loc = tl.block_to_tensor[blockIdx.x];
- // potentially use to pass in list of scalar
- // int tensor_num = tl.start_tensor_this_launch + tensor_loc;
- int chunk_idx = tl.block_to_chunk[blockIdx.x];
- int n = tl.sizes[tensor_loc];
- T* g = (T*)tl.addresses[0][tensor_loc];
- g += chunk_idx*chunk_size;
- T* p = (T*)tl.addresses[1][tensor_loc];
- p += chunk_idx*chunk_size;
- FULL_T* m = (FULL_T*)tl.addresses[2][tensor_loc];
- m += chunk_idx*chunk_size;
- FULL_T* v = (FULL_T*)tl.addresses[3][tensor_loc];
- v += chunk_idx*chunk_size;
- FULL_T* p_master = (FULL_T*)tl.addresses[4][tensor_loc];
- p_master += chunk_idx*chunk_size;
- n -= chunk_idx*chunk_size;
- // see note in multi_tensor_scale_kernel.cu
- for(int i_start = 0;
- i_start < n && i_start < chunk_size;
- i_start += blockDim.x*ILP)
- {
- MATH_T r_g[ILP];
- MATH_T r_p[ILP];
- MATH_T r_m[ILP];
- MATH_T r_v[ILP];
- #pragma unroll
- for(int ii = 0; ii < ILP; ii++)
- {
- int i = i_start + threadIdx.x + ii*blockDim.x;
- if(i < n && i < chunk_size)
- {
- r_g[ii] = static_cast<MATH_T>(g[i]) * (*inv_scale);
- g[i] = static_cast<T>(r_g[ii]);
- r_p[ii] = static_cast<MATH_T>(p_master[i]);
- r_m[ii] = static_cast<MATH_T>(m[i]);
- r_v[ii] = static_cast<MATH_T>(v[i]);
- } else {
- r_g[ii] = MATH_T(0);
- r_p[ii] = MATH_T(0);
- r_m[ii] = MATH_T(0);
- r_v[ii] = MATH_T(0);
- }
- }
- #pragma unroll
- for(int ii = 0; ii < ILP; ii++)
- {
- if(mode == ADAM_MODE_0) { // L2
- r_g[ii] = r_g[ii] + (decay * r_p[ii]);
- r_m[ii] = beta1 * r_m[ii] + (1-beta1) * r_g[ii];
- r_v[ii] = beta2 * r_v[ii] + (1-beta2) * r_g[ii] * r_g[ii];
- MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
- MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
- MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
- MATH_T update = next_m_unbiased / denom;
- r_p[ii] = r_p[ii] - (*lr * update);
- }
- else { // weight decay
- r_m[ii] = beta1 * r_m[ii] + (1-beta1) * r_g[ii];
- r_v[ii] = beta2 * r_v[ii] + (1-beta2) * r_g[ii] * r_g[ii];
- MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
- MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
- MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
- MATH_T update = (next_m_unbiased / denom) + (decay * r_p[ii]);
- r_p[ii] = r_p[ii] - (*lr * update);
- }
- }
- #pragma unroll
- for(int ii = 0; ii < ILP; ii++)
- {
- int i = i_start + threadIdx.x + ii*blockDim.x;
- if(i < n && i < chunk_size)
- {
- p[i] = static_cast<T>(r_p[ii]);
- p_master[i] = static_cast<FULL_T>(r_p[ii]);
- m[i] = static_cast<FULL_T>(r_m[ii]);
- v[i] = static_cast<FULL_T>(r_v[ii]);
- }
- }
- }
- }
- };
- void multi_tensor_adam_cuda(
- int chunk_size,
- at::Tensor noop_flag,
- std::vector<std::vector<at::Tensor>> tensor_lists,
- const float lr,
- const float beta1,
- const float beta2,
- const float epsilon,
- const int step,
- const int mode,
- const int bias_correction,
- const float weight_decay)
- {
- using namespace at;
- // Handle bias correction mode
- float bias_correction1 = 1.0f, bias_correction2 = 1.0f;
- if (bias_correction == 1) {
- bias_correction1 = 1 - std::pow(beta1, step);
- bias_correction2 = 1 - std::pow(beta2, step);
- }
- size_t max_size = 0;
- bool requires_64bit_indexing = false;
- for (auto it = tensor_lists.begin(); it != tensor_lists.end(); it++) {
- for (auto it2 = it->begin(); it2 != it->end(); it2++) {
- if (it2->numel() > max_size) {
- max_size = it2->numel();
- if (max_size >= INT_MAX) {
- requires_64bit_indexing = true;
- break;
- }
- }
- }
- if (requires_64bit_indexing) {
- break;
- }
- }
- if (requires_64bit_indexing) {
- // Assume single type across p,g,m1,m2 now
- DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
- tensor_lists[0][0].scalar_type(), 0, "adam",
- multi_tensor_apply<4>(
- (int64_t) BLOCK_SIZE,
- (int64_t) chunk_size,
- noop_flag,
- tensor_lists,
- AdamFunctor<scalar_t_0, float, int64_t>(),
- beta1,
- beta2,
- bias_correction1,
- bias_correction2,
- epsilon,
- lr,
- (adamMode_t) mode,
- weight_decay); )
- } else {
- // Assume single type across p,g,m1,m2 now
- DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
- tensor_lists[0][0].scalar_type(), 0, "adam",
- multi_tensor_apply<4>(
- BLOCK_SIZE,
- chunk_size,
- noop_flag,
- tensor_lists,
- AdamFunctor<scalar_t_0, float, int32_t>(),
- beta1,
- beta2,
- bias_correction1,
- bias_correction2,
- epsilon,
- lr,
- (adamMode_t) mode,
- weight_decay); )
- }
- AT_CUDA_CHECK(cudaGetLastError());
- }
- void multi_tensor_adam_capturable_cuda(
- int chunk_size,
- at::Tensor noop_flag,
- std::vector<std::vector<at::Tensor>> tensor_lists,
- at::Tensor lr,
- const float beta1,
- const float beta2,
- const float epsilon,
- at::Tensor step,
- const int mode,
- const int bias_correction,
- const float weight_decay,
- at::Tensor inv_scale)
- {
- using namespace at;
- DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
- tensor_lists[0][0].scalar_type(), 0, "adam",
- multi_tensor_apply<4>(
- BLOCK_SIZE,
- chunk_size,
- noop_flag,
- tensor_lists,
- AdamCapturableFunctor<scalar_t_0, float>(),
- beta1,
- beta2,
- step.data_ptr<int>(),
- bias_correction,
- epsilon,
- lr.data_ptr<float>(),
- (adamMode_t) mode,
- weight_decay,
- inv_scale.data_ptr<float>()); )
- AT_CUDA_CHECK(cudaGetLastError());
- }
- void multi_tensor_adam_capturable_master_cuda(
- int chunk_size,
- at::Tensor noop_flag,
- std::vector<std::vector<at::Tensor>> tensor_lists,
- at::Tensor lr,
- const float beta1,
- const float beta2,
- const float epsilon,
- at::Tensor step,
- const int mode,
- const int bias_correction,
- const float weight_decay,
- at::Tensor inv_scale)
- {
- using namespace at;
- DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
- tensor_lists[0][0].scalar_type(), 0, "adam",
- multi_tensor_apply<5>(
- BLOCK_SIZE,
- chunk_size,
- noop_flag,
- tensor_lists,
- AdamCapturableMasterFunctor<scalar_t_0, float>(),
- beta1,
- beta2,
- step.data_ptr<int>(),
- bias_correction,
- epsilon,
- lr.data_ptr<float>(),
- (adamMode_t) mode,
- weight_decay,
- inv_scale.data_ptr<float>()); )
- AT_CUDA_CHECK(cudaGetLastError());
- }
|