run_test.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. """L0 Tests Runner.
  2. How to run this script?
  3. 1. Run all the tests: `python /path/to/apex/tests/L0/run_test.py` If you want an xml report,
  4. pass `--xml-report`, i.e. `python /path/to/apex/tests/L0/run_test.py --xml-report` and
  5. the file is created in `/path/to/apex/tests/L0`.
  6. 2. Run one of the tests (e.g. fused layer norm):
  7. `python /path/to/apex/tests/L0/run_test.py --include run_fused_layer_norm`
  8. 3. Run two or more of the tests (e.g. optimizers and fused layer norm):
  9. `python /path/to/apex/tests/L0/run_test.py --include run_optimizers run_fused_layer_norm`
  10. """
  11. import argparse
  12. import os
  13. import unittest
  14. import sys
  15. TEST_ROOT = os.path.dirname(os.path.abspath(__file__))
  16. TEST_DIRS = [
  17. "run_amp",
  18. "run_deprecated",
  19. "run_fp16util",
  20. "run_optimizers",
  21. "run_fused_layer_norm",
  22. "run_mlp",
  23. "run_transformer",
  24. ]
  25. DEFAULT_TEST_DIRS = [
  26. "run_optimizers",
  27. "run_fused_layer_norm",
  28. "run_mlp",
  29. "run_transformer",
  30. ]
  31. def parse_args():
  32. parser = argparse.ArgumentParser(
  33. description="L0 test runner",
  34. formatter_class=argparse.ArgumentDefaultsHelpFormatter,
  35. )
  36. parser.add_argument(
  37. "--include",
  38. nargs="+",
  39. choices=TEST_DIRS,
  40. default=DEFAULT_TEST_DIRS,
  41. help="select a set of tests to run (defaults to ALL tests).",
  42. )
  43. parser.add_argument(
  44. "--xml-report",
  45. default=None,
  46. action="store_true",
  47. help="[deprecated] pass this argument to get a junit xml report. Use `--xml-dir`. (requires `xmlrunner`)",
  48. )
  49. parser.add_argument(
  50. "--xml-dir",
  51. default=None,
  52. type=str,
  53. help="Directory to save junit test reports. (requires `xmlrunner`)",
  54. )
  55. args, _ = parser.parse_known_args()
  56. return args
  57. def main(args: argparse.Namespace) -> None:
  58. test_runner_kwargs = {"verbosity": 2}
  59. Runner = unittest.TextTestRunner
  60. xml_dir = None
  61. if (args.xml_report is not None) or (args.xml_dir is not None):
  62. if args.xml_report is not None:
  63. import warnings
  64. warnings.warn("The option of `--xml-report` is deprecated", FutureWarning)
  65. import xmlrunner
  66. from datetime import date # NOQA
  67. Runner = xmlrunner.XMLTestRunner
  68. if args.xml_report:
  69. xml_dir = os.path.abspath(os.path.dirname(__file__))
  70. else:
  71. xml_dir = os.path.abspath(args.xml_dir)
  72. if not os.path.exists(xml_dir):
  73. os.makedirs(xml_dir)
  74. errcode = 0
  75. for test_dir in args.include:
  76. if xml_dir is not None:
  77. xml_output = os.path.join(
  78. xml_dir,
  79. f"""TEST_{test_dir}_{date.today().strftime("%y%m%d")}""",
  80. )
  81. if not os.path.exists(xml_output):
  82. os.makedirs(xml_output)
  83. test_runner_kwargs["output"] = xml_output
  84. runner = Runner(**test_runner_kwargs)
  85. test_dir = os.path.join(TEST_ROOT, test_dir)
  86. suite = unittest.TestLoader().discover(test_dir)
  87. print("\nExecuting tests from " + test_dir)
  88. result = runner.run(suite)
  89. if not result.wasSuccessful():
  90. errcode = 1
  91. sys.exit(errcode)
  92. if __name__ == '__main__':
  93. args = parse_args()
  94. main(args)