.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/basics/demo_phase_retrieval.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_basics_demo_phase_retrieval.py: Random phase retrieval and reconstruction methods. =================================================== This example shows how to create a random phase retrieval operator and generate phaseless measurements from a given image. The example showcases 4 different reconstruction methods to recover the image from the phaseless measurements: #. Gradient descent with random initialization; #. Spectral methods; #. Gradient descent with spectral methods initialization; #. Gradient descent with PnP denoisers. .. GENERATED FROM PYTHON SOURCE LINES 15-17 General setup ---------------------------- .. GENERATED FROM PYTHON SOURCE LINES 17-41 .. code-block:: Python import deepinv as dinv from pathlib import Path import torch import matplotlib.pyplot as plt from deepinv.models import DRUNet from deepinv.optim.data_fidelity import L2 from deepinv.optim.prior import PnP, Zero from deepinv.optim.optimizers import optim_builder from deepinv.utils.demo import load_url_image, get_image_url from deepinv.utils.plotting import plot from deepinv.optim.phase_retrieval import ( correct_global_phase, cosine_similarity, spectral_methods, ) from deepinv.models.complex import to_complex_denoiser BASE_DIR = Path(".") RESULTS_DIR = BASE_DIR / "results" # Set global random seed to ensure reproducibility. torch.manual_seed(0) device = dinv.utils.get_freer_gpu() if torch.cuda.is_available() else "cpu" .. GENERATED FROM PYTHON SOURCE LINES 42-46 Load image from the internet ---------------------------- We use the standard test image "Shepp–Logan phantom". .. GENERATED FROM PYTHON SOURCE LINES 46-56 .. code-block:: Python # Image size img_size = 32 url = get_image_url("SheppLogan.png") # The pixel values of the image are in the range [0, 1]. x = load_url_image( url=url, img_size=img_size, grayscale=True, resize_mode="resize", device=device ) print(x.min(), x.max()) .. rst-class:: sphx-glr-script-out .. code-block:: none tensor(0.) tensor(0.7412) .. GENERATED FROM PYTHON SOURCE LINES 57-61 Visualization --------------------------------------- We use the customized plot() function in deepinv to visualize the original image. .. GENERATED FROM PYTHON SOURCE LINES 61-63 .. code-block:: Python plot(x, titles="Original image") .. image-sg:: /auto_examples/basics/images/sphx_glr_demo_phase_retrieval_001.png :alt: Original image :srcset: /auto_examples/basics/images/sphx_glr_demo_phase_retrieval_001.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 64-67 Signal construction --------------------------------------- We use the original image as the phase information for the complex signal. The original value range is [0, 1], and we map it to the phase range [-pi/2, pi/2]. .. GENERATED FROM PYTHON SOURCE LINES 67-72 .. code-block:: Python x_phase = torch.exp(1j * x * torch.pi - 0.5j * torch.pi) # Every element of the signal should have unit norm. assert torch.allclose(x_phase.real**2 + x_phase.imag**2, torch.tensor(1.0)) .. GENERATED FROM PYTHON SOURCE LINES 73-78 Measurements generation --------------------------------------- Create a random phase retrieval operator with an oversampling ratio (measurements/pixels) of 5.0, and generate measurements from the signal with additive Gaussian noise. .. GENERATED FROM PYTHON SOURCE LINES 78-95 .. code-block:: Python # Define physics information oversampling_ratio = 5.0 img_shape = x.shape[1:] m = int(oversampling_ratio * torch.prod(torch.tensor(img_shape))) n_channels = 1 # 3 for color images, 1 for gray-scale images # Create the physics physics = dinv.physics.RandomPhaseRetrieval( m=m, img_shape=img_shape, device=device, ) # Generate measurements y = physics(x_phase) .. GENERATED FROM PYTHON SOURCE LINES 96-99 Reconstruction with gradient descent and random initialization --------------------------------------------------------------- First, we use the function :class:`deepinv.optim.L2` as the data fidelity function, and the class :class:`deepinv.optim.optim_iterators.GDIteration` as the optimizer to run a gradient descent algorithm. The initial guess is a random complex signal. .. GENERATED FROM PYTHON SOURCE LINES 99-132 .. code-block:: Python data_fidelity = L2() prior = Zero() iterator = dinv.optim.optim_iterators.GDIteration() # Parameters for the optimizer, including stepsize and regularization coefficient. optim_params = {"stepsize": 0.06, "lambda": 1.0, "g_param": []} num_iter = 1000 # Initial guess x_phase_gd_rand = torch.randn_like(x_phase) loss_hist = [] for _ in range(num_iter): res = iterator( {"est": (x_phase_gd_rand,), "cost": 0}, cur_data_fidelity=data_fidelity, cur_prior=prior, cur_params=optim_params, y=y, physics=physics, ) x_phase_gd_rand = res["est"][0] loss_hist.append(data_fidelity(x_phase_gd_rand, y, physics).cpu()) print("initial loss:", loss_hist[0]) print("final loss:", loss_hist[-1]) # Plot the loss curve plt.plot(loss_hist) plt.yscale("log") plt.title("loss curve (gradient descent with random initialization)") plt.show() .. image-sg:: /auto_examples/basics/images/sphx_glr_demo_phase_retrieval_002.png :alt: loss curve (gradient descent with random initialization) :srcset: /auto_examples/basics/images/sphx_glr_demo_phase_retrieval_002.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none initial loss: tensor([190.4569]) final loss: tensor([28.0710]) .. GENERATED FROM PYTHON SOURCE LINES 133-139 Phase correction and signal reconstruction ----------------------------------------------------------- The solution of the optimization algorithm x_est may be any phase-shifted version of the original complex signal x_phase, i.e., x_est = a * x_phase where a is an arbitrary unit norm complex number. Therefore, we use the function :class:`deepinv.optim.phase_retrieval.correct_global_phase` to correct the global phase shift of the estimated signal x_est to make it closer to the original signal x_phase. We then use ``torch.angle`` to extract the phase information. With the range of the returned value being [-pi/2, pi/2], we further normalize it to be [0, 1]. This operation will later be done for all the reconstruction methods. .. GENERATED FROM PYTHON SOURCE LINES 139-147 .. code-block:: Python # correct possible global phase shifts x_gd_rand = correct_global_phase(x_phase_gd_rand, x_phase) # extract phase information and normalize to the range [0, 1] x_gd_rand = torch.angle(x_gd_rand) / torch.pi + 0.5 plot([x, x_gd_rand], titles=["Signal", "Reconstruction"], rescale_mode="clip") .. image-sg:: /auto_examples/basics/images/sphx_glr_demo_phase_retrieval_003.png :alt: Signal, Reconstruction :srcset: /auto_examples/basics/images/sphx_glr_demo_phase_retrieval_003.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 148-151 Reconstruction with spectral methods --------------------------------------------------------------- Spectral methods :class:`deepinv.optim.phase_retrieval.spectral_methods` offers a good initial guess on the original signal. Moreover, :class:`deepinv.physics.RandomPhaseRetrieval` uses spectral methods as its default reconstruction method `A_dagger`, which we can directly call. .. GENERATED FROM PYTHON SOURCE LINES 151-155 .. code-block:: Python # Spectral methods return a tensor with unit norm. x_phase_spec = physics.A_dagger(y, n_iter=300) .. GENERATED FROM PYTHON SOURCE LINES 156-158 Phase correction and signal reconstruction ----------------------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 158-165 .. code-block:: Python # correct possible global phase shifts x_spec = correct_global_phase(x_phase_spec, x_phase) # extract phase information and normalize to the range [0, 1] x_spec = torch.angle(x_spec) / torch.pi + 0.5 plot([x, x_spec], titles=["Signal", "Reconstruction"], rescale_mode="clip") .. image-sg:: /auto_examples/basics/images/sphx_glr_demo_phase_retrieval_004.png :alt: Signal, Reconstruction :srcset: /auto_examples/basics/images/sphx_glr_demo_phase_retrieval_004.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 166-169 Reconstruction with gradient descent and spectral methods initialization ------------------------------------------------------------------------- The estimate from spectral methods can be directly used as the initial guess for the gradient descent algorithm. .. GENERATED FROM PYTHON SOURCE LINES 169-193 .. code-block:: Python # Initial guess from spectral methods x_phase_gd_spec = physics.A_dagger(y, n_iter=300) loss_hist = [] for _ in range(num_iter): res = iterator( {"est": (x_phase_gd_spec,), "cost": 0}, cur_data_fidelity=data_fidelity, cur_prior=prior, cur_params=optim_params, y=y, physics=physics, ) x_phase_gd_spec = res["est"][0] loss_hist.append(data_fidelity(x_phase_gd_spec, y, physics).cpu()) print("intial loss:", loss_hist[0]) print("final loss:", loss_hist[-1]) plt.plot(loss_hist) plt.yscale("log") plt.title("loss curve (gradient descent with spectral initialization)") plt.show() .. image-sg:: /auto_examples/basics/images/sphx_glr_demo_phase_retrieval_005.png :alt: loss curve (gradient descent with spectral initialization) :srcset: /auto_examples/basics/images/sphx_glr_demo_phase_retrieval_005.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none intial loss: tensor([42.0413]) final loss: tensor([0.0034]) .. GENERATED FROM PYTHON SOURCE LINES 194-196 Phase correction and signal reconstruction ----------------------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 196-203 .. code-block:: Python # correct possible global phase shifts x_gd_spec = correct_global_phase(x_phase_gd_spec, x_phase) # extract phase information and normalize to the range [0, 1] x_gd_spec = torch.angle(x_gd_spec) / torch.pi + 0.5 plot([x, x_gd_spec], titles=["Signal", "Reconstruction"], rescale_mode="clip") .. image-sg:: /auto_examples/basics/images/sphx_glr_demo_phase_retrieval_006.png :alt: Signal, Reconstruction :srcset: /auto_examples/basics/images/sphx_glr_demo_phase_retrieval_006.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 204-207 Reconstruction with gradient descent and PnP denoisers --------------------------------------------------------------- We can also use the Plug-and-Play (PnP) framework to incorporate denoisers as regularizers in the optimization algorithm. We use a deep denoiser as the prior, which is trained on a large dataset of natural images. .. GENERATED FROM PYTHON SOURCE LINES 207-240 .. code-block:: Python # Load the pre-trained denoiser denoiser = DRUNet( in_channels=n_channels, out_channels=n_channels, pretrained="download", # automatically downloads the pretrained weights, set to a path to use custom weights. device=device, ) # The original denoiser is designed for real-valued images, so we need to convert it to a complex-valued denoiser for phase retrieval problems. denoiser_complex = to_complex_denoiser(denoiser, mode="abs_angle") # Algorithm parameters data_fidelity = L2() prior = PnP(denoiser=denoiser_complex) params_algo = {"stepsize": 0.30, "g_param": 0.04} max_iter = 100 early_stop = True verbose = True # Instantiate the algorithm class to solve the IP problem. model = optim_builder( iteration="PGD", prior=prior, data_fidelity=data_fidelity, early_stop=early_stop, max_iter=max_iter, verbose=verbose, params_algo=params_algo, ) # Run the algorithm x_phase_pnp, metrics = model(y, physics, x_gt=x_phase, compute_metrics=True) .. rst-class:: sphx-glr-script-out .. code-block:: none Downloading: "https://huggingface.co/deepinv/drunet/resolve/main/drunet_deepinv_gray_finetune_26k.pth?download=true" to /home/runner/.cache/torch/hub/checkpoints/drunet_deepinv_gray_finetune_26k.pth 0%| | 0.00/125M [00:00` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: demo_phase_retrieval.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: demo_phase_retrieval.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_