Source code for helios.metrics.metrics

import typing

import numpy as np
import numpy.typing as npt
import torch
from torch import nn

from helios import core

from .functional import (
    calculate_mae,
    calculate_mae_torch,
    calculate_mAP,
    calculate_psnr,
    calculate_psnr_torch,
    calculate_ssim,
    calculate_ssim_torch,
)

METRICS_REGISTRY = core.Registry("metrics")
"""
Global instance of the registry for metric functions.

Example:
    .. code-block:: python

        import helios.metrics as hlm

        # This automatically registers your metric function.
        @hlm.METRICS_REGISTRY.register
        class MyMetric:
            ...

        # Alternatively you can manually register a metric function like this:
        hlm.METRICS_REGISTRY.register(MyMetric)
"""


[docs] def create_metric(type_name: str, *args: typing.Any, **kwargs: typing.Any) -> nn.Module: """ Create the metric function for the given type. Args: type_name: the type of the loss to create. args: positional arguments to pass into the metric. kwargs: keyword arguments to pass into the metric. Returns: The metric function """ return METRICS_REGISTRY.get(type_name)(*args, **kwargs)
[docs] @METRICS_REGISTRY.register class CalculatePSNR(nn.Module): """ Calculate PSNR (Peak Signal-to-Noise Ratio). Implementation follows: `<https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio>`__. Note that the input_order is only needed if you plan to evaluate Numpy images. It can be left as default otherwise. Args: crop_border: Cropped pixels in each edge of an image. These pixels are not involved in the calculation. input_order: Whether the input order is "HWC" or "CHW". Defaults to "HWC". test_y_channel: Test on Y channel of YCbCr. Defaults to false. """ def __init__( self, crop_border: int, input_order: str = "HWC", test_y_channel: bool = False ): """Construct the PSNR metric.""" super().__init__() self._crop_border = crop_border self._input_order = input_order self._test_y_channel = test_y_channel
[docs] def forward( self, img: npt.NDArray | torch.Tensor, img2: npt.NDArray | torch.Tensor ) -> float: """ Calculate the PSNR metric. Args: img: Images with range :math:`[0, 255]`. img2: Images with range :math:`[0, 255]`. Returns: PSNR value. """ if isinstance(img, torch.Tensor) and isinstance(img2, torch.Tensor): return calculate_psnr_torch( img, img2, self._crop_border, self._test_y_channel ) assert isinstance(img, np.ndarray) and isinstance(img2, np.ndarray) return calculate_psnr( img, img2, self._crop_border, self._input_order, self._test_y_channel )
[docs] @METRICS_REGISTRY.register class CalculateSSIM(nn.Module): """ Calculate SSIM (structural similarity). Implementation follows: 'Image quality assesment: From error visibility to structural similarity'. Results are identical to those of the official MATLAB code in `<https://ece.uwaterloo.ca/~z70wang/research/ssim/>`__. For three-channel images, SSIM is calculated for each channel and then averaged. Args: crop_border: Cropped pixels in each edge of an image. These pixels are not involved in the calculation. input_order: Whether the input order is "HWC" or "CHW". Defaults to "HWC". test_y_channel: Test on Y channel of YCbCr. Defaults to false. """ def __init__( self, crop_border: int, input_order: str = "HWC", test_y_channel: bool = False ): """Construct the SSIM metric.""" super().__init__() self._crop_border = crop_border self._input_order = input_order self._test_y_channel = test_y_channel
[docs] def forward( self, img: npt.NDArray | torch.Tensor, img2: npt.NDArray | torch.Tensor ) -> float: """ Calculate the SSIM metric. Args: img: Images with range :math:`[0, 255]`. img2: Images with range :math:`[0, 255]`. Returns: PSNR value. """ if isinstance(img, torch.Tensor) and isinstance(img2, torch.Tensor): return calculate_ssim_torch( img, img2, self._crop_border, self._test_y_channel ) assert isinstance(img, np.ndarray) and isinstance(img2, np.ndarray) return calculate_ssim( img, img2, self._crop_border, self._input_order, self._test_y_channel )
[docs] @METRICS_REGISTRY.register class CalculateMAP(nn.Module): """ Calculate the mAP (Mean Average Precision). Implementation follows: `<https://en.wikipedia.org/wiki/Evaluation_measures_(information_retrieval)#Mean_average_precision>`__. """
[docs] def forward(self, targs: npt.NDArray, preds: npt.NDArray) -> float: """ Calculate the mAP (Mean Average Precision). Args: targs: target (inferred) labels in range :math:`[0, 1]`. preds: predicate labels in range :math:`[0, 1]`. Returns: The mAP score """ return calculate_mAP(targs, preds)
[docs] @METRICS_REGISTRY.register class CalculateMAE(nn.Module): """ Compute the MAE (Mean-Average Precision) score. Implementation follows: `<https://en.wikipedia.org/wiki/Mean_absolute_error>`__. The scale argument is used in the event that the input arrays are not in the range :math:`[0, 1]` but instead have been scaled to be in the range :math:`[0, N]` where :math:`N` is the factor. For example, if the arrays are images in the range :math:`[0, 255]`, then the scaling factor should be set to 255. If the arrays are already in the range :math:`[0, 1]`, then the scale can be omitted. Args: scale: scaling factor that was used on the input tensors. Defaults to 1. """ def __init__(self, scale: float = 1): """Construct the MAE metric.""" super().__init__() self._scale = scale
[docs] def forward( self, pred: npt.NDArray | torch.Tensor, gt: npt.NDArray | torch.Tensor ) -> float: """ Calculate the MAE metric. Args: pred: predicate (inferred) data. gt: ground-truth data. Returns: The MAE score """ if isinstance(pred, torch.Tensor) and isinstance(gt, torch.Tensor): return calculate_mae_torch(pred, gt, self._scale) assert isinstance(pred, np.ndarray) and isinstance(gt, np.ndarray) return calculate_mae(pred, gt, self._scale)