RandomPhaseRetrieval#
- class deepinv.physics.RandomPhaseRetrieval(m, img_shape, channelwise=False, dtype=torch.complex64, device='cpu', unitary=False, compute_inverse=False, rng: Generator | None = None, **kwargs)[source]#
Bases:
PhaseRetrieval
Random Phase Retrieval forward operator. Creates a random \(m \times n\) sampling matrix \(B\) where \(n\) is the number of elements of the signal and \(m\) is the number of measurements.
This class generates a random i.i.d. Gaussian matrix
\[B_{i,j} \sim \mathcal{N} \left( 0, \frac{1}{2m} \right) + \mathrm{i} \mathcal{N} \left( 0, \frac{1}{2m} \right).\]An existing operator can be loaded from a saved .pth file via
self.load_state_dict(save_path)
, in a similar fashion totorch.nn.Module
.- Parameters:
m (int) – number of measurements.
img_shape (tuple) – shape (C, H, W) of inputs.
channelwise (bool) – Channels are processed independently using the same random forward operator.
unitary (bool) – Use a random unitary matrix instead of Gaussian matrix. Default is False.
compute_inverse (bool) – Compute the pseudo-inverse of the forward matrix. Default is False.
dtype (torch.type) – Forward matrix is stored as a dtype. Default is torch.cfloat.
device (str) – Device to store the forward matrix.
rng (torch.Generator (Optional)) – a pseudorandom random number generator for the parameter generation. If
None
, the default Generator of PyTorch will be used.
- Examples:
Random phase retrieval operator with 10 measurements for a 3x3 image:
>>> seed = torch.manual_seed(0) # Random seed for reproducibility >>> x = torch.randn((1, 1, 3, 3),dtype=torch.cfloat) # Define random 3x3 image >>> physics = RandomPhaseRetrieval(m=6, img_shape=(1, 3, 3), rng=torch.Generator('cpu')) >>> physics(x) tensor([[3.8405, 2.2588, 0.0146, 3.0864, 1.8075, 0.1518]])
Examples using RandomPhaseRetrieval
:#
Random phase retrieval and reconstruction methods.