update_scale_hysteresis.cu 2.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. #include <ATen/ATen.h>
  2. #include <ATen/cuda/Exceptions.h>
  3. #include <ATen/cuda/CUDAContext.h>
  4. __global__ void update_scale_hysteresis_cuda_kernel(float* current_scale,
  5. int* growth_tracker,
  6. int* hysteresis_tracker,
  7. const float* found_inf,
  8. double growth_factor,
  9. double backoff_factor,
  10. int growth_interval,
  11. int hysteresis)
  12. {
  13. if (*found_inf > 0) {
  14. *hysteresis_tracker -= 1;
  15. // Only reset the growth tracker when hysteresis is larger than zero
  16. if (*hysteresis_tracker > 0) {
  17. *growth_tracker = 0;
  18. return;
  19. }
  20. }
  21. if (*found_inf) {
  22. *current_scale = (*current_scale)*backoff_factor;
  23. *growth_tracker = 0;
  24. } else {
  25. // Entering this branch means we just carried out a successful step,
  26. // so growth_tracker is incremented before comparing to growth_interval.
  27. auto successful = (*growth_tracker) + 1;
  28. if (successful == growth_interval) {
  29. auto new_scale = static_cast<float>((*current_scale)*growth_factor);
  30. // Do not grow the scale past fp32 bounds to inf.
  31. if (isfinite(new_scale)) {
  32. *current_scale = new_scale;
  33. }
  34. *growth_tracker = 0;
  35. } else {
  36. *growth_tracker = successful;
  37. }
  38. }
  39. // Reset the hysteresis tracker if no infs are found
  40. if (*found_inf <= 0) {
  41. *hysteresis_tracker = hysteresis;
  42. }
  43. }
  44. at::Tensor update_scale_hysteresis_cuda(at::Tensor current_scale,
  45. at::Tensor growth_tracker,
  46. at::Tensor hysteresis_tracker,
  47. at::Tensor found_inf,
  48. const double growth_factor,
  49. const double backoff_factor,
  50. const int64_t growth_interval,
  51. const int hysteresis)
  52. {
  53. update_scale_hysteresis_cuda_kernel<<<1, 1, 0, at::cuda::getCurrentCUDAStream()>>>(
  54. current_scale.mutable_data_ptr<float>(),
  55. growth_tracker.mutable_data_ptr<int>(),
  56. hysteresis_tracker.mutable_data_ptr<int>(),
  57. found_inf.const_data_ptr<float>(),
  58. growth_factor,
  59. backoff_factor,
  60. growth_interval,
  61. hysteresis);
  62. AT_CUDA_CHECK(cudaGetLastError());
  63. return current_scale;
  64. }