helios.nn.swa_utils

Classes

EMA

Implements Exponential Moving Average (EMA).

Module Contents

class helios.nn.swa_utils.EMA(net: torch.nn.Module, decay: float = 0.9997, device: torch.device | None = None)

Bases: torch.nn.Module

Implements Exponential Moving Average (EMA).

Parameters:
  • net – the bare network on which EMA will be performed.

  • decay – decay rate. Defaults to 0.9997.

  • device – (optional) the device to be used.

property module: torch.nn.Module

Get the underlying network.

update(net: torch.nn.Module) None

Update the weights using EMA from the given network.

set(net: torch.nn.Module) None

Re-set the base weights.

forward(*args: Any, **kwargs: Any)

Evaluate the EMA wrapper on the network inputs.

Parameters:
  • args – named parameters for your network’s forward function.

  • kwargs – keyword arguments for your network’s forward function.