Random phase retrieval and reconstruction methods.

This example shows how to create a random phase retrieval operator and generate phaseless measurements from a given image. The example showcases 4 different reconstruction methods to recover the image from the phaseless measurements:

  1. Gradient descent with random initialization;

  2. Spectral methods;

  3. Gradient descent with spectral methods initialization;

  4. Gradient descent with PnP denoisers.

General setup

import deepinv as dinv
from pathlib import Path
import torch
import matplotlib.pyplot as plt
from deepinv.models import DRUNet
from deepinv.optim.data_fidelity import L2
from deepinv.optim.prior import PnP, Zero
from deepinv.optim.optimizers import optim_builder
from deepinv.utils.demo import load_url_image, get_image_url
from deepinv.utils.plotting import plot
from deepinv.optim.phase_retrieval import (
    correct_global_phase,
    cosine_similarity,
    spectral_methods,
)
from deepinv.models.complex import to_complex_denoiser

BASE_DIR = Path(".")
RESULTS_DIR = BASE_DIR / "results"
# Set global random seed to ensure reproducibility.
torch.manual_seed(0)

device = dinv.utils.get_freer_gpu() if torch.cuda.is_available() else "cpu"

Load image from the internet

We use the standard test image “Shepp–Logan phantom”.

# Image size
img_size = 32
url = get_image_url("SheppLogan.png")
# The pixel values of the image are in the range [0, 1].
x = load_url_image(
    url=url, img_size=img_size, grayscale=True, resize_mode="resize", device=device
)
print(x.min(), x.max())
tensor(0.) tensor(0.7412)

Visualization

We use the customized plot() function in deepinv to visualize the original image.

plot(x, titles="Original image")
Original image

Signal construction

We use the original image as the phase information for the complex signal. The original value range is [0, 1], and we map it to the phase range [-pi/2, pi/2].

x_phase = torch.exp(1j * x * torch.pi - 0.5j * torch.pi)

# Every element of the signal should have unit norm.
assert torch.allclose(x_phase.real**2 + x_phase.imag**2, torch.tensor(1.0))

Measurements generation

Create a random phase retrieval operator with an oversampling ratio (measurements/pixels) of 5.0, and generate measurements from the signal with additive Gaussian noise.

# Define physics information
oversampling_ratio = 5.0
img_shape = x.shape[1:]
m = int(oversampling_ratio * torch.prod(torch.tensor(img_shape)))
n_channels = 1  # 3 for color images, 1 for gray-scale images

# Create the physics
physics = dinv.physics.RandomPhaseRetrieval(
    m=m,
    img_shape=img_shape,
    device=device,
)

# Generate measurements
y = physics(x_phase)

Reconstruction with gradient descent and random initialization

First, we use the function deepinv.optim.L2 as the data fidelity function, and the class deepinv.optim.optim_iterators.GDIteration as the optimizer to run a gradient descent algorithm. The initial guess is a random complex signal.

data_fidelity = L2()
prior = Zero()
iterator = dinv.optim.optim_iterators.GDIteration()
# Parameters for the optimizer, including stepsize and regularization coefficient.
optim_params = {"stepsize": 0.06, "lambda": 1.0, "g_param": []}
num_iter = 1000

# Initial guess
x_phase_gd_rand = torch.randn_like(x_phase)

loss_hist = []

for _ in range(num_iter):
    res = iterator(
        {"est": (x_phase_gd_rand,), "cost": 0},
        cur_data_fidelity=data_fidelity,
        cur_prior=prior,
        cur_params=optim_params,
        y=y,
        physics=physics,
    )
    x_phase_gd_rand = res["est"][0]
    loss_hist.append(data_fidelity(x_phase_gd_rand, y, physics).cpu())

print("initial loss:", loss_hist[0])
print("final loss:", loss_hist[-1])
# Plot the loss curve
plt.plot(loss_hist)
plt.yscale("log")
plt.title("loss curve (gradient descent with random initialization)")
plt.show()
loss curve (gradient descent with random initialization)
initial loss: tensor([171.3826])
final loss: tensor([25.3845])

Phase correction and signal reconstruction

The solution of the optimization algorithm x_est may be any phase-shifted version of the original complex signal x_phase, i.e., x_est = a * x_phase where a is an arbitrary unit norm complex number. Therefore, we use the function deepinv.optim.phase_retrieval.correct_global_phase to correct the global phase shift of the estimated signal x_est to make it closer to the original signal x_phase. We then use torch.angle to extract the phase information. With the range of the returned value being [-pi/2, pi/2], we further normalize it to be [0, 1]. This operation will later be done for all the reconstruction methods.

# correct possible global phase shifts
x_gd_rand = correct_global_phase(x_phase_gd_rand, x_phase)
# extract phase information and normalize to the range [0, 1]
x_gd_rand = torch.angle(x_gd_rand) / torch.pi + 0.5

plot([x, x_gd_rand], titles=["Signal", "Reconstruction"], rescale_mode="clip")
Signal, Reconstruction

Reconstruction with spectral methods

Spectral methods deepinv.optim.phase_retrieval.spectral_methods offers a good initial guess on the original signal. Moreover, deepinv.physics.RandomPhaseRetrieval uses spectral methods as its default reconstruction method A_dagger, which we can directly call.

# Spectral methods return a tensor with unit norm.
x_phase_spec = physics.A_dagger(y, n_iter=300)

Phase correction and signal reconstruction

# correct possible global phase shifts
x_spec = correct_global_phase(x_phase_spec, x_phase)
# extract phase information and normalize to the range [0, 1]
x_spec = torch.angle(x_spec) / torch.pi + 0.5
plot([x, x_spec], titles=["Signal", "Reconstruction"], rescale_mode="clip")
Signal, Reconstruction

Reconstruction with gradient descent and spectral methods initialization

The estimate from spectral methods can be directly used as the initial guess for the gradient descent algorithm.

# Initial guess from spectral methods
x_phase_gd_spec = physics.A_dagger(y, n_iter=300)

loss_hist = []
for _ in range(num_iter):
    res = iterator(
        {"est": (x_phase_gd_spec,), "cost": 0},
        cur_data_fidelity=data_fidelity,
        cur_prior=prior,
        cur_params=optim_params,
        y=y,
        physics=physics,
    )
    x_phase_gd_spec = res["est"][0]
    loss_hist.append(data_fidelity(x_phase_gd_spec, y, physics).cpu())

print("intial loss:", loss_hist[0])
print("final loss:", loss_hist[-1])
plt.plot(loss_hist)
plt.yscale("log")
plt.title("loss curve (gradient descent with spectral initialization)")
plt.show()
loss curve (gradient descent with spectral initialization)
intial loss: tensor([44.9017])
final loss: tensor([0.0049])

Phase correction and signal reconstruction

# correct possible global phase shifts
x_gd_spec = correct_global_phase(x_phase_gd_spec, x_phase)
# extract phase information and normalize to the range [0, 1]
x_gd_spec = torch.angle(x_gd_spec) / torch.pi + 0.5
plot([x, x_gd_spec], titles=["Signal", "Reconstruction"], rescale_mode="clip")
Signal, Reconstruction

Reconstruction with gradient descent and PnP denoisers

We can also use the Plug-and-Play (PnP) framework to incorporate denoisers as regularizers in the optimization algorithm. We use a deep denoiser as the prior, which is trained on a large dataset of natural images.

# Load the pre-trained denoiser
denoiser = DRUNet(
    in_channels=n_channels,
    out_channels=n_channels,
    pretrained="download",  # automatically downloads the pretrained weights, set to a path to use custom weights.
    device=device,
)
# The original denoiser is designed for real-valued images, so we need to convert it to a complex-valued denoiser for phase retrieval problems.
denoiser_complex = to_complex_denoiser(denoiser, mode="abs_angle")

# Algorithm parameters
data_fidelity = L2()
prior = PnP(denoiser=denoiser_complex)
params_algo = {"stepsize": 0.30, "g_param": 0.04}
max_iter = 100
early_stop = True
verbose = True

# Instantiate the algorithm class to solve the IP problem.
model = optim_builder(
    iteration="PGD",
    prior=prior,
    data_fidelity=data_fidelity,
    early_stop=early_stop,
    max_iter=max_iter,
    verbose=verbose,
    params_algo=params_algo,
)

# Run the algorithm
x_phase_pnp, metrics = model(y, physics, x_gt=x_phase, compute_metrics=True)
Downloading: "https://huggingface.co/deepinv/drunet/resolve/main/drunet_deepinv_gray_finetune_26k.pth?download=true" to /home/runner/.cache/torch/hub/checkpoints/drunet_deepinv_gray_finetune_26k.pth

  0%|          | 0.00/125M [00:00<?, ?B/s]
  1%|          | 1.25M/125M [00:00<00:11, 11.2MB/s]
  2%|▏         | 2.38M/125M [00:00<00:11, 11.4MB/s]
  3%|▎         | 3.50M/125M [00:00<00:12, 10.4MB/s]
  4%|▎         | 4.62M/125M [00:00<00:11, 10.8MB/s]
  5%|▍         | 5.75M/125M [00:00<00:12, 10.4MB/s]
  5%|▌         | 6.75M/125M [00:00<00:11, 10.4MB/s]
  6%|▌         | 7.75M/125M [00:00<00:11, 10.4MB/s]
  7%|▋         | 8.75M/125M [00:00<00:11, 10.4MB/s]
  8%|▊         | 9.75M/125M [00:00<00:11, 10.4MB/s]
  9%|▊         | 10.8M/125M [00:01<00:11, 10.4MB/s]
  9%|▉         | 11.8M/125M [00:01<00:11, 10.4MB/s]
 10%|█         | 12.8M/125M [00:01<00:11, 10.4MB/s]
 11%|█         | 13.8M/125M [00:01<00:11, 10.4MB/s]
 12%|█▏        | 14.8M/125M [00:01<00:11, 10.4MB/s]
 13%|█▎        | 15.8M/125M [00:01<00:11, 10.4MB/s]
 13%|█▎        | 16.8M/125M [00:01<00:10, 10.3MB/s]
 14%|█▍        | 17.8M/125M [00:01<00:10, 10.3MB/s]
 15%|█▌        | 18.8M/125M [00:01<00:10, 10.3MB/s]
 16%|█▌        | 19.8M/125M [00:01<00:10, 10.4MB/s]
 17%|█▋        | 20.9M/125M [00:02<00:10, 10.8MB/s]
 18%|█▊        | 22.0M/125M [00:02<00:10, 10.3MB/s]
 18%|█▊        | 23.0M/125M [00:02<00:10, 10.4MB/s]
 19%|█▉        | 24.0M/125M [00:02<00:10, 10.4MB/s]
 20%|██        | 25.0M/125M [00:02<00:10, 10.4MB/s]
 21%|██        | 26.1M/125M [00:02<00:09, 10.8MB/s]
 22%|██▏       | 27.2M/125M [00:02<00:09, 10.3MB/s]
 23%|██▎       | 28.4M/125M [00:02<00:09, 10.7MB/s]
 24%|██▎       | 29.5M/125M [00:02<00:09, 10.3MB/s]
 25%|██▍       | 30.6M/125M [00:03<00:09, 10.7MB/s]
 25%|██▌       | 31.8M/125M [00:03<00:09, 10.3MB/s]
 26%|██▋       | 32.8M/125M [00:03<00:09, 10.3MB/s]
 27%|██▋       | 33.8M/125M [00:03<00:09, 10.3MB/s]
 28%|██▊       | 34.8M/125M [00:03<00:09, 10.3MB/s]
 29%|██▉       | 35.9M/125M [00:03<00:08, 10.7MB/s]
 30%|██▉       | 37.0M/125M [00:03<00:08, 10.3MB/s]
 31%|███       | 38.0M/125M [00:03<00:08, 10.3MB/s]
 31%|███▏      | 39.0M/125M [00:03<00:08, 10.4MB/s]
 32%|███▏      | 40.0M/125M [00:04<00:08, 10.4MB/s]
 33%|███▎      | 41.1M/125M [00:04<00:08, 10.8MB/s]
 34%|███▍      | 42.2M/125M [00:04<00:08, 10.3MB/s]
 35%|███▍      | 43.2M/125M [00:04<00:08, 10.3MB/s]
 36%|███▌      | 44.2M/125M [00:04<00:08, 10.4MB/s]
 36%|███▋      | 45.4M/125M [00:04<00:07, 10.8MB/s]
 37%|███▋      | 46.5M/125M [00:04<00:07, 10.3MB/s]
 38%|███▊      | 47.5M/125M [00:04<00:07, 10.3MB/s]
 39%|███▉      | 48.6M/125M [00:04<00:07, 10.7MB/s]
 40%|███▉      | 49.8M/125M [00:04<00:07, 10.3MB/s]
 41%|████      | 50.9M/125M [00:05<00:07, 10.7MB/s]
 42%|████▏     | 52.0M/125M [00:05<00:07, 10.3MB/s]
 43%|████▎     | 53.1M/125M [00:05<00:07, 10.7MB/s]
 44%|████▎     | 54.2M/125M [00:05<00:07, 10.2MB/s]
 44%|████▍     | 55.4M/125M [00:05<00:06, 10.6MB/s]
 45%|████▌     | 56.5M/125M [00:05<00:06, 10.3MB/s]
 46%|████▌     | 57.5M/125M [00:05<00:06, 10.3MB/s]
 47%|████▋     | 58.5M/125M [00:05<00:06, 10.4MB/s]
 48%|████▊     | 59.6M/125M [00:05<00:06, 10.7MB/s]
 49%|████▉     | 60.8M/125M [00:06<00:06, 10.3MB/s]
 50%|████▉     | 61.9M/125M [00:06<00:06, 10.7MB/s]
 51%|█████     | 63.0M/125M [00:06<00:06, 10.3MB/s]
 51%|█████▏    | 64.0M/125M [00:06<00:06, 10.3MB/s]
 52%|█████▏    | 65.0M/125M [00:06<00:06, 10.3MB/s]
 53%|█████▎    | 66.1M/125M [00:06<00:05, 10.7MB/s]
 54%|█████▍    | 67.2M/125M [00:06<00:05, 10.3MB/s]
 55%|█████▍    | 68.2M/125M [00:06<00:05, 10.3MB/s]
 56%|█████▌    | 69.2M/125M [00:06<00:05, 10.3MB/s]
 56%|█████▋    | 70.2M/125M [00:07<00:05, 10.3MB/s]
 57%|█████▋    | 71.2M/125M [00:07<00:05, 10.4MB/s]
 58%|█████▊    | 72.4M/125M [00:07<00:05, 10.8MB/s]
 59%|█████▉    | 73.5M/125M [00:07<00:05, 10.3MB/s]
 60%|█████▉    | 74.5M/125M [00:07<00:05, 10.3MB/s]
 61%|██████    | 75.5M/125M [00:07<00:04, 10.3MB/s]
 61%|██████▏   | 76.5M/125M [00:07<00:04, 10.3MB/s]
 62%|██████▏   | 77.5M/125M [00:07<00:04, 10.3MB/s]
 63%|██████▎   | 78.5M/125M [00:07<00:04, 10.4MB/s]
 64%|██████▍   | 79.5M/125M [00:07<00:04, 10.4MB/s]
 65%|██████▍   | 80.5M/125M [00:08<00:04, 10.4MB/s]
 65%|██████▌   | 81.5M/125M [00:08<00:04, 10.4MB/s]
 66%|██████▋   | 82.6M/125M [00:08<00:04, 10.8MB/s]
 67%|██████▋   | 83.8M/125M [00:08<00:04, 10.3MB/s]
 68%|██████▊   | 84.8M/125M [00:08<00:04, 10.3MB/s]
 69%|██████▉   | 85.9M/125M [00:08<00:03, 10.7MB/s]
 70%|██████▉   | 87.0M/125M [00:08<00:03, 10.3MB/s]
 71%|███████   | 88.0M/125M [00:08<00:03, 10.3MB/s]
 72%|███████▏  | 89.1M/125M [00:08<00:03, 10.7MB/s]
 72%|███████▏  | 90.2M/125M [00:09<00:03, 10.3MB/s]
 73%|███████▎  | 91.4M/125M [00:09<00:03, 10.7MB/s]
 74%|███████▍  | 92.5M/125M [00:09<00:03, 10.3MB/s]
 75%|███████▌  | 93.5M/125M [00:09<00:03, 10.4MB/s]
 76%|███████▌  | 94.5M/125M [00:09<00:03, 10.3MB/s]
 77%|███████▋  | 95.5M/125M [00:09<00:02, 10.4MB/s]
 77%|███████▋  | 96.5M/125M [00:09<00:02, 10.4MB/s]
 78%|███████▊  | 97.5M/125M [00:09<00:02, 10.4MB/s]
 79%|███████▉  | 98.5M/125M [00:09<00:02, 10.4MB/s]
 80%|███████▉  | 99.5M/125M [00:10<00:02, 10.4MB/s]
 81%|████████  | 100M/125M [00:10<00:02, 10.4MB/s]
 82%|████████▏ | 102M/125M [00:10<00:02, 10.8MB/s]
 83%|████████▎ | 103M/125M [00:10<00:02, 10.3MB/s]
 83%|████████▎ | 104M/125M [00:10<00:02, 10.3MB/s]
 84%|████████▍ | 105M/125M [00:10<00:02, 10.4MB/s]
 85%|████████▍ | 106M/125M [00:10<00:01, 10.4MB/s]
 86%|████████▌ | 107M/125M [00:10<00:01, 10.4MB/s]
 87%|████████▋ | 108M/125M [00:10<00:01, 10.8MB/s]
 88%|████████▊ | 109M/125M [00:10<00:01, 10.4MB/s]
 88%|████████▊ | 110M/125M [00:11<00:01, 10.4MB/s]
 89%|████████▉ | 111M/125M [00:11<00:01, 10.4MB/s]
 90%|████████▉ | 112M/125M [00:11<00:01, 10.4MB/s]
 91%|█████████ | 113M/125M [00:11<00:01, 10.4MB/s]
 92%|█████████▏| 114M/125M [00:11<00:01, 10.8MB/s]
 93%|█████████▎| 115M/125M [00:11<00:00, 10.3MB/s]
 93%|█████████▎| 116M/125M [00:11<00:00, 10.3MB/s]
 94%|█████████▍| 117M/125M [00:11<00:00, 10.3MB/s]
 95%|█████████▍| 118M/125M [00:11<00:00, 10.3MB/s]
 96%|█████████▌| 119M/125M [00:11<00:00, 10.7MB/s]
 97%|█████████▋| 120M/125M [00:12<00:00, 10.3MB/s]
 98%|█████████▊| 122M/125M [00:12<00:00, 10.3MB/s]
 98%|█████████▊| 122M/125M [00:12<00:00, 10.3MB/s]
 99%|█████████▉| 124M/125M [00:12<00:00, 10.3MB/s]
100%|█████████▉| 124M/125M [00:12<00:00, 10.4MB/s]
100%|██████████| 125M/125M [00:12<00:00, 10.4MB/s]

Phase correction and signal reconstruction

# correct possible global phase shifts
x_pnp = correct_global_phase(x_phase_pnp, x_phase)
# extract phase information and normalize to the range [0, 1]
x_pnp = torch.angle(x_pnp) / (2 * torch.pi) + 0.5
plot([x, x_pnp], titles=["Signal", "Reconstruction"], rescale_mode="clip")
Signal, Reconstruction

Overall comparison

We visualize the original image and the reconstructed images from the four methods. We further compute the PSNR (Peak Signal-to-Noise Ratio) scores (higher better) for every reconstruction and their cosine similarities with the original image (range in [0,1], higher better). In conclusion, gradient descent with random intialization provides a poor reconstruction, while spectral methods provide a good initial estimate which can later be improved by gradient descent to acquire the best reconstruction results. Besides, the PnP framework with a deep denoiser as the prior also provides a very good denoising results as it exploits prior information about the set of natural images.

imgs = [x, x_gd_rand, x_spec, x_gd_spec, x_pnp]
plot(
    imgs,
    titles=["Original", "GD random", "Spectral", "GD spectral", "PnP"],
    save_dir=RESULTS_DIR / "images",
    show=True,
    rescale_mode="clip",
)

# Compute metrics
print(
    f"GD Random reconstruction, PSNR: {dinv.metric.cal_psnr(x, x_gd_rand).item():.2f} dB; cosine similarity: {cosine_similarity(x_phase_gd_rand, x_phase).item():.3f}."
)
print(
    f"Spectral reconstruction, PSNR: {dinv.metric.cal_psnr(x, x_spec).item():.2f} dB; cosine similarity: {cosine_similarity(x_phase_spec, x_phase).item():.3f}."
)
print(
    f"GD Spectral reconstruction, PSNR: {dinv.metric.cal_psnr(x, x_gd_spec).item():.2f} dB; cosine similarity: {cosine_similarity(x_phase_gd_spec, x_phase).item():.3f}."
)
print(
    f"PnP reconstruction, PSNR: {dinv.metric.cal_psnr(x, x_pnp).item():.2f} dB; cosine similarity: {cosine_similarity(x_phase_pnp, x_phase).item():.3f}."
)
Original, GD random, Spectral, GD spectral, PnP
GD Random reconstruction, PSNR: 4.13 dB; cosine similarity: 0.195.
Spectral reconstruction, PSNR: 17.45 dB; cosine similarity: 0.884.
GD Spectral reconstruction, PSNR: 49.18 dB; cosine similarity: 1.000.
PnP reconstruction, PSNR: 13.83 dB; cosine similarity: 0.999.

Total running time of the script: (0 minutes 44.937 seconds)

Gallery generated by Sphinx-Gallery