Skip to content

torch

torch

Attributes

__doctitle__ module-attribute

__doctitle__ = 'TorchScript Runtime'

logger module-attribute

logger = getLogger('inferflow.runtime.torch')

__all__ module-attribute

__all__ = ['TorchRuntimeMixin', 'TorchScriptRuntime']

Classes

TorchRuntimeMixin

Shared TorchScript runtime logic for sync and async implementations.

This mixin provides common TorchScript-specific logic that is shared between synchronous and asynchronous runtime implementations. It handles:

- Device setup (CUDA, CPU, MPS)
- Precision conversion (FP32, FP16)
- Input preparation and validation
- Batch dimension management
- Output post-processing

This mixin is pure logic with no I/O operations, making it safe to reuse across sync and async implementations.

Attributes:

Name Type Description
device Any

Device configuration (provided by subclass).

precision Precision

Precision configuration (provided by subclass).

Example
# In sync runtime
class TorchScriptRuntime(
    TorchRuntimeMixin, RuntimeConfigMixin, BatchableRuntime
):
    def load(self):
        self._torch_device = (
            self._setup_torch_device()
        )  # Use mixin
        # ...


# In async runtime
class TorchScriptRuntime(
    TorchRuntimeMixin, RuntimeConfigMixin, BatchableRuntime
):
    async def load(self):
        self._torch_device = (
            self._setup_torch_device()
        )  # Same mixin!
        # ...
Attributes
device instance-attribute
device: Any
precision instance-attribute
precision: Precision

TorchScriptRuntime

TorchScriptRuntime(model_path: str | PathLike[str], device: str | Device, precision: Precision = FP32, warmup_iterations: int = 3, warmup_shape: tuple[int, ...] = (1, 3, 224, 224), auto_add_batch_dim: bool = False)

Bases: RuntimeConfigMixin, TorchRuntimeMixin, BatchableRuntime[Tensor, R]

TorchScript model runtime (sync version).

Supports
  • TorchScript (.pt, .pth) models
  • CUDA, CPU, MPS devices
  • FP32, FP16 precision
  • Batch inference
  • Automatic warmup
  • Optional automatic batch dimension handling

Attributes:

Name Type Description
model ScriptModule | None

Loaded TorchScript model (None before load()).

auto_add_batch_dim

Whether to auto-add batch dimension for 3D inputs.

Parameters:

Name Type Description Default
model_path str | PathLike[str]

Path to TorchScript model file.

required
device str | Device

Device to run inference on (default: "cpu").

required
precision Precision

Model precision (default: FP32).

FP32
warmup_iterations int

Number of warmup iterations (default: 3).

3
warmup_shape tuple[int, ...]

Input shape for warmup (default: (1, 3, 224, 224)).

(1, 3, 224, 224)
auto_add_batch_dim bool

Whether to automatically add batch dimension if input is 3D (default: False).

False

Raises:

Type Description
FileNotFoundError

If model file does not exist.

RuntimeError

If CUDA/MPS is requested but not available.

ImportError

If torch is not installed.

Example
import inferflow as iff
import torch

# Initialize runtime
runtime = iff.TorchScriptRuntime(
    model_path="model.pt",
    device="cuda: 0",
    precision=iff.Precision.FP16,
    auto_add_batch_dim=True,
)

# Single inference
with runtime:
    input_tensor = torch.randn(3, 224, 224)  # 3D input
    output = runtime.infer(
        input_tensor
    )  # Batch dim auto-added

# Batch inference
with runtime:
    batch = [
        torch.randn(1, 3, 224, 224),
        torch.randn(1, 3, 224, 224),
    ]
    outputs = runtime.infer_batch(batch)
Source code in inferflow/runtime/torch.py
def __init__(
    self,
    model_path: str | os.PathLike[str],
    device: str | Device,
    precision: Precision = Precision.FP32,
    warmup_iterations: int = 3,
    warmup_shape: tuple[int, ...] = (1, 3, 224, 224),
    auto_add_batch_dim: bool = False,
):
    super().__init__(
        model_path=model_path,
        device=device,
        precision=precision,
        warmup_iterations=warmup_iterations,
        warmup_shape=warmup_shape,
    )

    self.auto_add_batch_dim = auto_add_batch_dim

    self.model: torch.jit.ScriptModule | None = None
    self._torch_device: torch.device | None = None

    logger.info(
        f"TorchScriptRuntime initialized:  "
        f"model={self.model_path}, device={self.device}, "
        f"precision={self.precision.value}"
    )
Attributes
auto_add_batch_dim instance-attribute
auto_add_batch_dim = auto_add_batch_dim
model instance-attribute
model: ScriptModule | None = None
Functions
load
load() -> None

Load TorchScript model and prepare for inference.

Performs: - Load model from disk - Setup device - Move model to device - Set evaluation mode - Apply precision - Warmup inference

Raises:

Type Description
FileNotFoundError

If model file does not exist.

RuntimeError

If device is not available.

Source code in inferflow/runtime/torch.py
def load(self) -> None:
    """Load TorchScript model and prepare for inference.

    Performs:
    - Load model from disk
    - Setup device
    - Move model to device
    - Set evaluation mode
    - Apply precision
    - Warmup inference

    Raises:
        FileNotFoundError: If model file does not exist.
        RuntimeError: If device is not available.
    """
    logger.info(f"Loading model from {self.model_path}")

    # Load model
    self.model = torch.jit.load(str(self.model_path))

    # Setup device (reuse mixin)
    self._torch_device = self._setup_torch_device()

    # Configure model
    self.model.to(self._torch_device)
    self.model.eval()
    self.model = self._apply_precision_to_model(self.model)

    logger.info(f"Model loaded on {self._torch_device}")

    # Warmup
    self._warmup()
infer
infer(input: Tensor) -> R

Run inference on a single input.

Automatically handles: - Moving input to correct device - Converting to correct precision - Adding batch dimension (if configured) - Removing batch dimension (if added)

Parameters:

Name Type Description Default
input Tensor

Input tensor. Can be 3D (C, H, W) if auto_add_batch_dim=True, or 4D (1, C, H, W) otherwise.

required

Returns:

Type Description
R

Model output. Type depends on model architecture (tensor or tuple).

Raises:

Type Description
RuntimeError

If model is not loaded.

Example
with runtime:
    # 3D input (auto_add_batch_dim=True)
    input = torch.randn(3, 224, 224)
    output = runtime.infer(input)

    # 4D input
    input = torch.randn(1, 3, 224, 224)
    output = runtime.infer(input)
Source code in inferflow/runtime/torch.py
def infer(self, input: torch.Tensor) -> R:
    """Run inference on a single input.

    Automatically handles:
    - Moving input to correct device
    - Converting to correct precision
    - Adding batch dimension (if configured)
    - Removing batch dimension (if added)

    Args:
        input: Input tensor. Can be 3D (C, H, W) if auto_add_batch_dim=True,
            or 4D (1, C, H, W) otherwise.

    Returns:
        Model output.  Type depends on model architecture (tensor or tuple).

    Raises:
        RuntimeError: If model is not loaded.

    Example:
        ```python
        with runtime:
            # 3D input (auto_add_batch_dim=True)
            input = torch.randn(3, 224, 224)
            output = runtime.infer(input)

            # 4D input
            input = torch.randn(1, 3, 224, 224)
            output = runtime.infer(input)
        ```
    """
    if self.model is None:
        raise RuntimeError("Model not loaded. Call load() first.")

    # Prepare input (reuse mixin)
    input = self._prepare_input(input, self._torch_device)
    input, added_batch = self._add_batch_dim_if_needed(input, self.auto_add_batch_dim)

    # Inference
    with torch.no_grad():
        output = self.model(input)

    # Post-process (reuse mixin)
    return self._remove_batch_dim_if_added(output, added_batch)
infer_batch
infer_batch(inputs: list[Tensor]) -> list[R]

Run inference on a batch of inputs.

Concatenates inputs into a single batch tensor for efficient processing, then splits the output back into individual results.

Parameters:

Name Type Description Default
inputs list[Tensor]

List of input tensors. Each should have shape (1, C, H, W).

required

Returns:

Type Description
list[R]

List of outputs, one per input. Each maintains batch dimension.

Raises:

Type Description
RuntimeError

If model is not loaded.

Example
with runtime:
    batch = [
        torch.randn(1, 3, 224, 224),
        torch.randn(1, 3, 224, 224),
        torch.randn(1, 3, 224, 224),
    ]

    # Efficient batch processing
    outputs = runtime.infer_batch(batch)

    # outputs[0], outputs[1], outputs[2]
Source code in inferflow/runtime/torch.py
def infer_batch(self, inputs: list[torch.Tensor]) -> list[R]:
    """Run inference on a batch of inputs.

    Concatenates inputs into a single batch tensor for efficient
    processing, then splits the output back into individual results.

    Args:
        inputs: List of input tensors. Each should have shape (1, C, H, W).

    Returns:
        List of outputs, one per input. Each maintains batch dimension.

    Raises:
        RuntimeError: If model is not loaded.

    Example:
        ```python
        with runtime:
            batch = [
                torch.randn(1, 3, 224, 224),
                torch.randn(1, 3, 224, 224),
                torch.randn(1, 3, 224, 224),
            ]

            # Efficient batch processing
            outputs = runtime.infer_batch(batch)

            # outputs[0], outputs[1], outputs[2]
        ```
    """
    if self.model is None:
        raise RuntimeError("Model not loaded. Call load() first.")

    # Concatenate
    batch = torch.cat(inputs, dim=0).to(self._torch_device)

    if self.precision == Precision.FP16:
        batch = batch.half()

    # Inference
    with torch.no_grad():
        batch_output = self.model(batch)

    # Split (reuse mixin)
    return self._split_batch_output(batch_output, len(inputs))
unload
unload() -> None

Unload model and free resources.

Performs: - Release model from memory - Clear CUDA cache (if using CUDA)

Safe to call multiple times.

Source code in inferflow/runtime/torch.py
def unload(self) -> None:
    """Unload model and free resources.

    Performs:
    - Release model from memory
    - Clear CUDA cache (if using CUDA)

    Safe to call multiple times.
    """
    logger.info("Unloading model")
    self.model = None

    if self.device.type.value == "cuda":
        torch.cuda.empty_cache()

    logger.info("Model unloaded")