3
0

multi_tensor_adagrad.cu 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100
  1. #include <ATen/ATen.h>
  2. #include <ATen/AccumulateType.h>
  3. #include <ATen/cuda/CUDAContext.h>
  4. #include <ATen/cuda/Exceptions.h>
  5. // Another possibility:
  6. // #include <torch/all.h>
  7. #include <assert.h>
  8. #include "multi_tensor_apply.cuh"
  9. #include "type_shim.h"
  10. #define BLOCK_SIZE 1024
  11. #define ILP 4
  12. typedef enum {
  13. ADAGRAD_MODE_0 = 0, // L2 regularization mode.
  14. ADAGRAD_MODE_1 = 1, // AdamW-style weight decay.
  15. } adagradMode_t;
  16. using MATH_T = float;
  17. template <typename T> struct AdagradFunctor {
  18. __device__ __forceinline__ void
  19. operator()(int chunk_size, volatile int *noop_gmem, TensorListMetadata<3> &tl,
  20. const float epsilon, const float lr, adagradMode_t mode,
  21. const float weight_decay) {
  22. int tensor_loc = tl.block_to_tensor[blockIdx.x];
  23. int chunk_idx = tl.block_to_chunk[blockIdx.x];
  24. int n = tl.sizes[tensor_loc];
  25. T *g = (T *)tl.addresses[0][tensor_loc];
  26. g += chunk_idx * chunk_size;
  27. T *p = (T *)tl.addresses[1][tensor_loc];
  28. p += chunk_idx * chunk_size;
  29. T *h = (T *)tl.addresses[2][tensor_loc];
  30. h += chunk_idx * chunk_size;
  31. n -= chunk_idx * chunk_size;
  32. // see note in multi_tensor_scale_kernel.cu
  33. for (int i_start = 0; i_start < n && i_start < chunk_size;
  34. i_start += blockDim.x * ILP) {
  35. MATH_T r_g[ILP];
  36. MATH_T r_p[ILP];
  37. MATH_T r_h[ILP];
  38. #pragma unroll
  39. for (int ii = 0; ii < ILP; ii++) {
  40. int i = i_start + threadIdx.x + ii * blockDim.x;
  41. if (i < n && i < chunk_size) {
  42. r_g[ii] = g[i];
  43. r_p[ii] = p[i];
  44. r_h[ii] = h[i];
  45. } else {
  46. r_g[ii] = MATH_T(0);
  47. r_p[ii] = MATH_T(0);
  48. r_h[ii] = MATH_T(0);
  49. }
  50. }
  51. #pragma unroll
  52. for (int ii = 0; ii < ILP; ii++) {
  53. if (mode == ADAGRAD_MODE_0) { // L2
  54. r_g[ii] = r_g[ii] + weight_decay * r_p[ii];
  55. r_h[ii] = r_h[ii] + r_g[ii] * r_g[ii];
  56. r_p[ii] = r_p[ii] - lr * (r_g[ii] / (sqrtf(r_h[ii]) + epsilon));
  57. } else { // AdamW-style
  58. r_h[ii] = r_h[ii] + r_g[ii] * r_g[ii];
  59. r_p[ii] = r_p[ii] - lr * (r_g[ii] / (sqrtf(r_h[ii]) + epsilon) + weight_decay * r_p[ii]);
  60. }
  61. }
  62. #pragma unroll
  63. for (int ii = 0; ii < ILP; ii++) {
  64. int i = i_start + threadIdx.x + ii * blockDim.x;
  65. if (i < n && i < chunk_size) {
  66. p[i] = r_p[ii];
  67. h[i] = r_h[ii];
  68. }
  69. }
  70. }
  71. }
  72. };
  73. void multi_tensor_adagrad_cuda(
  74. int chunk_size, at::Tensor noop_flag,
  75. std::vector<std::vector<at::Tensor>> tensor_lists, const float lr,
  76. const float epsilon, const int mode, const float weight_decay) {
  77. using namespace at;
  78. // Assume single type across p,g,h now
  79. DISPATCH_DOUBLE_FLOAT_AND_HALF(
  80. tensor_lists[0][0].scalar_type(), 0, "adagrad",
  81. multi_tensor_apply<3>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
  82. AdagradFunctor<scalar_t_0>(), epsilon, lr,
  83. (adagradMode_t)mode, weight_decay);)
  84. AT_CUDA_CHECK(cudaGetLastError());
  85. }