DistributedStackedLinearPhysics#
- class deepinv.distributed.DistributedStackedLinearPhysics(ctx, num_operators, factory, *, factory_kwargs=None, reduction='sum', dtype=None, gather_strategy='concatenated', **kwargs)[source]#
Bases:
DistributedStackedPhysics,LinearPhysicsDistributed linear physics operators.
This class extends
DistributedStackedPhysicsfor linear operators. It provides distributed operations that automatically handle communication and reductions.Note
This class is intended to distribute a collection of linear operators (e.g.,
deepinv.physics.StackedLinearPhysicsor an explicit Python list ofdeepinv.physics.LinearPhysicsobjects) across ranks. It is not a mechanism to shard a single linear operator internally.If you have one linear physics operator that can naturally be split into multiple operators, you must do that split yourself (build a stacked/list representation) and provide those operators through the
factory.All linear operations (
A_adjoint,A_vjp, etc.) support areduceparameter:If
reduce=True(default): The method computes the global result by performing a single all-reduce across all ranks.If
reduce=False: The method computes only the local contribution from operators owned by this rank, without any inter-rank communication. This is useful for deferring reductions in custom algorithms.
- Parameters:
ctx (DistributedContext) β distributed context manager.
num_operators (int) β total number of physics operators to distribute.
factory (Callable) β factory function that creates linear physics operators. Should have signature
factory(index: int, device: torch.device, factory_kwargs: dict | None) -> LinearPhysics.factory_kwargs (dict | None) β shared data dictionary passed to factory function for all operators. Default is
None.reduction (str) β reduction mode for distributed operations. Options are
'sum'(stack operators) or'mean'(average operators). Default is'sum'.dtype (torch.dtype | None) β data type for operations. Default is
None.gather_strategy (str) β strategy for gathering distributed results in forward operations. Options are
'naive','concatenated', or'broadcast'. Default is'concatenated'.
- A_A_adjoint(y, gather=True, reduce_op=None, **kwargs)[source]#
Compute global \(A A^T\) operation with automatic reduction.
For stacked operators, this computes \(A A^T y\) where \(A^T y = \sum_i A_i^T y_i\) and then applies the forward operator to get \([A_1(A^T y), A_2(A^T y), \ldots, A_n(A^T y)]\).
Note
Unlike other operations, the adjoint step
A^T yis always computed globally (with full reduction across ranks) even whengather=False. This is because computing the correctA_A_adjointrequires the full adjointsum_i A_i^T y_i. Thegatherparameter only controls whether the final forward operationA(...)is gathered across ranks.- Parameters:
y (TensorList | list[torch.Tensor]) β full list of measurements from all operators.
gather (bool) β whether to gather final results across ranks. If
False, returns only local operatorsβ contributions (but still uses the global adjoint). Default isTrue.reduce_op (str | None) β reduction operation to apply across ranks for the final forward step. Default is
None.kwargs β optional parameters for the operation.
- Returns:
TensorList with entries \(A_i A^T y\) for all operators (or local list if
gather=False).- Return type:
TensorList | list[Tensor]
- A_adjoint(y, gather=True, reduce_op=None, **kwargs)[source]#
Compute global adjoint operation with automatic reduction.
Extracts local measurements, computes local adjoint contributions, and reduces across all ranks to obtain the complete \(A^T y\) where \(A\) is the stacked operator \(A = [A_1, A_2, \ldots, A_n]\) and \(A_i\) are the individual linear operators.
- Parameters:
y (TensorList | list[torch.Tensor]) β full list of measurements from all operators.
gather (bool) β whether to gather results across ranks. If False, returns local contribution. Default is
True.reduce_op (str | None) β reduction operation to apply across ranks. If None, uses class default. Default is
None.kwargs β optional parameters for the adjoint operation.
- Returns:
complete adjoint result \(A^T y\) (or local contribution if gather=False).
- Return type:
- A_adjoint_A(x, gather=True, reduce_op=None, **kwargs)[source]#
Compute global \(A^T A\) operation with automatic reduction.
Computes the complete normal operator \(A^T A x = \sum_i A_i^T A_i x\) by combining local contributions from all ranks.
- Parameters:
x (torch.Tensor) β input tensor.
gather (bool) β whether to gather results across ranks. If False, returns local contribution. Default is
True.reduce_op (str | None) β reduction operation to apply across ranks. If None, uses class default. Default is
None.kwargs β optional parameters for the operation.
- Returns:
complete \(A^T A x\) result (or local contribution if gather=False).
- Return type:
- A_dagger(y, solver='CG', max_iter=None, tol=None, verbose=False, *, local_only=True, gather=True, **kwargs)[source]#
Distributed pseudoinverse computation. This method provides two strategies:
1. Local approximation (
local_only=True, default): Each rank computes the pseudoinverse of its local operators independently, then averages the results with a single reduction. This is efficient (minimal communication) but provides only an approximation. In other words, for stacked operators this computes\[A^\dagger y = \frac{1}{n} \sum_i A_i^\dagger y_i\]2. Global computation (
local_only=False): Uses the full least squares solver with distributedA_adjoint_A()andA_A_adjoint()operations. This computes the exact pseudoinverse but requires communication at every iteration.- Parameters:
y (TensorList | list[torch.Tensor]) β measurements to invert.
solver (str) β least squares solver to use (only for
local_only=False). Choose between'CG','lsqr','BiCGStab'and'minres'. Default is'CG'.max_iter (int | None) β maximum number of iterations for least squares solver. Default is
None.tol (float | None) β relative tolerance for least squares solver. Default is
None.verbose (bool) β print information (only on rank 0). Default is
False.local_only (bool) β If
True(default), compute local daggers and sum-reduce (efficient). IfFalse, compute exact global pseudoinverse with full communication (expensive). Default isTrue.gather (bool) β whether to gather results across ranks (only applies if local_only=True). Default is
True.kwargs β optional parameters for the forward operator.
- Returns:
pseudoinverse solution. If
local_only=True, returns approximation. Iflocal_only=False, returns exact least squares solution.- Return type:
- A_vjp(x, v, gather=True, reduce_op=None, **kwargs)[source]#
Compute global vector-Jacobian product with automatic reduction.
Extracts local cotangent vectors, computes local VJP contributions, and reduces across all ranks to obtain the complete VJP.
- Parameters:
x (torch.Tensor) β input tensor.
v (TensorList | list[torch.Tensor]) β full list of cotangent vectors from all operators.
gather (bool) β whether to gather results across ranks. If False, returns local contribution. Default is
True.reduce_op (str | None) β reduction operation to apply across ranks. If None, uses class default. Default is
None.kwargs β optional parameters for the VJP operation.
- Returns:
complete VJP result (or local contribution if gather=False).
- Return type:
- compute_sqnorm(x0, *, max_iter=50, tol=1e-3, verbose=True, local_only=True, gather=True, **kwargs)[source]#
Computes the squared spectral \(\ell_2\) norm of the distributed operator.
This method provides two strategies:
1. Local approximation (
local_only=True, default): Each rank computes the norm of its local operators independently, then a single max-reduction provides an upper bound. This is efficient (minimal communication) and valid for conservative estimates. For stacked operators \(A = [A_1; A_2; \ldots; A_n]\), we have \(\|A\|^2 \leq \sum_i \|A_i\|^2\), and we use \(\max_i \|A_i\|^2\) as a conservative upper bound.2. Global computation (
local_only=False): Uses the full distributedA_adjoint_A()with communication at every power iteration. This computes the exact norm but is communication-intensive.- Parameters:
x0 (torch.Tensor) β an unbatched tensor sharing its shape, dtype and device with the initial iterate.
max_iter (int) β maximum number of iterations for power method. Default is
50.tol (float) β relative variation criterion for convergence. Default is
1e-3.verbose (bool) β print information (only on rank 0). Default is
True.local_only (bool) β If
True(default), compute local norms and max-reduce (efficient). IfFalse, compute exact global norm with full communication (expensive). Default isTrue.gather (bool) β whether to gather results across ranks (only applies if local_only=True). Default is
True.kwargs β optional parameters for the forward operator.
- Returns:
Squared spectral norm. If
local_only=True, returns upper bound. Iflocal_only=False, returns exact value.- Return type: