123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102 |
- import unittest
- import random
- import math
- import torch
- try:
- import amp_C
- from amp_C import update_scale_hysteresis
- disabled = False
- except ImportError as err:
- print("amp_C fused kernels unavailable, disabling TestUpdateScaleHysteresis. ImportError was ", err)
- disabled = True
- def isfinite(val):
- return ((val >= torch.finfo(torch.float32).smallest_normal) and (val <= torch.finfo(torch.float32).max))
- class TestUpdateScaleHysteresis(unittest.TestCase):
- def setUp(self):
- pass
- def tearDown(self):
- pass
- def update_scale_hysteresis_body(self, init_scale, growth_factor, backoff_factor,
- growth_interval, hysteresis):
- scale_ref = float(init_scale)
- grow_tracker_ref = 0
- hysteresis_tracker_ref = 0
- scale = torch.tensor([init_scale], dtype=torch.float32, device='cuda')
- growth_tracker = torch.tensor([0], dtype=torch.int32, device='cuda')
- hysteresis_tracker = torch.tensor([hysteresis], dtype=torch.int32, device='cuda')
- # Infs appear for hysteresis-1 iterations, scale shouldn't change
- found_inf = torch.tensor([1], dtype=torch.float32, device='cuda')
- for i in range(hysteresis-1):
- update_scale_hysteresis(scale, growth_tracker, hysteresis_tracker,
- found_inf, growth_factor, backoff_factor, growth_interval, hysteresis)
- self.assertTrue(scale.item() == init_scale)
- # No infs for growth_interval-1 iterations, scale shouldn't change
- found_inf.zero_()
- for i in range(growth_interval-1):
- update_scale_hysteresis(scale, growth_tracker, hysteresis_tracker,
- found_inf, growth_factor, backoff_factor, growth_interval, hysteresis)
- self.assertTrue(scale.item() == init_scale)
- # Infs appear for more than hysteresis iterations, scale should be backed off
- found_inf.fill_(1)
- extra_iters = random.randint(0, 1000)
- scale_before = scale.detach().item()
- scale_ref = scale_before
- for i in range(hysteresis + extra_iters):
- update_scale_hysteresis(scale, growth_tracker, hysteresis_tracker,
- found_inf, growth_factor, backoff_factor, growth_interval, hysteresis)
- for i in range(1 + extra_iters):
- # Scale is continuously backed off for each iteration with an inf
- scale_new = scale_ref * backoff_factor
- if isfinite(scale_new):
- scale_ref = scale_new
- else:
- scale_ref = 0 # Scale update kernel does not check for underflow when backing off, which results in zero
- self.assertTrue(scale.item() == scale_ref)
- # No infs for more than growth_interval iterations, scale should be increased
- found_inf.fill_(0)
- extra_iters = random.randint(0, 1000)
- scale_before = scale.detach().item()
- scale_ref = scale_before
- for i in range(growth_interval + extra_iters):
- update_scale_hysteresis(scale, growth_tracker, hysteresis_tracker,
- found_inf, growth_factor, backoff_factor, growth_interval, hysteresis)
- for i in range(1 + int(math.floor(extra_iters / growth_interval))):
- # Scale is grown every growth_interval iterations
- scale_new = scale_ref * growth_factor
- if isfinite(scale_new):
- scale_ref = scale_new
- self.assertTrue(scale.item() == scale_ref)
- @unittest.skipIf(disabled, "amp_C is unavailable")
- def test_fuzz(self):
- init_scale_list = [1, 1024, 65536]
- growth_factor_list = [1.0, 2.0, 4.0]
- backoff_factor_list = [0.5, 0.25]
- growth_interval_list = [10, 100]
- hysteresis_list = [10, 100]
- for init_scale in init_scale_list:
- for growth_factor in growth_factor_list:
- for backoff_factor in backoff_factor_list:
- for growth_interval in growth_interval_list:
- for hysteresis in hysteresis_list:
- self.update_scale_hysteresis_body(init_scale, growth_factor,
- backoff_factor, growth_interval, hysteresis)
- if __name__ == '__main__':
- unittest.main()
|