train.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. from loguru import logger
  2. import torch
  3. import torch.backends.cudnn as cudnn
  4. from yolox.core import Trainer, launch
  5. from yolox.exp import get_exp
  6. import argparse
  7. import random
  8. import warnings
  9. def make_parser():
  10. parser = argparse.ArgumentParser("YOLOX train parser")
  11. parser.add_argument("-expn", "--experiment-name", type=str, default=None)
  12. parser.add_argument("-n", "--name", type=str, default=None, help="model name")
  13. # distributed
  14. parser.add_argument(
  15. "--dist-backend", default="nccl", type=str, help="distributed backend"
  16. )
  17. parser.add_argument(
  18. "--dist-url",
  19. default=None,
  20. type=str,
  21. help="url used to set up distributed training",
  22. )
  23. parser.add_argument("-b", "--batch-size", type=int, default=64, help="batch size")
  24. parser.add_argument(
  25. "-d", "--devices", default=None, type=int, help="device for training"
  26. )
  27. parser.add_argument(
  28. "--local_rank", default=0, type=int, help="local rank for dist training"
  29. )
  30. parser.add_argument(
  31. "-f",
  32. "--exp_file",
  33. default=None,
  34. type=str,
  35. help="plz input your expriment description file",
  36. )
  37. parser.add_argument(
  38. "--resume", default=False, action="store_true", help="resume training"
  39. )
  40. parser.add_argument("-c", "--ckpt", default=None, type=str, help="checkpoint file")
  41. parser.add_argument(
  42. "-e",
  43. "--start_epoch",
  44. default=None,
  45. type=int,
  46. help="resume training start epoch",
  47. )
  48. parser.add_argument(
  49. "--num_machines", default=1, type=int, help="num of node for training"
  50. )
  51. parser.add_argument(
  52. "--machine_rank", default=0, type=int, help="node rank for multi-node training"
  53. )
  54. parser.add_argument(
  55. "--fp16",
  56. dest="fp16",
  57. default=True,
  58. action="store_true",
  59. help="Adopting mix precision training.",
  60. )
  61. parser.add_argument(
  62. "-o",
  63. "--occupy",
  64. dest="occupy",
  65. default=False,
  66. action="store_true",
  67. help="occupy GPU memory first for training.",
  68. )
  69. parser.add_argument(
  70. "opts",
  71. help="Modify config_files options using the command-line",
  72. default=None,
  73. nargs=argparse.REMAINDER,
  74. )
  75. return parser
  76. @logger.catch
  77. def main(exp, args):
  78. if exp.seed is not None:
  79. random.seed(exp.seed)
  80. torch.manual_seed(exp.seed)
  81. cudnn.deterministic = True
  82. warnings.warn(
  83. "You have chosen to seed training. This will turn on the CUDNN deterministic setting, "
  84. "which can slow down your training considerably! You may see unexpected behavior "
  85. "when restarting from checkpoints."
  86. )
  87. # set environment variables for distributed training
  88. cudnn.benchmark = True
  89. trainer = Trainer(exp, args)
  90. trainer.train()
  91. if __name__ == "__main__":
  92. args = make_parser().parse_args()
  93. exp = get_exp(args.exp_file, args.name)
  94. exp.merge(args.opts)
  95. if not args.experiment_name:
  96. args.experiment_name = exp.exp_name
  97. num_gpu = torch.cuda.device_count() if args.devices is None else args.devices
  98. assert num_gpu <= torch.cuda.device_count()
  99. launch(
  100. main,
  101. num_gpu,
  102. args.num_machines,
  103. args.machine_rank,
  104. backend=args.dist_backend,
  105. dist_url=args.dist_url,
  106. args=(exp, args),
  107. )