test_multi_tensor_axpby.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180
  1. import unittest
  2. import functools as ft
  3. import itertools as it
  4. from apex import amp
  5. import torch
  6. from torch import nn
  7. import torch.nn.functional as F
  8. from math import floor
  9. from utils import common_init, HALF, FLOAT,\
  10. ALWAYS_HALF, ALWAYS_FLOAT, MATCH_INPUT
  11. try:
  12. import amp_C
  13. from amp_C import multi_tensor_axpby
  14. from apex.multi_tensor_apply import MultiTensorApply
  15. disabled = False
  16. except ImportError as err:
  17. print("amp_C fused kernels unavailable, disabling TestMultiTensorApply. ImportError was ", err)
  18. disabled = True
  19. TORCH_MAJOR = int(torch.__version__.split('.')[0])
  20. TORCH_MINOR = int(torch.__version__.split('.')[1])
  21. try_nhwc = (TORCH_MAJOR > 1) or (TORCH_MAJOR == 1 and TORCH_MINOR > 4)
  22. class TestMultiTensorAxpby(unittest.TestCase):
  23. def setUp(self):
  24. common_init(self)
  25. self.a = 2.0
  26. self.b = 8.0
  27. self.xval = 4.0
  28. self.yval = 16.0
  29. self.overflow_buf = torch.cuda.IntTensor(1).zero_()
  30. self.ref = torch.full((1,), 136.0, device="cuda", dtype=torch.float32)
  31. def tearDown(self):
  32. pass
  33. # The tensor creation here is written for convenience, not speed.
  34. def axpby(self, sizea, sizeb, applier, repeat_tensors,
  35. x_type, y_type, out_type, inplace=False, nhwc=False):
  36. self.overflow_buf.zero_()
  37. sizea = sizea if isinstance(sizea, tuple) else (sizea,)
  38. sizeb = sizeb if isinstance(sizeb, tuple) else (sizeb,)
  39. t1 = torch.full(sizea, 1.0, device="cuda", dtype=torch.float32)
  40. t2 = torch.full(sizeb, 1.0, device="cuda", dtype=torch.float32)
  41. def to_fmt(t, tp):
  42. if nhwc:
  43. return t.clone().to(tp, memory_format=torch.channels_last)
  44. else:
  45. return t.clone().to(tp)
  46. y_list = []
  47. for i in range(repeat_tensors):
  48. y_list += [to_fmt(t1, y_type)*self.yval, to_fmt(t2, y_type)*self.yval]
  49. x_list = [to_fmt(x, x_type)*(self.xval/self.yval) for x in y_list]
  50. if inplace:
  51. out_list = y_list
  52. else:
  53. out_list = [to_fmt(out, out_type)*3.0 for out in y_list]
  54. applier(multi_tensor_axpby, self.overflow_buf, [x_list, y_list, out_list], self.a, self.b, -1)
  55. self.assertTrue(all([torch.allclose(out, self.ref.to(out_type)) for out in out_list]),
  56. msg="{} {} {} {} {} {} {}".format(sizea, sizeb, repeat_tensors,
  57. x_type, y_type, out_type, inplace))
  58. self.assertTrue(self.overflow_buf.item() == 0,
  59. msg="{} {} {} {} {} {} {}".format(sizea, sizeb, repeat_tensors,
  60. x_type, y_type, out_type, inplace))
  61. # def find_inf(self, sizea, sizeb, applier, repeat_tensors, in_type, out_type, t, ind, val, inplace=False):
  62. # self.overflow_buf.zero_()
  63. # a = torch.cuda.FloatTensor(sizea).fill_(self.scale)
  64. # b = torch.cuda.FloatTensor(sizeb).fill_(self.scale)
  65. # out_list = []
  66. # for i in range(repeat_tensors):
  67. # out_list += [a.clone().to(out_type), b.clone().to(out_type)]
  68. # if inplace:
  69. # in_list = out_list
  70. # else:
  71. # in_list = [out.clone().to(in_type) for out in out_list]
  72. # applier(multi_tensor_scale, self.overflow_buf, [in_list, out_list], 1./self.scale)
  73. # self.overflow_buf.zero_()
  74. # in_list[t][ind] = val
  75. # applier(multi_tensor_scale, self.overflow_buf, [in_list, out_list], 1./self.scale)
  76. # self.assertTrue(self.overflow_buf.item())
  77. @unittest.skipIf(disabled, "amp_C is unavailable")
  78. def test_fuzz(self):
  79. input_size_pairs = (
  80. (7777*77, 555*555),
  81. (777, 555),
  82. (555, 2048*32+1),
  83. (2048*32+1, 555),
  84. (555, 2048*32),
  85. (2048*32, 555),
  86. (33333, 555),
  87. (555, 33333))
  88. appliers = (
  89. MultiTensorApply(2048*32),
  90. MultiTensorApply(333),
  91. MultiTensorApply(33333))
  92. repeat_tensors = (
  93. 1,
  94. 55)
  95. for sizea, sizeb in input_size_pairs:
  96. for applier in appliers:
  97. for repeat in repeat_tensors:
  98. for x_type in (torch.float32, torch.float16):
  99. for y_type in (torch.float32, torch.float16):
  100. for out_type in (torch.float32, torch.float16):
  101. for inplace in (True, False):
  102. if inplace is True and (y_type is not out_type):
  103. continue
  104. else:
  105. self.axpby(sizea, sizeb, applier, repeat,
  106. x_type, y_type, out_type, inplace=inplace)
  107. # self.find_inf(sizea, sizeb, applier, repeat, in_type, out_type,
  108. # 0, 0, float('nan'), inplace=inplace)
  109. # self.find_inf(sizea, sizeb, applier, repeat, in_type, out_type,
  110. # 2*repeat-1, sizeb-1, float('inf'), inplace=inplace)
  111. # self.find_inf(sizea, sizeb, applier, repeat, in_type, out_type,
  112. # 2*(repeat//2), sizea//2, float('inf'), inplace=inplace)
  113. @unittest.skipIf(disabled, "amp_C is unavailable")
  114. @unittest.skipIf(not try_nhwc, "torch version is 1.4 or earlier, may not support nhwc")
  115. def test_fuzz_nhwc(self):
  116. input_size_pairs = (
  117. ((7, 77, 7, 77), (5, 55, 5, 55)),
  118. ((1, 1, 777, 1), (1, 1, 555, 1)),
  119. ((5, 47, 5, 55), (1, 1, 1, 2048*32 + 1)),
  120. ((1, 1, 1, 2048*32 + 1), (55, 47, 5, 55)),
  121. ((555, 1, 1, 1), (32, 8, 32, 8)),
  122. ((32, 8, 32, 8), (55, 47, 5, 55)),
  123. ((1, 1, 33333, 1), (55, 47, 55, 5)),
  124. ((55, 47, 55, 5), (1, 1, 33333, 1)))
  125. appliers = (
  126. MultiTensorApply(2048*32),
  127. MultiTensorApply(333),
  128. MultiTensorApply(33333))
  129. repeat_tensors = (
  130. 1,
  131. 55)
  132. for sizea, sizeb in input_size_pairs:
  133. for applier in appliers:
  134. for repeat in repeat_tensors:
  135. for x_type in (torch.float32, torch.float16):
  136. for y_type in (torch.float32, torch.float16):
  137. for out_type in (torch.float32, torch.float16):
  138. for inplace in (True, False):
  139. if inplace is True and (y_type is not out_type):
  140. continue
  141. else:
  142. self.axpby(sizea, sizeb, applier, repeat,
  143. x_type, y_type, out_type, inplace=inplace, nhwc=True)
  144. # self.find_inf(sizea, sizeb, applier, repeat, in_type, out_type,
  145. # 0, 0, float('nan'), inplace=inplace)
  146. # self.find_inf(sizea, sizeb, applier, repeat, in_type, out_type,
  147. # 2*repeat-1, sizeb-1, float('inf'), inplace=inplace)
  148. # self.find_inf(sizea, sizeb, applier, repeat, in_type, out_type,
  149. # 2*(repeat//2), sizea//2, float('inf'), inplace=inplace)
  150. if __name__ == '__main__':
  151. unittest.main()