helios.core.utils

Attributes

Classes

ChdirContext

Allow switching between the current working directory and another within a scope.

AverageTimer

Compute elapsed times using moving average.

Registry

Provides a name to object mapping to allow users to create custom types.

Functions

get_env_info_str(→ str)

Return a string with the Helios header and the environment information.

get_from_optional(→ T)

Ensure the given variable is not None and return it.

convert_to_list(→ list[T])

Convert the input into a list if it's not one already.

update_all_registries(→ None)

Ensure all registered types get added to their corresponding registries.

safe_torch_load(→ Any)

Wrap torch.load to handle safe loading.

Module Contents

helios.core.utils.T
helios.core.utils.T_Any
helios.core.utils.get_env_info_str() str[source]

Return a string with the Helios header and the environment information.

Returns:

The message string.

helios.core.utils.get_from_optional(opt_var: T | None, raise_on_empty: bool = False) T[source]

Ensure the given variable is not None and return it.

This is useful when dealing with variables that can be None at declaration but are set elsewhere. In those instances, mypy is unable to determine that the variable was set, so it will issue a warning. The workaround is to add asserts, but that can get tedious very quickly. This function can be used as an alternative.

Example

var: int | None = None
# ... Set var to a valid value some place else.

assert var is not None
v = var

# Alternatively:
v = core.get_from_optional(var)
Parameters:
  • opt_var – the optional variable.

  • raise_on_empty – if True, an exception is raised when the optional is None.

Returns:

The variable without the optional.

Raises:

RuntimeError – if the opt_var is None and raise_on_empty is true.

helios.core.utils.convert_to_list(var: T | list[T] | tuple[T, Ellipsis]) list[T][source]

Convert the input into a list if it’s not one already.

Example

def some_fun(x: int | list[int]) -> None:
    if isinstance(x, list):
        x = [x]
    for elem in x:
        ...

    # The above code an be replaced with this:
    for elem in convert_to_list(x):
        ...
Parameters:

var – an object that can be either a single object or a list.

Returns:

If the input was a list, no operation is done. Otherwise, the object is converted to a list and returned.

class helios.core.utils.ChdirContext(target_path: pathlib.Path)[source]

Allow switching between the current working directory and another within a scope.

The intention is to facilitate temporary switches of the current working directory (such as when attempting to resolve relative paths) by creating a context in which the working directory is automatically switched to a new one. Upon exiting of the context, the original working directory is restored.

Example

os.chdir(".")   # <- Starting working directory
with ChdirContext("/new/path") as prev_cwd:
    # prev_cwd is the starting working directory
    Path.cwd() # <- This is /new/path now
    ...
Path.cwd() # <- Back to the starting working directory.
Parameters:

target_path – the path to switch to.

start_path
target_path
__enter__() pathlib.Path[source]

Perform the switch from the current working directory to the new one.

Returns:

The previous working directory.

__exit__(exc_type: type[Exception] | None, exc_value: Exception | None, exc_traceback: types.TracebackType | None) None[source]

Restores the previous working directory.

class helios.core.utils.AverageTimer(sliding_window: int = 200)[source]

Compute elapsed times using moving average.

The timer will determine the elapsed time between a series of points using a sliding window moving average.

Parameters:

sliding_window – number of steps over which the moving average will be computed.

start() None[source]

Start the timer.

record() None[source]

Record a new step in the timer.

get_average_time() float[source]

Return the moving average over the current step count.

class helios.core.utils.Registry(name: str)[source]

Provides a name to object mapping to allow users to create custom types.

Example

# Create a registry:
TEST_REGISTRY = Registry("test")

# Register as a decorator:
@TEST_REGISTRY.register
class TestClass:
    ...

# Register in code:
TEST_REGISTRY.register(TestClass)
TEST_REGISTRY.register(test_function)
Parameters:

name – the name of the registry.

register(obj: T_Any, suffix: str | None = None) T_Any[source]

Register the given object.

Parameters:
  • obj – the type to add. Must have a __name__ attribute.

  • suffix – (optional) the suffix to add to the type name.

Returns:

The registered type.

get(name: str, suffix: str | None = None) Any[source]

Get the object that corresponds to the given name.

Parameters:
  • name – the name of the type.

  • suffix – (optional) the suffix to use if the type isn’t found with the given name.

Returns:

The requested type.

Raises:

KeyError – if no object with the given name is found in the registry.

__contains__(name: str) bool[source]

Check if the registry contains the given name.

Parameters:

name – the name to check.

Returns:

True if the name exists, false otherwise.

__iter__() Iterable[source]

Get an iterable over the registry items.

__str__() str[source]

Get the name of the registry.

keys() Iterable[source]

Return a set-like object providing a view into the registry’s keys.

Returns:

An iterable of the registry keys.

helios.core.utils.update_all_registries(root: pathlib.Path, recurse: bool = True, import_prefix: str = '') None[source]

Ensure all registered types get added to their corresponding registries.

This function serves as a way of automatically registering all types into their corresponding registries within a package. Normally, you’d have to manually include each module that contains a registered type to ensure that it gets registered. This can easily cascade if modules are nested inside packages, whereby the top-level module has to (somehow) ensure that all child modules get imported to ensure everything works correctly.

This function offers an alternative, whereby it will automatically scan all modules and sub-packages within a given package and import only those files that register a type. To do this, there are a few assumptions:

  1. Each package MUST contain an __init__.py (namespace packages are not supported)

  2. A module is included if and only if there is at least one line that contains the following pattern: @<any non-whitespace character(s)>.register.

Example

Suppose we have a project with the following structure:

main.py
my_package/
|---__init__.py
|---some_class.py <- This registers a type.
|---some_funcs.py <- Doesn't register anything.
|---sub_package/
|   |---__init__.py
|   |---another_type.py <- Registers
|   |---another_func.py <- Doesn't register.

We can then do the following inside main.py:

import helios.core as hlc
...
hlc.update_all_registries(Path.cwd() / "my_package", recurse=True)

The function will recursively walk through my_package and import the following:

  • my_package.some_class

  • my_package.sub_package.another_type

After the function returns, the corresponding registries will have been populated with the types and they can be used elsewhere in the code.

Parameters:
  • root – the path to the root package.

  • recurse – if True, recursively search through sub-packages. Defaults to true.

  • import_prefix – (optional) prefix to be added when imported. Defaults to empty.

Raises:

RuntimeError – if the given path isn’t a valid directory or if the directory is not Python package with __init__.py.

helios.core.utils.safe_torch_load(f: str | os.PathLike | BinaryIO | IO[bytes], **kwargs: Any) Any[source]

Wrap torch.load to handle safe loading.

This function will automatically set weights_only to true when calling torch.load. You are encouraged to use this function instead of the plain torch.load to ensure safe loading.

Warning

weights_only is set automatically by this function. do not set this value yourself when using this function.

Parameters:
  • f – a file-like object (has to implement read(), readline(), tell(), and seek()), or a string or a os.pathlike object containing a file name.

  • **kwargs – keyword arguments to pass to torch.load.

Returns:

The result of calling torch.load.