multi_tensor_lamb.cu 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413
  1. #include <ATen/ATen.h>
  2. #include <ATen/AccumulateType.h>
  3. #include <ATen/cuda/CUDAContext.h>
  4. #include <ATen/cuda/Exceptions.h>
  5. // Another possibility:
  6. // #include <torch/all.h>
  7. #include <assert.h>
  8. #include "type_shim.h"
  9. #include "multi_tensor_apply.cuh"
  10. #define BLOCK_SIZE 512
  11. #define ILP 4
  12. template<typename T>
  13. __device__ __forceinline__ bool is_aligned(T* p){
  14. return ((uint64_t)p) % (ILP*sizeof(T)) == 0;
  15. }
  16. template<typename T>
  17. __device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int src_offset){
  18. typedef typename std::aligned_storage<ILP*sizeof(T), ILP*alignof(T)>::type LT;
  19. ((LT*)dst)[dst_offset] = ((LT*)src)[src_offset];
  20. }
  21. typedef enum{
  22. MOMENT_MODE_0 =0, // L2 regularization mode
  23. MOMENT_MODE_1 =1 // Decoupled weight decay mode
  24. } adamMode_t;
  25. std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_cuda(
  26. int chunk_size,
  27. at::Tensor noop_flag,
  28. std::vector<std::vector<at::Tensor>> tensor_lists,
  29. at::optional<bool> per_tensor_python);
  30. using MATH_T = float;
  31. template<typename T>
  32. struct LAMBStage1Functor
  33. {
  34. __device__ __forceinline__ void operator()(
  35. int chunk_size,
  36. volatile int* noop_gmem,
  37. TensorListMetadata<4>& tl,
  38. const float beta1,
  39. const float beta2,
  40. const float beta3,
  41. const float beta1_correction,
  42. const float beta2_correction,
  43. const float epsilon,
  44. adamMode_t mode,
  45. const float decay,
  46. const float* global_grad_norm,
  47. const float max_global_grad_norm)
  48. {
  49. // I'd like this kernel to propagate infs/nans.
  50. // if(*noop_gmem == 1)
  51. // return;
  52. int tensor_loc = tl.block_to_tensor[blockIdx.x];
  53. int chunk_idx = tl.block_to_chunk[blockIdx.x];
  54. int n = tl.sizes[tensor_loc];
  55. float clipped_global_grad_norm = (*global_grad_norm) > max_global_grad_norm ? (*global_grad_norm) / max_global_grad_norm : 1.0f;
  56. T* g = (T*)tl.addresses[0][tensor_loc];
  57. g += chunk_idx*chunk_size;
  58. T* p = (T*)tl.addresses[1][tensor_loc];
  59. p += chunk_idx*chunk_size;
  60. T* m = (T*)tl.addresses[2][tensor_loc];
  61. m += chunk_idx*chunk_size;
  62. T* v = (T*)tl.addresses[3][tensor_loc];
  63. v += chunk_idx*chunk_size;
  64. n -= chunk_idx*chunk_size;
  65. MATH_T r_g[ILP];
  66. MATH_T r_p[ILP];
  67. MATH_T r_m[ILP];
  68. MATH_T r_v[ILP];
  69. // to make things simple, we put aligned case in a different code path
  70. if(n % ILP == 0 &&
  71. chunk_size % ILP == 0 &&
  72. is_aligned(g) &&
  73. is_aligned(p) &&
  74. is_aligned(m) &&
  75. is_aligned(v))
  76. {
  77. T l_g[ILP];
  78. T l_p[ILP];
  79. T l_m[ILP];
  80. T l_v[ILP];
  81. for(int i_start = threadIdx.x; i_start*ILP < n && i_start*ILP < chunk_size; i_start += blockDim.x)
  82. {
  83. // load
  84. load_store(l_g, g, 0, i_start);
  85. if (decay != 0)
  86. load_store(l_p, p, 0, i_start);
  87. load_store(l_m, m, 0, i_start);
  88. load_store(l_v, v, 0, i_start);
  89. // unpack
  90. #pragma unroll
  91. for(int ii = 0; ii < ILP; ii++)
  92. {
  93. r_g[ii] = l_g[ii];
  94. if (decay == 0) {
  95. r_p[ii] = MATH_T(0);
  96. }
  97. else {
  98. r_p[ii] = l_p[ii];
  99. }
  100. r_m[ii] = l_m[ii];
  101. r_v[ii] = l_v[ii];
  102. }
  103. #pragma unroll
  104. for(int ii = 0; ii < ILP; ii++)
  105. {
  106. if (mode == MOMENT_MODE_0) {
  107. MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm;
  108. // L2 on scaled grad
  109. scaled_grad = scaled_grad + decay*r_p[ii];
  110. r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;
  111. r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad;
  112. MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
  113. MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
  114. MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
  115. r_p[ii] = next_m_unbiased / denom;
  116. }
  117. else {
  118. MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm;
  119. r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;
  120. r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad;
  121. MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
  122. MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
  123. MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
  124. r_p[ii] = (next_m_unbiased/denom) + (decay*r_p[ii]);
  125. }
  126. }
  127. #pragma unroll
  128. for(int ii = 0; ii < ILP; ii++)
  129. {
  130. l_p[ii] = r_p[ii];
  131. l_m[ii] = r_m[ii];
  132. l_v[ii] = r_v[ii];
  133. }
  134. // store
  135. load_store(g, l_p, i_start, 0);
  136. load_store(m, l_m, i_start, 0);
  137. load_store(v, l_v, i_start, 0);
  138. }
  139. }
  140. else
  141. {
  142. // see note in multi_tensor_scale_kernel.cu
  143. for(int i_start = 0;
  144. i_start < n && i_start < chunk_size;
  145. i_start += blockDim.x*ILP)
  146. {
  147. MATH_T r_g[ILP];
  148. MATH_T r_p[ILP];
  149. MATH_T r_m[ILP];
  150. MATH_T r_v[ILP];
  151. #pragma unroll
  152. for(int ii = 0; ii < ILP; ii++)
  153. {
  154. int i = i_start + threadIdx.x + ii*blockDim.x;
  155. if(i < n && i < chunk_size)
  156. {
  157. r_g[ii] = g[i];
  158. // special ?optimization? for lamb stage 1
  159. if (decay == 0) {
  160. r_p[ii] = MATH_T(0);
  161. }
  162. else {
  163. r_p[ii] = p[i];
  164. }
  165. r_m[ii] = m[i];
  166. r_v[ii] = v[i];
  167. } else {
  168. r_g[ii] = MATH_T(0);
  169. r_p[ii] = MATH_T(0);
  170. r_m[ii] = MATH_T(0);
  171. r_v[ii] = MATH_T(0);
  172. }
  173. }
  174. #pragma unroll
  175. for(int ii = 0; ii < ILP; ii++)
  176. {
  177. if (mode == MOMENT_MODE_0) {
  178. MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm;
  179. // L2 on scaled grad
  180. scaled_grad = scaled_grad + decay*r_p[ii];
  181. r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;
  182. r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad;
  183. MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
  184. MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
  185. MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
  186. r_p[ii] = next_m_unbiased / denom;
  187. }
  188. else {
  189. MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm;
  190. r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;
  191. r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad;
  192. MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
  193. MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
  194. MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
  195. r_p[ii] = (next_m_unbiased/denom) + (decay*r_p[ii]);
  196. }
  197. }
  198. #pragma unroll
  199. for(int ii = 0; ii < ILP; ii++)
  200. {
  201. int i = i_start + threadIdx.x + ii*blockDim.x;
  202. if(i < n && i < chunk_size)
  203. {
  204. g[i] = r_p[ii];
  205. m[i] = r_m[ii];
  206. v[i] = r_v[ii];
  207. }
  208. }
  209. }
  210. }
  211. }
  212. };
  213. // Step 2 reads in 'update' value and per-tensor param_norm and update_norm.
  214. // It computes new parameter value.
  215. template<typename T>
  216. struct LAMBStage2Functor
  217. {
  218. __device__ __forceinline__ void operator()(
  219. int chunk_size,
  220. volatile int* noop_gmem,
  221. TensorListMetadata<2>& tl,
  222. const float* per_tensor_param_norm,
  223. const float* per_tensor_update_norm,
  224. const float learning_rate,
  225. const float decay,
  226. bool use_nvlamb)
  227. {
  228. // I'd like this kernel to propagate infs/nans.
  229. // if(*noop_gmem == 1)
  230. // return;
  231. int tensor_loc = tl.block_to_tensor[blockIdx.x];
  232. int tensor_num = tl.start_tensor_this_launch + tensor_loc;
  233. int chunk_idx = tl.block_to_chunk[blockIdx.x];
  234. int n = tl.sizes[tensor_loc];
  235. MATH_T ratio = learning_rate;
  236. // nvlamb: apply adaptive learning rate to all parameters
  237. // otherwise, only apply to those with non-zero weight decay
  238. if (use_nvlamb || (decay != 0.0))
  239. {
  240. float param_norm = per_tensor_param_norm[tensor_num];
  241. float update_norm = per_tensor_update_norm[tensor_num];
  242. ratio = (update_norm != 0.0f && param_norm != 0.0f) ? learning_rate * (param_norm / update_norm) : learning_rate;
  243. }
  244. T* update = (T*)tl.addresses[0][tensor_loc];
  245. update += chunk_idx*chunk_size;
  246. T* p = (T*)tl.addresses[1][tensor_loc];
  247. p += chunk_idx*chunk_size;
  248. n -= chunk_idx*chunk_size;
  249. // to make things simple, we put aligned case in a different code path
  250. if(n % ILP == 0 &&
  251. chunk_size % ILP == 0 &&
  252. is_aligned(p) &&
  253. is_aligned(update))
  254. {
  255. T r_p[ILP];
  256. T r_update[ILP];
  257. for(int i_start = threadIdx.x; i_start*ILP < n && i_start*ILP < chunk_size; i_start += blockDim.x)
  258. {
  259. // load
  260. load_store(r_p, p, 0, i_start);
  261. load_store(r_update, update, 0, i_start);
  262. #pragma unroll
  263. for(int ii = 0; ii < ILP; ii++)
  264. {
  265. r_p[ii] = static_cast<MATH_T>(r_p[ii]) - (ratio * static_cast<MATH_T>(r_update[ii]));
  266. }
  267. load_store(p, r_p, i_start, 0);
  268. }
  269. }
  270. else
  271. {
  272. for(int i_start = 0;
  273. i_start < n && i_start < chunk_size;
  274. i_start += blockDim.x*ILP)
  275. {
  276. MATH_T r_p[ILP];
  277. MATH_T r_update[ILP];
  278. #pragma unroll
  279. for(int ii = 0; ii < ILP; ii++)
  280. {
  281. int i = i_start + threadIdx.x + ii*blockDim.x;
  282. if(i < n && i < chunk_size)
  283. {
  284. r_p[ii] = p[i];
  285. r_update[ii] = update[i];
  286. }
  287. }
  288. #pragma unroll
  289. for(int ii = 0; ii < ILP; ii++)
  290. {
  291. r_p[ii] = r_p[ii] - (ratio * r_update[ii]);
  292. }
  293. #pragma unroll
  294. for(int ii = 0; ii < ILP; ii++)
  295. {
  296. int i = i_start + threadIdx.x + ii*blockDim.x;
  297. if(i < n && i < chunk_size)
  298. {
  299. p[i] = r_p[ii];
  300. }
  301. }
  302. }
  303. }
  304. }
  305. };
  306. void multi_tensor_lamb_cuda(
  307. int chunk_size,
  308. at::Tensor noop_flag,
  309. std::vector<std::vector<at::Tensor>> tensor_lists,
  310. const float lr,
  311. const float beta1,
  312. const float beta2,
  313. const float epsilon,
  314. const int step,
  315. const int bias_correction,
  316. const float weight_decay,
  317. const int grad_averaging,
  318. const int mode,
  319. at::Tensor global_grad_norm,
  320. const float max_grad_norm,
  321. at::optional<bool> use_nvlamb_python)
  322. {
  323. using namespace at;
  324. // Master weight and 32bit momentum(potentially changing) is not handled by this
  325. // So we assume every tensor are all in the same type
  326. bool use_nvlamb = use_nvlamb_python.has_value() ? use_nvlamb_python.value() : false;
  327. // Handle bias correction mode
  328. float bias_correction1 = 1.0f, bias_correction2 = 1.0f;
  329. if (bias_correction == 1) {
  330. bias_correction1 = 1 - std::pow(beta1, step);
  331. bias_correction2 = 1 - std::pow(beta2, step);
  332. }
  333. // Handle grad averaging mode
  334. float beta3 = 1.0f;
  335. if (grad_averaging == 1) beta3 = 1 - beta1;
  336. std::vector<std::vector<at::Tensor>> grad_list(tensor_lists.begin(), tensor_lists.begin()+1);
  337. std::vector<std::vector<at::Tensor>> param_list(tensor_lists.begin()+1, tensor_lists.begin()+2);
  338. // Compute per tensor param norm
  339. auto param_norm_tuple = multi_tensor_l2norm_cuda(chunk_size, noop_flag, param_list, true);
  340. // We now in-place modify grad to store update before compute its norm
  341. // Generally this is not a issue since people modify grad in step() method all the time
  342. // We can also grab list of empty tensor to avoid this, but I'd like to save space/cpu code
  343. DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "lamb_stage_1",
  344. multi_tensor_apply<4>(
  345. BLOCK_SIZE,
  346. chunk_size,
  347. noop_flag,
  348. tensor_lists,
  349. LAMBStage1Functor<scalar_t_0>(),
  350. beta1,
  351. beta2,
  352. beta3, // 1-beta1 or 1 depends on averaging mode
  353. bias_correction1,
  354. bias_correction2,
  355. epsilon,
  356. (adamMode_t) mode,
  357. weight_decay,
  358. global_grad_norm.DATA_PTR<float>(),
  359. max_grad_norm); )
  360. // Compute update norms
  361. auto update_norm_tuple = multi_tensor_l2norm_cuda(chunk_size, noop_flag, grad_list, true);
  362. std::vector<std::vector<at::Tensor>> grad_param_list(tensor_lists.begin(), tensor_lists.begin()+2);
  363. DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "lamb_stage_2",
  364. multi_tensor_apply<2>(
  365. BLOCK_SIZE,
  366. chunk_size,
  367. noop_flag,
  368. grad_param_list,
  369. LAMBStage2Functor<scalar_t_0>(),
  370. std::get<1>(param_norm_tuple).DATA_PTR<float>(),
  371. std::get<1>(update_norm_tuple).DATA_PTR<float>(),
  372. lr,
  373. weight_decay,
  374. use_nvlamb); )
  375. AT_CUDA_CHECK(cudaGetLastError());
  376. }