[docs]defget_default_seed()->int:"""Return the default seed."""return_DEFAULT_RNG_SEED
[docs]classDefaultNumpyRNG:""" Default RNG from Numpy. This is intended to serve as a replacement for the legacy random API from Numpy. The class wraps the new Generator instance, which is set to use PCG64. Functionality similar to the modules for PyTorch is provided for easy serialization and restoring. Args: seed: (optional) the initial seed to use. """def__init__(self,seed:int|list[int]|tuple[int]|None=None):"""Create the default RNG."""self._generator=npr.default_rng(seed)@propertydefgenerator(self)->npr.Generator:"""Return the Numpy Generator instance."""returnself._generator
[docs]defstate_dict(self)->abc.Mapping[str,typing.Any]:""" Create a dictionary containing the RNG state. Returns: The state of the RNG. """returnself._generator.bit_generator.state
[docs]defload_state_dict(self,state_dict:abc.Mapping[str,typing.Any])->None:""" Restore the RNG from the given state dictionary. Args: state_dict: the state dictionary. """self._generator.bit_generator.state=state_dict
_DEFAULT_RNG:DefaultNumpyRNG|None=Nonedef_get_safe_default_rng()->DefaultNumpyRNG:global_DEFAULT_RNGif_DEFAULT_RNGisNone:raiseRuntimeError("error: default RNG has not been created. Did you forget to call ""create_default_rng?")return_DEFAULT_RNG
[docs]defcreate_default_numpy_rng(seed:int|list[int]|tuple[int]|None=None):""" Initialize the default RNG with the given seed. Args: seed: (optional) the seed to use. """global_DEFAULT_RNG_DEFAULT_RNG=DefaultNumpyRNG(seed=seed)
[docs]defget_default_numpy_rng()->DefaultNumpyRNG:""" Return the default RNG. Return: The random generator. Raises: RuntimeError: if the default Numpy RNG hasn't been created. """return_get_safe_default_rng()
[docs]defseed_rngs(seed:int|None=None,skip_torch:bool=False)->None:""" Seed the default RNGs with the given seed. If no seed is given, then the default seed from Helios will be used. The RNGs that will be seeded are: PyTorch (+ CUDA if available), stdlib random, and the default Numpy generator. The ``skip_torch`` flag is intended to be used when seeding worker processes for dataloaders. In those cases, the RNGs for PyTorch have already been seeded, so we shouldn't be re-seeding them. Args: seed: optional value to seed the random generators with. skip_torch: if True, torch RNGs won't be seeded. """seed=get_default_seed()ifseedisNoneelseseedifnotskip_torch:torch.manual_seed(seed)iftorch.cuda.is_available():torch.cuda.manual_seed_all(seed)random.seed(seed)create_default_numpy_rng(seed)
[docs]defget_rng_state_dict()->dict[str,typing.Any]:""" Get the state dict for the default RNGs. Default RNGs are: PyTorch (+ CUDA if available) and Random. Returns: The state of all RNGs. """state={"torch":torch.get_rng_state(),"rand":random.getstate(),"numpy":get_default_numpy_rng().state_dict(),}iftorch.cuda.is_available():state["cuda"]=torch.cuda.get_rng_state()returnstate
[docs]defload_rng_state_dict(state_dict:dict[str,typing.Any])->None:""" Restore the default RNGs from the given state dict. See :py:func:`.get_rng_state_dict` for the list of default RNGs. Args: state_dict: the state of the RNGs """torch.set_rng_state(state_dict["torch"])random.setstate(state_dict["rand"])get_default_numpy_rng().load_state_dict(state_dict["numpy"])iftorch.cuda.is_available()and"cuda"instate_dict:torch.cuda.set_rng_state(state_dict["cuda"])