helios.nn.utils¶
Attributes¶
Global instance of the registry for networks. |
Functions¶
|
Create the network for the given type. |
|
Initialize network weights. |
Module Contents¶
- helios.nn.utils.NETWORK_REGISTRY¶
Global instance of the registry for networks.
Example
import helios.nn as hln # This automatically registers your network. @hln.NETWORK_REGISTRY.register class MyNetwork: ... # Alternatively you can manually register a network like this: hln.NETWORK_REGISTRY.register(MyNetwork)
- helios.nn.utils.create_network(type_name: str, *args: Any, **kwargs: Any) torch.nn.Module [source]¶
Create the network for the given type.
- Parameters:
type_name – the type of the network to create.
args – positional arguments to pass into the network.
kwargs – keyword arguments to pass into the network.
- Returns:
The network.
- helios.nn.utils.default_init_weights(module_list: list[torch.nn.Module] | torch.nn.Module, scale: float = 1, bias_fill: float = 0, **kwargs: Any) None [source]¶
Initialize network weights.
Specifically, this function will default initialize the following types of blocks:
torch.nn.Conv2d
,torch.nn.Linear
,torch.nn.modules.batchnorm._BatchNorm
- Parameters:
module_list – the list of modules to initialize.
scale – scale initialized weights, especially for residual blocks. Defaults to 1.
bias_fill – bias fill value. Defaults to 0.
kwargs – keyword arguments for the
torch.nn.init.kaiming_normal_
function.