helios.losses.weighted_loss

Classes

WeightedLoss

Defines a base class for weighted losses.

Module Contents

class helios.losses.weighted_loss.WeightedLoss(loss_weight: float = 1.0)

Bases: torch.nn.Module

Defines a base class for weighted losses.

The value of the final loss is determined by the following formula:

\[L_w = w * L\]

where \(w\) is the weight and \(L\) is the loss function.

Example

class MyLoss(WeightedLoss):
    ...
    def _eval(self, ...):
        return my_loss_function(...)
Parameters:

loss_weight – the weight of the loss function. Defaults to 1.

forward(*args: Any, **kwargs: Any) Any

Forward wrapper function.

The final loss value will be computed as described above.

Parameters:
  • *args – arguments to the loss function.

  • **kwargs – keyword arguments.

Returns:

The weighted value of the loss function.