test_fp16util.py 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475
  1. import unittest
  2. import torch
  3. import torch.nn as nn
  4. from apex.fp16_utils import FP16Model
  5. class DummyBlock(nn.Module):
  6. def __init__(self):
  7. super(DummyBlock, self).__init__()
  8. self.conv = nn.Conv2d(10, 10, 2)
  9. self.bn = nn.BatchNorm2d(10, affine=True)
  10. def forward(self, x):
  11. return self.conv(self.bn(x))
  12. class DummyNet(nn.Module):
  13. def __init__(self):
  14. super(DummyNet, self).__init__()
  15. self.conv1 = nn.Conv2d(3, 10, 2)
  16. self.bn1 = nn.BatchNorm2d(10, affine=False)
  17. self.db1 = DummyBlock()
  18. self.db2 = DummyBlock()
  19. def forward(self, x):
  20. out = x
  21. out = self.conv1(out)
  22. out = self.bn1(out)
  23. out = self.db1(out)
  24. out = self.db2(out)
  25. return out
  26. class DummyNetWrapper(nn.Module):
  27. def __init__(self):
  28. super(DummyNetWrapper, self).__init__()
  29. self.bn = nn.BatchNorm2d(3, affine=True)
  30. self.dn = DummyNet()
  31. def forward(self, x):
  32. return self.dn(self.bn(x))
  33. class TestFP16Model(unittest.TestCase):
  34. def setUp(self):
  35. self.N = 64
  36. self.C_in = 3
  37. self.H_in = 16
  38. self.W_in = 32
  39. self.in_tensor = torch.randn((self.N, self.C_in, self.H_in, self.W_in)).cuda()
  40. self.orig_model = DummyNetWrapper().cuda()
  41. self.fp16_model = FP16Model(self.orig_model)
  42. def test_params_and_buffers(self):
  43. exempted_modules = [
  44. self.fp16_model.network.bn,
  45. self.fp16_model.network.dn.db1.bn,
  46. self.fp16_model.network.dn.db2.bn,
  47. ]
  48. for m in self.fp16_model.modules():
  49. expected_dtype = torch.float if (m in exempted_modules) else torch.half
  50. for p in m.parameters(recurse=False):
  51. assert p.dtype == expected_dtype
  52. for b in m.buffers(recurse=False):
  53. assert b.dtype in (expected_dtype, torch.int64)
  54. def test_output_is_half(self):
  55. out_tensor = self.fp16_model(self.in_tensor)
  56. assert out_tensor.dtype == torch.half