Plug-Ins ############ Plug-in API =========== Helios offers a plug-in system that allows users to override certain elements of the training loops. All plug-ins *must* derive from the main :py:class:`~helios.plugins.plugin.Plugin` interface. The list of functions that are available is similar to the ones offered by the :py:class:`~helios.model.model.Model` class and follow the same call order. For example, the training loop would look something like this: .. code-block:: python plugin.on_training_start() model.on_training_start() model.train() for epoch in epoch: model.on_training_epoch_start() for batch in dataloader: plugin.process_training_batch(batch) model.on_training_batch_start() model.train_step() model.on_training_batch_end() model.on_training_epoch_end() model.on_training_end() plugin.on_training_end() Notice that the plug-in functions are always called *before* the corresponding model functions. This is to allow the plug-ins to override the model if necessary or to set state that can be later accessed by the model. The model (and the dataloader) can access the plug-ins through the reference to the trainer: .. code-block:: python def train_step(...): model.trainer.plugins["foo"] # <- Access plug-in with name "foo" Batch Processing ---------------- The major difference between the functions of the model and the plug-ins is the lack of a :py:meth:`~helios.model.model.Model.train_step` function (and similarly for validation and testing). Instead, the plug-ins have 3 equivalent functions: * :py:meth:`~helios.plugins.plugin.Plugin.process_training_batch` * :py:meth:`~helios.plugins.plugin.Plugin.process_validation_batch` * :py:meth:`~helios.plugins.plugin.Plugin.process_testing_batch` These functions receive the batch as an argument and return the processed batch. They can be used for a variety of tasks such as moving tensors to a given device, filtering batch entries, converting values, etc. For example, suppose we wanted to reduce the training batch size by removing elements. We could do this as follows: .. code-block:: python def process_training_batch(self, batch: list[torch.Tensor]) -> list[torch.Tensor]: return batch[:2] # <- Take the first two elements of the batch. When the model's :py:meth:`~helios.model.model.Model.train_step` function is called, it will only receive the first 2 tensors of the original batch. Plug-in Registration -------------------- The trainer contains the :py:attr:`~helios.trainer.Trainer.plugins` table in which all plug-ins must be registered. To facilitate this, the plug-in base class requires a string to act as the key with which it will be added to the table. In addition, it provides a function that automatically registers the plug-in itself into the plug-in table. The function can be easily invoked from the :py:meth:`~helios.plugins.plugin.Plugin.configure_trainer` function as follows: .. code-block:: python import helios.plugins as hlp import helios.trainer as hlt class MyPlugin(hlp.Plugin): def __init__(self): super().__init__("my_plugin") def configure_trainer(self, trainer: hlt.Trainer) -> None: self._register_in_trainer(trainer) # <- Automatically registers the plug-in. .. note:: All plug-ins that are shipped with Helios contain a ``plugin_id`` field as a class variable that can be used to easily access them from the trainer table. You are *encouraged* to always use this instead of manually typing in the key. For example, with the :py:class:`~helios.plugins.plugin.CUDAPlugin`, you could access it like this: .. code-block:: python import helios.plugins as hlp import helios.trainer as hlt trainer = hlt.Trainer(...) plugin = hlp.CUDAPlugin() plugin.configure_trainer(trainer) trainer.plugins[CUDAPlugin.plugin_id] # <- Access the plug-in like this. Unique Traits ------------- In order to avoid conflicts, the plug-in API designates certain functions as *unique*. In this context, a plug-in with a *unique* override may only appear exactly *once* in the :py:attr:`~helios.trainer.Trainer.plugins` table from the trainer. If a second plug-in with that specific override is added, an exception is raised. The full list of overrides can be found in the :py:class:`~helios.plugins.plugin.UniquePluginOverrides` struct. Each plug-in has a copy found under :py:attr:`~helios.plugins.plugin.Plugin.unique_overrides` and *must* be filled in with the corresponding information for each plug-in. For example, suppose we want to build a new plug-in that can modify the training batch and cause training to stop early. We would then set the structure as follows: .. code-block:: python import helios.plugins as hlp class MyPlugin(hlp.Plugin): def __init__(self): super().__init__("my_plugin") self.unique_overrides.training_batch = True self.unique_overrides.should_training_stop = True def process_training_batch(...): ... def should_training_stop(...): ... .. warning:: Attempting to add two plug-ins with the same overrides **will** result in an exception being raised. Built-in Plug-ins ================= Helios ships with the following built-in plug-ins, which will be discussed in the following sections: * :py:class:`~helios.plugins.plugin.CUDAPlugin` * :py:class:`~helios.plugins.optuna.OptunaPlugin` CUDA Plug-in ------------ The :py:class:`~helios.plugins.plugin.CUDAPlugin` is designed to move tensors from the batches returned by the datasets to the current CUDA device. The device is determined by the trainer when training starts with the same logic used to assign the device to the model. Specifically: * If training isn't distributed, the device is the GPU that is used for training. * If training is distributed, then the device corresponds to the GPU assigned to the given process (i.e. the local rank). .. warning:: As its name implies, the :py:class:`~helios.plugins.plugin.CUDAPlugin` **requires** CUDA to be enabled to function. If it isn't, an exception is raised. The plug-in is designed to handle the following types of batches: * :py:class:`torch.Tensor`, * Lists of :py:class:`torch.Tensor`, * Tuples of :py:class:`torch.Tensor`, and * Dictionaries whose values are :py:class:`torch.Tensor`. .. note:: The contents of the containers need not be homogeneous. In other words, it is perfectly valid some entries in a dictionary to *not* be tensors. The plug-in will automatically recognise tensors and move them to the device. .. warning:: The plug-in is **not** designed to handle nested containers. For instance, if your batch is a dictionary containing arrays of tensors, then the plug-in will **not** recognise the tensors contained in the arrays and move them. In the event that your batch requires special handling, you can easily derive the class and override the function that moves the tensors to the device. For example, suppose that our batch consists of a dictionary of arrays of tensors. Then we would do the following: .. code-block:: python import helios.plugins as hlp import torch class MyCUDAPlugin(hlp.CUDAPlugin): # Only need to override this function. Everything else will work automatically. def _move_collection_to_device( self, batch: dict[str, list[torch.Tensor]] ) -> dict[str, list[torch.Tensor]]: for key, value in batch.items(): for i in range(len(value)): value[i] = value[i].to(self.device) batch[key] = value return batch .. note:: The :py:class:`~helios.plugins.plugin.CUDAPlugin` is automatically registered in the plug-in registry and can therefore be created through the :py:func:`~helios.plugins.plugin.create_plugin` function. Optuna Plug-in -------------- In order to use the Optuna plugin, we first need to install `optuna `__:: pip install -U optuna .. warning:: Optuna is a **required** dependency for this plug-in. If it isn't installed, an exception is raised. The plug-in will automatically integrate with Optuna for hyper-parameter optimisation by performing the following tasks: * Register the :py:class:`optuna.TrialPruned` exception type with the trainer for correct trial pruning. * Automatically update the :py:class:`~helios.model.model.Model` so the save name is consistent and allow trials to continue if they're interrupted. * Correctly handle reporting and pruning for regular and distributed training. A full example for how to use this plug-in can be found `here `__, but we will discuss the basics below. For the sake of simplicity, the code is identical to the `cifar10 `__ example, so we will only focus on the necessary code to use the plug-in. Plug-in Registration ^^^^^^^^^^^^^^^^^^^^ After the creation of the :py:class:`~helios.model.model.Model`, :py:class:``, and the :py:class:`~helios.trainer.Trainer`, we can create the plug-in and do the following: .. code-block:: python import helios.plugins.optuna as hlpo import optuna def objective(trial: optuna.Trial) -> float: model = ... datamodule = ... trainer = ... plugin = hlpo.OptunaPlugin(trial, "accuracy") plugin.configure_trainer(trainer) plugin.configure_model(model) The two ``configure_`` functions will do the following: #. Configure the trainer so the plug-in is registered into the plug-in table and ensure that :py:class`optuna.TrialPruned`. #. Configure the name of the model to allow cancelled trials to continue. Specifically, it will append ``_trial-`` to the model name. .. note:: The call to :py:meth:`~helios.plugins.optuna.OptunaPlugin.configure_model` is completely optional and only impacts the ability to resume trials. You may choose to handle this yourself if it makes sense for your use-case. Using the Trial ^^^^^^^^^^^^^^^ The trial instance is held by the plugin and can be easily accessed through the trainer. For example, we can use it to configure the layers in the classifier network within the :py:meth:`~helios.model.model.Model.setup` function like this: .. code-block:: python def setup(self, fast_init: bool = False) -> None: plugin = self.trainer.plugins[0] assert isinstance(plugin, OptunaPlugin) # Assign the tunable parameters so we can log them as hyper-parameters when # training ends. self._tune_params["l1"] = plugin.trial.suggest_categorical( "l1", [2**i for i in range(9)] ) self._tune_params["l2"] = plugin.trial.suggest_categorical( "l2", [2**i for i in range(9)] ) self._tune_params["lr"] = plugin.trial.suggest_float("lr", 1e-4, 1e-1, log=True) self._net = Net( l1=self._tune_params["l1"], # type: ignore[arg-type] l2=self._tune_params["l2"], # type: ignore[arg-type] ).to(self.device) Reporting Metrics ^^^^^^^^^^^^^^^^^ As the plug-in will automatically handle the reporting of metrics to the trial, it is important for it to know which metric should be reported. This is accomplished by two things: #. The :py:attr:`~helios.model.model.Model.metrics` table and #. The value of ``metric_name`` in the constructor of :py:class:`~helios.plugins.optuna.OptunaPlugin`. In order for the plug-in to work properly, the plug-in assumes that the ``metric_name`` key exists in the :py:attr:`~helios.model.model.Model.metrics` table. If it doesn't, nothing is reported to the trial. The plug-in will automatically handled distributed training correctly, so there's no need for the model to do extra work. .. warning:: In distributed training, it is your responsibility to ensure that the value of the metric is correctly synced across processess (if applicable). Trial Pruning and Returning Metrics ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ The plug-in will automatically detect if a trial is pruned by optuna and gracefully request that training end. The exact behaviour depends on whether training is distributed or not. Specifically: * If training is not distributed, then the plug-in will raise a :py:class:`optuna.TrialPruned` exception *after* calling :py:meth:`~helios.model.model.Model.on_training_end` on the model. This ensures that if any metrics are logged when training ends, they get logged if the trial is pruned. * If training is distributed, then the plug-in requests that training terminate early. The normal execution flow occurs when training is terminated early. Once the code exits the :py:meth:`` function, the user should call :py:meth:`~helios.plugins.optuna.OptunaPlugin.check_pruned` to ensure that the corresponding exception is correctly raised. In code, this can be handled as follows: .. code-block:: python def objective(trial: optuna.Trial) -> float: ... plugin.configure_trainer(trainer) plugin.configure_model(model), datamodule) plugin.check_pruned() To correctly return metrics, there are two cases that need to be handled. If training isn't distributed, then the metrics can be grabbed directly from the :py:attr:`~helios.model.model.Model.metrics` table. If training is distributed, then the model needs to do a bit more work to ensure things get synchronized correctly. For our example, we will place the synchronization of the metrics on :py:meth:`~helios.model.model.Model.on_training_end`, but you may place it elsewhere if it's convenient for you: .. code-block:: python def on_training_end(self) -> None: ... # Push the metrics we want to save into the multi-processing queue. if self.is_distributed and self.rank == 0: assert self.trainer.queue is not None self.trainer.queue.put( {"accuracy": accuracy, "loss": self._loss_items["loss"].item()} ) The :py:attr:`~helios.trainer.Trainer.queue` ensures that the values get transferred to the primary process. Once that's done, we just need to add the following to our ``objective`` function: .. code-block:: python def objective(trial: optuna.Trial) -> float: ... plugin.configure_trainer(trainer) plugin.configure_model(model), datamodule) plugin.check_pruned() if trainer.queue is None: return model.metrics["accuracy"] metrics = trainer.queue.get() return metrics["accuracy"] Generic Suggestion of Parameters ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ The plug-in comes equipped with a function that wraps the ``suggest_`` family of functions from the :py:class:`optuna.Trial` instance it holds. This function is designed to allow the suggestion of parameters to be controlled by an outside source (such as command line arguments or a config file). The goal is to allow code re-usability by not having the parameters be hard-coded. The function is called :py:meth:`~helios.plugins.optuna.OptunaPlugin.suggest` and can be used as follows: .. code-block:: python def objective(trial: optuna.Trial) -> float: val1 = plugin.suggest("categorical", "val1", choices=[1, 2, 3]) val2 = plugin.suggest("float", "val2", low=0, high=1, log=True) The exact arguments for each ``suggest_`` function can be found `here `__. .. warning:: The plug-in does *not* provide wrappers for any function that is marked as deprecated.