helios.nn.swa_utils =================== .. py:module:: helios.nn.swa_utils Classes ------- .. autoapisummary:: helios.nn.swa_utils.EMA Module Contents --------------- .. py:class:: EMA(net: torch.nn.Module, decay: float = 0.9997, device: torch.device | None = None) Bases: :py:obj:`torch.nn.Module` Implements Exponential Moving Average (EMA). :param net: the bare network on which EMA will be performed. :param decay: decay rate. Defaults to 0.9997. :param device: (optional) the device to be used. .. py:property:: module :type: torch.nn.Module Get the underlying network. .. py:method:: update(net: torch.nn.Module) -> None Update the weights using EMA from the given network. .. py:method:: set(net: torch.nn.Module) -> None Re-set the base weights. .. py:method:: forward(*args: Any, **kwargs: Any) Evaluate the EMA wrapper on the network inputs. :param args: named parameters for your network's forward function. :param kwargs: keyword arguments for your network's forward function.