Skip to content

PyTorch Adapter

torchadapter

Attributes

ImageTargetPair module-attribute

ImageTargetPair: TypeAlias = tuple[Any, Target]

Type alias for a tuple of (image, target).

The image can be either a PIL Image or a torch.Tensor depending on whether transforms have been applied.

Classes

Target

Bases: TypedDict

Target dictionary containing annotation information for an image.

This TypedDict defines the structure of target data returned by the dataset adapter, following torchvision's object detection format conventions.

Attributes:

Name Type Description
boxes Tensor

Bounding boxes tensor of shape (N, 4) where N is the number of objects. Format depends on the adapter's return_format setting.

labels Tensor

Class label tensor of shape (N,) containing integer category IDs.

image_id Tensor

Image identifier tensor of shape (1,).

area Tensor

Area values tensor of shape (N,) for each bounding box.

iscrowd Tensor

Crowd flag tensor of shape (N,) indicating if object is a crowd.

Attributes
boxes instance-attribute
boxes: Tensor

Bounding boxes in the specified format.

labels instance-attribute
labels: Tensor

Class labels for each bounding box.

image_id instance-attribute
image_id: Tensor

Image identifier.

area instance-attribute
area: Tensor

Area of each bounding box.

iscrowd instance-attribute
iscrowd: Tensor

Crowd flag for each bounding box.

TorchDatasetAdapter

TorchDatasetAdapter(dataset: Dataset, transform: Callable[..., Tensor] | None = None, target_transform: Callable[[Target], Target] | None = None, return_format: Literal['xyxy', 'xywh', 'cxcywh'] = 'xyxy')

Bases: TorchDataset[ImageTargetPair]

Adapter to convert BoxLab datasets to PyTorch-compatible format.

This adapter wraps Dataset instances and provides a PyTorch Dataset interface suitable for use with DataLoader and torchvision transforms. It handles image loading, annotation formatting, and coordinate conversion.

The adapter follows torchvision's object detection conventions, making it compatible with models like Faster R-CNN, Mask R-CNN, and other detection architectures.

Parameters:

Name Type Description Default
dataset Dataset

Source BoxLab Dataset instance.

required
transform Callable[..., Tensor] | None

Optional torchvision transforms pipeline for images. Applied to PIL Images before returning.

None
target_transform Callable[[Target], Target] | None

Optional transforms for targets/annotations. Applied to the target dictionary.

None
return_format Literal['xyxy', 'xywh', 'cxcywh']

Format for bounding boxes. Options: - "xyxy": [x_min, y_min, x_max, y_max] - "xywh": [x_min, y_min, width, height] - "cxcywh": [center_x, center_y, width, height]

'xyxy'

Attributes:

Name Type Description
dataset

The wrapped Dataset instance.

transform

Image transformation pipeline.

target_transform

Target transformation pipeline.

return_format

Bounding box format string.

image_ids

List of image IDs for indexing.

Note

This adapter requires torch, torchvision, and pillow to be installed. Install with: pip install torch torchvision pillow

Raises:

Type Description
RequiredModuleNotFoundError

If torch, torchvision, or PIL are not installed.

Example
from boxlab.dataset import Dataset
from boxlab.dataset.torchadapter import TorchDatasetAdapter
from torchvision import transforms as T

# Create dataset
dataset = Dataset(name="my_dataset")
# ... populate dataset ...

# Create adapter with transforms
transform = T.Compose([
    T.Resize((640, 640)),
    T.ToTensor(),
])

torch_dataset = TorchDatasetAdapter(
    dataset, transform=transform, return_format="xyxy"
)

# Use with DataLoader
from torch.utils.data import DataLoader

loader = DataLoader(
    torch_dataset,
    batch_size=4,
    collate_fn=torch_dataset.collate,
)
Example
# Iterate over dataset
for image, target in torch_dataset:
    print(f"Image shape: {image.shape}")
    print(f"Boxes: {target['boxes'].shape}")
    print(f"Labels: {target['labels']}")
Source code in boxlab/dataset/torchadapter.py
def __init__(
    self,
    dataset: Dataset,
    transform: t.Callable[..., torch.Tensor] | None = None,  # type: ignore[valid-type]
    target_transform: t.Callable[[Target], Target] | None = None,
    return_format: t.Literal["xyxy", "xywh", "cxcywh"] = "xyxy",
) -> None:
    self.dataset = dataset
    self.transform = transform
    self.target_transform = target_transform
    self.return_format = return_format

    # Create index to image_id mapping
    self.image_ids = list(dataset.images.keys())

    logger.info(
        f"Created TorchDatasetAdapter with {len(self.image_ids)} images, "
        f"bbox format: {return_format}"
    )
Functions
__len__
__len__() -> int

Return the total number of samples in the dataset.

Returns:

Type Description
int

Number of images in the dataset.

Source code in boxlab/dataset/torchadapter.py
def __len__(self) -> int:
    """Return the total number of samples in the dataset.

    Returns:
        Number of images in the dataset.
    """
    return len(self.image_ids)
__getitem__
__getitem__(idx: int) -> ImageTargetPair

Get a sample by index.

Loads the image and its annotations, applies transforms, and returns them in PyTorch-compatible format.

Parameters:

Name Type Description Default
idx int

Sample index (0-based integer).

required

Returns:

Type Description
ImageTargetPair

Tuple of (image, target) where: - image: PIL Image or torch.Tensor (if transform applied) - target: Dictionary containing: - boxes: Tensor of shape (N, 4) with bounding boxes - labels: Tensor of shape (N,) with class labels (1-indexed) - image_id: Tensor with image identifier - area: Tensor of shape (N,) with box areas - iscrowd: Tensor of shape (N,) with crowd flags

Raises:

Type Description
DatasetError

If image is not found in dataset.

DatasetNotFoundError

If image file does not exist on disk.

Example
# Get first sample
image, target = torch_dataset[0]

# Access target components
boxes = target["boxes"]  # Shape: (N, 4)
labels = target["labels"]  # Shape: (N,)
image_id = target["image_id"]
Source code in boxlab/dataset/torchadapter.py
def __getitem__(self, idx: int) -> ImageTargetPair:
    """Get a sample by index.

    Loads the image and its annotations, applies transforms, and returns them
    in PyTorch-compatible format.

    Args:
        idx: Sample index (0-based integer).

    Returns:
        Tuple of (image, target) where:
            - image: PIL Image or torch.Tensor (if transform applied)
            - target: Dictionary containing:
                - boxes: Tensor of shape (N, 4) with bounding boxes
                - labels: Tensor of shape (N,) with class labels (1-indexed)
                - image_id: Tensor with image identifier
                - area: Tensor of shape (N,) with box areas
                - iscrowd: Tensor of shape (N,) with crowd flags

    Raises:
        DatasetError: If image is not found in dataset.
        DatasetNotFoundError: If image file does not exist on disk.

    Example:
        ```python
        # Get first sample
        image, target = torch_dataset[0]

        # Access target components
        boxes = target["boxes"]  # Shape: (N, 4)
        labels = target["labels"]  # Shape: (N,)
        image_id = target["image_id"]
        ```
    """
    image_id = self.image_ids[idx]
    img_info = self.dataset.get_image(image_id)

    if img_info is None:
        raise DatasetError(f"Image {image_id} not found in dataset")

    if img_info.path is None or not img_info.path.exists():
        raise DatasetNotFoundError(str(img_info.path), f"Image file not found: {img_info.path}")

    # Load image
    img: torch.Tensor | Image.Image
    img = Image.open(img_info.path).convert("RGB")

    # Get annotations
    annotations = self.dataset.get_annotations(image_id)

    # Prepare target dict
    target = self._prepare_target(annotations, img_info.image_id)

    # Apply transforms
    if self.transform is not None:
        img = self.transform(img)

    if self.target_transform is not None:
        target = self.target_transform(target)

    return img, target
collate
collate(batch: list[ImageTargetPair]) -> tuple[list[Tensor], list[Target]]

Custom collate function for DataLoader.

This collate function is useful when images have different numbers of objects, which is common in object detection. Instead of stacking tensors (which requires same dimensions), it returns lists of tensors and targets.

Parameters:

Name Type Description Default
batch list[ImageTargetPair]

List of (image, target) tuples from getitem.

required

Returns:

Type Description
tuple[list[Tensor], list[Target]]

Tuple of (images, targets) where: - images: List of image tensors - targets: List of target dictionaries

Example
from torch.utils.data import DataLoader

loader = DataLoader(
    torch_dataset,
    batch_size=4,
    collate_fn=torch_dataset.collate,
    shuffle=True,
)

for images, targets in loader:
    # images: list of 4 tensors
    # targets: list of 4 target dicts
    for img, tgt in zip(images, targets):
        print(f"Image: {img.shape}")
        print(f"Objects: {len(tgt['boxes'])}")
Source code in boxlab/dataset/torchadapter.py
def collate(self, batch: list[ImageTargetPair]) -> tuple[list[torch.Tensor], list[Target]]:  # type: ignore[valid-type]
    """Custom collate function for DataLoader.

    This collate function is useful when images have different numbers of objects,
    which is common in object detection. Instead of stacking tensors (which requires
    same dimensions), it returns lists of tensors and targets.

    Args:
        batch: List of (image, target) tuples from __getitem__.

    Returns:
        Tuple of (images, targets) where:
            - images: List of image tensors
            - targets: List of target dictionaries

    Example:
        ```python
        from torch.utils.data import DataLoader

        loader = DataLoader(
            torch_dataset,
            batch_size=4,
            collate_fn=torch_dataset.collate,
            shuffle=True,
        )

        for images, targets in loader:
            # images: list of 4 tensors
            # targets: list of 4 target dicts
            for img, tgt in zip(images, targets):
                print(f"Image: {img.shape}")
                print(f"Objects: {len(tgt['boxes'])}")
        ```
    """
    images = [item[0] for item in batch]
    targets = [item[1] for item in batch]
    return images, targets

Functions

build_torchdataset

build_torchdataset(dataset: Dataset, image_size: int | tuple[int, int] | None = None, augment: bool = False, normalize: bool = False, *transforms: Callable[..., Any], return_format: Literal['xyxy', 'xywh', 'cxcywh'] = 'xyxy') -> TorchDatasetAdapter

Create a PyTorch-compatible dataset with standard transforms.

This convenience function builds a TorchDatasetAdapter with commonly used transforms for object detection, including resizing, augmentation, and normalization.

Parameters:

Name Type Description Default
dataset Dataset

Source BoxLab Dataset instance.

required
image_size int | tuple[int, int] | None

Target image size. Can be: - int: Square resize (size, size) - tuple: (height, width) - None: No resizing

None
augment bool

Whether to apply data augmentation. Includes: - Random horizontal flip (p=0.5) - Color jitter (brightness, contrast, saturation, hue) - Random affine (rotation, translation, scale)

False
normalize bool

Whether to normalize images using ImageNet statistics: - mean=[0.485, 0.456, 0.406] - std=[0.229, 0.224, 0.225]

False
*transforms Callable[..., Any]

Additional user-defined transforms to append.

()
return_format Literal['xyxy', 'xywh', 'cxcywh']

Bounding box format ("xyxy", "xywh", or "cxcywh").

'xyxy'

Returns:

Type Description
TorchDatasetAdapter

TorchDatasetAdapter instance with configured transforms.

Note

This function requires torch, torchvision, and pillow to be installed. Install with: pip install torch torchvision pillow

Raises:

Type Description
RequiredModuleNotFoundError

If required packages are not installed.

Example
from boxlab.dataset import Dataset
from boxlab.dataset.torchadapter import build_torchdataset
from torch.utils.data import DataLoader

# Create dataset
dataset = Dataset(name="my_dataset")
# ... populate dataset ...

# Build training dataset with augmentation
train_dataset = build_torchdataset(
    dataset,
    image_size=640,
    augment=True,
    normalize=True,
    return_format="xyxy",
)

# Build validation dataset without augmentation
val_dataset = build_torchdataset(
    dataset,
    image_size=640,
    augment=False,
    normalize=True,
    return_format="xyxy",
)

# Create DataLoaders
train_loader = DataLoader(
    train_dataset,
    batch_size=16,
    shuffle=True,
    collate_fn=train_dataset.collate,
)

val_loader = DataLoader(
    val_dataset,
    batch_size=16,
    shuffle=False,
    collate_fn=val_dataset.collate,
)
Example
# Add custom transforms
from torchvision import transforms as T

custom_transform = T.GaussianBlur(kernel_size=3)

torch_dataset = build_torchdataset(
    dataset,
    image_size=640,
    augment=True,
    normalize=True,
    custom_transform,  # Added after normalization
    return_format="cxcywh"
)
Example
# Different image sizes
torch_dataset = build_torchdataset(
    dataset,
    image_size=(800, 600),  # height x width
    augment=False,
    normalize=False,
)
Source code in boxlab/dataset/torchadapter.py
def build_torchdataset(
    dataset: Dataset,
    image_size: int | tuple[int, int] | None = None,
    augment: bool = False,
    normalize: bool = False,
    *transforms: t.Callable[..., t.Any],
    return_format: t.Literal["xyxy", "xywh", "cxcywh"] = "xyxy",
) -> TorchDatasetAdapter:
    """Create a PyTorch-compatible dataset with standard transforms.

    This convenience function builds a TorchDatasetAdapter with commonly used
    transforms for object detection, including resizing, augmentation, and
    normalization.

    Args:
        dataset: Source BoxLab Dataset instance.
        image_size: Target image size. Can be:
            - int: Square resize (size, size)
            - tuple: (height, width)
            - None: No resizing
        augment: Whether to apply data augmentation. Includes:
            - Random horizontal flip (p=0.5)
            - Color jitter (brightness, contrast, saturation, hue)
            - Random affine (rotation, translation, scale)
        normalize: Whether to normalize images using ImageNet statistics:
            - mean=[0.485, 0.456, 0.406]
            - std=[0.229, 0.224, 0.225]
        *transforms: Additional user-defined transforms to append.
        return_format: Bounding box format ("xyxy", "xywh", or "cxcywh").

    Returns:
        TorchDatasetAdapter instance with configured transforms.

    Note:
        This function requires torch, torchvision, and pillow to be installed.
        Install with: `pip install torch torchvision pillow`

    Raises:
        RequiredModuleNotFoundError: If required packages are not installed.

    Example:
        ```python
        from boxlab.dataset import Dataset
        from boxlab.dataset.torchadapter import build_torchdataset
        from torch.utils.data import DataLoader

        # Create dataset
        dataset = Dataset(name="my_dataset")
        # ... populate dataset ...

        # Build training dataset with augmentation
        train_dataset = build_torchdataset(
            dataset,
            image_size=640,
            augment=True,
            normalize=True,
            return_format="xyxy",
        )

        # Build validation dataset without augmentation
        val_dataset = build_torchdataset(
            dataset,
            image_size=640,
            augment=False,
            normalize=True,
            return_format="xyxy",
        )

        # Create DataLoaders
        train_loader = DataLoader(
            train_dataset,
            batch_size=16,
            shuffle=True,
            collate_fn=train_dataset.collate,
        )

        val_loader = DataLoader(
            val_dataset,
            batch_size=16,
            shuffle=False,
            collate_fn=val_dataset.collate,
        )
        ```

    Example:
        ```python
        # Add custom transforms
        from torchvision import transforms as T

        custom_transform = T.GaussianBlur(kernel_size=3)

        torch_dataset = build_torchdataset(
            dataset,
            image_size=640,
            augment=True,
            normalize=True,
            custom_transform,  # Added after normalization
            return_format="cxcywh"
        )
        ```

    Example:
        ```python
        # Different image sizes
        torch_dataset = build_torchdataset(
            dataset,
            image_size=(800, 600),  # height x width
            augment=False,
            normalize=False,
        )
        ```
    """
    logger.info(
        f"Building PyTorch dataset: size={image_size}, augment={augment}, "
        f"normalize={normalize}, format={return_format}"
    )

    transform_list: list[t.Callable[..., t.Any]] = []

    # Resize
    if image_size is not None:
        if isinstance(image_size, int):
            image_size = (image_size, image_size)
        transform_list.append(T.Resize(image_size))

    # Augmentation
    if augment:
        transform_list.extend([
            T.RandomHorizontalFlip(p=0.5),
            T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
            T.RandomAffine(degrees=10, translate=(0.1, 0.1), scale=(0.9, 1.1)),
        ])

    # Convert to tensor (must be before normalization)
    transform_list.append(T.ToTensor())

    # Normalization (must be after ToTensor)
    if normalize:
        transform_list.append(T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]))

    # Additional user-defined transforms
    transform_list.extend(transforms)

    # Compose all transforms
    transform = T.Compose(transform_list)

    return TorchDatasetAdapter(
        dataset=dataset,
        transform=transform,
        return_format=return_format,
    )

options: show_root_heading: true show_source: true heading_level: 2 members_order: source show_signature_annotations: true separate_signature: true

Overview

The PyTorch adapter module provides seamless integration between BoxLab datasets and PyTorch. It converts BoxLab datasets into PyTorch-compatible format, enabling direct use with DataLoader, torchvision transforms, and popular detection models.

Installation Requirements

This module requires additional dependencies:

pip install torch torchvision

If these packages are not installed, import errors will be raised with helpful installation instructions.

Key Components

TorchDatasetAdapter

Wraps BoxLab Dataset instances to provide PyTorch Dataset interface. Handles:

  • Image loading and format conversion
  • Annotation format conversion
  • Transform pipeline application
  • Batch collation for variable-sized objects

Target Dictionary

The adapter returns targets in torchvision's standard object detection format:

{
    'boxes': Tensor,  # Shape: (N, 4) - bounding boxes
    'labels': Tensor,  # Shape: (N,) - class labels
    'image_id': Tensor,  # Shape: (1,) - image identifier
    'area': Tensor,  # Shape: (N,) - box areas
    'iscrowd': Tensor  # Shape: (N,) - crowd flags
}

Bounding Box Formats

Three formats are supported:

  • xyxy: [x_min, y_min, x_max, y_max] - Top-left and bottom-right corners
  • xywh: [x_min, y_min, width, height] - COCO format
  • cxcywh: [center_x, center_y, width, height] - YOLO format

Common Usage Patterns

Basic Training Setup

from boxlab.dataset import Dataset
from boxlab.dataset.torchadapter import build_torchdataset
from torch.utils.data import DataLoader

# Load dataset
dataset = Dataset(name="my_dataset")
# ... populate dataset ...

# Create training dataset with augmentation
train_ds = build_torchdataset(
    dataset,
    image_size=640,
    augment=True,
    normalize=True
)

# Create DataLoader
train_loader = DataLoader(
    train_ds,
    batch_size=16,
    shuffle=True,
    num_workers=4,
    collate_fn=train_ds.collate
)

# Training loop
for images, targets in train_loader:
    # images: list of tensors
    # targets: list of dicts
    ...

Train/Val Split

from boxlab.dataset import Dataset
from boxlab.dataset.types import SplitRatio
from boxlab.dataset.torchadapter import build_torchdataset

# Split dataset
dataset = Dataset(name="full_dataset")
splits = dataset.split(SplitRatio(train=0.8, val=0.2, test=0.0), seed=42)

# Create separate Dataset instances
train_dataset = Dataset(name="train")
val_dataset = Dataset(name="val")

# Populate split datasets
for img_id in splits['train']:
    img_info = dataset.get_image(img_id)
    train_dataset.add_image(img_info)
    for ann in dataset.get_annotations(img_id):
        train_dataset.add_annotation(ann)

for img_id in splits['val']:
    img_info = dataset.get_image(img_id)
    val_dataset.add_image(img_info)
    for ann in dataset.get_annotations(img_id):
        val_dataset.add_annotation(ann)

# Create PyTorch datasets
train_torch = build_torchdataset(train_dataset, image_size=640, augment=True, normalize=True)
val_torch = build_torchdataset(val_dataset, image_size=640, augment=False, normalize=True)

Custom Transforms

from torchvision import transforms as T
from boxlab.dataset.torchadapter import TorchDatasetAdapter

# Define custom transform pipeline
transform = T.Compose(
    [
        T.Resize((640, 640)),
        T.RandomRotation(10),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ]
)

# Create adapter with custom transforms
adapter = TorchDatasetAdapter(
    dataset,
    transform=transform,
    return_format="xyxy"
)

Using with Detection Models

import torch
from torchvision.models.detection import fasterrcnn_resnet50_fpn
from boxlab.dataset.torchadapter import build_torchdataset

# Prepare dataset
torch_dataset = build_torchdataset(
    dataset,
    image_size=800,
    augment=True,
    normalize=True,
    return_format="xyxy"  # Faster R-CNN expects xyxy
)

loader = DataLoader(
    torch_dataset,
    batch_size=4,
    collate_fn=torch_dataset.collate
)

# Load model
model = fasterrcnn_resnet50_fpn(pretrained=True)
model.train()

# Training
optimizer = torch.optim.SGD(model.parameters(), lr=0.005)

for images, targets in loader:
    images = [img.to(device) for img in images]
    targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

    loss_dict = model(images, targets)
    losses = sum(loss for loss in loss_dict.values())

    optimizer.zero_grad()
    losses.backward()
    optimizer.step()

Transform Pipeline Order

When using build_torchdataset(), transforms are applied in this order:

  1. Resize (if image_size specified)
  2. Augmentation (if augment=True):
    • Random horizontal flip
    • Color jitter
    • Random affine transformations
  3. ToTensor (always applied)
  4. Normalization (if normalize=True)
  5. Custom transforms (additional args)

Error Handling

Missing Dependencies

try:
    from boxlab.dataset.torchadapter import build_torchdataset
except RequiredModuleNotFoundError as e:
    print(f"Missing dependency: {e}")
    print("Install with: pip install torch torchvision pillow")

Missing Images

from boxlab.exceptions import DatasetNotFoundError

try:
    image, target = torch_dataset[0]
except DatasetNotFoundError as e:
    print(f"Image file not found: {e}")

See Also