Fitting NIQE on a custom dataset#

This example shows how to fit deepinv.loss.metric.NIQE on a new dataset, and use it to evaluate denoiser performance.

NIQE is a no-reference image quality metric that compares local image statistics against those of pristine (distortion-free) images. Fitting NIQE on a domain-specific dataset can better capture the expected image characteristics.

In this example, we fit NIQE on DIV2K. DIV2K is also natural imaging data, but the image quality and sharpness is higher than the dataset NIQE was originally fitted on, so the resulting weights characterise a sharper, higher-quality prior while remaining valid NIQE statistics.

To apply this procedure to your own data, any dataset returning RGB or single-channel tensors will work. The denominator constructor argument divides input pixels before computing statistics; it serves two purposes: keeping pixel magnitudes from dominating the local-statistics computation, and matching the input scale to the scale the weights were fitted on. Two consequences:

  • When using the bundled original NIQE weights (which were fitted on [0, 255] data with denominator=1), inputs must reach NIQE on a comparable [0, 255] scale. So for [0, 1] data, pass denominator=1/255 (x / (1/255) = 255 * x), or scale to [0, 255] before calling and leave denominator=1.

  • When fitting your own weights, the only requirement is that the same denominator is used at fit and evaluation time. The absolute scale is up to you.

In this example we want to compare the original and DIV2K-fitted weights on the same inputs, so we keep both on the [0, 255] scale: the fitting transform multiplies by 255 (default denominator=1 at fit time), and at evaluation we scale the denoised [0, 1] outputs to [0, 255] before passing them to either NIQE instance.

We perform 5-fold cross-validation on the DIV2K validation set (80 fit / 20 test per fold) and compare original NIQE weights against DIV2K-fitted weights at noise level Οƒ=0.05. A key finding is that the DIV2K-fitted NIQE assigns systematically higher (worse) scores to over-smoothed outputs (e.g. large median filters), reflecting that it is more sensitive to the loss of fine texture detail captured in the DIV2K prior.

Setup#

import deepinv as dinv
from deepinv.utils import plot
import torch
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Subset
from torchvision.transforms import Compose, ToTensor, CenterCrop, Lambda
from natsort import natsorted

device = dinv.utils.get_device()
Selected GPU 0 with 1606.25 MiB free memory

Define transforms and load DIV2K#

We create two instances of DIV2K with different transforms: one that scales pixel values to [0, 255] for fitting NIQE weights, and one that keeps values in [0, 1] for denoising.

crop_size = 1024

fit_transform = Compose(
    [
        ToTensor(),
        CenterCrop(crop_size),
        Lambda(lambda x: x * 255),
    ]
)

test_transform = Compose(
    [
        ToTensor(),
        CenterCrop(crop_size),
    ]
)

div2k_fit = dinv.datasets.DIV2K(
    root=dinv.utils.get_data_home(), mode="val", download=True, transform=fit_transform
)
div2k_fit.x_paths = natsorted(div2k_fit.x_paths)
div2k_test = dinv.datasets.DIV2K(
    root=dinv.utils.get_data_home(),
    mode="val",
    download=False,
    transform=test_transform,
)
div2k_test.x_paths = natsorted(div2k_test.x_paths)
n_images = len(div2k_fit)
all_indices = list(range(n_images))
/local/jtachell/deepinv/deepinv/examples/metrics/demo_custom_niqe.py:82: DeprecationWarning: Function 'get_data_home' is deprecated and will be removed in a future version.
  root=dinv.utils.get_data_home(), mode="val", download=True, transform=fit_transform

  0%|          | 0/448993893 [00:00<?, ?it/s]
  1%|▏         | 6.00M/428M [00:00<00:07, 62.7MB/s]
  4%|▍         | 16.1M/428M [00:00<00:04, 88.0MB/s]
  6%|β–Œ         | 26.5M/428M [00:00<00:04, 97.2MB/s]
  8%|β–Š         | 35.8M/428M [00:00<00:04, 93.7MB/s]
 10%|β–ˆ         | 44.8M/428M [00:00<00:04, 87.1MB/s]
 12%|β–ˆβ–        | 53.2M/428M [00:00<00:04, 82.5MB/s]
 14%|β–ˆβ–        | 61.2M/428M [00:00<00:04, 81.0MB/s]
 16%|β–ˆβ–Œ        | 69.0M/428M [00:00<00:04, 81.0MB/s]
 18%|β–ˆβ–Š        | 76.8M/428M [00:00<00:04, 80.5MB/s]
 20%|β–ˆβ–‰        | 84.6M/428M [00:01<00:04, 80.9MB/s]
 22%|β–ˆβ–ˆβ–       | 92.5M/428M [00:01<00:04, 81.3MB/s]
 23%|β–ˆβ–ˆβ–Ž       | 100M/428M [00:01<00:04, 81.8MB/s]
 25%|β–ˆβ–ˆβ–Œ       | 108M/428M [00:01<00:04, 80.9MB/s]
 27%|β–ˆβ–ˆβ–‹       | 116M/428M [00:01<00:03, 82.4MB/s]
 29%|β–ˆβ–ˆβ–‰       | 124M/428M [00:01<00:03, 82.2MB/s]
 31%|β–ˆβ–ˆβ–ˆ       | 133M/428M [00:01<00:03, 83.7MB/s]
 33%|β–ˆβ–ˆβ–ˆβ–Ž      | 141M/428M [00:01<00:03, 77.4MB/s]
 35%|β–ˆβ–ˆβ–ˆβ–      | 148M/428M [00:01<00:04, 72.1MB/s]
 36%|β–ˆβ–ˆβ–ˆβ–‹      | 155M/428M [00:02<00:04, 69.4MB/s]
 38%|β–ˆβ–ˆβ–ˆβ–Š      | 162M/428M [00:02<00:04, 68.3MB/s]
 39%|β–ˆβ–ˆβ–ˆβ–‰      | 169M/428M [00:02<00:04, 67.5MB/s]
 41%|β–ˆβ–ˆβ–ˆβ–ˆ      | 175M/428M [00:02<00:03, 67.2MB/s]
 42%|β–ˆβ–ˆβ–ˆβ–ˆβ–     | 182M/428M [00:02<00:03, 66.2MB/s]
 44%|β–ˆβ–ˆβ–ˆβ–ˆβ–     | 188M/428M [00:02<00:03, 66.7MB/s]
 46%|β–ˆβ–ˆβ–ˆβ–ˆβ–Œ     | 195M/428M [00:02<00:03, 67.1MB/s]
 47%|β–ˆβ–ˆβ–ˆβ–ˆβ–‹     | 202M/428M [00:02<00:03, 68.1MB/s]
 49%|β–ˆβ–ˆβ–ˆβ–ˆβ–Š     | 209M/428M [00:02<00:03, 69.2MB/s]
 50%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆ     | 216M/428M [00:02<00:03, 70.3MB/s]
 52%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–    | 222M/428M [00:03<00:03, 70.1MB/s]
 54%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Ž    | 230M/428M [00:03<00:02, 71.2MB/s]
 55%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ    | 237M/428M [00:03<00:02, 72.4MB/s]
 57%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‹    | 244M/428M [00:03<00:02, 73.0MB/s]
 59%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Š    | 251M/428M [00:03<00:02, 74.1MB/s]
 60%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ    | 258M/428M [00:03<00:02, 74.5MB/s]
 62%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–   | 266M/428M [00:03<00:02, 76.1MB/s]
 64%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–   | 273M/428M [00:03<00:02, 75.3MB/s]
 66%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ   | 281M/428M [00:03<00:02, 76.9MB/s]
 67%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‹   | 289M/428M [00:03<00:01, 78.1MB/s]
 69%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‰   | 297M/428M [00:04<00:01, 79.4MB/s]
 71%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ   | 305M/428M [00:04<00:01, 80.5MB/s]
 73%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Ž  | 313M/428M [00:04<00:01, 81.2MB/s]
 75%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–  | 321M/428M [00:04<00:01, 82.2MB/s]
 77%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‹  | 329M/428M [00:04<00:01, 83.6MB/s]
 79%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‰  | 338M/428M [00:04<00:01, 84.6MB/s]
 81%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ  | 346M/428M [00:04<00:01, 85.7MB/s]
 83%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Ž | 354M/428M [00:04<00:00, 86.6MB/s]
 85%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ– | 363M/428M [00:04<00:00, 87.4MB/s]
 87%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‹ | 372M/428M [00:04<00:00, 88.4MB/s]
 89%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‰ | 380M/428M [00:05<00:00, 88.9MB/s]
 91%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ | 389M/428M [00:05<00:00, 90.1MB/s]
 93%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Ž| 398M/428M [00:05<00:00, 91.3MB/s]
 95%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ| 407M/428M [00:05<00:00, 92.3MB/s]
 97%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‹| 416M/428M [00:05<00:00, 93.2MB/s]
 99%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‰| 426M/428M [00:05<00:00, 94.3MB/s]
100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 428M/428M [00:05<00:00, 79.9MB/s]

Extracting:   0%|          | 0/101 [00:00<?, ?it/s]
Extracting:   7%|β–‹         | 7/101 [00:00<00:01, 63.76it/s]
Extracting:  15%|β–ˆβ–        | 15/101 [00:00<00:01, 68.73it/s]
Extracting:  22%|β–ˆβ–ˆβ–       | 22/101 [00:00<00:01, 66.92it/s]
Extracting:  29%|β–ˆβ–ˆβ–Š       | 29/101 [00:00<00:01, 67.84it/s]
Extracting:  36%|β–ˆβ–ˆβ–ˆβ–Œ      | 36/101 [00:00<00:01, 55.54it/s]
Extracting:  44%|β–ˆβ–ˆβ–ˆβ–ˆβ–Ž     | 44/101 [00:00<00:00, 58.62it/s]
Extracting:  50%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆ     | 51/101 [00:00<00:00, 61.24it/s]
Extracting:  57%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‹    | 58/101 [00:00<00:00, 63.66it/s]
Extracting:  64%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–   | 65/101 [00:01<00:00, 63.91it/s]
Extracting:  71%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–  | 72/101 [00:01<00:00, 61.51it/s]
Extracting:  79%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–‰  | 80/101 [00:01<00:00, 65.54it/s]
Extracting:  86%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Œ | 87/101 [00:01<00:00, 64.69it/s]
Extracting:  93%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–Ž| 94/101 [00:01<00:00, 64.32it/s]
Extracting: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 101/101 [00:01<00:00, 60.37it/s]
Extracting: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 101/101 [00:01<00:00, 62.45it/s]
Dataset has been successfully downloaded.
/local/jtachell/deepinv/deepinv/examples/metrics/demo_custom_niqe.py:86: DeprecationWarning: Function 'get_data_home' is deprecated and will be removed in a future version.
  root=dinv.utils.get_data_home(),

Define denoisers#

We compare DRUNet against MedianFilter at two kernel sizes. At inference time we wrap each call in torch.autocast(..., dtype=torch.float16) so DRUNet fits in GPU memory at the 1024Γ—1024 crop size used here.

denoisers = {
    "DRUNet": dinv.models.DRUNet(pretrained="download", device=device),
    "Median (k=6)": dinv.models.MedianFilter(kernel_size=6),
    "Median (k=9)": dinv.models.MedianFilter(kernel_size=9),
}

Load original NIQE weights#

Constructing deepinv.loss.metric.NIQE without an explicit weights_path loads the original published NIQE weights bundled with the package. We use this instance as the baseline for comparison against our custom-fitted weights.

Fit NIQE and save/load weights#

To fit NIQE on a custom dataset, we construct a NIQE instance with weights_path=None (which skips loading the bundled weights, as they would be overwritten anyway) and call deepinv.loss.metric.NIQE.create_weights() with a dataset of pristine images. This populates the instance’s statistics in-place. The same object can then be called directly to score new images.

To persist the fitted weights for reuse, pass save_path="my_weights.pt" to create_weights. In a later session, load them back via NIQE(weights_path="my_weights.pt"). Here we do not save: weights are computed on each fold of the cross-validation below.

def fit_niqe(fit_subset: Subset) -> dinv.loss.metric.NIQE:
    print(f"  Fitting NIQE on {len(fit_subset)} images...")
    niqe = dinv.loss.metric.NIQE(weights_path=None, device="cpu")
    niqe.create_weights(fit_subset)
    return niqe

Run 5-fold cross-validation#

We split the 100 DIV2K validation images into 5 folds of 20 images each. For each fold, we fit NIQE on the remaining 80 images and evaluate on the held-out 20.

sigma = 0.05
fold_size = n_images // 5
results = {
    denoiser_name: {"original_niqe": [], "div2k_niqe": []}
    for denoiser_name in denoisers.keys()
}
torch.manual_seed(16 * 16)
for fold in range(5):
    print(f"Fold {fold + 1} / 5")

    test_indices = all_indices[fold * fold_size : (fold + 1) * fold_size]
    fit_indices = (
        all_indices[: fold * fold_size] + all_indices[(fold + 1) * fold_size :]
    )

    fit_subset = Subset(div2k_fit, fit_indices)
    test_subset = Subset(div2k_test, test_indices)

    niqe_fitted = fit_niqe(fit_subset)

    for i, img in enumerate(test_subset):
        img = img.unsqueeze(0)
        noisy = img + sigma * torch.randn_like(img)

        images = {}
        with (
            torch.no_grad(),
            torch.autocast(device_type=device.type, dtype=torch.float16),
        ):
            for name, denoiser in denoisers.items():
                images[name] = denoiser(noisy.to(device), sigma).cpu()

        for name, im in images.items():
            im_255 = im.to(torch.float32) * 255
            results[name]["original_niqe"].append(float(niqe_original(im_255)))
            results[name]["div2k_niqe"].append(float(niqe_fitted(im_255)))
Fold 1 / 5
  Fitting NIQE on 80 images...
Fold 2 / 5
  Fitting NIQE on 80 images...
Fold 3 / 5
  Fitting NIQE on 80 images...
Fold 4 / 5
  Fitting NIQE on 80 images...
Fold 5 / 5
  Fitting NIQE on 80 images...

Scatter plot: original vs DIV2K-fitted NIQE#

Each point represents one test image, coloured by method. The x-axis shows the score under original weights and the y-axis shows the score under DIV2K-fitted weights. Points above the identity line are penalised more by the DIV2K prior.

The median filters’ NIQE score have a systematic upward shift: the DIV2K prior, fitted on higher-quality natural images, is more sensitive to over-smoothing and penalises the blurring introduced by large median filters more strongly than the original weights fitted on lower-quality natural images.

DRUNet introduces less smoothing and has a less systematic shift.

fig, ax = plt.subplots(figsize=(9, 6))

all_orig, all_div2k = [], []
print(
    "Average relative change by utilizing DIV2K fitted NIQE instead of original NIQE:"
)
for name in denoisers.keys():
    x = np.array(results[name]["original_niqe"])
    y = np.array(results[name]["div2k_niqe"])
    avg_relative_shift = np.mean((y - x) / x)
    print(f"{name}: {float(avg_relative_shift) * 100:.3f} %")
    mask = np.isfinite(x) & np.isfinite(y)
    x, y = x[mask], y[mask]
    all_orig.append(x)
    all_div2k.append(y)
    ax.scatter(x, y, s=30, label=name, alpha=0.8)

all_orig = np.concatenate(all_orig)
all_div2k = np.concatenate(all_div2k)
lim_min = min(all_orig.min(), all_div2k.min())
lim_max = max(all_orig.max(), all_div2k.max())
ax.plot([lim_min, lim_max], [lim_min, lim_max], "k--", linewidth=1, label="identity")

ax.set_xlabel("NIQE with original weights")
ax.set_ylabel("NIQE with DIV2K-fitted weights")
ax.set_title(
    f"Per-image NIQE scores (Οƒ = {sigma})\nPoints above the line are penalised more by the DIV2K prior"
)
ax.legend()
plt.tight_layout()
plt.show()
Per-image NIQE scores (Οƒ = 0.05) Points above the line are penalised more by the DIV2K prior
Average relative change by utilizing DIV2K fitted NIQE instead of original NIQE:
DRUNet: -0.273 %
Median (k=6): 10.488 %
Median (k=9): 14.334 %

Visual comparison between different denoisers#

Finally, we visually confirm the blurring introduced by the median filters, which is absent in the ground-truth and DRUNet outputs.

methods_all = ["gt", "noisy"] + list(denoisers.keys())
c = crop_size // 2
sample_img = div2k_test[5][:, c - 128 : c + 128, c - 128 : c + 128].unsqueeze(0)
sample_noisy = sample_img + sigma * torch.randn_like(sample_img)
images = {"gt": sample_img, "noisy": sample_noisy}
with torch.no_grad(), torch.autocast(device_type=device.type, dtype=torch.float16):
    for name, denoiser in denoisers.items():
        images[name] = denoiser(sample_noisy.to(device), 0.05).cpu()
plot(
    [images[m] for m in methods_all],
    titles=methods_all,
    vmin=0,
    vmax=1,
    rescale_mode="clip",
)
gt, noisy, DRUNet, Median (k=6), Median (k=9)

Total running time of the script: (2 minutes 10.413 seconds)

Gallery generated by Sphinx-Gallery