test_layers.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561
  1. import logging
  2. import unittest
  3. import typing
  4. import torch
  5. import torch.nn as nn
  6. from torch.testing._internal import common_utils
  7. from apex.transformer import parallel_state
  8. from apex.transformer.tensor_parallel import layers
  9. from apex.transformer.testing.commons import set_random_seed
  10. from apex.transformer.testing.distributed_test_base import NcclDistributedTestBase
  11. from apex.transformer.testing.distributed_test_base import UccDistributedTestBase
  12. logging.getLogger("torch").setLevel(logging.WARNING)
  13. logging.getLogger("apex").setLevel(logging.WARNING)
  14. # N.B.(mkozuki): Disable TF32 matrix multiply.
  15. # Matrices used in this test are so small that TF32 matmul
  16. # can be less precise so that `self.assertEqual` raises.
  17. torch.backends.cuda.matmul.allow_tf32 = False
  18. class TensorParallelLayerTestBase:
  19. BATCH_SIZE: int = 8
  20. SEQUENCE_LENGTH: int = 128
  21. VOCAB_SIZE: int = 1024
  22. HIDDEN_SIZE: int = 256
  23. INPUT_SIZE_COEFF: int = 256
  24. OUTPUT_SIZE_COEFF: int = 256
  25. SEED: int = 123456
  26. @property
  27. def tensor_shape(self) -> typing.Sequence[int]:
  28. return [self.SEQUENCE_LENGTH, self.BATCH_SIZE, self.HIDDEN_SIZE]
  29. @torch.no_grad()
  30. @unittest.skipIf(torch.cuda.device_count() < 2, "Requires >=2 GPUs")
  31. def test_all_gather_parity(self) -> None:
  32. if self.DISTRIBUTED_BACKEND == "ucc":
  33. self.skipTest("torch_ucc does NOT support `torch.distributed._all_gather_base` as of 2022/06/15")
  34. from torch.distributed.distributed_c10d import all_gather, _all_gather_base # NOQA
  35. for tensor_model_parallel_world_size in range(1, self.world_size + 1):
  36. if self.world_size % tensor_model_parallel_world_size:
  37. continue
  38. parallel_state.initialize_model_parallel(
  39. tensor_model_parallel_size_=tensor_model_parallel_world_size,
  40. )
  41. tensor_model_parallel_rank = parallel_state.get_tensor_model_parallel_rank()
  42. cur_tensor_model_device = torch.device(f"cuda:{tensor_model_parallel_rank}")
  43. with torch.no_grad():
  44. tensor = tensor_model_parallel_rank * torch.ones(
  45. self.tensor_shape, dtype=torch.float32, device=cur_tensor_model_device)
  46. numel = tensor.numel()
  47. numel_gathered = tensor_model_parallel_world_size * numel
  48. gathered = torch.empty(
  49. torch.Size((numel_gathered,)),
  50. device=cur_tensor_model_device,
  51. dtype=torch.float32,
  52. requires_grad=False,
  53. )
  54. chunks = [
  55. gathered[i * numel : (i + 1) * numel]
  56. for i in range(tensor_model_parallel_world_size)
  57. ]
  58. all_gather(chunks, tensor, group=parallel_state.get_tensor_model_parallel_group())
  59. gathered_for_base = torch.empty(
  60. torch.Size((numel_gathered,)),
  61. device=cur_tensor_model_device,
  62. dtype=torch.float32,
  63. requires_grad=False,
  64. )
  65. _all_gather_base(
  66. gathered_for_base,
  67. tensor,
  68. group=parallel_state.get_tensor_model_parallel_group(),
  69. )
  70. msg = f"tensor_model_parallel_world_size: {tensor_model_parallel_world_size}"
  71. self.assertEqual(gathered, gathered_for_base, msg=msg)
  72. parallel_state.destroy_model_parallel()
  73. @torch.no_grad()
  74. @unittest.skipIf(torch.cuda.device_count() < 2, "Requires >=2 GPUs")
  75. def test_reduce_scatter_parity(self) -> None:
  76. if self.DISTRIBUTED_BACKEND == "ucc":
  77. self.skipTest("torch_ucc does NOT support `torch.distributed._reduce_scatter_base` as of 2022/06/15")
  78. from torch.distributed.distributed_c10d import reduce_scatter, _reduce_scatter_base # NOQA
  79. for tensor_model_parallel_world_size in range(2, self.world_size + 1):
  80. if self.world_size % tensor_model_parallel_world_size:
  81. continue
  82. parallel_state.initialize_model_parallel(
  83. tensor_model_parallel_size_=tensor_model_parallel_world_size,
  84. )
  85. tensor_model_parallel_rank = parallel_state.get_tensor_model_parallel_rank()
  86. cur_tensor_model_device = torch.device(f"cuda:{tensor_model_parallel_rank}")
  87. with torch.no_grad():
  88. input = torch.cat([
  89. i * torch.ones(self.tensor_shape, dtype=torch.float32, device=cur_tensor_model_device)
  90. for i in range(tensor_model_parallel_world_size)
  91. ])
  92. input_list = [t.clone() for t in input.chunk(tensor_model_parallel_world_size)]
  93. output = torch.empty(
  94. self.tensor_shape,
  95. device=cur_tensor_model_device,
  96. dtype=torch.float32,
  97. requires_grad=False,
  98. )
  99. reduce_scatter(
  100. output, input_list,
  101. group=parallel_state.get_tensor_model_parallel_group(),
  102. )
  103. output_for_base = torch.empty(
  104. self.tensor_shape,
  105. device=cur_tensor_model_device,
  106. dtype=torch.float32,
  107. requires_grad=False,
  108. )
  109. _reduce_scatter_base(
  110. output_for_base,
  111. input,
  112. group=parallel_state.get_tensor_model_parallel_group(),
  113. )
  114. msg = f"tensor_model_parallel_world_size: {tensor_model_parallel_world_size}"
  115. self.assertEqual(output, output_for_base, msg=msg)
  116. self.assertEqual(input, torch.cat(input_list), msg=msg)
  117. parallel_state.destroy_model_parallel()
  118. def test_parallel_embedding(self) -> None:
  119. for tensor_model_parallel_world_size in range(1, self.world_size + 1):
  120. if self.world_size % tensor_model_parallel_world_size:
  121. continue
  122. parallel_state.initialize_model_parallel(
  123. tensor_model_parallel_size_=tensor_model_parallel_world_size,
  124. )
  125. set_random_seed(self.SEED + 1)
  126. input_tensor = torch.randint(
  127. 0,
  128. self.VOCAB_SIZE,
  129. (
  130. self.BATCH_SIZE,
  131. self.SEQUENCE_LENGTH,
  132. ),
  133. device="cuda",
  134. )
  135. loss_weight = torch.randn(
  136. (
  137. self.BATCH_SIZE,
  138. self.SEQUENCE_LENGTH,
  139. self.HIDDEN_SIZE,
  140. ),
  141. device="cuda",
  142. )
  143. set_random_seed(self.SEED)
  144. embedding_torch = nn.Embedding(
  145. self.VOCAB_SIZE,
  146. self.HIDDEN_SIZE,
  147. ).cuda()
  148. output_torch = embedding_torch(input_tensor)
  149. loss_torch = torch.mul(output_torch, loss_weight).sum()
  150. loss_torch.backward()
  151. # N.B.(mkozuki): With affine weight initialization on GPU,
  152. # it's super difficult to keep the consistency with nn.Embedding.
  153. # Thus, turning on `use_cpu_initialization`.
  154. set_random_seed(self.SEED)
  155. embedding_vocab_parallel = layers.VocabParallelEmbedding(
  156. self.VOCAB_SIZE,
  157. self.HIDDEN_SIZE,
  158. init_method=nn.init.normal_,
  159. use_cpu_initialization=True,
  160. ).cuda()
  161. output_vocab_parallel = embedding_vocab_parallel(input_tensor)
  162. loss_vocab_parallel = torch.mul(
  163. output_vocab_parallel, loss_weight
  164. ).sum()
  165. loss_vocab_parallel.backward()
  166. msg = f"tensor_model_parallel_world_size: {tensor_model_parallel_world_size}"
  167. self.assertEqual(output_torch, output_vocab_parallel, msg=msg)
  168. self.assertEqual(loss_torch, loss_vocab_parallel, msg=msg)
  169. splitted_weight_torch = torch.split(
  170. embedding_torch.weight.grad,
  171. self.VOCAB_SIZE
  172. // tensor_model_parallel_world_size,
  173. 0,
  174. )[parallel_state.get_tensor_model_parallel_rank()]
  175. self.assertEqual(
  176. splitted_weight_torch, embedding_vocab_parallel.weight.grad, msg=msg,
  177. )
  178. parallel_state.destroy_model_parallel()
  179. def _affine_weight_init_test_impl(
  180. self, init_device: str, is_column_parallel: bool
  181. ) -> None:
  182. dim = int(not is_column_parallel)
  183. for tensor_model_parallel_world_size in range(1, self.world_size + 1):
  184. if self.world_size % tensor_model_parallel_world_size:
  185. continue
  186. parallel_state.initialize_model_parallel(
  187. tensor_model_parallel_size_=tensor_model_parallel_world_size
  188. )
  189. input_size: int = self.INPUT_SIZE_COEFF * tensor_model_parallel_world_size
  190. output_size: int = self.OUTPUT_SIZE_COEFF * tensor_model_parallel_world_size
  191. weight_shape = (
  192. (self.OUTPUT_SIZE_COEFF, input_size)
  193. if is_column_parallel
  194. else (output_size, self.INPUT_SIZE_COEFF)
  195. )
  196. weight = torch.empty(weight_shape)
  197. set_random_seed(self.SEED)
  198. sharding_dim_size = (
  199. self.OUTPUT_SIZE_COEFF
  200. if is_column_parallel
  201. else self.INPUT_SIZE_COEFF
  202. )
  203. if init_device == "cpu":
  204. layers._initialize_affine_weight_cpu(
  205. weight,
  206. output_size,
  207. input_size,
  208. sharding_dim_size,
  209. dim,
  210. nn.init.normal_,
  211. params_dtype=torch.float32,
  212. )
  213. else:
  214. layers._initialize_affine_weight_gpu(
  215. weight, torch.nn.init.normal_, dim
  216. )
  217. # Target
  218. set_random_seed(self.SEED)
  219. if init_device == "cpu":
  220. main_weight = torch.empty(output_size, input_size)
  221. nn.init.normal_(main_weight)
  222. curr_weight = torch.split(main_weight, sharding_dim_size, dim=dim)[
  223. parallel_state.get_tensor_model_parallel_rank()
  224. ]
  225. else:
  226. curr_weight = torch.empty(*weight_shape)
  227. nn.init.normal_(curr_weight)
  228. self.assertEqual(
  229. curr_weight, weight, msg=f"tensor_model_parallel_world_size: {tensor_model_parallel_world_size}")
  230. parallel_state.destroy_model_parallel()
  231. def test_affine_weight_init_column_parallel_cpu(self) -> None:
  232. self._affine_weight_init_test_impl(init_device="cpu", is_column_parallel=True)
  233. def test_affine_weight_init_column_parallel_gpu(self) -> None:
  234. self._affine_weight_init_test_impl(init_device="gpu", is_column_parallel=True)
  235. def test_affine_weight_init_row_parallel_cpu(self) -> None:
  236. self._affine_weight_init_test_impl(init_device="cpu", is_column_parallel=False)
  237. def test_affine_weight_init_row_parallel_gpu(self) -> None:
  238. self._affine_weight_init_test_impl(init_device="gpu", is_column_parallel=False)
  239. def test_row_parallel_linear(self) -> None:
  240. self._row_parallel_linear_test_impl(False, False, False)
  241. def test_row_parallel_linear_gradient_accumulation_fusion(self) -> None:
  242. self._row_parallel_linear_test_impl(True, False, False)
  243. def test_row_parallel_linear_gradient_accumulation_fusion_in_fp16(self) -> None:
  244. self._row_parallel_linear_test_impl(True, True, False)
  245. # fails on native ucc and torch ucc: ucc does not support reduce scatter
  246. @unittest.skipIf(torch.cuda.device_count() < 2, "Sequence Parallel requires >=2 GPUs")
  247. def test_row_parallel_linear_sequence_parallel(self) -> None:
  248. self._row_parallel_linear_test_impl(False, False, True)
  249. # TODO(mkozuki): Merge this with `_column_parallel_linear_test_impl`
  250. # Note that `input_is_parallel` is unique to `RowParallelLinear` which could make the merge complicated.
  251. def _row_parallel_linear_test_impl(
  252. self,
  253. gradient_accumulation_fusion: bool,
  254. accumulation_in_fp16: bool,
  255. sequence_parallel_enabled: bool,
  256. ) -> None:
  257. tensor_shape = (
  258. self.SEQUENCE_LENGTH,
  259. self.BATCH_SIZE,
  260. self.HIDDEN_SIZE,
  261. )
  262. for tensor_model_parallel_world_size in range(
  263. 1 + int(sequence_parallel_enabled), self.world_size + 1
  264. ):
  265. if self.world_size % tensor_model_parallel_world_size:
  266. continue
  267. parallel_state.initialize_model_parallel(
  268. tensor_model_parallel_size_=tensor_model_parallel_world_size,
  269. )
  270. set_random_seed(self.SEED)
  271. linear = layers.RowParallelLinear(
  272. self.HIDDEN_SIZE,
  273. self.HIDDEN_SIZE,
  274. keep_master_weight_for_test=True,
  275. params_dtype=torch.float32,
  276. use_cpu_initialization=True,
  277. gradient_accumulation_fusion=gradient_accumulation_fusion,
  278. accumulation_in_fp16=accumulation_in_fp16,
  279. sequence_parallel_enabled=sequence_parallel_enabled,
  280. # n.b.(mkozuki): RowParallelLinear is constructed with `input_is_parallel=True`
  281. # by default, e.g. https://github.com/NVIDIA/NeMo/blob/782b4e1652aaa43c8be390d9\
  282. # db0dc89544afa080/nemo/collections/nlp/modules/common/megatron/transformer.py#L204
  283. input_is_parallel=True,
  284. ).cuda()
  285. if accumulation_in_fp16:
  286. linear = linear.half()
  287. # Simulate the situation where fusion of weight grad calculation and gradient accumulation is enabled.
  288. if gradient_accumulation_fusion:
  289. with torch.no_grad():
  290. linear.weight.main_grad = torch.zeros_like(linear.weight)
  291. msg = f"tensor_model_parallel_world_size: {tensor_model_parallel_world_size}"
  292. with torch.no_grad():
  293. orig_input_tensor = torch.randn(tensor_shape, requires_grad=True, device="cuda")
  294. orig_loss_weight = torch.randn(tensor_shape, device="cuda")
  295. input_tensor = orig_input_tensor.chunk(
  296. chunks=tensor_model_parallel_world_size,
  297. dim=2,
  298. )[parallel_state.get_tensor_model_parallel_rank()].contiguous()
  299. if sequence_parallel_enabled:
  300. loss_weight = orig_loss_weight.chunk(
  301. chunks=tensor_model_parallel_world_size,
  302. dim=0,
  303. )[parallel_state.get_tensor_model_parallel_rank()]
  304. else:
  305. loss_weight = orig_loss_weight
  306. if accumulation_in_fp16:
  307. orig_input_tensor = orig_input_tensor.half()
  308. input_tensor = input_tensor.half()
  309. loss_weight = loss_weight.half()
  310. input_tensor.requires_grad_()
  311. output, _ = linear(input_tensor)
  312. loss = torch.mul(output, loss_weight).sum()
  313. loss.backward()
  314. self.assertIsNotNone(input_tensor.grad, msg=msg)
  315. ref_linear = nn.Linear(
  316. in_features=self.HIDDEN_SIZE,
  317. out_features=self.HIDDEN_SIZE,
  318. bias=False,
  319. device="cuda",
  320. )
  321. with torch.no_grad():
  322. dldy = orig_loss_weight.clone()
  323. x = orig_input_tensor.clone()
  324. ref_linear.weight.copy_(linear.master_weight)
  325. if accumulation_in_fp16:
  326. ref_linear = ref_linear.half()
  327. x.requires_grad_()
  328. expected_output = ref_linear(x)
  329. expected_loss = torch.mul(expected_output, dldy).sum()
  330. expected_loss.backward()
  331. if not accumulation_in_fp16:
  332. if sequence_parallel_enabled:
  333. self.assertEqual(
  334. x=output,
  335. y=expected_output.chunk(
  336. chunks=tensor_model_parallel_world_size,
  337. dim=0,
  338. )[parallel_state.get_tensor_model_parallel_rank()],
  339. msg=msg,
  340. )
  341. else:
  342. self.assertEqual(
  343. x=output,
  344. y=expected_output,
  345. msg=msg,
  346. )
  347. grad_attr_name = "main_grad" if gradient_accumulation_fusion else "grad"
  348. # NOTE(mkozuki): Numerical errors seems to be enlarged by tensor model parallel.
  349. if tensor_model_parallel_world_size == 1:
  350. self.assertEqual(
  351. x=getattr(linear.weight, grad_attr_name),
  352. y=ref_linear.weight.grad.chunk(
  353. chunks=tensor_model_parallel_world_size,
  354. dim=0,
  355. )[parallel_state.get_tensor_model_parallel_rank()],
  356. msg=msg,
  357. )
  358. parallel_state.destroy_model_parallel()
  359. def test_column_parallel_linear(self):
  360. self._column_parallel_linear_test_impl(False, False, False, False)
  361. def test_column_parallel_linear_async(self):
  362. self._column_parallel_linear_test_impl(True, False, False, False)
  363. def test_column_parallel_linear_gradient_accumulation_fusion(self):
  364. self._column_parallel_linear_test_impl(False, True, False, False)
  365. def test_column_parallel_linear_gradient_accumulation_fusion_in_fp16(self):
  366. self._column_parallel_linear_test_impl(False, True, True, False)
  367. def test_column_parallel_linear_sequence_parallel(self):
  368. if self.DISTRIBUTED_BACKEND == "ucc":
  369. self.skipTest("Backward's reduce_scatter fails. as of 2022/06/15")
  370. self._column_parallel_linear_test_impl(False, False, False, True)
  371. @unittest.skipIf(torch.cuda.device_count() < 2, "Sequence Parallel requires >= 2 GPUs")
  372. def test_column_parallel_linear_exception(self):
  373. with self.assertRaisesRegex(
  374. RuntimeError,
  375. "`async_tensor_model_parallel_allreduce` and `sequence_parallel_enabled` cannot be enabled at the same time.",
  376. ):
  377. self._column_parallel_linear_test_impl(True, False, False, True)
  378. def _column_parallel_linear_test_impl(
  379. self,
  380. async_tensor_model_parallel_allreduce: bool,
  381. gradient_accumulation_fusion: bool,
  382. accumulation_in_fp16: bool,
  383. sequence_parallel_enabled: bool,
  384. ):
  385. for tensor_model_parallel_world_size in range(1, self.world_size + 1):
  386. if async_tensor_model_parallel_allreduce and sequence_parallel_enabled:
  387. if tensor_model_parallel_world_size == 1:
  388. continue
  389. if self.world_size % tensor_model_parallel_world_size:
  390. continue
  391. msg = f"tensor_model_parallel_world_size: {tensor_model_parallel_world_size}"
  392. parallel_state.initialize_model_parallel(
  393. tensor_model_parallel_size_=tensor_model_parallel_world_size,
  394. )
  395. input_tensor_shape = self.tensor_shape
  396. expected_output_shape = self.tensor_shape
  397. # When sequence parallel, `gather_output` is disabled, i.e.,
  398. # output of matmul isn't gathered in dimension of feature/hidden (last dim).
  399. if sequence_parallel_enabled:
  400. expected_output_shape[-1] //= tensor_model_parallel_world_size
  401. # tensor's shape is [sequence length, batch size, hidden size]
  402. set_random_seed(self.SEED)
  403. linear = layers.ColumnParallelLinear(
  404. self.HIDDEN_SIZE,
  405. self.HIDDEN_SIZE,
  406. bias=False,
  407. keep_master_weight_for_test=True,
  408. params_dtype=torch.float32,
  409. use_cpu_initialization=True,
  410. gather_output=not sequence_parallel_enabled,
  411. no_async_tensor_model_parallel_allreduce=not async_tensor_model_parallel_allreduce,
  412. gradient_accumulation_fusion=gradient_accumulation_fusion,
  413. accumulation_in_fp16=accumulation_in_fp16,
  414. sequence_parallel_enabled=sequence_parallel_enabled,
  415. ).cuda()
  416. if accumulation_in_fp16:
  417. linear = linear.half()
  418. # Simulate the situation where fusion of weight grad calculation and gradient accumulation happens.
  419. if gradient_accumulation_fusion:
  420. with torch.no_grad():
  421. linear.weight.main_grad = torch.zeros_like(linear.weight)
  422. orig_input_tensor = torch.randn(input_tensor_shape, device="cuda", requires_grad=True)
  423. if accumulation_in_fp16:
  424. orig_input_tensor = orig_input_tensor.half()
  425. if sequence_parallel_enabled:
  426. input_tensor = list(
  427. orig_input_tensor.chunk(tensor_model_parallel_world_size, dim=0)
  428. )[parallel_state.get_tensor_model_parallel_rank()]
  429. else:
  430. input_tensor = orig_input_tensor
  431. output, _ = linear(input_tensor)
  432. # The order of dimension is expected to be (sequence, batch, hidden)
  433. self.assertEqual(output.shape, expected_output_shape, msg=msg)
  434. orig_loss_weight = torch.randn(input_tensor_shape, device="cuda")
  435. if accumulation_in_fp16:
  436. orig_loss_weight = orig_loss_weight.half()
  437. if sequence_parallel_enabled:
  438. loss_weight = orig_loss_weight.chunk(
  439. tensor_model_parallel_world_size, dim=2,
  440. )[parallel_state.get_tensor_model_parallel_rank()]
  441. else:
  442. loss_weight = orig_loss_weight
  443. loss = torch.mul(output, loss_weight).sum()
  444. loss.backward()
  445. with torch.no_grad():
  446. dldy = orig_loss_weight.clone()
  447. x = orig_input_tensor.clone()
  448. ref_linear = nn.Linear(
  449. in_features=self.HIDDEN_SIZE,
  450. out_features=self.HIDDEN_SIZE,
  451. bias=False,
  452. device="cuda",
  453. )
  454. if accumulation_in_fp16:
  455. ref_linear = ref_linear.half()
  456. # NOTE(mkozuki): `master_weight` is available because `keep_master_weight_for_test` is set.
  457. ref_linear.weight.copy_(linear.master_weight)
  458. x.requires_grad_()
  459. expected_output = ref_linear(x)
  460. if sequence_parallel_enabled:
  461. chunk = expected_output.chunk(
  462. tensor_model_parallel_world_size,
  463. dim=2,
  464. )[parallel_state.get_tensor_model_parallel_rank()]
  465. self.assertEqual(
  466. x=output,
  467. y=chunk,
  468. msg=msg,
  469. )
  470. else:
  471. self.assertEqual(
  472. x=output,
  473. y=expected_output,
  474. msg=msg,
  475. )
  476. expected_loss = torch.mul(expected_output, dldy).sum()
  477. expected_loss.backward()
  478. grad_attr_name = "main_grad" if gradient_accumulation_fusion else "grad"
  479. # NOTE(mkozuki): Numerical errors seems to be enlarged by tensor model parallel.
  480. if tensor_model_parallel_world_size == 1:
  481. self.assertEqual(
  482. x=getattr(linear.weight, grad_attr_name),
  483. y=ref_linear.weight.grad.chunk(
  484. chunks=tensor_model_parallel_world_size,
  485. dim=0,
  486. )[parallel_state.get_tensor_model_parallel_rank()],
  487. msg=msg,
  488. )
  489. parallel_state.destroy_model_parallel()
  490. class NcclTensorParallelLayerTest(TensorParallelLayerTestBase, NcclDistributedTestBase):
  491. pass
  492. class UccTensorParallelLayerTest(TensorParallelLayerTestBase, UccDistributedTestBase):
  493. pass
  494. if __name__ == "__main__":
  495. common_utils.run_tests()