Source code for helios.model.utils
from __future__ import annotations
import pathlib
import typing
from helios import core
if typing.TYPE_CHECKING:
from .model import Model
MODEL_REGISTRY = core.Registry("model")
"""
Global instance of the registry for models.
Example:
.. code-block:: python
import helios.model as hlm
# This automatically registers your model.
@hlm.MODEL_REGISTRY.register
class MyModel:
...
# Alternatively you can manually register a model like this:
hlm.MODEL_REGISTRY.register(MyModel)
"""
[docs]
def create_model(type_name: str, *args: typing.Any, **kwargs: typing.Any) -> Model:
"""
Create the model for the given type.
Args:
type_name: the type of the model to create.
args: positional arguments to pass into the model.
kwargs: keyword arguments to pass into the model.
Returns:
The model.
"""
return MODEL_REGISTRY.get(type_name)(*args, **kwargs)
[docs]
def find_pretrained_file(root: pathlib.Path, name: str) -> pathlib.Path:
"""
Find the pre-trained file in the given root.
The assumption is the following:
Given a root ``/models/cifar`` and a name ``resnet-50``, then the name of the
pre-trained file will contain ``cifar_resnet-50_`` as a prefix. If no file is
found, an exception is raised.
Args:
root: the root where the file is stored.
net_name: the save name of the file.
Returns:
The path to the file.
Raises:
RuntimeError: if no pre-trained network was found.
"""
for path in root.glob("*.pth"):
file_name = str(path.stem)
base_name = f"{str(root.stem)}_{name}_"
if base_name in file_name:
return path
raise RuntimeError(
f"error: unable to find a pretrained network named {name} at {str(root)}"
)