Inference and fine-tune a foundation model#

This example shows how to perform inference on and fine-tune the Reconstruct Anything Model (RAM) foundation model [1] to solve inverse problems.

The Reconstruct Anything Model is a model that has been trained to work on a large variety of linear image reconstruction tasks and datasets (deblurring, inpainting, denoising, tomography, MRI, etc.) and is robust to a wide variety of imaging domains.

Tip

import deepinv as dinv
import torch

device = dinv.utils.get_freer_gpu() if torch.cuda.is_available() else "cpu"

model = dinv.models.RAM(device=device, pretrained=True)

1. Zero-shot inference#

First, let’s evaluate the zero-shot inference performance of the foundation model.

Accelerated medical imaging#

Here, we demonstrated reconstructing brain MRI from an accelerated noisy MRI scan from FastMRI:

x = dinv.utils.load_example("demo_mini_subset_fastmri_brain_0.pt", device=device)

# Define physics
physics = dinv.physics.MRI(noise_model=dinv.physics.GaussianNoise(0.05), device=device)

physics_generator = dinv.physics.generator.GaussianMaskGenerator(
    (320, 320), device=device
)

# Generate measurement
y = physics(x, **physics_generator.step())

# Perform inference
with torch.no_grad():
    x_hat = model(y, physics)
    x_lin = physics.A_adjoint(y)

psnr = dinv.metric.PSNR()

dinv.utils.plot(
    {
        "Ground truth": x,
        f"Linear inverse\n PSNR {psnr(x_lin, x).item():.2f}dB": x_lin,
        f"Pretrained RAM\n PSNR {psnr(x_hat, x).item():.2f}dB": x_hat,
    }
)
Ground truth, Linear inverse  PSNR 29.30dB, Pretrained RAM  PSNR 37.11dB

Computational photography#

Joint random motion deblurring and denoising, using a cropped image from color BSD:

x = dinv.utils.load_example("CBSD_0010.png", img_size=(200, 200), device=device)

physics = dinv.physics.BlurFFT(
    img_size=x.shape[1:],
    noise_model=dinv.physics.GaussianNoise(sigma=0.05),
    device=device,
)

# fmt: off
physics_generator = (
    dinv.physics.generator.MotionBlurGenerator((31, 31), l=2.0, sigma=2.4, device=device) +
    dinv.physics.generator.SigmaGenerator(sigma_min=0.001, sigma_max=0.2, device=device)
)
# fmt: on

y = physics(x, **physics_generator.step())

with torch.no_grad():
    x_hat = model(y, physics)
    x_lin = physics.A_adjoint(y)

dinv.utils.plot(
    {
        "Ground truth": x,
        f"Linear inverse\n PSNR {psnr(x_lin, x).item():.2f}dB": x_lin,
        f"Pretrained RAM\n PSNR {psnr(x_hat, x).item():.2f}dB": x_hat,
    }
)
Ground truth, Linear inverse  PSNR 16.74dB, Pretrained RAM  PSNR 22.52dB

Tomography#

Computed Tomography with limited angles using data from the The Cancer Imaging Archive of lungs:

x = dinv.utils.load_example("CT100_256x256_0.pt", device=device)

physics = dinv.physics.Tomography(
    img_width=256,
    angles=10,
    normalize=True,
    device=device,
)

y = physics(x)

with torch.no_grad():
    x_hat = model(y, physics)
    x_lin = physics.A_dagger(y)

dinv.utils.plot(
    {
        "Ground truth": x,
        f"FBP pseudo-inverse\n PSNR {psnr(x_lin, x).item():.2f}dB": x_lin,
        f"Pretrained RAM\n PSNR {psnr(x_hat, x).item():.2f}dB": x_hat,
    }
)
Ground truth, FBP pseudo-inverse  PSNR 12.87dB, Pretrained RAM  PSNR 24.02dB

Remote sensing#

Satellite denoising with Poisson-Gaussian noise using urban data from the WorldView-3 satellite over Jacksonville:

x = dinv.utils.load_example("JAX_018_011_RGB.tif", device=device)[..., :300, :300]

physics = dinv.physics.Denoising(
    noise_model=dinv.physics.PoissonGaussianNoise(sigma=0.1, gain=0.1)
)

y = physics(x)

with torch.no_grad():
    x_hat = model(y, physics)
    # Alternatively, use the model without physics:
    # x_hat = model(y, sigma=0.1, gain=0.1)
    x_lin = physics.A_adjoint(y)

dinv.utils.plot(
    {
        "Ground truth": x,
        f"Linear inverse\n PSNR {psnr(x_lin, x).item():.2f}dB": x_lin,
        f"Pretrained RAM\n PSNR {psnr(x_hat, x).item():.2f}dB": x_hat,
    }
)
Ground truth, Linear inverse  PSNR 12.48dB, Pretrained RAM  PSNR 27.51dB

2. Fine-tuning#

As with all models, there may be a drop in performance when used zero-shot on problems or data outside those seen during training.

For instance, RAM is not trained on image demosaicing:

x = dinv.utils.load_example("butterfly.png", img_size=(127, 129), device=device)

physics = dinv.physics.Demosaicing(
    img_size=x.shape[1:], noise_model=dinv.physics.PoissonNoise(0.1), device=device
)

# Generate measurement
y = physics(x)

# Run inference
with torch.no_grad():
    x_hat = model(y, physics)

# Show results
dinv.utils.plot(
    {
        "Original": x,
        f"Measurement\n PSNR {psnr(y, x).item():.2f}dB": y,
        f"Reconstruction\n PSNR {psnr(x_hat, x).item():.2f}dB": x_hat,
    },
)
Original, Measurement  PSNR 5.99dB, Reconstruction  PSNR 21.37dB

To improve results, we can fine-tune the model on our problem and data, even in the absence of ground truth data, using a self-supervised loss, and even on a single image only.

Here, since this example is run in a no-GPU environment, we will use a small patch of the image to speed up training, but in practice, we can use the full image.

Note

You can also fine-tune on larger datasets if you want, by replacing the dataset.

# Take small patch
x_train = x[..., :64, :64]

physics_train = dinv.physics.Demosaicing(
    img_size=x_train.shape[1:],
    noise_model=dinv.physics.PoissonNoise(0.1, clip_positive=True),
    device=device,
)

y_train = physics_train(x_train)

# Define training loss
losses = [
    dinv.loss.R2RLoss(),
    dinv.loss.EILoss(dinv.transform.Shift(shift_max=0.4), weight=0.1),
]

dataset = dinv.datasets.TensorDataset(y=y_train)
train_dataloader = torch.utils.data.DataLoader(dataset)

We fine-tune using early stopping using a validation set, again without ground truth. We use a small patch of another set of measurements.

The model has 35618953 trainable parameters

  0%|                                                                                                                           | 0/1 [00:00<?, ?it/s]
Train epoch 1/20:   0%|                                                                                                         | 0/1 [00:00<?, ?it/s]
Train epoch 1/20:   0%|                                                         | 0/1 [00:01<?, ?it/s, R2RLoss=0.127, EILoss=0.00117, TotalLoss=0.128]
Train epoch 1/20: 100%|█████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.63s/it, R2RLoss=0.127, EILoss=0.00117, TotalLoss=0.128]
Train epoch 1/20: 100%|█████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.63s/it, R2RLoss=0.127, EILoss=0.00117, TotalLoss=0.128]

  0%|                                                                                                                           | 0/1 [00:00<?, ?it/s]
Eval epoch 1/20:   0%|                                                                                                          | 0/1 [00:00<?, ?it/s]
Eval epoch 1/20:   0%|                                                                                            | 0/1 [00:01<?, ?it/s, R2RLoss=3.67]
Eval epoch 1/20: 100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.42s/it, R2RLoss=3.67]
Eval epoch 1/20: 100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.42s/it, R2RLoss=3.67]
Best model saved at epoch 1

  0%|                                                                                                                           | 0/1 [00:00<?, ?it/s]
Train epoch 2/20:   0%|                                                                                                         | 0/1 [00:00<?, ?it/s]
Train epoch 2/20:   0%|                                                        | 0/1 [00:01<?, ?it/s, R2RLoss=0.122, EILoss=0.000238, TotalLoss=0.123]
Train epoch 2/20: 100%|████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.56s/it, R2RLoss=0.122, EILoss=0.000238, TotalLoss=0.123]
Train epoch 2/20: 100%|████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.56s/it, R2RLoss=0.122, EILoss=0.000238, TotalLoss=0.123]

  0%|                                                                                                                           | 0/1 [00:00<?, ?it/s]
Train epoch 3/20:   0%|                                                                                                         | 0/1 [00:00<?, ?it/s]
Train epoch 3/20:   0%|                                                        | 0/1 [00:01<?, ?it/s, R2RLoss=0.128, EILoss=0.000214, TotalLoss=0.128]
Train epoch 3/20: 100%|████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.56s/it, R2RLoss=0.128, EILoss=0.000214, TotalLoss=0.128]
Train epoch 3/20: 100%|████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.56s/it, R2RLoss=0.128, EILoss=0.000214, TotalLoss=0.128]

  0%|                                                                                                                           | 0/1 [00:00<?, ?it/s]
Train epoch 4/20:   0%|                                                                                                         | 0/1 [00:00<?, ?it/s]
Train epoch 4/20:   0%|                                                        | 0/1 [00:01<?, ?it/s, R2RLoss=0.129, EILoss=0.000209, TotalLoss=0.129]
Train epoch 4/20: 100%|████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.57s/it, R2RLoss=0.129, EILoss=0.000209, TotalLoss=0.129]
Train epoch 4/20: 100%|████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.57s/it, R2RLoss=0.129, EILoss=0.000209, TotalLoss=0.129]

  0%|                                                                                                                           | 0/1 [00:00<?, ?it/s]
Train epoch 5/20:   0%|                                                                                                         | 0/1 [00:00<?, ?it/s]
Train epoch 5/20:   0%|                                                        | 0/1 [00:01<?, ?it/s, R2RLoss=0.126, EILoss=0.000215, TotalLoss=0.126]
Train epoch 5/20: 100%|████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.57s/it, R2RLoss=0.126, EILoss=0.000215, TotalLoss=0.126]
Train epoch 5/20: 100%|████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.57s/it, R2RLoss=0.126, EILoss=0.000215, TotalLoss=0.126]

  0%|                                                                                                                           | 0/1 [00:00<?, ?it/s]
Train epoch 6/20:   0%|                                                                                                         | 0/1 [00:00<?, ?it/s]
Train epoch 6/20:   0%|                                                        | 0/1 [00:01<?, ?it/s, R2RLoss=0.121, EILoss=0.000266, TotalLoss=0.122]
Train epoch 6/20: 100%|████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.56s/it, R2RLoss=0.121, EILoss=0.000266, TotalLoss=0.122]
Train epoch 6/20: 100%|████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.56s/it, R2RLoss=0.121, EILoss=0.000266, TotalLoss=0.122]

  0%|                                                                                                                           | 0/1 [00:00<?, ?it/s]
Eval epoch 6/20:   0%|                                                                                                          | 0/1 [00:00<?, ?it/s]
Eval epoch 6/20:   0%|                                                                                            | 0/1 [00:01<?, ?it/s, R2RLoss=3.71]
Eval epoch 6/20: 100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.43s/it, R2RLoss=3.71]
Eval epoch 6/20: 100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.43s/it, R2RLoss=3.71]

  0%|                                                                                                                           | 0/1 [00:00<?, ?it/s]
Train epoch 7/20:   0%|                                                                                                         | 0/1 [00:00<?, ?it/s]
Train epoch 7/20:   0%|                                                          | 0/1 [00:01<?, ?it/s, R2RLoss=0.13, EILoss=0.000209, TotalLoss=0.13]
Train epoch 7/20: 100%|██████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.56s/it, R2RLoss=0.13, EILoss=0.000209, TotalLoss=0.13]
Train epoch 7/20: 100%|██████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.56s/it, R2RLoss=0.13, EILoss=0.000209, TotalLoss=0.13]

  0%|                                                                                                                           | 0/1 [00:00<?, ?it/s]
Train epoch 8/20:   0%|                                                                                                         | 0/1 [00:00<?, ?it/s]
Train epoch 8/20:   0%|                                                        | 0/1 [00:01<?, ?it/s, R2RLoss=0.121, EILoss=0.000281, TotalLoss=0.122]
Train epoch 8/20: 100%|████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.56s/it, R2RLoss=0.121, EILoss=0.000281, TotalLoss=0.122]
Train epoch 8/20: 100%|████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.56s/it, R2RLoss=0.121, EILoss=0.000281, TotalLoss=0.122]

  0%|                                                                                                                           | 0/1 [00:00<?, ?it/s]
Train epoch 9/20:   0%|                                                                                                         | 0/1 [00:00<?, ?it/s]
Train epoch 9/20:   0%|                                                        | 0/1 [00:01<?, ?it/s, R2RLoss=0.126, EILoss=0.000263, TotalLoss=0.126]
Train epoch 9/20: 100%|████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.56s/it, R2RLoss=0.126, EILoss=0.000263, TotalLoss=0.126]
Train epoch 9/20: 100%|████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.56s/it, R2RLoss=0.126, EILoss=0.000263, TotalLoss=0.126]

  0%|                                                                                                                           | 0/1 [00:00<?, ?it/s]
Train epoch 10/20:   0%|                                                                                                        | 0/1 [00:00<?, ?it/s]
Train epoch 10/20:   0%|                                                       | 0/1 [00:01<?, ?it/s, R2RLoss=0.122, EILoss=0.000253, TotalLoss=0.122]
Train epoch 10/20: 100%|███████████████████████████████████████████████| 1/1 [00:01<00:00,  1.56s/it, R2RLoss=0.122, EILoss=0.000253, TotalLoss=0.122]
Train epoch 10/20: 100%|███████████████████████████████████████████████| 1/1 [00:01<00:00,  1.56s/it, R2RLoss=0.122, EILoss=0.000253, TotalLoss=0.122]

  0%|                                                                                                                           | 0/1 [00:00<?, ?it/s]
Train epoch 11/20:   0%|                                                                                                        | 0/1 [00:00<?, ?it/s]
Train epoch 11/20:   0%|                                                       | 0/1 [00:01<?, ?it/s, R2RLoss=0.129, EILoss=0.000225, TotalLoss=0.129]
Train epoch 11/20: 100%|███████████████████████████████████████████████| 1/1 [00:01<00:00,  1.56s/it, R2RLoss=0.129, EILoss=0.000225, TotalLoss=0.129]
Train epoch 11/20: 100%|███████████████████████████████████████████████| 1/1 [00:01<00:00,  1.56s/it, R2RLoss=0.129, EILoss=0.000225, TotalLoss=0.129]

  0%|                                                                                                                           | 0/1 [00:00<?, ?it/s]
Eval epoch 11/20:   0%|                                                                                                         | 0/1 [00:00<?, ?it/s]
Eval epoch 11/20:   0%|                                                                                           | 0/1 [00:01<?, ?it/s, R2RLoss=3.69]
Eval epoch 11/20: 100%|███████████████████████████████████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.43s/it, R2RLoss=3.69]
Eval epoch 11/20: 100%|███████████████████████████████████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.43s/it, R2RLoss=3.69]

  0%|                                                                                                                           | 0/1 [00:00<?, ?it/s]
Train epoch 12/20:   0%|                                                                                                        | 0/1 [00:00<?, ?it/s]
Train epoch 12/20:   0%|                                                       | 0/1 [00:01<?, ?it/s, R2RLoss=0.128, EILoss=0.000226, TotalLoss=0.128]
Train epoch 12/20: 100%|███████████████████████████████████████████████| 1/1 [00:01<00:00,  1.56s/it, R2RLoss=0.128, EILoss=0.000226, TotalLoss=0.128]
Train epoch 12/20: 100%|███████████████████████████████████████████████| 1/1 [00:01<00:00,  1.56s/it, R2RLoss=0.128, EILoss=0.000226, TotalLoss=0.128]

  0%|                                                                                                                           | 0/1 [00:00<?, ?it/s]
Train epoch 13/20:   0%|                                                                                                        | 0/1 [00:00<?, ?it/s]
Train epoch 13/20:   0%|                                                       | 0/1 [00:01<?, ?it/s, R2RLoss=0.121, EILoss=0.000183, TotalLoss=0.121]
Train epoch 13/20: 100%|███████████████████████████████████████████████| 1/1 [00:01<00:00,  1.55s/it, R2RLoss=0.121, EILoss=0.000183, TotalLoss=0.121]
Train epoch 13/20: 100%|███████████████████████████████████████████████| 1/1 [00:01<00:00,  1.55s/it, R2RLoss=0.121, EILoss=0.000183, TotalLoss=0.121]

  0%|                                                                                                                           | 0/1 [00:00<?, ?it/s]
Train epoch 14/20:   0%|                                                                                                        | 0/1 [00:00<?, ?it/s]
Train epoch 14/20:   0%|                                                       | 0/1 [00:01<?, ?it/s, R2RLoss=0.125, EILoss=0.000218, TotalLoss=0.125]
Train epoch 14/20: 100%|███████████████████████████████████████████████| 1/1 [00:01<00:00,  1.56s/it, R2RLoss=0.125, EILoss=0.000218, TotalLoss=0.125]
Train epoch 14/20: 100%|███████████████████████████████████████████████| 1/1 [00:01<00:00,  1.56s/it, R2RLoss=0.125, EILoss=0.000218, TotalLoss=0.125]

  0%|                                                                                                                           | 0/1 [00:00<?, ?it/s]
Train epoch 15/20:   0%|                                                                                                        | 0/1 [00:00<?, ?it/s]
Train epoch 15/20:   0%|                                                       | 0/1 [00:01<?, ?it/s, R2RLoss=0.126, EILoss=0.000207, TotalLoss=0.126]
Train epoch 15/20: 100%|███████████████████████████████████████████████| 1/1 [00:01<00:00,  1.57s/it, R2RLoss=0.126, EILoss=0.000207, TotalLoss=0.126]
Train epoch 15/20: 100%|███████████████████████████████████████████████| 1/1 [00:01<00:00,  1.57s/it, R2RLoss=0.126, EILoss=0.000207, TotalLoss=0.126]

  0%|                                                                                                                           | 0/1 [00:00<?, ?it/s]
Train epoch 16/20:   0%|                                                                                                        | 0/1 [00:00<?, ?it/s]
Train epoch 16/20:   0%|                                                       | 0/1 [00:01<?, ?it/s, R2RLoss=0.126, EILoss=0.000236, TotalLoss=0.127]
Train epoch 16/20: 100%|███████████████████████████████████████████████| 1/1 [00:01<00:00,  1.56s/it, R2RLoss=0.126, EILoss=0.000236, TotalLoss=0.127]
Train epoch 16/20: 100%|███████████████████████████████████████████████| 1/1 [00:01<00:00,  1.56s/it, R2RLoss=0.126, EILoss=0.000236, TotalLoss=0.127]

  0%|                                                                                                                           | 0/1 [00:00<?, ?it/s]
Eval epoch 16/20:   0%|                                                                                                         | 0/1 [00:00<?, ?it/s]
Eval epoch 16/20:   0%|                                                                                           | 0/1 [00:01<?, ?it/s, R2RLoss=3.72]
Eval epoch 16/20: 100%|███████████████████████████████████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.42s/it, R2RLoss=3.72]
Eval epoch 16/20: 100%|███████████████████████████████████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.42s/it, R2RLoss=3.72]
Early stopping triggered as validation metrics have not improved in the last 3 validation steps, disable it with early_stop=False

We can now use the fine-tuned model to reconstruct the image from the measurement y.

with torch.no_grad():
    x_hat_ft = finetuned_model(y, physics)

# Show results
dinv.utils.plot(
    {
        "Original": x,
        f"Measurement\n PSNR {psnr(y, x).item():.2f}dB": y,
        f"Zero-shot reconstruction\n PSNR {psnr(x_hat, x).item():.2f}dB": x_hat,
        f"Fine-tuned reconstruction\n PSNR {psnr(x_hat_ft, x).item():.2f}dB": x_hat_ft,
    },
)
Original, Measurement  PSNR 5.99dB, Zero-shot reconstruction  PSNR 21.37dB, Fine-tuned reconstruction  PSNR 24.45dB
References:

Total running time of the script: (0 minutes 54.152 seconds)

Gallery generated by Sphinx-Gallery