DistributedDataFidelity#

class deepinv.distributed.DistributedDataFidelity(ctx, data_fidelity, num_operators=None, *, factory_kwargs=None, reduction='sum')[source]#

Bases: object

Distributed data fidelity term for use with distributed physics operators.

This class wraps a standard DataFidelity object and makes it compatible with DistributedStackedLinearPhysics by implementing efficient distributed computation patterns. It computes data fidelity terms and gradients using local operations followed by a single reduction, avoiding redundant communication.

The key operations are:

  • fn(x, y, physics): Computes the data fidelity \(\sum_i d(A_i(x), y_i)\)

  • grad(x, y, physics): Computes the gradient \(\sum_i A_i^T \nabla d(A_i(x), y_i)\)

Both operations use an efficient pattern:

  1. Compute local forward operations (A_local)

  2. Apply distance function and compute gradients locally

  3. Perform a single reduction across ranks

Parameters:
  • ctx (DistributedContext) – distributed context manager.

  • data_fidelity (DataFidelity | Callable) – either a DataFidelity instance or a factory function that creates DataFidelity instances for each operator. The factory should have signature factory(index: int, device: torch.device, factory_kwargs: dict | None) -> DataFidelity.

  • num_operators (int | None) – number of operators (required if data_fidelity is a factory). Default is None.

  • factory_kwargs (dict | None) – shared data dictionary passed to factory function for all operators. Default is None.

  • reduction (str) – reduction mode matching the distributed physics. Options are 'sum' or 'mean'. Default is 'sum'.

fn(x, y, physics, gather=True, *args, **kwargs)[source]#

Compute the distributed data fidelity term.

For distributed physics with operators \(\{A_i\}\) and measurements \(\{y_i\}\), computes:

\[f(x) = \sum_i d(A_i(x), y_i)\]

This is computed efficiently by:

  1. Each rank computes \(A_i(x)\) for its local operators

  2. Each rank computes \(\sum_{i \in \text{local}} d(A_i(x), y_i)\)

  3. Results are reduced across all ranks

Parameters:
  • x (torch.Tensor) – input signal at which to evaluate the data fidelity.

  • y (list[torch.Tensor]) – measurements (TensorList or list of tensors).

  • physics (DistributedStackedLinearPhysics) – distributed physics operator.

  • gather (bool) – whether to gather (reduce) results across ranks. Default is True.

  • args – additional positional arguments passed to the distance function.

  • kwargs – additional keyword arguments passed to the distance function.

Returns:

scalar data fidelity value.

Return type:

Tensor

grad(x, y, physics, gather=True, *args, **kwargs)[source]#

Compute the gradient of the distributed data fidelity term.

For distributed physics with operators \(\{A_i\}\) and measurements \(\{y_i\}\), computes:

\[\nabla_x f(x) = \sum_i A_i^T \nabla d(A_i(x), y_i)\]

This is computed efficiently by:

  1. Each rank computes \(A_i(x)\) for its local operators

  2. Each rank computes \(\nabla d(A_i(x), y_i)\) for its local operators

  3. Each rank computes \(\sum_{i \in \text{local}} A_i^T \nabla d(A_i(x), y_i)\) using A_vjp_local

  4. Results are reduced across all ranks

Parameters:
  • x (torch.Tensor) – input signal at which to compute the gradient.

  • y (list[torch.Tensor]) – measurements (TensorList or list of tensors).

  • physics (DistributedStackedLinearPhysics) – distributed physics operator.

  • gather (bool) – whether to gather (reduce) results across ranks. Default is True.

  • args – additional positional arguments passed to the distance function gradient.

  • kwargs – additional keyword arguments passed to the distance function gradient.

Returns:

gradient with same shape as x.

Return type:

Tensor

Examples using DistributedDataFidelity:#

Distributed Plug-and-Play (PnP) Reconstruction

Distributed Plug-and-Play (PnP) Reconstruction