helios.nn.utils

Attributes

NETWORK_REGISTRY

Global instance of the registry for networks.

Functions

create_network(→ torch.nn.Module)

Create the network for the given type.

default_init_weights(→ None)

Initialize network weights.

Module Contents

helios.nn.utils.NETWORK_REGISTRY

Global instance of the registry for networks.

Example

import helios.nn as hln

# This automatically registers your network.
@hln.NETWORK_REGISTRY.register
class MyNetwork:
    ...

# Alternatively you can manually register a network like this:
hln.NETWORK_REGISTRY.register(MyNetwork)
helios.nn.utils.create_network(type_name: str, *args: Any, **kwargs: Any) torch.nn.Module

Create the network for the given type.

Parameters:
  • type_name – the type of the network to create.

  • args – positional arguments to pass into the network.

  • kwargs – keyword arguments to pass into the network.

Returns:

The network.

helios.nn.utils.default_init_weights(module_list: list[torch.nn.Module] | torch.nn.Module, scale: float = 1, bias_fill: float = 0, **kwargs: Any) None

Initialize network weights.

Specifically, this function will default initialize the following types of blocks:

  • torch.nn.Conv2d,

  • torch.nn.Linear,

  • torch.nn.modules.batchnorm._BatchNorm

Parameters:
  • module_list – the list of modules to initialize.

  • scale – scale initialized weights, especially for residual blocks. Defaults to 1.

  • bias_fill – bias fill value. Defaults to 0.

  • kwargs – keyword arguments for the torch.nn.init.kaiming_normal_ function.