Random phase retrieval and reconstruction methods.#

This example shows how to create a random phase retrieval operator and generate phaseless measurements from a given image. The example showcases 4 different reconstruction methods to recover the image from the phaseless measurements:

  1. Gradient descent with random initialization;

  2. Spectral methods;

  3. Gradient descent with spectral methods initialization;

  4. Gradient descent with PnP denoisers.

General setup#

import deepinv as dinv
from pathlib import Path
import torch
import matplotlib.pyplot as plt
from deepinv.models import DRUNet
from deepinv.optim.data_fidelity import L2
from deepinv.optim.prior import PnP, Zero
from deepinv.optim.optimizers import optim_builder
from deepinv.utils.demo import load_url_image, get_image_url
from deepinv.utils.plotting import plot
from deepinv.optim.phase_retrieval import (
    correct_global_phase,
    cosine_similarity,
    spectral_methods,
)
from deepinv.models.complex import to_complex_denoiser

BASE_DIR = Path(".")
RESULTS_DIR = BASE_DIR / "results"
# Set global random seed to ensure reproducibility.
torch.manual_seed(0)

device = dinv.utils.get_freer_gpu() if torch.cuda.is_available() else "cpu"

Load image from the internet#

We use the standard test image “Shepp–Logan phantom”.

# Image size
img_size = 32
url = get_image_url("SheppLogan.png")
# The pixel values of the image are in the range [0, 1].
x = load_url_image(
    url=url, img_size=img_size, grayscale=True, resize_mode="resize", device=device
)
print(x.min(), x.max())
tensor(0.) tensor(0.7412)

Visualization#

We use the customized plot() function in deepinv to visualize the original image.

plot(x, titles="Original image")
Original image

Signal construction#

We use the original image as the phase information for the complex signal. The original value range is [0, 1], and we map it to the phase range [-pi/2, pi/2].

x_phase = torch.exp(1j * x * torch.pi - 0.5j * torch.pi)

# Every element of the signal should have unit norm.
assert torch.allclose(x_phase.real**2 + x_phase.imag**2, torch.tensor(1.0))

Measurements generation#

Create a random phase retrieval operator with an oversampling ratio (measurements/pixels) of 5.0, and generate measurements from the signal with additive Gaussian noise.

# Define physics information
oversampling_ratio = 5.0
img_shape = x.shape[1:]
m = int(oversampling_ratio * torch.prod(torch.tensor(img_shape)))
n_channels = 1  # 3 for color images, 1 for gray-scale images

# Create the physics
physics = dinv.physics.RandomPhaseRetrieval(
    m=m,
    img_shape=img_shape,
    device=device,
)

# Generate measurements
y = physics(x_phase)

Reconstruction with gradient descent and random initialization#

First, we use the function deepinv.optim.L2 as the data fidelity function, and the class deepinv.optim.optim_iterators.GDIteration as the optimizer to run a gradient descent algorithm. The initial guess is a random complex signal.

data_fidelity = L2()
prior = Zero()
iterator = dinv.optim.optim_iterators.GDIteration()
# Parameters for the optimizer, including stepsize and regularization coefficient.
optim_params = {"stepsize": 0.06, "lambda": 1.0, "g_param": []}
num_iter = 1000

# Initial guess
x_phase_gd_rand = torch.randn_like(x_phase)

loss_hist = []

for _ in range(num_iter):
    res = iterator(
        {"est": (x_phase_gd_rand,), "cost": 0},
        cur_data_fidelity=data_fidelity,
        cur_prior=prior,
        cur_params=optim_params,
        y=y,
        physics=physics,
    )
    x_phase_gd_rand = res["est"][0]
    loss_hist.append(data_fidelity(x_phase_gd_rand, y, physics).cpu())

print("initial loss:", loss_hist[0])
print("final loss:", loss_hist[-1])
# Plot the loss curve
plt.plot(loss_hist)
plt.yscale("log")
plt.title("loss curve (gradient descent with random initialization)")
plt.show()
loss curve (gradient descent with random initialization)
initial loss: tensor([190.4569])
final loss: tensor([28.0710])

Phase correction and signal reconstruction#

The solution of the optimization algorithm x_est may be any phase-shifted version of the original complex signal x_phase, i.e., x_est = a * x_phase where a is an arbitrary unit norm complex number. Therefore, we use the function deepinv.optim.phase_retrieval.correct_global_phase to correct the global phase shift of the estimated signal x_est to make it closer to the original signal x_phase. We then use torch.angle to extract the phase information. With the range of the returned value being [-pi/2, pi/2], we further normalize it to be [0, 1]. This operation will later be done for all the reconstruction methods.

# correct possible global phase shifts
x_gd_rand = correct_global_phase(x_phase_gd_rand, x_phase)
# extract phase information and normalize to the range [0, 1]
x_gd_rand = torch.angle(x_gd_rand) / torch.pi + 0.5

plot([x, x_gd_rand], titles=["Signal", "Reconstruction"], rescale_mode="clip")
Signal, Reconstruction

Reconstruction with spectral methods#

Spectral methods deepinv.optim.phase_retrieval.spectral_methods offers a good initial guess on the original signal. Moreover, deepinv.physics.RandomPhaseRetrieval uses spectral methods as its default reconstruction method A_dagger, which we can directly call.

# Spectral methods return a tensor with unit norm.
x_phase_spec = physics.A_dagger(y, n_iter=300)

Phase correction and signal reconstruction#

# correct possible global phase shifts
x_spec = correct_global_phase(x_phase_spec, x_phase)
# extract phase information and normalize to the range [0, 1]
x_spec = torch.angle(x_spec) / torch.pi + 0.5
plot([x, x_spec], titles=["Signal", "Reconstruction"], rescale_mode="clip")
Signal, Reconstruction

Reconstruction with gradient descent and spectral methods initialization#

The estimate from spectral methods can be directly used as the initial guess for the gradient descent algorithm.

# Initial guess from spectral methods
x_phase_gd_spec = physics.A_dagger(y, n_iter=300)

loss_hist = []
for _ in range(num_iter):
    res = iterator(
        {"est": (x_phase_gd_spec,), "cost": 0},
        cur_data_fidelity=data_fidelity,
        cur_prior=prior,
        cur_params=optim_params,
        y=y,
        physics=physics,
    )
    x_phase_gd_spec = res["est"][0]
    loss_hist.append(data_fidelity(x_phase_gd_spec, y, physics).cpu())

print("intial loss:", loss_hist[0])
print("final loss:", loss_hist[-1])
plt.plot(loss_hist)
plt.yscale("log")
plt.title("loss curve (gradient descent with spectral initialization)")
plt.show()
loss curve (gradient descent with spectral initialization)
intial loss: tensor([42.0413])
final loss: tensor([0.0034])

Phase correction and signal reconstruction#

# correct possible global phase shifts
x_gd_spec = correct_global_phase(x_phase_gd_spec, x_phase)
# extract phase information and normalize to the range [0, 1]
x_gd_spec = torch.angle(x_gd_spec) / torch.pi + 0.5
plot([x, x_gd_spec], titles=["Signal", "Reconstruction"], rescale_mode="clip")
Signal, Reconstruction

Reconstruction with gradient descent and PnP denoisers#

We can also use the Plug-and-Play (PnP) framework to incorporate denoisers as regularizers in the optimization algorithm. We use a deep denoiser as the prior, which is trained on a large dataset of natural images.

# Load the pre-trained denoiser
denoiser = DRUNet(
    in_channels=n_channels,
    out_channels=n_channels,
    pretrained="download",  # automatically downloads the pretrained weights, set to a path to use custom weights.
    device=device,
)
# The original denoiser is designed for real-valued images, so we need to convert it to a complex-valued denoiser for phase retrieval problems.
denoiser_complex = to_complex_denoiser(denoiser, mode="abs_angle")

# Algorithm parameters
data_fidelity = L2()
prior = PnP(denoiser=denoiser_complex)
params_algo = {"stepsize": 0.30, "g_param": 0.04}
max_iter = 100
early_stop = True
verbose = True

# Instantiate the algorithm class to solve the IP problem.
model = optim_builder(
    iteration="PGD",
    prior=prior,
    data_fidelity=data_fidelity,
    early_stop=early_stop,
    max_iter=max_iter,
    verbose=verbose,
    params_algo=params_algo,
)

# Run the algorithm
x_phase_pnp, metrics = model(y, physics, x_gt=x_phase, compute_metrics=True)
Downloading: "https://huggingface.co/deepinv/drunet/resolve/main/drunet_deepinv_gray_finetune_26k.pth?download=true" to /home/runner/.cache/torch/hub/checkpoints/drunet_deepinv_gray_finetune_26k.pth

  0%|          | 0.00/125M [00:00<?, ?B/s]
  1%|▏         | 1.75M/125M [00:00<00:07, 18.3MB/s]
  4%|▍         | 4.75M/125M [00:00<00:04, 25.9MB/s]
  7%|▋         | 9.25M/125M [00:00<00:03, 35.2MB/s]
 11%|█         | 13.5M/125M [00:00<00:03, 38.2MB/s]
 14%|█▍        | 17.5M/125M [00:00<00:02, 39.4MB/s]
 17%|█▋        | 21.8M/125M [00:00<00:02, 40.4MB/s]
 21%|██        | 26.0M/125M [00:00<00:02, 41.4MB/s]
 24%|██▍       | 30.2M/125M [00:00<00:02, 41.7MB/s]
 28%|██▊       | 34.5M/125M [00:00<00:02, 42.0MB/s]
 31%|███       | 38.5M/125M [00:01<00:02, 41.8MB/s]
 34%|███▍      | 42.8M/125M [00:01<00:02, 42.2MB/s]
 38%|███▊      | 47.0M/125M [00:01<00:01, 42.3MB/s]
 41%|████      | 51.2M/125M [00:01<00:01, 42.5MB/s]
 45%|████▍     | 55.5M/125M [00:01<00:01, 42.7MB/s]
 48%|████▊     | 59.8M/125M [00:01<00:01, 42.6MB/s]
 51%|█████▏    | 64.0M/125M [00:01<00:01, 42.4MB/s]
 55%|█████▍    | 68.1M/125M [00:01<00:01, 42.5MB/s]
 58%|█████▊    | 72.2M/125M [00:01<00:01, 41.8MB/s]
 61%|██████▏   | 76.5M/125M [00:01<00:01, 42.0MB/s]
 65%|██████▍   | 80.8M/125M [00:02<00:01, 42.3MB/s]
 68%|██████▊   | 85.0M/125M [00:02<00:00, 42.2MB/s]
 72%|███████▏  | 89.2M/125M [00:02<00:00, 42.3MB/s]
 75%|███████▌  | 93.5M/125M [00:02<00:00, 42.4MB/s]
 78%|███████▊  | 97.8M/125M [00:02<00:00, 42.6MB/s]
 82%|████████▏ | 102M/125M [00:02<00:00, 42.5MB/s]
 85%|████████▌ | 106M/125M [00:02<00:00, 42.6MB/s]
 89%|████████▊ | 110M/125M [00:02<00:00, 42.7MB/s]
 92%|█████████▏| 115M/125M [00:02<00:00, 42.7MB/s]
 96%|█████████▌| 119M/125M [00:03<00:00, 42.7MB/s]
 99%|█████████▉| 123M/125M [00:03<00:00, 42.7MB/s]
100%|██████████| 125M/125M [00:03<00:00, 41.5MB/s]

Phase correction and signal reconstruction#

# correct possible global phase shifts
x_pnp = correct_global_phase(x_phase_pnp, x_phase)
# extract phase information and normalize to the range [0, 1]
x_pnp = torch.angle(x_pnp) / (2 * torch.pi) + 0.5
plot([x, x_pnp], titles=["Signal", "Reconstruction"], rescale_mode="clip")
Signal, Reconstruction

Overall comparison#

We visualize the original image and the reconstructed images from the four methods. We further compute the PSNR (Peak Signal-to-Noise Ratio) scores (higher better) for every reconstruction and their cosine similarities with the original image (range in [0,1], higher better). In conclusion, gradient descent with random intialization provides a poor reconstruction, while spectral methods provide a good initial estimate which can later be improved by gradient descent to acquire the best reconstruction results. Besides, the PnP framework with a deep denoiser as the prior also provides a very good denoising results as it exploits prior information about the set of natural images.

imgs = [x, x_gd_rand, x_spec, x_gd_spec, x_pnp]
plot(
    imgs,
    titles=["Original", "GD random", "Spectral", "GD spectral", "PnP"],
    save_dir=RESULTS_DIR / "images",
    show=True,
    rescale_mode="clip",
)

# Compute metrics
print(
    f"GD Random reconstruction, PSNR: {dinv.metric.cal_psnr(x, x_gd_rand).item():.2f} dB; cosine similarity: {cosine_similarity(x_phase_gd_rand, x_phase).item():.3f}."
)
print(
    f"Spectral reconstruction, PSNR: {dinv.metric.cal_psnr(x, x_spec).item():.2f} dB; cosine similarity: {cosine_similarity(x_phase_spec, x_phase).item():.3f}."
)
print(
    f"GD Spectral reconstruction, PSNR: {dinv.metric.cal_psnr(x, x_gd_spec).item():.2f} dB; cosine similarity: {cosine_similarity(x_phase_gd_spec, x_phase).item():.3f}."
)
print(
    f"PnP reconstruction, PSNR: {dinv.metric.cal_psnr(x, x_pnp).item():.2f} dB; cosine similarity: {cosine_similarity(x_phase_pnp, x_phase).item():.3f}."
)
Original, GD random, Spectral, GD spectral, PnP
GD Random reconstruction, PSNR: 3.95 dB; cosine similarity: 0.201.
Spectral reconstruction, PSNR: 18.88 dB; cosine similarity: 0.902.
GD Spectral reconstruction, PSNR: 50.80 dB; cosine similarity: 1.000.
PnP reconstruction, PSNR: 13.83 dB; cosine similarity: 0.999.

Total running time of the script: (0 minutes 36.003 seconds)

Gallery generated by Sphinx-Gallery