123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596 |
- #!/usr/bin/env python3
- # -*- coding:utf-8 -*-
- # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
- from loguru import logger
- import inspect
- import os
- import sys
- def get_caller_name(depth=0):
- """
- Args:
- depth (int): Depth of caller conext, use 0 for caller depth. Default value: 0.
- Returns:
- str: module name of the caller
- """
- # the following logic is a little bit faster than inspect.stack() logic
- frame = inspect.currentframe().f_back
- for _ in range(depth):
- frame = frame.f_back
- return frame.f_globals["__name__"]
- class StreamToLoguru:
- """
- stream object that redirects writes to a logger instance.
- """
- def __init__(self, level="INFO", caller_names=("apex", "pycocotools")):
- """
- Args:
- level(str): log level string of loguru. Default value: "INFO".
- caller_names(tuple): caller names of redirected module.
- Default value: (apex, pycocotools).
- """
- self.level = level
- self.linebuf = ""
- self.caller_names = caller_names
- def write(self, buf):
- full_name = get_caller_name(depth=1)
- module_name = full_name.rsplit(".", maxsplit=-1)[0]
- if module_name in self.caller_names:
- for line in buf.rstrip().splitlines():
- # use caller level log
- logger.opt(depth=2).log(self.level, line.rstrip())
- else:
- sys.__stdout__.write(buf)
- def flush(self):
- pass
- def redirect_sys_output(log_level="INFO"):
- redirect_logger = StreamToLoguru(log_level)
- sys.stderr = redirect_logger
- sys.stdout = redirect_logger
- def setup_logger(save_dir, distributed_rank=0, filename="log.txt", mode="a"):
- """setup logger for training and testing.
- Args:
- save_dir(str): location to save log file
- distributed_rank(int): device rank when multi-gpu environment
- filename (string): log save name.
- mode(str): log file write mode, `append` or `override`. default is `a`.
- Return:
- logger instance.
- """
- loguru_format = (
- "<green>{time:YYYY-MM-DD HH:mm:ss}</green> | "
- "<level>{level: <8}</level> | "
- "<cyan>{name}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>"
- )
- logger.remove()
- save_file = os.path.join(save_dir, filename)
- if mode == "o" and os.path.exists(save_file):
- os.remove(save_file)
- # only keep logger in rank0 process
- if distributed_rank == 0:
- logger.add(
- sys.stderr,
- format=loguru_format,
- level="INFO",
- enqueue=True,
- )
- logger.add(save_file)
- # redirect stdout/stderr to loguru
- redirect_sys_output("INFO")
|