helios.nn.swa_utils¶
Classes¶
Implements Exponential Moving Average (EMA). |
Module Contents¶
- class helios.nn.swa_utils.EMA(net: torch.nn.Module, decay: float = 0.9997, device: torch.device | None = None)¶
Bases:
torch.nn.Module
Implements Exponential Moving Average (EMA).
- Parameters:
net – the bare network on which EMA will be performed.
decay – decay rate. Defaults to 0.9997.
device – (optional) the device to be used.
- property module: torch.nn.Module¶
Get the underlying network.
- update(net: torch.nn.Module) None ¶
Update the weights using EMA from the given network.
- set(net: torch.nn.Module) None ¶
Re-set the base weights.
- forward(*args: Any, **kwargs: Any)¶
Evaluate the EMA wrapper on the network inputs.
- Parameters:
args – named parameters for your network’s forward function.
kwargs – keyword arguments for your network’s forward function.