.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/metrics/demo_custom_niqe.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note New to DeepInverse? Get started with the basics with the :ref:`5 minute quickstart tutorial `.. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_metrics_demo_custom_niqe.py: Fitting NIQE on a custom dataset ==================================================================================================== This example shows how to fit :class:`deepinv.loss.metric.NIQE` on a new dataset, and use it to evaluate denoiser performance. NIQE is a no-reference image quality metric that compares local image statistics against those of pristine (distortion-free) images. Fitting NIQE on a domain-specific dataset can better capture the expected image characteristics. In this example, we fit NIQE on DIV2K. DIV2K is also natural imaging data, but the image quality and sharpness is higher than the dataset NIQE was originally fitted on, so the resulting weights characterise a sharper, higher-quality prior while remaining valid NIQE statistics. To apply this procedure to your own data, any dataset returning RGB or single-channel tensors will work. The ``denominator`` constructor argument divides input pixels before computing statistics; it serves two purposes: keeping pixel magnitudes from dominating the local-statistics computation, and matching the input scale to the scale the weights were fitted on. Two consequences: - When using the bundled original NIQE weights (which were fitted on [0, 255] data with ``denominator=1``), inputs must reach NIQE on a comparable [0, 255] scale. So for [0, 1] data, pass ``denominator=1/255`` (``x / (1/255) = 255 * x``), or scale to [0, 255] before calling and leave ``denominator=1``. - When fitting your own weights, the only requirement is that the *same* ``denominator`` is used at fit and evaluation time. The absolute scale is up to you. In this example we want to compare the original and DIV2K-fitted weights on the same inputs, so we keep both on the [0, 255] scale: the fitting transform multiplies by 255 (default ``denominator=1`` at fit time), and at evaluation we scale the denoised [0, 1] outputs to [0, 255] before passing them to either NIQE instance. We perform 5-fold cross-validation on the DIV2K validation set (80 fit / 20 test per fold) and compare original NIQE weights against DIV2K-fitted weights at noise level σ=0.05. A key finding is that the DIV2K-fitted NIQE assigns systematically higher (worse) scores to over-smoothed outputs (e.g. large median filters), reflecting that it is more sensitive to the loss of fine texture detail captured in the DIV2K prior. .. GENERATED FROM PYTHON SOURCE LINES 43-45 Setup ----- .. GENERATED FROM PYTHON SOURCE LINES 45-57 .. code-block:: Python import deepinv as dinv from deepinv.utils import plot import torch import numpy as np import matplotlib.pyplot as plt from torch.utils.data import Subset from torchvision.transforms import Compose, ToTensor, CenterCrop, Lambda from natsort import natsorted device = dinv.utils.get_device() .. rst-class:: sphx-glr-script-out .. code-block:: none Selected GPU 0 with 1606.25 MiB free memory .. GENERATED FROM PYTHON SOURCE LINES 58-63 Define transforms and load DIV2K --------------------------------- We create two instances of DIV2K with different transforms: one that scales pixel values to [0, 255] for fitting NIQE weights, and one that keeps values in [0, 1] for denoising. .. GENERATED FROM PYTHON SOURCE LINES 63-95 .. code-block:: Python crop_size = 1024 fit_transform = Compose( [ ToTensor(), CenterCrop(crop_size), Lambda(lambda x: x * 255), ] ) test_transform = Compose( [ ToTensor(), CenterCrop(crop_size), ] ) div2k_fit = dinv.datasets.DIV2K( root=dinv.utils.get_data_home(), mode="val", download=True, transform=fit_transform ) div2k_fit.x_paths = natsorted(div2k_fit.x_paths) div2k_test = dinv.datasets.DIV2K( root=dinv.utils.get_data_home(), mode="val", download=False, transform=test_transform, ) div2k_test.x_paths = natsorted(div2k_test.x_paths) n_images = len(div2k_fit) all_indices = list(range(n_images)) .. rst-class:: sphx-glr-script-out .. code-block:: none /local/jtachell/deepinv/deepinv/examples/metrics/demo_custom_niqe.py:82: DeprecationWarning: Function 'get_data_home' is deprecated and will be removed in a future version. root=dinv.utils.get_data_home(), mode="val", download=True, transform=fit_transform 0%| | 0/448993893 [00:00 dinv.loss.metric.NIQE: print(f" Fitting NIQE on {len(fit_subset)} images...") niqe = dinv.loss.metric.NIQE(weights_path=None, device="cpu") niqe.create_weights(fit_subset) return niqe .. GENERATED FROM PYTHON SOURCE LINES 140-145 Run 5-fold cross-validation ---------------------------- We split the 100 DIV2K validation images into 5 folds of 20 images each. For each fold, we fit NIQE on the remaining 80 images and evaluate on the held-out 20. .. GENERATED FROM PYTHON SOURCE LINES 145-183 .. code-block:: Python sigma = 0.05 fold_size = n_images // 5 results = { denoiser_name: {"original_niqe": [], "div2k_niqe": []} for denoiser_name in denoisers.keys() } torch.manual_seed(16 * 16) for fold in range(5): print(f"Fold {fold + 1} / 5") test_indices = all_indices[fold * fold_size : (fold + 1) * fold_size] fit_indices = ( all_indices[: fold * fold_size] + all_indices[(fold + 1) * fold_size :] ) fit_subset = Subset(div2k_fit, fit_indices) test_subset = Subset(div2k_test, test_indices) niqe_fitted = fit_niqe(fit_subset) for i, img in enumerate(test_subset): img = img.unsqueeze(0) noisy = img + sigma * torch.randn_like(img) images = {} with ( torch.no_grad(), torch.autocast(device_type=device.type, dtype=torch.float16), ): for name, denoiser in denoisers.items(): images[name] = denoiser(noisy.to(device), sigma).cpu() for name, im in images.items(): im_255 = im.to(torch.float32) * 255 results[name]["original_niqe"].append(float(niqe_original(im_255))) results[name]["div2k_niqe"].append(float(niqe_fitted(im_255))) .. rst-class:: sphx-glr-script-out .. code-block:: none Fold 1 / 5 Fitting NIQE on 80 images... Fold 2 / 5 Fitting NIQE on 80 images... Fold 3 / 5 Fitting NIQE on 80 images... Fold 4 / 5 Fitting NIQE on 80 images... Fold 5 / 5 Fitting NIQE on 80 images... .. GENERATED FROM PYTHON SOURCE LINES 184-196 Scatter plot: original vs DIV2K-fitted NIQE -------------------------------------------- Each point represents one test image, coloured by method. The x-axis shows the score under original weights and the y-axis shows the score under DIV2K-fitted weights. Points above the identity line are penalised *more* by the DIV2K prior. The median filters' NIQE score have a systematic upward shift: the DIV2K prior, fitted on higher-quality natural images, is more sensitive to over-smoothing and penalises the blurring introduced by large median filters more strongly than the original weights fitted on lower-quality natural images. DRUNet introduces less smoothing and has a less systematic shift. .. GENERATED FROM PYTHON SOURCE LINES 196-229 .. code-block:: Python fig, ax = plt.subplots(figsize=(9, 6)) all_orig, all_div2k = [], [] print( "Average relative change by utilizing DIV2K fitted NIQE instead of original NIQE:" ) for name in denoisers.keys(): x = np.array(results[name]["original_niqe"]) y = np.array(results[name]["div2k_niqe"]) avg_relative_shift = np.mean((y - x) / x) print(f"{name}: {float(avg_relative_shift) * 100:.3f} %") mask = np.isfinite(x) & np.isfinite(y) x, y = x[mask], y[mask] all_orig.append(x) all_div2k.append(y) ax.scatter(x, y, s=30, label=name, alpha=0.8) all_orig = np.concatenate(all_orig) all_div2k = np.concatenate(all_div2k) lim_min = min(all_orig.min(), all_div2k.min()) lim_max = max(all_orig.max(), all_div2k.max()) ax.plot([lim_min, lim_max], [lim_min, lim_max], "k--", linewidth=1, label="identity") ax.set_xlabel("NIQE with original weights") ax.set_ylabel("NIQE with DIV2K-fitted weights") ax.set_title( f"Per-image NIQE scores (σ = {sigma})\nPoints above the line are penalised more by the DIV2K prior" ) ax.legend() plt.tight_layout() plt.show() .. image-sg:: /auto_examples/metrics/images/sphx_glr_demo_custom_niqe_001.png :alt: Per-image NIQE scores (σ = 0.05) Points above the line are penalised more by the DIV2K prior :srcset: /auto_examples/metrics/images/sphx_glr_demo_custom_niqe_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none Average relative change by utilizing DIV2K fitted NIQE instead of original NIQE: DRUNet: -0.273 % Median (k=6): 10.488 % Median (k=9): 14.334 % .. GENERATED FROM PYTHON SOURCE LINES 230-233 Visual comparison between different denoisers --------------------------------------------- Finally, we visually confirm the blurring introduced by the median filters, which is absent in the ground-truth and DRUNet outputs. .. GENERATED FROM PYTHON SOURCE LINES 233-249 .. code-block:: Python methods_all = ["gt", "noisy"] + list(denoisers.keys()) c = crop_size // 2 sample_img = div2k_test[5][:, c - 128 : c + 128, c - 128 : c + 128].unsqueeze(0) sample_noisy = sample_img + sigma * torch.randn_like(sample_img) images = {"gt": sample_img, "noisy": sample_noisy} with torch.no_grad(), torch.autocast(device_type=device.type, dtype=torch.float16): for name, denoiser in denoisers.items(): images[name] = denoiser(sample_noisy.to(device), 0.05).cpu() plot( [images[m] for m in methods_all], titles=methods_all, vmin=0, vmax=1, rescale_mode="clip", ) .. image-sg:: /auto_examples/metrics/images/sphx_glr_demo_custom_niqe_002.png :alt: gt, noisy, DRUNet, Median (k=6), Median (k=9) :srcset: /auto_examples/metrics/images/sphx_glr_demo_custom_niqe_002.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-timing **Total running time of the script:** (2 minutes 10.413 seconds) .. _sphx_glr_download_auto_examples_metrics_demo_custom_niqe.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: demo_custom_niqe.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: demo_custom_niqe.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: demo_custom_niqe.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_