Source code for helios.data.transforms
import typing
import numpy.typing as npt
import PIL
import torch
import torchvision.transforms.v2 as T
from torch import nn
from helios import core
TRANSFORM_REGISTRY = core.Registry("transform")
"""
Global instance of the registry for transforms.
Example:
.. code-block:: python
import helios.data.transforms as hldt
# This automatically registers your dataset.
@hldt.TRANSFORM_REGISTRY.register()
class MyTransform:
...
# Alternatively you can manually register a dataset like this:
hldt.TRANSFORM_REGISTRY.register(MyTransform)
"""
[docs]
@TRANSFORM_REGISTRY.register
class ToImageTensor(nn.Module):
"""
Convert an image (or list of images) to tensor(s).
An image is meant to be a tensor, ndarray, or PIL image. The shape expected to be
either [H, W, C] or [C, H, W].
Args:
dtype: (optional) the output type of the tensors.
scale: (optional) if true, scale the values to the valid range.
"""
def __init__(self, dtype: torch.dtype = torch.float32, scale: bool = True):
"""Create the transform."""
super().__init__()
self._transform = T.Compose(
[T.ToImage(), T.ToDtype(dtype, scale=scale), T.ToPureTensor()]
)
[docs]
def forward(
self,
img: npt.NDArray
| list[npt.NDArray]
| tuple[npt.NDArray, ...]
| PIL.Image.Image
| list[PIL.Image.Image]
| tuple[PIL.Image.Image, ...],
) -> torch.Tensor | list[torch.Tensor] | tuple[torch.Tensor, ...]:
"""
Convert the input image(s) into tensor(s).
The return type will match the type of the input. So, if the input is a single
image, then the output will be a single tensor. If the input is a list or a tuple
of images, the output will be a list or tuple of tensors.
Args:
img: image(s) to convert.
Returns:
The converted images.
"""
out_tens: list[torch.Tensor] = []
for elem in core.convert_to_list(img):
out_tens.append(self._transform(elem))
if len(out_tens) == 1:
return out_tens[0]
if isinstance(img, tuple):
return tuple(out_tens)
return out_tens