Source code for helios.data.datamodule

from __future__ import annotations

import abc
import copy
import dataclasses as dc
import enum
import typing

import torch
import torch.utils.data as tud

from helios import core
from helios.core import rng

if typing.TYPE_CHECKING:
    from ..trainer import Trainer

from .samplers import (
    ResumableDistributedSampler,
    ResumableRandomSampler,
    ResumableSamplerType,
    ResumableSequentialSampler,
)

DATASET_REGISTRY = core.Registry("dataset")
"""
Global instance of the registry for datasets.

Example:
    .. code-block:: python

        import helios.data as hld

        # This automatically registers your dataset.
        @hld.DATASET_REGISTRY.register
        class MyDataset:
            ...

        # Alternatively you can manually register a dataset like this:
        hld.DATASET_REGISTRY.register(MyDataset)
"""


[docs] def create_dataset(type_name: str, *args: typing.Any, **kwargs: typing.Any): """ Create a dataset of the given type. This uses ``DATASET_REGISTRY`` to look-up dataset types, so ensure your datasets have been registered before using this function. Args: type_name: the type of the dataset to create. args: positional arguments to pass into the dataset. kwargs: keyword arguments to pass into the dataset. Returns: The constructed dataset. """ return DATASET_REGISTRY.get(type_name)(*args, **kwargs)
[docs] class DatasetSplit(enum.Enum): """The different dataset splits.""" TRAIN = 0 VALID = 1 TEST = 2
[docs] @staticmethod def from_str(label: str) -> DatasetSplit: """ Convert the given string to the corresponding enum value. Must be one of "train", "test", or "valid" Args: label: the label to convert. Returns: The corresponding enum value. Raises: ValueError: if the given value is not one of "train", "test", or "valid". """ if label == "train": return DatasetSplit.TRAIN if label == "valid": return DatasetSplit.VALID if label == "test": return DatasetSplit.TEST raise ValueError( "invalid dataset split. Expected one of 'train', 'test', or " f"'valid', but received '{label}'" )
def _seed_worker(worker_id: int) -> None: worker_seed = torch.initial_seed() % 2**32 rng.seed_rngs(worker_seed, skip_torch=True)
[docs] def create_dataloader( dataset: tud.Dataset, random_seed: int = rng.get_default_seed(), batch_size: int = 1, shuffle: bool = False, num_workers: int = 0, pin_memory: bool = True, drop_last: bool = False, debug_mode: bool = False, is_distributed: bool = False, sampler: ResumableSamplerType | None = None, collate_fn: typing.Callable | None = None, ) -> tuple[tud.DataLoader, ResumableSamplerType]: """ Create the dataloader for the given dataset. If no sampler is provided, the choice of sampler will be determined based on the values of is_distributed and shuffle. Specifically, the following logic is used: * If is_distributed, then sampler is :py:class:`~helios.data.samplers.ResumableDistributedSampler`. * Otherwise, if shuffle then sampler is :py:class:`~helios.data.samplers.ResumableRandomSampler`, else :py:class:`~helios.data.samplers.ResumableSequentialSampler`. You may override this behaviour by providing your own sampler instance. .. warning:: If you provide a custom sampler, then it **must** be derived from one of :py:class:`helios.data.samplers.ResumableSampler` or :py:class:`helios.data.samplers.ResumableDistributedSampler`. Args: dataset: the dataset to use. random_seed: value to use as seed for the worker processes. Defaults to the value returned by :py:func:`~helios.core.rng.get_default_seed`. batch_size: number of samplers per batch. Defaults to 1. shuffle: if true, samples are randomly shuffled. Defaults to false. num_workers: number of worker processes for loading data. Defaults to 0. pin_memory: if true, use page-locked device memory. Defaults to true. drop_last: if true, remove the final batch. Defaults to false. debug_mode: if true, then ``num_workers`` will be set to 0. Defaults to false. is_distributed: if true, create the distributed sampler. Defaults to false. sampler: (optional) sampler to use. collate_fn: (optional) function to merge batches. Returns: The dataloader and sampler. Raises: TypeError: if ``sampler`` is not ``None`` and not derived from one of :py:class:`~helios.data.samplers.ResumableDistributedSampler` or :py:class:`~helios.data.samplers.ResumableSampler`. """ assert len(typing.cast(typing.Sized, dataset)) if sampler is None: if is_distributed: sampler = ResumableDistributedSampler( dataset, shuffle=shuffle, drop_last=drop_last, seed=random_seed ) else: if shuffle: sampler = ResumableRandomSampler( dataset, # type: ignore[arg-type] seed=random_seed, batch_size=batch_size, ) else: sampler = ResumableSequentialSampler( dataset, # type: ignore[arg-type] batch_size=batch_size, ) elif not isinstance(sampler, typing.get_args(ResumableSamplerType)): raise TypeError( "error: expected sampler to derive from one of ResumableSampler or " f"ResumableDistributedSampler, but received {type(sampler)}" ) return ( tud.DataLoader( dataset, batch_size=batch_size, num_workers=num_workers if not debug_mode else 0, pin_memory=pin_memory, drop_last=drop_last, sampler=sampler, worker_init_fn=_seed_worker, collate_fn=collate_fn, ), sampler, )
[docs] @dc.dataclass class DataLoaderParams: """ Params used to create the dataloader object. Args: random_seed: value to use as seed for the worker processes. batch_size: number of samplers per batch. shuffle: if true, samples are randomly shuffled. num_workers: number of worker processes for loading data. pin_memory: if true, use page-locked device memory. drop_last: if true, remove the final batch. debug_mode: if true, set number of workers to 0. is_distributed: if true, create the distributed sampler. sampler: (optional) sampler to use. collate_fn: (optional) function to merge batches. """ random_seed: int = rng.get_default_seed() batch_size: int = 1 shuffle: bool = False num_workers: int = 0 pin_memory: bool = True drop_last: bool = False debug_mode: bool = False is_distributed: bool | None = None sampler: ResumableSamplerType | None = None collate_fn: typing.Callable | None = None
[docs] def to_dict(self) -> dict[str, typing.Any]: """Convert the params object to a dictionary using shallow copies.""" return {field.name: getattr(self, field.name) for field in dc.fields(self)}
[docs] @classmethod def from_dict(cls, table: dict[str, typing.Any]): """Create a new params object from the given table.""" params = cls() keys = [field.name for field in dc.fields(params)] # Skip the sampler key, since that one needs to be created differently. for key in keys: if key in table and key != "sampler": setattr(params, key, table[key]) return params
[docs] @dc.dataclass class Dataset: """ The dataset and corresponding data loader params. Args: dataset: the dataset. params: the data loader params. """ dataset: tud.Dataset params: DataLoaderParams
[docs] def dict(self) -> dict[str, typing.Any]: """Convert to a dictionary.""" res = self.params.to_dict() res["dataset"] = self.dataset return res
[docs] class DataModule(abc.ABC): """ Base class that groups together the creation of the main training datasets. The use of this class is to standardize the way datasets and their respective dataloaders are created, thereby allowing consistent settings across models. Example: .. code-block:: python from torchvision.datasets import CIFAR10 from helios import data class MyDataModule(data.DataModule): def prepare_data(self) -> None: # Use this function to prepare the data for your datasets. This will # be called before the distributed processes are created (if using) # so you should not set any state here. CIFAR10(download=True) # download the dataset only. def setup(self) -> None: # Create the training dataset using a DataLoaderParams instance. Note # that you MUST assign it to self._train_dataset. self._train_dataset = self._create_dataset(CIFAR10(train=True), DataLoaderParams(...)) # It is also possible to create a dataset using a table of key-value # pairs that was loaded from a config file or manually created. Let's # use one to create the validation split: settings = {"batch_size": 1, ...} # We can now use it to assign to self._valid_dataset like this: self._valid_dataset = self._create_dataset( CIFAR10(train=False), settings) # Finally, if you need a testing split, you can create it like this: self._test_dataset = self._create_dataset( CIFAR10(train=False), settings) def teardown(self) -> None: # Use this function to clean up any state. It will be called after # training is done. """ def __init__(self) -> None: """Create the data module.""" self._is_distributed: bool = False self._train_dataset: Dataset | None = None self._valid_dataset: Dataset | None = None self._test_dataset: Dataset | None = None self._trainer: Trainer | None = None @property def is_distributed(self) -> bool: """Flag controlling whether distributed training is being used or not.""" return self._is_distributed @is_distributed.setter def is_distributed(self, val: bool) -> None: self._is_distributed = val @property def trainer(self) -> Trainer: """Reference to the trainer.""" return core.get_from_optional(self._trainer) @trainer.setter def trainer(self, t) -> None: self._trainer = t @property def train_dataset(self) -> tud.Dataset | None: """The training dataset (if available).""" if self._train_dataset is not None: return self._train_dataset.dataset return None @property def valid_dataset(self) -> tud.Dataset | None: """The validation dataset (if available).""" if self._valid_dataset is not None: return self._valid_dataset.dataset return None @property def test_dataset(self) -> tud.Dataset | None: """The testing dataset (if available).""" if self._test_dataset is not None: return self._test_dataset.dataset return None
[docs] def prepare_data(self) -> None: # noqa: B027 """ Prepare data for training. This can include downloading datasets, preparing caches, or streaming them from external services. This function will be called on the primary process when using distributed training (will be called prior to initialization of the processes) so don't store any state here. """
[docs] @abc.abstractmethod def setup(self) -> None: """Construct all required datasets."""
[docs] def train_dataloader(self) -> tuple[tud.DataLoader, ResumableSamplerType] | None: """Create the train dataloader (if available).""" if self._train_dataset is None: return None return self._create_dataloader(self._train_dataset)
[docs] def valid_dataloader(self) -> tuple[tud.DataLoader, ResumableSamplerType] | None: """Create the valid dataloader (if available).""" if self._valid_dataset is None: return None return self._create_dataloader(self._valid_dataset)
[docs] def test_dataloader(self) -> tuple[tud.DataLoader, ResumableSamplerType] | None: """Create the test dataloader (if available).""" if self._test_dataset is None: return None return self._create_dataloader(self._test_dataset)
[docs] def teardown(self) -> None: # noqa: B027 """Clean up any state after training is over."""
def _create_dataset( self, dataset: tud.Dataset, params: DataLoaderParams | dict[str, typing.Any] ) -> Dataset: """ Create a dataset object from a Torch dataset and a params. This is a convenience function to help you create the dataset objects. The params object can either be the fully filled in DataLoaderParams instance, or it can be a dict containing all of the necessary settings. If a dict is passed in, it will be used to populate the corresponding DataLoaderParams object. Args: dataset: the dataset. params: either a params object or a dict. Returns: The new dataset object. """ return Dataset( dataset, copy.deepcopy(params) if isinstance(params, DataLoaderParams) else DataLoaderParams.from_dict(params), ) def _create_dataloader( self, dataset: Dataset ) -> tuple[tud.DataLoader, ResumableSamplerType]: # Only override the distributed flag if it hasn't been set by the user. if dataset.params.is_distributed is None: dataset.params.is_distributed = self._is_distributed return create_dataloader(**dataset.dict())