helios.plugins.plugin¶
Attributes¶
Global instance of the registry for plug-ins. |
Classes¶
Set of flags that determine the unique overrides a plug-in can have. |
|
Base class for plug-ins that extend the functionality of the Helios trainer. |
|
Plug-in to move elements of a training batch to a GPU. |
Functions¶
|
Create the plug-in for the given type. |
Module Contents¶
- class helios.plugins.plugin.UniquePluginOverrides[source]¶
Set of flags that determine the unique overrides a plug-in can have.
In order to avoid conflicts, two plug-ins should not be able to perform the same action twice. For example, it shouldn’t be possible to have two distinct plug-ins perform processing on the training batch as that would cause undefined behaviour. This structure therefore holds all the possible overrides a plug-in might have that must remain unique.
- Parameters:
training_batch – if true, the plug-in performs processing on the training batch.
validation_batch – if true, the plug-in performs processing on the validation batch.
testing_batch – if true, the plug-in performs processing on the testing batch.
should_training_stop – if true, the plug-in can arbitrarily stop training.
- training_batch: bool = False¶
- validation_batch: bool = False¶
- testing_batch: bool = False¶
- should_training_stop: bool = False¶
- helios.plugins.plugin.PLUGIN_REGISTRY¶
Global instance of the registry for plug-ins.
By default, the registry contains the following plug-ins:
¶ Plugin
Name
CUDAPlugin
OptunaPlugin
Note
The
OptunaPlugin
is only registered if the module is imported somewhere in the code. Otherwise it won’t be registered.Example
import helios.plug-ins as hlp # This automatically registers your plug-in @hlp.PLUGIN_REGISTRY class MyPlugin(hlp.Plugin): ... # Alternatively, you can manually register a plug-in like this: hlp.PLUGIN_REGISTRY.register(MyPlugin)
- helios.plugins.plugin.create_plugin(type_name: str, *args: Any, **kwargs: Any) Plugin [source]¶
Create the plug-in for the given type.
- Parameters:
type_name – the type of the plug-in to create.
args – positional arguments to pass into the plug-in.
kwargs – keyword arguments to pass into the plug-in.
- Returns:
The plug-in.
- class helios.plugins.plugin.Plugin(plug_id: str)[source]¶
Bases:
abc.ABC
Base class for plug-ins that extend the functionality of the Helios trainer.
You can use this class to customize the behaviour of training to achieve a variety of objectives. The plug-ins have a similar API to the
Model
class. The only major difference is that the plug-in functions are called before the corresponding model functions, providing the ability to override the model if necessary.- Parameters:
plug_id – the string with which the plug-in will be registered in the trainer plug-in table.
- property unique_overrides: UniquePluginOverrides¶
The set of unique overrides the plug-in uses.
- property is_distributed: bool¶
Flag controlling whether distributed training is being used or not.
- property map_loc: str | dict[str, str]¶
The location to map loaded weights from a checkpoint or pre-trained file.
- property device: torch.device¶
The device on which the plug-in is running.
- property rank: int¶
The local rank (device id) that the plug-in is running on.
- property trainer: helios.trainer.Trainer¶
Reference to the trainer.
- configure_trainer(trainer: helios.trainer.Trainer) None [source]¶
Configure the trainer before training or testing.
This function can be used to set certain properties of the trainer. For example, it can be used to assign valid exceptions that the plug-in requires or to register the plug-in itself in the trainer.
- Parameters:
trainer – the trainer instance.
- configure_model(model: helios.model.Model) None [source]¶
Configure the model before training or testing.
This function can be used to set certain properties of the model. For example, it can be used to override the save name of the model.
- process_training_batch(batch: Any, state: helios.trainer.TrainingState) Any [source]¶
Process the training batch.
This function can be used to perform any processing on the training batch prior to the call to
train_step()
. For example, this can be used to filter out elements in a batch to reduce its size, or it can be used to move all elements in the batch to a set device.- Parameters:
batch – the batch data returned from the dataset.
state – the current training state.
- on_validation_start(validation_cycle: int) None [source]¶
Perform any necessary actions when validation starts.
- Parameters:
validation_cycle – the validation cycle number.
- process_validation_batch(batch: Any, step: int) Any [source]¶
Process the validation batch.
This function can be used to perform any processing on the validation batch prior to the call to
valid_step()
. For example, this can be used to filter out elements in a batch to reduce its size, or it can be used to move all elements in the batch to a set device.- Parameters:
batch – the batch data returned from the dataset.
step – the current validation batch.
- on_validation_end(validation_cycle: int) None [source]¶
Perform any necessary actions when validation ends.
- Parameters:
validation_cycle – the validation cycle number
- should_training_stop() bool [source]¶
Determine whether training should stop or continue.
- Returns:
False if training should continue, true otherwise.
- load_state_dict(state_dict: dict[str, Any]) None [source]¶
Load the plug-in state from the given state dictionary.
Use this function to restore any state from a checkpoint.
- Parameters:
state_dict – the state dictionary to load from.
- state_dict() dict[str, Any] [source]¶
Get the state dictionary of the plug-in.
Use this function to save any state that you require for checkpoints.
- Returns:
The state dictionary of the plug-in.
- process_testing_batch(batch: Any, step: int) Any [source]¶
Process the testing batch.
This function can be used to perform any processing on the testing batch prior to the call to
test_step()
. For example, this can be used to filter out elements in a batch to reduce its size, or it can be used to move all elements in the batch to a set device.- Parameters:
batch – the batch data returned from the dataset.
step – the current testing batch number.
- class helios.plugins.plugin.CUDAPlugin[source]¶
Bases:
Plugin
Plug-in to move elements of a training batch to a GPU.
This plug-in can be used to move the elements of a training batch to the currently selected device automatically prior to the call to
train_step()
. The device is automatically assigned by thehelios.trainer.Trainer
when training or testing starts.In order to cover the largest possible number of structures, the plug-in can handle the following containers:
Single tensors
Lists. Note that the elements of the list need not all be tensors. If any tensors are present, they are automatically moved to the device.
Dictionaries. Similar to the list, not all the elements of the dictionary have to be tensors. Any tensors are detected automatically.
Warning
The plug-in is not designed to work with nested structures. In other words, if a list of dictionaries is passed in, the plug-in will not recognise any tensors contained inside the dictionary. Similarly, if a dictionary contains nested dictionaries (or any other container), the plug-in won’t recognise them.
Warning
The use of this plug-in requires CUDA being enabled. If CUDA is not present, an exception is raised.
Note
If you require custom handling for your specific data types, you can override the behaviour of the plug-in by deriving from it. See the example below for details.
- Example:
import helios.plug-ins as hlp class MyCUDAPlugin(hlp.CUDAPlugin): def _move_collection_to_device(self, batch: <your-type>): # Suppose our batch is a list: for i in range(len(batch)): batch[i] = batch[i].to(self.device)
- plugin_id = 'cuda'¶
- configure_trainer(trainer: helios.trainer.Trainer) None [source]¶
Register the plug-in instance into the trainer.
The plug-in will be registered under the name
cuda
.- Parameters:
trainer – the trainer instance.
- process_training_batch(batch: Any, state: helios.trainer.TrainingState) Any [source]¶
Move the training batch to the GPU.
- Parameters:
batch – the batch returned by the training dataset.
state – the current training state.