Super-resolution with SRResNet#

Single-image super-resolution (SISR) is the inverse problem of recovering a high-resolution (HR) image \(x\) from a low-resolution (LR) observation \(y = \downarrow_s(x)\), where \(\downarrow_s\) denotes downsampling by factor \(s\).

Unlike physics-aware methods in DeepInverse (iterative algorithms, unrolled networks, diffusion models) that require the forward operator at inference, SRResNet [1] is a direct feed-forward network: it maps LR images to HR estimates in a single forward pass without needing the degradation model at test time. Inference is simply model(y).

This example demonstrates:

  1. Inference with weights pretrained on DIV2K for 4× RGB bicubic super-resolution.

  2. Fine-tuning with deepinv.Trainer to show the model is fully trainable.

import torch
import matplotlib.pyplot as plt

import deepinv as dinv

device = dinv.utils.get_device()
Selected GPU 0 with 8559.25 MiB free memory

1. Inference with pretrained weights#

The default SRResNet was trained for RGB 4× super-resolution on DIV2K under the L1 loss, using DownsamplingMatlab (MATLAB-style bicubic downsampling, factor 4), ADAM with lr 5e-4 at batch size 16, and random 128×128 HR crops for 400 epochs.

Note

The pretrained checkpoint uses the default architecture with final_relu=True.

model = dinv.models.SRResNet(pretrained="download", final_relu=True).to(device)

We can visualise the training loss and validation PSNR on DIV2K stored inside the checkpoint.

Note

The validation PSNR in the checkpoint was computed on the luminance (Y) channel only, as is standard practice in SISR benchmarking. The PSNR values shown later in this example are computed on all RGB channels and will therefore differ.

ckpt = torch.hub.load_state_dict_from_url(
    "https://huggingface.co/deepinv/srresnet/resolve/main/srresnet_ckpt.pth.tar",
    file_name="srresnet_ckpt.pth.tar",
    map_location=device,
    weights_only=False,
)

loss_curve = ckpt["loss"]["SupLoss"]
psnr_curve = ckpt["eval_metrics"]["PSNR"]

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(9, 3))
ax1.plot(loss_curve)
ax1.set_xlabel("Epoch")
ax1.set_ylabel("L1 loss")
ax1.set_title("Training loss")
ax1.grid(True)
ax2.plot(tuple(range(0, 401, 20)), psnr_curve, marker="o", markersize=3)
ax2.set_xlabel("Epoch")
ax2.set_ylabel("PSNR (dB)")
ax2.set_title("DIV2K validation PSNR")
ax2.grid(True)
fig.tight_layout()
plt.show()
Training loss, DIV2K validation PSNR

Reconstruction on a DIV2K validation image#

We apply DownsamplingMatlab to match the training physics. Note that only y is passed to the model, no physics is needed at inference. We compare against a standard bicubic interpolation baseline.

physics = dinv.physics.DownsamplingMatlab(factor=4, device=device)
psnr = dinv.metric.PSNR()

x = dinv.utils.load_example("div2k_valid_hr_0877.png", img_size=256, device=device)
y = physics(x)

with torch.no_grad():
    x_hat = model(y)
    x_bic = torch.nn.functional.interpolate(
        y, scale_factor=4, mode="bicubic", antialias=True
    )

dinv.utils.plot(
    {"Ground truth": x, "Bicubic": x_bic, "SRResNet": x_hat},
    subtitles=[
        "PSNR (RGB):",
        f"{psnr(x, x_bic).item():.2f} dB",
        f"{psnr(x, x_hat).item():.2f} dB",
    ],
    figsize=(8, 4),
    rescale_mode="clip",
)
Ground truth, Bicubic, SRResNet

2. Fine-tuning#

SRResNet is fully trainable with Trainer. To demonstrate this, we fine-tune the pretrained model on a small subset of Urban100.

/local/jtachell/deepinv/deepinv/deepinv/datasets/datagenerator.py:600: UserWarning: Dataset ./dinv_dataset0.h5 already exists, this will close and overwrite the previous dataset.
  warn(
Dataset has been saved at ./dinv_dataset0.h5

Visualise a data sample:

x, y = next(iter(test_dataloader))
dinv.utils.plot({"Ground truth": x, "LR measurement": y}, rescale_mode="clip")
Ground truth, LR measurement

We fine-tune the pretrained model.

Note

With only a handful of images and no regularisation, the model will overfit after the first few epochs and eval PSNR will decrease. We therefore load the best checkpoint after training. For a proper fine-tuning run, use a larger and more diverse dataset.

epochs = 10 if torch.cuda.is_available() else 1

trainer = dinv.Trainer(
    model=model,
    physics=physics,
    optimizer=torch.optim.Adam(model.parameters(), lr=1e-4),
    train_dataloader=train_dataloader,
    eval_dataloader=test_dataloader,
    epochs=epochs,
    losses=dinv.loss.SupLoss(metric=torch.nn.L1Loss()),
    metrics=dinv.metric.PSNR(),
    device=device,
    plot_images=False,
    show_progress_bar=False,
)

_ = trainer.train()
best_model = trainer.load_best_model()
/local/jtachell/deepinv/deepinv/deepinv/training/trainer.py:1356: UserWarning: non_blocking_transfers=True but DataLoader.pin_memory=False; set pin_memory=True to overlap host-device copies with compute.
  self.setup_train()
The model has 1549462 trainable parameters
Train epoch 0: TotalLoss=0.061, PSNR=22.094
Eval epoch 0: PSNR=22.142
Best model saved at epoch 1
Train epoch 1: TotalLoss=0.058, PSNR=22.842
Eval epoch 1: PSNR=20.793
Train epoch 2: TotalLoss=0.056, PSNR=23.223
Eval epoch 2: PSNR=19.585
Train epoch 3: TotalLoss=0.055, PSNR=23.577
Eval epoch 3: PSNR=19.363
Train epoch 4: TotalLoss=0.053, PSNR=23.828
Eval epoch 4: PSNR=18.712
Train epoch 5: TotalLoss=0.052, PSNR=23.962
Eval epoch 5: PSNR=18.175
Train epoch 6: TotalLoss=0.051, PSNR=24.201
Eval epoch 6: PSNR=18.489
Train epoch 7: TotalLoss=0.051, PSNR=24.363
Eval epoch 7: PSNR=17.303
Train epoch 8: TotalLoss=0.05, PSNR=24.43
Eval epoch 8: PSNR=17.552
Train epoch 9: TotalLoss=0.05, PSNR=24.406
Eval epoch 9: PSNR=17.248
Model, optimizer, epoch_start successfully loaded from checkpoint: 26-05-24-16:00:06/ckp_best.pth.tar

Plot a reconstruction from the best checkpoint:

x, y = x.to(device), y.to(device)
with torch.no_grad():
    x_ft = best_model(y)
    x_bic_ft = torch.nn.functional.interpolate(
        y, scale_factor=4, mode="bicubic", antialias=True
    )

dinv.utils.plot(
    {"Ground truth": x, "Bicubic": x_bic_ft, "SRResNet": x_ft},
    subtitles=[
        "PSNR (RGB):",
        f"{psnr(x, x_bic_ft).item():.2f} dB",
        f"{psnr(x, x_ft).item():.2f} dB",
    ],
    figsize=(8, 4),
    rescale_mode="clip",
)
Ground truth, Bicubic, SRResNet
References:

Total running time of the script: (0 minutes 6.139 seconds)

Gallery generated by Sphinx-Gallery