3
0

multi_tensor_novograd.cu 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. #include <ATen/ATen.h>
  2. #include <ATen/AccumulateType.h>
  3. #include <ATen/cuda/CUDAContext.h>
  4. #include <ATen/cuda/Exceptions.h>
  5. // Another possibility:
  6. // #include <torch/all.h>
  7. #include <assert.h>
  8. #include "type_shim.h"
  9. #include "multi_tensor_apply.cuh"
  10. #define BLOCK_SIZE 512
  11. #define ILP 4
  12. typedef enum{
  13. MOMENT_MODE_0 =0, // Novograd paper mode, momentum caculation with denom then decay inside
  14. MOMENT_MODE_1 =1 // Decoupled weight decay mode
  15. } momentMode_t;
  16. void multi_tensor_norm_out_cuda(
  17. int chunk_size,
  18. at::Tensor noop_flag,
  19. std::vector<std::vector<at::Tensor>> tensor_lists,
  20. at::Tensor out,
  21. const float alpha,
  22. const float beta,
  23. const int norm_type);
  24. using MATH_T = float;
  25. template<typename T>
  26. struct NovoGradFunctor
  27. {
  28. __device__ __forceinline__ void operator()(
  29. int chunk_size,
  30. volatile int* noop_gmem,
  31. TensorListMetadata<3>& tl,
  32. const float beta1,
  33. const float beta2,
  34. const float beta3,
  35. const float beta1_correction,
  36. const float beta2_correction,
  37. const float epsilon,
  38. const float lr,
  39. momentMode_t m_mode,
  40. const float decay,
  41. const float* per_tensor_grad_norm)
  42. {
  43. // I'd like this kernel to propagate infs/nans.
  44. // if(*noop_gmem == 1)
  45. // return;
  46. int tensor_loc = tl.block_to_tensor[blockIdx.x];
  47. int tensor_num = tl.start_tensor_this_launch + tensor_loc;
  48. int chunk_idx = tl.block_to_chunk[blockIdx.x];
  49. int n = tl.sizes[tensor_loc];
  50. float grad_norm = per_tensor_grad_norm[tensor_num];
  51. T* g = (T*)tl.addresses[0][tensor_loc];
  52. g += chunk_idx*chunk_size;
  53. T* p = (T*)tl.addresses[1][tensor_loc];
  54. p += chunk_idx*chunk_size;
  55. T* m = (T*)tl.addresses[2][tensor_loc];
  56. m += chunk_idx*chunk_size;
  57. n -= chunk_idx*chunk_size;
  58. // see note in multi_tensor_scale_kernel.cu
  59. for(int i_start = 0;
  60. i_start < n && i_start < chunk_size;
  61. i_start += blockDim.x*ILP)
  62. {
  63. MATH_T r_g[ILP];
  64. MATH_T r_p[ILP];
  65. MATH_T r_m[ILP];
  66. #pragma unroll
  67. for(int ii = 0; ii < ILP; ii++)
  68. {
  69. int i = i_start + threadIdx.x + ii*blockDim.x;
  70. if(i < n && i < chunk_size)
  71. {
  72. r_g[ii] = g[i];
  73. r_p[ii] = p[i];
  74. r_m[ii] = m[i];
  75. } else {
  76. r_g[ii] = MATH_T(0);
  77. r_p[ii] = MATH_T(0);
  78. r_m[ii] = MATH_T(0);
  79. }
  80. }
  81. #pragma unroll
  82. for(int ii = 0; ii < ILP; ii++)
  83. {
  84. if (m_mode == MOMENT_MODE_0) {
  85. MATH_T next_v_unbiased = grad_norm / beta2_correction;
  86. MATH_T denom = next_v_unbiased + epsilon;
  87. r_g[ii] = (r_g[ii] / denom) + (decay * r_p[ii]);
  88. r_m[ii] = beta1 * r_m[ii] + beta3 * r_g[ii];
  89. MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
  90. r_p[ii] = r_p[ii] - (lr * next_m_unbiased);
  91. }
  92. else {
  93. r_m[ii] = beta1 * r_m[ii] + beta3 * r_g[ii];
  94. MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
  95. MATH_T next_v_unbiased = grad_norm / beta2_correction;
  96. MATH_T denom = next_v_unbiased + epsilon;
  97. MATH_T update = (next_m_unbiased / denom) + (decay * r_p[ii]);
  98. r_p[ii] = r_p[ii] - (lr * update);
  99. }
  100. }
  101. #pragma unroll
  102. for(int ii = 0; ii < ILP; ii++)
  103. {
  104. int i = i_start + threadIdx.x + ii*blockDim.x;
  105. if(i < n && i < chunk_size)
  106. {
  107. p[i] = r_p[ii];
  108. m[i] = r_m[ii];
  109. }
  110. }
  111. }
  112. }
  113. };
  114. void multi_tensor_novograd_cuda(
  115. int chunk_size,
  116. at::Tensor noop_flag,
  117. std::vector<std::vector<at::Tensor>> tensor_lists,
  118. at::Tensor grad_norms,
  119. const float lr,
  120. const float beta1,
  121. const float beta2,
  122. const float epsilon,
  123. const int step,
  124. const int bias_correction,
  125. const float weight_decay,
  126. const int grad_averaging,
  127. const int moment_mode,
  128. const int norm_type)
  129. {
  130. using namespace at;
  131. // Handle bias correction mode
  132. float bias_correction1 = 1.0f, bias_correction2 = 1.0f;
  133. if (bias_correction == 1) {
  134. bias_correction1 = 1 - std::pow(beta1, step);
  135. bias_correction2 = std::sqrt(1 - std::pow(beta2, step));
  136. }
  137. // Handle grad averaging mode
  138. float beta3 = 1;
  139. if (grad_averaging == 1) beta3 = 1 - beta1;
  140. std::vector<std::vector<at::Tensor>> grad_list(tensor_lists.begin(), tensor_lists.begin()+1);
  141. // Compute and update grad norm
  142. // Here use a per tensor norm, and blend new norm(n) and old norm(gn) by
  143. // L-2: gn = sqrt(a * gn^2 + b * n^2)
  144. // L-inf: gn = a * gn + b * n
  145. multi_tensor_norm_out_cuda(chunk_size, noop_flag, grad_list, grad_norms, beta2, (1.0f - beta2), norm_type);
  146. // Assume single type across p,g,m1,m2 now
  147. DISPATCH_DOUBLE_FLOAT_AND_HALF(
  148. tensor_lists[0][0].scalar_type(), 0, "novograd",
  149. multi_tensor_apply<3>(
  150. BLOCK_SIZE,
  151. chunk_size,
  152. noop_flag,
  153. tensor_lists,
  154. NovoGradFunctor<scalar_t_0>(),
  155. beta1,
  156. beta2,
  157. beta3, // 1-beta1 or 1 depends on averaging mode
  158. bias_correction1,
  159. bias_correction2,
  160. epsilon,
  161. lr,
  162. (momentMode_t) moment_mode,
  163. weight_decay,
  164. grad_norms.DATA_PTR<float>()); )
  165. AT_CUDA_CHECK(cudaGetLastError());
  166. }