test_mlp.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. """Tests for c++ MLP"""
  2. from itertools import product
  3. from time import time
  4. import torch
  5. from torch import nn
  6. from torch.testing._internal import common_utils
  7. from torch.testing._internal.common_device_type import instantiate_device_type_tests
  8. from torch.testing._internal.common_device_type import onlyCUDA
  9. from apex.mlp import MLP
  10. batch_size = 1024
  11. mlp_sizes = [480, 1024, 1024, 512, 256, 1]
  12. num_iters = 10
  13. # note(crcrpar): On Ampere, this test should be run without TF32 enabled.
  14. class TestMLP(common_utils.TestCase):
  15. def test_creation(self):
  16. MLP(mlp_sizes)
  17. def test_numeric(self):
  18. mlp = MLP(mlp_sizes).cuda()
  19. mlp_layers = []
  20. for i in range(mlp.num_layers):
  21. linear = nn.Linear(mlp_sizes[i], mlp_sizes[i + 1])
  22. with torch.no_grad():
  23. mlp.weights[i].copy_(linear.weight)
  24. mlp.biases[i].copy_(linear.bias)
  25. mlp_layers.append(linear)
  26. mlp_layers.append(nn.ReLU())
  27. ref_mlp = nn.Sequential(*mlp_layers).cuda()
  28. test_input = (
  29. torch.empty(batch_size, mlp_sizes[0], device="cuda")
  30. .uniform_(-1.0, 1.0)
  31. .requires_grad_()
  32. )
  33. ref_input = test_input.clone().detach().requires_grad_()
  34. mlp_out = mlp(test_input)
  35. ref_out = ref_mlp(ref_input)
  36. self.assertEqual(mlp_out, ref_out)
  37. # Use mean value as scalar loss. Multiply 10 to make it big enough not zero out
  38. mlp_out.mean().mul(10.0).backward()
  39. ref_out.mean().mul(10.0).backward()
  40. self.assertEqual(test_input.grad, ref_input.grad)
  41. self.assertEqual(mlp.biases[0].grad, ref_mlp[0].bias.grad)
  42. def _test_mlp_impl(self, use_activation: str, bias: bool, enable_autocast: bool):
  43. mlp = MLP(mlp_sizes, bias=bias, activation=use_activation).cuda()
  44. mlp_layers = []
  45. for i in range(mlp.num_layers):
  46. linear = nn.Linear(mlp_sizes[i], mlp_sizes[i + 1], bias=bias)
  47. with torch.no_grad():
  48. mlp.weights[i].copy_(linear.weight)
  49. if bias:
  50. mlp.biases[i].copy_(linear.bias)
  51. mlp_layers.append(linear)
  52. if use_activation == "relu":
  53. mlp_layers.append(nn.ReLU())
  54. if use_activation == "sigmoid":
  55. mlp_layers.append(nn.Sigmoid())
  56. ref_mlp = nn.Sequential(*mlp_layers).cuda()
  57. test_input = (
  58. torch.empty(batch_size, mlp_sizes[0], device="cuda")
  59. .uniform_(-1.0, 1.0)
  60. .requires_grad_()
  61. )
  62. ref_input = test_input.clone().detach().requires_grad_()
  63. with torch.cuda.amp.autocast_mode.autocast(enabled=enable_autocast):
  64. mlp_out = mlp(test_input)
  65. mlp_loss = mlp_out.mean().mul(10.0)
  66. # Use mean value as scalar loss. Multiply 10 to make it big enough not zero out
  67. ref_out = ref_mlp(ref_input)
  68. ref_loss = ref_out.mean().mul(10.0)
  69. mlp_loss.backward()
  70. ref_loss.backward()
  71. if enable_autocast:
  72. self.assertEqual(mlp_out.dtype, torch.float16)
  73. self.assertEqual(ref_out.dtype, torch.float16)
  74. else:
  75. self.assertEqual(mlp_out, ref_out)
  76. self.assertEqual(test_input.grad, ref_input.grad)
  77. self.assertEqual(mlp.weights[0].grad, ref_mlp[0].weight.grad)
  78. @common_utils.parametrize(
  79. "use_activation,bias",
  80. list(product(("none", "relu", "sigmoid"), (True, False))),
  81. )
  82. def test_mlp(self, use_activation: str, bias: bool):
  83. self._test_mlp_impl(use_activation, bias, enable_autocast=False)
  84. @common_utils.parametrize(
  85. "use_activation,bias",
  86. list(product(("none", "relu", "sigmoid"), (True, False))),
  87. )
  88. def test_mlp_autocast_fp16(self, use_activation: str, bias: bool):
  89. self._test_mlp_impl(use_activation, bias, enable_autocast=True)
  90. def test_no_grad(self):
  91. mlp = MLP(mlp_sizes).cuda()
  92. mlp_layers = []
  93. for i in range(mlp.num_layers):
  94. linear = nn.Linear(mlp_sizes[i], mlp_sizes[i + 1])
  95. with torch.no_grad():
  96. mlp.weights[i].copy_(linear.weight)
  97. mlp.biases[i].copy_(linear.bias)
  98. mlp_layers.append(linear)
  99. mlp_layers.append(nn.ReLU(inplace=True))
  100. ref_mlp = nn.Sequential(*mlp_layers).cuda()
  101. test_input = torch.empty(batch_size, mlp_sizes[0], device="cuda").uniform_(-1.0, 1.0)
  102. ref_input = test_input.clone().detach()
  103. mlp_out = mlp(test_input)
  104. ref_out = ref_mlp(ref_input)
  105. self.assertEqual(mlp_out, ref_out)
  106. # Use mean value as scalar loss. Multiply 10 to make it big enough not zero out
  107. mlp_out.mean().mul(10.0).backward()
  108. ref_out.mean().mul(10.0).backward()
  109. self.assertEqual(mlp.weights[0].grad, ref_mlp[0].weight.grad)
  110. def test_performance_half(self):
  111. mlp = MLP(mlp_sizes).cuda().half()
  112. mlp_layers = []
  113. for i in range(mlp.num_layers):
  114. linear = nn.Linear(mlp_sizes[i], mlp_sizes[i + 1])
  115. mlp.weights[i].data.copy_(linear.weight)
  116. mlp.biases[i].data.copy_(linear.bias)
  117. mlp_layers.append(linear)
  118. mlp_layers.append(nn.ReLU(inplace=True))
  119. ref_mlp = nn.Sequential(*mlp_layers).cuda().half()
  120. test_input = (
  121. torch.empty(batch_size, mlp_sizes[0], device="cuda", dtype=torch.half)
  122. .fill_(10.0)
  123. .requires_grad_()
  124. )
  125. ref_input = (
  126. torch.empty(batch_size, mlp_sizes[0], device="cuda", dtype=torch.half)
  127. .fill_(10.0)
  128. .requires_grad_()
  129. )
  130. # Warm up GPU
  131. for _ in range(100):
  132. ref_out = ref_mlp(ref_input)
  133. ref_loss = ref_out.mean()
  134. ref_mlp.zero_grad()
  135. ref_loss.backward()
  136. mlp_out = mlp(test_input)
  137. test_loss = mlp_out.mean()
  138. mlp.zero_grad()
  139. test_loss.backward()
  140. torch.cuda.profiler.start()
  141. torch.cuda.synchronize()
  142. start_time = time()
  143. for _ in range(num_iters):
  144. ref_out = ref_mlp(ref_input)
  145. ref_loss = ref_out.mean()
  146. ref_mlp.zero_grad()
  147. ref_loss.backward()
  148. torch.cuda.synchronize()
  149. stop_time = time()
  150. ref_time = (stop_time - start_time) * 1000.0 / num_iters
  151. print(f"\nPytorch MLP time {ref_time:.4f} ms")
  152. torch.cuda.synchronize()
  153. start_time = time()
  154. for _ in range(num_iters):
  155. mlp_out = mlp(test_input)
  156. test_loss = mlp_out.mean()
  157. mlp.zero_grad()
  158. test_loss.backward()
  159. torch.cuda.synchronize()
  160. stop_time = time()
  161. actual_time = (stop_time - start_time) * 1000.0 / num_iters
  162. print(f"C++ MLP time {actual_time:.4f} ms")
  163. torch.cuda.profiler.stop()
  164. self.assertLessEqual(
  165. actual_time,
  166. ref_time,
  167. msg=f"Custom extension took {actual_time:.4f} while PyTorch took {ref_time:.4f}",
  168. )
  169. instantiate_device_type_tests(TestMLP, globals(), only_for=("cuda",))
  170. if __name__ == "__main__":
  171. common_utils.run_tests()