multi_tensor_adam.cu 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513
  1. #include <ATen/ATen.h>
  2. #include <ATen/AccumulateType.h>
  3. #include <ATen/cuda/CUDAContext.h>
  4. #include <ATen/cuda/Exceptions.h>
  5. // Another possibility:
  6. // #include <torch/all.h>
  7. #include <assert.h>
  8. #include "type_shim.h"
  9. #include "multi_tensor_apply.cuh"
  10. #define BLOCK_SIZE 512
  11. #define ILP 4
  12. typedef enum{
  13. ADAM_MODE_0 =0, // L2 regularization mode
  14. ADAM_MODE_1 =1 // Decoupled weight decay mode(AdamW)
  15. } adamMode_t;
  16. using MATH_T = float;
  17. template<typename T, typename FULL_T, typename index_t>
  18. struct AdamFunctor
  19. {
  20. __device__ __forceinline__ void operator()(
  21. index_t chunk_size,
  22. volatile int* noop_gmem,
  23. TensorListMetadata<4>& tl,
  24. const float beta1,
  25. const float beta2,
  26. const float beta1_correction,
  27. const float beta2_correction,
  28. const float epsilon,
  29. const float lr,
  30. adamMode_t mode,
  31. const float decay)
  32. {
  33. // I'd like this kernel to propagate infs/nans.
  34. // if(*noop_gmem == 1)
  35. // return;
  36. index_t tensor_loc = tl.block_to_tensor[blockIdx.x];
  37. // potentially use to pass in list of scalar
  38. // int tensor_num = tl.start_tensor_this_launch + tensor_loc;
  39. index_t chunk_idx = tl.block_to_chunk[blockIdx.x];
  40. index_t n = tl.sizes[tensor_loc];
  41. T* g = (T*)tl.addresses[0][tensor_loc];
  42. g += chunk_idx*chunk_size;
  43. T* p = (T*)tl.addresses[1][tensor_loc];
  44. p += chunk_idx*chunk_size;
  45. FULL_T* m = (FULL_T*)tl.addresses[2][tensor_loc];
  46. m += chunk_idx*chunk_size;
  47. FULL_T* v = (FULL_T*)tl.addresses[3][tensor_loc];
  48. v += chunk_idx*chunk_size;
  49. n -= chunk_idx*chunk_size;
  50. // see note in multi_tensor_scale_kernel.cu
  51. for(index_t i_start = 0;
  52. i_start < n && i_start < chunk_size;
  53. i_start += blockDim.x*ILP)
  54. {
  55. MATH_T r_g[ILP];
  56. MATH_T r_p[ILP];
  57. MATH_T r_m[ILP];
  58. MATH_T r_v[ILP];
  59. #pragma unroll
  60. for(int ii = 0; ii < ILP; ii++)
  61. {
  62. int i = i_start + threadIdx.x + ii*blockDim.x;
  63. if(i < n && i < chunk_size)
  64. {
  65. r_g[ii] = g[i];
  66. r_p[ii] = p[i];
  67. r_m[ii] = m[i];
  68. r_v[ii] = v[i];
  69. } else {
  70. r_g[ii] = MATH_T(0);
  71. r_p[ii] = MATH_T(0);
  72. r_m[ii] = MATH_T(0);
  73. r_v[ii] = MATH_T(0);
  74. }
  75. }
  76. #pragma unroll
  77. for(int ii = 0; ii < ILP; ii++)
  78. {
  79. if(mode == ADAM_MODE_0) { // L2
  80. r_g[ii] = r_g[ii] + (decay * r_p[ii]);
  81. r_m[ii] = beta1 * r_m[ii] + (1-beta1) * r_g[ii];
  82. r_v[ii] = beta2 * r_v[ii] + (1-beta2) * r_g[ii] * r_g[ii];
  83. MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
  84. MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
  85. MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
  86. MATH_T update = next_m_unbiased / denom;
  87. r_p[ii] = r_p[ii] - (lr * update);
  88. }
  89. else { // weight decay
  90. r_m[ii] = beta1 * r_m[ii] + (1-beta1) * r_g[ii];
  91. r_v[ii] = beta2 * r_v[ii] + (1-beta2) * r_g[ii] * r_g[ii];
  92. MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
  93. MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
  94. MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
  95. MATH_T update = (next_m_unbiased / denom) + (decay * r_p[ii]);
  96. r_p[ii] = r_p[ii] - (lr * update);
  97. }
  98. }
  99. #pragma unroll
  100. for(int ii = 0; ii < ILP; ii++)
  101. {
  102. int i = i_start + threadIdx.x + ii*blockDim.x;
  103. if(i < n && i < chunk_size)
  104. {
  105. p[i] = r_p[ii];
  106. m[i] = r_m[ii];
  107. v[i] = r_v[ii];
  108. }
  109. }
  110. }
  111. }
  112. };
  113. template<typename T, typename FULL_T>
  114. struct AdamCapturableFunctor
  115. {
  116. __device__ __forceinline__ void operator()(
  117. int chunk_size,
  118. volatile int* noop_gmem,
  119. TensorListMetadata<4>& tl,
  120. const float beta1,
  121. const float beta2,
  122. const int* step,
  123. const int bias_correction,
  124. const float epsilon,
  125. const float* lr,
  126. adamMode_t mode,
  127. const float decay,
  128. const float* inv_scale)
  129. {
  130. if(*noop_gmem == 1)
  131. return;
  132. float beta1_correction = 1.0f, beta2_correction = 1.0f;
  133. if (bias_correction == 1) {
  134. beta1_correction = 1 - pow(beta1, *step);
  135. beta2_correction = 1 - pow(beta2, *step);
  136. }
  137. int tensor_loc = tl.block_to_tensor[blockIdx.x];
  138. // potentially use to pass in list of scalar
  139. // int tensor_num = tl.start_tensor_this_launch + tensor_loc;
  140. int chunk_idx = tl.block_to_chunk[blockIdx.x];
  141. int n = tl.sizes[tensor_loc];
  142. T* g = (T*)tl.addresses[0][tensor_loc];
  143. g += chunk_idx*chunk_size;
  144. T* p = (T*)tl.addresses[1][tensor_loc];
  145. p += chunk_idx*chunk_size;
  146. FULL_T* m = (FULL_T*)tl.addresses[2][tensor_loc];
  147. m += chunk_idx*chunk_size;
  148. FULL_T* v = (FULL_T*)tl.addresses[3][tensor_loc];
  149. v += chunk_idx*chunk_size;
  150. n -= chunk_idx*chunk_size;
  151. // see note in multi_tensor_scale_kernel.cu
  152. for(int i_start = 0;
  153. i_start < n && i_start < chunk_size;
  154. i_start += blockDim.x*ILP)
  155. {
  156. MATH_T r_g[ILP];
  157. MATH_T r_p[ILP];
  158. MATH_T r_m[ILP];
  159. MATH_T r_v[ILP];
  160. #pragma unroll
  161. for(int ii = 0; ii < ILP; ii++)
  162. {
  163. int i = i_start + threadIdx.x + ii*blockDim.x;
  164. if(i < n && i < chunk_size)
  165. {
  166. r_g[ii] = static_cast<MATH_T>(g[i]) * (*inv_scale);
  167. g[i] = static_cast<T>(r_g[ii]);
  168. r_p[ii] = static_cast<MATH_T>(p[i]);
  169. r_m[ii] = static_cast<MATH_T>(m[i]);
  170. r_v[ii] = static_cast<MATH_T>(v[i]);
  171. } else {
  172. r_g[ii] = MATH_T(0);
  173. r_p[ii] = MATH_T(0);
  174. r_m[ii] = MATH_T(0);
  175. r_v[ii] = MATH_T(0);
  176. }
  177. }
  178. #pragma unroll
  179. for(int ii = 0; ii < ILP; ii++)
  180. {
  181. if(mode == ADAM_MODE_0) { // L2
  182. r_g[ii] = r_g[ii] + (decay * r_p[ii]);
  183. r_m[ii] = beta1 * r_m[ii] + (1-beta1) * r_g[ii];
  184. r_v[ii] = beta2 * r_v[ii] + (1-beta2) * r_g[ii] * r_g[ii];
  185. MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
  186. MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
  187. MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
  188. MATH_T update = next_m_unbiased / denom;
  189. r_p[ii] = r_p[ii] - (*lr * update);
  190. }
  191. else { // weight decay
  192. r_m[ii] = beta1 * r_m[ii] + (1-beta1) * r_g[ii];
  193. r_v[ii] = beta2 * r_v[ii] + (1-beta2) * r_g[ii] * r_g[ii];
  194. MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
  195. MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
  196. MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
  197. MATH_T update = (next_m_unbiased / denom) + (decay * r_p[ii]);
  198. r_p[ii] = r_p[ii] - (*lr * update);
  199. }
  200. }
  201. #pragma unroll
  202. for(int ii = 0; ii < ILP; ii++)
  203. {
  204. int i = i_start + threadIdx.x + ii*blockDim.x;
  205. if(i < n && i < chunk_size)
  206. {
  207. p[i] = static_cast<T>(r_p[ii]);
  208. m[i] = static_cast<T>(r_m[ii]);
  209. v[i] = static_cast<T>(r_v[ii]);
  210. }
  211. }
  212. }
  213. }
  214. };
  215. template<typename T, typename FULL_T>
  216. struct AdamCapturableMasterFunctor
  217. {
  218. __device__ __forceinline__ void operator()(
  219. int chunk_size,
  220. volatile int* noop_gmem,
  221. TensorListMetadata<5>& tl,
  222. const float beta1,
  223. const float beta2,
  224. const int* step,
  225. const int bias_correction,
  226. const float epsilon,
  227. const float* lr,
  228. adamMode_t mode,
  229. const float decay,
  230. const float* inv_scale)
  231. {
  232. if(*noop_gmem == 1)
  233. return;
  234. float beta1_correction = 1.0f, beta2_correction = 1.0f;
  235. if (bias_correction == 1) {
  236. beta1_correction = 1 - pow(beta1, *step);
  237. beta2_correction = 1 - pow(beta2, *step);
  238. }
  239. int tensor_loc = tl.block_to_tensor[blockIdx.x];
  240. // potentially use to pass in list of scalar
  241. // int tensor_num = tl.start_tensor_this_launch + tensor_loc;
  242. int chunk_idx = tl.block_to_chunk[blockIdx.x];
  243. int n = tl.sizes[tensor_loc];
  244. T* g = (T*)tl.addresses[0][tensor_loc];
  245. g += chunk_idx*chunk_size;
  246. T* p = (T*)tl.addresses[1][tensor_loc];
  247. p += chunk_idx*chunk_size;
  248. FULL_T* m = (FULL_T*)tl.addresses[2][tensor_loc];
  249. m += chunk_idx*chunk_size;
  250. FULL_T* v = (FULL_T*)tl.addresses[3][tensor_loc];
  251. v += chunk_idx*chunk_size;
  252. FULL_T* p_master = (FULL_T*)tl.addresses[4][tensor_loc];
  253. p_master += chunk_idx*chunk_size;
  254. n -= chunk_idx*chunk_size;
  255. // see note in multi_tensor_scale_kernel.cu
  256. for(int i_start = 0;
  257. i_start < n && i_start < chunk_size;
  258. i_start += blockDim.x*ILP)
  259. {
  260. MATH_T r_g[ILP];
  261. MATH_T r_p[ILP];
  262. MATH_T r_m[ILP];
  263. MATH_T r_v[ILP];
  264. #pragma unroll
  265. for(int ii = 0; ii < ILP; ii++)
  266. {
  267. int i = i_start + threadIdx.x + ii*blockDim.x;
  268. if(i < n && i < chunk_size)
  269. {
  270. r_g[ii] = static_cast<MATH_T>(g[i]) * (*inv_scale);
  271. g[i] = static_cast<T>(r_g[ii]);
  272. r_p[ii] = static_cast<MATH_T>(p_master[i]);
  273. r_m[ii] = static_cast<MATH_T>(m[i]);
  274. r_v[ii] = static_cast<MATH_T>(v[i]);
  275. } else {
  276. r_g[ii] = MATH_T(0);
  277. r_p[ii] = MATH_T(0);
  278. r_m[ii] = MATH_T(0);
  279. r_v[ii] = MATH_T(0);
  280. }
  281. }
  282. #pragma unroll
  283. for(int ii = 0; ii < ILP; ii++)
  284. {
  285. if(mode == ADAM_MODE_0) { // L2
  286. r_g[ii] = r_g[ii] + (decay * r_p[ii]);
  287. r_m[ii] = beta1 * r_m[ii] + (1-beta1) * r_g[ii];
  288. r_v[ii] = beta2 * r_v[ii] + (1-beta2) * r_g[ii] * r_g[ii];
  289. MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
  290. MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
  291. MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
  292. MATH_T update = next_m_unbiased / denom;
  293. r_p[ii] = r_p[ii] - (*lr * update);
  294. }
  295. else { // weight decay
  296. r_m[ii] = beta1 * r_m[ii] + (1-beta1) * r_g[ii];
  297. r_v[ii] = beta2 * r_v[ii] + (1-beta2) * r_g[ii] * r_g[ii];
  298. MATH_T next_m_unbiased = r_m[ii] / beta1_correction;
  299. MATH_T next_v_unbiased = r_v[ii] / beta2_correction;
  300. MATH_T denom = sqrtf(next_v_unbiased) + epsilon;
  301. MATH_T update = (next_m_unbiased / denom) + (decay * r_p[ii]);
  302. r_p[ii] = r_p[ii] - (*lr * update);
  303. }
  304. }
  305. #pragma unroll
  306. for(int ii = 0; ii < ILP; ii++)
  307. {
  308. int i = i_start + threadIdx.x + ii*blockDim.x;
  309. if(i < n && i < chunk_size)
  310. {
  311. p[i] = static_cast<T>(r_p[ii]);
  312. p_master[i] = static_cast<FULL_T>(r_p[ii]);
  313. m[i] = static_cast<FULL_T>(r_m[ii]);
  314. v[i] = static_cast<FULL_T>(r_v[ii]);
  315. }
  316. }
  317. }
  318. }
  319. };
  320. void multi_tensor_adam_cuda(
  321. int chunk_size,
  322. at::Tensor noop_flag,
  323. std::vector<std::vector<at::Tensor>> tensor_lists,
  324. const float lr,
  325. const float beta1,
  326. const float beta2,
  327. const float epsilon,
  328. const int step,
  329. const int mode,
  330. const int bias_correction,
  331. const float weight_decay)
  332. {
  333. using namespace at;
  334. // Handle bias correction mode
  335. float bias_correction1 = 1.0f, bias_correction2 = 1.0f;
  336. if (bias_correction == 1) {
  337. bias_correction1 = 1 - std::pow(beta1, step);
  338. bias_correction2 = 1 - std::pow(beta2, step);
  339. }
  340. size_t max_size = 0;
  341. bool requires_64bit_indexing = false;
  342. for (auto it = tensor_lists.begin(); it != tensor_lists.end(); it++) {
  343. for (auto it2 = it->begin(); it2 != it->end(); it2++) {
  344. if (it2->numel() > max_size) {
  345. max_size = it2->numel();
  346. if (max_size >= INT_MAX) {
  347. requires_64bit_indexing = true;
  348. break;
  349. }
  350. }
  351. }
  352. if (requires_64bit_indexing) {
  353. break;
  354. }
  355. }
  356. if (requires_64bit_indexing) {
  357. // Assume single type across p,g,m1,m2 now
  358. DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
  359. tensor_lists[0][0].scalar_type(), 0, "adam",
  360. multi_tensor_apply<4>(
  361. (int64_t) BLOCK_SIZE,
  362. (int64_t) chunk_size,
  363. noop_flag,
  364. tensor_lists,
  365. AdamFunctor<scalar_t_0, float, int64_t>(),
  366. beta1,
  367. beta2,
  368. bias_correction1,
  369. bias_correction2,
  370. epsilon,
  371. lr,
  372. (adamMode_t) mode,
  373. weight_decay); )
  374. } else {
  375. // Assume single type across p,g,m1,m2 now
  376. DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
  377. tensor_lists[0][0].scalar_type(), 0, "adam",
  378. multi_tensor_apply<4>(
  379. BLOCK_SIZE,
  380. chunk_size,
  381. noop_flag,
  382. tensor_lists,
  383. AdamFunctor<scalar_t_0, float, int32_t>(),
  384. beta1,
  385. beta2,
  386. bias_correction1,
  387. bias_correction2,
  388. epsilon,
  389. lr,
  390. (adamMode_t) mode,
  391. weight_decay); )
  392. }
  393. AT_CUDA_CHECK(cudaGetLastError());
  394. }
  395. void multi_tensor_adam_capturable_cuda(
  396. int chunk_size,
  397. at::Tensor noop_flag,
  398. std::vector<std::vector<at::Tensor>> tensor_lists,
  399. at::Tensor lr,
  400. const float beta1,
  401. const float beta2,
  402. const float epsilon,
  403. at::Tensor step,
  404. const int mode,
  405. const int bias_correction,
  406. const float weight_decay,
  407. at::Tensor inv_scale)
  408. {
  409. using namespace at;
  410. DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
  411. tensor_lists[0][0].scalar_type(), 0, "adam",
  412. multi_tensor_apply<4>(
  413. BLOCK_SIZE,
  414. chunk_size,
  415. noop_flag,
  416. tensor_lists,
  417. AdamCapturableFunctor<scalar_t_0, float>(),
  418. beta1,
  419. beta2,
  420. step.data_ptr<int>(),
  421. bias_correction,
  422. epsilon,
  423. lr.data_ptr<float>(),
  424. (adamMode_t) mode,
  425. weight_decay,
  426. inv_scale.data_ptr<float>()); )
  427. AT_CUDA_CHECK(cudaGetLastError());
  428. }
  429. void multi_tensor_adam_capturable_master_cuda(
  430. int chunk_size,
  431. at::Tensor noop_flag,
  432. std::vector<std::vector<at::Tensor>> tensor_lists,
  433. at::Tensor lr,
  434. const float beta1,
  435. const float beta2,
  436. const float epsilon,
  437. at::Tensor step,
  438. const int mode,
  439. const int bias_correction,
  440. const float weight_decay,
  441. at::Tensor inv_scale)
  442. {
  443. using namespace at;
  444. DISPATCH_DOUBLE_FLOAT_HALF_AND_BFLOAT(
  445. tensor_lists[0][0].scalar_type(), 0, "adam",
  446. multi_tensor_apply<5>(
  447. BLOCK_SIZE,
  448. chunk_size,
  449. noop_flag,
  450. tensor_lists,
  451. AdamCapturableMasterFunctor<scalar_t_0, float>(),
  452. beta1,
  453. beta2,
  454. step.data_ptr<int>(),
  455. bias_correction,
  456. epsilon,
  457. lr.data_ptr<float>(),
  458. (adamMode_t) mode,
  459. weight_decay,
  460. inv_scale.data_ptr<float>()); )
  461. AT_CUDA_CHECK(cudaGetLastError());
  462. }