importtypingimportnumpy.typingasnptimportPILimporttorchimporttorchvision.transforms.v2asTfromtorchimportnnfromheliosimportcoreTRANSFORM_REGISTRY=core.Registry("transform")"""Global instance of the registry for transforms.Example: .. code-block:: python import helios.data.transforms as hldt # This automatically registers your dataset. @hldt.TRANSFORM_REGISTRY.register() class MyTransform: ... # Alternatively you can manually register a dataset like this: hldt.TRANSFORM_REGISTRY.register(MyTransform)"""
[docs]defcreate_transform(type_name:str,*args:typing.Any,**kwargs:typing.Any)->nn.Module:""" Create a transform of the given type. This uses TRANSFORM_REGISTRY to look-up transform types, so ensure your transforms have been registered before using this function. Args: type_name: the type of the transform to create. args: positional arguments to pass into the transform. kwargs: keyword arguments to pass into the transform. Returns: The constructed transform. """returnTRANSFORM_REGISTRY.get(type_name)(*args,**kwargs)
[docs]@TRANSFORM_REGISTRY.registerclassToImageTensor(nn.Module):""" Convert an image (or list of images) to tensor(s). An image is meant to be a tensor, ndarray, or PIL image. The shape expected to be either [H, W, C] or [C, H, W]. Args: dtype: the output type of the tensors. scale: if true, scale the values to the valid range. Defaults to true. """def__init__(self,dtype:torch.dtype=torch.float32,scale:bool=True):"""Create the transform."""super().__init__()self._transform=T.Compose([T.ToImage(),T.ToDtype(dtype,scale=scale),T.ToPureTensor()])
[docs]defforward(self,img:npt.NDArray|list[npt.NDArray]|tuple[npt.NDArray,...]|PIL.Image.Image|list[PIL.Image.Image]|tuple[PIL.Image.Image,...],)->torch.Tensor|list[torch.Tensor]|tuple[torch.Tensor,...]:""" Convert the input image(s) into tensor(s). The return type will match the type of the input. So, if the input is a single image, then the output will be a single tensor. If the input is a list or a tuple of images, the output will be a list or tuple of tensors. Args: img: image(s) to convert. Returns: The converted images. """out_tens:list[torch.Tensor]=[]forelemincore.convert_to_list(img):out_tens.append(self._transform(elem))iflen(out_tens)==1:returnout_tens[0]ifisinstance(img,tuple):returntuple(out_tens)returnout_tens