| # Licensed to the Apache Software Foundation (ASF) under one or more |
| # contributor license agreements. See the NOTICE file distributed with |
| # this work for additional information regarding copyright ownership. |
| # The ASF licenses this file to You under the Apache License, Version 2.0 |
| # (the "License"); you may not use this file except in compliance with |
| # the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| |
| import atexit |
| import logging |
| import os |
| import sys |
| import time |
| from collections import Counter |
| from functools import lru_cache |
| from logging.handlers import RotatingFileHandler |
| |
| from rich.logging import RichHandler |
| |
| __all__ = [ |
| "init_logger", |
| "fetch_log_level", |
| "log_first_n_times", |
| "log_every_n_times", |
| "log_every_n_secs", |
| ] |
| |
| LOG_BUFFER_SIZE_ENV: str = "LOG_BUFFER_SIZE" |
| DEFAULT_BUFFER_SIZE: int = 1024 * 1024 # 1MB |
| |
| |
| @lru_cache() # avoid creating multiple handlers when calling init_logger() |
| def init_logger( |
| log_output=None, |
| log_level=logging.INFO, |
| rank=0, |
| *, |
| logger_name="client", # users should set logger name for modules |
| propagate_logs: bool = False, |
| stdout_logging: bool = True, |
| max_log_size=50 * 1024 * 1024, # 50 MB |
| backup_logs=5, |
| ): |
| """ |
| Initialize the logger and set its verbosity level to "DEBUG". |
| |
| Args: |
| log_output (str): a file name or a directory to save log. If None, will not save a log file. |
| If it ends with ".txt" or ".log", assumed to be a file name. |
| Otherwise, logs will be saved to `log_output/log.txt`. |
| logger_name (str): the root module name of this logger |
| propagate_logs (bool): whether to propagate logs to the parent logger. |
| stdout_logging (bool): whether to configure logging to stdout. |
| |
| Returns: |
| logging.Logger: a logger |
| """ |
| log_instance = logging.getLogger(logger_name) |
| log_instance.setLevel(log_level) |
| log_instance.propagate = propagate_logs |
| |
| if log_instance.hasHandlers(): |
| log_instance.handlers.clear() |
| |
| # stdout logging: master only |
| if stdout_logging and rank == 0: |
| rich_handler = RichHandler(log_level) |
| rich_handler.setFormatter(logging.Formatter("%(name)s: %(message)s")) |
| log_instance.addHandler(rich_handler) |
| |
| # file logging: all workers |
| if log_output is not None: |
| if log_output.endswith(".txt") or log_output.endswith(".log"): |
| log_filename = log_output |
| else: |
| log_filename = os.path.join(log_output, "log.txt") |
| |
| if rank > 0: |
| log_filename = f"{log_filename}.rank{rank}" |
| |
| os.makedirs(os.path.dirname(log_filename), exist_ok=True) |
| file_handler = RotatingFileHandler( |
| log_filename, |
| maxBytes=max_log_size, |
| backupCount=backup_logs, |
| ) |
| file_handler.setLevel(logging.DEBUG) |
| formatter = logging.Formatter( |
| "[%(asctime)s] %(levelname)s [%(name)s:%(filename)s:%(lineno)d] %(message)s", |
| datefmt="%m/%d/%y %H:%M:%S", |
| ) |
| file_handler.setFormatter(formatter) |
| log_instance.addHandler(file_handler) |
| return log_instance |
| |
| |
| # Cache the opened file object, so that different calls to `initialize_logger` |
| # with the same file name can safely write to the same file. |
| @lru_cache(maxsize=None) |
| def _cached_log_file(filename): |
| """Cache the opened file object""" |
| # Use 1K buffer if writing to cloud storage |
| with open(filename, "a", buffering=_determine_buffer_size(filename), encoding="utf-8") as file_io: |
| atexit.register(file_io.close) |
| return file_io |
| |
| |
| def _determine_buffer_size(filename: str) -> int: |
| """Determine the buffer size for the log stream""" |
| if "://" not in filename: |
| # Local file, no extra caching is necessary |
| return -1 |
| # Remote file requires a larger cache to avoid many smalls writes. |
| if LOG_BUFFER_SIZE_ENV in os.environ: |
| return int(os.environ[LOG_BUFFER_SIZE_ENV]) |
| return DEFAULT_BUFFER_SIZE |
| |
| |
| def _identify_caller(): |
| """ |
| Returns: |
| str: module name of the caller |
| tuple: a hashable key to be used to identify different callers |
| """ |
| frame = sys._getframe(2) # pylint: disable=protected-access |
| while frame: |
| code = frame.f_code |
| if os.path.join("utils", "logger.") not in code.co_filename: |
| module_name = frame.f_globals["__name__"] |
| if module_name == "__main__": |
| module_name = "core" |
| return module_name, (code.co_filename, frame.f_lineno, code.co_name) |
| frame = frame.f_back |
| return None, None |
| |
| |
| LOG_COUNTER = Counter() |
| LOG_TIMERS = {} |
| |
| |
| def log_first_n_times(level, message, n=1, *, logger_name=None, key="caller"): |
| """ |
| Log only for the first n times. |
| |
| Args: |
| logger_name (str): name of the logger to use. Will use the caller's module by default. |
| key (str or tuple[str]): the string(s) can be one of "callers" or |
| "message", which defines how to identify duplicated logs. |
| For example, if called with `n=1, key="caller"`, this function |
| will only log the first call from the same caller, regardless of |
| the message content. |
| If called with `n=1, key="message"`, this function will log the |
| same content only once, even if they are called from different places. |
| If called with `n=1, key=("caller", "message")`, this function |
| will not log only if the same caller has logged the same message before. |
| """ |
| if isinstance(key, str): |
| key = (key,) |
| assert len(key) > 0 |
| |
| caller_module, caller_key = _identify_caller() |
| hash_key = () |
| if "caller" in key: |
| hash_key = hash_key + caller_key |
| if "message" in key: |
| hash_key = hash_key + (message,) |
| |
| LOG_COUNTER[hash_key] += 1 |
| if LOG_COUNTER[hash_key] <= n: |
| logging.getLogger(logger_name or caller_module).log(level, message) |
| |
| |
| def log_every_n_times(level, message, n=1, *, logger_name=None): |
| caller_module, key = _identify_caller() |
| LOG_COUNTER[key] += 1 |
| if n == 1 or LOG_COUNTER[key] % n == 1: |
| logging.getLogger(logger_name or caller_module).log(level, message) |
| |
| |
| def log_every_n_secs(level, message, n=1, *, logger_name=None): |
| caller_module, key = _identify_caller() |
| last_logged = LOG_TIMERS.get(key, None) |
| current_time = time.time() |
| if last_logged is None or current_time - last_logged >= n: |
| logging.getLogger(logger_name or caller_module).log(level, message) |
| LOG_TIMERS[key] = current_time |
| |
| |
| def fetch_log_level(level_name: str): |
| """Fetch the logging level by its name""" |
| level = getattr(logging, level_name.upper(), None) |
| if not isinstance(level, int): |
| raise ValueError(f"Invalid log level: {level_name}") |
| return level |
| |
| log = init_logger(log_output="logs/output.log", log_level=logging.INFO) |