# Copyright (c) Microsoft Corporation. # Licensed under the MIT license. import contextlib import io import itertools import json import math import os import subprocess import unittest from unittest.mock import patch try: import GPUtil GPU_NAME = GPUtil.getGPUs()[0].name except: GPU_NAME = 'unknown' class HelloworldCaller(): """A class for run tutel helloworld example with different arguments""" def run( self, nproc_per_node=1, helloworld_file='helloworld', top=2, dtype='float32', num_local_experts='2', hidden_size=2048, show_step_time=True, batch_size=16, is_round=True, a2a_ffn_overlap_degree=1, num_steps=100, use_model_parallel=False, device='cuda' ): # Disable NCCL SHM because it's capacity is limited in Azure pipeline new_env = os.environ.copy() new_env['NCCL_SHM_DISABLE'] = '1' """Run helloworld example""" if helloworld_file == 'helloworld': command = 'python3 -m torch.distributed.run --nproc_per_node=' + str(nproc_per_node) + ' tutel/examples/helloworld.py --top ' + str(top) + ' --dtype ' + dtype + ' --num_local_experts ' + str(num_local_experts) + ' --hidden_size ' + str(hidden_size) + ' --batch_size ' + str(batch_size) + ' --a2a_ffn_overlap_degree ' + str(a2a_ffn_overlap_degree) + ' --num_steps ' + str(num_steps) + ' --device ' + str(device) + ' --num_tokens 1024' if use_model_parallel: command += ' --parallel_type model' else: command += ' --parallel_type data' else: raise Exception('Unhandled helloworld_file: %s' % helloworld_file) p = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=new_env) losses = [] while p.poll() is None: line = p.stdout.readline().decode("utf8").split() for i in range(len(line) - 1): if line[i] == 'loss': if is_round: if dtype == 'float32': losses.append(round(float(line[i + 2][:-1]), 3)) else: losses.append(round(float(line[i + 2][:-1]), 1)) else: losses.append(line[i + 2][:-1]) break if show_step_time and line[0] == '[Summary]': print('step time:', line[5]) p.stdout.close() assert len(losses) > 0, "No valid loss result found for this unit test: %s" % command return losses class TutelTestCase(unittest.TestCase): """A class for tutel test cases.""" def setUp(self): """Hook method for setting up the test""" self.GPUtype = GPU_NAME with open('tests/test_baseline.json') as f: self.data = json.load(f) for i in range(9): for j in range(len(self.data[i]['losses'])): if '32' in self.data[i]['dtype']: self.data[i]['losses'][j] = round(float(self.data[i]['losses'][j]), 3) else: self.data[i]['losses'][j] = round(float(self.data[i]['losses'][j]), 1) self.tutelCaller = HelloworldCaller() def test_cpu_kernel(self): """Test cpu kernel""" cuda_losses = self.tutelCaller.run(nproc_per_node=1, num_steps=10, device='cuda', show_step_time=False) cpu_losses = self.tutelCaller.run(nproc_per_node=1, num_steps=10, device='cpu', show_step_time=False) for i in range(10): cuda_losses[i] = round(cuda_losses[i],2) cpu_losses[i] = round(cpu_losses[i],2) self.assertEqual(cuda_losses, cpu_losses) def test_top1_fp32_1_expert(self): """Test helloworld with top1 gate, float32 dtype and 1 expert(s).""" for i in range(len(self.data[2]['step_time'])): if self.data[2]['step_time'][i]['GPU'] in self.GPUtype: print('\nexpected time:', self.data[2]['step_time'][i]['value']) self.assertEqual(self.tutelCaller.run(nproc_per_node=1, helloworld_file='helloworld', top=1, dtype='float32', num_local_experts=1), self.data[2]['losses']) def test_top1_fp32_2_experts(self): """Test helloworld with top1 gate, float32 dtype and 2 expert(s).""" for i in range(len(self.data[3]['step_time'])): if self.data[3]['step_time'][i]['GPU'] in self.GPUtype: print('\nexpected time:', self.data[3]['step_time'][i]['value']) self.assertEqual(self.tutelCaller.run(nproc_per_node=1, helloworld_file='helloworld', top=1, dtype='float32', num_local_experts=2), self.data[3]['losses']) def test_top1_fp16_1_expert(self): """Test helloworld with top1 gate, float16 dtype and 1 expert(s).""" for i in range(len(self.data[0]['step_time'])): if self.data[0]['step_time'][i]['GPU'] in self.GPUtype: print('\nexpected time:', self.data[0]['step_time'][i]['value']) self.assertEqual(self.tutelCaller.run(nproc_per_node=1, helloworld_file='helloworld', top=1, dtype='float16', num_local_experts=1)[0:2], self.data[0]['losses'][0:2]) def test_top1_fp16_2_experts(self): """Test helloworld with top1 gate, float16 dtype and 2 expert(s).""" for i in range(len(self.data[1]['step_time'])): if self.data[1]['step_time'][i]['GPU'] in self.GPUtype: print('\nexpected time:', self.data[1]['step_time'][i]['value']) self.assertEqual(self.tutelCaller.run(nproc_per_node=1, helloworld_file='helloworld', top=1, dtype='float16', num_local_experts=2)[0:2], self.data[1]['losses'][0:2]) def test_top2_fp32_1_expert(self): """Test helloworld with top2 gate, float32 dtype and 1 expert(s).""" for i in range(len(self.data[6]['step_time'])): if self.data[6]['step_time'][i]['GPU'] in self.GPUtype: print('\nexpected time:', self.data[6]['step_time'][i]['value']) self.assertEqual(self.tutelCaller.run(nproc_per_node=1, helloworld_file='helloworld', top=2, dtype='float32', num_local_experts=1), self.data[6]['losses']) def test_top2_fp32_2_experts(self): """Test helloworld with top2 gate, float32 dtype and 2 expert(s).""" for i in range(len(self.data[7]['step_time'])): if self.data[7]['step_time'][i]['GPU'] in self.GPUtype: print('\nexpected time:', self.data[7]['step_time'][i]['value']) self.assertEqual(self.tutelCaller.run(nproc_per_node=1, helloworld_file='helloworld', top=2, dtype='float32', num_local_experts=2), self.data[7]['losses']) def test_top2_fp16_1_expert(self): """Test helloworld with top2 gate, float16 dtype and 1 expert(s).""" for i in range(len(self.data[4]['step_time'])): if self.data[4]['step_time'][i]['GPU'] in self.GPUtype: print('\nexpected time:', self.data[4]['step_time'][i]['value']) self.assertEqual(self.tutelCaller.run(nproc_per_node=1, helloworld_file='helloworld', top=2, dtype='float16', num_local_experts=1)[0:2], self.data[4]['losses'][0:2]) def test_top2_fp16_2_experts(self): """Test helloworld with top2 gate, float16 dtype and 2 expert(s).""" for i in range(len(self.data[5]['step_time'])): if self.data[5]['step_time'][i]['GPU'] in self.GPUtype: print('\nexpected time:', self.data[5]['step_time'][i]['value']) self.assertEqual(self.tutelCaller.run(nproc_per_node=1, helloworld_file='helloworld', top=2, dtype='float16', num_local_experts=2)[0:2], self.data[5]['losses'][0:2]) def test_top2_fp64_2_experts(self): """Test helloworld with top2 gate, float64 dtype and 2 expert(s).""" self.assertEqual(self.tutelCaller.run(nproc_per_node=1, helloworld_file='helloworld', top=2, dtype='float64', num_local_experts=2, show_step_time=False, batch_size=1), self.data[8]['losses']) def test_compare_data_model_parallel(self): """Test helloworld data parallel and helloworld model parallel which should have same result""" self.assertEqual( self.tutelCaller.run(nproc_per_node=2, helloworld_file='helloworld', top=2, dtype='float32', num_local_experts=-2, show_step_time=False, use_model_parallel=False), self.tutelCaller.run(nproc_per_node=2, helloworld_file='helloworld', top=2, dtype='float32', num_local_experts=-2, show_step_time=False, use_model_parallel=True), ) def test_a2a_ffn_overlap(self): """Test whether AllToAll-FFN overlapping works properly. Note that too small batch size might cause precision issue.""" self.assertEqual( self.tutelCaller.run(nproc_per_node=2, helloworld_file='helloworld', top=2, dtype='float64', num_local_experts=-2, show_step_time=False, batch_size=1, a2a_ffn_overlap_degree=1), self.tutelCaller.run(nproc_per_node=2, helloworld_file='helloworld', top=2, dtype='float64', num_local_experts=-2, show_step_time=False, batch_size=1, a2a_ffn_overlap_degree=2) ) self.assertEqual( self.tutelCaller.run(nproc_per_node=2, helloworld_file='helloworld', top=2, dtype='float64', num_local_experts=1, show_step_time=False, batch_size=1, a2a_ffn_overlap_degree=1), self.tutelCaller.run(nproc_per_node=2, helloworld_file='helloworld', top=2, dtype='float64', num_local_experts=1, show_step_time=False, batch_size=1, a2a_ffn_overlap_degree=2) ) self.assertEqual( self.tutelCaller.run(nproc_per_node=2, helloworld_file='helloworld', top=2, dtype='float64', num_local_experts=2, show_step_time=False, batch_size=1, a2a_ffn_overlap_degree=1), self.tutelCaller.run(nproc_per_node=2, helloworld_file='helloworld', top=2, dtype='float64', num_local_experts=2, show_step_time=False, batch_size=1, a2a_ffn_overlap_degree=2) ) def test_a2a_algos(self): def get_loss_and_step_time(args): with contextlib.redirect_stdout(io.StringIO()) as f: loss = self.tutelCaller.run(**args) step_time = float(f.getvalue().strip().split()[-1]) return loss, step_time for nproc_per_node, dtype, num_local_experts in itertools.product( [1, 2], ['float32', 'float16'], [1, 2], ): test_case = { 'nproc_per_node': nproc_per_node, 'helloworld_file': 'helloworld', 'top': 2, 'dtype': dtype, 'num_local_experts': num_local_experts, 'show_step_time': True, 'num_steps': 50, } with self.subTest(msg='Testing a2a algo with setting', test_case=test_case): loss_expected, step_time_expected = get_loss_and_step_time(test_case) for algo in ['LINEAR', '2D']: with patch.dict('os.environ', { 'TUTEL_ALLTOALL_ALGO': algo, 'LOCAL_SIZE': str(nproc_per_node), }): loss, step_time = get_loss_and_step_time(test_case) self.assertEqual(loss, loss_expected) print('\nsubcase(ndevs=%s, dtype=%s, local_experts=%s, algo=%s): step_time = %s (LINEAR = %s)' % (nproc_per_node, dtype, num_local_experts, algo, step_time, step_time_expected)) self.assertTrue(math.isclose(step_time, step_time_expected, rel_tol=0.01, abs_tol=0.01))