Using KITTI with vision3d#

This example demonstrates using the KITTI 3D Object Detection dataset with vision3d.datasets.Kitti3D. It covers inspecting the FusionInputs, SampleTargets tuple returned by the dataset, batching with vision3d.datasets.collate_fn() for training, and visualizing a frame with vision3d.viz.log_sample().

Construct the dataset#

Kitti3D yields sample frames describing the 3D scene. Each sample carries lidar points, the left color camera image, its intrinsics and extrinsics, and 3D bounding-box annotations of the objects in the scene.

from pathlib import Path

from vision3d.datasets import Kitti3D

KITTI_ROOT = Path("~/.cache/vision3d/kitti-mini").expanduser()

dataset = Kitti3D(KITTI_ROOT, train=True, mini=True, frames=range(10), download=True)
print(f"len(dataset) = {len(dataset)}")
print(f"classes ({len(dataset.classes)}): {dataset.classes}")
len(dataset) = 10
classes (8): ('Car', 'Pedestrian', 'Cyclist', 'Van', 'Truck', 'Person_sitting', 'Tram', 'Misc')

Inspect a sample#

A single index returns a (inputs, targets) tuple where inputs is a FusionInputs dict and targets is a SampleTargets dict. Most values are semantic tensor types from vision3d.tensors (PointCloud3D, CameraImages, BoundingBoxes3D, …) so vision3d.transforms can dispatch to the right operation per input.

inputs, targets = dataset[0]

print("inputs:")
print(
    f"  points: type={type(inputs['points']).__name__} "
    f"shape={tuple(inputs['points'].shape)} dtype={inputs['points'].dtype}"
)
print(
    f"  images: type={type(inputs['images']).__name__} "
    f"shape={tuple(inputs['images'].shape)} dtype={inputs['images'].dtype}"
)
print(
    f"  intrinsics: type={type(inputs['intrinsics']).__name__} "
    f"shape={tuple(inputs['intrinsics'].shape)} dtype={inputs['intrinsics'].dtype}"
)
print(
    f"  extrinsics: type={type(inputs['extrinsics']).__name__} "
    f"shape={tuple(inputs['extrinsics'].shape)} dtype={inputs['extrinsics'].dtype}"
)

assert targets is not None  # train split always has targets
print("targets:")
print(
    f"  boxes: type={type(targets['boxes']).__name__} "
    f"shape={tuple(targets['boxes'].shape)} dtype={targets['boxes'].dtype} "
    f"format={targets['boxes'].format.name}"
)
print(
    f"  labels: type={type(targets['labels']).__name__} "
    f"shape={tuple(targets['labels'].shape)} dtype={targets['labels'].dtype}"
)
inputs:
  points: type=PointCloud3D shape=(20285, 4) dtype=torch.float32
  images: type=CameraImages shape=(1, 3, 370, 1224) dtype=torch.float32
  intrinsics: type=CameraIntrinsics shape=(1, 3, 3) dtype=torch.float32
  extrinsics: type=CameraExtrinsics shape=(1, 4, 4) dtype=torch.float32
targets:
  boxes: type=BoundingBoxes3D shape=(1, 7) dtype=torch.float32 format=XYZLWHY
  labels: type=Tensor shape=(1,) dtype=torch.int64

Batch with vision3d.datasets.collate_fn()#

Variable-size tensors (point clouds, per-frame box counts) cannot be stacked along a batch dimension, so vision3d.datasets.collate_fn() returns tuples-of-tensors keyed the same as the per-sample dicts. Pass it as the collate_fn argument to DataLoader whenever you train or evaluate on a vision3d dataset.

from torch.utils.data import DataLoader

from vision3d.datasets import collate_fn

loader = DataLoader(dataset, batch_size=2, collate_fn=collate_fn)
batch_inputs, batch_targets = next(iter(loader))

print(f"batch size: {len(batch_inputs)}")
for i, (inp, tgt) in enumerate(zip(batch_inputs, batch_targets)):
    assert tgt is not None
    print(
        f"  sample {i}: "
        f"points={tuple(inp['points'].shape)} "
        f"boxes={tuple(tgt['boxes'].shape)}"
    )
batch size: 2
  sample 0: points=(20285, 4) boxes=(1, 7)
  sample 1: points=(18630, 4) boxes=(3, 7)

Visualize the dataset#

vision3d.viz.log_sample() logs a FusionInputs / SampleTargets pair to Rerun for interactive visualization.

import rerun as rr
import rerun.blueprint as rrb

from vision3d.viz import fusion_layout, log_sample

rr.init("vision3d_kitti", spawn=True)
rr.send_blueprint(
    rrb.Blueprint(
        fusion_layout(Kitti3D.camera_names, Kitti3D.camera_grid),
        rrb.TimePanel(state="collapsed"),
    )
)
rr.log("world", rr.ViewCoordinates.RIGHT_HAND_Z_UP, static=True)

for frame_idx in range(len(dataset)):
    f_inputs, f_targets = dataset[frame_idx]
    rr.set_time("frame", sequence=frame_idx)
    log_sample(f_inputs, f_targets, label_to_id=dataset.class_to_idx, jpeg_quality=75)

Total running time of the script: (0 minutes 0.267 seconds)