DistributedDataFidelity#
- class deepinv.distributed.DistributedDataFidelity(ctx, data_fidelity, num_operators=None, *, factory_kwargs=None, reduction='sum')[source]#
Bases:
objectDistributed data fidelity term for use with distributed physics operators.
This class wraps a standard DataFidelity object and makes it compatible with DistributedStackedLinearPhysics by implementing efficient distributed computation patterns. It computes data fidelity terms and gradients using local operations followed by a single reduction, avoiding redundant communication.
The key operations are:
fn(x, y, physics): Computes the data fidelity \(\sum_i d(A_i(x), y_i)\)grad(x, y, physics): Computes the gradient \(\sum_i A_i^T \nabla d(A_i(x), y_i)\)
Both operations use an efficient pattern:
Compute local forward operations (A_local)
Apply distance function and compute gradients locally
Perform a single reduction across ranks
- Parameters:
ctx (DistributedContext) – distributed context manager.
data_fidelity (DataFidelity | Callable) – either a DataFidelity instance or a factory function that creates DataFidelity instances for each operator. The factory should have signature
factory(index: int, device: torch.device, factory_kwargs: dict | None) -> DataFidelity.num_operators (int | None) – number of operators (required if data_fidelity is a factory). Default is
None.factory_kwargs (dict | None) – shared data dictionary passed to factory function for all operators. Default is
None.reduction (str) – reduction mode matching the distributed physics. Options are
'sum'or'mean'. Default is'sum'.
- fn(x, y, physics, gather=True, *args, **kwargs)[source]#
Compute the distributed data fidelity term.
For distributed physics with operators \(\{A_i\}\) and measurements \(\{y_i\}\), computes:
\[f(x) = \sum_i d(A_i(x), y_i)\]This is computed efficiently by:
Each rank computes \(A_i(x)\) for its local operators
Each rank computes \(\sum_{i \in \text{local}} d(A_i(x), y_i)\)
Results are reduced across all ranks
- Parameters:
x (torch.Tensor) – input signal at which to evaluate the data fidelity.
y (list[torch.Tensor]) – measurements (TensorList or list of tensors).
physics (DistributedStackedLinearPhysics) – distributed physics operator.
gather (bool) – whether to gather (reduce) results across ranks. Default is
True.args – additional positional arguments passed to the distance function.
kwargs – additional keyword arguments passed to the distance function.
- Returns:
scalar data fidelity value.
- Return type:
- grad(x, y, physics, gather=True, *args, **kwargs)[source]#
Compute the gradient of the distributed data fidelity term.
For distributed physics with operators \(\{A_i\}\) and measurements \(\{y_i\}\), computes:
\[\nabla_x f(x) = \sum_i A_i^T \nabla d(A_i(x), y_i)\]This is computed efficiently by:
Each rank computes \(A_i(x)\) for its local operators
Each rank computes \(\nabla d(A_i(x), y_i)\) for its local operators
Each rank computes \(\sum_{i \in \text{local}} A_i^T \nabla d(A_i(x), y_i)\) using A_vjp_local
Results are reduced across all ranks
- Parameters:
x (torch.Tensor) – input signal at which to compute the gradient.
y (list[torch.Tensor]) – measurements (TensorList or list of tensors).
physics (DistributedStackedLinearPhysics) – distributed physics operator.
gather (bool) – whether to gather (reduce) results across ranks. Default is
True.args – additional positional arguments passed to the distance function gradient.
kwargs – additional keyword arguments passed to the distance function gradient.
- Returns:
gradient with same shape as x.
- Return type: