distribute#
- deepinv.distributed.distribute(object, ctx, *, num_operators=None, type_object='auto', dtype=torch.float32, gather_strategy='concatenated', tiling_strategy='overlap_tiling', tiling_dims=None, patch_size=256, overlap=64, max_batch_size=None, **kwargs)[source]#
Distribute a DeepInverse object across multiple devices.
This function takes a DeepInverse object and distributes it using the provided DistributedContext.
The list of supported DeepInverse objects includes:
Physics operators: a list of
deepinv.physics.Physics,deepinv.physics.StackedPhysicsordeepinv.physics.StackedLinearPhysics.Data fidelity terms: a list of
deepinv.optim.DataFidelityordeepinv.optim.StackedPhysicsDataFidelity.Priors/Denoisers:
deepinv.models.Denoiserordeepinv.optim.Priorobjects.
- Parameters:
object (StackedPhysics | list[Physics] | Callable | Denoiser | DataFidelity | StackedPhysicsDataFidelity | list[DataFidelity]) – DeepInverse object to distribute. The supported types are listed above.
ctx (DistributedContext) – distributed context manager.
num_operators (int | None) – number of physics operators when using a factory for physics, otherwise inferred. Default is
None.type_object (str | None) – type of object to distribute. Options are
'physics','linear_physics','data_fidelity','denoiser', or'auto'for automatic detection. Default is'auto'.dtype (torch.dtype | None) – data type for distributed object. Default is
torch.float32.gather_strategy (str) –
strategy for gathering distributed results.
- Options are:
'naive': Simple object serialization (best for small tensors)'concatenated': Single concatenated tensor (best for medium/large tensors, minimal communication)'broadcast': Per-operator broadcasts (best for heterogeneous sizes or streaming)
Default is
'concatenated'.tiling_strategy (str | DistributedSignalStrategy | None) – strategy for tiling the signal (for Denoiser). Options are
'basic','overlap_tiling', or a custom strategy instance. Default is'overlap_tiling'.tiling_dims (int | tuple[int, ...] | None) –
dimensions to tile over (for Denoiser).
- Can be one of the following:
If
None(default), tiles the last N-2 dimensions of your input tensor.If an int
N, only tiles over the specified dimension.If a tuple, specifies exact dimensions to tile.
- Examples:
For
(B, C, H, W)image:tiling_dims=(2, 3)tiles over H and W.For
(B, C, D, H, W)volume:tiling_dims=(2, 3, 4)tiles over D, H, W.For
(B, C, H, W)image:tiling_dims=2tiles only over H dimension.For
(B, C, D, H, W)volume:tiling_dims=Nonetiles over D, H, W dimensions.
patch_size (int) – size of patches for tiling strategies (for Denoiser). Can be an int (same size for all tiled dims) or a tuple (per-dimension size). Default is
256.overlap (int) – receptive field size for overlap in tiling strategies (for Denoiser). Can be an int (same size for all tiled dims) or a tuple (per-dimension size). Default is
64.max_batch_size (int | None) – maximum number of patches to process in a single batch (for Denoiser). If
None, all patches are batched together. Set to1for sequential processing. Default isNone.kwargs – additional keyword arguments for specific distributed classes.
- Returns:
Distributed version of the input object.
- Return type:
DistributedStackedPhysics | DistributedStackedLinearPhysics | DistributedProcessing | DistributedDataFidelity
- Examples:
Distribute a Physics object:
>>> from deepinv.physics import Blur, StackedLinearPhysics >>> from deepinv.distributed import DistributedContext, distribute >>> with DistributedContext() as ctx: ... physics = StackedLinearPhysics([Blur(kernel_size=5), Blur(kernel_size=9)]) ... dphysics = distribute(physics, ctx)
Distribute a DataFidelity object:
>>> from deepinv.optim.data_fidelity import L2 >>> from deepinv.distributed import DistributedContext, distribute >>> with DistributedContext() as ctx: ... data_fidelity = L2() ... ddata_fidelity = distribute(data_fidelity, ctx)
Distribute a Prior object:
>>> from deepinv.models import DnCNN >>> from deepinv.distributed import DistributedContext, distribute >>> with DistributedContext() as ctx: ... denoiser = DnCNN() ... ddenoiser = distribute(denoiser, ctx)