Source code for homura.liblog

""" logging tools leaned a lot from Optuna and Transformers
"""
import io
import logging
import sys
import threading
import warnings
from typing import Optional, TextIO

import tqdm as _tqdm
from tqdm.contrib import DummyTqdmFile

try:
    import colorlog

    _has_colorlog = True
except ImportError:
    _has_colorlog = False

# private APIs
_LOG_LEVEL = {"debug": logging.DEBUG,
              "info": logging.INFO,
              "warning": logging.WARNING,
              "error": logging.ERROR,
              "critical": logging.CRITICAL}

_default_handler = None
_original_stds = sys.stdout, sys.stderr
_lock = threading.Lock()


def _name() -> str:
    return __name__.split('.')[0]


def _create_default_formatter() -> logging.Formatter:
    datefmt = "%Y-%m-%d %H:%M:%S"
    return colorlog.ColoredFormatter('%(log_color)s[%(name)s|%(asctime)s|%(levelname)s] %(message)s', datefmt=datefmt)


def _get_root_logger() -> logging.Logger:
    return logging.getLogger(_name())


def _configure_root_logger() -> None:
    global _default_handler
    with _lock:
        if _default_handler is not None:
            return None
        _default_handler = logging.StreamHandler()
        _default_handler.setFormatter(_create_default_formatter())
        _user_root_logger = logging.getLogger()
        if len(_user_root_logger.handlers) > 0:
            # if user already defines their own root logger
            return None
        root_logger = _get_root_logger()
        root_logger.addHandler(_default_handler)
        root_logger.setLevel(logging.INFO)
        root_logger.propagate = False


def _reset_root_logger() -> None:
    global _default_handler
    with _lock:
        if _default_handler is None:
            return None
        root_logger = _get_root_logger()
        root_logger.removeHandler(_default_handler)
        root_logger.setLevel(logging.NOTSET)
        _default_handler = None


# public APIs

[docs]def get_logger(name: str = None ) -> logging.Logger: if name is None: name = _name() _configure_root_logger() return logging.getLogger(name)
[docs]def get_verb_level() -> int: _configure_root_logger() return _get_root_logger().getEffectiveLevel()
[docs]def set_verb_level(level: str or int) -> None: if isinstance(level, str): level = _LOG_LEVEL[level] _configure_root_logger() _get_root_logger().setLevel(level)
[docs]def enable_default_handler() -> None: _configure_root_logger() if _default_handler is None: raise RuntimeWarning() _get_root_logger().addHandler(_default_handler)
[docs]def disable_default_handler() -> None: _configure_root_logger() if _default_handler is None: raise RuntimeWarning() _get_root_logger().removeHandler(_default_handler)
[docs]def enable_propagation() -> None: _configure_root_logger() _get_root_logger().propagate = True
[docs]def disable_propagation() -> None: _configure_root_logger() _get_root_logger().propagate = False
[docs]def set_file_handler(log_file: str or TextIO, level: str or int = logging.DEBUG, formatter: Optional[logging.Formatter] = None) -> None: _configure_root_logger() fh = logging.FileHandler(log_file) if isinstance(level, str): level = _LOG_LEVEL[level] fh.setLevel(level) if formatter is None: formatter = _create_default_formatter() fh.setFormatter(formatter) _get_root_logger().addHandler(fh)
# internal APIs
[docs]def set_tqdm_handler(level: str or int = logging.INFO, formatter: Optional[logging.Formatter] = None) -> None: """ An alternative handler to avoid disturbing tqdm """ import tqdm class TQDMHandler(logging.StreamHandler): def __init__(self): logging.StreamHandler.__init__(self) def emit(self, record): msg = self.format(record) tqdm.tqdm.write(msg) _configure_root_logger() th = TQDMHandler() if isinstance(level, str): level = _LOG_LEVEL[level] th.setLevel(level) if _default_handler is not None: # to avoid multiple logs! _get_root_logger().removeHandler(_default_handler) if formatter is None: formatter = _create_default_formatter() th.setFormatter(formatter) _get_root_logger().addHandler(th)
# tqdm
[docs]def set_tqdm_stdout_stderr(): # https://github.com/tqdm/tqdm/blob/master/examples/redirect_print.py # Some libraries override sys.stdout, which causes OSError: [Errno 9] Bad file descriptor. # To avoid this, this if statement is necessary if isinstance(sys.stdout, io.TextIOWrapper): sys.stdout, sys.stderr = map(DummyTqdmFile, _original_stds) elif not isinstance(sys.stdout, DummyTqdmFile): warnings.warn(f"sys.stdout is unexpected type: {type(sys.stdout)}.\n" f"If you use wandb, set WANDB_CONSOLE=off to avoid tqdm-related problems.", UserWarning)
[docs]def tqdm(*args, **kwargs): # https://github.com/tqdm/tqdm/blob/master/examples/redirect_print.py if kwargs.get("file") is None: kwargs["file"] = _original_stds[0] # tqdm seems to prioritize dynamic_ncols over ncols if kwargs.get("ncols") is None and kwargs.get("dynamic_ncols") is None: kwargs["dynamic_ncols"] = True return _tqdm.tqdm(*args, **kwargs)
# log once _LOG_CACHE = set()
[docs]def log_once(logger, message: str, key=Optional[str]) -> None: """ Log message only once. :param logger: e.g., `print`, `logger.info` :param message: :param key: if `key=None`, `message` is used as `key`. :return: """ if key is None: key = message if key in _LOG_CACHE: return logger(message) _LOG_CACHE.add(key)