helios.nn.swa_utils¶
Classes¶
Implements Exponential Moving Average (EMA). |
Module Contents¶
- class helios.nn.swa_utils.EMA(net: torch.nn.Module, decay: float = 0.9997, device: torch.device | None = None)[source]¶
Bases:
torch.nn.Module
Implements Exponential Moving Average (EMA).
- Parameters:
net – the bare network on which EMA will be performed.
decay – decay rate. Defaults to 0.9997.
device – (optional) the device to be used.
- property module: torch.nn.Module¶
Get the underlying network.