DistributedStackedPhysics#
- class deepinv.distributed.DistributedStackedPhysics(ctx, num_operators, factory, *, factory_kwargs=None, dtype=None, gather_strategy='concatenated', **kwargs)[source]#
Bases:
PhysicsHolds only local physics operators. Exposes fast local and compatible global APIs.
This class distributes a collection of physics operators across multiple processes, where each process owns a subset of the operators.
Note
It is intended to parallelize models naturally expressed as a stack/list of operators (e.g.,
deepinv.physics.StackedPhysicsor an explicit Python list ofdeepinv.physics.Physicsobjects) and is not meant to split a single monolithic physics operator across ranks.If your forward model is a single operator that can be decomposed into multiple sub-operators, it is up to you to perform that decomposition (e.g., build a
deepinv.physics.StackedPhysics) and then pass that collection toDistributedStackedPhysicsvia thefactoryargument.- Parameters:
ctx (DistributedContext) – distributed context manager.
num_operators (int) – total number of physics operators.
factory (Callable) – factory function that creates physics operators. Should have signature
factory(index, device, factory_kwargs) -> Physics.factory_kwargs (dict | None) – shared data dictionary passed to factory function. Default is
None.dtype (torch.dtype | None) – data type for operations. Default is
None.gather_strategy (str) –
strategy for gathering distributed results. Options are: -
'naive': Simple object serialization (best for small tensors) -'concatenated': Single concatenated tensor (best for medium/large tensors, minimal communication) -'broadcast': Per-operator broadcasts (best for heterogeneous sizes or streaming)Default is
'concatenated'.
- A(x, gather=True, reduce_op=None, **kwargs)[source]#
Apply forward operator to all distributed physics operators with automatic gathering.
Applies the forward operator \(A(x)\) by computing local measurements and gathering results from all ranks using the configured gather strategy.
- Parameters:
x (torch.Tensor) – input signal.
gather (bool) – whether to gather results across ranks. If
False, returns local measurements. Default isTrue.reduce_op (str | None) – reduction operation to apply across ranks. Default is
None.kwargs – optional parameters for the forward operator.
- Returns:
complete list of measurements from all operators (or local list if
reduce=False).- Return type:
TensorList | list[Tensor]
- forward(x, gather=True, reduce_op=None, **kwargs)[source]#
Apply full forward model with sensor and noise models to the input signal and gather results.
\[y = N(A(x))\]- Parameters:
x (torch.Tensor) – input signal.
gather (bool) – whether to gather results across ranks. If
False, returns local measurements. Default isTrue.reduce_op (str | None) – reduction operation to apply across ranks. Default is
None.kwargs – optional parameters for the forward model.
- Returns:
complete list of noisy measurements from all operators.