.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/external-libraries/demo_connect_spyrit.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note New to DeepInverse? Get started with the basics with the :ref:`5 minute quickstart tutorial `. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_external-libraries_demo_connect_spyrit.py: Single-pixel imaging with Spyrit ==================================================================================================== This example shows how to use Spyrit linear models and measurements with DeepInverse. Here we use the HadamSplit2d linear model from Spyrit. .. GENERATED FROM PYTHON SOURCE LINES 10-12 Load images ----------------------------------------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 14-15 We start by loading the butterfly image using `func`:`deepinv.utils.load_example`: .. GENERATED FROM PYTHON SOURCE LINES 15-27 .. code-block:: Python import torch.nn from deepinv.utils import plot import deepinv as dinv device = dinv.utils.get_freer_gpu() if torch.cuda.is_available() else "cpu" im_size = 64 x = dinv.utils.load_example( "butterfly.png", device=device, img_size=(im_size, im_size), grayscale=True ) print(f"Ground-truth image: {x.shape}") .. rst-class:: sphx-glr-script-out .. code-block:: none Ground-truth image: torch.Size([1, 1, 64, 64]) .. GENERATED FROM PYTHON SOURCE LINES 28-29 Then we plot it: .. GENERATED FROM PYTHON SOURCE LINES 29-32 .. code-block:: Python plot(x, r"$32\times 32$ image $X$") .. image-sg:: /auto_examples/external-libraries/images/sphx_glr_demo_connect_spyrit_001.png :alt: $32\times 32$ image $X$ :srcset: /auto_examples/external-libraries/images/sphx_glr_demo_connect_spyrit_001.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 33-35 Basic example ----------------------------------------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 37-38 We instantiate an HadamSplit2d object and simulate the 2D hadamard transform of the input images. Reshape output is necesary for deepinv. We also add Poisson noise. .. GENERATED FROM PYTHON SOURCE LINES 38-50 .. code-block:: Python from spyrit.core.meas import HadamSplit2d from spyrit.core.prep import UnsplitRescale physics_spyrit = HadamSplit2d(im_size, 512, device=device, reshape_output=True) y_spyrit = physics_spyrit(x) # preprocess prep = UnsplitRescale(alpha=1.0) y_spyrit = prep(y_spyrit) print(y_spyrit.shape) .. rst-class:: sphx-glr-script-out .. code-block:: none torch.Size([1, 1, 512]) .. GENERATED FROM PYTHON SOURCE LINES 51-52 The norm has to be computed to be passed to deepinv. We need to use the max singular value of the linear operator. .. GENERATED FROM PYTHON SOURCE LINES 52-56 .. code-block:: Python norm = torch.linalg.norm(physics_spyrit.H, ord=2) print(norm) .. rst-class:: sphx-glr-script-out .. code-block:: none tensor(64.0006) .. GENERATED FROM PYTHON SOURCE LINES 57-59 Forward operator ---------------------------------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 61-62 You can direcly give the forward operator to deepinv. You can also add noise using deepinv model or spyrit model. .. GENERATED FROM PYTHON SOURCE LINES 62-69 .. code-block:: Python physics_deepinv = dinv.physics.LinearPhysics( lambda y: physics_spyrit.measure_H(y) / norm, A_adjoint=lambda y: physics_spyrit.unvectorize(physics_spyrit.adjoint_H(y) / norm), ) y_deepinv = physics_deepinv(x) print("diff:", torch.linalg.norm(y_spyrit / norm - y_deepinv)) .. rst-class:: sphx-glr-script-out .. code-block:: none diff: tensor(2.4486e-05) .. GENERATED FROM PYTHON SOURCE LINES 70-72 Computing the reconstructions ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. GENERATED FROM PYTHON SOURCE LINES 74-75 All of the usual solvers work out of the box and we showcase some of them here starting with simple linear reconstructions using :meth:`deepinv.physics.LinearPhysics.A_adjoint` and :meth:`deepinv.physics.LinearPhysics.A_dagger`: .. GENERATED FROM PYTHON SOURCE LINES 75-79 .. code-block:: Python x_adj = physics_deepinv.A_adjoint(y_spyrit / norm) x_pinv = physics_deepinv.A_dagger(y_spyrit / norm) .. GENERATED FROM PYTHON SOURCE LINES 80-81 You can also use optimization-based methods from deepinv. Here, we use Total Variation (TV) regularization with a projected gradient descent (PGD) algorithm. You can note the use of the custom_init parameter to initialize the algorithm with the dagger operator. .. GENERATED FROM PYTHON SOURCE LINES 81-94 .. code-block:: Python model_tv = dinv.optim.optim_builder( iteration="PGD", prior=dinv.optim.TVPrior(), data_fidelity=dinv.optim.L2(), params_algo={"stepsize": 1, "lambda": 5e-2}, max_iter=10, custom_init=lambda y, Physics: {"est": (Physics.A_dagger(y),)}, ) x_tv, metrics_TV = model_tv( y_spyrit / norm, physics_deepinv, compute_metrics=True, x_gt=x ) .. GENERATED FROM PYTHON SOURCE LINES 95-96 And so do deep learning methods: .. GENERATED FROM PYTHON SOURCE LINES 96-102 .. code-block:: Python denoiser = dinv.models.DRUNet(in_channels=1, out_channels=1, device=device) model_dpir = dinv.optim.DPIR(sigma=1e-1, device=device, denoiser=denoiser) model_dpir.custom_init = lambda y, Physics: {"est": (Physics.A_dagger(y),)} with torch.no_grad(): x_dpir = model_dpir(y_spyrit / norm, physics_deepinv) .. GENERATED FROM PYTHON SOURCE LINES 103-104 Including reconstruction with :class:`deepinv.models.RAM`: .. GENERATED FROM PYTHON SOURCE LINES 104-133 .. code-block:: Python model_ram = dinv.models.RAM(pretrained=True, device=device) model_ram.sigma_threshold = 1e-1 with torch.no_grad(): x_ram = model_ram(y_spyrit / norm, physics_deepinv) metric = dinv.metric.PSNR() psnr_y = 0 psnr_pinv = metric(x_pinv, x).item() psnr_tv = metric(x_tv, x).item() psnr_dpir = metric(x_dpir, x).item() psnr_ram = metric(x_ram, x).item() dinv.utils.plot( { "Ground Truth": x, "Pseudo-Inverse": x_pinv, "TV": x_tv, "DPIR": x_dpir, "RAM": x_ram, }, subtitles=[ "PSNR (dB):", f"{psnr_pinv:.2f}", f"{psnr_tv:.2f}", f"{psnr_dpir:.2f}", f"{psnr_ram:.2f}", ], ) .. image-sg:: /auto_examples/external-libraries/images/sphx_glr_demo_connect_spyrit_002.png :alt: Ground Truth, Pseudo-Inverse, TV, DPIR, RAM :srcset: /auto_examples/external-libraries/images/sphx_glr_demo_connect_spyrit_002.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 6.101 seconds) .. _sphx_glr_download_auto_examples_external-libraries_demo_connect_spyrit.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: demo_connect_spyrit.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: demo_connect_spyrit.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: demo_connect_spyrit.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_