Source code for helios.chkpt_migrator

import argparse
import pathlib
import platform

import torch
import tqdm

from helios import core

from ._version import __version__
from .trainer import TrainingState, register_trainer_types_for_safe_load


[docs] def migrate_checkpoints_to_current_version(root: pathlib.Path) -> None: """ Migrate existing checkpoints from a previous version of Helios to the current version. This function exists to provide backwards compatibility with checkpoints produced by older versions of Helios. Args: root: the root where the checkpoints are stored. """ register_trainer_types_for_safe_load() for chkpt_path in tqdm.tqdm( list(root.glob("*.pth")), desc="Migrating checkpoints", unit="chkpt" ): state = core.safe_torch_load(chkpt_path) # Pre-v1.0, checkpoints didn't have a version key. if "version" not in state: state["version"] = __version__ # Pre-v1.1, the TrainingState struct is saved as a dictionary, not the object # itself. if isinstance(state["training_state"], dict): state["training_state"] = TrainingState(**state["training_state"]) torch.save(state, chkpt_path)
def _main() -> None: parser = argparse.ArgumentParser( description="Migration tool to convert checkpoints generated by versions of " f"Helios prior to {__version__}" ) parser.add_argument( "root", metavar="ROOT", nargs=1, type=str, help="Root where the checkpoints are stored", ) args = parser.parse_args() root = args.root[0] # Temporarily re-direct PosixPath to WindowsPath on Windows to avoid problems. tmp: type[pathlib.WindowsPath] | type[pathlib.PosixPath] | None = None if platform.system() == "Windows": tmp = pathlib.PosixPath pathlib.PosixPath = pathlib.WindowsPath # type: ignore[assignment, misc] elif platform.system() == "Linux": tmp = pathlib.WindowsPath pathlib.WindowsPath = pathlib.PosixPath # type: ignore[assignment, misc] migrate_checkpoints_to_current_version(pathlib.Path(root)) # Restore to default values. if platform.system() == "Windows": pathlib.PosixPath = tmp # type: ignore[assignment, misc] elif platform.system() == "Linux": pathlib.WindowsPath = tmp # type: ignore[assignment, misc] if __name__ == "__main__": _main()