Source code for vision3d.transforms.functional._registry
"""Kernel dispatch for vision3d transforms.
Minimal reimplementation of torchvision's kernel registry, since the public
``register_kernel`` only allows registering kernels for torchvision's own
functionals.
"""
import functools
from collections.abc import Callable
from typing import Any
from torch import Tensor
from torchvision.tv_tensors import TVTensor
# {functional: {input_type: kernel}}
_KERNEL_REGISTRY: dict[Callable[..., Any], dict[type, Callable[..., Any]]] = {}
[docs]
def register_kernel(
functional: Callable[..., Any],
tv_tensor_cls: type[TVTensor],
*,
tv_tensor_wrapper: bool = True,
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
"""Register a kernel for a functional and TVTensor type.
Args:
functional: The functional to register a kernel for.
tv_tensor_cls: The TVTensor subclass this kernel handles.
tv_tensor_wrapper: If True (default), the kernel receives an unwrapped
pure tensor and the output is automatically re-wrapped. If False,
the kernel receives the full TVTensor and must handle wrap itself.
Returns:
Decorator that registers the kernel.
"""
registry = _KERNEL_REGISTRY.setdefault(functional, {})
def decorator(kernel: Callable[..., Any]) -> Callable[..., Any]:
if tv_tensor_cls in registry:
msg = (
f"{functional.__name__} already has a kernel "
f"registered for {tv_tensor_cls.__name__}."
)
raise ValueError(msg)
if tv_tensor_wrapper:
@functools.wraps(kernel)
def wrapper(inpt: TVTensor, *args: Any, **kwargs: Any) -> TVTensor:
from vision3d.tensors import wrap
output = kernel(inpt.as_subclass(Tensor), *args, **kwargs)
return wrap(output, like=inpt)
registry[tv_tensor_cls] = wrapper
else:
registry[tv_tensor_cls] = kernel
return kernel
return decorator
def _get_kernel(
functional: Callable[..., Any],
input_type: type,
*,
allow_passthrough: bool = False,
) -> Callable[..., Any]:
"""Look up the registered kernel for a functional and input type.
Args:
functional: The functional to look up.
input_type: The type of the input.
allow_passthrough: If True, return an identity kernel when the
functional has no registered kernel for ``input_type``. If
False (default), raise :class:`TypeError`.
Returns:
The kernel function, or an identity lambda when
``allow_passthrough`` is True and no kernel is registered for
``input_type``.
Raises:
ValueError: If the functional has no kernels registered at all.
TypeError: If the functional has kernels but none for
``input_type`` and ``allow_passthrough`` is False.
"""
registry = _KERNEL_REGISTRY.get(functional)
if not registry:
msg = f"No kernel registered for functional `{functional.__name__}`."
raise ValueError(msg)
for cls in input_type.__mro__:
if cls in registry:
return registry[cls]
if cls is TVTensor:
break
if allow_passthrough:
return lambda inpt, *args, **kwargs: inpt
msg = (
f"Functional `{functional.__name__}` supports inputs of type "
f"{sorted(c.__name__ for c in registry)}, but got "
f"`{input_type.__name__}` instead."
)
raise TypeError(msg)