importtypingimportoptunaimporttorchimporthelios.core.distributedasdistimporthelios.modelashlmimporthelios.pluginsashlpimporthelios.trainerashlt_PRUNED_KEY="ddp_hl:pruned"_CYCLE_KEY="ddp_hl:cycle"# Ignore private member access# ruff: noqa: SLF001
[docs]@hlp.PLUGIN_REGISTRY.registerclassOptunaPlugin(hlp.Plugin):""" Plug-in to do hyper-parameter tuning with Optuna. This plug-in integrates `Optuna <https://optuna.readthedocs.io/en/stable/>`__ into the training system in order to provide hyper-parameter tuning. The plug-in provides the following functionality: #. Automatic handling of trial pruning. #. Automatic reporting of metrics. #. Exception registration for trial pruning. #. Easy integration with Helios' checkpoint system to continue stopped trials. Example: .. code-block:: python import helios.plugins as hlp import optuna def objective(trial: optuna.Trial) -> float: datamodule = ... model = ... plugin = hlp.optuna.OptunaPlugin(trial, "accuracy") trainer = ... # Automatically registers the plug-in with the trainer. plugin.configure_trainer(trainer) # This can be skipped if you don't want the auto-resume functionality or # if you wish to manage it yourself. plugin.configure_model(model) trainer.fit(model, datamodule) plugin.check_pruned() return model.metrics["accuracy"] def main(): # Note that the plug-in requires the storage to be persistent. study = optuna.create_study(storage="sqlite:///example.db", ...) study.optimize(objective, ...) Args: trial: the Optuna trial. metric_name: the name of the metric to monitor. This assumes the name will be present in the :py:attr:`~helios.model.model.Model.metrics` table. """plugin_id="optuna"def__init__(self,trial:optuna.Trial,metric_name:str)->None:"""Create the plug-in."""super().__init__(self.plugin_id)self._trial=trialself._metric_name=metric_nameself._last_cycle:int=0self.unique_overrides.should_training_stop=True@propertydeftrial(self)->optuna.Trial:"""Return the trial."""returnself._trial@trial.setterdeftrial(self,t:optuna.Trial)->None:self._trial=t
[docs]defconfigure_trainer(self,trainer:hlt.Trainer)->None:""" Configure the trainer with the required settings. This will do two things: #. Register the plug-in itself with the trainer. #. Append the trial pruned exception to the trainer. Args: trainer: the trainer instance. """self._register_in_trainer(trainer)self._append_train_exceptions(optuna.TrialPruned,trainer)
[docs]defconfigure_model(self,model:hlm.Model)->None:""" Configure the model to set the trial number into the save name. This will alter the :py:attr:`~helios.model.model.Model.save_name` property of the model by appending :code:`_trial-<trial-numer>`. Args: model: the model instance. """n_trial=self.trial.numbermodel._save_name=model._save_name+f"_trial-{n_trial}"
[docs]defsuggest(self,type_name:str,name:str,**kwargs:typing.Any)->typing.Any:""" Generically Wrap the ``suggest_`` family of functions of the optuna trial. This function can be used to easily invoke the corresponding ``suggest_`` function from the Optuna trial held by the plug-in without having to manually type each individual function. This lets you write generic code that can be controlled by an external source (such as command line arguments or a config table). The function wraps the following functions: .. list-table:: Suggestion Functions :header-rows: 1 * - Function - Name * - ``optuna.Trial.suggest_categorical`` - categorical * - ``optuna.Trial.suggest_int`` - int * - ``optuna.Trial.suggest_float`` - float .. warning:: Functions that are marked as deprecated by Optuna are *not* included in this wrapper. .. note:: You can find the exact arguments for each function `here <https://optuna.readthedocs.io/en/stable/reference/generated/optuna.trial.Trial.html>`__. Example: .. code-block:: python import helios.plugin as hlp import optuna def objective(trial: optuna.Trial) -> float: plugin = hlp.optuna.OptunaPlugin(trial, "accuracy") # ... configure model and trainer. val1 = plugin.suggest("categorical", "val1", choices=[1, 2, 3]) val2 = plugin.suggest("int", "val2", low=0, high=10) val3 = plugin.suggest("float", "val3", low=0, high=1) Args: type_name: the name of the type to suggest from. name: a parameter name **kwargs: keyword arguments to the corresponding suggest function. Raises: KeyError: if the value passed in to ``type_name`` is not recognised. """iftype_namenotin("categorical","float","int"):raiseKeyError(f"error: {type_name} is not a valid suggestion type.")fn=getattr(self._trial,f"suggest_{type_name}")returnfn(name,**kwargs)
[docs]defsetup(self)->None:""" Configure the plug-in. Raises: ValueError: if the study wasn't created with persistent storage. """ifself.is_distributedandnot(isinstance(self.trial.study._storage,optuna.storages._CachedStorage)andisinstance(self.trial.study._storage._backend,optuna.storages.RDBStorage)):raiseValueError("error: optuna integration supports only optuna.storages.RDBStorage ""in distributed mode")
[docs]defreport_metrics(self,validation_cycle:int)->None:""" Report metrics to the trial. This function should be called from the model once the corresponding metrics have been saved into the :py:attr:`~helios.model.model.Model.metrics` table. Example: .. code-block:: python import helios.model as hlm import helios.plugins.optuna as hlpo class MyModel(hlm.Model): ... def on_validation_end(self, validation_cycle: int) -> None: # Compute metrics self.metrics["accuracy"] = 10 plugin = self.trainer.plugins[hlpo.OptunaPlugin.plugin_id] assert isinstance(plugin hlpo.OptunaPlugin) plugin.report_metrics(validation_cycle) .. note:: In distributed training, only rank 0 will report the metrics to the trial. Args: validation_cycle: the current validation cycle. """model=self.trainer.modelifnotmodel.metricsorself._metric_namenotinmodel.metrics:returnifself.rank==0:self.trial.report(model.metrics[self._metric_name],validation_cycle)self._last_cycle=validation_cycle
[docs]defshould_training_stop(self)->bool:""" Handle trial pruning. Returns: True if the trial should be pruned, false otherwise. """should_stop=Falseifself.rank==0:should_stop=self.trial.should_prune()# Sync the value across all processes (if using distributed training).ifself.is_distributed:t=dist.all_reduce_tensors(torch.tensor(should_stop).to(self.device))should_stop=t.item()# type: ignore[assignment]ifshould_stopandself.rank==0:self.trial.set_user_attr(_PRUNED_KEY,True)self.trial.set_user_attr(_CYCLE_KEY,self._last_cycle)returnshould_stop
[docs]defon_training_end(self)->None:""" Clean-up on training end. If training is non-distributed and the trial was pruned, then this function will do the following: #. Call :py:meth:`~helios.model.model.Model.on_training_end` to ensure metrics are correctly logged (if using). #. Raise :py:exc:`optuna.TrialPruned` exception to signal the trial was pruned. If training is distributed, this function does nothing. Raises: TrialPruned: if the trial was pruned. """ifnotself.is_distributedandself.trial.should_prune():self.trainer.model.on_training_end()raiseoptuna.TrialPruned(f"Pruned on validation cycle {self._last_cycle}")
[docs]defcheck_pruned(self)->None:""" Ensure pruned distributed trials are correctly handled. Due to the way distributed training works, we can't raise an exception within the distributed processes, so we have to do it after we return to the main process. If the trial was pruned, this function will raise :py:exc:`optuna.TrialPruned`. If distributed training wasn't used, this function does nothing. .. warning:: You *must* ensure this function is called after :py:meth:`~helios.trainer.Trainer.fit` to ensure pruning works correctly. Raises: TrialPruned: if the trial was pruned. """trial_id=self.trial._trial_idstudy=self.trial.studytrial=study._storage._backend.get_trial(trial_id)# type: ignore[attr-defined]is_pruned=trial.user_attrs.get(_PRUNED_KEY)val_cycle=trial.user_attrs.get(_CYCLE_KEY)ifis_prunedisNoneorval_cycleisNone:returnifis_pruned:raiseoptuna.TrialPruned(f"Pruned on validation cycle {val_cycle}")
[docs]defstate_dict(self)->dict[str,typing.Any]:""" Get the state of the current trial. This will return the parameters to be optimised for the current trial. Returns: The parameters of the trial. """returnself._trial.params