helios.data.datamodule

Attributes

DATASET_REGISTRY

Global instance of the registry for datasets.

Classes

DatasetSplit

The different dataset splits.

DataLoaderParams

Params used to create the dataloader object.

Dataset

The dataset and corresponding data loader params.

DataModule

Base class that groups together the creation of the main training datasets.

Functions

create_dataset(type_name, *args, **kwargs)

Create a dataset of the given type.

create_dataloader(, batch_size, shuffle, num_workers, ...)

Create the dataloader for the given dataset.

Module Contents

helios.data.datamodule.DATASET_REGISTRY

Global instance of the registry for datasets.

Example

import helios.data as hld

# This automatically registers your dataset.
@hld.DATASET_REGISTRY.register
class MyDataset:
    ...

# Alternatively you can manually register a dataset like this:
hld.DATASET_REGISTRY.register(MyDataset)
helios.data.datamodule.create_dataset(type_name: str, *args: Any, **kwargs: Any)

Create a dataset of the given type.

This uses DATASET_REGISTRY to look-up dataset types, so ensure your datasets have been registered before using this function.

Parameters:
  • type_name – the type of the dataset to create.

  • args – positional arguments to pass into the dataset.

  • kwargs – keyword arguments to pass into the dataset.

Returns:

The constructed dataset.

class helios.data.datamodule.DatasetSplit(*args, **kwds)

Bases: enum.Enum

The different dataset splits.

TRAIN = 0
VALID = 1
TEST = 2
static from_str(label: str) DatasetSplit

Convert the given string to the corresponding enum value.

Must be one of “train”, “test”, or “valid”

Parameters:

label – the label to convert.

Returns:

The corresponding enum value.

Raises:

ValueError – if the given value is not one of “train”, “test”, or “valid”.

helios.data.datamodule.create_dataloader(dataset: torch.utils.data.Dataset, random_seed: int = rng.get_default_seed(), batch_size: int = 1, shuffle: bool = False, num_workers: int = 0, pin_memory: bool = True, drop_last: bool = False, debug_mode: bool = False, is_distributed: bool = False, sampler: helios.data.samplers.ResumableSamplerType | None = None, collate_fn: Callable | None = None) tuple[torch.utils.data.DataLoader, helios.data.samplers.ResumableSamplerType]

Create the dataloader for the given dataset.

If no sampler is provided, the choice of sampler will be determined based on the values of is_distributed and shuffle. Specifically, the following logic is used:

You may override this behaviour by providing your own sampler instance.

Warning

If you provide a custom sampler, then it must be derived from one of helios.data.samplers.ResumableSampler or helios.data.samplers.ResumableDistributedSampler.

Parameters:
  • dataset – the dataset to use.

  • random_seed – value to use as seed for the worker processes. Defaults to the value returned by get_default_seed().

  • batch_size – number of samplers per batch. Defaults to 1.

  • shuffle – if true, samples are randomly shuffled. Defaults to false.

  • num_workers – number of worker processes for loading data. Defaults to 0.

  • pin_memory – if true, use page-locked device memory. Defaults to true.

  • drop_last – if true, remove the final batch. Defaults to false.

  • debug_mode – if true, then num_workers will be set to 0. Defaults to false.

  • is_distributed – if true, create the distributed sampler. Defaults to false.

  • sampler – (optional) sampler to use.

  • collate_fn – (optional) function to merge batches.

Returns:

The dataloader and sampler.

Raises:

TypeError – if sampler is not None and not derived from one of ResumableDistributedSampler or ResumableSampler.

class helios.data.datamodule.DataLoaderParams

Params used to create the dataloader object.

Parameters:
  • random_seed – value to use as seed for the worker processes.

  • batch_size – number of samplers per batch.

  • shuffle – if true, samples are randomly shuffled.

  • num_workers – number of worker processes for loading data.

  • pin_memory – if true, use page-locked device memory.

  • drop_last – if true, remove the final batch.

  • debug_mode – if true, set number of workers to 0.

  • is_distributed – if true, create the distributed sampler.

  • sampler – (optional) sampler to use.

  • collate_fn – (optional) function to merge batches.

random_seed: int
batch_size: int = 1
shuffle: bool = False
num_workers: int = 0
pin_memory: bool = True
drop_last: bool = False
debug_mode: bool = False
is_distributed: bool | None = None
sampler: helios.data.samplers.ResumableSamplerType | None = None
collate_fn: Callable | None = None
to_dict() dict[str, Any]

Convert the params object to a dictionary using shallow copies.

classmethod from_dict(table: dict[str, Any])

Create a new params object from the given table.

class helios.data.datamodule.Dataset

The dataset and corresponding data loader params.

Parameters:
  • dataset – the dataset.

  • params – the data loader params.

dataset: torch.utils.data.Dataset
params: DataLoaderParams
dict() dict[str, Any]

Convert to a dictionary.

class helios.data.datamodule.DataModule

Bases: abc.ABC

Base class that groups together the creation of the main training datasets.

The use of this class is to standardize the way datasets and their respective dataloaders are created, thereby allowing consistent settings across models.

Example

from torchvision.datasets import CIFAR10
from helios import data

class MyDataModule(data.DataModule):
    def prepare_data(self) -> None:
        # Use this function to prepare the data for your datasets. This will
        # be called before the distributed processes are created (if using)
        # so you should not set any state here.
        CIFAR10(download=True) # download the dataset only.

    def setup(self) -> None:
        # Create the training dataset using a DataLoaderParams instance. Note
        # that you MUST assign it to self._train_dataset.
        self._train_dataset = self._create_dataset(CIFAR10(train=True),
                                                   DataLoaderParams(...))

        # It is also possible to create a dataset using a table of key-value
        # pairs that was loaded from a config file or manually created. Let's
        # use one to create the validation split:
        settings = {"batch_size": 1, ...}

        # We can now use it to assign to self._valid_dataset like this:
        self._valid_dataset = self._create_dataset(
                CIFAR10(train=False), settings)

        # Finally, if you need a testing split, you can create it like this:
        self._test_dataset = self._create_dataset(
                CIFAR10(train=False), settings)

    def teardown(self) -> None:
        # Use this function to clean up any state. It will be called after
        # training is done.
property is_distributed: bool

Flag controlling whether distributed training is being used or not.

property trainer: helios.trainer.Trainer

Reference to the trainer.

property train_dataset: torch.utils.data.Dataset | None

The training dataset (if available).

property valid_dataset: torch.utils.data.Dataset | None

The validation dataset (if available).

property test_dataset: torch.utils.data.Dataset | None

The testing dataset (if available).

prepare_data() None

Prepare data for training.

This can include downloading datasets, preparing caches, or streaming them from external services. This function will be called on the primary process when using distributed training (will be called prior to initialization of the processes) so don’t store any state here.

abstract setup() None

Construct all required datasets.

train_dataloader() tuple[torch.utils.data.DataLoader, helios.data.samplers.ResumableSamplerType] | None

Create the train dataloader (if available).

valid_dataloader() tuple[torch.utils.data.DataLoader, helios.data.samplers.ResumableSamplerType] | None

Create the valid dataloader (if available).

test_dataloader() tuple[torch.utils.data.DataLoader, helios.data.samplers.ResumableSamplerType] | None

Create the test dataloader (if available).

teardown() None

Clean up any state after training is over.