helios.data.samplers¶
Attributes¶
Global instance of the registry for samplers. |
|
Defines the resumable sampler type. |
Classes¶
Base class for samplers that are resumable. |
|
Random sampler with resumable state. |
|
Sequential sampler with resumable state. |
|
Distributed sampler with resumable state. |
Functions¶
|
Create a sampler of the given type. |
Module Contents¶
- helios.data.samplers.SAMPLER_REGISTRY¶
Global instance of the registry for samplers.
Example
import helios.data.samplers as hlds # This automatically registers your sampler. @hlds.SAMPLER_REGISTRY.register class MySampler: ... # Alternatively you can manually register a sampler like this: hlds.SAMPLER_REGISTRY.register(MySampler)
- helios.data.samplers.create_sampler(type_name: str, *args: Any, **kwargs: Any) ResumableSamplerType [source]¶
Create a sampler of the given type.
This uses the SAMPLER_REGISTRY to look-up sampler types, so ensure your samplers have been registered before using this function.
- Parameters:
type_name – the type of the transform to create.
args – positional arguments to pass into the sampler.
kwargs – keyword arguments to pass into the sampler.
- Returns:
The constructed sampler.
- class helios.data.samplers.ResumableSampler(batch_size: int)[source]¶
Bases:
torch.utils.data.Sampler
Base class for samplers that are resumable.
Let \(b_i\) be the ith batch for a given epoch \(e\). Let the sequence of batches that follow be \(b_{i + 1}, b_{i + 2}, \ldots\). Suppose that on iteration \(i\), batch \(b_i\) is loaded, and training is stopped immediately after. A sampler is defined to be resumable if and only if:
Upon re-starting training on epoch \(e\), the next batch the sampler loads is \(b_{i + 1}\).
The order of the subsequent batches \(b_{i + 2}, \ldots\) must be identical to the order that the sampler would’ve produced for the epoch \(e\) had training not stopped.
- Parameters:
batch_size – the number of samples per batch.
- property start_iter: int¶
The starting iteration for the sampler.
- class helios.data.samplers.ResumableRandomSampler(data_source: Sized, seed: int = 0, batch_size: int = 1)[source]¶
Bases:
ResumableSampler
Random sampler with resumable state.
This allows training to stop and resume while guaranteeing that the order in which the batches will be returned stays consistent. It is effectively a replacement to the default
RandomSampler
from PyTorch.- Parameters:
data_source – the dataset to sample from.
seed – the seed to use for setting up the random generator.
batch_size – the number of samples per batch.
- class helios.data.samplers.ResumableSequentialSampler(data_source: Sized, batch_size: int = 1)[source]¶
Bases:
ResumableSampler
Sequential sampler with resumable state.
This allows training to stop and resume while guaranteeing that the order in which the batches will be returned stays consistent. It is effectively a replacement to the default
SequentialSampler
from PyTorch.- Parameters:
data_source – the dataset to sample from.
batch_size – the number of samples per batch.
- class helios.data.samplers.ResumableDistributedSampler(dataset: torch.utils.data.Dataset, num_replicas: int | None = None, rank: int | None = None, shuffle: bool = True, seed: int = 0, drop_last: bool = False, batch_size: int = 1)[source]¶
Bases:
torch.utils.data.DistributedSampler
Distributed sampler with resumable state.
This allows training to stop and resume while guaranteeing that the order in which the batches will be returned stays consistent. It is effectively a replacement to the default
DistributedSampler
from PyTorch.- Parameters:
dataset – the dataset to sample from.
num_replicas – number of processes for distributed training.
rank – (optional) rank of the current process.
shuffle – if true, shuffle the indices. Defaults to true.
seed – random seed used to shuffle the sampler. Defaults to 0.
drop_last – if true, then drop the final sample to make it even across replicas. Defaults to false.
batch_size – the number of samples per batch. Defaults to 1.
- property start_iter: int¶
The starting iteration for the sampler.
- helios.data.samplers.ResumableSamplerType¶
Defines the resumable sampler type.
A resumable sampler must be derived from either
ResumableSampler
orResumableDistributedSampler
.