test_multi_tensor_scale.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. import unittest
  2. import functools as ft
  3. import itertools as it
  4. from apex import amp
  5. import torch
  6. from torch import nn
  7. import torch.nn.functional as F
  8. from utils import common_init, HALF, FLOAT,\
  9. ALWAYS_HALF, ALWAYS_FLOAT, MATCH_INPUT
  10. try:
  11. import amp_C
  12. from amp_C import multi_tensor_scale
  13. from apex.multi_tensor_apply import MultiTensorApply
  14. disabled = False
  15. except ImportError as err:
  16. print("amp_C fused kernels unavailable, disabling TestMultiTensorApply. ImportError was ", err)
  17. disabled = True
  18. class TestMultiTensorScale(unittest.TestCase):
  19. def setUp(self):
  20. common_init(self)
  21. self.scale = 4.0
  22. self.overflow_buf = torch.cuda.IntTensor(1).zero_()
  23. self.ref = torch.cuda.FloatTensor([1.0])
  24. def tearDown(self):
  25. pass
  26. # The tensor creation here is written for convenience, not speed.
  27. def downscale(self, sizea, sizeb, applier, repeat_tensors, in_type, out_type, inplace=False):
  28. self.overflow_buf.zero_()
  29. a = torch.cuda.FloatTensor(sizea).fill_(self.scale)
  30. b = torch.cuda.FloatTensor(sizeb).fill_(self.scale)
  31. out_list = []
  32. for i in range(repeat_tensors):
  33. out_list += [a.clone().to(out_type), b.clone().to(out_type)]
  34. if inplace:
  35. in_list = out_list
  36. else:
  37. in_list = [out.clone().to(in_type) for out in out_list]
  38. applier(multi_tensor_scale, self.overflow_buf, [in_list, out_list], 1./self.scale)
  39. self.assertTrue(all([torch.allclose(out, self.ref.to(out_type)) for out in out_list]))
  40. self.assertTrue(self.overflow_buf.item() == 0)
  41. def find_inf(self, sizea, sizeb, applier, repeat_tensors, in_type, out_type, t, ind, val, inplace=False):
  42. self.overflow_buf.zero_()
  43. a = torch.cuda.FloatTensor(sizea).fill_(self.scale)
  44. b = torch.cuda.FloatTensor(sizeb).fill_(self.scale)
  45. out_list = []
  46. for i in range(repeat_tensors):
  47. out_list += [a.clone().to(out_type), b.clone().to(out_type)]
  48. if inplace:
  49. in_list = out_list
  50. else:
  51. in_list = [out.clone().to(in_type) for out in out_list]
  52. applier(multi_tensor_scale, self.overflow_buf, [in_list, out_list], 1./self.scale)
  53. self.overflow_buf.zero_()
  54. in_list[t][ind] = val
  55. applier(multi_tensor_scale, self.overflow_buf, [in_list, out_list], 1./self.scale)
  56. self.assertTrue(self.overflow_buf.item())
  57. # Currently, the fused kernel gives a hard error if you attempt to downscale
  58. # into fp16 output, which imo is the desired behavior. Maybe someday we
  59. # will learn otherwise.
  60. # @unittest.skipIf(disabled, "amp_C is unavailable")
  61. # def test_fp16_to_fp16(self):
  62. # self.downscale(self.fp16, self.fp16, self.fp16_ref)
  63. #
  64. # @unittest.skipIf(disabled, "amp_C is unavailable")
  65. # def test_fp32_to_fp16(self):
  66. # self.downscale(self.fp32, self.fp16, self.fp16_ref)
  67. @unittest.skipIf(disabled, "amp_C is unavailable")
  68. def test_fuzz(self):
  69. input_size_pairs = (
  70. (7777*77, 555*555),
  71. (777, 555),
  72. (555, 2048*32+1),
  73. (2048*32+1, 555),
  74. (555, 2048*32),
  75. (2048*32, 555),
  76. (33333, 555),
  77. (555, 33333))
  78. appliers = (
  79. MultiTensorApply(2048*32),
  80. MultiTensorApply(333),
  81. MultiTensorApply(33333))
  82. repeat_tensors = (
  83. 1,
  84. 55)
  85. for sizea, sizeb in input_size_pairs:
  86. for applier in appliers:
  87. for repeat in repeat_tensors:
  88. for in_type in (torch.float32, torch.float16):
  89. for out_type in (torch.float32, torch.float16):
  90. for inplace in (True, False):
  91. if inplace is True and (out_type is not in_type):
  92. continue
  93. else:
  94. self.downscale(sizea, sizeb, applier, repeat, in_type, out_type, inplace=inplace)
  95. self.find_inf(sizea, sizeb, applier, repeat, in_type, out_type,
  96. 0, 0, float('nan'), inplace=inplace)
  97. self.find_inf(sizea, sizeb, applier, repeat, in_type, out_type,
  98. 2*repeat-1, sizeb-1, float('inf'), inplace=inplace)
  99. self.find_inf(sizea, sizeb, applier, repeat, in_type, out_type,
  100. 2*(repeat//2), sizea//2, float('inf'), inplace=inplace)
  101. if __name__ == '__main__':
  102. unittest.main()