helios.data.datamodule¶
Attributes¶
Global instance of the registry for datasets. |
Classes¶
The different dataset splits. |
|
Params used to create the dataloader object. |
|
The dataset and corresponding data loader params. |
|
Base class that groups together the creation of the main training datasets. |
Functions¶
|
Create a dataset of the given type. |
|
Create the dataloader for the given dataset. |
Module Contents¶
- helios.data.datamodule.DATASET_REGISTRY¶
Global instance of the registry for datasets.
Example
import helios.data as hld # This automatically registers your dataset. @hld.DATASET_REGISTRY.register class MyDataset: ... # Alternatively you can manually register a dataset like this: hld.DATASET_REGISTRY.register(MyDataset)
- helios.data.datamodule.create_dataset(type_name: str, *args: Any, **kwargs: Any)[source]¶
Create a dataset of the given type.
This uses
DATASET_REGISTRY
to look-up dataset types, so ensure your datasets have been registered before using this function.- Parameters:
type_name – the type of the dataset to create.
args – positional arguments to pass into the dataset.
kwargs – keyword arguments to pass into the dataset.
- Returns:
The constructed dataset.
- class helios.data.datamodule.DatasetSplit(*args, **kwds)[source]¶
Bases:
enum.Enum
The different dataset splits.
- TRAIN = 0¶
- VALID = 1¶
- TEST = 2¶
- static from_str(label: str) DatasetSplit [source]¶
Convert the given string to the corresponding enum value.
Must be one of “train”, “test”, or “valid”
- Parameters:
label – the label to convert.
- Returns:
The corresponding enum value.
- Raises:
ValueError – if the given value is not one of “train”, “test”, or “valid”.
- helios.data.datamodule.create_dataloader(dataset: torch.utils.data.Dataset, random_seed: int = rng.get_default_seed(), batch_size: int = 1, shuffle: bool = False, num_workers: int = 0, pin_memory: bool = True, drop_last: bool = False, debug_mode: bool = False, is_distributed: bool = False, sampler: helios.data.samplers.ResumableSamplerType | None = None, collate_fn: Callable | None = None) tuple[torch.utils.data.DataLoader, helios.data.samplers.ResumableSamplerType] [source]¶
Create the dataloader for the given dataset.
If no sampler is provided, the choice of sampler will be determined based on the values of is_distributed and shuffle. Specifically, the following logic is used:
If is_distributed, then sampler is
ResumableDistributedSampler
.Otherwise, if shuffle then sampler is
ResumableRandomSampler
, elseResumableSequentialSampler
.
You may override this behaviour by providing your own sampler instance.
Warning
If you provide a custom sampler, then it must be derived from one of
helios.data.samplers.ResumableSampler
orhelios.data.samplers.ResumableDistributedSampler
.- Parameters:
dataset – the dataset to use.
random_seed – value to use as seed for the worker processes. Defaults to the value returned by
get_default_seed()
.batch_size – number of samplers per batch. Defaults to 1.
shuffle – if true, samples are randomly shuffled. Defaults to false.
num_workers – number of worker processes for loading data. Defaults to 0.
pin_memory – if true, use page-locked device memory. Defaults to true.
drop_last – if true, remove the final batch. Defaults to false.
debug_mode – if true, then
num_workers
will be set to 0. Defaults to false.is_distributed – if true, create the distributed sampler. Defaults to false.
sampler – (optional) sampler to use.
collate_fn – (optional) function to merge batches.
- Returns:
The dataloader and sampler.
- Raises:
TypeError – if
sampler
is notNone
and not derived from one ofResumableDistributedSampler
orResumableSampler
.
- class helios.data.datamodule.DataLoaderParams[source]¶
Params used to create the dataloader object.
- Parameters:
random_seed – value to use as seed for the worker processes.
batch_size – number of samplers per batch.
shuffle – if true, samples are randomly shuffled.
num_workers – number of worker processes for loading data.
pin_memory – if true, use page-locked device memory.
drop_last – if true, remove the final batch.
debug_mode – if true, set number of workers to 0.
is_distributed – if true, create the distributed sampler.
sampler – (optional) sampler to use.
collate_fn – (optional) function to merge batches.
- random_seed: int¶
- batch_size: int = 1¶
- shuffle: bool = False¶
- num_workers: int = 0¶
- pin_memory: bool = True¶
- drop_last: bool = False¶
- debug_mode: bool = False¶
- is_distributed: bool | None = None¶
- sampler: helios.data.samplers.ResumableSamplerType | None = None¶
- collate_fn: Callable | None = None¶
- class helios.data.datamodule.Dataset[source]¶
The dataset and corresponding data loader params.
- Parameters:
dataset – the dataset.
params – the data loader params.
- dataset: torch.utils.data.Dataset¶
- params: DataLoaderParams¶
- class helios.data.datamodule.DataModule[source]¶
Bases:
abc.ABC
Base class that groups together the creation of the main training datasets.
The use of this class is to standardize the way datasets and their respective dataloaders are created, thereby allowing consistent settings across models.
Example
from torchvision.datasets import CIFAR10 from helios import data class MyDataModule(data.DataModule): def prepare_data(self) -> None: # Use this function to prepare the data for your datasets. This will # be called before the distributed processes are created (if using) # so you should not set any state here. CIFAR10(download=True) # download the dataset only. def setup(self) -> None: # Create the training dataset using a DataLoaderParams instance. Note # that you MUST assign it to self._train_dataset. self._train_dataset = self._create_dataset(CIFAR10(train=True), DataLoaderParams(...)) # It is also possible to create a dataset using a table of key-value # pairs that was loaded from a config file or manually created. Let's # use one to create the validation split: settings = {"batch_size": 1, ...} # We can now use it to assign to self._valid_dataset like this: self._valid_dataset = self._create_dataset( CIFAR10(train=False), settings) # Finally, if you need a testing split, you can create it like this: self._test_dataset = self._create_dataset( CIFAR10(train=False), settings) def teardown(self) -> None: # Use this function to clean up any state. It will be called after # training is done.
- property is_distributed: bool¶
Flag controlling whether distributed training is being used or not.
- property trainer: helios.trainer.Trainer¶
Reference to the trainer.
- property train_dataset: torch.utils.data.Dataset | None¶
The training dataset (if available).
- property valid_dataset: torch.utils.data.Dataset | None¶
The validation dataset (if available).
- property test_dataset: torch.utils.data.Dataset | None¶
The testing dataset (if available).
- prepare_data() None [source]¶
Prepare data for training.
This can include downloading datasets, preparing caches, or streaming them from external services. This function will be called on the primary process when using distributed training (will be called prior to initialization of the processes) so don’t store any state here.
- train_dataloader() tuple[torch.utils.data.DataLoader, helios.data.samplers.ResumableSamplerType] | None [source]¶
Create the train dataloader (if available).
- valid_dataloader() tuple[torch.utils.data.DataLoader, helios.data.samplers.ResumableSamplerType] | None [source]¶
Create the valid dataloader (if available).