{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Augmenting samples with vision3d transforms\n\nThis example showcases :mod:`vision3d.transforms` on the nuScenes dataset.\n\nTransforms automatically dispatch on the tensor types from\n:mod:`vision3d.tensors` carried by the sample, so a geometric transform\nlike :class:`~vision3d.transforms.RandomRotate3D` updates points, boxes,\nand extrinsics together without requiring any special handling.\n\nEvery transform in :mod:`vision3d.transforms` is input and dataset\nagnostic by design. They support every\n:class:`~vision3d.tensors.BoundingBox3DFormat`, run on lidar-only,\ncamera-only, and fusion samples, and make no assumptions about scene\ncomposition such as camera count, sensor layout, or axis convention.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Construct the dataset\nGrab a single ``(inputs, targets)`` sample to act as the baseline that\nevery transform is applied to.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "from pathlib import Path\n\nimport torch\n\nfrom vision3d.datasets import NuScenes3D\n\nNUSCENES_ROOT = Path(\"~/.cache/vision3d/nuscenes-mini\").expanduser()\nFRAME_INDEX = 100\n\ntorch.manual_seed(42)\n\ndataset = NuScenes3D(NUSCENES_ROOT, version=\"v1.0-mini\", split=\"train\", download=True)\ninputs, targets = dataset[FRAME_INDEX]\nprint(f\"num boxes: {targets['boxes'].shape[0]}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Visualize the baseline sample\nRender the original sample in an embedded Rerun viewer to use as\na reference for the transformed scenes shown later in this example.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import rerun as rr\nimport rerun.blueprint as rrb\n\nfrom vision3d.viz import fusion_layout, log_sample\n\nlabel_to_id = dataset.class_to_idx\n\nrr.init(\"vision3d_original\", spawn=True)\nrr.send_blueprint(\n    rrb.Blueprint(\n        fusion_layout(\n            NuScenes3D.camera_names,\n            NuScenes3D.camera_grid,\n            entity_prefix=\"original\",\n            name=\"Original\",\n        ),\n        rrb.TimePanel(state=\"hidden\"),\n    )\n)\nrr.log(\"original\", rr.ViewCoordinates.RIGHT_HAND_Z_UP, static=True)\nlog_sample(\n    inputs,\n    targets,\n    entity_prefix=\"original\",\n    label_to_id=label_to_id,\n    jpeg_quality=75,\n)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Apply a single transform\nEvery transform is a :class:`torchvision.transforms.v2.Transform` that\naccepts ``(inputs, targets)`` and returns the transformed pair. The\n:class:`~vision3d.tensors.PointCloud3D`,\n:class:`~vision3d.tensors.CameraImages`, and\n:class:`~vision3d.tensors.BoundingBoxes3D` semantic tensor types\nsteer dispatch, so a single call updates geometry, imagery, and box\nannotations with geometric and photometric consistency.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "import math\n\nfrom vision3d.transforms import RandomRotate3D\n\nrotate = RandomRotate3D(angle_range=math.pi / 4, p=1.0)\nr_inputs, r_targets = rotate(inputs, targets)\nprint(f\"rotated boxes shape: {tuple(r_targets['boxes'].shape)}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Geometric safety\nvision3d transforms mirror the torchvision v2 dispatch model: each\ntransform declares the input types it operates on via the class-level\n``_transformed_types`` tuple, and any input whose type is not listed\npasses through unchanged. Transforms whose operation would only be\ncorrect on a subset of scene types additionally override\n:meth:`~vision3d.transforms.Transform.check_inputs` to raise\n:class:`TypeError` for input combinations they cannot handle. Together\nthese guard against silently producing geometrically inconsistent\nscenes (e.g. flipping the lidar but not the camera image alongside\nit).\n\nFor example, :class:`~vision3d.transforms.RandomFlip3D` operates on\n:class:`~vision3d.tensors.PointCloud3D` and\n:class:`~vision3d.tensors.BoundingBoxes3D`, and its ``check_inputs``\nrefuses samples that also carry camera tensors (images, extrinsics,\nintrinsics): flipping the 3D scene without coordinated changes to the\ncamera side would break geometric consistency. Running it on a fusion\ndataset sample therefore raises a :class:`TypeError`.\n\nThe error signals that the transform is not compatible with a fusion\npipeline. :class:`~vision3d.transforms.RandomFlip3D` is intended for\nlidar-only training pipelines, where there are no camera tensors to\nfall out of correspondence.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "from vision3d.transforms import RandomFlip3D\n\nflip = RandomFlip3D(axis=\"x\", p=1.0)\ntry:\n    flip(inputs, targets)\nexcept TypeError as e:\n    print(e)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Composing transforms\nvision3d transforms are designed to be chained together to build\nadvanced data-augmentation pipelines to be used during training.\nThe standard :class:`torchvision.transforms.v2.Compose` can run 3D\ntransforms alongside any tensor-aware torchvision image transform.\ntorchvision image transforms see only the\n:class:`~vision3d.tensors.CameraImages` tensor and leave 3D geometry\nuntouched, so mixing them is safe.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "from torchvision.transforms import v2\n\nfrom vision3d.transforms import (\n    PointJitter,\n    RandomScale3D,\n    RandomTranslate3D,\n    RangeFilter3D,\n)\n\ncompose = v2.Compose(\n    [\n        RandomRotate3D(angle_range=math.pi / 4, p=1.0),\n        RandomScale3D(scale_range=(0.7, 1.3), p=1.0),\n        RandomTranslate3D(translation_range=5.0, p=1.0),\n        PointJitter(sigma=0.1, p=1.0),\n        RangeFilter3D(point_cloud_range=(-50, -50, -5, 50, 50, 3)),\n        v2.Resize(size=[450, 800]),\n        v2.CenterCrop(size=[400, 700]),\n        v2.ColorJitter(brightness=0.6, contrast=0.6, saturation=0.6, hue=0.3),\n    ]\n)\n\nc_inputs, c_targets = compose(inputs, targets)\nprint(f\"composed points: {tuple(c_inputs['points'].shape)}\")\nprint(f\"composed images: {tuple(c_inputs['images'].shape)}\")\nprint(f\"composed boxes:  {tuple(c_targets['boxes'].shape)}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Cross-sample augmentation with CopyPaste3D\n:class:`~vision3d.transforms.CopyPaste3D` is an advanced augmentation\nmethod based on the ground-truth sampling technique first introduced\nin [SECOND](https://www.mdpi.com/1424-8220/18/10/3337). It improves\nscene diversity by injecting instances from other scenes into the\ncurrent one. Unlike single-sample transforms\n:class:`~vision3d.transforms.CopyPaste3D`\noperates on collated batches and reads from an internal object\ndatabase that grows lazily with each seen batch.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "from torch.utils.data import DataLoader, Subset\n\nfrom vision3d.datasets import collate_fn\nfrom vision3d.transforms import CopyPaste3D\n\ntarget_counts = {\n    dataset.class_to_idx[\"car\"]: 30,\n    dataset.class_to_idx[\"pedestrian\"]: 20,\n    dataset.class_to_idx[\"traffic_cone\"]: 15,\n}\ncopy_paste = CopyPaste3D(target_counts=target_counts, min_points=5)\n\ndataset_range = list(range(max(0, FRAME_INDEX - 10), FRAME_INDEX))\ndataset_loader = DataLoader(\n    Subset(dataset, dataset_range),\n    batch_size=2,\n    collate_fn=collate_fn,\n)\nfor epoch in range(2):\n    for batch_inputs, batch_targets in dataset_loader:\n        copy_paste(batch_inputs, batch_targets)\n\ncp_inputs, cp_targets = copy_paste((inputs,), (targets,))\nprint(f\"boxes before: {targets['boxes'].shape[0]}\")\nprint(f\"boxes after CopyPaste3D:  {cp_targets[0]['boxes'].shape[0]}\")\n\nrr.init(\"vision3d_copy_paste\", spawn=True)\nrr.send_blueprint(\n    rrb.Blueprint(\n        fusion_layout(\n            NuScenes3D.camera_names,\n            NuScenes3D.camera_grid,\n            entity_prefix=\"copy_paste\",\n            name=\"CopyPaste3D\",\n        ),\n        rrb.TimePanel(state=\"hidden\"),\n    )\n)\nrr.log(\"copy_paste\", rr.ViewCoordinates.RIGHT_HAND_Z_UP, static=True)\nlog_sample(\n    cp_inputs[0],\n    cp_targets[0],\n    entity_prefix=\"copy_paste\",\n    label_to_id=label_to_id,\n    jpeg_quality=75,\n)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Transforms showcase\nView every transform side by side in the embedded Rerun viewer, each\non its own tab. Compare each tab against the baseline viewer at the\ntop of the page.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "from vision3d.transforms import PointSample, PointShuffle\n\ntransforms = [\n    (\n        \"translate\",\n        \"RandomTranslate3D(5.0)\",\n        RandomTranslate3D(translation_range=5.0, p=1.0),\n    ),\n    (\"rotate\", \"RandomRotate3D(pi/4)\", RandomRotate3D(angle_range=math.pi / 4, p=1.0)),\n    (\n        \"scale\",\n        \"RandomScale3D(0.25, 4.0)\",\n        RandomScale3D(scale_range=(0.25, 4.0), p=1.0),\n    ),\n    (\n        \"color_jitter\",\n        \"ColorJitter\",\n        v2.ColorJitter(brightness=0.8, contrast=0.8, saturation=0.8, hue=0.4),\n    ),\n    (\"gaussian_blur\", \"GaussianBlur\", v2.GaussianBlur(kernel_size=31, sigma=10.0)),\n    (\"solarize\", \"Solarize\", v2.RandomSolarize(threshold=0.5, p=1.0)),\n    (\"resize_half\", \"Resize(half)\", v2.Resize(size=[450, 800])),\n    (\"center_crop\", \"CenterCrop()\", v2.CenterCrop(size=[600, 800])),\n    (\"pad\", \"Pad(100)\", v2.Pad(padding=100)),\n    (\"point_shuffle\", \"PointShuffle\", PointShuffle(p=1.0)),\n    (\"point_sample\", \"PointSample(4096)\", PointSample(n=4096)),\n    (\"point_jitter\", \"PointJitter(sigma=0.1)\", PointJitter(sigma=0.1, p=1.0)),\n    (\n        \"range_filter\",\n        \"RangeFilter3D()\",\n        RangeFilter3D(point_cloud_range=(-30, -30, -5, 30, 30, 3)),\n    ),\n    (\"compose\", \"Compose\", compose),\n]\n\nrr.init(\"vision3d_transforms\", spawn=True)\nrr.send_blueprint(\n    rrb.Blueprint(\n        rrb.Tabs(\n            *(\n                fusion_layout(\n                    NuScenes3D.camera_names,\n                    NuScenes3D.camera_grid,\n                    entity_prefix=prefix,\n                    name=name,\n                )\n                for prefix, name, _ in transforms\n            )\n        ),\n        rrb.TimePanel(state=\"hidden\"),\n    )\n)\n\nfor prefix, name, pipeline in transforms:\n    rr.log(prefix, rr.ViewCoordinates.RIGHT_HAND_Z_UP, static=True)\n    t_inputs, t_targets = pipeline(inputs, targets)\n    log_sample(\n        t_inputs,\n        t_targets,\n        entity_prefix=prefix,\n        label_to_id=label_to_id,\n        jpeg_quality=75,\n    )"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.14.5"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}