[docs]classEMA(nn.Module):""" Implements Exponential Moving Average (EMA). Args: net: the bare network on which EMA will be performed. decay: decay rate. Defaults to 0.9997. device: (optional) the device to be used. """def__init__(self,net:nn.Module,decay:float=0.9997,device:torch.device|None=None):"""Create the EMA wrapper."""super().__init__()self._module=copy.deepcopy(net)self._module=self._module.eval()self._decay=decayself._device=deviceifself._deviceisnotNone:self._module.to(device=device)@propertydefmodule(self)->nn.Module:"""Get the underlying network."""returnself._module@torch.no_grad()def_update(self,net:nn.Module,update_fn:typing.Callable)->None:forema_v,net_vinzip(self._module.state_dict().values(),net.state_dict().values(),strict=True,):ifself._device:net_v=net_v.to(device=self._device)ema_v.copy_(update_fn(ema_v,net_v))
[docs]defupdate(self,net:nn.Module)->None:"""Update the weights using EMA from the given network."""self._update(net,update_fn=lambdae,m:self._decay*e+(1.0-self._decay)*m)
[docs]defset(self,net:nn.Module)->None:"""Re-set the base weights."""self._update(net,update_fn=lambdae,m:m)
[docs]defforward(self,*args:typing.Any,**kwargs:typing.Any):""" Evaluate the EMA wrapper on the network inputs. Args: args: named parameters for your network's forward function. kwargs: keyword arguments for your network's forward function. """returnself._module(*args,**kwargs)