helios.trainer¶
Classes¶
Defines the types of units for training steps. |
|
The training state. |
|
Automates the training, validation, and testing code. |
Functions¶
|
Return the list of safe types for loading needed by the trainer. |
Register trainer types for safe loading. |
|
|
Find the last saved checkpoint (if available). |
Module Contents¶
- class helios.trainer.TrainingUnit(*args, **kwds)[source]¶
Bases:
enum.Enum
Defines the types of units for training steps.
- ITERATION = 0¶
- EPOCH = 1¶
- classmethod from_str(label: str) TrainingUnit [source]¶
Convert the given string to the corresponding enum value.
Must be one of “iteration” or “epoch”.
- Parameters:
label – the label to convert.
- Returns:
The corresponding value.
- Raises:
ValueError – if the given value is not one of “iteration” or “epoch”.
- class helios.trainer.TrainingState[source]¶
The training state.
- Parameters:
current_iteration – the current iteration number. Note that this refers to the iteration in which gradients are updated. This may or may not be equal to the
global_iteration
count.global_iteration – the total iteration count.
global_epoch – the total epoch count.
validation_cycles – the number of validation cycles.
dataset_iter – the current batch index of the dataset. This is reset every epoch.
early_stop_count – the current number of validation cycles for early stop.
average_iter_time – average time per iteration.
running_iter – iteration count in the current validation cycle. Useful for computing running averages of loss functions.
- current_iteration: int = 0¶
- global_iteration: int = 0¶
- global_epoch: int = 0¶
- validation_cycles: int = 0¶
- dataset_iter: int = 0¶
- early_stop_count: int = 0¶
- average_iter_time: float = 0¶
- running_iter: int = 0¶
- dict¶
- helios.trainer.get_trainer_safe_types_for_load() list[type] [source]¶
Return the list of safe types for loading needed by the trainer.
- Returns:
The list of types that need to be registered for safe loading.
- helios.trainer.register_trainer_types_for_safe_load() None [source]¶
Register trainer types for safe loading.
- helios.trainer.find_last_checkpoint(root: pathlib.Path | None) pathlib.Path | None [source]¶
Find the last saved checkpoint (if available).
The function assumes that checkpoint names contain
epoch_<epoch>
anditer_<iter>
in the name, in which case it will return the path to the checkpoint with the highest epoch and/or iteration count.- Parameters:
root – the path where the checkpoints are stored.
- Returns:
The path to the last checkpoint (if any).
- class helios.trainer.Trainer(run_name: str = '', train_unit: TrainingUnit | str = TrainingUnit.EPOCH, total_steps: int | float = 0, valid_frequency: int | None = None, chkpt_frequency: int | None = None, print_frequency: int | None = None, accumulation_steps: int = 1, enable_cudnn_benchmark: bool = False, enable_deterministic: bool = False, early_stop_cycles: int | None = None, use_cpu: bool | None = None, gpus: list[int] | None = None, random_seed: int | None = None, enable_tensorboard: bool = False, enable_file_logging: bool = False, enable_progress_bar: bool = False, chkpt_root: pathlib.Path | None = None, log_path: pathlib.Path | None = None, run_path: pathlib.Path | None = None, src_root: pathlib.Path | None = None, import_prefix: str = '', print_banner: bool = True)[source]¶
Automates the training, validation, and testing code.
The trainer handles all of the required steps to setup the correct environment (including handling distributed training), the training/validation/testing loops, and any clean up afterwards.
- Parameters:
run_name – name of the current run. Defaults to empty.
train_unit – the unit used for training. Defaults to
TrainingUnit.ITERATION
.total_steps – the total number of steps to train for. Defaults to 0.
valid_frequency – (optional) frequency with which to perform validation.
chkpt_frequency – (optional) frequency with which to save checkpoints.
print_frequency – (optional) frequency with which to log.
accumulation_steps – number of steps for gradient accumulation. Defaults to 1.
enable_cudnn_benchmark – enable/disable CuDNN benchmark. Defaults to false.
enable_deterministic – enable/disable PyTorch deterministic. Defaults to false.
early_stop_cycles – (optional) number of cycles after which training will stop if no improvement is seen during validation.
use_cpu – (optional) if true, CPU will be used.
gpus – (optional) IDs of GPUs to use.
random_seed – (optional) the seed to use for RNGs.
enable_tensorboard – enable/disable Tensorboard logging. Defaults to false.
enable_file_logging – enable/disable file logging. Defaults to false.
enable_progress_bar – enable/disable the progress bar(s). Defaults to false.
chkpt_root – (optional) root folder in which checkpoints will be placed.
log_path – (optional) root folder in which logs will be saved.
run_path – (optional) root folder in which Tensorboard runs will be saved.
src_root – (optional) root folder where the code is located. This is used to automatically populate the registries using
update_all_registries()
.import_prefix – prefix to use when importing modules. See
update_all_registries()
for details.print_banner – if true, the Helios banner with system info will be printed. Defaults to true.
- property model: helios.model.Model¶
Return the model.
- property datamodule: helios.data.DataModule¶
Return the datamodule.
- property local_rank: int¶
Return the local rank of the trainer.
- property rank: int¶
Return the global rank of the trainer.
- property gpu_ids: list[int]¶
Return the list of GPU IDs to use for training.
- property train_exceptions: list[type[Exception]]¶
Return the list of valid exceptions for training.
- property test_exceptions: list[type[Exception]]¶
Return the list of valid exceptions for testing.
- property plugins: dict[str, helios.plugins.Plugin]¶
Return the list of plug-ins.
- property queue: torch.multiprocessing.Queue | None¶
Return the multi-processing queue instance.
Note
If training isn’t distributed or if torchrun, then None is returned instead.