.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/models/demo_super_resolution.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note New to DeepInverse? Get started with the basics with the :ref:`5 minute quickstart tutorial `.. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_models_demo_super_resolution.py: Super-resolution with SRResNet =============================== Single-image super-resolution (SISR) is the inverse problem of recovering a high-resolution (HR) image :math:`x` from a low-resolution (LR) observation :math:`y = \downarrow_s(x)`, where :math:`\downarrow_s` denotes downsampling by factor :math:`s`. Unlike physics-aware methods in DeepInverse (iterative algorithms, unrolled networks, diffusion models) that require the forward operator at inference, :class:`SRResNet ` :footcite:p:`ledig2017photo` is a *direct* feed-forward network: it maps LR images to HR estimates in a single forward pass without needing the degradation model at test time. Inference is simply ``model(y)``. This example demonstrates: 1. **Inference** with weights pretrained on DIV2K for 4× RGB bicubic super-resolution. 2. **Fine-tuning** with :class:`deepinv.Trainer` to show the model is fully trainable. .. GENERATED FROM PYTHON SOURCE LINES 24-31 .. code-block:: Python import torch import matplotlib.pyplot as plt import deepinv as dinv device = dinv.utils.get_device() .. rst-class:: sphx-glr-script-out .. code-block:: none Selected GPU 0 with 8559.25 MiB free memory .. GENERATED FROM PYTHON SOURCE LINES 32-43 1. Inference with pretrained weights ------------------------------------- The default SRResNet was trained for RGB 4× super-resolution on DIV2K under the L1 loss, using :class:`~deepinv.physics.DownsamplingMatlab` (MATLAB-style bicubic downsampling, factor 4), ADAM with lr 5e-4 at batch size 16, and random 128×128 HR crops for 400 epochs. .. note:: The pretrained checkpoint uses the default architecture with ``final_relu=True``. .. GENERATED FROM PYTHON SOURCE LINES 43-46 .. code-block:: Python model = dinv.models.SRResNet(pretrained="download", final_relu=True).to(device) .. GENERATED FROM PYTHON SOURCE LINES 47-54 We can visualise the training loss and validation PSNR on DIV2K stored inside the checkpoint. .. note:: The validation PSNR in the checkpoint was computed on the luminance (Y) channel only, as is standard practice in SISR benchmarking. The PSNR values shown later in this example are computed on all RGB channels and will therefore differ. .. GENERATED FROM PYTHON SOURCE LINES 54-79 .. code-block:: Python ckpt = torch.hub.load_state_dict_from_url( "https://huggingface.co/deepinv/srresnet/resolve/main/srresnet_ckpt.pth.tar", file_name="srresnet_ckpt.pth.tar", map_location=device, weights_only=False, ) loss_curve = ckpt["loss"]["SupLoss"] psnr_curve = ckpt["eval_metrics"]["PSNR"] fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(9, 3)) ax1.plot(loss_curve) ax1.set_xlabel("Epoch") ax1.set_ylabel("L1 loss") ax1.set_title("Training loss") ax1.grid(True) ax2.plot(tuple(range(0, 401, 20)), psnr_curve, marker="o", markersize=3) ax2.set_xlabel("Epoch") ax2.set_ylabel("PSNR (dB)") ax2.set_title("DIV2K validation PSNR") ax2.grid(True) fig.tight_layout() plt.show() .. image-sg:: /auto_examples/models/images/sphx_glr_demo_super_resolution_001.png :alt: Training loss, DIV2K validation PSNR :srcset: /auto_examples/models/images/sphx_glr_demo_super_resolution_001.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 80-87 Reconstruction on a DIV2K validation image ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ We apply :class:`~deepinv.physics.DownsamplingMatlab` to match the training physics. Note that only ``y`` is passed to the model, no physics is needed at inference. We compare against a standard bicubic interpolation baseline. .. GENERATED FROM PYTHON SOURCE LINES 87-111 .. code-block:: Python physics = dinv.physics.DownsamplingMatlab(factor=4, device=device) psnr = dinv.metric.PSNR() x = dinv.utils.load_example("div2k_valid_hr_0877.png", img_size=256, device=device) y = physics(x) with torch.no_grad(): x_hat = model(y) x_bic = torch.nn.functional.interpolate( y, scale_factor=4, mode="bicubic", antialias=True ) dinv.utils.plot( {"Ground truth": x, "Bicubic": x_bic, "SRResNet": x_hat}, subtitles=[ "PSNR (RGB):", f"{psnr(x, x_bic).item():.2f} dB", f"{psnr(x, x_hat).item():.2f} dB", ], figsize=(8, 4), rescale_mode="clip", ) .. image-sg:: /auto_examples/models/images/sphx_glr_demo_super_resolution_002.png :alt: Ground truth, Bicubic, SRResNet :srcset: /auto_examples/models/images/sphx_glr_demo_super_resolution_002.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 112-118 2. Fine-tuning -------------- SRResNet is fully trainable with :class:`~deepinv.Trainer`. To demonstrate this, we fine-tune the pretrained model on a small subset of Urban100. .. GENERATED FROM PYTHON SOURCE LINES 118-149 .. code-block:: Python from torchvision.transforms import CenterCrop, Compose, ToTensor torch.manual_seed(16) hr_size = 64 dataset = dinv.datasets.Urban100HR( dinv.utils.get_cache_home() / "datasets" / "Urban100", download=True, transform=Compose([ToTensor(), CenterCrop(hr_size)]), ) train_dataset, test_dataset = torch.utils.data.random_split( torch.utils.data.Subset(dataset, range(10)), (0.8, 0.2) ) dataset_path = dinv.datasets.generate_dataset( train_dataset=train_dataset, test_dataset=test_dataset, physics=physics, device=device, save_dir=".", batch_size=1, ) train_dataloader = torch.utils.data.DataLoader( dinv.datasets.HDF5Dataset(dataset_path, train=True), shuffle=True ) test_dataloader = torch.utils.data.DataLoader( dinv.datasets.HDF5Dataset(dataset_path, train=False), shuffle=False ) .. rst-class:: sphx-glr-script-out .. code-block:: none /local/jtachell/deepinv/deepinv/deepinv/datasets/datagenerator.py:600: UserWarning: Dataset ./dinv_dataset0.h5 already exists, this will close and overwrite the previous dataset. warn( Dataset has been saved at ./dinv_dataset0.h5 .. GENERATED FROM PYTHON SOURCE LINES 150-152 Visualise a data sample: .. GENERATED FROM PYTHON SOURCE LINES 152-156 .. code-block:: Python x, y = next(iter(test_dataloader)) dinv.utils.plot({"Ground truth": x, "LR measurement": y}, rescale_mode="clip") .. image-sg:: /auto_examples/models/images/sphx_glr_demo_super_resolution_003.png :alt: Ground truth, LR measurement :srcset: /auto_examples/models/images/sphx_glr_demo_super_resolution_003.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 157-165 We fine-tune the pretrained model. .. note:: With only a handful of images and no regularisation, the model will overfit after the first few epochs and eval PSNR will decrease. We therefore load the best checkpoint after training. For a proper fine-tuning run, use a larger and more diverse dataset. .. GENERATED FROM PYTHON SOURCE LINES 165-185 .. code-block:: Python epochs = 10 if torch.cuda.is_available() else 1 trainer = dinv.Trainer( model=model, physics=physics, optimizer=torch.optim.Adam(model.parameters(), lr=1e-4), train_dataloader=train_dataloader, eval_dataloader=test_dataloader, epochs=epochs, losses=dinv.loss.SupLoss(metric=torch.nn.L1Loss()), metrics=dinv.metric.PSNR(), device=device, plot_images=False, show_progress_bar=False, ) _ = trainer.train() best_model = trainer.load_best_model() .. rst-class:: sphx-glr-script-out .. code-block:: none /local/jtachell/deepinv/deepinv/deepinv/training/trainer.py:1356: UserWarning: non_blocking_transfers=True but DataLoader.pin_memory=False; set pin_memory=True to overlap host-device copies with compute. self.setup_train() The model has 1549462 trainable parameters Train epoch 0: TotalLoss=0.061, PSNR=22.094 Eval epoch 0: PSNR=22.142 Best model saved at epoch 1 Train epoch 1: TotalLoss=0.058, PSNR=22.842 Eval epoch 1: PSNR=20.793 Train epoch 2: TotalLoss=0.056, PSNR=23.223 Eval epoch 2: PSNR=19.585 Train epoch 3: TotalLoss=0.055, PSNR=23.577 Eval epoch 3: PSNR=19.363 Train epoch 4: TotalLoss=0.053, PSNR=23.828 Eval epoch 4: PSNR=18.712 Train epoch 5: TotalLoss=0.052, PSNR=23.962 Eval epoch 5: PSNR=18.175 Train epoch 6: TotalLoss=0.051, PSNR=24.201 Eval epoch 6: PSNR=18.489 Train epoch 7: TotalLoss=0.051, PSNR=24.363 Eval epoch 7: PSNR=17.303 Train epoch 8: TotalLoss=0.05, PSNR=24.43 Eval epoch 8: PSNR=17.552 Train epoch 9: TotalLoss=0.05, PSNR=24.406 Eval epoch 9: PSNR=17.248 Model, optimizer, epoch_start successfully loaded from checkpoint: 26-05-24-16:00:06/ckp_best.pth.tar .. GENERATED FROM PYTHON SOURCE LINES 186-188 Plot a reconstruction from the best checkpoint: .. GENERATED FROM PYTHON SOURCE LINES 188-207 .. code-block:: Python x, y = x.to(device), y.to(device) with torch.no_grad(): x_ft = best_model(y) x_bic_ft = torch.nn.functional.interpolate( y, scale_factor=4, mode="bicubic", antialias=True ) dinv.utils.plot( {"Ground truth": x, "Bicubic": x_bic_ft, "SRResNet": x_ft}, subtitles=[ "PSNR (RGB):", f"{psnr(x, x_bic_ft).item():.2f} dB", f"{psnr(x, x_ft).item():.2f} dB", ], figsize=(8, 4), rescale_mode="clip", ) .. image-sg:: /auto_examples/models/images/sphx_glr_demo_super_resolution_004.png :alt: Ground truth, Bicubic, SRResNet :srcset: /auto_examples/models/images/sphx_glr_demo_super_resolution_004.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 208-211 :References: .. footbibliography:: .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 6.139 seconds) .. _sphx_glr_download_auto_examples_models_demo_super_resolution.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: demo_super_resolution.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: demo_super_resolution.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: demo_super_resolution.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_