123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126 |
- import unittest
- import functools as ft
- import itertools as it
- from apex import amp
- import torch
- from torch import nn
- import torch.nn.functional as F
- from utils import common_init, HALF, FLOAT,\
- ALWAYS_HALF, ALWAYS_FLOAT, MATCH_INPUT
- try:
- import amp_C
- from amp_C import multi_tensor_scale
- from apex.multi_tensor_apply import MultiTensorApply
- disabled = False
- except ImportError as err:
- print("amp_C fused kernels unavailable, disabling TestMultiTensorApply. ImportError was ", err)
- disabled = True
- class TestMultiTensorScale(unittest.TestCase):
- def setUp(self):
- common_init(self)
- self.scale = 4.0
- self.overflow_buf = torch.cuda.IntTensor(1).zero_()
- self.ref = torch.cuda.FloatTensor([1.0])
- def tearDown(self):
- pass
- # The tensor creation here is written for convenience, not speed.
- def downscale(self, sizea, sizeb, applier, repeat_tensors, in_type, out_type, inplace=False):
- self.overflow_buf.zero_()
- a = torch.cuda.FloatTensor(sizea).fill_(self.scale)
- b = torch.cuda.FloatTensor(sizeb).fill_(self.scale)
- out_list = []
- for i in range(repeat_tensors):
- out_list += [a.clone().to(out_type), b.clone().to(out_type)]
- if inplace:
- in_list = out_list
- else:
- in_list = [out.clone().to(in_type) for out in out_list]
- applier(multi_tensor_scale, self.overflow_buf, [in_list, out_list], 1./self.scale)
- self.assertTrue(all([torch.allclose(out, self.ref.to(out_type)) for out in out_list]))
- self.assertTrue(self.overflow_buf.item() == 0)
-
- def find_inf(self, sizea, sizeb, applier, repeat_tensors, in_type, out_type, t, ind, val, inplace=False):
- self.overflow_buf.zero_()
- a = torch.cuda.FloatTensor(sizea).fill_(self.scale)
- b = torch.cuda.FloatTensor(sizeb).fill_(self.scale)
- out_list = []
- for i in range(repeat_tensors):
- out_list += [a.clone().to(out_type), b.clone().to(out_type)]
- if inplace:
- in_list = out_list
- else:
- in_list = [out.clone().to(in_type) for out in out_list]
- applier(multi_tensor_scale, self.overflow_buf, [in_list, out_list], 1./self.scale)
- self.overflow_buf.zero_()
- in_list[t][ind] = val
- applier(multi_tensor_scale, self.overflow_buf, [in_list, out_list], 1./self.scale)
- self.assertTrue(self.overflow_buf.item())
- # Currently, the fused kernel gives a hard error if you attempt to downscale
- # into fp16 output, which imo is the desired behavior. Maybe someday we
- # will learn otherwise.
- # @unittest.skipIf(disabled, "amp_C is unavailable")
- # def test_fp16_to_fp16(self):
- # self.downscale(self.fp16, self.fp16, self.fp16_ref)
- #
- # @unittest.skipIf(disabled, "amp_C is unavailable")
- # def test_fp32_to_fp16(self):
- # self.downscale(self.fp32, self.fp16, self.fp16_ref)
- @unittest.skipIf(disabled, "amp_C is unavailable")
- def test_fuzz(self):
- input_size_pairs = (
- (7777*77, 555*555),
- (777, 555),
- (555, 2048*32+1),
- (2048*32+1, 555),
- (555, 2048*32),
- (2048*32, 555),
- (33333, 555),
- (555, 33333))
- appliers = (
- MultiTensorApply(2048*32),
- MultiTensorApply(333),
- MultiTensorApply(33333))
- repeat_tensors = (
- 1,
- 55)
- for sizea, sizeb in input_size_pairs:
- for applier in appliers:
- for repeat in repeat_tensors:
- for in_type in (torch.float32, torch.float16):
- for out_type in (torch.float32, torch.float16):
- for inplace in (True, False):
- if inplace is True and (out_type is not in_type):
- continue
- else:
- self.downscale(sizea, sizeb, applier, repeat, in_type, out_type, inplace=inplace)
- self.find_inf(sizea, sizeb, applier, repeat, in_type, out_type,
- 0, 0, float('nan'), inplace=inplace)
- self.find_inf(sizea, sizeb, applier, repeat, in_type, out_type,
- 2*repeat-1, sizeb-1, float('inf'), inplace=inplace)
- self.find_inf(sizea, sizeb, applier, repeat, in_type, out_type,
- 2*(repeat//2), sizea//2, float('inf'), inplace=inplace)
- if __name__ == '__main__':
- unittest.main()
|