Note
Go to the end to download the full example code.
Random phase retrieval and reconstruction methods.
This example shows how to create a random phase retrieval operator and generate phaseless measurements from a given image. The example showcases 4 different reconstruction methods to recover the image from the phaseless measurements:
Gradient descent with random initialization;
Spectral methods;
Gradient descent with spectral methods initialization;
Gradient descent with PnP denoisers.
General setup
import deepinv as dinv
from pathlib import Path
import torch
import matplotlib.pyplot as plt
from deepinv.models import DRUNet
from deepinv.optim.data_fidelity import L2
from deepinv.optim.prior import PnP
from deepinv.optim.optimizers import optim_builder
from deepinv.utils.demo import load_url_image, get_image_url
from deepinv.utils.plotting import plot
from deepinv.optim.phase_retrieval import correct_global_phase, cosine_similarity
from deepinv.models.complex import to_complex_denoiser
BASE_DIR = Path(".")
RESULTS_DIR = BASE_DIR / "results"
# Set global random seed to ensure reproducibility.
torch.manual_seed(0)
device = dinv.utils.get_freer_gpu() if torch.cuda.is_available() else "cpu"
Load image from the internet
We use the standard test image “Shepp–Logan phantom”.
tensor(0.) tensor(0.7412)
Visualization
We use the customized plot() function in deepinv to visualize the original image.
plot(x, titles="Original image")
Signal construction
We use the original image as the phase information for the complex signal. The original value range is [0, 1], and we map it to the phase range [-pi/2, pi/2].
x_phase = torch.exp(1j * x * torch.pi - 0.5j * torch.pi)
# Every element of the signal should have unit norm.
assert torch.allclose(x_phase.real**2 + x_phase.imag**2, torch.tensor(1.0))
Measurements generation
Create a random phase retrieval operator with an oversampling ratio (measurements/pixels) of 5.0, and generate measurements from the signal with additive Gaussian noise.
# Define physics information
oversampling_ratio = 5.0
img_shape = x.shape[1:]
m = int(oversampling_ratio * torch.prod(torch.tensor(img_shape)))
noise_level_img = 0.05 # Gaussian Noise standard deviation for the degradation
n_channels = 1 # 3 for color images, 1 for gray-scale images
# Create the physics
physics = dinv.physics.RandomPhaseRetrieval(
m=m,
img_shape=img_shape,
noise_model=dinv.physics.GaussianNoise(sigma=noise_level_img),
device=device,
)
# Generate measurements
y = physics(x_phase)
Reconstruction with gradient descent and random initialization
First, we use the function deepinv.optim.L2
as the data fidelity function, and directly call its grad
method to run a gradient descent algorithm. The initial guess is a random complex signal.
data_fidelity = L2()
# Step size for the gradient descent
stepsize = 0.10
num_iter = 1000
# Initial guess
x_phase_gd_rand = torch.randn_like(x_phase)
loss_hist = []
for _ in range(num_iter):
x_phase_gd_rand = x_phase_gd_rand - stepsize * data_fidelity.grad(
x_phase_gd_rand, y, physics
)
loss_hist.append(data_fidelity(x_phase_gd_rand, y, physics).cpu())
print("initial loss:", loss_hist[0])
print("final loss:", loss_hist[-1])
# Plot the loss curve
plt.plot(loss_hist)
plt.yscale("log")
plt.title("loss curve (gradient descent with random initialization)")
plt.show()
initial loss: tensor([179.9851])
final loss: tensor([25.8607])
Phase correction and signal reconstruction
The solution of the optimization algorithm x_est may be any phase-shifted version of the original complex signal x_phase, i.e., x_est = a * x_phase where a is an arbitrary unit norm complex number.
Therefore, we use the function deepinv.optim.phase_retrieval.correct_global_phase
to correct the global phase shift of the estimated signal x_est to make it closer to the original signal x_phase.
We then use torch.angle
to extract the phase information. With the range of the returned value being [-pi/2, pi/2], we further normalize it to be [0, 1].
This operation will later be done for all the reconstruction methods.
# correct possible global phase shifts
x_gd_rand = correct_global_phase(x_phase_gd_rand, x_phase)
# extract phase information and normalize to the range [0, 1]
x_gd_rand = torch.angle(x_gd_rand) / torch.pi + 0.5
plot([x, x_gd_rand], titles=["Signal", "Reconstruction"], rescale_mode="clip")
Reconstruction with spectral methods
Spectral methods deepinv.optim.phase_retrieval.spectral_methods
offers a good initial guess on the original signal. Moreover, deepinv.physics.RandomPhaseRetrieval
uses spectral methods as its default reconstruction method A_dagger, which we can directly call.
# Spectral methods return a tensor with unit norm.
x_phase_spec = physics.A_dagger(y, n_iter=300)
# Correct the norm of the estimated signal
x_phase_spec = x_phase_spec * torch.sqrt(y.sum())
Phase correction and signal reconstruction
# correct possible global phase shifts
x_spec = correct_global_phase(x_phase_spec, x_phase)
# extract phase information and normalize to the range [0, 1]
x_spec = torch.angle(x_spec) / torch.pi + 0.5
plot([x, x_spec], titles=["Signal", "Reconstruction"], rescale_mode="clip")
Reconstruction with gradient descent and spectral methods initialization
The estimate from spectral methods can be directly used as the initial guess for the gradient descent algorithm.
# Initial guess from spectral methods
x_phase_gd_spec = physics.A_dagger(y, n_iter=300)
x_phase_gd_spec = x_phase_gd_spec * torch.sqrt(y.sum())
loss_hist = []
for _ in range(num_iter):
x_phase_gd_spec = x_phase_gd_spec - stepsize * data_fidelity.grad(
x_phase_gd_spec, y, physics
)
loss_hist.append(data_fidelity(x_phase_gd_spec, y, physics).cpu())
print("intial loss:", loss_hist[0])
print("final loss:", loss_hist[-1])
plt.plot(loss_hist)
plt.yscale("log")
plt.title("loss curve (gradient descent with spectral initialization)")
plt.show()
intial loss: tensor([35.3479])
final loss: tensor([0.0001])
Phase correction and signal reconstruction
# correct possible global phase shifts
x_gd_spec = correct_global_phase(x_phase_gd_spec, x_phase)
# extract phase information and normalize to the range [0, 1]
x_gd_spec = torch.angle(x_gd_spec) / torch.pi + 0.5
plot([x, x_gd_spec], titles=["Signal", "Reconstruction"], rescale_mode="clip")
Reconstruction with gradient descent and PnP denoisers
We can also use the Plug-and-Play (PnP) framework to incorporate denoisers as regularizers in the optimization algorithm. We use a deep denoiser as the prior, which is trained on a large dataset of natural images.
# Load the pre-trained denoiser
denoiser = DRUNet(
in_channels=n_channels,
out_channels=n_channels,
pretrained="download", # automatically downloads the pretrained weights, set to a path to use custom weights.
device=device,
)
# The original denoiser is designed for real-valued images, so we need to convert it to a complex-valued denoiser for phase retrieval problems.
denoiser_complex = to_complex_denoiser(denoiser, mode="abs_angle")
# Algorithm parameters
data_fidelity = L2()
prior = PnP(denoiser=denoiser_complex)
params_algo = {"stepsize": 0.30, "g_param": 0.04}
max_iter = 400
early_stop = True
verbose = True
# Instantiate the algorithm class to solve the IP problem.
model = optim_builder(
iteration="PGD",
prior=prior,
data_fidelity=data_fidelity,
early_stop=early_stop,
max_iter=max_iter,
verbose=verbose,
params_algo=params_algo,
)
# Run the algorithm
x_phase_pnp, metrics = model(y, physics, x_gt=x_phase, compute_metrics=True)
Downloading: "https://huggingface.co/deepinv/drunet/resolve/main/drunet_deepinv_gray_finetune_26k.pth?download=true" to /home/runner/.cache/torch/hub/checkpoints/drunet_deepinv_gray_finetune_26k.pth
0%| | 0.00/125M [00:00<?, ?B/s]
1%| | 1.25M/125M [00:00<00:12, 10.7MB/s]
2%|▏ | 2.38M/125M [00:00<00:11, 11.1MB/s]
3%|▎ | 3.50M/125M [00:00<00:12, 10.4MB/s]
4%|▎ | 4.50M/125M [00:00<00:12, 10.4MB/s]
4%|▍ | 5.50M/125M [00:00<00:12, 10.4MB/s]
5%|▌ | 6.50M/125M [00:00<00:11, 10.3MB/s]
6%|▌ | 7.50M/125M [00:00<00:11, 10.3MB/s]
7%|▋ | 8.62M/125M [00:00<00:11, 10.8MB/s]
8%|▊ | 9.75M/125M [00:00<00:11, 10.3MB/s]
9%|▊ | 10.9M/125M [00:01<00:11, 10.7MB/s]
10%|▉ | 12.0M/125M [00:01<00:11, 10.3MB/s]
10%|█ | 13.0M/125M [00:01<00:11, 10.3MB/s]
11%|█ | 14.0M/125M [00:01<00:11, 10.3MB/s]
12%|█▏ | 15.0M/125M [00:01<00:11, 10.3MB/s]
13%|█▎ | 16.0M/125M [00:01<00:11, 10.3MB/s]
14%|█▍ | 17.1M/125M [00:01<00:10, 10.7MB/s]
15%|█▍ | 18.2M/125M [00:01<00:10, 10.3MB/s]
15%|█▌ | 19.2M/125M [00:01<00:10, 10.3MB/s]
16%|█▋ | 20.2M/125M [00:02<00:10, 10.3MB/s]
17%|█▋ | 21.2M/125M [00:02<00:10, 10.3MB/s]
18%|█▊ | 22.4M/125M [00:02<00:09, 10.7MB/s]
19%|█▉ | 23.5M/125M [00:02<00:10, 10.3MB/s]
20%|█▉ | 24.5M/125M [00:02<00:10, 10.3MB/s]
20%|██ | 25.5M/125M [00:02<00:10, 10.3MB/s]
21%|██▏ | 26.8M/125M [00:02<00:09, 10.4MB/s]
22%|██▏ | 27.8M/125M [00:02<00:09, 10.4MB/s]
23%|██▎ | 28.8M/125M [00:02<00:09, 10.4MB/s]
24%|██▍ | 29.9M/125M [00:02<00:09, 10.8MB/s]
25%|██▍ | 31.0M/125M [00:03<00:09, 10.4MB/s]
26%|██▌ | 32.1M/125M [00:03<00:09, 10.7MB/s]
27%|██▋ | 33.2M/125M [00:03<00:09, 10.3MB/s]
28%|██▊ | 34.2M/125M [00:03<00:09, 10.3MB/s]
28%|██▊ | 35.2M/125M [00:03<00:09, 10.4MB/s]
29%|██▉ | 36.2M/125M [00:03<00:08, 10.4MB/s]
30%|██▉ | 37.2M/125M [00:03<00:08, 10.3MB/s]
31%|███ | 38.2M/125M [00:03<00:08, 10.4MB/s]
32%|███▏ | 39.4M/125M [00:03<00:08, 10.8MB/s]
33%|███▎ | 40.5M/125M [00:04<00:08, 10.3MB/s]
33%|███▎ | 41.5M/125M [00:04<00:08, 10.4MB/s]
34%|███▍ | 42.6M/125M [00:04<00:07, 10.7MB/s]
35%|███▌ | 43.8M/125M [00:04<00:08, 10.3MB/s]
36%|███▌ | 44.8M/125M [00:04<00:08, 10.4MB/s]
37%|███▋ | 45.8M/125M [00:04<00:07, 10.4MB/s]
38%|███▊ | 46.9M/125M [00:04<00:07, 10.8MB/s]
39%|███▊ | 48.0M/125M [00:04<00:07, 10.3MB/s]
39%|███▉ | 49.1M/125M [00:04<00:07, 10.7MB/s]
40%|████ | 50.2M/125M [00:05<00:07, 10.3MB/s]
41%|████ | 51.2M/125M [00:05<00:07, 10.3MB/s]
42%|████▏ | 52.2M/125M [00:05<00:07, 10.3MB/s]
43%|████▎ | 53.2M/125M [00:05<00:07, 10.3MB/s]
44%|████▎ | 54.2M/125M [00:05<00:07, 10.3MB/s]
44%|████▍ | 55.2M/125M [00:05<00:07, 10.3MB/s]
45%|████▌ | 56.2M/125M [00:05<00:06, 10.4MB/s]
46%|████▌ | 57.2M/125M [00:05<00:06, 10.4MB/s]
47%|████▋ | 58.4M/125M [00:05<00:06, 10.8MB/s]
48%|████▊ | 59.5M/125M [00:05<00:06, 10.3MB/s]
49%|████▊ | 60.5M/125M [00:06<00:06, 10.3MB/s]
49%|████▉ | 61.5M/125M [00:06<00:06, 10.4MB/s]
50%|█████ | 62.5M/125M [00:06<00:06, 10.4MB/s]
51%|█████ | 63.5M/125M [00:06<00:06, 10.4MB/s]
52%|█████▏ | 64.5M/125M [00:06<00:06, 10.4MB/s]
53%|█████▎ | 65.5M/125M [00:06<00:05, 10.4MB/s]
54%|█████▎ | 66.6M/125M [00:06<00:05, 10.8MB/s]
54%|█████▍ | 67.8M/125M [00:06<00:05, 10.3MB/s]
55%|█████▌ | 68.8M/125M [00:06<00:05, 10.3MB/s]
56%|█████▌ | 70.0M/125M [00:07<00:05, 10.4MB/s]
57%|█████▋ | 71.0M/125M [00:07<00:05, 10.4MB/s]
58%|█████▊ | 72.1M/125M [00:07<00:05, 10.8MB/s]
59%|█████▉ | 73.2M/125M [00:07<00:05, 10.3MB/s]
60%|█████▉ | 74.2M/125M [00:07<00:05, 10.4MB/s]
60%|██████ | 75.2M/125M [00:07<00:04, 10.4MB/s]
61%|██████▏ | 76.4M/125M [00:07<00:04, 10.8MB/s]
62%|██████▏ | 77.5M/125M [00:07<00:04, 10.3MB/s]
63%|██████▎ | 78.5M/125M [00:07<00:04, 10.3MB/s]
64%|██████▍ | 79.5M/125M [00:07<00:04, 10.3MB/s]
65%|██████▍ | 80.5M/125M [00:08<00:04, 10.4MB/s]
65%|██████▌ | 81.5M/125M [00:08<00:04, 10.4MB/s]
66%|██████▋ | 82.5M/125M [00:08<00:04, 10.4MB/s]
67%|██████▋ | 83.5M/125M [00:08<00:04, 10.4MB/s]
68%|██████▊ | 84.8M/125M [00:08<00:04, 10.4MB/s]
69%|██████▉ | 85.9M/125M [00:08<00:03, 10.8MB/s]
70%|██████▉ | 87.0M/125M [00:08<00:03, 10.3MB/s]
71%|███████ | 88.0M/125M [00:08<00:03, 10.3MB/s]
72%|███████▏ | 89.1M/125M [00:08<00:03, 10.7MB/s]
72%|███████▏ | 90.2M/125M [00:09<00:03, 10.3MB/s]
73%|███████▎ | 91.2M/125M [00:09<00:03, 10.3MB/s]
74%|███████▍ | 92.4M/125M [00:09<00:03, 10.7MB/s]
75%|███████▌ | 93.5M/125M [00:09<00:03, 10.3MB/s]
76%|███████▌ | 94.6M/125M [00:09<00:02, 10.7MB/s]
77%|███████▋ | 95.8M/125M [00:09<00:02, 10.3MB/s]
78%|███████▊ | 96.9M/125M [00:09<00:02, 10.7MB/s]
79%|███████▊ | 98.0M/125M [00:09<00:02, 10.3MB/s]
80%|███████▉ | 99.0M/125M [00:09<00:02, 10.3MB/s]
80%|████████ | 100M/125M [00:10<00:02, 10.3MB/s]
81%|████████ | 101M/125M [00:10<00:02, 10.3MB/s]
82%|████████▏ | 102M/125M [00:10<00:02, 10.3MB/s]
83%|████████▎ | 103M/125M [00:10<00:02, 10.3MB/s]
84%|████████▎ | 104M/125M [00:10<00:01, 10.8MB/s]
85%|████████▍ | 105M/125M [00:10<00:01, 10.3MB/s]
85%|████████▌ | 106M/125M [00:10<00:01, 10.3MB/s]
86%|████████▌ | 107M/125M [00:10<00:01, 10.7MB/s]
87%|████████▋ | 108M/125M [00:10<00:01, 10.3MB/s]
88%|████████▊ | 110M/125M [00:11<00:01, 10.3MB/s]
89%|████████▉ | 111M/125M [00:11<00:01, 10.7MB/s]
90%|████████▉ | 112M/125M [00:11<00:01, 10.3MB/s]
91%|█████████ | 113M/125M [00:11<00:01, 10.3MB/s]
91%|█████████▏| 114M/125M [00:11<00:01, 10.3MB/s]
92%|█████████▏| 115M/125M [00:11<00:00, 10.3MB/s]
93%|█████████▎| 116M/125M [00:11<00:00, 10.4MB/s]
94%|█████████▍| 117M/125M [00:11<00:00, 10.4MB/s]
95%|█████████▍| 118M/125M [00:11<00:00, 10.4MB/s]
95%|█████████▌| 119M/125M [00:11<00:00, 10.4MB/s]
96%|█████████▌| 120M/125M [00:12<00:00, 10.4MB/s]
97%|█████████▋| 121M/125M [00:12<00:00, 10.4MB/s]
98%|█████████▊| 122M/125M [00:12<00:00, 10.4MB/s]
99%|█████████▉| 123M/125M [00:12<00:00, 10.4MB/s]
100%|█████████▉| 124M/125M [00:12<00:00, 10.4MB/s]
100%|██████████| 125M/125M [00:12<00:00, 10.4MB/s]
Phase correction and signal reconstruction
# correct possible global phase shifts
x_pnp = correct_global_phase(x_phase_pnp, x_phase)
# extract phase information and normalize to the range [0, 1]
x_pnp = torch.angle(x_pnp) / (2 * torch.pi) + 0.5
plot([x, x_pnp], titles=["Signal", "Reconstruction"], rescale_mode="clip")
Overall comparison
We visualize the original image and the reconstructed images from the four methods. We further compute the PSNR (Peak Signal-to-Noise Ratio) scores (higher better) for every reconstruction and their cosine similarities with the original image (range in [0,1], higher better). In conclusion, gradient descent with random intialization provides a poor reconstruction, while spectral methods provide a good initial estimate which can later be improved by gradient descent to acquire the best reconstruction results. Besides, the PnP framework with a deep denoiser as the prior also provides a very good denoising results as it exploits prior information about the set of natural images.
imgs = [x, x_gd_rand, x_spec, x_gd_spec, x_pnp]
plot(
imgs,
titles=["Original", "GD random", "Spectral", "GD spectral", "PnP"],
save_dir=RESULTS_DIR / "images",
show=True,
rescale_mode="clip",
)
# Compute metrics
print(
f"GD Random reconstruction, PSNR: {dinv.metric.PSNR()(x, x_gd_rand).item():.2f} dB; cosine similarity: {cosine_similarity(x_phase_gd_rand, x_phase):.3f}."
)
print(
f"Spectral reconstruction, PSNR: {dinv.metric.PSNR()(x, x_spec).item():.2f} dB; cosine similarity: {cosine_similarity(x_phase_spec, x_phase):.3f}."
)
print(
f"GD Spectral reconstruction, PSNR: {dinv.metric.PSNR()(x, x_gd_spec).item():.2f} dB; cosine similarity: {cosine_similarity(x_phase_gd_spec, x_phase):.3f}."
)
print(
f"PnP reconstruction, PSNR: {dinv.metric.PSNR()(x, x_pnp).item():.2f} dB; cosine similarity: {cosine_similarity(x_phase_pnp, x_phase):.3f}."
)
GD Random reconstruction, PSNR: 4.14 dB; cosine similarity: 0.222.
Spectral reconstruction, PSNR: 18.88 dB; cosine similarity: 0.902.
GD Spectral reconstruction, PSNR: 63.81 dB; cosine similarity: 1.000.
PnP reconstruction, PSNR: 13.84 dB; cosine similarity: 0.999.
Total running time of the script: (1 minutes 6.999 seconds)