RandomPhaseRetrieval

class deepinv.physics.RandomPhaseRetrieval(m, img_shape, channelwise=False, dtype=torch.complex64, device='cpu', rng: Generator | None = None, **kwargs)[source]

Bases: PhaseRetrieval

Random Phase Retrieval forward operator. Creates a random \(m \times n\) sampling matrix \(B\) where \(n\) is the number of elements of the signal and \(m\) is the number of measurements.

This class generates a random i.i.d. Gaussian matrix

\[B_{i,j} \sim \mathcal{N} \left( 0, \frac{1}{2m} \right) + \mathrm{i} \mathcal{N} \left( 0, \frac{1}{2m} \right).\]

An existing operator can be loaded from a saved .pth file via self.load_state_dict(save_path), in a similar fashion to torch.nn.Module.

Parameters:
  • m (int) – number of measurements.

  • img_shape (tuple) – shape (C, H, W) of inputs.

  • channelwise (bool) – Channels are processed independently using the same random forward operator.

  • dtype (torch.type) – Forward matrix is stored as a dtype. Default is torch.cfloat.

  • device (str) – Device to store the forward matrix.

  • rng (torch.Generator (Optional)) – a pseudorandom random number generator for the parameter generation. If None, the default Generator of PyTorch will be used.


Examples:

Random phase retrieval operator with 10 measurements for a 3x3 image:

>>> seed = torch.manual_seed(0) # Random seed for reproducibility
>>> x = torch.randn((1, 1, 3, 3),dtype=torch.cfloat) # Define random 3x3 image
>>> physics = RandomPhaseRetrieval(m=10,img_shape=(1, 3, 3), rng=torch.Generator('cpu'))
>>> physics(x)
tensor([[2.3043, 1.3553, 0.0087, 1.8518, 1.0845, 1.1561, 0.8668, 2.2031, 0.4542,
         0.0225]])

Examples using RandomPhaseRetrieval:

Random phase retrieval and reconstruction methods.

Random phase retrieval and reconstruction methods.