.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/basics/demo_denoiser_tour.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_basics_demo_denoiser_tour.py: A tour of DeepInv's denoisers =================================================== This example provides a tour of the denoisers in DeepInv. A denoiser is a model that takes in a noisy image and outputs a denoised version of it. Basically, it solves the following problem: .. math:: \underset{x}{\min}\|x - \denoiser{x + \sigma \epsilon}{\sigma}\|_2^2. The denoisers in DeepInv comes with different flavors, depending on whether they are derived from analytical image processing techniques or learned from data. This example will show how to use the different denoisers in DeepInv, compare their performances, and highlights the different tradeoffs they offer. .. GENERATED FROM PYTHON SOURCE LINES 18-29 .. code-block:: Python import time import torch import pandas as pd import matplotlib.pyplot as plt import deepinv as dinv from deepinv.utils.plotting import plot_inset from deepinv.utils.demo import load_url_image, get_image_url .. GENERATED FROM PYTHON SOURCE LINES 30-34 Load test images ---------------- First, let's load a test image to illustrate the denoisers. .. GENERATED FROM PYTHON SOURCE LINES 34-52 .. code-block:: Python dtype = torch.float32 device = "cpu" img_size = (173, 125) url = get_image_url("CBSD_0010.png") image = load_url_image( url, grayscale=False, device=device, dtype=dtype, img_size=img_size ) # Next, set the global random seed from pytorch to ensure reproducibility of the example. torch.manual_seed(0) torch.cuda.manual_seed(0) # Finally, create a noisy version of the image with a fixed noise level sigma. sigma = 0.2 noisy_image = image + sigma * torch.randn_like(image) .. GENERATED FROM PYTHON SOURCE LINES 53-55 For this tour, we define an helper function to display comparison of various restored images, with their PSNR values and zoom-in on a region of interest. .. GENERATED FROM PYTHON SOURCE LINES 55-89 .. code-block:: Python def show_image_comparison(images, suptitle=None, ref=None): """Display various images restoration with PSNR and zoom-in""" titles = list(images.keys()) if "Original" in images or ref is not None: # If the original image is in the dict, add PSNR in the titles. image = images["Original"] if "Original" in images else ref psnr = [dinv.metric.cal_psnr(image, im).item() for im in images.values()] titles = [ f"{name} \n (PSNR: {psnr:.2f})" if name != "Original" else name for name, psnr in zip(images.keys(), psnr) ] # Plot the images with zoom-in fig = plot_inset( list(images.values()), titles=titles, extract_size=0.2, extract_loc=(0.5, 0.0), inset_size=0.5, return_fig=True, show=False, figsize=(len(images) * 1.5, 2.5), ) # Add a suptitle if it is provided if suptitle: plt.suptitle(suptitle, size=12) plt.tight_layout() fig.subplots_adjust(top=0.85, bottom=0.02, left=0.02, right=0.95) plt.show() .. GENERATED FROM PYTHON SOURCE LINES 90-101 We are now ready to explore the different denoisers. Classical Denoisers ------------------- DeepInv provides a set of classical denoisers such as :class:`deepinv.models.BM3D`, :class:`deepinv.models.TGVDenoiser`, or :class:`deepinv.models.WaveletDictDenoiser`. They can be easily used simply by instanciating their corresponding class, and calling them with the noisy image and the noise level. .. GENERATED FROM PYTHON SOURCE LINES 101-114 .. code-block:: Python bm3d = dinv.models.BM3D() tgv = dinv.models.TGVDenoiser() wavelet = dinv.models.WaveletDictDenoiser() denoiser_results = { "Original": image, "Noisy": noisy_image, "BM3D": bm3d(noisy_image, sigma), "TGV": tgv(noisy_image, sigma), "Wavelet": wavelet(noisy_image, sigma), } show_image_comparison(denoiser_results, suptitle=rf"Noise level $\sigma={sigma:.2f}$") .. image-sg:: /auto_examples/basics/images/sphx_glr_demo_denoiser_tour_001.png :alt: Noise level $\sigma=0.20$, Original, Noisy (PSNR: 13.99), BM3D (PSNR: 23.86), TGV (PSNR: 15.75), Wavelet (PSNR: 20.38) :srcset: /auto_examples/basics/images/sphx_glr_demo_denoiser_tour_001.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 115-123 Deep Denoisers -------------- DeepInv also provides a set of deep denoisers. Most of these denoisers are available with pretrained weights, so they can be used readily. To instantiate them, you can simply call their corresponding class with default parameters and ``pretrained="download"`` to load their weights. You can then apply them by calling the model with the noisy image and the noise level. .. GENERATED FROM PYTHON SOURCE LINES 123-138 .. code-block:: Python dncnn = dinv.models.DnCNN() drunet = dinv.models.DRUNet() swinir = dinv.models.SwinIR() scunet = dinv.models.SCUNet() denoiser_results = { "Original": image, "Noisy": noisy_image, "DnCNN": dncnn(noisy_image, sigma), "DRUNet": drunet(noisy_image, sigma), "SCUNet": scunet(noisy_image, sigma), "SwinIR": swinir(noisy_image, sigma), } show_image_comparison(denoiser_results, suptitle=rf"Noise level $\sigma={sigma:.2f}$") .. image-sg:: /auto_examples/basics/images/sphx_glr_demo_denoiser_tour_002.png :alt: Noise level $\sigma=0.20$, Original, Noisy (PSNR: 13.99), DnCNN (PSNR: 13.99), DRUNet (PSNR: 27.13), SCUNet (PSNR: 24.37), SwinIR (PSNR: 15.53) :srcset: /auto_examples/basics/images/sphx_glr_demo_denoiser_tour_002.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none Downloading: "https://huggingface.co/deepinv/dncnn/resolve/main/dncnn_sigma2_color.pth?download=true" to /home/runner/.cache/torch/hub/checkpoints/dncnn_sigma2_color.pth 0%| | 0.00/2.56M [00:00` for more details on the chose noise level. A way to improve the performance of these models is to artificially rescale the input image to match the training noise level. We can define a wrapper that automatically applies this rescaling. .. GENERATED FROM PYTHON SOURCE LINES 320-373 .. code-block:: Python class AdaptedDenoiser: r""" This function rescales the input image to match the noise level of the model, applies the denoiser, and then rescales the output to the original noise level. """ def __init__(self, model, sigma_train): self.model = model self.sigma_train = sigma_train def __call__(self, image, sigma): if isinstance(sigma, torch.Tensor): # If sigma is a tensor, we assume it is one value per element in the batch assert len(sigma) == image.shape[0] sigma = sigma[:, None, None, None] # Rescale the output to match the original noise level rescaled_image = image / sigma * self.sigma_train with torch.no_grad(): output = self.model(rescaled_image, self.sigma_train) output = output * sigma / self.sigma_train return output # Apply to DnCNN and SwinIR sigma_train_dncnn = 2.0 / 255.0 adapted_dncnn = AdaptedDenoiser(dncnn, sigma_train_dncnn) # Apply SwinIR # sigma_train_swinir = 15.0 / 255.0 # adapted_swinir = AdaptedDenoiser(swinir, sigma_train_swinir) # sphinx_gallery_multi_image = "single" denoiser_results = { f"Original": image, f"Noisy": noisy_image, f"DnCNN": dncnn(noisy_image, sigma), f"DnCNN (adapted)": adapted_dncnn(noisy_image, sigma), } show_image_comparison(denoiser_results, suptitle=rf"Noise level $\sigma={sigma:.2f}$") denoiser_results = { # Skipping SwinIR on CI due to high memory usage # f"SwinIR": swinir(noisy_image, sigma), # f"SwinIR (adapted)": adapted_swinir(noisy_image, sigma), f"DRUNet": drunet(noisy_image, sigma), f"SCUNet": scunet(noisy_image, sigma), } show_image_comparison( denoiser_results, ref=image, suptitle=rf"Noise level $\sigma={sigma:.2f}$" ) .. rst-class:: sphx-glr-horizontal * .. image-sg:: /auto_examples/basics/images/sphx_glr_demo_denoiser_tour_006.png :alt: Noise level $\sigma=0.20$, Original, Noisy (PSNR: 13.99), DnCNN (PSNR: 13.99), DnCNN (adapted) (PSNR: 24.79) :srcset: /auto_examples/basics/images/sphx_glr_demo_denoiser_tour_006.png :class: sphx-glr-single-img * .. image-sg:: /auto_examples/basics/images/sphx_glr_demo_denoiser_tour_007.png :alt: Noise level $\sigma=0.20$, DRUNet (PSNR: 27.13), SCUNet (PSNR: 24.37) :srcset: /auto_examples/basics/images/sphx_glr_demo_denoiser_tour_007.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 374-375 We can finally update our comparison with the adapted denoisers for DnCNN and SwinIR. .. GENERATED FROM PYTHON SOURCE LINES 375-410 .. code-block:: Python adapted_denoisers = { # "SwinIR": adapted_swinir, # SwinIR is slow for this example, skipping it in the doc "DnCNN (adapted)": adapted_dncnn, } res = [] for name, d in adapted_denoisers.items(): print(f"Denoiser {name}...", end="", flush=True) t_start = time.perf_counter() with torch.no_grad(): clean_images = d(noisy_images, noise_levels) psnr_x = psnr(clean_images, image) runtime = time.perf_counter() - t_start res.extend( [ {"sigma": sig.item(), "denoiser": name, "psnr": v.item(), "time": runtime} for sig, v in zip(noise_levels, psnr_x) ] ) print(f" done ({runtime:.2f}s)") df_adapted = pd.DataFrame(res) merge_df = pd.concat( [merge_df.query("~denoiser.isin(['DnCNN', 'SwinIR'])"), df_adapted] ) _, ax = plt.subplots(figsize=(6, 4)) for name, g in merge_df.groupby("denoiser"): g.plot(x=r"sigma", y="psnr", label=name, ax=ax, **styles.get(name, {})) ax.set_xscale("log") ax.set_xlabel(r"$\sigma$") ax.set_ylabel("PSNR") plt.legend(bbox_to_anchor=(1.02, 1), loc="upper left", borderaxespad=0) plt.tight_layout() plt.show() .. image-sg:: /auto_examples/basics/images/sphx_glr_demo_denoiser_tour_008.png :alt: demo denoiser tour :srcset: /auto_examples/basics/images/sphx_glr_demo_denoiser_tour_008.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none Denoiser DnCNN (adapted)... done (2.24s) .. GENERATED FROM PYTHON SOURCE LINES 411-415 We can see that the adapted denoisers achieve better performances than the original ones, but they are still not as good as DRUNet which is trained for a wide range of noise levels. Finally, we can also compare the tradeoff between computation time and performances of the different denoisers. .. GENERATED FROM PYTHON SOURCE LINES 415-432 .. code-block:: Python fig = plt.figure(figsize=(12, 6)) grid = plt.GridSpec(2, 2, height_ratios=[0.25, 0.75]) for i, sig in enumerate(noise_levels[[0, 4]]): ax = fig.add_subplot(grid[1, i]) to_plot = merge_df.query(f"sigma == {sig}") handles = [] for name, g in to_plot.groupby("denoiser"): handles.append(ax.scatter(g["time"], g["psnr"], label=name)) ax.set_title(rf"$\sigma={sig:.2f}$") ax.set_xscale("log") ax.set_xlabel("Time (s)") ax.set_ylabel("PSNR") ax_legend = fig.add_subplot(grid[0, :]) ax_legend.legend(handles=handles, ncol=3, loc="center") ax_legend.set_axis_off() .. image-sg:: /auto_examples/basics/images/sphx_glr_demo_denoiser_tour_009.png :alt: $\sigma=0.01$, $\sigma=0.10$ :srcset: /auto_examples/basics/images/sphx_glr_demo_denoiser_tour_009.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 433-435 We see that depending on the noise-level, the tadeoff between computation time and performances changes, with the deep denoisers performing the best .. rst-class:: sphx-glr-timing **Total running time of the script:** (1 minutes 31.962 seconds) .. _sphx_glr_download_auto_examples_basics_demo_denoiser_tour.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: demo_denoiser_tour.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: demo_denoiser_tour.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: demo_denoiser_tour.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_