helios.plugins.optuna¶
Classes¶
Plug-in to do hyper-parameter tuning with Optuna. |
Functions¶
|
Resume a study that stopped because of a failure. |
|
Create a checkpoint with the state of the sampler. |
|
Restore the sampler from a previously saved checkpoint. |
Module Contents¶
- helios.plugins.optuna.resume_study(study_args: dict[str, Any], failed_states: Sequence = (optuna.trial.TrialState.FAIL,), backup_study: bool = True) optuna.Study [source]¶
Resume a study that stopped because of a failure.
The goal of this function is to allow studies that failed due to an error (either an exception, system error etc.) to continue utilising the built-in checkpoint system from Helios. To accomplish this, the function will do the following:
- Grab all the trials from the study created by the given
study_args
, splitting them into three groups: completed/pruned, failed, and failed but completed.
- Grab all the trials from the study created by the given
- Create a new study with the same name and storage. This new study will get all of
the completed trials of the original, and will have the failed trials re-enqueued.
Warning
- This function requires the following conditions to be true:
It is called before the trials are started.
#. The study uses
RDBStorage
as the storage argument foroptuna.create_study
. #.load_if_exists
is set to True instudy_args
. #.TrialState.PRUNED
cannot be in the list offailed_states
.
The
failed_states
argument can be used to set additional trial states to be considered as “failures”. This can be useful when dealing with special cases where trials were either completed or pruned but need to be re-run.By default, the original study (assuming there is one) will be backed up with the name
<study-name>_backup-#
where<study-name>
is the name of the database of the original study, and#
is an incremental number starting at 0. This behaviour can be disabled by settingbackup_study
to False.This function works in tandem with
configure_model()
to ensure that when the failed trial is re-run, the original save name is restored so any saved checkpoints can be re-used so the trial can continue instead of starting from scratch.Note
Only trials that fail but haven’t been completed will be enqueued by this function. If a trial fails and is completed later on, it will be treated as if it had finished successfully.
- Parameters:
study_args – dictionary of arguments for
optuna.create_study
.failed_states – the trial states that are considered to be failures and should be re-enqueued.
backup_study – if True, the original study is backed up so it can be re-used later on.
- helios.plugins.optuna.checkpoint_sampler(trial: optuna.Trial, chkpt_root: pathlib.Path) None [source]¶
Create a checkpoint with the state of the sampler.
This function can be used to ensure that if a study is restarted, the state of the sampler is recovered so trials can be reproducible. The function will automatically create a checkpoint using
torch.save
.Note
It is recommended that this function be called at the start of the objective function to ensure the checkpoint is made correctly, but it can be called at any time.
- Parameters:
trial – the current trial.
chkpt_root – the root where the checkpoints will be saved.
- helios.plugins.optuna.restore_sampler(chkpt_root: pathlib.Path) optuna.samplers.BaseSampler | None [source]¶
Restore the sampler from a previously saved checkpoint.
This function can be used in tandem with
checkpoint_sampler()
to ensure that the last checkpoint is loaded and the correct state is restored for the sampler. This function needs to be called beforeoptuna.create_study
is called.- Parameters:
chkpt_root – the root where the checkpoints are stored.
- Returns:
The restored sampler.
- class helios.plugins.optuna.OptunaPlugin(trial: optuna.Trial, metric_name: str)[source]¶
Bases:
helios.plugins.Plugin
Plug-in to do hyper-parameter tuning with Optuna.
This plug-in integrates Optuna into the training system in order to provide hyper-parameter tuning. The plug-in provides the following functionality:
Automatic handling of trial pruning.
Automatic reporting of metrics.
Exception registration for trial pruning.
Easy integration with Helios’ checkpoint system to continue stopped trials.
Warning
This plug-in requires Optuna to be installed before being used. If it isn’t, then
ImportError
is raised.Example
import helios.plugins as hlp import optuna def objective(trial: optuna.Trial) -> float: datamodule = ... model = ... plugin = hlp.optuna.OptunaPlugin(trial, "accuracy") trainer = ... # Automatically registers the plug-in with the trainer. plugin.configure_trainer(trainer) # This can be skipped if you don't want the auto-resume functionality or # if you wish to manage it yourself. plugin.configure_model(model) trainer.fit(model, datamodule) plugin.check_pruned() return model.metrics["accuracy"] def main(): # Note that the plug-in requires the storage to be persistent. study = optuna.create_study(storage="sqlite:///example.db", ...) study.optimize(objective, ...)
- Parameters:
trial – the Optuna trial.
metric_name – the name of the metric to monitor. This assumes the name will be present in the
metrics
table.
- plugin_id = 'optuna'¶
- property trial: optuna.Trial¶
Return the trial.
- configure_trainer(trainer: helios.trainer.Trainer) None [source]¶
Configure the trainer with the required settings.
This will do two things:
Register the plug-in itself with the trainer.
Append the trial pruned exception to the trainer.
- Parameters:
trainer – the trainer instance.
- configure_model(model: helios.model.Model) None [source]¶
Configure the model to set the trial number into the save name.
This will alter the
save_name
property of the model by appending_trial-<trial-numer>
.- Parameters:
model – the model instance.
- suggest(type_name: str, name: str, **kwargs: Any) Any [source]¶
Generically Wrap the
suggest_
family of functions of the optuna trial.This function can be used to easily invoke the corresponding
suggest_
function from the Optuna trial held by the plug-in without having to manually type each individual function. This lets you write generic code that can be controlled by an external source (such as command line arguments or a config table). The function wraps the following functions:¶ Function
Name
optuna.Trial.suggest_categorical
categorical
optuna.Trial.suggest_int
int
optuna.Trial.suggest_float
float
Warning
Functions that are marked as deprecated by Optuna are not included in this wrapper.
Note
You can find the exact arguments for each function here.
Example
import helios.plugin as hlp import optuna def objective(trial: optuna.Trial) -> float: plugin = hlp.optuna.OptunaPlugin(trial, "accuracy") # ... configure model and trainer. val1 = plugin.suggest("categorical", "val1", choices=[1, 2, 3]) val2 = plugin.suggest("int", "val2", low=0, high=10) val3 = plugin.suggest("float", "val3", low=0, high=1)
- Parameters:
type_name – the name of the type to suggest from.
name – a parameter name
**kwargs – keyword arguments to the corresponding suggest function.
- Raises:
KeyError – if the value passed in to
type_name
is not recognised.
- setup() None [source]¶
Configure the plug-in.
- Raises:
ValueError – if the study wasn’t created with persistent storage.
- report_metrics(validation_cycle: int) None [source]¶
Report metrics to the trial.
This function should be called from the model once the corresponding metrics have been saved into the
metrics
table.Example
import helios.model as hlm import helios.plugins.optuna as hlpo class MyModel(hlm.Model): ... def on_validation_end(self, validation_cycle: int) -> None: # Compute metrics self.metrics["accuracy"] = 10 plugin = self.trainer.plugins[hlpo.OptunaPlugin.plugin_id] assert isinstance(plugin hlpo.OptunaPlugin) plugin.report_metrics(validation_cycle)
Note
In distributed training, only rank 0 will report the metrics to the trial.
- Parameters:
validation_cycle – the current validation cycle.
- should_training_stop() bool [source]¶
Handle trial pruning.
- Returns:
True if the trial should be pruned, false otherwise.
- on_training_end() None [source]¶
Clean-up on training end.
If training is non-distributed and the trial was pruned, then this function will do the following:
Call
on_training_end()
to ensure metrics are correctly logged (if using).Raise
optuna.TrialPruned
exception to signal the trial was pruned.
If training is distributed, this function does nothing.
- Raises:
TrialPruned – if the trial was pruned.
- check_pruned() None [source]¶
Ensure pruned distributed trials are correctly handled.
Due to the way distributed training works, we can’t raise an exception within the distributed processes, so we have to do it after we return to the main process. If the trial was pruned, this function will raise
optuna.TrialPruned
. If distributed training wasn’t used, this function does nothing.Warning
You must ensure this function is called after
fit()
to ensure pruning works correctly.- Raises:
TrialPruned – if the trial was pruned.