importtypingimporttorchfromtorchimportnnfromtorch.nnimportinitfromtorch.nn.modulesimportbatchnormasbnfromheliosimportcoreNETWORK_REGISTRY=core.Registry("network")"""Global instance of the registry for networks.Example: .. code-block:: python import helios.nn as hln # This automatically registers your network. @hln.NETWORK_REGISTRY.register class MyNetwork: ... # Alternatively you can manually register a network like this: hln.NETWORK_REGISTRY.register(MyNetwork)"""
[docs]defcreate_network(type_name:str,*args:typing.Any,**kwargs:typing.Any)->nn.Module:""" Create the network for the given type. Args: type_name: the type of the network to create. args: positional arguments to pass into the network. kwargs: keyword arguments to pass into the network. Returns: The network. """returnNETWORK_REGISTRY.get(type_name)(*args,**kwargs)
[docs]@torch.no_grad()defdefault_init_weights(module_list:list[nn.Module]|nn.Module,scale:float=1,bias_fill:float=0,**kwargs:typing.Any,)->None:""" Initialize network weights. Specifically, this function will default initialize the following types of blocks: * ``torch.nn.Conv2d``, * ``torch.nn.Linear``, * ``torch.nn.modules.batchnorm._BatchNorm`` Args: module_list: the list of modules to initialize. scale: scale initialized weights, especially for residual blocks. Defaults to 1. bias_fill: bias fill value. Defaults to 0. kwargs: keyword arguments for the ``torch.nn.init.kaiming_normal_`` function. """ifnotisinstance(module_list,list):module_list=[module_list]formoduleinmodule_list:forminmodule.modules():ifisinstance(m,nn.Conv2d|nn.Linear):init.kaiming_normal_(m.weight,**kwargs)m.weight.data*=scaleifm.biasisnotNone:m.bias.data.fill_(bias_fill)elifisinstance(m,bn._BatchNorm):# noqa: SLF001init.constant_(m.weight,1)ifm.biasisnotNone:m.bias.data.fill_(bias_fill)