DistributedStackedPhysics#

class deepinv.distributed.DistributedStackedPhysics(ctx, num_operators, factory, *, factory_kwargs=None, dtype=None, gather_strategy='concatenated', **kwargs)[source]#

Bases: Physics

Holds only local physics operators. Exposes fast local and compatible global APIs.

This class distributes a collection of physics operators across multiple processes, where each process owns a subset of the operators.

Note

It is intended to parallelize models naturally expressed as a stack/list of operators (e.g., deepinv.physics.StackedPhysics or an explicit Python list of deepinv.physics.Physics objects) and is not meant to split a single monolithic physics operator across ranks.

If your forward model is a single operator that can be decomposed into multiple sub-operators, it is up to you to perform that decomposition (e.g., build a deepinv.physics.StackedPhysics) and then pass that collection to DistributedStackedPhysics via the factory argument.

Parameters:
  • ctx (DistributedContext) – distributed context manager.

  • num_operators (int) – total number of physics operators.

  • factory (Callable) – factory function that creates physics operators. Should have signature factory(index, device, factory_kwargs) -> Physics.

  • factory_kwargs (dict | None) – shared data dictionary passed to factory function. Default is None.

  • dtype (torch.dtype | None) – data type for operations. Default is None.

  • gather_strategy (str) –

    strategy for gathering distributed results. Options are: - 'naive': Simple object serialization (best for small tensors) - 'concatenated': Single concatenated tensor (best for medium/large tensors, minimal communication) - 'broadcast': Per-operator broadcasts (best for heterogeneous sizes or streaming)

    Default is 'concatenated'.

A(x, gather=True, reduce_op=None, **kwargs)[source]#

Apply forward operator to all distributed physics operators with automatic gathering.

Applies the forward operator \(A(x)\) by computing local measurements and gathering results from all ranks using the configured gather strategy.

Parameters:
  • x (torch.Tensor) – input signal.

  • gather (bool) – whether to gather results across ranks. If False, returns local measurements. Default is True.

  • reduce_op (str | None) – reduction operation to apply across ranks. Default is None.

  • kwargs – optional parameters for the forward operator.

Returns:

complete list of measurements from all operators (or local list if reduce=False).

Return type:

TensorList | list[Tensor]

forward(x, gather=True, reduce_op=None, **kwargs)[source]#

Apply full forward model with sensor and noise models to the input signal and gather results.

\[y = N(A(x))\]
Parameters:
  • x (torch.Tensor) – input signal.

  • gather (bool) – whether to gather results across ranks. If False, returns local measurements. Default is True.

  • reduce_op (str | None) – reduction operation to apply across ranks. Default is None.

  • kwargs – optional parameters for the forward model.

Returns:

complete list of noisy measurements from all operators.