Source code for vision3d.tensors._wrap

from typing import Any, cast

from torch import Tensor
from torchvision.tv_tensors import TVTensor

# Mirrors torchvision 0.28's dispatch
# (https://github.com/pytorch/vision/pull/9490). Required while vision3d is
# compatible with torchvision 0.25-0.27, whose wrap() hardcodes its own
# tv_tensor types. Replace with ``from torchvision.tv_tensors import wrap``
# once vision3d requires ``torchvision>=0.28``.


[docs] def wrap[T: TVTensor]( wrappee: Tensor, *, like: T, **kwargs: Any, ) -> T: """Convert a :class:`~torch.Tensor` into the same :class:`~torchvision.tv_tensors.TVTensor` subclass as ``like``. If ``like`` carries metadata (e.g. ``format``, ``image_size``), that metadata is copied to the output. Individual fields can be overridden via ``kwargs``. Subclass authors can define a ``wrap`` classmethod on their subclass of :class:`~torchvision.tv_tensors.TVTensor` to control how metadata propagates. Args: wrappee (:class:`~torch.Tensor`): The tensor to convert. like (:class:`~torchvision.tv_tensors.TVTensor`): The reference. ``wrappee`` will be converted into the same subclass as ``like``. kwargs: Metadata overrides forwarded to the subclass's ``wrap`` classmethod. Returns: :class:`~torchvision.tv_tensors.TVTensor`: A TVTensor of the same subclass as ``like``. """ if (wrap_method := getattr(type(like), "wrap", None)) is not None: return cast("T", wrap_method(wrappee, like, **kwargs)) return wrappee.as_subclass(type(like))