helios.losses.weighted_loss¶
Classes¶
Defines a base class for weighted losses. |
Module Contents¶
- class helios.losses.weighted_loss.WeightedLoss(loss_weight: float = 1.0)[source]¶
Bases:
torch.nn.Module
Defines a base class for weighted losses.
The value of the final loss is determined by the following formula:
\[L_w = w * L\]where \(w\) is the weight and \(L\) is the loss function.
Example
class MyLoss(WeightedLoss): ... def _eval(self, ...): return my_loss_function(...)
- Parameters:
loss_weight – the weight of the loss function. Defaults to 1.