.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/models/demo_foundation_model.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note New to DeepInverse? Get started with the basics with the :ref:`5 minute quickstart tutorial `. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_models_demo_foundation_model.py: Inference and fine-tune a foundation model ========================================== This example shows how to perform inference on and fine-tune the Reconstruct Anything Model (RAM) foundation model :footcite:p:`terris2025reconstruct` to solve inverse problems. The :class:`Reconstruct Anything Model ` is a model that has been trained to work on a large variety of linear image reconstruction tasks and datasets (deblurring, inpainting, denoising, tomography, MRI, etc.) and is robust to a wide variety of imaging domains. .. tip:: * Want to use your own dataset? See :ref:`sphx_glr_auto_examples_basics_demo_custom_dataset.py` * Want to use your own physics? See :ref:`sphx_glr_auto_examples_basics_demo_custom_physics.py` .. GENERATED FROM PYTHON SOURCE LINES 17-25 .. code-block:: Python import deepinv as dinv import torch device = dinv.utils.get_freer_gpu() if torch.cuda.is_available() else "cpu" model = dinv.models.RAM(device=device, pretrained=True) .. GENERATED FROM PYTHON SOURCE LINES 26-35 1. Zero-shot inference ---------------------- First, let's evaluate the zero-shot inference performance of the foundation model. Accelerated medical imaging ~~~~~~~~~~~~~~~~~~~~~~~~~~~ Here, we demonstrated reconstructing brain MRI from an accelerated noisy MRI scan from `FastMRI `_: .. GENERATED FROM PYTHON SOURCE LINES 35-63 .. code-block:: Python x = dinv.utils.load_example("demo_mini_subset_fastmri_brain_0.pt", device=device) # Define physics physics = dinv.physics.MRI(noise_model=dinv.physics.GaussianNoise(0.05), device=device) physics_generator = dinv.physics.generator.GaussianMaskGenerator( (320, 320), device=device ) # Generate measurement y = physics(x, **physics_generator.step()) # Perform inference with torch.no_grad(): x_hat = model(y, physics) x_lin = physics.A_adjoint(y) psnr = dinv.metric.PSNR() dinv.utils.plot( { "Ground truth": x, f"Linear inverse\n PSNR {psnr(x_lin, x).item():.2f}dB": x_lin, f"Pretrained RAM\n PSNR {psnr(x_hat, x).item():.2f}dB": x_hat, } ) .. image-sg:: /auto_examples/models/images/sphx_glr_demo_foundation_model_001.png :alt: Ground truth, Linear inverse PSNR 29.30dB, Pretrained RAM PSNR 37.11dB :srcset: /auto_examples/models/images/sphx_glr_demo_foundation_model_001.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 64-68 Computational photography ~~~~~~~~~~~~~~~~~~~~~~~~~ Joint random motion deblurring and denoising, using a cropped image from color BSD: .. GENERATED FROM PYTHON SOURCE LINES 68-98 .. code-block:: Python x = dinv.utils.load_example("CBSD_0010.png", img_size=(200, 200), device=device) physics = dinv.physics.BlurFFT( img_size=x.shape[1:], noise_model=dinv.physics.GaussianNoise(sigma=0.05), device=device, ) # fmt: off physics_generator = ( dinv.physics.generator.MotionBlurGenerator((31, 31), l=2.0, sigma=2.4, device=device) + dinv.physics.generator.SigmaGenerator(sigma_min=0.001, sigma_max=0.2, device=device) ) # fmt: on y = physics(x, **physics_generator.step()) with torch.no_grad(): x_hat = model(y, physics) x_lin = physics.A_adjoint(y) dinv.utils.plot( { "Ground truth": x, f"Linear inverse\n PSNR {psnr(x_lin, x).item():.2f}dB": x_lin, f"Pretrained RAM\n PSNR {psnr(x_hat, x).item():.2f}dB": x_hat, } ) .. image-sg:: /auto_examples/models/images/sphx_glr_demo_foundation_model_002.png :alt: Ground truth, Linear inverse PSNR 16.74dB, Pretrained RAM PSNR 22.52dB :srcset: /auto_examples/models/images/sphx_glr_demo_foundation_model_002.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 99-104 Tomography ~~~~~~~~~~ Computed Tomography with limited angles using data from the `The Cancer Imaging Archive `_ of lungs: .. GENERATED FROM PYTHON SOURCE LINES 104-128 .. code-block:: Python x = dinv.utils.load_example("CT100_256x256_0.pt", device=device) physics = dinv.physics.Tomography( img_width=256, angles=10, normalize=True, device=device, ) y = physics(x) with torch.no_grad(): x_hat = model(y, physics) x_lin = physics.A_dagger(y) dinv.utils.plot( { "Ground truth": x, f"FBP pseudo-inverse\n PSNR {psnr(x_lin, x).item():.2f}dB": x_lin, f"Pretrained RAM\n PSNR {psnr(x_hat, x).item():.2f}dB": x_hat, } ) .. image-sg:: /auto_examples/models/images/sphx_glr_demo_foundation_model_003.png :alt: Ground truth, FBP pseudo-inverse PSNR 12.87dB, Pretrained RAM PSNR 24.02dB :srcset: /auto_examples/models/images/sphx_glr_demo_foundation_model_003.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 129-134 Remote sensing ~~~~~~~~~~~~~~ Satellite denoising with Poisson-Gaussian noise using urban data from the `WorldView-3 satellite `_ over Jacksonville: .. GENERATED FROM PYTHON SOURCE LINES 134-158 .. code-block:: Python x = dinv.utils.load_example("JAX_018_011_RGB.tif", device=device)[..., :300, :300] physics = dinv.physics.Denoising( noise_model=dinv.physics.PoissonGaussianNoise(sigma=0.1, gain=0.1) ) y = physics(x) with torch.no_grad(): x_hat = model(y, physics) # Alternatively, use the model without physics: # x_hat = model(y, sigma=0.1, gain=0.1) x_lin = physics.A_adjoint(y) dinv.utils.plot( { "Ground truth": x, f"Linear inverse\n PSNR {psnr(x_lin, x).item():.2f}dB": x_lin, f"Pretrained RAM\n PSNR {psnr(x_hat, x).item():.2f}dB": x_hat, } ) .. image-sg:: /auto_examples/models/images/sphx_glr_demo_foundation_model_004.png :alt: Ground truth, Linear inverse PSNR 12.48dB, Pretrained RAM PSNR 27.51dB :srcset: /auto_examples/models/images/sphx_glr_demo_foundation_model_004.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 159-164 2. Fine-tuning -------------- As with all models, there may be a drop in performance when used zero-shot on problems or data outside those seen during training. For instance, RAM is not trained on image demosaicing: .. GENERATED FROM PYTHON SOURCE LINES 164-187 .. code-block:: Python x = dinv.utils.load_example("butterfly.png", img_size=(127, 129), device=device) physics = dinv.physics.Demosaicing( img_size=x.shape[1:], noise_model=dinv.physics.PoissonNoise(0.1), device=device ) # Generate measurement y = physics(x) # Run inference with torch.no_grad(): x_hat = model(y, physics) # Show results dinv.utils.plot( { "Original": x, f"Measurement\n PSNR {psnr(y, x).item():.2f}dB": y, f"Reconstruction\n PSNR {psnr(x_hat, x).item():.2f}dB": x_hat, }, ) .. image-sg:: /auto_examples/models/images/sphx_glr_demo_foundation_model_005.png :alt: Original, Measurement PSNR 5.99dB, Reconstruction PSNR 21.37dB :srcset: /auto_examples/models/images/sphx_glr_demo_foundation_model_005.png :class: sphx-glr-single-img .. GENERATED FROM PYTHON SOURCE LINES 188-197 To improve results, we can fine-tune the model on our problem and data, **even in the absence of ground truth data**, using a :ref:`self-supervised loss `, and **even on a single image only**. Here, since this example is run in a no-GPU environment, we will use a small patch of the image to speed up training, but in practice, we can use the full image. .. note:: You can also fine-tune on larger datasets if you want, by replacing the :ref:`dataset `. .. GENERATED FROM PYTHON SOURCE LINES 197-218 .. code-block:: Python # Take small patch x_train = x[..., :64, :64] physics_train = dinv.physics.Demosaicing( img_size=x_train.shape[1:], noise_model=dinv.physics.PoissonNoise(0.1, clip_positive=True), device=device, ) y_train = physics_train(x_train) # Define training loss losses = [ dinv.loss.R2RLoss(), dinv.loss.EILoss(dinv.transform.Shift(shift_max=0.4), weight=0.1), ] dataset = dinv.datasets.TensorDataset(y=y_train) train_dataloader = torch.utils.data.DataLoader(dataset) .. GENERATED FROM PYTHON SOURCE LINES 219-221 We fine-tune using early stopping using a validation set, again without ground truth. We use a small patch of another set of measurements. .. GENERATED FROM PYTHON SOURCE LINES 221-248 .. code-block:: Python eval_dataloader = torch.utils.data.DataLoader( dinv.datasets.TensorDataset( y=physics_train( dinv.utils.load_example("leaves.png", device=device)[..., :64, :64] ) ) ) max_epochs = 20 trainer = dinv.Trainer( model=model, physics=physics_train, eval_interval=5, ckp_interval=max_epochs - 1, metrics=losses[0], early_stop=True, device=device, losses=losses, epochs=max_epochs, optimizer=torch.optim.Adam(model.parameters(), lr=5e-5), train_dataloader=train_dataloader, eval_dataloader=eval_dataloader, ) finetuned_model = trainer.train() .. rst-class:: sphx-glr-script-out .. code-block:: none The model has 35618953 trainable parameters 0%| | 0/1 [00:00` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: demo_foundation_model.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: demo_foundation_model.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_