CompressedSensing
- class deepinv.physics.CompressedSensing(m, img_shape, fast=False, channelwise=False, dtype=torch.float32, device='cpu', rng: Generator | None = None, **kwargs)[source]
Bases:
LinearPhysics
Compressed Sensing forward operator. Creates a random sampling \(m \times n\) matrix where \(n\) is the number of elements of the signal, i.e.,
np.prod(img_shape)
andm
is the number of measurements.This class generates a random iid Gaussian matrix if
fast=False
\[A_{i,j} \sim \mathcal{N}(0,\frac{1}{m})\]or a Subsampled Orthogonal with Random Signs matrix (SORS) if
fast=True
(see https://arxiv.org/abs/1506.03521)\[A = \text{diag}(m)D\text{diag}(s)\]where \(s\in\{-1,1\}^{n}\) is a random sign flip with probability 0.5, \(D\in\mathbb{R}^{n\times n}\) is a fast orthogonal transform (DST-1) and \(\text{diag}(m)\in\mathbb{R}^{m\times n}\) is random subsampling matrix, which keeps \(m\) out of \(n\) entries.
It is recommended to use
fast=True
for image sizes bigger than 32 x 32, since the forward computation withfast=False
has an \(O(mn)\) complexity, whereas withfast=True
it has an \(O(n \log n)\) complexity.An existing operator can be loaded from a saved .pth file via
self.load_state_dict(save_path)
, in a similar fashion totorch.nn.Module
.Note
If
fast=False
, the forward operator has a norm which tends to \((1+\sqrt{n/m})^2\) for large \(n\) and \(m\) due to the Marcenko-Pastur law. Iffast=True
, the forward operator has a unit norm.If
dtype=torch.cfloat
, the forward operator will be generated as a random i.i.d. complex Gaussian matrix to be used withfast=False
\[A_{i,j} \sim \mathcal{N} \left( 0, \frac{1}{2m}) \right) + \mathrm{i} \mathcal{N} \left( 0, \frac{1}{2m} \right).\]- Parameters:
m (int) – number of measurements.
img_shape (tuple) – shape (C, H, W) of inputs.
fast (bool) – The operator is iid Gaussian if false, otherwise A is a SORS matrix with the Discrete Sine Transform (type I).
channelwise (bool) – Channels are processed independently using the same random forward operator.
dtype (torch.type) – Forward matrix is stored as a dtype. For complex matrices, use torch.cfloat. Default is torch.float.
device (str) – Device to store the forward matrix.
rng (torch.Generator (Optional)) – a pseudorandom random number generator for the parameter generation. If
None
, the default Generator of PyTorch will be used.
- Examples:
Compressed sensing operator with 100 measurements for a 3x3 image:
>>> from deepinv.physics import CompressedSensing >>> seed = torch.manual_seed(0) # Random seed for reproducibility >>> x = torch.randn(1, 1, 3, 3) # Define random 3x3 image >>> physics = CompressedSensing(m=10, img_shape=(1, 3, 3), rng=torch.Generator('cpu')) >>> physics(x) tensor([[-1.7769, 0.6160, -0.8181, -0.5282, -1.2197, 0.9332, -0.1668, 1.5779, 0.6752, -1.5684]])
- A(x, **kwargs)[source]
Computes forward operator \(y = A(x)\) (without noise and/or sensor non-linearities)
- Parameters:
x (torch.Tensor,list[torch.Tensor]) – signal/image
- Returns:
(torch.Tensor) clean measurements
- A_adjoint(y, **kwargs)[source]
Computes transpose of the forward operator \(\tilde{x} = A^{\top}y\). If \(A\) is linear, it should be the exact transpose of the forward matrix.
Note
If the problem is non-linear, there is not a well-defined transpose operation, but defining one can be useful for some reconstruction networks, such as
deepinv.models.ArtifactRemoval
.- Parameters:
y (torch.Tensor) – measurements.
params (None, torch.Tensor) – optional additional parameters for the adjoint operator.
- Returns:
(torch.Tensor) linear reconstruction \(\tilde{x} = A^{\top}y\).
- A_dagger(y, **kwargs)[source]
Computes the solution in \(x\) to \(y = Ax\) using the conjugate gradient method, see
deepinv.optim.utils.conjugate_gradient()
.If the size of \(y\) is larger than \(x\) (overcomplete problem), it computes \((A^{\top} A)^{-1} A^{\top} y\), otherwise (incomplete problem) it computes \(A^{\top} (A A^{\top})^{-1} y\).
This function can be overwritten by a more efficient pseudoinverse in cases where closed form formulas exist.
- Parameters:
y (torch.Tensor) – a measurement \(y\) to reconstruct via the pseudoinverse.
- Returns:
(torch.Tensor) The reconstructed image \(x\).
Examples using CompressedSensing
:
A tour of forward sensing operators
Learned Iterative Soft-Thresholding Algorithm (LISTA) for compressed sensing
Learned iterative custom prior