RandomPhaseRetrieval#

class deepinv.physics.RandomPhaseRetrieval(m, img_shape, channelwise=False, dtype=torch.complex64, device='cpu', unitary=False, compute_inverse=False, rng: Generator | None = None, **kwargs)[source]#

Bases: PhaseRetrieval

Random Phase Retrieval forward operator. Creates a random \(m \times n\) sampling matrix \(B\) where \(n\) is the number of elements of the signal and \(m\) is the number of measurements.

This class generates a random i.i.d. Gaussian matrix

\[B_{i,j} \sim \mathcal{N} \left( 0, \frac{1}{2m} \right) + \mathrm{i} \mathcal{N} \left( 0, \frac{1}{2m} \right).\]

An existing operator can be loaded from a saved .pth file via self.load_state_dict(save_path), in a similar fashion to torch.nn.Module.

Parameters:
  • m (int) – number of measurements.

  • img_shape (tuple) – shape (C, H, W) of inputs.

  • channelwise (bool) – Channels are processed independently using the same random forward operator.

  • unitary (bool) – Use a random unitary matrix instead of Gaussian matrix. Default is False.

  • compute_inverse (bool) – Compute the pseudo-inverse of the forward matrix. Default is False.

  • dtype (torch.type) – Forward matrix is stored as a dtype. Default is torch.cfloat.

  • device (str) – Device to store the forward matrix.

  • rng (torch.Generator (Optional)) – a pseudorandom random number generator for the parameter generation. If None, the default Generator of PyTorch will be used.


Examples:

Random phase retrieval operator with 10 measurements for a 3x3 image:

>>> seed = torch.manual_seed(0) # Random seed for reproducibility
>>> x = torch.randn((1, 1, 3, 3),dtype=torch.cfloat) # Define random 3x3 image
>>> physics = RandomPhaseRetrieval(m=6, img_shape=(1, 3, 3), rng=torch.Generator('cpu'))
>>> physics(x)
tensor([[3.8405, 2.2588, 0.0146, 3.0864, 1.8075, 0.1518]])

Examples using RandomPhaseRetrieval:#

Random phase retrieval and reconstruction methods.

Random phase retrieval and reconstruction methods.