3
0

test_tutel.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209
  1. # Copyright (c) Microsoft Corporation.
  2. # Licensed under the MIT license.
  3. import contextlib
  4. import io
  5. import itertools
  6. import json
  7. import math
  8. import os
  9. import subprocess
  10. import unittest
  11. from unittest.mock import patch
  12. try:
  13. import GPUtil
  14. GPU_NAME = GPUtil.getGPUs()[0].name
  15. except:
  16. GPU_NAME = 'unknown'
  17. class HelloworldCaller():
  18. """A class for run tutel helloworld example with different arguments"""
  19. def run(
  20. self,
  21. nproc_per_node=1,
  22. helloworld_file='helloworld',
  23. top=2, dtype='float32',
  24. num_local_experts='2',
  25. hidden_size=2048,
  26. show_step_time=True,
  27. batch_size=16,
  28. is_round=True,
  29. a2a_ffn_overlap_degree=1,
  30. num_steps=100,
  31. use_model_parallel=False,
  32. device='cuda'
  33. ):
  34. # Disable NCCL SHM because it's capacity is limited in Azure pipeline
  35. new_env = os.environ.copy()
  36. new_env['NCCL_SHM_DISABLE'] = '1'
  37. """Run helloworld example"""
  38. if helloworld_file == 'helloworld':
  39. command = 'python3 -m torch.distributed.run --nproc_per_node=' + str(nproc_per_node) + ' tutel/examples/helloworld.py --top ' + str(top) + ' --dtype ' + dtype + ' --num_local_experts ' + str(num_local_experts) + ' --hidden_size ' + str(hidden_size) + ' --batch_size ' + str(batch_size) + ' --a2a_ffn_overlap_degree ' + str(a2a_ffn_overlap_degree) + ' --num_steps ' + str(num_steps) + ' --device ' + str(device) + ' --num_tokens 1024'
  40. if use_model_parallel:
  41. command += ' --parallel_type model'
  42. else:
  43. command += ' --parallel_type data'
  44. else:
  45. raise Exception('Unhandled helloworld_file: %s' % helloworld_file)
  46. p = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=new_env)
  47. losses = []
  48. while p.poll() is None:
  49. line = p.stdout.readline().decode("utf8").split()
  50. for i in range(len(line) - 1):
  51. if line[i] == 'loss':
  52. if is_round:
  53. if dtype == 'float32':
  54. losses.append(round(float(line[i + 2][:-1]), 3))
  55. else:
  56. losses.append(round(float(line[i + 2][:-1]), 1))
  57. else:
  58. losses.append(line[i + 2][:-1])
  59. break
  60. if show_step_time and line[0] == '[Summary]':
  61. print('step time:', line[5])
  62. p.stdout.close()
  63. assert len(losses) > 0, "No valid loss result found for this unit test: %s" % command
  64. return losses
  65. class TutelTestCase(unittest.TestCase):
  66. """A class for tutel test cases."""
  67. def setUp(self):
  68. """Hook method for setting up the test"""
  69. self.GPUtype = GPU_NAME
  70. with open('tests/test_baseline.json') as f:
  71. self.data = json.load(f)
  72. for i in range(9):
  73. for j in range(len(self.data[i]['losses'])):
  74. if '32' in self.data[i]['dtype']:
  75. self.data[i]['losses'][j] = round(float(self.data[i]['losses'][j]), 3)
  76. else:
  77. self.data[i]['losses'][j] = round(float(self.data[i]['losses'][j]), 1)
  78. self.tutelCaller = HelloworldCaller()
  79. def test_cpu_kernel(self):
  80. """Test cpu kernel"""
  81. cuda_losses = self.tutelCaller.run(nproc_per_node=1, num_steps=10, device='cuda', show_step_time=False)
  82. cpu_losses = self.tutelCaller.run(nproc_per_node=1, num_steps=10, device='cpu', show_step_time=False)
  83. for i in range(10):
  84. cuda_losses[i] = round(cuda_losses[i],2)
  85. cpu_losses[i] = round(cpu_losses[i],2)
  86. self.assertEqual(cuda_losses, cpu_losses)
  87. def test_top1_fp32_1_expert(self):
  88. """Test helloworld with top1 gate, float32 dtype and 1 expert(s)."""
  89. for i in range(len(self.data[2]['step_time'])):
  90. if self.data[2]['step_time'][i]['GPU'] in self.GPUtype:
  91. print('\nexpected time:', self.data[2]['step_time'][i]['value'])
  92. self.assertEqual(self.tutelCaller.run(nproc_per_node=1, helloworld_file='helloworld', top=1, dtype='float32', num_local_experts=1), self.data[2]['losses'])
  93. def test_top1_fp32_2_experts(self):
  94. """Test helloworld with top1 gate, float32 dtype and 2 expert(s)."""
  95. for i in range(len(self.data[3]['step_time'])):
  96. if self.data[3]['step_time'][i]['GPU'] in self.GPUtype:
  97. print('\nexpected time:', self.data[3]['step_time'][i]['value'])
  98. self.assertEqual(self.tutelCaller.run(nproc_per_node=1, helloworld_file='helloworld', top=1, dtype='float32', num_local_experts=2), self.data[3]['losses'])
  99. def test_top1_fp16_1_expert(self):
  100. """Test helloworld with top1 gate, float16 dtype and 1 expert(s)."""
  101. for i in range(len(self.data[0]['step_time'])):
  102. if self.data[0]['step_time'][i]['GPU'] in self.GPUtype:
  103. print('\nexpected time:', self.data[0]['step_time'][i]['value'])
  104. self.assertEqual(self.tutelCaller.run(nproc_per_node=1, helloworld_file='helloworld', top=1, dtype='float16', num_local_experts=1)[0:2], self.data[0]['losses'][0:2])
  105. def test_top1_fp16_2_experts(self):
  106. """Test helloworld with top1 gate, float16 dtype and 2 expert(s)."""
  107. for i in range(len(self.data[1]['step_time'])):
  108. if self.data[1]['step_time'][i]['GPU'] in self.GPUtype:
  109. print('\nexpected time:', self.data[1]['step_time'][i]['value'])
  110. self.assertEqual(self.tutelCaller.run(nproc_per_node=1, helloworld_file='helloworld', top=1, dtype='float16', num_local_experts=2)[0:2], self.data[1]['losses'][0:2])
  111. def test_top2_fp32_1_expert(self):
  112. """Test helloworld with top2 gate, float32 dtype and 1 expert(s)."""
  113. for i in range(len(self.data[6]['step_time'])):
  114. if self.data[6]['step_time'][i]['GPU'] in self.GPUtype:
  115. print('\nexpected time:', self.data[6]['step_time'][i]['value'])
  116. self.assertEqual(self.tutelCaller.run(nproc_per_node=1, helloworld_file='helloworld', top=2, dtype='float32', num_local_experts=1), self.data[6]['losses'])
  117. def test_top2_fp32_2_experts(self):
  118. """Test helloworld with top2 gate, float32 dtype and 2 expert(s)."""
  119. for i in range(len(self.data[7]['step_time'])):
  120. if self.data[7]['step_time'][i]['GPU'] in self.GPUtype:
  121. print('\nexpected time:', self.data[7]['step_time'][i]['value'])
  122. self.assertEqual(self.tutelCaller.run(nproc_per_node=1, helloworld_file='helloworld', top=2, dtype='float32', num_local_experts=2), self.data[7]['losses'])
  123. def test_top2_fp16_1_expert(self):
  124. """Test helloworld with top2 gate, float16 dtype and 1 expert(s)."""
  125. for i in range(len(self.data[4]['step_time'])):
  126. if self.data[4]['step_time'][i]['GPU'] in self.GPUtype:
  127. print('\nexpected time:', self.data[4]['step_time'][i]['value'])
  128. self.assertEqual(self.tutelCaller.run(nproc_per_node=1, helloworld_file='helloworld', top=2, dtype='float16', num_local_experts=1)[0:2], self.data[4]['losses'][0:2])
  129. def test_top2_fp16_2_experts(self):
  130. """Test helloworld with top2 gate, float16 dtype and 2 expert(s)."""
  131. for i in range(len(self.data[5]['step_time'])):
  132. if self.data[5]['step_time'][i]['GPU'] in self.GPUtype:
  133. print('\nexpected time:', self.data[5]['step_time'][i]['value'])
  134. self.assertEqual(self.tutelCaller.run(nproc_per_node=1, helloworld_file='helloworld', top=2, dtype='float16', num_local_experts=2)[0:2], self.data[5]['losses'][0:2])
  135. def test_top2_fp64_2_experts(self):
  136. """Test helloworld with top2 gate, float64 dtype and 2 expert(s)."""
  137. self.assertEqual(self.tutelCaller.run(nproc_per_node=1, helloworld_file='helloworld', top=2, dtype='float64', num_local_experts=2, show_step_time=False, batch_size=1), self.data[8]['losses'])
  138. def test_compare_data_model_parallel(self):
  139. """Test helloworld data parallel and helloworld model parallel which should have same result"""
  140. self.assertEqual(
  141. self.tutelCaller.run(nproc_per_node=2, helloworld_file='helloworld', top=2, dtype='float32', num_local_experts=-2, show_step_time=False, use_model_parallel=False),
  142. self.tutelCaller.run(nproc_per_node=2, helloworld_file='helloworld', top=2, dtype='float32', num_local_experts=-2, show_step_time=False, use_model_parallel=True),
  143. )
  144. def test_a2a_ffn_overlap(self):
  145. """Test whether AllToAll-FFN overlapping works properly. Note that too small batch size might cause precision issue."""
  146. self.assertEqual(
  147. self.tutelCaller.run(nproc_per_node=2, helloworld_file='helloworld', top=2, dtype='float64', num_local_experts=-2, show_step_time=False, batch_size=1, a2a_ffn_overlap_degree=1),
  148. self.tutelCaller.run(nproc_per_node=2, helloworld_file='helloworld', top=2, dtype='float64', num_local_experts=-2, show_step_time=False, batch_size=1, a2a_ffn_overlap_degree=2)
  149. )
  150. self.assertEqual(
  151. self.tutelCaller.run(nproc_per_node=2, helloworld_file='helloworld', top=2, dtype='float64', num_local_experts=1, show_step_time=False, batch_size=1, a2a_ffn_overlap_degree=1),
  152. self.tutelCaller.run(nproc_per_node=2, helloworld_file='helloworld', top=2, dtype='float64', num_local_experts=1, show_step_time=False, batch_size=1, a2a_ffn_overlap_degree=2)
  153. )
  154. self.assertEqual(
  155. self.tutelCaller.run(nproc_per_node=2, helloworld_file='helloworld', top=2, dtype='float64', num_local_experts=2, show_step_time=False, batch_size=1, a2a_ffn_overlap_degree=1),
  156. self.tutelCaller.run(nproc_per_node=2, helloworld_file='helloworld', top=2, dtype='float64', num_local_experts=2, show_step_time=False, batch_size=1, a2a_ffn_overlap_degree=2)
  157. )
  158. def test_a2a_algos(self):
  159. def get_loss_and_step_time(args):
  160. with contextlib.redirect_stdout(io.StringIO()) as f:
  161. loss = self.tutelCaller.run(**args)
  162. step_time = float(f.getvalue().strip().split()[-1])
  163. return loss, step_time
  164. for nproc_per_node, dtype, num_local_experts in itertools.product(
  165. [1, 2],
  166. ['float32', 'float16'],
  167. [1, 2],
  168. ):
  169. test_case = {
  170. 'nproc_per_node': nproc_per_node,
  171. 'helloworld_file': 'helloworld',
  172. 'top': 2,
  173. 'dtype': dtype,
  174. 'num_local_experts': num_local_experts,
  175. 'show_step_time': True,
  176. 'num_steps': 50,
  177. }
  178. with self.subTest(msg='Testing a2a algo with setting', test_case=test_case):
  179. loss_expected, step_time_expected = get_loss_and_step_time(test_case)
  180. for algo in ['LINEAR', '2D']:
  181. with patch.dict('os.environ', {
  182. 'TUTEL_ALLTOALL_ALGO': algo,
  183. 'LOCAL_SIZE': str(nproc_per_node),
  184. }):
  185. loss, step_time = get_loss_and_step_time(test_case)
  186. self.assertEqual(loss, loss_expected)
  187. print('\nsubcase(ndevs=%s, dtype=%s, local_experts=%s, algo=%s): step_time = %s (LINEAR = %s)' % (nproc_per_node, dtype, num_local_experts, algo, step_time, step_time_expected))
  188. self.assertTrue(math.isclose(step_time, step_time_expected, rel_tol=0.01, abs_tol=0.01))