launch.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219
  1. #!/usr/bin/env python3
  2. # -*- coding:utf-8 -*-
  3. # Code are based on
  4. # https://github.com/facebookresearch/detectron2/blob/master/detectron2/engine/launch.py
  5. # Copyright (c) Facebook, Inc. and its affiliates.
  6. # Copyright (c) Megvii, Inc. and its affiliates.
  7. from loguru import logger
  8. import torch
  9. import torch.distributed as dist
  10. import torch.multiprocessing as mp
  11. import yolox.utils.dist as comm
  12. from yolox.utils import configure_nccl
  13. import os
  14. import subprocess
  15. import sys
  16. import time
  17. __all__ = ["launch"]
  18. def _find_free_port():
  19. """
  20. Find an available port of current machine / node.
  21. """
  22. import socket
  23. sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
  24. # Binding to port 0 will cause the OS to find an available port for us
  25. sock.bind(("", 0))
  26. port = sock.getsockname()[1]
  27. sock.close()
  28. # NOTE: there is still a chance the port could be taken by other processes.
  29. return port
  30. def launch(
  31. main_func,
  32. num_gpus_per_machine,
  33. num_machines=1,
  34. machine_rank=0,
  35. backend="nccl",
  36. dist_url=None,
  37. args=(),
  38. ):
  39. """
  40. Args:
  41. main_func: a function that will be called by `main_func(*args)`
  42. num_machines (int): the total number of machines
  43. machine_rank (int): the rank of this machine (one per machine)
  44. dist_url (str): url to connect to for distributed training, including protocol
  45. e.g. "tcp://127.0.0.1:8686".
  46. Can be set to auto to automatically select a free port on localhost
  47. args (tuple): arguments passed to main_func
  48. """
  49. world_size = num_machines * num_gpus_per_machine
  50. if world_size > 1:
  51. if int(os.environ.get("WORLD_SIZE", "1")) > 1:
  52. dist_url = "{}:{}".format(
  53. os.environ.get("MASTER_ADDR", None),
  54. os.environ.get("MASTER_PORT", "None"),
  55. )
  56. local_rank = int(os.environ.get("LOCAL_RANK", "0"))
  57. world_size = int(os.environ.get("WORLD_SIZE", "1"))
  58. _distributed_worker(
  59. local_rank,
  60. main_func,
  61. world_size,
  62. num_gpus_per_machine,
  63. num_machines,
  64. machine_rank,
  65. backend,
  66. dist_url,
  67. args,
  68. )
  69. exit()
  70. launch_by_subprocess(
  71. sys.argv,
  72. world_size,
  73. num_machines,
  74. machine_rank,
  75. num_gpus_per_machine,
  76. dist_url,
  77. args,
  78. )
  79. else:
  80. main_func(*args)
  81. def launch_by_subprocess(
  82. raw_argv,
  83. world_size,
  84. num_machines,
  85. machine_rank,
  86. num_gpus_per_machine,
  87. dist_url,
  88. args,
  89. ):
  90. assert (
  91. world_size > 1
  92. ), "subprocess mode doesn't support single GPU, use spawn mode instead"
  93. if dist_url is None:
  94. # ------------------------hack for multi-machine training -------------------- #
  95. if num_machines > 1:
  96. master_ip = subprocess.check_output(["hostname", "--fqdn"]).decode("utf-8")
  97. master_ip = str(master_ip).strip()
  98. dist_url = "tcp://{}".format(master_ip)
  99. ip_add_file = "./" + args[1].experiment_name + "_ip_add.txt"
  100. if machine_rank == 0:
  101. port = _find_free_port()
  102. with open(ip_add_file, "w") as ip_add:
  103. ip_add.write(dist_url+'\n')
  104. ip_add.write(str(port))
  105. else:
  106. while not os.path.exists(ip_add_file):
  107. time.sleep(0.5)
  108. with open(ip_add_file, "r") as ip_add:
  109. dist_url = ip_add.readline().strip()
  110. port = ip_add.readline()
  111. else:
  112. dist_url = "tcp://127.0.0.1"
  113. port = _find_free_port()
  114. # set PyTorch distributed related environmental variables
  115. current_env = os.environ.copy()
  116. current_env["MASTER_ADDR"] = dist_url
  117. current_env["MASTER_PORT"] = str(port)
  118. current_env["WORLD_SIZE"] = str(world_size)
  119. assert num_gpus_per_machine <= torch.cuda.device_count()
  120. if "OMP_NUM_THREADS" not in os.environ and num_gpus_per_machine > 1:
  121. current_env["OMP_NUM_THREADS"] = str(1)
  122. logger.info(
  123. "\n*****************************************\n"
  124. "Setting OMP_NUM_THREADS environment variable for each process "
  125. "to be {} in default, to avoid your system being overloaded, "
  126. "please further tune the variable for optimal performance in "
  127. "your application as needed. \n"
  128. "*****************************************".format(
  129. current_env["OMP_NUM_THREADS"]
  130. )
  131. )
  132. processes = []
  133. for local_rank in range(0, num_gpus_per_machine):
  134. # each process's rank
  135. dist_rank = machine_rank * num_gpus_per_machine + local_rank
  136. current_env["RANK"] = str(dist_rank)
  137. current_env["LOCAL_RANK"] = str(local_rank)
  138. # spawn the processes
  139. cmd = ["python3", *raw_argv]
  140. process = subprocess.Popen(cmd, env=current_env)
  141. processes.append(process)
  142. for process in processes:
  143. process.wait()
  144. if process.returncode != 0:
  145. raise subprocess.CalledProcessError(returncode=process.returncode, cmd=cmd)
  146. def _distributed_worker(
  147. local_rank,
  148. main_func,
  149. world_size,
  150. num_gpus_per_machine,
  151. num_machines,
  152. machine_rank,
  153. backend,
  154. dist_url,
  155. args,
  156. ):
  157. assert (
  158. torch.cuda.is_available()
  159. ), "cuda is not available. Please check your installation."
  160. configure_nccl()
  161. global_rank = machine_rank * num_gpus_per_machine + local_rank
  162. logger.info("Rank {} initialization finished.".format(global_rank))
  163. try:
  164. dist.init_process_group(
  165. backend=backend,
  166. init_method=dist_url,
  167. world_size=world_size,
  168. rank=global_rank,
  169. )
  170. except Exception:
  171. logger.error("Process group URL: {}".format(dist_url))
  172. raise
  173. # synchronize is needed here to prevent a possible timeout after calling init_process_group
  174. # See: https://github.com/facebookresearch/maskrcnn-benchmark/issues/172
  175. comm.synchronize()
  176. if global_rank == 0 and os.path.exists(
  177. "./" + args[1].experiment_name + "_ip_add.txt"
  178. ):
  179. os.remove("./" + args[1].experiment_name + "_ip_add.txt")
  180. assert num_gpus_per_machine <= torch.cuda.device_count()
  181. torch.cuda.set_device(local_rank)
  182. args[1].local_rank = local_rank
  183. args[1].num_machines = num_machines
  184. # Setup the local process group (which contains ranks within the same machine)
  185. # assert comm._LOCAL_PROCESS_GROUP is None
  186. # num_machines = world_size // num_gpus_per_machine
  187. # for i in range(num_machines):
  188. # ranks_on_i = list(range(i * num_gpus_per_machine, (i + 1) * num_gpus_per_machine))
  189. # pg = dist.new_group(ranks_on_i)
  190. # if i == machine_rank:
  191. # comm._LOCAL_PROCESS_GROUP = pg
  192. main_func(*args)