Source code for helios.core.rng

from __future__ import annotations

import random
import typing
from collections import abc

import torch
from numpy import random as npr

_DEFAULT_RNG_SEED = 6691


[docs] def get_default_seed() -> int: """Return the default seed.""" return _DEFAULT_RNG_SEED
[docs] class DefaultNumpyRNG: """ Default RNG from Numpy. This is intended to serve as a replacement for the legacy random API from Numpy. The class wraps the new Generator instance, which is set to use PCG64. Functionality similar to the modules for PyTorch is provided for easy serialization and restoring. Args: seed: (optional) the initial seed to use. """ def __init__(self, seed: int | list[int] | tuple[int] | None = None): """Create the default RNG.""" self._generator = npr.default_rng(seed) @property def generator(self) -> npr.Generator: """Return the Numpy Generator instance.""" return self._generator
[docs] def state_dict(self) -> abc.Mapping[str, typing.Any]: """ Create a dictionary containing the RNG state. Returns: The state of the RNG. """ return self._generator.bit_generator.state
[docs] def load_state_dict(self, state_dict: abc.Mapping[str, typing.Any]) -> None: """ Restore the RNG from the given state dictionary. Args: state_dict: the state dictionary. """ self._generator.bit_generator.state = state_dict
_DEFAULT_RNG: DefaultNumpyRNG | None = None def _get_safe_default_rng() -> DefaultNumpyRNG: global _DEFAULT_RNG if _DEFAULT_RNG is None: raise RuntimeError( "error: default RNG has not been created. Did you forget to call " "create_default_rng?" ) return _DEFAULT_RNG
[docs] def create_default_numpy_rng(seed: int | list[int] | tuple[int] | None = None): """ Initialize the default RNG with the given seed. Args: seed: (optional) the seed to use. """ global _DEFAULT_RNG _DEFAULT_RNG = DefaultNumpyRNG(seed=seed)
[docs] def get_default_numpy_rng() -> DefaultNumpyRNG: """ Return the default RNG. Return: The random generator. Raises: RuntimeError: if the default Numpy RNG hasn't been created. """ return _get_safe_default_rng()
[docs] def seed_rngs(seed: int | None = None, skip_torch: bool = False) -> None: """ Seed the default RNGs with the given seed. If no seed is given, then the default seed from Helios will be used. The RNGs that will be seeded are: PyTorch (+ CUDA if available), stdlib random, and the default Numpy generator. The ``skip_torch`` flag is intended to be used when seeding worker processes for dataloaders. In those cases, the RNGs for PyTorch have already been seeded, so we shouldn't be re-seeding them. Args: seed: optional value to seed the random generators with. skip_torch: if True, torch RNGs won't be seeded. """ seed = get_default_seed() if seed is None else seed if not skip_torch: torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) random.seed(seed) create_default_numpy_rng(seed)
[docs] def get_rng_state_dict() -> dict[str, typing.Any]: """ Get the state dict for the default RNGs. Default RNGs are: PyTorch (+ CUDA if available) and Random. Returns: The state of all RNGs. """ state = { "torch": torch.get_rng_state(), "rand": random.getstate(), "numpy": get_default_numpy_rng().state_dict(), } if torch.cuda.is_available(): state["cuda"] = torch.cuda.get_rng_state() return state
[docs] def load_rng_state_dict(state_dict: dict[str, typing.Any]) -> None: """ Restore the default RNGs from the given state dict. See :py:func:`.get_rng_state_dict` for the list of default RNGs. Args: state_dict: the state of the RNGs """ torch.set_rng_state(state_dict["torch"]) random.setstate(state_dict["rand"]) get_default_numpy_rng().load_state_dict(state_dict["numpy"]) if torch.cuda.is_available() and "cuda" in state_dict: torch.cuda.set_rng_state(state_dict["cuda"])