Defines the types of units for training steps. |
The training state. |
Automates the training, validation, and testing code. |
Return the list of safe types for loading needed by the trainer. |
Register trainer types for safe loading. |
Find the last saved checkpoint (if available). |
Module Contents¶
- class helios.trainer.TrainingUnit(*args, **kwds)[source]¶
Defines the types of units for training steps.
- EPOCH = 1¶
- classmethod from_str(label: str) TrainingUnit [source]¶
Convert the given string to the corresponding enum value.
Must be one of “iteration” or “epoch”.
- Parameters:
label – the label to convert.
- Returns:
The corresponding value.
- Raises:
ValueError – if the given value is not one of “iteration” or “epoch”.
- class helios.trainer.TrainingState[source]¶
The training state.
- Parameters:
current_iteration – the current iteration number. Note that this refers to the iteration in which gradients are updated. This may or may not be equal to the
count.global_iteration – the total iteration count.
global_epoch – the total epoch count.
validation_cycles – the number of validation cycles.
dataset_iter – the current batch index of the dataset. This is reset every epoch.
early_stop_count – the current number of validation cycles for early stop.
average_iter_time – average time per iteration.
running_iter – iteration count in the current validation cycle. Useful for computing running averages of loss functions.
- current_iteration: int = 0¶
- global_iteration: int = 0¶
- global_epoch: int = 0¶
- validation_cycles: int = 0¶
- dataset_iter: int = 0¶
- early_stop_count: int = 0¶
- average_iter_time: float = 0¶
- running_iter: int = 0¶
- dict¶
- helios.trainer.get_trainer_safe_types_for_load() list[Callable | tuple[Callable, str]] [source]¶
Return the list of safe types for loading needed by the trainer.
- Returns:
The list of types that need to be registered for safe loading.
- helios.trainer.register_trainer_types_for_safe_load() None [source]¶
Register trainer types for safe loading.
- helios.trainer.find_last_checkpoint(root: pathlib.Path | None) pathlib.Path | None [source]¶
Find the last saved checkpoint (if available).
The function assumes that checkpoint names contain
in the name, in which case it will return the path to the checkpoint with the highest epoch and/or iteration count.- Parameters:
root – the path where the checkpoints are stored.
- Returns:
The path to the last checkpoint (if any).
- class helios.trainer.Trainer(run_name: str = '', train_unit: TrainingUnit | str = TrainingUnit.EPOCH, total_steps: int | float = 0, valid_frequency: int | None = None, chkpt_frequency: int | None = None, print_frequency: int | None = None, accumulation_steps: int = 1, enable_cudnn_benchmark: bool = False, enable_deterministic: bool = False, early_stop_cycles: int | None = None, use_cpu: bool | None = None, gpus: list[int] | None = None, random_seed: int | None = None, enable_tensorboard: bool = False, enable_file_logging: bool = False, enable_progress_bar: bool = False, chkpt_root: pathlib.Path | None = None, log_path: pathlib.Path | None = None, run_path: pathlib.Path | None = None, src_root: pathlib.Path | None = None, import_prefix: str = '', print_banner: bool = True)[source]¶
Automates the training, validation, and testing code.
The trainer handles all of the required steps to setup the correct environment (including handling distributed training), the training/validation/testing loops, and any clean up afterwards.
- Parameters:
run_name – name of the current run. Defaults to empty.
train_unit – the unit used for training. Defaults to
.total_steps – the total number of steps to train for. Defaults to 0.
valid_frequency – (optional) frequency with which to perform validation.
chkpt_frequency – (optional) frequency with which to save checkpoints.
print_frequency – (optional) frequency with which to log.
accumulation_steps – number of steps for gradient accumulation. Defaults to 1.
enable_cudnn_benchmark – enable/disable CuDNN benchmark. Defaults to false.
enable_deterministic – enable/disable PyTorch deterministic. Defaults to false.
early_stop_cycles – (optional) number of cycles after which training will stop if no improvement is seen during validation.
use_cpu – (optional) if true, CPU will be used.
gpus – (optional) IDs of GPUs to use.
random_seed – (optional) the seed to use for RNGs.
enable_tensorboard – enable/disable Tensorboard logging. Defaults to false.
enable_file_logging – enable/disable file logging. Defaults to false.
enable_progress_bar – enable/disable the progress bar(s). Defaults to false.
chkpt_root – (optional) root folder in which checkpoints will be placed.
log_path – (optional) root folder in which logs will be saved.
run_path – (optional) root folder in which Tensorboard runs will be saved.
src_root – (optional) root folder where the code is located. This is used to automatically populate the registries using
.import_prefix – prefix to use when importing modules. See
for details.print_banner – if true, the Helios banner with system info will be printed. Defaults to true.
- property model: helios.model.Model¶
Return the model.
- property datamodule: helios.data.DataModule¶
Return the datamodule.
- property local_rank: int¶
Return the local rank of the trainer.
- property rank: int¶
Return the global rank of the trainer.
- property gpu_ids: list[int]¶
Return the list of GPU IDs to use for training.
- property train_exceptions: list[type[Exception]]¶
Return the list of valid exceptions for training.
- property test_exceptions: list[type[Exception]]¶
Return the list of valid exceptions for testing.
- property plugins: dict[str, helios.plugins.Plugin]¶
Return the list of plug-ins.
- property queue: torch.multiprocessing.Queue | None¶
Return the multi-processing queue instance.
If training isn’t distributed or if torchrun, then None is returned instead.