3
0

multi_tensor_sgd_kernel.cu 7.9 KB


  1. #include <ATen/ATen.h>
  2. #include <ATen/AccumulateType.h>
  3. #include <ATen/cuda/CUDAContext.h>
  4. #include <ATen/cuda/Exceptions.h>
  5. #include "multi_tensor_apply.cuh"
  6. #include "compat.h"
  7. #include <assert.h>
  8. #include <cuda_runtime.h>
  9. #define BLOCK_SIZE 512
  10. #define ILP 4
  11. /**
  12. * Perform fused SGD on multiple buffers
  13. * N: number of tensors
  14. * tl[0] : gradients
  15. * tl[1] : weights
  16. * tl[2] : momentum buffers
  17. * tl[3] : fp16 weights (if appropriate)
  18. * wd : weight_decay (scalar)
  19. * momentum : momentum (scalar)
  20. * dampening : momentum dampening (scalar)
  21. * lr : learning rate (scalar)
  22. * nesterov : enable nesterov (bool)
  23. * first run : necessary for proper momentum handling & init
  24. * wd_after_momentum : apply weight decay _after_ momentum instead of before
  25. **/
  26. template<int N, typename T_grad, typename T_weight>
  27. struct SGDFunctor
  28. {
  29. __device__ __forceinline__ void operator()(
  30. int chunk_size,
  31. volatile int* noop_gmem,
  32. TensorListMetadata<N>& tl,
  33. float wd,
  34. float momentum,
  35. float dampening,
  36. float lr,
  37. bool nesterov,
  38. bool first_run,
  39. bool wd_after_momentum,
  40. float scale)
  41. {
  42. // Early exit if we don't need to do anything
  43. if (*noop_gmem) return;
  44. int tensor_loc = tl.block_to_tensor[blockIdx.x];
  45. int chunk_idx = tl.block_to_chunk[blockIdx.x];
  46. int n = tl.sizes[tensor_loc];
  47. T_grad* grad_in = (T_grad*)tl.addresses[0][tensor_loc];
  48. grad_in += chunk_idx*chunk_size;
  49. T_weight* weight_in = (T_weight*)tl.addresses[1][tensor_loc];
  50. weight_in += chunk_idx*chunk_size;
  51. T_weight* mom_in = (T_weight*)tl.addresses[2][tensor_loc];
  52. mom_in += chunk_idx*chunk_size;
  53. at::Half *model_weights_out = nullptr;
  54. if(N == 4)
  55. {
  56. model_weights_out = (at::Half*)tl.addresses[3][tensor_loc];
  57. model_weights_out += chunk_idx*chunk_size;
  58. }
  59. n -= chunk_idx*chunk_size;
  60. // Non-divergent exit condition for the __syncthreads
  61. float incoming_grads[ILP];
  62. float incoming_weights[ILP];
  63. float incoming_moms[ILP];
  64. for(int i_start = 0;
  65. i_start < n && i_start < chunk_size;
  66. i_start += blockDim.x*ILP)
  67. {
  68. #pragma unroll
  69. for(int ii = 0; ii < ILP; ii++)
  70. {
  71. incoming_grads[ii] = 0;
  72. incoming_weights[ii] = 0;
  73. incoming_moms[ii] = 0;
  74. int i = i_start + threadIdx.x + ii*blockDim.x;
  75. if(i < n && i < chunk_size)
  76. {
  77. incoming_grads[ii] = static_cast<float>(grad_in[i])*scale;
  78. incoming_weights[ii] = static_cast<float>(weight_in[i]);
  79. incoming_moms[ii] = static_cast<float>(mom_in[i]);
  80. }
  81. }
  82. // note for clarification to future michael:
  83. // From a pure memory dependency perspective, there's likely no point unrolling
  84. // the write loop, since writes just fire off once their LDGs arrive.
  85. // Put another way, the STGs are dependent on the LDGs, but not on each other.
  86. // There is still compute ILP benefit from unrolling the loop though.
  87. #pragma unroll
  88. for(int ii = 0; ii < ILP; ii++)
  89. {
  90. int i = i_start + threadIdx.x + ii*blockDim.x;
  91. if(i < n && i < chunk_size)
  92. {
  93. // apply weight decay before momentum if necessary
  94. if(wd != 0.f && !wd_after_momentum)
  95. incoming_grads[ii] += wd * incoming_weights[ii];
  96. if(momentum != 0.f)
  97. {
  98. if(!first_run)
  99. incoming_moms[ii] = incoming_moms[ii] * momentum + (1.f - dampening) * incoming_grads[ii];
  100. else // initialize momentums to current incoming grads
  101. incoming_moms[ii] = incoming_grads[ii];
  102. if(nesterov)
  103. incoming_grads[ii] += momentum * incoming_moms[ii];
  104. else
  105. incoming_grads[ii] = incoming_moms[ii];
  106. }
  107. // Apply WD after momentum if desired
  108. if(wd != 0.f && wd_after_momentum)
  109. incoming_grads[ii] += wd * incoming_weights[ii];
  110. // adjust the weight and write out
  111. weight_in[i] += (-lr * incoming_grads[ii]);
  112. // if necessary, write out an fp16 copy of the weights
  113. if(N == 4)
  114. model_weights_out[i] = static_cast<at::Half>(weight_in[i]);
  115. // also write out the new momentum
  116. if(momentum != 0.f)
  117. mom_in[i] = incoming_moms[ii];
  118. }
  119. }
  120. }
  121. }
  122. };
  123. void multi_tensor_sgd_cuda(
  124. int chunk_size,
  125. at::Tensor noop_flag,
  126. std::vector<std::vector<at::Tensor>> tensor_lists,
  127. float wd,
  128. float momentum,
  129. float dampening,
  130. float lr,
  131. bool nesterov,
  132. bool first_run,
  133. bool wd_after_momentum,
  134. float scale)
  135. {
  136. auto num_tensors = tensor_lists.size();
  137. auto grad_type = tensor_lists[0][0].scalar_type();
  138. auto weight_type = tensor_lists[1][0].scalar_type();
  139. if(num_tensors == 4)
  140. for(int i = 0; i < tensor_lists[3].size(); i++)
  141. TORCH_CHECK(tensor_lists[3][i].scalar_type() == at::ScalarType::Half,
  142. "Additional output tensors should always be fp16.");
  143. TORCH_CHECK(noop_flag.device() == tensor_lists[0][0].device(), "expected noop flag to be on the same device as tensors");
  144. // We have 3 possibilities to handle here, in terms of
  145. // grad_type, param_type, momentum_type, requires_fp16_copy
  146. // 1. fp16, fp16, fp16, No
  147. // 2. fp32, fp32, fp32, No
  148. // 3. fp16, fp32, fp32, Yes
  149. // 4. fp32, fp32, fp32, Yes // this is the materialize_master_grads=True case
  150. // It's easier to hardcode these possibilities than to use
  151. // switches etc. to handle the cross-product of cases where
  152. // we don't want the majority of them.
  153. // Case 1. fp16, fp16, fp16, No
  154. if(grad_type == at::ScalarType::Half &&
  155. weight_type == at::ScalarType::Half &&
  156. num_tensors == 3)
  157. {
  158. multi_tensor_apply<3>(
  159. BLOCK_SIZE,
  160. chunk_size,
  161. noop_flag,
  162. tensor_lists,
  163. SGDFunctor<3, at::Half, at::Half>(),
  164. wd,
  165. momentum,
  166. dampening,
  167. lr,
  168. nesterov,
  169. first_run,
  170. wd_after_momentum,
  171. scale);
  172. }
  173. // Case 2. fp16, fp32, fp32, No
  174. // else if (grad_type == at::ScalarType::Half &&
  175. // weight_type == at::ScalarType::Float &&
  176. // num_tensors == 3) {
  177. // multi_tensor_apply<3>(
  178. // BLOCK_SIZE,
  179. // chunk_size,
  180. // noop_flag,
  181. // tensor_lists,
  182. // SGDFunctor<3, at::Half, float>(),
  183. // wd,
  184. // momentum,
  185. // dampening,
  186. // lr,
  187. // nesterov,
  188. // first_run,
  189. // wd_after_momentum);
  190. // }
  191. // Case 2. fp32, fp32, fp32, No
  192. else if(grad_type == at::ScalarType::Float &&
  193. weight_type == at::ScalarType::Float &&
  194. num_tensors == 3)
  195. {
  196. multi_tensor_apply<3>(
  197. BLOCK_SIZE,
  198. chunk_size,
  199. noop_flag,
  200. tensor_lists,
  201. SGDFunctor<3, float, float>(),
  202. wd,
  203. momentum,
  204. dampening,
  205. lr,
  206. nesterov,
  207. first_run,
  208. wd_after_momentum,
  209. scale);
  210. }
  211. // Case 3. fp16, fp32, fp32, Yes
  212. else if(grad_type == at::ScalarType::Half &&
  213. weight_type == at::ScalarType::Float &&
  214. num_tensors == 4)
  215. {
  216. multi_tensor_apply<4>(
  217. BLOCK_SIZE,
  218. chunk_size,
  219. noop_flag,
  220. tensor_lists,
  221. SGDFunctor<4, at::Half, float>(),
  222. wd,
  223. momentum,
  224. dampening,
  225. lr,
  226. nesterov,
  227. first_run,
  228. wd_after_momentum,
  229. scale);
  230. }
  231. // Case 4. fp32, fp32, fp32, Yes
  232. else if(grad_type == at::ScalarType::Float &&
  233. weight_type == at::ScalarType::Float &&
  234. num_tensors == 4)
  235. {
  236. multi_tensor_apply<4>(
  237. BLOCK_SIZE,
  238. chunk_size,
  239. noop_flag,
  240. tensor_lists,
  241. SGDFunctor<4, float, float>(),
  242. wd,
  243. momentum,
  244. dampening,
  245. lr,
  246. nesterov,
  247. first_run,
  248. wd_after_momentum,
  249. scale);
  250. }
  251. else
  252. {
  253. AT_ERROR("multi_tensor_sgd only supports some combinations of gradient & weight types. Given: ",
  254. "gradient: ", grad_type, ", weight: ", weight_type, ", num_lists: ", num_tensors);
  255. }
  256. AT_CUDA_CHECK(cudaGetLastError());
  257. }