helios.trainer

Classes

TrainingUnit

Defines the types of units for training steps.

TrainingState

The training state.

Trainer

Automates the training, validation, and testing code.

Functions

get_trainer_safe_types_for_load(→ list[type])

Return the list of safe types for loading needed by the trainer.

register_trainer_types_for_safe_load(→ None)

Register trainer types for safe loading.

find_last_checkpoint(→ pathlib.Path | None)

Find the last saved checkpoint (if available).

Module Contents

class helios.trainer.TrainingUnit(*args, **kwds)[source]

Bases: enum.Enum

Defines the types of units for training steps.

ITERATION = 0
EPOCH = 1
classmethod from_str(label: str) TrainingUnit[source]

Convert the given string to the corresponding enum value.

Must be one of “iteration” or “epoch”.

Parameters:

label – the label to convert.

Returns:

The corresponding value.

Raises:

ValueError – if the given value is not one of “iteration” or “epoch”.

class helios.trainer.TrainingState[source]

The training state.

Parameters:
  • current_iteration – the current iteration number. Note that this refers to the iteration in which gradients are updated. This may or may not be equal to the global_iteration count.

  • global_iteration – the total iteration count.

  • global_epoch – the total epoch count.

  • validation_cycles – the number of validation cycles.

  • dataset_iter – the current batch index of the dataset. This is reset every epoch.

  • early_stop_count – the current number of validation cycles for early stop.

  • average_iter_time – average time per iteration.

  • running_iter – iteration count in the current validation cycle. Useful for computing running averages of loss functions.

current_iteration: int = 0
global_iteration: int = 0
global_epoch: int = 0
validation_cycles: int = 0
dataset_iter: int = 0
early_stop_count: int = 0
average_iter_time: float = 0
running_iter: int = 0
dict
helios.trainer.get_trainer_safe_types_for_load() list[type][source]

Return the list of safe types for loading needed by the trainer.

Returns:

The list of types that need to be registered for safe loading.

helios.trainer.register_trainer_types_for_safe_load() None[source]

Register trainer types for safe loading.

helios.trainer.find_last_checkpoint(root: pathlib.Path | None) pathlib.Path | None[source]

Find the last saved checkpoint (if available).

The function assumes that checkpoint names contain epoch_<epoch> and iter_<iter> in the name, in which case it will return the path to the checkpoint with the highest epoch and/or iteration count.

Parameters:

root – the path where the checkpoints are stored.

Returns:

The path to the last checkpoint (if any).

class helios.trainer.Trainer(run_name: str = '', train_unit: TrainingUnit | str = TrainingUnit.EPOCH, total_steps: int | float = 0, valid_frequency: int | None = None, chkpt_frequency: int | None = None, print_frequency: int | None = None, accumulation_steps: int = 1, enable_cudnn_benchmark: bool = False, enable_deterministic: bool = False, early_stop_cycles: int | None = None, use_cpu: bool | None = None, gpus: list[int] | None = None, random_seed: int | None = None, enable_tensorboard: bool = False, enable_file_logging: bool = False, enable_progress_bar: bool = False, chkpt_root: pathlib.Path | None = None, log_path: pathlib.Path | None = None, run_path: pathlib.Path | None = None, src_root: pathlib.Path | None = None, import_prefix: str = '', print_banner: bool = True)[source]

Automates the training, validation, and testing code.

The trainer handles all of the required steps to setup the correct environment (including handling distributed training), the training/validation/testing loops, and any clean up afterwards.

Parameters:
  • run_name – name of the current run. Defaults to empty.

  • train_unit – the unit used for training. Defaults to TrainingUnit.ITERATION.

  • total_steps – the total number of steps to train for. Defaults to 0.

  • valid_frequency – (optional) frequency with which to perform validation.

  • chkpt_frequency – (optional) frequency with which to save checkpoints.

  • print_frequency – (optional) frequency with which to log.

  • accumulation_steps – number of steps for gradient accumulation. Defaults to 1.

  • enable_cudnn_benchmark – enable/disable CuDNN benchmark. Defaults to false.

  • enable_deterministic – enable/disable PyTorch deterministic. Defaults to false.

  • early_stop_cycles – (optional) number of cycles after which training will stop if no improvement is seen during validation.

  • use_cpu – (optional) if true, CPU will be used.

  • gpus – (optional) IDs of GPUs to use.

  • random_seed – (optional) the seed to use for RNGs.

  • enable_tensorboard – enable/disable Tensorboard logging. Defaults to false.

  • enable_file_logging – enable/disable file logging. Defaults to false.

  • enable_progress_bar – enable/disable the progress bar(s). Defaults to false.

  • chkpt_root – (optional) root folder in which checkpoints will be placed.

  • log_path – (optional) root folder in which logs will be saved.

  • run_path – (optional) root folder in which Tensorboard runs will be saved.

  • src_root – (optional) root folder where the code is located. This is used to automatically populate the registries using update_all_registries().

  • import_prefix – prefix to use when importing modules. See update_all_registries() for details.

  • print_banner – if true, the Helios banner with system info will be printed. Defaults to true.

property model: helios.model.Model

Return the model.

property datamodule: helios.data.DataModule

Return the datamodule.

property local_rank: int

Return the local rank of the trainer.

property rank: int

Return the global rank of the trainer.

property gpu_ids: list[int]

Return the list of GPU IDs to use for training.

property train_exceptions: list[type[Exception]]

Return the list of valid exceptions for training.

property test_exceptions: list[type[Exception]]

Return the list of valid exceptions for testing.

property plugins: dict[str, helios.plugins.Plugin]

Return the list of plug-ins.

property queue: torch.multiprocessing.Queue | None

Return the multi-processing queue instance.

Note

If training isn’t distributed or if torchrun, then None is returned instead.

fit(model: helios.model.Model, datamodule: helios.data.DataModule) bool[source]

Run the full training routine.

Parameters:
  • model – the model to run on.

  • datamodule – the datamodule to use.

Returns:

True if the training process completed successfully, false otherwise.

test(model: helios.model.Model, datamodule: helios.data.DataModule) bool[source]

Run the full testing routine.

Parameters:
  • model – the model to run on.

  • datamodule – the datamodule to use.

Returns:

True if the training process completed successfully, false otherwise.