helios.data.samplers

Attributes

SAMPLER_REGISTRY

Global instance of the registry for samplers.

ResumableSamplerType

Defines the resumable sampler type.

Classes

ResumableSampler

Base class for samplers that are resumable.

ResumableRandomSampler

Random sampler with resumable state.

ResumableSequentialSampler

Sequential sampler with resumable state.

ResumableDistributedSampler

Distributed sampler with resumable state.

Functions

create_sampler(→ ResumableSamplerType)

Create a sampler of the given type.

Module Contents

helios.data.samplers.SAMPLER_REGISTRY

Global instance of the registry for samplers.

Example

import helios.data.samplers as hlds

# This automatically registers your sampler.
@hlds.SAMPLER_REGISTRY.register
class MySampler:
    ...

# Alternatively you can manually register a sampler like this:
hlds.SAMPLER_REGISTRY.register(MySampler)
helios.data.samplers.create_sampler(type_name: str, *args: Any, **kwargs: Any) ResumableSamplerType[source]

Create a sampler of the given type.

This uses the SAMPLER_REGISTRY to look-up sampler types, so ensure your samplers have been registered before using this function.

Parameters:
  • type_name – the type of the transform to create.

  • args – positional arguments to pass into the sampler.

  • kwargs – keyword arguments to pass into the sampler.

Returns:

The constructed sampler.

class helios.data.samplers.ResumableSampler(batch_size: int)[source]

Bases: torch.utils.data.Sampler

Base class for samplers that are resumable.

Let \(b_i\) be the ith batch for a given epoch \(e\). Let the sequence of batches that follow be \(b_{i + 1}, b_{i + 2}, \ldots\). Suppose that on iteration \(i\), batch \(b_i\) is loaded, and training is stopped immediately after. A sampler is defined to be resumable if and only if:

  1. Upon re-starting training on epoch \(e\), the next batch the sampler loads is \(b_{i + 1}\).

  2. The order of the subsequent batches \(b_{i + 2}, \ldots\) must be identical to the order that the sampler would’ve produced for the epoch \(e\) had training not stopped.

Parameters:

batch_size – the number of samples per batch.

property start_iter: int

The starting iteration for the sampler.

set_epoch(epoch: int) None[source]

Set the current epoch for seeding.

class helios.data.samplers.ResumableRandomSampler(data_source: Sized, seed: int = 0, batch_size: int = 1)[source]

Bases: ResumableSampler

Random sampler with resumable state.

This allows training to stop and resume while guaranteeing that the order in which the batches will be returned stays consistent. It is effectively a replacement to the default RandomSampler from PyTorch.

Parameters:
  • data_source – the dataset to sample from.

  • seed – the seed to use for setting up the random generator.

  • batch_size – the number of samples per batch.

__len__() int[source]

Return the length of the dataset.

__iter__() Iterator[int][source]

Retrieve the index of the next sample.

class helios.data.samplers.ResumableSequentialSampler(data_source: Sized, batch_size: int = 1)[source]

Bases: ResumableSampler

Sequential sampler with resumable state.

This allows training to stop and resume while guaranteeing that the order in which the batches will be returned stays consistent. It is effectively a replacement to the default SequentialSampler from PyTorch.

Parameters:
  • data_source – the dataset to sample from.

  • batch_size – the number of samples per batch.

__len__() int[source]

Return the length of the dataset.

__iter__() Iterator[int][source]

Retrieve the index of the next sample.

class helios.data.samplers.ResumableDistributedSampler(dataset: torch.utils.data.Dataset, num_replicas: int | None = None, rank: int | None = None, shuffle: bool = True, seed: int = 0, drop_last: bool = False, batch_size: int = 1)[source]

Bases: torch.utils.data.DistributedSampler

Distributed sampler with resumable state.

This allows training to stop and resume while guaranteeing that the order in which the batches will be returned stays consistent. It is effectively a replacement to the default DistributedSampler from PyTorch.

Parameters:
  • dataset – the dataset to sample from.

  • num_replicas – number of processes for distributed training.

  • rank – (optional) rank of the current process.

  • shuffle – if true, shuffle the indices. Defaults to true.

  • seed – random seed used to shuffle the sampler. Defaults to 0.

  • drop_last – if true, then drop the final sample to make it even across replicas. Defaults to false.

  • batch_size – the number of samples per batch. Defaults to 1.

property start_iter: int

The starting iteration for the sampler.

__iter__() Iterator[int][source]

Retrieve the index of the next sample.

helios.data.samplers.ResumableSamplerType

Defines the resumable sampler type.

A resumable sampler must be derived from either ResumableSampler or ResumableDistributedSampler.