.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/basics/demo_ptychography.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_basics_demo_ptychography.py: Ptychography phase retrieval ============================ This example shows how to create a Ptychography phase retrieval operator and generate phaseless measurements from a given image. .. GENERATED FROM PYTHON SOURCE LINES 9-13 General setup ------------- Imports the necessary libraries and modules, including ptychography phase retrieval function from `deepinv`. It sets the device to GPU if available, otherwise uses the CPU. .. GENERATED FROM PYTHON SOURCE LINES 13-26 .. code-block:: Python import matplotlib.pyplot as plt import torch import numpy as np import deepinv as dinv from deepinv.utils.demo import load_url_image, get_image_url from deepinv.utils.plotting import plot from deepinv.physics import Ptychography from deepinv.optim.data_fidelity import L1 from deepinv.optim.phase_retrieval import correct_global_phase device = dinv.utils.get_freer_gpu() if torch.cuda.is_available() else "cpu" .. GENERATED FROM PYTHON SOURCE LINES 27-30 Load image from the internet ---------------------------- Loads a sample image from a URL, resizes it to 128x128 pixels, and extracts only one color channel. .. GENERATED FROM PYTHON SOURCE LINES 30-40 .. code-block:: Python size = 128 url = get_image_url("CBSD_0010.png") image = load_url_image(url, grayscale=False, img_size=(size, size)) x = image[:, 0, ...].unsqueeze(1) # Take only one channel print(x.shape) plot([x], figsize=(10, 10)) .. image-sg:: /auto_examples/basics/images/sphx_glr_demo_ptychography_001.png :alt: demo ptychography :srcset: /auto_examples/basics/images/sphx_glr_demo_ptychography_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none torch.Size([1, 1, 128, 128]) .. GENERATED FROM PYTHON SOURCE LINES 41-44 Prepare phase input ------------------- We use the original image as the phase information for the complex signal. The original value range is [0, 1], and we map it to the phase range [0, pi]. .. GENERATED FROM PYTHON SOURCE LINES 44-49 .. code-block:: Python phase = x / x.max() * np.pi # between 0 and pi input = torch.exp(1j * phase.to(torch.complex64)).to(device) .. GENERATED FROM PYTHON SOURCE LINES 50-54 Set up ptychography physics model --------------------------------- Initializes the ptychography physics model with parameters like the probe and shifts. This model will be used to simulate ptychography measurements. .. GENERATED FROM PYTHON SOURCE LINES 54-69 .. code-block:: Python img_size = (1, size, size) n_img = 10**2 probe = dinv.physics.phase_retrieval.build_probe( img_size, type="disk", probe_radius=30, device=device ) shifts = dinv.physics.phase_retrieval.generate_shifts(img_size, n_img=n_img, fov=170) physics = Ptychography( in_shape=img_size, probe=probe, shifts=shifts, device=device, ) .. GENERATED FROM PYTHON SOURCE LINES 70-73 Display probe overlap --------------------- Calculates and displays the overlap of probe regions in the image, helping visualize the ptychography pattern. .. GENERATED FROM PYTHON SOURCE LINES 73-82 .. code-block:: Python overlap_img = physics.B.get_overlap_img(physics.B.shifts).cpu() overlap2probe = physics.B.get_overlap_img(physics.B.shifts[55:57]).cpu() plot( [overlap2probe.unsqueeze(0), overlap_img.unsqueeze(0)], titles=["Overlap 2 probe", "Overlap images"], ) .. image-sg:: /auto_examples/basics/images/sphx_glr_demo_ptychography_002.png :alt: Overlap 2 probe, Overlap images :srcset: /auto_examples/basics/images/sphx_glr_demo_ptychography_002.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 83-86 Generate and visualize probe and measurements --------------------------------------------- Displays the ptychography probe and a sum of the generated measurement data. .. GENERATED FROM PYTHON SOURCE LINES 86-95 .. code-block:: Python probe = physics.probe[:, 55].cpu() y = physics(input) plot( [torch.abs(probe), y[0].sum(dim=0).log().unsqueeze(0)], titles=["Probe", "y"], ) .. image-sg:: /auto_examples/basics/images/sphx_glr_demo_ptychography_003.png :alt: Probe, y :srcset: /auto_examples/basics/images/sphx_glr_demo_ptychography_003.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 96-99 Gradient descent for phase retrieval ------------------------------------ Implements a simple gradient descent algorithm to minimize the L1 data fidelity loss for phase retrieval. .. GENERATED FROM PYTHON SOURCE LINES 99-117 .. code-block:: Python data_fidelity = L1() lr = 0.1 n_iter = 200 x_est = torch.randn_like(x).to(device) loss_hist = [] for i in range(n_iter): x_est = x_est - lr * data_fidelity.grad(x_est, y, physics) loss_hist.append(data_fidelity(x_est, y, physics).cpu()) if i % 10 == 0: print(f"Iter {i}, loss: {loss_hist[i]}") # Plot the loss curve plt.plot(loss_hist) plt.title("loss curve (gradient descent with random initialization)") plt.show() .. image-sg:: /auto_examples/basics/images/sphx_glr_demo_ptychography_004.png :alt: loss curve (gradient descent with random initialization) :srcset: /auto_examples/basics/images/sphx_glr_demo_ptychography_004.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none Iter 0, loss: tensor([24033.8477]) Iter 10, loss: tensor([14610.7012]) Iter 20, loss: tensor([12231.4805]) Iter 30, loss: tensor([9531.4746]) Iter 40, loss: tensor([8343.7061]) Iter 50, loss: tensor([7400.0605]) Iter 60, loss: tensor([6179.6592]) Iter 70, loss: tensor([4710.7944]) Iter 80, loss: tensor([3487.5955]) Iter 90, loss: tensor([2489.2402]) Iter 100, loss: tensor([1696.7488]) Iter 110, loss: tensor([1269.8395]) Iter 120, loss: tensor([2899.7510]) Iter 130, loss: tensor([4312.7715]) Iter 140, loss: tensor([3467.2544]) Iter 150, loss: tensor([3429.9648]) Iter 160, loss: tensor([3731.9387]) Iter 170, loss: tensor([3542.7930]) Iter 180, loss: tensor([4249.0420]) Iter 190, loss: tensor([3322.3809]) .. GENERATED FROM PYTHON SOURCE LINES 118-122 Display final estimated phase retrieval --------------------------------------- Corrects the global phase of the estimated image to match the original phase and plots the result. This final visualization shows the original image and the estimated phase after retrieval. .. GENERATED FROM PYTHON SOURCE LINES 122-126 .. code-block:: Python x_est = x_est.detach().cpu() final_est = correct_global_phase(x_est, x) plot([x, torch.angle(final_est)], titles=["x", "Final estimate"]) .. image-sg:: /auto_examples/basics/images/sphx_glr_demo_ptychography_005.png :alt: x, Final estimate :srcset: /auto_examples/basics/images/sphx_glr_demo_ptychography_005.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 7.218 seconds) .. _sphx_glr_download_auto_examples_basics_demo_ptychography.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: demo_ptychography.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: demo_ptychography.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: demo_ptychography.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_