helios.plugins.optuna

Classes

OptunaPlugin

Plug-in to do hyper-parameter tuning with Optuna.

Module Contents

class helios.plugins.optuna.OptunaPlugin(trial: optuna.Trial, metric_name: str)

Bases: helios.plugins.Plugin

Plug-in to do hyper-parameter tuning with Optuna.

This plug-in integrates Optuna into the training system in order to provide hyper-parameter tuning. The plug-in provides the following functionality:

  1. Automatic handling of trial pruning.

  2. Automatic reporting of metrics.

  3. Exception registration for trial pruning.

  4. Easy integration with Helios’ checkpoint system to continue stopped trials.

Warning

This plug-in requires Optuna to be installed before being used. If it isn’t, then ImportError is raised.

Example

import helios.plugins as hlp
import optuna

def objective(trial: optuna.Trial) -> float:
    datamodule = ...
    model = ...
    plugin = hlp.optuna.OptunaPlugin(trial, "accuracy")

    trainer = ...

    # Automatically registers the plug-in with the trainer.
    plugin.configure_trainer(trainer)

    # This can be skipped if you don't want the auto-resume functionality or
    # if you wish to manage it yourself.
    plugin.configure_model(model)

    trainer.fit(model, datamodule)
    plugin.check_pruned()
    return model.metrics["accuracy"]

def main():
    # Note that the plug-in requires the storage to be persistent.
    study = optuna.create_study(storage="sqlite:///example.db", ...)
    study.optimize(objective, ...)
Parameters:
  • trial – the Optuna trial.

  • metric_name – the name of the metric to monitor. This assumes the name will be present in the metrics table.

plugin_id = 'optuna'
property trial: optuna.Trial

Return the trial.

configure_trainer(trainer: helios.trainer.Trainer) None

Configure the trainer with the required settings.

This will do two things:

  1. Register the plug-in itself with the trainer.

  2. Append the trial pruned exception to the trainer.

Parameters:

trainer – the trainer instance.

configure_model(model: helios.model.Model) None

Configure the model to allow trials to resume.

This will alter the save_name property of the model by appending _trial-<trial-numer>. In the event that a trial with that number has already been attempted, it will be set to that number instead. This will allow the automatic checkpoint system of the trainer to resume the trial.

Parameters:

model – the model instance.

suggest(type_name: str, name: str, **kwargs: Any) Any

Generically Wrap the suggest_ family of functions of the optuna trial.

This function can be used to easily invoke the corresponding suggest_ function from the Optuna trial held by the plug-in without having to manually type each individual function. This lets you write generic code that can be controlled by an external source (such as command line arguments or a config table). The function wraps the following functions:

Suggestion Functions

Function

Name

optuna.Trial.suggest_categorical

categorical

optuna.Trial.suggest_int

int

optuna.Trial.suggest_float

float

Warning

Functions that are marked as deprecated by Optuna are not included in this wrapper.

Note

You can find the exact arguments for each function here.

Example

import helios.plugin as hlp
import optuna

def objective(trial: optuna.Trial) -> float:
    plugin = hlp.optuna.OptunaPlugin(trial, "accuracy")
    # ... configure model and trainer.

    val1 = plugin.suggest("categorical", "val1", choices=[1, 2, 3])
    val2 = plugin.suggest("int", "val2", low=0, high=10)
    val3 = plugin.suggest("float", "val3", low=0, high=1)
Parameters:
  • type_name – the name of the type to suggest from.

  • name – a parameter name

  • **kwargs – keyword arguments to the corresponding suggest function.

Raises:

KeyError – if the value passed in to type_name is not recognised.

setup() None

Configure the plug-in.

Raises:

ValueError – if the study wasn’t created with persistent storage.

on_validation_end(validation_cycle: int) None

Report metrics to the trial.

Note

In distributed training, only rank 0 will report the metrics to the trial.

Parameters:

validation_cycle – the current validation cycle.

should_training_stop() bool

Handle trial pruning.

Returns:

True if the trial should be pruned, false otherwise.

on_training_end() None

Clean-up on training end.

If training is non-distributed and the trial was pruned, then this function will do the following:

  1. Call on_training_end() to ensure metrics are correctly logged (if using).

  2. Raise optuna.TrialPruned exception to signal the trial was pruned.

If training is distributed, this function does nothing.

Raises:

TrialPruned – if the trial was pruned.

check_pruned() None

Ensure pruned distributed trials are correctly handled.

Due to the way distributed training works, we can’t raise an exception within the distributed processes, so we have to do it after we return to the main process. If the trial was pruned, this function will raise optuna.TrialPruned. If distributed training wasn’t used, this function does nothing.

Warning

You must ensure this function is called after fit() to ensure pruning works correctly.

Raises:

TrialPruned – if the trial was pruned.