test_update_scale_hysteresis.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. import unittest
  2. import random
  3. import math
  4. import torch
  5. try:
  6. import amp_C
  7. from amp_C import update_scale_hysteresis
  8. disabled = False
  9. except ImportError as err:
  10. print("amp_C fused kernels unavailable, disabling TestUpdateScaleHysteresis. ImportError was ", err)
  11. disabled = True
  12. def isfinite(val):
  13. return ((val >= torch.finfo(torch.float32).smallest_normal) and (val <= torch.finfo(torch.float32).max))
  14. class TestUpdateScaleHysteresis(unittest.TestCase):
  15. def setUp(self):
  16. pass
  17. def tearDown(self):
  18. pass
  19. def update_scale_hysteresis_body(self, init_scale, growth_factor, backoff_factor,
  20. growth_interval, hysteresis):
  21. scale_ref = float(init_scale)
  22. grow_tracker_ref = 0
  23. hysteresis_tracker_ref = 0
  24. scale = torch.tensor([init_scale], dtype=torch.float32, device='cuda')
  25. growth_tracker = torch.tensor([0], dtype=torch.int32, device='cuda')
  26. hysteresis_tracker = torch.tensor([hysteresis], dtype=torch.int32, device='cuda')
  27. # Infs appear for hysteresis-1 iterations, scale shouldn't change
  28. found_inf = torch.tensor([1], dtype=torch.float32, device='cuda')
  29. for i in range(hysteresis-1):
  30. update_scale_hysteresis(scale, growth_tracker, hysteresis_tracker,
  31. found_inf, growth_factor, backoff_factor, growth_interval, hysteresis)
  32. self.assertTrue(scale.item() == init_scale)
  33. # No infs for growth_interval-1 iterations, scale shouldn't change
  34. found_inf.zero_()
  35. for i in range(growth_interval-1):
  36. update_scale_hysteresis(scale, growth_tracker, hysteresis_tracker,
  37. found_inf, growth_factor, backoff_factor, growth_interval, hysteresis)
  38. self.assertTrue(scale.item() == init_scale)
  39. # Infs appear for more than hysteresis iterations, scale should be backed off
  40. found_inf.fill_(1)
  41. extra_iters = random.randint(0, 1000)
  42. scale_before = scale.detach().item()
  43. scale_ref = scale_before
  44. for i in range(hysteresis + extra_iters):
  45. update_scale_hysteresis(scale, growth_tracker, hysteresis_tracker,
  46. found_inf, growth_factor, backoff_factor, growth_interval, hysteresis)
  47. for i in range(1 + extra_iters):
  48. # Scale is continuously backed off for each iteration with an inf
  49. scale_new = scale_ref * backoff_factor
  50. if isfinite(scale_new):
  51. scale_ref = scale_new
  52. else:
  53. scale_ref = 0 # Scale update kernel does not check for underflow when backing off, which results in zero
  54. self.assertTrue(scale.item() == scale_ref)
  55. # No infs for more than growth_interval iterations, scale should be increased
  56. found_inf.fill_(0)
  57. extra_iters = random.randint(0, 1000)
  58. scale_before = scale.detach().item()
  59. scale_ref = scale_before
  60. for i in range(growth_interval + extra_iters):
  61. update_scale_hysteresis(scale, growth_tracker, hysteresis_tracker,
  62. found_inf, growth_factor, backoff_factor, growth_interval, hysteresis)
  63. for i in range(1 + int(math.floor(extra_iters / growth_interval))):
  64. # Scale is grown every growth_interval iterations
  65. scale_new = scale_ref * growth_factor
  66. if isfinite(scale_new):
  67. scale_ref = scale_new
  68. self.assertTrue(scale.item() == scale_ref)
  69. @unittest.skipIf(disabled, "amp_C is unavailable")
  70. def test_fuzz(self):
  71. init_scale_list = [1, 1024, 65536]
  72. growth_factor_list = [1.0, 2.0, 4.0]
  73. backoff_factor_list = [0.5, 0.25]
  74. growth_interval_list = [10, 100]
  75. hysteresis_list = [10, 100]
  76. for init_scale in init_scale_list:
  77. for growth_factor in growth_factor_list:
  78. for backoff_factor in backoff_factor_list:
  79. for growth_interval in growth_interval_list:
  80. for hysteresis in hysteresis_list:
  81. self.update_scale_hysteresis_body(init_scale, growth_factor,
  82. backoff_factor, growth_interval, hysteresis)
  83. if __name__ == '__main__':
  84. unittest.main()