[docs]defmigrate_checkpoints_to_current_version(root:pathlib.Path)->None:""" Migrate existing checkpoints from a previous version of Helios to the current version. This function exists to provide backwards compatibility with checkpoints produced by older versions of Helios. Args: root: the root where the checkpoints are stored. """register_trainer_types_for_safe_load()forchkpt_pathintqdm.tqdm(list(root.glob("*.pth")),desc="Migrating checkpoints",unit="chkpt"):state=core.safe_torch_load(chkpt_path)# Pre-v1.0, checkpoints didn't have a version key.if"version"notinstate:state["version"]=__version__# Pre-v1.1, the TrainingState struct is saved as a dictionary, not the object# itself.ifisinstance(state["training_state"],dict):state["training_state"]=TrainingState(**state["training_state"])torch.save(state,chkpt_path)
def_main()->None:parser=argparse.ArgumentParser(description="Migration tool to convert checkpoints generated by versions of "f"Helios prior to {__version__}")parser.add_argument("root",metavar="ROOT",nargs=1,type=str,help="Root where the checkpoints are stored",)args=parser.parse_args()root=args.root[0]# Temporarily re-direct PosixPath to WindowsPath on Windows to avoid problems.tmp:type[pathlib.WindowsPath]|type[pathlib.PosixPath]|None=Noneifplatform.system()=="Windows":tmp=pathlib.PosixPathpathlib.PosixPath=pathlib.WindowsPath# type: ignore[assignment, misc]elifplatform.system()=="Linux":tmp=pathlib.WindowsPathpathlib.WindowsPath=pathlib.PosixPath# type: ignore[assignment, misc]migrate_checkpoints_to_current_version(pathlib.Path(root))# Restore to default values.ifplatform.system()=="Windows":pathlib.PosixPath=tmp# type: ignore[assignment, misc]elifplatform.system()=="Linux":pathlib.WindowsPath=tmp# type: ignore[assignment, misc]if__name__=="__main__":_main()