helios.data.datamodule ====================== .. py:module:: helios.data.datamodule Attributes ---------- .. autoapisummary:: helios.data.datamodule.DATASET_REGISTRY Classes ------- .. autoapisummary:: helios.data.datamodule.DatasetSplit helios.data.datamodule.DataLoaderParams helios.data.datamodule.Dataset helios.data.datamodule.DataModule Functions --------- .. autoapisummary:: helios.data.datamodule.create_dataset helios.data.datamodule.create_dataloader Module Contents --------------- .. py:data:: DATASET_REGISTRY Global instance of the registry for datasets. .. rubric:: Example .. code-block:: python import helios.data as hld # This automatically registers your dataset. @hld.DATASET_REGISTRY.register class MyDataset: ... # Alternatively you can manually register a dataset like this: hld.DATASET_REGISTRY.register(MyDataset) .. py:function:: create_dataset(type_name: str, *args: Any, **kwargs: Any) Create a dataset of the given type. This uses ``DATASET_REGISTRY`` to look-up dataset types, so ensure your datasets have been registered before using this function. :param type_name: the type of the dataset to create. :param args: positional arguments to pass into the dataset. :param kwargs: keyword arguments to pass into the dataset. :returns: The constructed dataset. .. py:class:: DatasetSplit(*args, **kwds) Bases: :py:obj:`enum.Enum` The different dataset splits. .. py:attribute:: TRAIN :value: 0 .. py:attribute:: VALID :value: 1 .. py:attribute:: TEST :value: 2 .. py:method:: from_str(label: str) -> DatasetSplit :staticmethod: Convert the given string to the corresponding enum value. Must be one of "train", "test", or "valid" :param label: the label to convert. :returns: The corresponding enum value. :raises ValueError: if the given value is not one of "train", "test", or "valid". .. py:function:: create_dataloader(dataset: torch.utils.data.Dataset, random_seed: int = rng.get_default_seed(), batch_size: int = 1, shuffle: bool = False, num_workers: int = 0, pin_memory: bool = True, drop_last: bool = False, debug_mode: bool = False, is_distributed: bool = False, sampler: helios.data.samplers.ResumableSamplerType | None = None, collate_fn: Callable | None = None) -> tuple[torch.utils.data.DataLoader, helios.data.samplers.ResumableSamplerType] Create the dataloader for the given dataset. If no sampler is provided, the choice of sampler will be determined based on the values of is_distributed and shuffle. Specifically, the following logic is used: * If is_distributed, then sampler is :py:class:`~helios.data.samplers.ResumableDistributedSampler`. * Otherwise, if shuffle then sampler is :py:class:`~helios.data.samplers.ResumableRandomSampler`, else :py:class:`~helios.data.samplers.ResumableSequentialSampler`. You may override this behaviour by providing your own sampler instance. .. warning:: If you provide a custom sampler, then it **must** be derived from one of :py:class:`helios.data.samplers.ResumableSampler` or :py:class:`helios.data.samplers.ResumableDistributedSampler`. :param dataset: the dataset to use. :param random_seed: value to use as seed for the worker processes. Defaults to the value returned by :py:func:`~helios.core.rng.get_default_seed`. :param batch_size: number of samplers per batch. Defaults to 1. :param shuffle: if true, samples are randomly shuffled. Defaults to false. :param num_workers: number of worker processes for loading data. Defaults to 0. :param pin_memory: if true, use page-locked device memory. Defaults to true. :param drop_last: if true, remove the final batch. Defaults to false. :param debug_mode: if true, then ``num_workers`` will be set to 0. Defaults to false. :param is_distributed: if true, create the distributed sampler. Defaults to false. :param sampler: (optional) sampler to use. :param collate_fn: (optional) function to merge batches. :returns: The dataloader and sampler. :raises TypeError: if ``sampler`` is not ``None`` and not derived from one of :py:class:`~helios.data.samplers.ResumableDistributedSampler` or :py:class:`~helios.data.samplers.ResumableSampler`. .. py:class:: DataLoaderParams Params used to create the dataloader object. :param random_seed: value to use as seed for the worker processes. :param batch_size: number of samplers per batch. :param shuffle: if true, samples are randomly shuffled. :param num_workers: number of worker processes for loading data. :param pin_memory: if true, use page-locked device memory. :param drop_last: if true, remove the final batch. :param debug_mode: if true, set number of workers to 0. :param is_distributed: if true, create the distributed sampler. :param sampler: (optional) sampler to use. :param collate_fn: (optional) function to merge batches. .. py:attribute:: random_seed :type: int .. py:attribute:: batch_size :type: int :value: 1 .. py:attribute:: shuffle :type: bool :value: False .. py:attribute:: num_workers :type: int :value: 0 .. py:attribute:: pin_memory :type: bool :value: True .. py:attribute:: drop_last :type: bool :value: False .. py:attribute:: debug_mode :type: bool :value: False .. py:attribute:: is_distributed :type: bool | None :value: None .. py:attribute:: sampler :type: helios.data.samplers.ResumableSamplerType | None :value: None .. py:attribute:: collate_fn :type: Callable | None :value: None .. py:method:: to_dict() -> dict[str, Any] Convert the params object to a dictionary using shallow copies. .. py:method:: from_dict(table: dict[str, Any]) :classmethod: Create a new params object from the given table. .. py:class:: Dataset The dataset and corresponding data loader params. :param dataset: the dataset. :param params: the data loader params. .. py:attribute:: dataset :type: torch.utils.data.Dataset .. py:attribute:: params :type: DataLoaderParams .. py:method:: dict() -> dict[str, Any] Convert to a dictionary. .. py:class:: DataModule Bases: :py:obj:`abc.ABC` Base class that groups together the creation of the main training datasets. The use of this class is to standardize the way datasets and their respective dataloaders are created, thereby allowing consistent settings across models. .. rubric:: Example .. code-block:: python from torchvision.datasets import CIFAR10 from helios import data class MyDataModule(data.DataModule): def prepare_data(self) -> None: # Use this function to prepare the data for your datasets. This will # be called before the distributed processes are created (if using) # so you should not set any state here. CIFAR10(download=True) # download the dataset only. def setup(self) -> None: # Create the training dataset using a DataLoaderParams instance. Note # that you MUST assign it to self._train_dataset. self._train_dataset = self._create_dataset(CIFAR10(train=True), DataLoaderParams(...)) # It is also possible to create a dataset using a table of key-value # pairs that was loaded from a config file or manually created. Let's # use one to create the validation split: settings = {"batch_size": 1, ...} # We can now use it to assign to self._valid_dataset like this: self._valid_dataset = self._create_dataset( CIFAR10(train=False), settings) # Finally, if you need a testing split, you can create it like this: self._test_dataset = self._create_dataset( CIFAR10(train=False), settings) def teardown(self) -> None: # Use this function to clean up any state. It will be called after # training is done. .. py:property:: is_distributed :type: bool Flag controlling whether distributed training is being used or not. .. py:property:: trainer :type: helios.trainer.Trainer Reference to the trainer. .. py:property:: train_dataset :type: torch.utils.data.Dataset | None The training dataset (if available). .. py:property:: valid_dataset :type: torch.utils.data.Dataset | None The validation dataset (if available). .. py:property:: test_dataset :type: torch.utils.data.Dataset | None The testing dataset (if available). .. py:method:: prepare_data() -> None Prepare data for training. This can include downloading datasets, preparing caches, or streaming them from external services. This function will be called on the primary process when using distributed training (will be called prior to initialization of the processes) so don't store any state here. .. py:method:: setup() -> None :abstractmethod: Construct all required datasets. .. py:method:: train_dataloader() -> tuple[torch.utils.data.DataLoader, helios.data.samplers.ResumableSamplerType] | None Create the train dataloader (if available). .. py:method:: valid_dataloader() -> tuple[torch.utils.data.DataLoader, helios.data.samplers.ResumableSamplerType] | None Create the valid dataloader (if available). .. py:method:: test_dataloader() -> tuple[torch.utils.data.DataLoader, helios.data.samplers.ResumableSamplerType] | None Create the test dataloader (if available). .. py:method:: teardown() -> None Clean up any state after training is over.