multi_tensor_lamb_mp.cu 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496
  1. #include <ATen/ATen.h>
  2. #include <ATen/AccumulateType.h>
  3. #include <ATen/cuda/CUDAContext.h>
  4. #include <ATen/cuda/Exceptions.h>
  5. // Another possibility:
  6. // #include <torch/all.h>
  7. #include <assert.h>
  8. #include "type_shim.h"
  9. #include "multi_tensor_apply.cuh"
  10. #define BLOCK_SIZE 512
  11. #define ILP 4
  12. template<typename T>
  13. __device__ __forceinline__ bool is_aligned(T* p){
  14. return ((uint64_t)p) % (ILP*sizeof(T)) == 0;
  15. }
  16. template<typename T>
  17. __device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int src_offset){
  18. typedef typename std::aligned_storage<ILP*sizeof(T), ILP*alignof(T)>::type LT;
  19. ((LT*)dst)[dst_offset] = ((LT*)src)[src_offset];
  20. }
  21. typedef enum{
  22. MOMENT_MODE_0 =0, // L2 regularization mode
  23. MOMENT_MODE_1 =1 // Decoupled weight decay mode
  24. } adamMode_t;
  25. std::tuple<at::Tensor, at::Tensor> multi_tensor_l2norm_mp_cuda(
  26. int chunk_size,
  27. at::Tensor noop_flag,
  28. std::vector<std::vector<at::Tensor>> tensor_lists,
  29. at::optional<bool> per_tensor_python);
  30. using MATH_T = float;
  31. template<typename T, typename param_t>
  32. struct LAMBStage1Functor
  33. {
  34. __device__ __forceinline__ void operator()(
  35. int chunk_size,
  36. volatile int* noop_gmem,
  37. TensorListMetadata<4>& tl,
  38. const float beta1,
  39. const float beta2,
  40. const float beta3,
  41. const int* step_ptr,
  42. const int bias_correction,
  43. const float epsilon,
  44. adamMode_t mode,
  45. const float decay,
  46. const float* global_grad_norm,
  47. const float* max_global_grad_norm,
  48. const float* found_inf,
  49. const float* inv_scale)
  50. {
  51. if (*noop_gmem) {
  52. return;
  53. }
  54. float beta1_correction = 1.0f;
  55. float beta2_correction = 1.0f;
  56. if (bias_correction == 1) {
  57. int step = *step_ptr;
  58. beta1_correction = 1 - std::pow(beta1, step);
  59. beta2_correction = 1 - std::pow(beta2, step);
  60. }
  61. int tensor_loc = tl.block_to_tensor[blockIdx.x];
  62. int chunk_idx = tl.block_to_chunk[blockIdx.x];
  63. int n = tl.sizes[tensor_loc];
  64. float clipped_global_grad_norm = (*global_grad_norm) > (*max_global_grad_norm) ? (*global_grad_norm) / (*max_global_grad_norm) : 1.0f;
  65. T* g = (T*)tl.addresses[0][tensor_loc];
  66. g += chunk_idx*chunk_size;
  67. param_t* p = (param_t*)tl.addresses[1][tensor_loc];
  68. p += chunk_idx*chunk_size;
  69. param_t* m = (param_t*)tl.addresses[2][tensor_loc];
  70. m += chunk_idx*chunk_size;
  71. param_t* v = (param_t*)tl.addresses[3][tensor_loc];
  72. v += chunk_idx*chunk_size;
  73. n -= chunk_idx*chunk_size;
  74. MATH_T r_g[ILP];
  75. MATH_T r_p[ILP];
  76. MATH_T r_m[ILP];
  77. MATH_T r_v[ILP];
  78. // to make things simple, we put aligned case in a different code path
  79. if(n % ILP == 0 &&
  80. chunk_size % ILP == 0 &&
  81. is_aligned(g) &&
  82. is_aligned(p) &&
  83. is_aligned(m) &&
  84. is_aligned(v))
  85. {
  86. T l_g[ILP];
  87. param_t l_p[ILP];
  88. param_t l_m[ILP];
  89. param_t l_v[ILP];
  90. for(int i_start = threadIdx.x; i_start*ILP < n && i_start*ILP < chunk_size; i_start += blockDim.x)
  91. {
  92. // load
  93. load_store(l_g, g, 0, i_start);
  94. if (decay != 0)
  95. load_store(l_p, p, 0, i_start);
  96. load_store(l_m, m, 0, i_start);
  97. load_store(l_v, v, 0, i_start);
  98. // unpack
  99. #pragma unroll
  100. for(int ii = 0; ii < ILP; ii++)
  101. {
  102. r_g[ii] = l_g[ii] * (*inv_scale);
  103. if (decay == 0) {
  104. r_p[ii] = MATH_T(0);
  105. }
  106. else {
  107. r_p[ii] = l_p[ii];
  108. }
  109. r_m[ii] = l_m[ii];
  110. r_v[ii] = l_v[ii];
  111. }
  112. #pragma unroll
  113. for(int ii = 0; ii < ILP; ii++)
  114. {
  115. if (mode == MOMENT_MODE_0) {
  116. MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm;
  117. // L2 on scaled grad
  118. scaled_grad = scaled_grad + decay*r_p[ii];
  119. r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;
  120. r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad;
  121. MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
  122. MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
  123. MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
  124. r_p[ii] = next_m_unbiased / denom;
  125. }
  126. else {
  127. MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm;
  128. r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;
  129. r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad;
  130. MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
  131. MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
  132. MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
  133. r_p[ii] = (next_m_unbiased/denom) + (decay*r_p[ii]);
  134. }
  135. }
  136. #pragma unroll
  137. for(int ii = 0; ii < ILP; ii++)
  138. {
  139. l_p[ii] = r_p[ii];
  140. // Difference from APEX's LAMB kernel. `g` and `p` can be different dtypes.
  141. l_g[ii] = r_p[ii];
  142. l_m[ii] = r_m[ii];
  143. l_v[ii] = r_v[ii];
  144. }
  145. // store
  146. load_store(g, l_g, i_start, 0);
  147. load_store(m, l_m, i_start, 0);
  148. load_store(v, l_v, i_start, 0);
  149. }
  150. }
  151. else
  152. {
  153. // see note in multi_tensor_scale_kernel.cu
  154. for(int i_start = 0;
  155. i_start < n && i_start < chunk_size;
  156. i_start += blockDim.x*ILP)
  157. {
  158. MATH_T r_g[ILP];
  159. MATH_T r_p[ILP];
  160. MATH_T r_m[ILP];
  161. MATH_T r_v[ILP];
  162. #pragma unroll
  163. for(int ii = 0; ii < ILP; ii++)
  164. {
  165. int i = i_start + threadIdx.x + ii*blockDim.x;
  166. if(i < n && i < chunk_size)
  167. {
  168. r_g[ii] = g[i] * (*inv_scale);
  169. // special ?optimization? for lamb stage 1
  170. if (decay == 0) {
  171. r_p[ii] = MATH_T(0);
  172. }
  173. else {
  174. r_p[ii] = p[i];
  175. }
  176. r_m[ii] = m[i];
  177. r_v[ii] = v[i];
  178. } else {
  179. r_g[ii] = MATH_T(0);
  180. r_p[ii] = MATH_T(0);
  181. r_m[ii] = MATH_T(0);
  182. r_v[ii] = MATH_T(0);
  183. }
  184. }
  185. #pragma unroll
  186. for(int ii = 0; ii < ILP; ii++)
  187. {
  188. if (mode == MOMENT_MODE_0) {
  189. MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm;
  190. // L2 on scaled grad
  191. scaled_grad = scaled_grad + decay*r_p[ii];
  192. r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;
  193. r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad;
  194. MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
  195. MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
  196. MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
  197. r_p[ii] = next_m_unbiased / denom;
  198. }
  199. else {
  200. MATH_T scaled_grad = r_g[ii] / clipped_global_grad_norm;
  201. r_m[ii] = r_m[ii] * beta1 + beta3 * scaled_grad;
  202. r_v[ii] = r_v[ii] * beta2 + (1-beta2) * scaled_grad * scaled_grad;
  203. MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
  204. MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
  205. MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
  206. r_p[ii] = (next_m_unbiased/denom) + (decay*r_p[ii]);
  207. }
  208. }
  209. #pragma unroll
  210. for(int ii = 0; ii < ILP; ii++)
  211. {
  212. int i = i_start + threadIdx.x + ii*blockDim.x;
  213. if(i < n && i < chunk_size)
  214. {
  215. g[i] = r_p[ii];
  216. m[i] = r_m[ii];
  217. v[i] = r_v[ii];
  218. }
  219. }
  220. }
  221. }
  222. }
  223. };
  224. // Step 2 reads in 'update' value and per-tensor param_norm and update_norm.
  225. // It computes new parameter value.
  226. // N == 2: FP32 params, no master params
  227. // N == 3: FP16 params, FP32 master params.
  228. template<typename T, int N, typename param_t>
  229. struct LAMBStage2Functor
  230. {
  231. static_assert((N == 2 && std::is_same<T, param_t>::value) || (N == 3 && std::is_same<param_t, float>::value), "");
  232. __device__ __forceinline__ void operator()(
  233. int chunk_size,
  234. volatile int* noop_gmem,
  235. TensorListMetadata<N>& tl,
  236. const float* per_tensor_param_norm,
  237. const float* per_tensor_update_norm,
  238. const float* learning_rate,
  239. const float decay,
  240. bool use_nvlamb)
  241. {
  242. if (*noop_gmem) {
  243. return;
  244. }
  245. int tensor_loc = tl.block_to_tensor[blockIdx.x];
  246. int tensor_num = tl.start_tensor_this_launch + tensor_loc;
  247. int chunk_idx = tl.block_to_chunk[blockIdx.x];
  248. int n = tl.sizes[tensor_loc];
  249. MATH_T ratio = *learning_rate;
  250. // nvlamb: apply adaptive learning rate to all parameters
  251. // otherwise, only apply to those with non-zero weight decay
  252. if (use_nvlamb || (decay != 0.0))
  253. {
  254. float param_norm = per_tensor_param_norm[tensor_num];
  255. float update_norm = per_tensor_update_norm[tensor_num];
  256. ratio = (update_norm != 0.0f && param_norm != 0.0f) ? *learning_rate * (param_norm / update_norm) : *learning_rate;
  257. }
  258. T* update = (T*)tl.addresses[0][tensor_loc];
  259. update += chunk_idx*chunk_size;
  260. param_t* p = (param_t*)tl.addresses[1][tensor_loc];
  261. p += chunk_idx*chunk_size;
  262. T* out_p;
  263. if (N == 3) {
  264. out_p = (T*)tl.addresses[2][tensor_loc];
  265. out_p += chunk_idx*chunk_size;
  266. }
  267. n -= chunk_idx*chunk_size;
  268. // to make things simple, we put aligned case in a different code path
  269. bool can_use_aligned_path = n % ILP == 0 && chunk_size % ILP == 0 && is_aligned(p) && is_aligned(update);
  270. if (N == 3) {
  271. can_use_aligned_path = can_use_aligned_path && is_aligned(out_p);
  272. }
  273. if(can_use_aligned_path)
  274. {
  275. param_t r_p[ILP];
  276. T r_update[ILP];
  277. T r_out_p[ILP];
  278. for(int i_start = threadIdx.x; i_start*ILP < n && i_start*ILP < chunk_size; i_start += blockDim.x)
  279. {
  280. // load
  281. load_store(r_p, p, 0, i_start);
  282. load_store(r_update, update, 0, i_start);
  283. if (N == 3) {
  284. load_store(r_out_p, out_p, 0, i_start);
  285. }
  286. #pragma unroll
  287. for(int ii = 0; ii < ILP; ii++)
  288. {
  289. r_p[ii] = static_cast<MATH_T>(r_p[ii]) - (ratio * static_cast<MATH_T>(r_update[ii]));
  290. if (N == 3) {
  291. r_out_p[ii] = r_p[ii];
  292. }
  293. }
  294. load_store(p, r_p, i_start, 0);
  295. if (N == 3) {
  296. load_store(out_p, r_out_p, i_start, 0);
  297. }
  298. }
  299. }
  300. else
  301. {
  302. for(int i_start = 0;
  303. i_start < n && i_start < chunk_size;
  304. i_start += blockDim.x*ILP)
  305. {
  306. MATH_T r_p[ILP];
  307. MATH_T r_update[ILP];
  308. #pragma unroll
  309. for(int ii = 0; ii < ILP; ii++)
  310. {
  311. int i = i_start + threadIdx.x + ii*blockDim.x;
  312. if(i < n && i < chunk_size)
  313. {
  314. r_p[ii] = p[i];
  315. r_update[ii] = update[i];
  316. }
  317. }
  318. #pragma unroll
  319. for(int ii = 0; ii < ILP; ii++)
  320. {
  321. r_p[ii] = r_p[ii] - (ratio * r_update[ii]);
  322. }
  323. #pragma unroll
  324. for(int ii = 0; ii < ILP; ii++)
  325. {
  326. int i = i_start + threadIdx.x + ii*blockDim.x;
  327. if(i < n && i < chunk_size)
  328. {
  329. p[i] = r_p[ii];
  330. if (N == 3) {
  331. out_p[i] = r_p[ii];
  332. }
  333. }
  334. }
  335. }
  336. }
  337. }
  338. };
  339. void multi_tensor_lamb_mp_cuda(
  340. int chunk_size,
  341. at::Tensor noop_flag,
  342. std::vector<std::vector<at::Tensor>> tensor_lists,
  343. at::Tensor lr,
  344. const float beta1,
  345. const float beta2,
  346. const float epsilon,
  347. at::Tensor step,
  348. const int bias_correction,
  349. const float weight_decay,
  350. const int grad_averaging,
  351. const int mode,
  352. at::Tensor global_grad_norm,
  353. at::Tensor max_grad_norm,
  354. at::optional<bool> use_nvlamb_python,
  355. at::Tensor found_inf,
  356. at::Tensor inv_scale)
  357. {
  358. // n_tensors == 5: FP16 model params & FP32 master params
  359. // n_tensors == 4: FP32 model params & NO FP32 master params
  360. const auto n_tensors = tensor_lists.size();
  361. assert(n_tensors == 4 || n_tensors == 5);
  362. using namespace at;
  363. bool use_nvlamb = use_nvlamb_python.has_value() ? use_nvlamb_python.value() : false;
  364. // note(mkozuki): move bias handling below to functor
  365. // Handle bias correction mode
  366. // float bias_correction1 = 1.0f, bias_correction2 = 1.0f;
  367. // if (bias_correction == 1) {
  368. // bias_correction1 = 1 - std::pow(beta1, step);
  369. // bias_correction2 = 1 - std::pow(beta2, step);
  370. // }
  371. // Handle grad averaging mode
  372. float beta3 = 1.0f;
  373. if (grad_averaging == 1) beta3 = 1 - beta1;
  374. std::vector<std::vector<at::Tensor>> stage1_tensor_lists(tensor_lists.begin(), tensor_lists.begin() + 4);
  375. std::vector<std::vector<at::Tensor>> grad_list(tensor_lists.begin(), tensor_lists.begin()+1);
  376. std::vector<std::vector<at::Tensor>> param_list(tensor_lists.begin()+1, tensor_lists.begin()+2);
  377. // Compute per tensor param norm
  378. auto param_norm_tuple = multi_tensor_l2norm_mp_cuda(chunk_size, noop_flag, param_list, true);
  379. // We now in-place modify grad to store update before compute its norm
  380. // Generally this is not a issue since people modify grad in step() method all the time
  381. // We can also grab list of empty tensor to avoid this, but I'd like to save space/cpu code
  382. if (n_tensors == 4) {
  383. DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "lamb_stage_1",
  384. multi_tensor_apply<4>(
  385. BLOCK_SIZE,
  386. chunk_size,
  387. noop_flag,
  388. stage1_tensor_lists,
  389. LAMBStage1Functor<scalar_t_0, scalar_t_0>(),
  390. beta1,
  391. beta2,
  392. beta3, // 1-beta1 or 1 depends on averaging mode
  393. // bias_correction1,
  394. // bias_correction2,
  395. step.data_ptr<int>(),
  396. bias_correction,
  397. epsilon,
  398. (adamMode_t) mode,
  399. weight_decay,
  400. global_grad_norm.data_ptr<float>(),
  401. max_grad_norm.data_ptr<float>(),
  402. found_inf.data_ptr<float>(),
  403. inv_scale.data_ptr<float>()); )
  404. } else {
  405. DISPATCH_FLOAT_HALF_AND_BFLOAT(tensor_lists[0][0].scalar_type(), 0, "lamb_stage_1",
  406. multi_tensor_apply<4>(
  407. BLOCK_SIZE,
  408. chunk_size,
  409. noop_flag,
  410. stage1_tensor_lists,
  411. LAMBStage1Functor<scalar_t_0, float>(),
  412. beta1,
  413. beta2,
  414. beta3, // 1-beta1 or 1 depends on averaging mode
  415. // bias_correction1,
  416. // bias_correction2,
  417. step.data_ptr<int>(),
  418. bias_correction,
  419. epsilon,
  420. (adamMode_t) mode,
  421. weight_decay,
  422. global_grad_norm.data_ptr<float>(),
  423. max_grad_norm.data_ptr<float>(),
  424. found_inf.data_ptr<float>(),
  425. inv_scale.data_ptr<float>()); )
  426. }
  427. // Compute update norms
  428. auto update_norm_tuple = multi_tensor_l2norm_mp_cuda(chunk_size, noop_flag, grad_list, true);
  429. std::vector<std::vector<at::Tensor>> grad_param_list(tensor_lists.begin(), tensor_lists.begin()+2);
  430. if (n_tensors == 4) {
  431. DISPATCH_FLOAT_AND_HALF(tensor_lists[0][0].scalar_type(), 0, "lamb_stage_2",
  432. multi_tensor_apply<2>(
  433. BLOCK_SIZE,
  434. chunk_size,
  435. noop_flag,
  436. grad_param_list,
  437. LAMBStage2Functor<scalar_t_0, 2, scalar_t_0>(),
  438. std::get<1>(param_norm_tuple).data_ptr<float>(),
  439. std::get<1>(update_norm_tuple).data_ptr<float>(),
  440. lr.data_ptr<float>(),
  441. weight_decay,
  442. use_nvlamb); )
  443. } else {
  444. grad_param_list.push_back(tensor_lists[4]);
  445. DISPATCH_FLOAT_HALF_AND_BFLOAT(tensor_lists[0][0].scalar_type(), 0, "lamb_stage_2",
  446. multi_tensor_apply<3>(
  447. BLOCK_SIZE,
  448. chunk_size,
  449. noop_flag,
  450. grad_param_list,
  451. LAMBStage2Functor<scalar_t_0, 3, float>(),
  452. std::get<1>(param_norm_tuple).data_ptr<float>(),
  453. std::get<1>(update_norm_tuple).data_ptr<float>(),
  454. lr.data_ptr<float>(),
  455. weight_decay,
  456. use_nvlamb); )
  457. }
  458. AT_CUDA_CHECK(cudaGetLastError());
  459. }