helios.losses.weighted_loss¶
Classes¶
Defines a base class for weighted losses. |
Module Contents¶
- class helios.losses.weighted_loss.WeightedLoss(loss_weight: float = 1.0)¶
Bases:
torch.nn.Module
Defines a base class for weighted losses.
The value of the final loss is determined by the following formula:
\[L_w = w * L\]where \(w\) is the weight and \(L\) is the loss function.
Example
class MyLoss(WeightedLoss): ... def _eval(self, ...): return my_loss_function(...)
- Parameters:
loss_weight – the weight of the loss function. Defaults to 1.
- forward(*args: Any, **kwargs: Any) Any ¶
Forward wrapper function.
The final loss value will be computed as described above.
- Parameters:
*args – arguments to the loss function.
**kwargs – keyword arguments.
- Returns:
The weighted value of the loss function.