__init__.py 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. import logging
  2. import warnings
  3. # May help avoid undefined symbol errors https://pytorch.org/cppdocs/notes/faq.html#undefined-symbol-errors-from-pytorch-aten
  4. import torch
  5. __all__ = ["amp", "fp16_utils", "optimizers", "normalization", "transformer"]
  6. if torch.distributed.is_available():
  7. from . import parallel
  8. __all__.append("parallel")
  9. from . import amp
  10. from . import fp16_utils
  11. # For optimizers and normalization there is no Python fallback.
  12. # Absence of cuda backend is a hard error.
  13. # I would like the errors from importing fused_adam_cuda or fused_layer_norm_cuda
  14. # to be triggered lazily, because if someone has installed with --cpp_ext and --cuda_ext
  15. # so they expect those backends to be available, but for some reason they actually aren't
  16. # available (for example because they built improperly in a way that isn't revealed until
  17. # load time) the error message is timely and visible.
  18. from . import optimizers
  19. from . import normalization
  20. from . import transformer
  21. # Logging utilities for apex.transformer module
  22. class RankInfoFormatter(logging.Formatter):
  23. def format(self, record):
  24. from apex.transformer.parallel_state import get_rank_info
  25. record.rank_info = get_rank_info()
  26. return super().format(record)
  27. _library_root_logger = logging.getLogger(__name__)
  28. handler = logging.StreamHandler()
  29. handler.setFormatter(RankInfoFormatter("%(asctime)s - PID:%(process)d - rank:%(rank_info)s - %(filename)s:%(lineno)d - %(levelname)s - %(message)s", "%y-%m-%d %H:%M:%S"))
  30. _library_root_logger.addHandler(handler)
  31. _library_root_logger.propagate = False
  32. def check_cudnn_version_and_warn(global_option: str, required_cudnn_version: int) -> bool:
  33. cudnn_available = torch.backends.cudnn.is_available()
  34. cudnn_version = torch.backends.cudnn.version() if cudnn_available else None
  35. if not (cudnn_available and (cudnn_version >= required_cudnn_version)):
  36. warnings.warn(
  37. f"`{global_option}` depends on cuDNN {required_cudnn_version} or later, "
  38. f"but {'cuDNN is not available' if not cudnn_available else cudnn_version}"
  39. )
  40. return False
  41. return True
  42. class DeprecatedFeatureWarning(FutureWarning):
  43. pass
  44. def deprecated_warning(msg: str) -> None:
  45. if (
  46. not torch.distributed.is_available
  47. or not torch.distributed.is_initialized()
  48. or (torch.distributed.is_initialized() and torch.distributed.get_rank() == 0)
  49. ):
  50. warnings.warn(msg, DeprecatedFeatureWarning)