test_promotion.py 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. import unittest
  2. import itertools as it
  3. from apex import amp
  4. import torch
  5. from torch import nn
  6. import torch.nn.functional as F
  7. from utils import common_init, HALF, FLOAT, DTYPES
  8. class TestPromotion(unittest.TestCase):
  9. def setUp(self):
  10. self.handle = amp.init(enabled=True)
  11. common_init(self)
  12. def tearDown(self):
  13. self.handle._deactivate()
  14. def run_binary_promote_test(self, fns, input_shape, x_inplace=False):
  15. type_pairs = it.product(DTYPES, DTYPES)
  16. for fn, (xtype, ytype) in it.product(fns, type_pairs):
  17. x = torch.randn(input_shape, dtype=xtype).requires_grad_()
  18. x_leaf = x
  19. if x_inplace:
  20. # We need a non-leaf to call in place on
  21. x = x.clone()
  22. y = torch.randn(input_shape, dtype=ytype)
  23. out = fn(x, y)
  24. if x_inplace:
  25. # In place: always match xtype
  26. self.assertEqual(out.type(), x.type())
  27. else:
  28. # Out of place: match widest type
  29. if xtype == torch.float or ytype == torch.float:
  30. self.assertEqual(out.type(), FLOAT)
  31. else:
  32. self.assertEqual(out.type(), HALF)
  33. out.float().sum().backward()
  34. self.assertEqual(x_leaf.grad.dtype, xtype)
  35. def test_atan2_matches_widest(self):
  36. fns = [lambda x, y : torch.atan2(x, y),
  37. lambda x, y : x.atan2(y)]
  38. self.run_binary_promote_test(fns, (self.b,))
  39. def test_mul_matches_widest(self):
  40. fns = [lambda x, y : torch.mul(x, y),
  41. lambda x, y: x.mul(y)]
  42. self.run_binary_promote_test(fns, (self.b,))
  43. def test_cat_matches_widest(self):
  44. shape = self.b
  45. ys = [torch.randn(shape, dtype=torch.half) for _ in range(5)]
  46. x_float = torch.randn(shape)
  47. out = torch.cat(ys + [x_float])
  48. self.assertEqual(out.type(), FLOAT)
  49. x_half = torch.randn(shape, dtype=torch.half)
  50. out = torch.cat(ys + [x_half])
  51. self.assertEqual(out.type(), HALF)
  52. def test_inplace_exp_is_error_for_half(self):
  53. xs = torch.randn(self.b)
  54. xs.exp_()
  55. self.assertEqual(xs.type(), FLOAT)
  56. xs = torch.randn(self.b, dtype=torch.half)
  57. with self.assertRaises(NotImplementedError):
  58. xs.exp_()
  59. def test_inplace_add_matches_self(self):
  60. fn = lambda x, y: x.add_(y)
  61. self.run_binary_promote_test([fn], (self.b,), x_inplace=True)
  62. if __name__ == '__main__':
  63. unittest.main()