helios.plugins.optuna
=====================
.. py:module:: helios.plugins.optuna
Classes
-------
.. autoapisummary::
helios.plugins.optuna.OptunaPlugin
Module Contents
---------------
.. py:class:: OptunaPlugin(trial: optuna.Trial, metric_name: str)
Bases: :py:obj:`helios.plugins.Plugin`
Plug-in to do hyper-parameter tuning with Optuna.
This plug-in integrates `Optuna `__ into the
training system in order to provide hyper-parameter tuning. The plug-in provides the
following functionality:
#. Automatic handling of trial pruning.
#. Automatic reporting of metrics.
#. Exception registration for trial pruning.
#. Easy integration with Helios' checkpoint system to continue stopped trials.
.. warning::
This plug-in **requires** Optuna to be installed before being used. If it isn't,
then :py:exc:`ImportError` is raised.
.. rubric:: Example
.. code-block:: python
import helios.plugins as hlp
import optuna
def objective(trial: optuna.Trial) -> float:
datamodule = ...
model = ...
plugin = hlp.optuna.OptunaPlugin(trial, "accuracy")
trainer = ...
# Automatically registers the plug-in with the trainer.
plugin.configure_trainer(trainer)
# This can be skipped if you don't want the auto-resume functionality or
# if you wish to manage it yourself.
plugin.configure_model(model)
trainer.fit(model, datamodule)
plugin.check_pruned()
return model.metrics["accuracy"]
def main():
# Note that the plug-in requires the storage to be persistent.
study = optuna.create_study(storage="sqlite:///example.db", ...)
study.optimize(objective, ...)
:param trial: the Optuna trial.
:param metric_name: the name of the metric to monitor. This assumes the name will be
present in the :py:attr:`~helios.model.model.Model.metrics` table.
.. py:attribute:: plugin_id
:value: 'optuna'
.. py:property:: trial
:type: optuna.Trial
Return the trial.
.. py:method:: configure_trainer(trainer: helios.trainer.Trainer) -> None
Configure the trainer with the required settings.
This will do two things:
#. Register the plug-in itself with the trainer.
#. Append the trial pruned exception to the trainer.
:param trainer: the trainer instance.
.. py:method:: configure_model(model: helios.model.Model) -> None
Configure the model to allow trials to resume.
This will alter the :py:attr:`~helios.model.model.Model.save_name` property of the
model by appending :code:`_trial-`. In the event that a trial with
that number has already been attempted, it will be set to that number instead.
This will allow the automatic checkpoint system of the trainer to resume the
trial.
:param model: the model instance.
.. py:method:: suggest(type_name: str, name: str, **kwargs: Any) -> Any
Generically Wrap the ``suggest_`` family of functions of the optuna trial.
This function can be used to easily invoke the corresponding ``suggest_`` function
from the Optuna trial held by the plug-in without having to manually type each
individual function. This lets you write generic code that can be controlled by an
external source (such as command line arguments or a config table). The function
wraps the following functions:
.. list-table:: Suggestion Functions
:header-rows: 1
* - Function
- Name
* - ``optuna.Trial.suggest_categorical``
- categorical
* - ``optuna.Trial.suggest_int``
- int
* - ``optuna.Trial.suggest_float``
- float
.. warning::
Functions that are marked as deprecated by Optuna are *not* included in this
wrapper.
.. note::
You can find the exact arguments for each function `here
`__.
.. rubric:: Example
.. code-block:: python
import helios.plugin as hlp
import optuna
def objective(trial: optuna.Trial) -> float:
plugin = hlp.optuna.OptunaPlugin(trial, "accuracy")
# ... configure model and trainer.
val1 = plugin.suggest("categorical", "val1", choices=[1, 2, 3])
val2 = plugin.suggest("int", "val2", low=0, high=10)
val3 = plugin.suggest("float", "val3", low=0, high=1)
:param type_name: the name of the type to suggest from.
:param name: a parameter name
:param \*\*kwargs: keyword arguments to the corresponding suggest function.
:raises KeyError: if the value passed in to ``type_name`` is not recognised.
.. py:method:: setup() -> None
Configure the plug-in.
:raises ValueError: if the study wasn't created with persistent storage.
.. py:method:: on_validation_end(validation_cycle: int) -> None
Report metrics to the trial.
.. note::
In distributed training, only rank 0 will report the metrics to the trial.
:param validation_cycle: the current validation cycle.
.. py:method:: should_training_stop() -> bool
Handle trial pruning.
:returns: True if the trial should be pruned, false otherwise.
.. py:method:: on_training_end() -> None
Clean-up on training end.
If training is non-distributed and the trial was pruned, then this function will
do the following:
#. Call :py:meth:`~helios.model.model.Model.on_training_end` to ensure metrics are
correctly logged (if using).
#. Raise :py:exc:`optuna.TrialPruned` exception to signal the trial was pruned.
If training is distributed, this function does nothing.
:raises TrialPruned: if the trial was pruned.
.. py:method:: check_pruned() -> None
Ensure pruned distributed trials are correctly handled.
Due to the way distributed training works, we can't raise an exception within the
distributed processes, so we have to do it after we return to the main process.
If the trial was pruned, this function will raise :py:exc:`optuna.TrialPruned`. If
distributed training wasn't used, this function does nothing.
.. warning::
You *must* ensure this function is called after
:py:meth:`~helios.trainer.Trainer.fit` to ensure pruning works correctly.
:raises TrialPruned: if the trial was pruned.