"""
Logging sub-package for Helios.
Holds all of the logger classes.
"""
import enum
import pathlib
import typing
from .base import Logger, get_default_log_name
from .root import RootLogger
from .tensorboard import TensorboardWriter
from .wandb import WandbArgs, WandbWriter
[docs]
class LoggerType(enum.Enum):
"""Defines the types of loggers."""
ROOT = "root"
TENSORBOARD = "tensorboard"
WANDB = "wandb"
_ACTIVE_LOGGERS: dict[LoggerType, Logger] = {}
[docs]
def create_loggers(
enable_tensorboard: bool = True,
capture_warnings: bool = True,
wandb_args: WandbArgs | None = None,
) -> None:
"""
Construct the logger instances and add them to the active table.
The :py:class:`RootLogger` is always created, while additional loggers are only crated
when their corresponding flag is ``True``. If a logger has already been created then
this function does nothing, making it safe to call multiple times.
In distributed training, this function should be called *after* the processes have
been created to ensure each process gets a copy of the loggers.
Args:
enable_tensorboard: enable the Tensorboard writer. Defaults to
``True``.
capture_warnings: if ``True``, :py:func:`warnings.warn` output is
captured by the root logger. Defaults to ``True``.
wandb_writer: an already-constructed :py:class:`WandbWriter` instance to
register, or ``None`` to skip W&B logging. Defaults to ``None``.
"""
if LoggerType.ROOT not in _ACTIVE_LOGGERS:
_ACTIVE_LOGGERS[LoggerType.ROOT] = RootLogger(capture_warnings=capture_warnings)
if enable_tensorboard and LoggerType.TENSORBOARD not in _ACTIVE_LOGGERS:
_ACTIVE_LOGGERS[LoggerType.TENSORBOARD] = TensorboardWriter()
if wandb_args is not None and LoggerType.WANDB not in _ACTIVE_LOGGERS:
_ACTIVE_LOGGERS[LoggerType.WANDB] = WandbWriter(**wandb_args)
[docs]
def setup_loggers(
run_name: str,
log_root: pathlib.Path | None = None,
) -> None:
"""
Call :py:meth:`~Logger.setup` on every active logger for a fresh run.
This function should be called when the loggers don't need to continue from a previous
run. If you need that, call :py:func:`restore_loggers` instead.
Args:
run_name: the name of the current run.
log_root: root directory under which each logger will create its own
subfolder. ``None`` disables on-disk output.
"""
for logger in _ACTIVE_LOGGERS.values():
logger.setup(run_name, log_root, is_resume=False)
[docs]
def restore_loggers(
run_name: str,
log_root: pathlib.Path | None = None,
loggers_state: dict[str, dict[str, typing.Any]] | None = None,
) -> None:
"""
Restore active loggers from a previous run.
For each active logger whose name appears in the ``loggers_state`` dictionary:
1. Call :py:meth:`~Logger.load_state_dict` so that their previous state is loaded.
1. Call :py:meth:`~Logger.setup` so the loggers re-use the original paths.
If an active logger does not have an entry in the dictionary, then it is configured to
start fresh.
Args:
run_name: the name of the current run.
log_root: root directory under which each logger will look for its
subfolder. ``None`` disables on-disk output.
loggers_state: mapping of ``{logger_name: state_dict}`` as returned
by a prior call to :py:func:`get_logger_state_dicts`. ``None``
is treated the same as an empty mapping.
"""
if loggers_state is None:
loggers_state = {}
for logger_type, logger in _ACTIVE_LOGGERS.items():
key = logger_type.value
if key in loggers_state:
logger.load_state_dict(loggers_state[key])
logger.setup(run_name, log_root, is_resume=True)
else:
logger.setup(run_name, log_root, is_resume=False)
[docs]
def get_logger(name: LoggerType) -> Logger:
"""
Return the active logger identified by *name*.
Args:
name: the :py:class:`LoggerType` value identifying the desired logger.
Returns:
The requested :py:class:`Logger` instance.
Raises:
KeyError: if the requested logger has not been created.
"""
if name not in _ACTIVE_LOGGERS:
raise KeyError(
f"error: logger '{name.value}' has not been created. "
"Did you forget to call create_default_loggers?"
)
return _ACTIVE_LOGGERS[name]
[docs]
def get_logger_state_dicts() -> dict[str, dict[str, typing.Any]]:
"""
Return the state dictionaries of all active loggers.
Returns:
The dictionary containing the state of all active loggers.
"""
return {
logger_type.value: logger.state_dict()
for logger_type, logger in _ACTIVE_LOGGERS.items()
}
[docs]
def flush_loggers() -> None:
"""Flush all active loggers."""
for logger in _ACTIVE_LOGGERS.values():
logger.flush()
[docs]
def close_loggers() -> None:
"""Close all active loggers and remove them from the active table."""
for logger in _ACTIVE_LOGGERS.values():
logger.close()
_ACTIVE_LOGGERS.clear()
[docs]
def is_root_logger_active() -> bool:
"""Return ``True`` if the root logger has been created."""
return LoggerType.ROOT in _ACTIVE_LOGGERS
[docs]
def get_root_logger() -> RootLogger:
"""
Get the root logger instance.
Returns:
The root logger.
Raises:
KeyError: if the root logger has not been created.
"""
return typing.cast(RootLogger, get_logger(LoggerType.ROOT))
[docs]
def get_tensorboard_writer() -> TensorboardWriter | None:
"""
Return the Tensorboard writer.
If Tensorboard is disabled, this function will return ``None``.
Returns:
The Tensorboard logger, or ``None`` if it doesn't exist.
"""
if LoggerType.TENSORBOARD not in _ACTIVE_LOGGERS:
return None
return typing.cast(TensorboardWriter, _ACTIVE_LOGGERS[LoggerType.TENSORBOARD])
[docs]
def get_wandb_writer() -> WandbWriter | None:
"""
Return the WandbWriter.
If Wandb is disabled, this function will return ``None``.
Returns:
The Wandb logger, or ``None`` if it doesn't exist.
"""
if LoggerType.WANDB not in _ACTIVE_LOGGERS:
return None
return typing.cast(WandbWriter, _ACTIVE_LOGGERS[LoggerType.WANDB])
__all__ = [
"Logger",
"LoggerType",
"RootLogger",
"TensorboardWriter",
"WandbArgs",
"WandbWriter",
"get_default_log_name",
"create_loggers",
"setup_loggers",
"restore_loggers",
"get_logger",
"get_logger_state_dicts",
"flush_loggers",
"close_loggers",
"is_root_logger_active",
"get_root_logger",
"get_tensorboard_writer",
"get_wandb_writer",
]