Source code for helios.trainer

from __future__ import annotations

import dataclasses as dc
import enum
import itertools
import os
import pathlib
import re
import time
import typing

import packaging.version as pv
import torch
import torch.multiprocessing as mp
import torch.utils.data as tud
import tqdm

import helios.model as hlm
import helios.plugins as hlp
from helios import core, data
from helios.core import distributed as dist
from helios.core import logging, rng
from helios.data.samplers import ResumableSamplerType

from ._version import __version__


[docs] class TrainingUnit(enum.Enum): """Defines the types of units for training steps.""" ITERATION = 0 EPOCH = 1
[docs] @classmethod def from_str(cls, label: str) -> TrainingUnit: """ Convert the given string to the corresponding enum value. Must be one of "iteration" or "epoch". Args: label: the label to convert. Returns: The corresponding value. Raises: ValueError: if the given value is not one of "iteration" or "epoch". """ if label == "iteration": return cls.ITERATION if label == "epoch": return cls.EPOCH raise ValueError( "invalid training unit. Expected one of 'iteration' or 'epoch' but " f"received '{label}'" )
class _TrainerMode(enum.Enum): """ Defines the types of actions that the trainer can do. These values are used internally to figure out which function should be invoked from the distributed handler. """ TRAIN = 0 TEST = 1 @dc.dataclass class _DistributedErrorState: """ Holds the state necessary to correctly handle errors in distributed training. This class is pushed into the internal error handling queue and used by the main process when an error occurs to ensure the exception is handled correctly. Args: log_path: the path used by the root logger (if using). """ log_path: pathlib.Path | None = None
[docs] @dc.dataclass class TrainingState: """ The training state. Args: current_iteration: the current iteration number. Note that this refers to the iteration in which gradients are updated. This may or may not be equal to the :py:attr:`global_iteration` count. global_iteration: the total iteration count. global_epoch: the total epoch count. validation_cycles: the number of validation cycles. dataset_iter: the current batch index of the dataset. This is reset every epoch. early_stop_count: the current number of validation cycles for early stop. average_iter_time: average time per iteration. running_iter: iteration count in the current validation cycle. Useful for computing running averages of loss functions. """ current_iteration: int = 0 global_iteration: int = 0 global_epoch: int = 0 validation_cycles: int = 0 dataset_iter: int = 0 early_stop_count: int = 0 average_iter_time: float = 0 running_iter: int = 0 dict = dc.asdict
[docs] def get_trainer_safe_types_for_load() -> list[type]: """ Return the list of safe types for loading needed by the trainer. Returns: The list of types that need to be registered for safe loading. """ return [TrainingState, pathlib.PosixPath, pathlib.WindowsPath]
[docs] def register_trainer_types_for_safe_load() -> None: """Register trainer types for safe loading.""" torch.serialization.add_safe_globals(get_trainer_safe_types_for_load())
[docs] def find_last_checkpoint(root: pathlib.Path | None) -> pathlib.Path | None: """ Find the last saved checkpoint (if available). The function assumes that checkpoint names contain ``epoch_<epoch>`` and ``iter_<iter>`` in the name, in which case it will return the path to the checkpoint with the highest epoch and/or iteration count. Args: root: the path where the checkpoints are stored. Returns: The path to the last checkpoint (if any). """ if root is None: return None epoch = 0 ite = 0 last_chkpt = None regexp = re.compile(r"epoch_\d+_iter_\d+") for path in root.glob("*.pth"): if not regexp.search(path.stem): continue elems = str(path.stem).split("_") idx = elems.index("epoch") + 1 e = int(elems[idx]) idx = elems.index("iter") + 1 i = int(elems[idx]) if e > epoch or i > ite: epoch = e ite = i last_chkpt = path return last_chkpt
def _spawn_handler( rank: int, world_size: int, trainer: Trainer, datamodule: data.DataModule, model: hlm.Model, mode: _TrainerMode, queue: mp.Queue, error_queue: mp.Queue, ) -> None: """ Spawn handler for distributed training. Args: rank: the rank id of the current process within the work group. world_size: number of processes in the work group. trainer: the trainer to use. datamodule: the datamodule to use. model: the model to use. mode: determines which operation needs to be performed. """ dist.init_dist(rank=rank, world_size=world_size) trainer.model = model trainer.datamodule = datamodule trainer.rank = rank trainer.local_rank = rank trainer.queue = queue trainer._distributed_error_queue = error_queue # noqa: SLF001 try: if mode == _TrainerMode.TRAIN: trainer._train() # noqa: SLF001 elif mode == _TrainerMode.TEST: trainer._test() # noqa: SLF001 except Exception as e: dist.shutdown_dist() raise e dist.safe_barrier() dist.shutdown_dist()
[docs] class Trainer: """ Automates the training, validation, and testing code. The trainer handles all of the required steps to setup the correct environment (including handling distributed training), the training/validation/testing loops, and any clean up afterwards. Args: run_name: name of the current run. Defaults to empty. train_unit: the unit used for training. Defaults to :py:attr:`TrainingUnit.ITERATION`. total_steps: the total number of steps to train for. Defaults to 0. valid_frequency: (optional) frequency with which to perform validation. chkpt_frequency: (optional) frequency with which to save checkpoints. print_frequency: (optional) frequency with which to log. accumulation_steps: number of steps for gradient accumulation. Defaults to 1. enable_cudnn_benchmark: enable/disable CuDNN benchmark. Defaults to false. enable_deterministic: enable/disable PyTorch deterministic. Defaults to false. early_stop_cycles: (optional) number of cycles after which training will stop if no improvement is seen during validation. use_cpu: (optional) if true, CPU will be used. gpus: (optional) IDs of GPUs to use. random_seed: (optional) the seed to use for RNGs. enable_tensorboard: enable/disable Tensorboard logging. Defaults to false. enable_file_logging: enable/disable file logging. Defaults to false. enable_progress_bar: enable/disable the progress bar(s). Defaults to false. chkpt_root: (optional) root folder in which checkpoints will be placed. log_path: (optional) root folder in which logs will be saved. run_path: (optional) root folder in which Tensorboard runs will be saved. src_root: (optional) root folder where the code is located. This is used to automatically populate the registries using :py:func:`~helios.core.utils.update_all_registries`. import_prefix: prefix to use when importing modules. See :py:func:`~helios.core.utils.update_all_registries` for details. print_banner: if true, the Helios banner with system info will be printed. Defaults to true. """ def __init__( self, run_name: str = "", train_unit: TrainingUnit | str = TrainingUnit.EPOCH, total_steps: int | float = 0, valid_frequency: int | None = None, chkpt_frequency: int | None = None, print_frequency: int | None = None, accumulation_steps: int = 1, enable_cudnn_benchmark: bool = False, enable_deterministic: bool = False, early_stop_cycles: int | None = None, use_cpu: bool | None = None, gpus: list[int] | None = None, random_seed: int | None = None, enable_tensorboard: bool = False, enable_file_logging: bool = False, enable_progress_bar: bool = False, chkpt_root: pathlib.Path | None = None, log_path: pathlib.Path | None = None, run_path: pathlib.Path | None = None, src_root: pathlib.Path | None = None, import_prefix: str = "", print_banner: bool = True, ): """Create the trainer.""" self._model: hlm.Model | None = None self._datamodule: data.DataModule | None = None self._local_rank: int = 0 self._rank: int = 0 self._use_cpu: bool = False self._device: torch.device | None = None self._map_loc: str | dict[str, str] = "" self._gpu_ids: list[int] = [] if gpus is None else gpus self._active_gpu: int = 0 self._is_distributed: bool = False self._is_torchrun: bool = dist.is_using_torchrun() if isinstance(train_unit, str): train_unit = TrainingUnit.from_str(train_unit) self._train_unit = train_unit self._total_steps = total_steps self._accumulation_steps = accumulation_steps self._valid_frequency = valid_frequency self._chkpt_frequency = chkpt_frequency self._print_frequency = print_frequency self._enable_cudnn_benchmark = enable_cudnn_benchmark self._enable_deterministic = enable_deterministic self._early_stop_cycles = early_stop_cycles self._enable_tensorboard = enable_tensorboard self._enable_file_logging = enable_file_logging self._random_seed = rng.get_default_seed() if random_seed is None else random_seed self._enable_progress_bar = enable_progress_bar self._chkpt_root = chkpt_root self._log_path = log_path self._run_path = run_path self._src_root = src_root self._import_prefix = import_prefix self._run_name = run_name self._print_banner = print_banner self._train_exceptions: list[type[Exception]] = [] self._test_exceptions: list[type[Exception]] = [] self._plugins: dict[str, hlp.Plugin] = {} self._queue: mp.Queue | None = None self._distributed_error_queue: mp.Queue | None = None self._validate_flags(use_cpu) self._setup_device_flags(use_cpu) @property def model(self) -> hlm.Model: """Return the model.""" return core.get_from_optional(self._model) @model.setter def model(self, model: hlm.Model) -> None: self._model = model @property def datamodule(self) -> data.DataModule: """Return the datamodule.""" return core.get_from_optional(self._datamodule) @datamodule.setter def datamodule(self, datamodule: data.DataModule) -> None: self._datamodule = datamodule @property def local_rank(self) -> int: """Return the local rank of the trainer.""" return self._local_rank @local_rank.setter def local_rank(self, r) -> None: self._local_rank = r if not self._use_cpu: self._active_gpu = self._gpu_ids[r] @property def rank(self) -> int: """Return the global rank of the trainer.""" return self._rank @rank.setter def rank(self, r) -> None: self._rank = r @property def gpu_ids(self) -> list[int]: """Return the list of GPU IDs to use for training.""" return self._gpu_ids @property def train_exceptions(self) -> list[type[Exception]]: """Return the list of valid exceptions for training.""" return self._train_exceptions @train_exceptions.setter def train_exceptions(self, exc: list[type[Exception]]) -> None: self._train_exceptions = exc @property def test_exceptions(self) -> list[type[Exception]]: """Return the list of valid exceptions for testing.""" return self._test_exceptions @test_exceptions.setter def test_exceptions(self, exc: list[type[Exception]]) -> None: self._test_exceptions = exc @property def plugins(self) -> dict[str, hlp.Plugin]: """Return the list of plug-ins.""" return self._plugins @plugins.setter def plugins(self, plugs: dict[str, hlp.Plugin]) -> None: self._plugins = plugs @property def queue(self) -> mp.Queue | None: """ Return the multi-processing queue instance. .. note:: If training isn't distributed or if `torchrun`, then `None` is returned instead. """ return self._queue @queue.setter def queue(self, q: mp.Queue) -> None: self._queue = q
[docs] def fit(self, model: hlm.Model, datamodule: data.DataModule) -> bool: """ Run the full training routine. Args: model: the model to run on. datamodule: the datamodule to use. Returns: True if the training process completed successfully, false otherwise. """ try: self._launch(model, datamodule, _TrainerMode.TRAIN) except Exception as e: if not self._handle_exception(e, _TrainerMode.TRAIN): raise e return False return True
[docs] def test(self, model: hlm.Model, datamodule: data.DataModule) -> bool: """ Run the full testing routine. Args: model: the model to run on. datamodule: the datamodule to use. Returns: True if the training process completed successfully, false otherwise. """ try: self._launch(model, datamodule, _TrainerMode.TEST) except Exception as e: if not self._handle_exception(e, _TrainerMode.TEST): raise e return False return True
def _handle_exception(self, e: Exception, mode: _TrainerMode) -> bool: """ Exception handler. Args: e: the raised exception. Returns: False if the exception should be allowed to continue up the stack. If true, the exception has been handled and should not be re-raised. """ exc_list = ( self._train_exceptions if mode == _TrainerMode.TRAIN else self._test_exceptions ) if any(isinstance(e, exc) for exc in exc_list): logging.close_default_loggers() return False if logging.is_root_logger_active(): root_logger = logging.get_root_logger() root_logger.exception("error: uncaught exception") logging.close_default_loggers() return True def _configure_env_for_distributed_error_handling(self) -> None: assert self._distributed_error_queue is not None if self._distributed_error_queue.empty(): return state: _DistributedErrorState = self._distributed_error_queue.get_nowait() if self._enable_file_logging and state.log_path is not None: logging.create_default_loggers(enable_tensorboard=False) logging.restore_default_loggers(log_path=state.log_path) def _push_distributed_error_state(self, state: _DistributedErrorState) -> None: if self._distributed_error_queue is None: return self._distributed_error_queue.put(state) def _launch( self, model: hlm.Model, datamodule: data.DataModule, mode: _TrainerMode ) -> None: """ Launch the function corresponding to the given mode. If distributed training is used, this will spawn the processes and call the handler. Args: model: the model to use. datamodule: the datamodule to use. mode:: the operation to perform. """ datamodule.prepare_data() if self._is_distributed and not self._is_torchrun: if mp.get_start_method(allow_none=True) is None: mp.set_start_method("spawn") queue: mp.Queue = mp.Queue() error_queue: mp.Queue = mp.Queue() world_size = len(self._gpu_ids) try: mp.spawn( _spawn_handler, args=( world_size, self, datamodule, model, mode, queue, error_queue, ), nprocs=world_size, join=True, ) except Exception as e: self._distributed_error_queue = error_queue self._configure_env_for_distributed_error_handling() raise e self.queue = queue return if self._is_torchrun and self._is_distributed: dist.init_dist() self.model = model self.datamodule = datamodule self.rank = dist.get_global_rank() self.local_rank = dist.get_local_rank() if mode == _TrainerMode.TRAIN: self._train() elif mode == _TrainerMode.TEST: self._test() if self._is_torchrun and self._is_distributed: dist.shutdown_dist() logging.close_default_loggers() def _train(self) -> None: """ Train the model. This will ensure everything gets correctly initialised as well as select the appropriate training loop for the given training unit. """ self._configure_env() self._setup_plugins() self._setup_datamodule() self._setup_model() self._prepare_roots() chkpt_path = find_last_checkpoint(self._chkpt_root) training_state = self._load_checkpoint(chkpt_path) log_path = logging.get_root_logger().log_file self._push_distributed_error_state(_DistributedErrorState(log_path=log_path)) self._print_header(chkpt_path) self._execute_plugins("on_training_start") self.model.on_training_start() if self._train_unit == TrainingUnit.ITERATION: self._train_on_iteration(training_state) else: self._train_on_epoch(training_state) self._execute_plugins("on_training_end") self.model.on_training_end() logging.flush_default_loggers() logging.close_default_loggers() # If we're distributed, ensure that all processes are caught up before we exit. dist.safe_barrier() def _test(self) -> None: """ Test the model. This will ensure everything gets correctly initialised and run the testing loop on the dataset. It will automatically try to load the last saved checkpoint for testing provided there is one. If no checkpoints are available, it is assumed the model is loading the correct state internally. """ self._configure_env() self._setup_plugins() self._setup_datamodule() self._setup_model(fast_init=True) self._prepare_roots(mkdir=False) chkpt_path: pathlib.Path | None = None loaded: bool = False if self._chkpt_root is not None: chkpt_path = find_last_checkpoint(core.get_from_optional(self._chkpt_root)) loaded = ( self._load_checkpoint(chkpt_path, skip_rng=True, model_fast_init=True) != TrainingState() ) # We failed to load the last checkpoint, so tell the model to load its state. if self._chkpt_root is None or not loaded: self.model.load_for_testing() log_path = logging.get_root_logger().log_file self._push_distributed_error_state(_DistributedErrorState(log_path=log_path)) self._print_header(chkpt_path, for_training=False) if self.datamodule.test_dataloader() is None: return dataloader: tud.DataLoader sampler: ResumableSamplerType dataloader, sampler = core.get_from_optional(self.datamodule.test_dataloader()) self._execute_plugins("on_testing_start") self.model.on_testing_start() enable_progress_bar = self._enable_progress_bar pbar_disabled = ( self._is_distributed and self.rank != 0 ) or not enable_progress_bar pbar = tqdm.tqdm( total=len(dataloader), desc="Testing", unit="it", disable=pbar_disabled, leave=False, ) with core.cuda.DisableCuDNNBenchmarkContext(): self.model.eval() with torch.no_grad(): for idx, batch in enumerate(dataloader): batch = self._plugins_process_batch("testing", batch, step=idx) self.model.on_testing_batch_start(idx) self.model.test_step(batch, idx) self.model.on_testing_batch_end(idx) pbar.update() self._execute_plugins("on_testing_end") self.model.on_testing_end() dist.safe_barrier() def _configure_env(self) -> None: """ Configure the training environment. This will seed the RNGs as well as setup any CUDA state (if using). It will also set all of the registries provided the source root is not None. This is to prevent the registries from being empty if distributed training is launched through spawn (note that ``torchrun`` doesn't have this problem). """ register_trainer_types_for_safe_load() rng.seed_rngs(self._random_seed) torch.use_deterministic_algorithms(self._enable_deterministic) if not self._use_cpu: self._device = torch.device(f"cuda:{self._active_gpu}") self._map_loc = {"cuda:0": f"cuda:{self._active_gpu}"} torch.backends.cudnn.benchmark = self._enable_cudnn_benchmark torch.cuda.set_device(self._device) logging.create_default_loggers(self._enable_tensorboard) if self._src_root is not None: core.update_all_registries( self._src_root, recurse=True, import_prefix=self._import_prefix ) def _setup_datamodule(self) -> None: """Finish setting up the datamodule.""" self.datamodule.is_distributed = self._is_distributed self.datamodule.trainer = self self.datamodule.setup() def _setup_model(self, fast_init: bool = False) -> None: """ Finish setting up the model. Args: fast_init: whether the model should setup its full state or not. """ self.model.map_loc = self._map_loc self.model.is_distributed = self._is_distributed self.model.device = core.get_from_optional(self._device) self.model.rank = self.local_rank self.model.trainer = self self.model.setup(fast_init) def _setup_plugins(self) -> None: """Finish setting up the plug-ins.""" for plugin in self._plugins.values(): plugin.is_distributed = self._is_distributed plugin.map_loc = self._map_loc plugin.device = core.get_from_optional(self._device) plugin.rank = self.local_rank plugin.trainer = self plugin.setup() def _prepare_roots(self, mkdir=True) -> None: """Prepare the training roots.""" name = self.model.save_name self._chkpt_root = ( self._chkpt_root / name if self._chkpt_root is not None else None ) if mkdir: if self._chkpt_root is not None: self._chkpt_root.mkdir(parents=True, exist_ok=True) if self._log_path is not None: self._log_path.mkdir(parents=True, exist_ok=True) if self._run_path is not None: self._run_path.mkdir(parents=True, exist_ok=True) def _print_header( self, chkpt_path: pathlib.Path | None, for_training: bool = True ) -> None: """Print the Helios header with system info to the logs.""" root_logger = logging.get_root_logger() model = core.get_from_optional(self._model) banner = model.append_to_banner(core.get_env_info_str()) if self._print_banner: dist.global_print(banner) if for_training: if chkpt_path is not None: msg = f"Resuming training from checkpoint {str(chkpt_path)}" root_logger.info(msg) dist.global_print(f"{msg}\n") elif self._print_banner: root_logger.info(banner) else: root_logger.info(core.get_env_info_str()) msg = ( f"Testing using checkpoint {str(chkpt_path)}" if chkpt_path is not None else "Testing from loaded model" ) root_logger.info(msg) dist.global_print(f"{msg}\n") def _validate_flags(self, use_cpu: bool | None): """Ensure that all the settings and flags are valid.""" if isinstance(self._total_steps, float) and self._total_steps != float("inf"): raise ValueError( "error: expected 'total_steps' to be of type 'int' or 'infinity', but " f"received {self._total_steps}" ) if use_cpu is not None and use_cpu and len(self._gpu_ids) > 0: raise ValueError("error: cannot request CPU and GPU training") if self._chkpt_frequency is not None and self._chkpt_frequency == 0: raise ValueError("error: checkpoint frequency must be greater than 0 or None") if self._print_frequency is not None and self._print_frequency == 0: raise ValueError("error: print frequency must be greater than 0 or None") if self._valid_frequency is not None and self._valid_frequency == 0: raise ValueError("error: valid frequency must be greater than 0 or None") if self._enable_deterministic and self._enable_cudnn_benchmark: raise ValueError( "error: CUDA benchmark and deterministic flags are mutually exclusive" ) if self._total_steps == float("inf") and self._early_stop_cycles == 0: raise ValueError( f"error: given 'total_steps' with value {self._total_steps}, " "'early_stop_cycles' must be non-zero" ) if self._enable_tensorboard: if self._run_path is None: raise ValueError( "error: Tensorboard requested but no run directory was given" ) if self._run_path.exists() and not self._run_path.is_dir(): raise ValueError("error: run path must be a directory") if self._enable_file_logging: if self._log_path is None: raise ValueError( "error: file logging requested but no log directory was given" ) if self._log_path.exists() and not self._log_path.is_dir(): raise ValueError("error: log path must be a directory") if self._src_root is not None and not self._src_root.is_dir(): raise ValueError("error: source root must be a directory") def _setup_device_flags(self, use_cpu: bool | None): """ Configure the device state. If the CPU is being used, this will automatically set the correct settings. If the GPU will be used, then it will only verify that the GPU IDs are correct. The remaining state will be set afterwards. The ``use_cpu`` flag is used to determine whether the CPU will be used for training. If it is ``None``, then the value is determined by whether CUDA is available. Args: use_cpu: whether to use the CPU or not. """ if use_cpu is None: use_cpu = not torch.cuda.is_available() if use_cpu: self._use_cpu = True self._device = torch.device("cpu") self._map_loc = {"cuda:0": "cpu"} self._gpu_ids = [] self._is_distributed = False return if not torch.cuda.is_available(): raise RuntimeError( "error: CUDA usage is requested, but CUDA is not available" ) # At this point we know that CUDA exists and that we're supposed to use it. For # now, just verify that the GPU IDs are valid, but don't set the device or the map # location. Those need to be set after we launch distributed training (if using) # to ensure they get set to the correct thing. valid_ids = list(range(torch.cuda.device_count())) if len(self._gpu_ids) == 0: self._gpu_ids = valid_ids if len(self._gpu_ids) > len(valid_ids): raise ValueError( f"error: expected a maximum of {len(valid_ids)} GPU IDs but " f"received {len(self._gpu_ids)}" ) for gpu_id in self._gpu_ids: if gpu_id not in valid_ids: raise ValueError(f"error: {gpu_id} is not a valid GPU") self._is_distributed = ( len(self._gpu_ids) > 1 if not self._is_torchrun else int(os.environ["WORLD_SIZE"]) > 1 ) def _validate_state_dict(self, state_dict: dict[str, typing.Any]) -> bool: """ Ensure the state table is valid. This is to handle the eventuality that someone tries to load a checkpoint that wasn't generated by the trainer. It checks that all the required keys appear, and it also does a version check. Mainly: we guarantee support for checkpoints generated in versions less than or equal to the current version. Args: state_dict: the state dictionary. Returns: True if the checkpoint is valid, false otherwise. """ required_keys = ("version", "training_state", "model", "rng") if not all(key in state_dict for key in required_keys): return False # Now check the version to see if it's compatible with us. cur_ver = pv.Version(__version__) chkpt_ver = pv.Version(state_dict["version"]) return chkpt_ver <= cur_ver def _save_checkpoint(self, state: TrainingState) -> None: """ Save the current training state to a checkpoint. This will automatically save the training state, RNG state, as well as the model state. Args: state: the current training state. """ chkpt_root = core.get_from_optional(self._chkpt_root) epoch = state.global_epoch ite = state.current_iteration filename = f"{self.model.save_name}_epoch_{epoch}_iter_{ite}" filename = self.model.append_metadata_to_chkpt_name(filename) filename += ".pth" state_dict: dict[str, typing.Any] = {} state_dict["version"] = __version__ state_dict["training_state"] = state state_dict["model"] = self.model.state_dict() state_dict["rng"] = rng.get_rng_state_dict() if self._enable_file_logging: state_dict["log_path"] = logging.get_root_logger().log_file if self._enable_tensorboard: writer = core.get_from_optional(logging.get_tensorboard_writer()) state_dict["run_path"] = writer.run_path # Add the plug-ins (if using) for plug_id, plugin in self._plugins.items(): state_dict[plug_id] = plugin.state_dict() # Safety check. assert self._validate_state_dict(state_dict) torch.save(state_dict, chkpt_root / filename) def _load_checkpoint( self, chkpt_path: pathlib.Path | None, skip_rng: bool = False, model_fast_init: bool = False, ) -> TrainingState: """ Load the given checkpoint. Args: chkpt_path: path to the checkpoint to load. skip_rng: if True, skip the loading of the RNG states. model_fast_init: whether the model should setup its full state or not. Returns: Returns the loaded training state and ``True`` if the checkpoint was loaded successfully. Otherwise it returns an empty training state and ``False``. """ if chkpt_path is None: logging.setup_default_loggers(self._run_name, self._log_path, self._run_path) return TrainingState() state_dict = core.safe_torch_load(chkpt_path, map_location=self._map_loc) if not self._validate_state_dict(state_dict): raise RuntimeError( f"error: the checkpoint found at {str(chkpt_path)} is not a " "valid checkpoint generated by Helios" ) logging.restore_default_loggers( state_dict.get("log_path", None), state_dict.get("run_path", None) ) if not skip_rng: rng.load_rng_state_dict(state_dict["rng"]) self.model.load_state_dict(state_dict["model"], fast_init=model_fast_init) for plug_id, plugin in self._plugins.items(): if plug_id in state_dict: plugin.load_state_dict(state_dict[plug_id]) return state_dict["training_state"] def _train_on_iteration(self, state: TrainingState) -> None: """ Run the main loop for iteration-based training. Args: state: the training state. """ total_steps = self._total_steps save_freq = self._chkpt_frequency val_freq = self._valid_frequency print_freq = self._print_frequency accumulation_steps = self._accumulation_steps enable_progress_bar = self._enable_progress_bar early_stop_cycles = self._early_stop_cycles current_iteration_changed: bool = True training_done: bool = False root_logger = logging.get_root_logger() iter_timer = core.AverageTimer() pbar_disabled = ( self._is_distributed and self.rank != 0 ) or not enable_progress_bar pbar = tqdm.tqdm( total=total_steps if total_steps != float("inf") else None, desc="Training iterations", unit="it", disable=pbar_disabled, initial=state.current_iteration, ) dataloader: tud.DataLoader sampler: ResumableSamplerType dataloader, sampler = core.get_from_optional(self.datamodule.train_dataloader()) sampler.start_iter = state.dataset_iter self.model.train() for epoch in itertools.count(start=state.global_epoch): if training_done: break state.global_epoch += 1 root_logger.info(f"Starting epoch {epoch + 1}") sampler.set_epoch(epoch) epoch_start = time.time() self.model.on_training_epoch_start(state.global_epoch) iter_timer.start() for batch in dataloader: state.global_iteration += 1 if state.global_iteration % accumulation_steps == 0: state.current_iteration += 1 state.running_iter += 1 current_iteration_changed = True else: current_iteration_changed = False batch = self._plugins_process_batch("training", batch, state=state) self.model.on_training_batch_start(state) self.model.train_step(batch, state) iter_timer.record() state.average_iter_time = iter_timer.get_average_time() self.model.on_training_batch_end( state, should_log=( False if print_freq is None else state.current_iteration % print_freq == 0 and current_iteration_changed ), ) # Depending on how fast the iteration loop is, it is possible that the # progress bar isn't refreshed every tick, so make sure it gets re-drawn. if state.global_iteration % accumulation_steps == 0 and not pbar.update(): pbar.refresh() state.dataset_iter += 1 if ( self._plugins_should_training_stop() or self.model.should_training_stop() ): training_done = True break if ( val_freq is not None and state.current_iteration % val_freq == 0 and current_iteration_changed ): self._validate(state.validation_cycles) state.validation_cycles += 1 state.running_iter = 0 if not self.model.have_metrics_improved(): state.early_stop_count += 1 else: state.early_stop_count = 0 if ( save_freq is not None and state.current_iteration % save_freq == 0 and self.rank == 0 and current_iteration_changed ): self._save_checkpoint(state) if ( early_stop_cycles is not None and state.early_stop_count >= early_stop_cycles ): training_done = True break if ( self._plugins_should_training_stop() or self.model.should_training_stop() ): training_done = True break if state.current_iteration >= total_steps: training_done = True break state.dataset_iter = 0 self.model.on_training_epoch_end(state.global_epoch) root_logger.info( f"Epoch {epoch + 1} completed in {time.time() - epoch_start:.2f}s" ) def _train_on_epoch(self, state: TrainingState) -> None: """ Run the main loop for epoch-based training. Args: state: the training state. """ total_steps = self._total_steps save_freq = self._chkpt_frequency val_freq = self._valid_frequency print_freq = self._print_frequency enable_progress_bar = self._enable_progress_bar early_stop_cycles = self._early_stop_cycles training_done: bool = False root_logger = logging.get_root_logger() iter_timer = core.AverageTimer() pbar_disabled = ( self._is_distributed and self.rank != 0 ) or not enable_progress_bar pbar = tqdm.tqdm( total=total_steps if total_steps != float("inf") else None, desc="Training epochs", unit="epoch", disable=pbar_disabled, ) dataloader: tud.DataLoader sampler: ResumableSamplerType dataloader, sampler = core.get_from_optional(self.datamodule.train_dataloader()) sampler.start_iter = state.dataset_iter self.model.train() iterator = ( range(state.global_epoch, int(total_steps)) if total_steps != float("inf") else itertools.count(start=state.global_epoch) ) for epoch in iterator: if training_done: break state.global_epoch += 1 if state.global_epoch > total_steps: training_done = True break root_logger.info(f"Starting epoch {epoch + 1}") sampler.set_epoch(epoch) epoch_start = time.time() self.model.on_training_epoch_start(state.global_epoch) iter_timer.start() with tqdm.tqdm( total=len(dataloader), desc=f"Epoch {epoch + 1}", unit="iter", disable=pbar_disabled, leave=False, ) as ite_pbar: for batch in dataloader: state.global_iteration += 1 state.current_iteration += 1 state.running_iter += 1 batch = self._plugins_process_batch("training", batch, state=state) self.model.on_training_batch_start(state) self.model.train_step(batch, state) iter_timer.record() state.average_iter_time = iter_timer.get_average_time() self.model.on_training_batch_end( state, should_log=( False if print_freq is None else state.current_iteration % print_freq == 0 ), ) state.dataset_iter += 1 if not ite_pbar.update(): ite_pbar.refresh() if ( self._plugins_should_training_stop() or self.model.should_training_stop() ): training_done = True break state.dataset_iter = 0 if val_freq is not None and state.global_epoch % val_freq == 0: self._validate(state.validation_cycles) state.running_iter = 0 if not self.model.have_metrics_improved(): state.early_stop_count += 1 else: state.early_stop_count = 0 state.validation_cycles += 1 if ( save_freq is not None and state.global_epoch % save_freq == 0 and self.rank == 0 ): self._save_checkpoint(state) self.model.on_training_epoch_end(state.global_epoch) root_logger.info( f"Epoch {epoch + 1} completed in {time.time() - epoch_start:.2f}s" ) pbar.update() if ( early_stop_cycles is not None and state.early_stop_count >= early_stop_cycles ): training_done = True if self._plugins_should_training_stop() or self.model.should_training_stop(): training_done = True def _validate(self, val_cycle: int) -> None: """ Run the validation loop. Args: val_cycle: the current validation cycle number. """ if self.datamodule.valid_dataloader() is None: return dataloader: tud.DataLoader sampler: ResumableSamplerType dataloader, sampler = core.get_from_optional(self.datamodule.valid_dataloader()) enable_progress_bar = self._enable_progress_bar pbar_disabled = ( self._is_distributed and self.rank != 0 ) or not enable_progress_bar pbar = tqdm.tqdm( total=len(dataloader), desc="Validation", unit="it", disable=pbar_disabled, leave=False, ) with core.cuda.DisableCuDNNBenchmarkContext(): self.model.eval() self._execute_plugins("on_validation_start", validation_cycle=val_cycle) self.model.on_validation_start(val_cycle) with torch.no_grad(): for idx, batch in enumerate(dataloader): batch = self._plugins_process_batch("validation", batch, step=idx) self.model.on_validation_batch_start(idx) self.model.valid_step(batch, idx) self.model.on_validation_batch_end(idx) # Ensure the progress bar is updated in the event that the validation # loop runs faster than the refresh rate of the progress bar. if not pbar.update(): pbar.refresh() self.model.train() self._execute_plugins("on_validation_end", validation_cycle=val_cycle) self.model.on_validation_end(val_cycle) dist.safe_barrier() def _validate_plugins(self) -> None: seen_overrides: dict[str, str] = {} fields = dc.fields(hlp.UniquePluginOverrides) for _, plugin in self._plugins.items(): for field in fields: name = field.name if getattr(plugin.unique_overrides, name): if name not in seen_overrides: seen_overrides[name] = str(type(plugin)) else: raise ValueError( f"error: override field {name} has already been overridden " f"by {seen_overrides[name]}" ) def _execute_plugins(self, func_name: str, **kwargs: typing.Any) -> None: for plugin in self._plugins.values(): func = getattr(plugin, func_name) func(**kwargs) def _plugins_process_batch( self, mode: str, batch: typing.Any, **kwargs: typing.Any ) -> typing.Any: func_name = f"process_{mode}_batch" override_name = f"{mode}_batch" for plugin in self._plugins.values(): if getattr(plugin.unique_overrides, override_name): return getattr(plugin, func_name)(batch=batch, **kwargs) return batch def _plugins_should_training_stop(self) -> bool: return any(plugin.should_training_stop() for plugin in self._plugins.values())