[docs]classWeightedLoss(nn.Module,metaclass=abc.ABCMeta):""" Defines a base class for weighted losses. The value of the final loss is determined by the following formula: .. math:: L_w = w * L where :math:`w` is the weight and :math:`L` is the loss function. Example: .. code-block:: python class MyLoss(WeightedLoss): ... def _eval(self, ...): return my_loss_function(...) Args: loss_weight: the weight of the loss function. Defaults to 1. """def__init__(self,loss_weight:float=1.0):"""Create the weighted loss."""super().__init__()self._loss_weight=loss_weight@abc.abstractmethoddef_eval(self,*args:typing.Any,**kwargs:typing.Any)->typing.Any:""" Evaluate the loss function. Args: *args: arguments to the loss function. **kwargs: keyword arguments. Returns: The result of the loss function. """
[docs]defforward(self,*args:typing.Any,**kwargs:typing.Any)->typing.Any:""" Forward wrapper function. The final loss value will be computed as described above. Args: *args: arguments to the loss function. **kwargs: keyword arguments. Returns: The weighted value of the loss function. """returnself._loss_weight*self._eval(*args,**kwargs)