123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143 |
- import unittest
- import functools as ft
- import itertools as it
- from apex import amp
- import torch
- from torch import nn
- import torch.nn.functional as F
- from utils import common_init, HALF, FLOAT,\
- ALWAYS_HALF, ALWAYS_FLOAT, MATCH_INPUT
- def run_layer_test(test_case, fns, expected, input_shape, test_backward=True):
- for fn, typ in it.product(fns, expected.keys()):
- x = torch.randn(input_shape, dtype=typ).requires_grad_()
- y = fn(x)
- test_case.assertEqual(y.type(), expected[typ])
- if test_backward:
- y.float().sum().backward()
- test_case.assertEqual(x.grad.type(), MATCH_INPUT[typ])
- class TestBasicCasts(unittest.TestCase):
- def setUp(self):
- self.handle = amp.init(enabled=True)
- common_init(self)
- def tearDown(self):
- self.handle._deactivate()
- def test_linear_is_half(self):
- m = nn.Linear(self.h, self.h)
- f = ft.partial(F.linear, weight=m.weight, bias=m.bias)
- run_layer_test(self, [m, f], ALWAYS_HALF, (self.b, self.h))
- def test_conv2d_is_half(self):
- m = nn.Conv2d(self.c, self.c, self.k)
- f = ft.partial(F.conv2d, weight=m.weight, bias=m.bias)
- run_layer_test(self, [m, f], ALWAYS_HALF, (self.b, self.c, self.h, self.h))
- def test_softmax_is_float(self):
- m = nn.Softmax(dim=1)
- f = ft.partial(F.softmax, dim=1)
- run_layer_test(self, [m, f], ALWAYS_FLOAT, (self.b, self.h))
- def test_group_norm_is_float(self):
- m = nn.GroupNorm(num_groups=4, num_channels=self.c)
- run_layer_test(self, [m], ALWAYS_FLOAT, (self.b, self.c, self.h, self.h))
- def test_mse_loss_is_float(self):
- shape = (self.b, self.h)
- target = torch.randn(shape)
- mod = nn.MSELoss()
- m = lambda x: mod(x, target)
- f = ft.partial(F.mse_loss, target=target)
- run_layer_test(self, [m], ALWAYS_FLOAT, shape)
- def test_relu_is_match(self):
- run_layer_test(self, [nn.ReLU(), F.relu], MATCH_INPUT, (self.b, self.h))
- def test_batch_norm_is_match(self):
- m = nn.BatchNorm2d(num_features=self.c)
- f = ft.partial(F.batch_norm, running_mean=m.running_mean, running_var=m.running_var,
- weight=m.weight, bias=m.bias, training=True)
- run_layer_test(self, [m], MATCH_INPUT, (self.b, self.c, self.h, self.h))
- # Test forward-only for BN inference
- m.eval()
- f = ft.partial(F.batch_norm, running_mean=m.running_mean, running_var=m.running_var,
- weight=m.weight, bias=m.bias, training=False)
- run_layer_test(self, [m, f], MATCH_INPUT, (self.b, self.c, self.h, self.h),
- test_backward=False)
- class TestBannedMethods(unittest.TestCase):
- def setUp(self):
- self.handle = amp.init(enabled=True)
- common_init(self)
- def tearDown(self):
- self.handle._deactivate()
- def bce_common(self, assertion):
- shape = (self.b, self.h)
- target = torch.rand(shape)
- mod = nn.BCELoss()
- m = lambda x: mod(x, target)
- f = ft.partial(F.binary_cross_entropy, target=target)
- for fn in [m, f]:
- x = torch.rand(shape, dtype=torch.half)
- assertion(fn, x)
- def test_bce_raises_by_default(self):
- assertion = lambda fn, x: self.assertRaises(NotImplementedError, fn, x)
- self.bce_common(assertion)
- def test_bce_is_float_with_allow_banned(self):
- self.handle._deactivate()
- self.handle = amp.init(enabled=True, allow_banned=True)
- assertion = lambda fn, x: self.assertEqual(fn(x).type(), FLOAT)
- self.bce_common(assertion)
- class TestTensorCasts(unittest.TestCase):
- def setUp(self):
- self.handle = amp.init(enabled=True)
- common_init(self)
- def tearDown(self):
- self.handle._deactivate()
- def test_matmul_method_is_half(self):
- other = torch.randn(self.h, self.h)
- lhs = lambda x: x.matmul(other)
- rhs = lambda x: other.matmul(x)
- run_layer_test(self, [lhs, rhs], ALWAYS_HALF, (self.h, self.h))
- def test_matmul_op_is_half(self):
- other = torch.randn(self.h, self.h)
- lhs = lambda x: x @ other
- rhs = lambda x: other @ x
- run_layer_test(self, [lhs, rhs], ALWAYS_HALF, (self.h, self.h))
- def test_pow_method_is_float(self):
- fn = lambda x: x.pow(2.)
- run_layer_test(self, [fn], ALWAYS_FLOAT, (self.b, self.h))
- def test_pow_op_is_float(self):
- fn = lambda x: x ** 2.
- run_layer_test(self, [fn], ALWAYS_FLOAT, (self.b, self.h))
- def test_cpu_is_float(self):
- fn = lambda x: x.cpu()
- always_cpu_float = {torch.float: 'torch.FloatTensor',
- torch.half: 'torch.FloatTensor'}
- run_layer_test(self, [fn], always_cpu_float, (self.b, self.h))
- def test_sum_is_float(self):
- fn = lambda x: x.sum()
- run_layer_test(self, [fn], ALWAYS_FLOAT, (self.b, self.h))
- # TODO: maybe more tests on disabled casting?
- if __name__ == '__main__':
- unittest.main()
|