Source code for helios.core.loggers.wandb
import pathlib
import typing
try:
import wandb as wandb_sdk
_WANDB_AVAILABLE = True
except ImportError:
_WANDB_AVAILABLE = False
from ..distributed import get_global_rank
from .base import Logger
[docs]
class WandbArgs(typing.TypedDict, total=False):
"""
Arguments for constructing a :py:class:`WandbWriter`.
``project`` is the only required key; all others are optional.
Args:
project: W&B project name.
name: display name for the run shown in the W&B UI. If not provided, defaults to
``run_name`` provided by :py:meth:`~WandbWriter.setup`.
config: hyper-parameter dictionary to associate with the run.
extra_args: additional keyword arguments forwarded verbatim to
:func:`wandb.init`.
"""
project: typing.Required[str]
name: str
config: dict[str, typing.Any]
extra_args: dict[str, typing.Any]
[docs]
class WandbWriter(Logger):
"""
Wrapper for the Weights & Biases ``wandb.init`` run.
Data for the logger will be placed under ``log_root/wandb``. When
resuming, the original run ID is restored and new data is appended to it.
Requires the ``wandb`` package. Install it with::
pip install wandb
Args:
project: W&B project name.
name: (optional) display name for the run.
config: (optional) hyper-parameter configuration dictionary to associate
with the run.
extra_args: (optional) extra keyword arguments forwarded verbatim to
:func:`wandb.init`. Defaults to ``{}``.
"""
def __init__(
self,
project: str,
name: str | None = None,
config: dict[str, typing.Any] | None = None,
extra_args: dict[str, typing.Any] | None = None,
) -> None:
"""
Create the W&B writer.
Args:
project: W&B project name.
name: (optional) display name for the run.
config: (optional) hyper-parameter configuration dictionary.
extra_args: (optional) extra keyword arguments forwarded to
:func:`wandb.init`. Defaults to ``{}``.
Raises:
ImportError: if ``wandb`` is not installed.
"""
if not _WANDB_AVAILABLE:
raise ImportError(
"wandb is required to use the WandbWriter. "
"Install it with: pip install wandb"
)
self._project = project
self._name = name
self._config = config
self._extra_args = extra_args if extra_args is not None else {}
self._rank = get_global_rank()
self._run: typing.Any = None
self._run_id: str | None = None
self._saved_run_id: str | None = None
[docs]
def setup(
self, run_name: str, log_root: pathlib.Path | None, is_resume: bool
) -> None:
"""
Finish configuring the ``WandbWriter``.
In particular, this function will call :func:`wandb.init`. If a run ID was
previously saved, then it will be forwarded to W&B so the run continues in place.
In distributed training, the writer will only be created on rank 0.
Args:
run_name: the name of the current run; used as the W&B run name
when no explicit ``name`` was given in ``__init__``.
log_root: root directory for logs. W&B data will be written
under ``log_root/wandb/``. ``None`` lets W&B choose its own
default directory.
is_resume: ``True`` when continuing a previous run.
"""
if self._rank != 0:
return
log_dir: str | None = None
if log_root is not None:
wandb_root = log_root / "wandb"
wandb_root.mkdir(parents=True, exist_ok=True)
log_dir = str(wandb_root)
run_id = (
self._saved_run_id if is_resume and self._saved_run_id is not None else None
)
display_name = self._name if self._name is not None else run_name
self._run = wandb_sdk.init(
project=self._project,
name=display_name,
config=self._config,
dir=log_dir,
id=run_id,
resume="allow" if run_id is not None else None,
**self._extra_args,
)
self._run_id = self._run.id
[docs]
def flush(self) -> None:
"""No-op, W&B syncs its own data automatically."""
[docs]
def close(self) -> None:
"""Finish and close the W&B run."""
if self._run is None:
return
self._run.finish()
self._run = None
[docs]
def state_dict(self) -> dict[str, typing.Any]:
"""
Return a dictionary containing the writer state.
The state will be saved under a key called ``"run_id"`` holding the current run
ID. If W&B was disabled, then ``None`` is stored.
Returns:
A dictionary with the logger state.
"""
return {"run_id": self._run_id}
[docs]
def load_state_dict(self, state_dict: dict[str, typing.Any]) -> None:
"""
Restore the writer state from a previously saved dictionary.
Args:
state_dict: the state dictionary returned by :py:meth:`state_dict`.
"""
self._saved_run_id = state_dict.get("run_id")