Distributed Denoiser with Image Tiling#

In many imaging problems, the data to be processed can be very large, making it challenging to fit the denoising process into the memory of a single device. For instance, medical imaging or satellite imagery often involves processing gigapixel images that cannot be processed as a whole.

The distributed framework enables you to parallelize the denoising of large images across multiple devices using image tiling. Each device processes different image patches independently, and the results are merged to produce the final denoised image.

This example demonstrates how to use the deepinv.distributed.distribute() function to create a distributed denoiser that automatically handles patch extraction, processing, and merging.

Usage:

# Single process
python examples/distributed/demo_denoiser_distributed.py
# Multi-process with torchrun (2 GPUs/processes)
python -m torch.distributed.run --nproc_per_node=2 examples/distributed/demo_denoiser_distributed.py

Key Features:

  • Distribute denoising across processes/devices using image tiling

  • Automatic patch extraction and reassembly

  • Memory-efficient processing of large images

Key Steps:

  1. Load a large test image

  2. Add noise to create a noisy observation

  3. Initialize distributed context

  4. Configure tiling parameters

  5. Distribute denoiser with deepinv.distributed.distribute()

  6. Apply distributed denoising

  7. Visualize results and compute metrics

Import modules and define noisy image generation#

We start by importing torch and the modules of deepinv that we use in this example. We also define a function that generates noisy images to evaluate the distributed framework.

import torch
from deepinv.models import DRUNet
from deepinv.utils.demo import load_example
from deepinv.utils.plotting import plot
from deepinv.loss.metric import PSNR

# Import distributed framework
from deepinv.distributed import DistributedContext, distribute


def create_noisy_image(device, img_size=1024, noise_sigma=0.1, seed=42):
    """
    Create a noisy test image.

    :param device: Device to create image on
    :param tuple img_size: Size of the image (H, W)
    :param float noise_sigma: Standard deviation of Gaussian noise
    :param int seed: Random seed for reproducible noise
    :returns: Tuple of (clean_image, noisy_image, noise_sigma)
    """
    # Load example image in original size
    clean_image = load_example(
        "CBSD_0010.png",
        grayscale=False,
        device=device,
        img_size=img_size,
        resize_mode="resize",
    )

    # Set seed for reproducible noise
    torch.manual_seed(seed)

    # Add Gaussian noise
    noise = torch.randn_like(clean_image) * noise_sigma
    noisy_image = clean_image + noise

    # Clip to valid range
    noisy_image = torch.clamp(noisy_image, 0, 1)

    return clean_image, noisy_image, noise_sigma

Configuration of parallel denoising#

img_size = 512  # Large image for demonstrating tiling
noise_sigma = 0.1
patch_size = 256  # Size of each patch
overlap = 64  # Overlap for smooth boundaries

Define distributed context and run algorithm#

# Initialize distributed context (handles single and multi-process automatically)
with DistributedContext(seed=42) as ctx:

    if ctx.rank == 0:
        print("=" * 70)
        print("Distributed Denoiser Demo")
        print("=" * 70)
        print(f"\nRunning on {ctx.world_size} process(es)")
        print(f"   Device: {ctx.device}")

    # ---------------------------------------------------------------------------
    # Step 1: Create test image with noise
    # ---------------------------------------------------------------------------

    clean_image, noisy_image, sigma = create_noisy_image(
        ctx.device, img_size=img_size, noise_sigma=noise_sigma
    )

    # Compute input PSNR (create metric on all ranks for consistency)
    psnr_metric = PSNR()
    input_psnr = psnr_metric(noisy_image, clean_image).item()

    if ctx.rank == 0:
        print(f"\nCreated test image")
        print(f"   Image shape: {clean_image.shape}")
        print(f"   Noise sigma: {sigma}")
        print(f"   Input PSNR: {input_psnr:.2f} dB")

    # ---------------------------------------------------------------------------
    # Step 2: Load denoiser model
    # ---------------------------------------------------------------------------

    if ctx.rank == 0:
        print(f"\nLoading DRUNet denoiser...")

    denoiser = DRUNet(pretrained="download").to(ctx.device)

    if ctx.rank == 0:
        print(f"   Denoiser loaded")

    # ---------------------------------------------------------------------------
    # Step 3: Distribute denoiser with tiling configuration
    # ---------------------------------------------------------------------------

    if ctx.rank == 0:
        print(f"\nConfiguring distributed denoiser")
        print(f"   Patch size: {patch_size}x{patch_size}")
        print(f"   Receptive field radius: {overlap}")
        print(f"   Tiling strategy: overlap_tiling")

    distributed_denoiser = distribute(
        denoiser,
        ctx,
        patch_size=patch_size,
        overlap=overlap,
    )

    if ctx.rank == 0:
        print(f"   Distributed denoiser created")

    # ---------------------------------------------------------------------------
    # Step 4: Apply distributed denoising
    # ---------------------------------------------------------------------------

    if ctx.rank == 0:
        print(f"\nApplying distributed denoising...")

    with torch.no_grad():
        denoised_image = distributed_denoiser(noisy_image, sigma=sigma)

    if ctx.rank == 0:
        print(f"   Denoising completed")
        print(f"   Output shape: {denoised_image.shape}")

    # Compare with non-distributed result (only on rank 0)
    if ctx.rank == 0:
        print(f"\nComparing with non-distributed denoising...")
        with torch.no_grad():
            denoised_ref = denoiser(noisy_image, sigma=sigma)

        diff = torch.abs(denoised_image - denoised_ref)
        mean_diff = diff.mean().item()
        max_diff = diff.max().item()

        print(f"   Mean absolute difference: {mean_diff:.2e}")
        print(f"   Max absolute difference:  {max_diff:.2e}")

        # Check that differences are small (due to tiling boundary effects)
        # The distributed version uses tiling with overlapping patches and blending,
        # which can produce slightly different results at patch boundaries.
        # These differences are typically very small (< 0.01 mean, < 0.5 max).
        tolerance_mean = 0.01
        tolerance_max = 0.5
        assert (
            mean_diff < tolerance_mean
        ), f"Mean difference too large: {mean_diff:.4f} (tolerance: {tolerance_mean})"
        assert (
            max_diff < tolerance_max
        ), f"Max difference too large: {max_diff:.4f} (tolerance: {tolerance_max})"
        print(f"   Results are very close (within tolerance)!")

    # ---------------------------------------------------------------------------
    # Step 5: Compute metrics and visualize results (only on rank 0)
    # ---------------------------------------------------------------------------

    if ctx.rank == 0:
        # Compute output PSNR
        output_psnr = psnr_metric(denoised_image, clean_image).item()
        psnr_improvement = output_psnr - input_psnr

        print(f"\nResults:")
        print(f"   Input PSNR:  {input_psnr:.2f} dB")
        print(f"   Output PSNR: {output_psnr:.2f} dB")
        print(f"   Improvement: {psnr_improvement:.2f} dB")

        # Plot results
        plot(
            [clean_image, noisy_image, denoised_image],
            titles=[
                "Clean Image",
                f"Noisy (PSNR: {input_psnr:.2f} dB)",
                f"Denoised (PSNR: {output_psnr:.2f} dB)",
            ],
            save_fn="distributed_denoiser_result.png",
            figsize=(15, 4),
        )

        # Plot zoom on a region to see details
        # Extract a 256x256 patch from center
        h, w = clean_image.shape[-2:]
        y_start, x_start = h // 2 - 128, w // 2 - 128
        y_end, x_end = y_start + 256, x_start + 256

        clean_patch = clean_image[..., y_start:y_end, x_start:x_end]
        noisy_patch = noisy_image[..., y_start:y_end, x_start:x_end]
        denoised_patch = denoised_image[..., y_start:y_end, x_start:x_end]

        plot(
            [clean_patch, noisy_patch, denoised_patch],
            titles=["Clean (zoom)", "Noisy (zoom)", "Denoised (zoom)"],
            save_fn="distributed_denoiser_zoom.png",
            figsize=(15, 4),
        )

        print(f"\nDemo completed successfully!")
        print(f"   Results saved to:")
        print(f"   - distributed_denoiser_result.png")
        print(f"   - distributed_denoiser_zoom.png")
        print("\n" + "=" * 70)
  • Clean Image, Noisy (PSNR: 20.34 dB), Denoised (PSNR: 34.01 dB)
  • Clean (zoom), Noisy (zoom), Denoised (zoom)
======================================================================
Distributed Denoiser Demo
======================================================================

Running on 1 process(es)
   Device: cuda:0

Created test image
   Image shape: torch.Size([1, 3, 767, 512])
   Noise sigma: 0.1
   Input PSNR: 20.34 dB

Loading DRUNet denoiser...
   Denoiser loaded

Configuring distributed denoiser
   Patch size: 256x256
   Receptive field radius: 64
   Tiling strategy: overlap_tiling
   Distributed denoiser created

Applying distributed denoising...
/local/jtachell/deepinv/deepinv/deepinv/distributed/strategies.py:476: UserWarning: No tiling_dims provided. Assuming last 2 dimensions: (-2, -1). If your layout is different, please provide tiling_dims explicitly.
  warnings.warn(
   Denoising completed
   Output shape: torch.Size([1, 3, 767, 512])

Comparing with non-distributed denoising...
   Mean absolute difference: 6.39e-04
   Max absolute difference:  9.65e-02
   Results are very close (within tolerance)!

Results:
   Input PSNR:  20.34 dB
   Output PSNR: 34.01 dB
   Improvement: 13.67 dB

Demo completed successfully!
   Results saved to:
   - distributed_denoiser_result.png
   - distributed_denoiser_zoom.png

======================================================================

Total running time of the script: (0 minutes 10.415 seconds)

Gallery generated by Sphinx-Gallery