test_batch_sampler.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. import torch
  2. from torch.testing._internal import common_utils
  3. from torch.utils.data import Dataset
  4. from torch.utils.data import DataLoader
  5. from apex.transformer.pipeline_parallel.utils import _split_batch_into_microbatch as split_batch_into_microbatch
  6. class MyIterableDataset(Dataset):
  7. def __init__(self, start, end):
  8. super().__init__()
  9. assert end > start, "this example code only works with end >= start"
  10. self.start = start
  11. self.end = end
  12. self.samples = list(range(self.start, self.end))
  13. def __iter__(self):
  14. return iter(range(self.start, self.end))
  15. def __getitem__(self, index):
  16. return self.samples[index]
  17. class MegatronPretrainingRandomSampler:
  18. def __init__(self, total_samples, consumed_samples, micro_batch_size,
  19. data_parallel_rank, data_parallel_size):
  20. # Keep a copy of input params for later use.
  21. self.total_samples = total_samples
  22. self.consumed_samples = consumed_samples
  23. self.micro_batch_size = micro_batch_size
  24. self.data_parallel_rank = data_parallel_rank
  25. self.data_parallel_size = data_parallel_size
  26. self.micro_batch_times_data_parallel_size = \
  27. self.micro_batch_size * data_parallel_size
  28. self.last_batch_size = \
  29. self.total_samples % self.micro_batch_times_data_parallel_size
  30. # Sanity checks.
  31. assert self.total_samples > 0, \
  32. 'no sample to consume: {}'.format(self.total_samples)
  33. assert self.micro_batch_size > 0
  34. assert data_parallel_size > 0
  35. assert self.data_parallel_rank < data_parallel_size, \
  36. 'data_parallel_rank should be smaller than data size: {}, ' \
  37. '{}'.format(self.data_parallel_rank, data_parallel_size)
  38. def __len__(self):
  39. return self.total_samples
  40. def __iter__(self):
  41. active_total_samples = self.total_samples - self.last_batch_size
  42. self.epoch = self.consumed_samples // active_total_samples
  43. current_epoch_samples = self.consumed_samples % active_total_samples
  44. assert current_epoch_samples % self.micro_batch_times_data_parallel_size == 0
  45. # data sharding and random sampling
  46. bucket_size = (self.total_samples // self.micro_batch_times_data_parallel_size) * self.micro_batch_size
  47. bucket_offset = current_epoch_samples // self.data_parallel_size
  48. start_idx = self.data_parallel_rank * bucket_size
  49. g = torch.Generator()
  50. g.manual_seed(self.epoch)
  51. random_idx = torch.randperm(bucket_size, generator=g).tolist()
  52. idx_range = [start_idx + x for x in random_idx[bucket_offset:]]
  53. batch = []
  54. # Last batch if not complete will be dropped.
  55. for idx in idx_range:
  56. batch.append(idx)
  57. if len(batch) == self.micro_batch_size:
  58. self.consumed_samples += self.micro_batch_times_data_parallel_size
  59. yield batch
  60. batch = []
  61. # Samples 8 tensors in total.
  62. # First sample 4 tensors twice, then sample 2 tensors fourth.
  63. class TestBatchSamplerBehavior(common_utils.TestCase):
  64. def tearDown(self) -> None:
  65. torch.cuda.empty_cache()
  66. super().tearDown()
  67. def test_batch_sampler_behavior(self):
  68. dataset = MyIterableDataset(0, 100)
  69. for num_workers in (1, 2, 4):
  70. torch.manual_seed(42)
  71. loader = DataLoader(dataset, batch_sampler=MegatronPretrainingRandomSampler(100, 0, 4, 0, 1), num_workers=num_workers)
  72. samples = []
  73. for i, batch in enumerate(loader):
  74. samples.append(batch)
  75. if i == 2 - 1:
  76. break
  77. torch.manual_seed(42)
  78. loader = DataLoader(dataset, batch_sampler=MegatronPretrainingRandomSampler(100, 0, 2, 0, 1), num_workers=num_workers)
  79. samples2 = []
  80. for i, batch in enumerate(loader):
  81. samples2.append(batch)
  82. if i == 4 - 1:
  83. break
  84. self.assertEqual(torch.cat(samples), torch.cat(samples2), msg=f"num_workers={num_workers}")
  85. def test_split_batch(self):
  86. class MyIterableDataset(Dataset):
  87. def __init__(self, start, end):
  88. super().__init__()
  89. assert end > start, "this example code only works with end >= start"
  90. self.start = start
  91. self.end = end
  92. self.samples = list(range(self.start, self.end))
  93. def __len__(self):
  94. return self.end - self.start
  95. def __iter__(self):
  96. return iter(range(self.start, self.end))
  97. def __getitem__(self, index):
  98. return (torch.tensor([index, index]), torch.tensor([index // 2, index // 2]))
  99. dataset = MyIterableDataset(0, 100)
  100. torch.manual_seed(42)
  101. global_batch_size = 16
  102. loader = DataLoader(dataset, batch_sampler=MegatronPretrainingRandomSampler(100, 0, global_batch_size, 0, 1), num_workers=2)
  103. batch = next(iter(loader))
  104. for _micro_batch_size in (1, 2, 4, 8):
  105. microbatches = list(split_batch_into_microbatch(
  106. batch,
  107. _micro_batch_size=_micro_batch_size,
  108. _global_batch_size=global_batch_size,
  109. ))
  110. self.assertEqual(len(microbatches), global_batch_size // _micro_batch_size)
  111. self.assertEqual(len(microbatches[0][0]), _micro_batch_size)
  112. if __name__ == "__main__":
  113. common_utils.run_tests()