Note
New to DeepInverse? Get started with the basics with the 5 minute quickstart tutorial..
Fitting NIQE on a custom dataset#
This example shows how to fit deepinv.loss.metric.NIQE on a new dataset, and use it
to evaluate denoiser performance.
NIQE is a no-reference image quality metric that compares local image statistics against those of pristine (distortion-free) images. Fitting NIQE on a domain-specific dataset can better capture the expected image characteristics.
In this example, we fit NIQE on DIV2K. DIV2K is also natural imaging data, but the image quality and sharpness is higher than the dataset NIQE was originally fitted on, so the resulting weights characterise a sharper, higher-quality prior while remaining valid NIQE statistics.
To apply this procedure to your own data, any dataset returning RGB or single-channel
tensors will work. The denominator constructor argument divides input pixels before
computing statistics; it serves two purposes: keeping pixel magnitudes from dominating
the local-statistics computation, and matching the input scale to the scale the weights
were fitted on. Two consequences:
When using the bundled original NIQE weights (which were fitted on [0, 255] data with
denominator=1), inputs must reach NIQE on a comparable [0, 255] scale. So for [0, 1] data, passdenominator=1/255(x / (1/255) = 255 * x), or scale to [0, 255] before calling and leavedenominator=1.When fitting your own weights, the only requirement is that the same
denominatoris used at fit and evaluation time. The absolute scale is up to you.
In this example we want to compare the original and DIV2K-fitted weights on the same
inputs, so we keep both on the [0, 255] scale: the fitting transform multiplies by 255
(default denominator=1 at fit time), and at evaluation we scale the denoised [0, 1]
outputs to [0, 255] before passing them to either NIQE instance.
We perform 5-fold cross-validation on the DIV2K validation set (80 fit / 20 test per fold) and compare original NIQE weights against DIV2K-fitted weights at noise level Ο=0.05. A key finding is that the DIV2K-fitted NIQE assigns systematically higher (worse) scores to over-smoothed outputs (e.g. large median filters), reflecting that it is more sensitive to the loss of fine texture detail captured in the DIV2K prior.
Setup#
Selected GPU 0 with 1606.25 MiB free memory
Define transforms and load DIV2K#
We create two instances of DIV2K with different transforms: one that scales pixel values to [0, 255] for fitting NIQE weights, and one that keeps values in [0, 1] for denoising.
crop_size = 1024
fit_transform = Compose(
[
ToTensor(),
CenterCrop(crop_size),
Lambda(lambda x: x * 255),
]
)
test_transform = Compose(
[
ToTensor(),
CenterCrop(crop_size),
]
)
div2k_fit = dinv.datasets.DIV2K(
root=dinv.utils.get_data_home(), mode="val", download=True, transform=fit_transform
)
div2k_fit.x_paths = natsorted(div2k_fit.x_paths)
div2k_test = dinv.datasets.DIV2K(
root=dinv.utils.get_data_home(),
mode="val",
download=False,
transform=test_transform,
)
div2k_test.x_paths = natsorted(div2k_test.x_paths)
n_images = len(div2k_fit)
all_indices = list(range(n_images))
/local/jtachell/deepinv/deepinv/examples/metrics/demo_custom_niqe.py:82: DeprecationWarning: Function 'get_data_home' is deprecated and will be removed in a future version.
root=dinv.utils.get_data_home(), mode="val", download=True, transform=fit_transform
0%| | 0/448993893 [00:00<?, ?it/s]
1%|β | 6.00M/428M [00:00<00:07, 62.7MB/s]
4%|β | 16.1M/428M [00:00<00:04, 88.0MB/s]
6%|β | 26.5M/428M [00:00<00:04, 97.2MB/s]
8%|β | 35.8M/428M [00:00<00:04, 93.7MB/s]
10%|β | 44.8M/428M [00:00<00:04, 87.1MB/s]
12%|ββ | 53.2M/428M [00:00<00:04, 82.5MB/s]
14%|ββ | 61.2M/428M [00:00<00:04, 81.0MB/s]
16%|ββ | 69.0M/428M [00:00<00:04, 81.0MB/s]
18%|ββ | 76.8M/428M [00:00<00:04, 80.5MB/s]
20%|ββ | 84.6M/428M [00:01<00:04, 80.9MB/s]
22%|βββ | 92.5M/428M [00:01<00:04, 81.3MB/s]
23%|βββ | 100M/428M [00:01<00:04, 81.8MB/s]
25%|βββ | 108M/428M [00:01<00:04, 80.9MB/s]
27%|βββ | 116M/428M [00:01<00:03, 82.4MB/s]
29%|βββ | 124M/428M [00:01<00:03, 82.2MB/s]
31%|βββ | 133M/428M [00:01<00:03, 83.7MB/s]
33%|ββββ | 141M/428M [00:01<00:03, 77.4MB/s]
35%|ββββ | 148M/428M [00:01<00:04, 72.1MB/s]
36%|ββββ | 155M/428M [00:02<00:04, 69.4MB/s]
38%|ββββ | 162M/428M [00:02<00:04, 68.3MB/s]
39%|ββββ | 169M/428M [00:02<00:04, 67.5MB/s]
41%|ββββ | 175M/428M [00:02<00:03, 67.2MB/s]
42%|βββββ | 182M/428M [00:02<00:03, 66.2MB/s]
44%|βββββ | 188M/428M [00:02<00:03, 66.7MB/s]
46%|βββββ | 195M/428M [00:02<00:03, 67.1MB/s]
47%|βββββ | 202M/428M [00:02<00:03, 68.1MB/s]
49%|βββββ | 209M/428M [00:02<00:03, 69.2MB/s]
50%|βββββ | 216M/428M [00:02<00:03, 70.3MB/s]
52%|ββββββ | 222M/428M [00:03<00:03, 70.1MB/s]
54%|ββββββ | 230M/428M [00:03<00:02, 71.2MB/s]
55%|ββββββ | 237M/428M [00:03<00:02, 72.4MB/s]
57%|ββββββ | 244M/428M [00:03<00:02, 73.0MB/s]
59%|ββββββ | 251M/428M [00:03<00:02, 74.1MB/s]
60%|ββββββ | 258M/428M [00:03<00:02, 74.5MB/s]
62%|βββββββ | 266M/428M [00:03<00:02, 76.1MB/s]
64%|βββββββ | 273M/428M [00:03<00:02, 75.3MB/s]
66%|βββββββ | 281M/428M [00:03<00:02, 76.9MB/s]
67%|βββββββ | 289M/428M [00:03<00:01, 78.1MB/s]
69%|βββββββ | 297M/428M [00:04<00:01, 79.4MB/s]
71%|βββββββ | 305M/428M [00:04<00:01, 80.5MB/s]
73%|ββββββββ | 313M/428M [00:04<00:01, 81.2MB/s]
75%|ββββββββ | 321M/428M [00:04<00:01, 82.2MB/s]
77%|ββββββββ | 329M/428M [00:04<00:01, 83.6MB/s]
79%|ββββββββ | 338M/428M [00:04<00:01, 84.6MB/s]
81%|ββββββββ | 346M/428M [00:04<00:01, 85.7MB/s]
83%|βββββββββ | 354M/428M [00:04<00:00, 86.6MB/s]
85%|βββββββββ | 363M/428M [00:04<00:00, 87.4MB/s]
87%|βββββββββ | 372M/428M [00:04<00:00, 88.4MB/s]
89%|βββββββββ | 380M/428M [00:05<00:00, 88.9MB/s]
91%|βββββββββ | 389M/428M [00:05<00:00, 90.1MB/s]
93%|ββββββββββ| 398M/428M [00:05<00:00, 91.3MB/s]
95%|ββββββββββ| 407M/428M [00:05<00:00, 92.3MB/s]
97%|ββββββββββ| 416M/428M [00:05<00:00, 93.2MB/s]
99%|ββββββββββ| 426M/428M [00:05<00:00, 94.3MB/s]
100%|ββββββββββ| 428M/428M [00:05<00:00, 79.9MB/s]
Extracting: 0%| | 0/101 [00:00<?, ?it/s]
Extracting: 7%|β | 7/101 [00:00<00:01, 63.76it/s]
Extracting: 15%|ββ | 15/101 [00:00<00:01, 68.73it/s]
Extracting: 22%|βββ | 22/101 [00:00<00:01, 66.92it/s]
Extracting: 29%|βββ | 29/101 [00:00<00:01, 67.84it/s]
Extracting: 36%|ββββ | 36/101 [00:00<00:01, 55.54it/s]
Extracting: 44%|βββββ | 44/101 [00:00<00:00, 58.62it/s]
Extracting: 50%|βββββ | 51/101 [00:00<00:00, 61.24it/s]
Extracting: 57%|ββββββ | 58/101 [00:00<00:00, 63.66it/s]
Extracting: 64%|βββββββ | 65/101 [00:01<00:00, 63.91it/s]
Extracting: 71%|ββββββββ | 72/101 [00:01<00:00, 61.51it/s]
Extracting: 79%|ββββββββ | 80/101 [00:01<00:00, 65.54it/s]
Extracting: 86%|βββββββββ | 87/101 [00:01<00:00, 64.69it/s]
Extracting: 93%|ββββββββββ| 94/101 [00:01<00:00, 64.32it/s]
Extracting: 100%|ββββββββββ| 101/101 [00:01<00:00, 60.37it/s]
Extracting: 100%|ββββββββββ| 101/101 [00:01<00:00, 62.45it/s]
Dataset has been successfully downloaded.
/local/jtachell/deepinv/deepinv/examples/metrics/demo_custom_niqe.py:86: DeprecationWarning: Function 'get_data_home' is deprecated and will be removed in a future version.
root=dinv.utils.get_data_home(),
Define denoisers#
We compare DRUNet against MedianFilter at two kernel sizes. At inference time we wrap
each call in torch.autocast(..., dtype=torch.float16) so DRUNet fits in GPU memory
at the 1024Γ1024 crop size used here.
denoisers = {
"DRUNet": dinv.models.DRUNet(pretrained="download", device=device),
"Median (k=6)": dinv.models.MedianFilter(kernel_size=6),
"Median (k=9)": dinv.models.MedianFilter(kernel_size=9),
}
Load original NIQE weights#
Constructing deepinv.loss.metric.NIQE without an explicit weights_path
loads the original published NIQE weights bundled with the package. We use this
instance as the baseline for comparison against our custom-fitted weights.
niqe_original = dinv.loss.metric.NIQE(device="cpu")
Fit NIQE and save/load weights#
To fit NIQE on a custom dataset, we construct a NIQE instance with weights_path=None
(which skips loading the bundled weights, as they would be overwritten anyway) and call
deepinv.loss.metric.NIQE.create_weights() with a dataset of pristine images.
This populates the instanceβs statistics in-place. The same object can then be called
directly to score new images.
To persist the fitted weights for reuse, pass save_path="my_weights.pt" to
create_weights. In a later session, load them back via
NIQE(weights_path="my_weights.pt"). Here we do not save: weights are computed
on each fold of the cross-validation below.
def fit_niqe(fit_subset: Subset) -> dinv.loss.metric.NIQE:
print(f" Fitting NIQE on {len(fit_subset)} images...")
niqe = dinv.loss.metric.NIQE(weights_path=None, device="cpu")
niqe.create_weights(fit_subset)
return niqe
Run 5-fold cross-validation#
We split the 100 DIV2K validation images into 5 folds of 20 images each. For each fold, we fit NIQE on the remaining 80 images and evaluate on the held-out 20.
sigma = 0.05
fold_size = n_images // 5
results = {
denoiser_name: {"original_niqe": [], "div2k_niqe": []}
for denoiser_name in denoisers.keys()
}
torch.manual_seed(16 * 16)
for fold in range(5):
print(f"Fold {fold + 1} / 5")
test_indices = all_indices[fold * fold_size : (fold + 1) * fold_size]
fit_indices = (
all_indices[: fold * fold_size] + all_indices[(fold + 1) * fold_size :]
)
fit_subset = Subset(div2k_fit, fit_indices)
test_subset = Subset(div2k_test, test_indices)
niqe_fitted = fit_niqe(fit_subset)
for i, img in enumerate(test_subset):
img = img.unsqueeze(0)
noisy = img + sigma * torch.randn_like(img)
images = {}
with (
torch.no_grad(),
torch.autocast(device_type=device.type, dtype=torch.float16),
):
for name, denoiser in denoisers.items():
images[name] = denoiser(noisy.to(device), sigma).cpu()
for name, im in images.items():
im_255 = im.to(torch.float32) * 255
results[name]["original_niqe"].append(float(niqe_original(im_255)))
results[name]["div2k_niqe"].append(float(niqe_fitted(im_255)))
Fold 1 / 5
Fitting NIQE on 80 images...
Fold 2 / 5
Fitting NIQE on 80 images...
Fold 3 / 5
Fitting NIQE on 80 images...
Fold 4 / 5
Fitting NIQE on 80 images...
Fold 5 / 5
Fitting NIQE on 80 images...
Scatter plot: original vs DIV2K-fitted NIQE#
Each point represents one test image, coloured by method. The x-axis shows the score under original weights and the y-axis shows the score under DIV2K-fitted weights. Points above the identity line are penalised more by the DIV2K prior.
The median filtersβ NIQE score have a systematic upward shift: the DIV2K prior, fitted on higher-quality natural images, is more sensitive to over-smoothing and penalises the blurring introduced by large median filters more strongly than the original weights fitted on lower-quality natural images.
DRUNet introduces less smoothing and has a less systematic shift.
fig, ax = plt.subplots(figsize=(9, 6))
all_orig, all_div2k = [], []
print(
"Average relative change by utilizing DIV2K fitted NIQE instead of original NIQE:"
)
for name in denoisers.keys():
x = np.array(results[name]["original_niqe"])
y = np.array(results[name]["div2k_niqe"])
avg_relative_shift = np.mean((y - x) / x)
print(f"{name}: {float(avg_relative_shift) * 100:.3f} %")
mask = np.isfinite(x) & np.isfinite(y)
x, y = x[mask], y[mask]
all_orig.append(x)
all_div2k.append(y)
ax.scatter(x, y, s=30, label=name, alpha=0.8)
all_orig = np.concatenate(all_orig)
all_div2k = np.concatenate(all_div2k)
lim_min = min(all_orig.min(), all_div2k.min())
lim_max = max(all_orig.max(), all_div2k.max())
ax.plot([lim_min, lim_max], [lim_min, lim_max], "k--", linewidth=1, label="identity")
ax.set_xlabel("NIQE with original weights")
ax.set_ylabel("NIQE with DIV2K-fitted weights")
ax.set_title(
f"Per-image NIQE scores (Ο = {sigma})\nPoints above the line are penalised more by the DIV2K prior"
)
ax.legend()
plt.tight_layout()
plt.show()

Average relative change by utilizing DIV2K fitted NIQE instead of original NIQE:
DRUNet: -0.273 %
Median (k=6): 10.488 %
Median (k=9): 14.334 %
Visual comparison between different denoisers#
Finally, we visually confirm the blurring introduced by the median filters, which is absent in the ground-truth and DRUNet outputs.
methods_all = ["gt", "noisy"] + list(denoisers.keys())
c = crop_size // 2
sample_img = div2k_test[5][:, c - 128 : c + 128, c - 128 : c + 128].unsqueeze(0)
sample_noisy = sample_img + sigma * torch.randn_like(sample_img)
images = {"gt": sample_img, "noisy": sample_noisy}
with torch.no_grad(), torch.autocast(device_type=device.type, dtype=torch.float16):
for name, denoiser in denoisers.items():
images[name] = denoiser(sample_noisy.to(device), 0.05).cpu()
plot(
[images[m] for m in methods_all],
titles=methods_all,
vmin=0,
vmax=1,
rescale_mode="clip",
)

Total running time of the script: (2 minutes 10.413 seconds)