Distributed Computing#
For large-scale inverse problems, the memory and compute of a single device might not be enough. The distributed computing framework enables efficient parallel processing across multiple GPUs by distributing physics operators and computations across multiple processes.
The framework provides an API centered around two key functions:
deepinv.distributed.DistributedContext- manages distributed executiondeepinv.distributed.distribute()- converts regular objects to distributed versions
Note
The distributed framework is particularly useful when:
Multiple physics operators with individual measurements need to be processed in parallel
Large images are too large to fit in a single device’s memory
Denoising priors need to be applied to large images using spatial tiling
You want to accelerate reconstruction by leveraging multiple devices
Warning
This module is in beta and may undergo significant changes in future releases. Some features are experimental and only supported for specific use cases. Please report any issues you encounter on our GitHub repository.
Quick Start#
Here’s a minimal example that shows the complete workflow:
from deepinv.physics import Blur, stack
from deepinv.physics.blur import gaussian_blur
from deepinv.optim.data_fidelity import L2
from deepinv.models import DRUNet
from deepinv.distributed import DistributedContext, distribute
from deepinv.utils.demo import load_example
# Step 1: Create distributed context
with DistributedContext() as ctx:
# Load an example image
x = load_example(
"CBSD_0010.png", grayscale=False, device=str(ctx.device) # Make sure the image is on the correct device
)
# Step 2: Create and stack your physics operators
physics_list = [
Blur(
filter=gaussian_blur(sigma=1.0), padding="circular"
),
Blur(
filter=gaussian_blur(sigma=2.0), padding="circular"
),
Blur(
filter=gaussian_blur(sigma=3.0), padding="circular"
),
]
stacked_physics = stack(*physics_list)
# Step 3: Distribute physics
distributed_physics = distribute(stacked_physics, ctx) # Distribute physics operators, transfers to correct devices
# Use it like regular physics
y = distributed_physics(x) # Forward operation
x_adj = distributed_physics.A_adjoint(y) # Adjoint
# Step 4: Distribute a denoiser for large images
denoiser = DRUNet()
distributed_denoiser = distribute(
denoiser,
ctx,
patch_size=256, # Split image into patches
overlap=64, # Overlap for smooth blending
)
# Use it like regular denoiser
denoised = distributed_denoiser(x_adj, sigma=0.1)
# Step 5: Distribute a data fidelity term
data_fidelity = L2()
distributed_data_fidelity = distribute(data_fidelity, ctx)
# Use it like regular data fidelity
loss = distributed_data_fidelity.fn(denoised, y, distributed_physics)
# Step 6: debug and print on rank 0 only
if ctx.rank == 0:
print("Distributed physics output shape:", y.shape)
print("Distributed physics adjoint output shape:", x_adj.shape)
print("Distributed denoiser output shape:", denoised.shape)
print(f"Distributed data fidelity loss: {loss.item():.6f}")
Distributed physics output shape: [torch.Size([1, 3, 481, 321]), torch.Size([1, 3, 481, 321]), torch.Size([1, 3, 481, 321])]
Distributed physics adjoint output shape: torch.Size([1, 3, 481, 321])
Distributed denoiser output shape: torch.Size([1, 3, 481, 321])
Distributed data fidelity loss: ...
That’s the entire API! The deepinv.distributed.distribute() function handles all the complexity of distributed computing.
You can choose to distribute some components and not others, depending on your needs.
For instance, you might only want to distribute the denoiser for large images, while keeping the physics and data fidelity local.
When to Use Distributed Computing#
Multi-Operator Problems: many inverse problems involve multiple physics operators with corresponding measurements:
Multi-view imaging: Different camera angles or viewpoints
Multi-frequency acquisitions: Different measurement frequencies or channels
Multi-blur deconvolution: Different blur kernels applied to the same scene
Tomography: Different projection angles
The distributed framework automatically splits these operators across processes, computing forward operations, adjoints, and data fidelity gradients in parallel.
Large-Scale Images: for very large images (e.g., high-resolution medical scans, satellite imagery, radio interferometry), the distributed framework uses spatial tiling to:
Split the image into overlapping patches
Process each patch independently across multiple devices
Reconstruct the full image with smooth blending at boundaries
This enables handling arbitrarily large images that wouldn’t fit in a single device’s memory.
Simple Two-Step Pattern#
Step 1: Create a distributed context
from deepinv.distributed import DistributedContext
with DistributedContext() as ctx:
# All distributed operations go here
pass
The context:
Works seamlessly in both single-process and multi-process modes
Automatically initializes process groups when running with
torchrunor on a slurm cluster with one task per gpu.Assigns devices based on available GPUs
Cleans up resources on exit
Step 2: Distribute your objects
# Distribute physics operators
distributed_physics = distribute(physics, ctx)
# Distribute denoisers with tiling parameters
distributed_denoiser = distribute(denoiser, ctx, patch_size=256, overlap=64)
# Distribute data fidelity
distributed_data_fidelity = distribute(data_fidelity, ctx)
The deepinv.distributed.distribute() function:
Auto-detects the object type (physics, denoiser, prior, data fidelity)
Creates the appropriate distributed version
Handles all parallelization logic internally
Distributed Physics#
Large-scale physics operators can sometimes be separated into blocks:
for sub-operators \(A_i\).
The distributed framework allows you to compute each sub-operator in parallel to speed up the computation of the global forward or adjoint operator.
Note
Check out the distributed physics example for a complete demo.
Basic Usage#
from deepinv.physics import Blur, stack
from deepinv.distributed import DistributedContext, distribute
with DistributedContext() as ctx:
# Create multiple operators
physics_list = [operator1, operator2, operator3, ...]
stacked_physics = stack(*physics_list)
# Distribute them
dist_physics = distribute(stacked_physics, ctx)
# Use like regular physics
y = dist_physics(x) # Forward (parallel)
x_adj = dist_physics.A_adjoint(y) # Adjoint (parallel)
x_ata = dist_physics.A_adjoint_A(x) # Composition (parallel)
How It Works#
Operator Sharding: Operators are divided across processes using round-robin assignment
Parallel Forward: Each process computes \(A_i(x)\) for its local operators
Parallel Adjoint: Each process computes local adjoints, then results are summed via
all_reduce
Input Formats#
The distribute() function accepts multiple formats:
physics_list = [operator1, operator2, operator3, ...]
# From StackedPhysics
stacked = stack(*physics_list)
dist_physics = distribute(stacked, ctx)
# From list of physics
dist_physics = distribute(physics_list, ctx)
# From factory function
def physics_factory(idx, device, shared):
# idx is the index of the operator to create
# device is the assigned device for this process
# shared is a dict for sharing parameters across operators (optional)
return create_physics(idx, device)
dist_physics = distribute(physics_factory, ctx, num_operators=10)
# With shared parameters
shared_params = {"common_param": value}
dist_physics = distribute(
physics_factory, ctx, num_operators=10, shared=shared_params
)
Gather Strategies#
You can control how results are gathered from different processes:
# Concatenated (default): most efficient for similar-sized tensors
dist_physics = distribute(physics, ctx, gather_strategy="concatenated")
# Naive: simple serialization, good for small tensors
dist_physics = distribute(physics, ctx, gather_strategy="naive")
# Broadcast: good for heterogeneous sizes
dist_physics = distribute(physics, ctx, gather_strategy="broadcast")
Distributed Denoisers#
Denoisers can be distributed using spatial tiling to handle large images.
Note
Check out the distributed denoiser example for a complete demo.
Basic Usage#
from deepinv.models import DRUNet
from deepinv.distributed import DistributedContext, distribute
with DistributedContext() as ctx:
# Load your denoiser
denoiser = DRUNet()
# Distribute with tiling parameters
dist_denoiser = distribute(
denoiser,
ctx,
patch_size=256, # Size of each patch
overlap=64, # Overlap for smooth boundaries
)
# Process image
image = torch.randn(1, 3, 512, 512, device=ctx.device)
with torch.no_grad():
denoised = dist_denoiser(image, sigma=0.05)
if ctx.rank == 0:
print("Denoised image shape:", denoised.shape)
Denoised image shape: torch.Size([1, 3, 512, 512])
How It Works#
Patch Extraction: Image is split into overlapping patches
Distributed Processing: Patches are distributed across processes
Parallel Denoising: Each process denoises its local patches
Reconstruction: Patches are blended back into full image, each rank has access to the full output
Tiling Parameters#
Parameter |
Description |
|---|---|
|
Size of each patch (default: 256). Larger patches = less communication, more memory |
|
Overlap radius for smooth blending (default: 64). |
|
Strategy for tiling: |
|
Max patches per batch (default: all). Set to 1 for sequential processing (lowest memory) |
Tiling Strategies#
# Tiling with overlap (default)
dist_denoiser = distribute(denoiser, ctx, tiling_strategy="overlap_tiling")
# Basic (no overlap blending)
dist_denoiser = distribute(denoiser, ctx, tiling_strategy="basic")
Running Multi-Process#
Use torchrun to launch multiple processes. Examples:
4 GPUs on one machine:
torchrun --nproc_per_node=4 my_script.py
2 machines with 2 GPUs each:
# On machine 1 (rank 0)
torchrun --nproc_per_node=4 --nnodes=2 --node_rank=0 \
--master_addr="192.168.1.1" --master_port=29500 my_script.py
# On machine 2 (rank 1):
torchrun --nproc_per_node=4 --nnodes=2 --node_rank=1 \
--master_addr="192.168.1.1" --master_port=29500 my_script.py
Alternatively, use the -m torch.distributed.run syntax to run as a module:
python -m torch.distributed.run --nproc_per_node=4 my_script.py
The DistributedContext automatically detects the settings from environment variables.
Advanced Features#
Local vs Reduced Operations#
By default, distributed methods return fully gathered and reduced results (combined from all processes).
You can get local-only results with gather=False and you can choose to skip reduction with reduce_op=None:
# Get local results without local reduction
y_local = dist_physics.A(x, gather=False, reduce_op=None)
# Get local results with reduction (sum by default)
y_local = dist_physics.A(x, gather=False)
# Get gathered results without reduction
y_gathered = dist_physics.A(x, reduce_op=None)
# Get gathered results (default)
y_all = dist_physics.A(x)
This is useful for:
Custom reduction strategies
Debugging distributed execution
Optimizing communication patterns
Custom Tiling Strategies#
You can implement custom tiling strategies by subclassing
deepinv.distributed.strategies.DistributedSignalStrategy:
from deepinv.distributed.strategies import DistributedSignalStrategy
class MyCustomStrategy(DistributedSignalStrategy):
def get_local_patches(self, X, local_indices):
# Your patch extraction logic
pass
def reduce_patches(self, out_tensor, local_pairs):
# Your patch reduction logic
pass
def get_num_patches(self):
# Total number of patches
pass
# Use it
dist_denoiser = distribute(
denoiser, ctx,
tiling_strategy=MyCustomStrategy(img_size),
)
Performance Tips#
Choosing the Right Number of Processes
Multi-operator problems: Use as many processes as operators (up to available devices)
Spatial tiling: Balance parallelism vs communication overhead
Rule of thumb: Start with number of GPUs, experiment from there
Optimizing Patch Size
Larger patches (512+): Less communication, more memory per process
Smaller patches (128-256): More parallelism, more communication
Recommendation: 256-512 pixels for deep denoisers on natural images
Receptive Field Padding
Set
overlapto match your denoiser’s receptive fieldEnsures smooth blending at patch boundaries
Typical values: 32-64 pixels for U-Net style denoisers
Gather Strategies
Concatenated (default): Best for most cases, minimal communication
Naive: Use for small tensors or debugging
Broadcast: Use when operator outputs have very different sizes
Key Classes#
Class |
Description |
|---|---|
Manages distributed execution, process groups, and devices |
|
Distributes physics operators across processes (auto-created by |
|
Extends DistributedStackedPhysics for linear operators with adjoint operations |
|
Distributes denoisers/priors using spatial tiling (auto-created by |
|
Distributes data fidelity |
You typically won’t need to instantiate these classes directly. Use the deepinv.distributed.distribute() function instead.
Troubleshooting#
Out of memory errors
Reduce
patch_sizefor distributed denoisersSet
max_batch_size=1for sequential patch processing
Results differ slightly from non-distributed
This is normal for tiling strategies due to boundary blending
Differences are typically very small
The distributed implementation of
A_daggerandcompute_norminLinearDistributedPhysicsuses approximations that lead to differences compared to the non-distributed versions.
See Also#
API Reference: deepinv.distributed
Examples:
Related:
deepinv.physics.StackedPhysicsfor multi-operator physicsOptimization algorithms for reconstruction methods