gpt_scaling_test.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. import subprocess
  2. import os
  3. from apex.transformer.testing.commons import TEST_SUCCESS_MESSAGE
  4. def run_gpt(cmd):
  5. args = list(cmd.split(" "))
  6. p = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
  7. outs, errs = p.communicate()
  8. outs = list(str((outs).decode("utf-8")).splitlines())
  9. success = False
  10. runtime = 0
  11. num_params = 0
  12. for out in outs:
  13. out = str(out)
  14. if "Average Iteration Time:" in str(out):
  15. slicey = out[out.find(":") + 2 :]
  16. try:
  17. runtime = float(slicey)
  18. except:
  19. print(slicey)
  20. quit()
  21. if "Number of Parameters:" in str(out):
  22. slicey = out[out.find(":") + 2 :]
  23. try:
  24. num_params = int(slicey)
  25. except:
  26. print(slicey)
  27. quit()
  28. if str(out) == str(TEST_SUCCESS_MESSAGE):
  29. success = True
  30. return runtime, round(float(int(num_params)) / 10.0 ** 9, 3), success, errs
  31. def plot(runtimes):
  32. import matplotlib.pyplot as plt
  33. for distributed_setting in runtimes.keys():
  34. plt.scatter(
  35. runtimes[distributed_setting].keys(),
  36. runtimes[distributed_setting].values(),
  37. label=distributed_setting,
  38. )
  39. plt.legend()
  40. plt.xlabel("Parameters (Billions)")
  41. plt.ylabel("Training Iteration time (s)")
  42. plt.title(str("GPT Scaling w/ Offloading"))
  43. plt.savefig("offload_gpt_scaling.png")
  44. plt.close()
  45. if not os.path.exists("/my_workspace/"):
  46. os.system("mkdir /my_workspace/")
  47. os.system("cp *.png /my_workspace/")
  48. def main():
  49. runtimes = {}
  50. nlist = (
  51. list(range(2000, 10000, 2000))
  52. + list(range(10000, 50000, 5000))
  53. + list(range(50000, 100000, 10000))
  54. )
  55. print("N-List:", nlist)
  56. for data_parr, tens_parr, pipe_parr in [(8, 1, 1), (4, 2, 1), (2, 1, 4), (1, 2, 4)]:
  57. for offload in [True, False]:
  58. dist_setting = (
  59. "ddp="
  60. + str(data_parr)
  61. + ", tensor_parr="
  62. + str(tens_parr)
  63. + ", pipe_parr="
  64. + str(pipe_parr)
  65. + ", offload="
  66. + str(offload)
  67. )
  68. runtimes[dist_setting] = {}
  69. print("Beginning Testing for", dist_setting)
  70. for n in nlist:
  71. cmd = "python3 -m torch.distributed.launch --nproc_per_node=8 run_gpt_minimal_test.py"
  72. cmd += (
  73. " --micro-batch-size 1 --num-layers "
  74. + str(n)
  75. + " --hidden-size 128 --num-attention-heads 16"
  76. )
  77. cmd += (
  78. " --max-position-embeddings 128 --seq-length 128 --tensor-model-parallel-size "
  79. + str(tens_parr)
  80. )
  81. cmd += (
  82. " --pipeline-model-parallel-size "
  83. + str(pipe_parr)
  84. + (" --cpu-offload" if offload else "")
  85. )
  86. print(cmd)
  87. runtime, bill_params, success, errs = run_gpt(cmd)
  88. if success:
  89. runtimes[dist_setting][bill_params] = runtime
  90. print(
  91. str(runtime) + "s per training iter for",
  92. str(bill_params) + "B parameter GPT-2",
  93. )
  94. if n >= 10000:
  95. plot(runtimes)
  96. else:
  97. print("GPT-2 w/", n, "layers failed using", dist_setting)
  98. print("Moving on to the next distributed setting...")
  99. print("#" * (25))
  100. print()
  101. plot(runtimes)
  102. break
  103. print(runtimes)
  104. plot(runtimes)
  105. if __name__ == "__main__":
  106. main()