helios.trainer ============== .. py:module:: helios.trainer Classes ------- .. autoapisummary:: helios.trainer.TrainingUnit helios.trainer.TrainingState helios.trainer.Trainer Functions --------- .. autoapisummary:: helios.trainer.get_trainer_safe_types_for_load helios.trainer.register_trainer_types_for_safe_load helios.trainer.find_last_checkpoint Module Contents --------------- .. py:class:: TrainingUnit(*args, **kwds) Bases: :py:obj:`enum.Enum` Defines the types of units for training steps. .. py:attribute:: ITERATION :value: 0 .. py:attribute:: EPOCH :value: 1 .. py:method:: from_str(label: str) -> TrainingUnit :classmethod: Convert the given string to the corresponding enum value. Must be one of "iteration" or "epoch". :param label: the label to convert. :returns: The corresponding value. :raises ValueError: if the given value is not one of "iteration" or "epoch". .. py:class:: TrainingState The training state. :param current_iteration: the current iteration number. Note that this refers to the iteration in which gradients are updated. This may or may not be equal to the :py:attr:`global_iteration` count. :param global_iteration: the total iteration count. :param global_epoch: the total epoch count. :param validation_cycles: the number of validation cycles. :param dataset_iter: the current batch index of the dataset. This is reset every epoch. :param early_stop_count: the current number of validation cycles for early stop. :param average_iter_time: average time per iteration. :param running_iter: iteration count in the current validation cycle. Useful for computing running averages of loss functions. .. py:attribute:: current_iteration :type: int :value: 0 .. py:attribute:: global_iteration :type: int :value: 0 .. py:attribute:: global_epoch :type: int :value: 0 .. py:attribute:: validation_cycles :type: int :value: 0 .. py:attribute:: dataset_iter :type: int :value: 0 .. py:attribute:: early_stop_count :type: int :value: 0 .. py:attribute:: average_iter_time :type: float :value: 0 .. py:attribute:: running_iter :type: int :value: 0 .. py:attribute:: dict .. py:function:: get_trainer_safe_types_for_load() -> list[type] Return the list of safe types for loading needed by the trainer. :returns: The list of types that need to be registered for safe loading. .. py:function:: register_trainer_types_for_safe_load() -> None Register trainer types for safe loading. .. py:function:: find_last_checkpoint(root: pathlib.Path | None) -> pathlib.Path | None Find the last saved checkpoint (if available). The function assumes that checkpoint names contain ``epoch_`` and ``iter_`` in the name, in which case it will return the path to the checkpoint with the highest epoch and/or iteration count. :param root: the path where the checkpoints are stored. :returns: The path to the last checkpoint (if any). .. py:class:: Trainer(run_name: str = '', train_unit: TrainingUnit | str = TrainingUnit.EPOCH, total_steps: int | float = 0, valid_frequency: int | None = None, chkpt_frequency: int | None = None, print_frequency: int | None = None, accumulation_steps: int = 1, enable_cudnn_benchmark: bool = False, enable_deterministic: bool = False, early_stop_cycles: int | None = None, use_cpu: bool | None = None, gpus: list[int] | None = None, random_seed: int | None = None, enable_tensorboard: bool = False, enable_file_logging: bool = False, enable_progress_bar: bool = False, chkpt_root: pathlib.Path | None = None, log_path: pathlib.Path | None = None, run_path: pathlib.Path | None = None, src_root: pathlib.Path | None = None, import_prefix: str = '', print_banner: bool = True) Automates the training, validation, and testing code. The trainer handles all of the required steps to setup the correct environment (including handling distributed training), the training/validation/testing loops, and any clean up afterwards. :param run_name: name of the current run. Defaults to empty. :param train_unit: the unit used for training. Defaults to :py:attr:`TrainingUnit.ITERATION`. :param total_steps: the total number of steps to train for. Defaults to 0. :param valid_frequency: (optional) frequency with which to perform validation. :param chkpt_frequency: (optional) frequency with which to save checkpoints. :param print_frequency: (optional) frequency with which to log. :param accumulation_steps: number of steps for gradient accumulation. Defaults to 1. :param enable_cudnn_benchmark: enable/disable CuDNN benchmark. Defaults to false. :param enable_deterministic: enable/disable PyTorch deterministic. Defaults to false. :param early_stop_cycles: (optional) number of cycles after which training will stop if no improvement is seen during validation. :param use_cpu: (optional) if true, CPU will be used. :param gpus: (optional) IDs of GPUs to use. :param random_seed: (optional) the seed to use for RNGs. :param enable_tensorboard: enable/disable Tensorboard logging. Defaults to false. :param enable_file_logging: enable/disable file logging. Defaults to false. :param enable_progress_bar: enable/disable the progress bar(s). Defaults to false. :param chkpt_root: (optional) root folder in which checkpoints will be placed. :param log_path: (optional) root folder in which logs will be saved. :param run_path: (optional) root folder in which Tensorboard runs will be saved. :param src_root: (optional) root folder where the code is located. This is used to automatically populate the registries using :py:func:`~helios.core.utils.update_all_registries`. :param import_prefix: prefix to use when importing modules. See :py:func:`~helios.core.utils.update_all_registries` for details. :param print_banner: if true, the Helios banner with system info will be printed. Defaults to true. .. py:property:: model :type: helios.model.Model Return the model. .. py:property:: datamodule :type: helios.data.DataModule Return the datamodule. .. py:property:: local_rank :type: int Return the local rank of the trainer. .. py:property:: rank :type: int Return the global rank of the trainer. .. py:property:: gpu_ids :type: list[int] Return the list of GPU IDs to use for training. .. py:property:: train_exceptions :type: list[type[Exception]] Return the list of valid exceptions for training. .. py:property:: test_exceptions :type: list[type[Exception]] Return the list of valid exceptions for testing. .. py:property:: plugins :type: dict[str, helios.plugins.Plugin] Return the list of plug-ins. .. py:property:: queue :type: torch.multiprocessing.Queue | None Return the multi-processing queue instance. .. note:: If training isn't distributed or if `torchrun`, then `None` is returned instead. .. py:method:: fit(model: helios.model.Model, datamodule: helios.data.DataModule) -> None Run the full training routine. :param model: the model to run on. :param datamodule: the datamodule to use. .. py:method:: test(model: helios.model.Model, datamodule: helios.data.DataModule) -> None Run the full testing routine. :param model: the model to run on. :param datamodule: the datamodule to use.