Inference and fine-tune a foundation model#

This example shows how to perform inference on and fine-tune the Reconstruct Anything Model (RAM) foundation model [1] to solve inverse problems.

The Reconstruct Anything Model is a model that has been trained to work on a large variety of linear image reconstruction tasks and datasets (deblurring, inpainting, denoising, tomography, MRI, etc.) and is robust to a wide variety of imaging domains.

Tip

import deepinv as dinv
import torch

device = dinv.utils.get_freer_gpu() if torch.cuda.is_available() else "cpu"

model = dinv.models.RAM(device=device, pretrained=True)

1. Zero-shot inference#

First, let’s evaluate the zero-shot inference performance of the foundation model.

Accelerated medical imaging#

Here, we demonstrated reconstructing brain MRI from an accelerated noisy MRI scan from FastMRI:

x = dinv.utils.load_example("demo_mini_subset_fastmri_brain_0.pt", device=device)

# Define physics
physics = dinv.physics.MRI(noise_model=dinv.physics.GaussianNoise(0.05), device=device)

physics_generator = dinv.physics.generator.GaussianMaskGenerator(
    (320, 320), device=device
)

# Generate measurement
y = physics(x, **physics_generator.step())

# Perform inference
with torch.no_grad():
    x_hat = model(y, physics)
    x_lin = physics.A_adjoint(y)

psnr = dinv.metric.PSNR()

dinv.utils.plot(
    {
        "Ground truth": x,
        f"Linear inverse": x_lin,
        f"Pretrained RAM": x_hat,
    },
    subtitles=[
        "PSNR:",
        f"{psnr(x, x_lin).item():.2f} dB",
        f"{psnr(x, x_hat).item():.2f} dB",
    ],
    figsize=(6, 4),
)
Ground truth, Linear inverse, Pretrained RAM

Computational photography#

Joint random motion deblurring and denoising, using a cropped image from color BSD:

x = dinv.utils.load_example("CBSD_0010.png", img_size=(200, 200), device=device)

physics = dinv.physics.BlurFFT(
    img_size=x.shape[1:],
    noise_model=dinv.physics.GaussianNoise(sigma=0.05),
    device=device,
)

# fmt: off
physics_generator = (
    dinv.physics.generator.MotionBlurGenerator((31, 31), l=2.0, sigma=2.4, device=device) +
    dinv.physics.generator.SigmaGenerator(sigma_min=0.001, sigma_max=0.2, device=device)
)
# fmt: on

y = physics(x, **physics_generator.step())

with torch.no_grad():
    x_hat = model(y, physics)
    x_lin = physics.A_adjoint(y)

dinv.utils.plot(
    {
        "Ground truth": x,
        f"Linear inverse": x_lin,
        f"Pretrained RAM": x_hat,
    },
    subtitles=[
        "PSNR:",
        f"{psnr(x, x_lin).item():.2f} dB",
        f"{psnr(x, x_hat).item():.2f} dB",
    ],
    figsize=(6, 4),
)
Ground truth, Linear inverse, Pretrained RAM

Tomography#

Computed Tomography with limited angles using data from the The Cancer Imaging Archive of lungs:

x = dinv.utils.load_example("CT100_256x256_0.pt", device=device)

physics = dinv.physics.Tomography(
    img_width=256,
    angles=10,
    normalize=True,
    device=device,
)

y = physics(x)

with torch.no_grad():
    x_hat = model(y, physics)
    x_lin = physics.A_dagger(y)

dinv.utils.plot(
    {
        "Ground truth": x,
        f"FBP pseudo-inverse": x_lin,
        f"Pretrained RAM": x_hat,
    },
    subtitles=[
        "PSNR:",
        f"{psnr(x, x_lin).item():.2f} dB",
        f"{psnr(x, x_hat).item():.2f} dB",
    ],
    figsize=(6, 4),
)
Ground truth, FBP pseudo-inverse, Pretrained RAM
Power iteration converged at iteration 7, ||A^T A||_2=2476.22

Remote sensing#

Satellite denoising with Poisson-Gaussian noise using urban data from the WorldView-3 satellite over Jacksonville:

x = dinv.utils.load_example("JAX_018_011_RGB.tif", device=device)[..., :300, :300]

physics = dinv.physics.Denoising(
    noise_model=dinv.physics.PoissonGaussianNoise(sigma=0.1, gain=0.1)
)

y = physics(x)

with torch.no_grad():
    x_hat = model(y, physics)
    # Alternatively, use the model without physics:
    # x_hat = model(y, sigma=0.1, gain=0.1)
    x_lin = physics.A_adjoint(y)

dinv.utils.plot(
    {
        "Ground truth": x,
        f"Linear inverse": x_lin,
        f"Pretrained RAM": x_hat,
    },
    subtitles=[
        "PSNR:",
        f"{psnr(x, x_lin).item():.2f} dB",
        f"{psnr(x, x_hat).item():.2f} dB",
    ],
    figsize=(6, 4),
)
Ground truth, Linear inverse, Pretrained RAM

2. Fine-tuning#

As with all models, there may be a drop in performance when used zero-shot on problems or data outside those seen during training.

For instance, RAM is not trained on image demosaicing:

x = dinv.utils.load_example("butterfly.png", img_size=(127, 129), device=device)

physics = dinv.physics.Demosaicing(
    img_size=x.shape[1:], noise_model=dinv.physics.PoissonNoise(0.1), device=device
)

# Generate measurement
y = physics(x)

# Run inference
with torch.no_grad():
    x_hat = model(y, physics)

# Show results
dinv.utils.plot(
    {
        "Original": x,
        f"Measurement": y,
        f"Reconstruction": x_hat,
    },
    subtitles=[
        "PSNR:",
        f"{psnr(x, y).item():.2f} dB",
        f"{psnr(x, x_hat).item():.2f} dB",
    ],
    figsize=(6, 4),
)
Original, Measurement, Reconstruction

To improve results, we can fine-tune the model on our problem and data, even in the absence of ground truth data, using a self-supervised loss, and even on a single image only.

Here, since this example is run in a no-GPU environment, we will use a small patch of the image to speed up training, but in practice, we can use the full image.

Note

You can also fine-tune on larger datasets if you want, by replacing the dataset.

# Take small patch
x_train = x[..., :64, :64]

physics_train = dinv.physics.Demosaicing(
    img_size=x_train.shape[1:],
    noise_model=dinv.physics.PoissonNoise(0.1, clip_positive=True),
    device=device,
)

y_train = physics_train(x_train)

# Define training loss
losses = [
    dinv.loss.R2RLoss(),
    dinv.loss.EILoss(dinv.transform.Shift(shift_max=0.4), weight=0.1),
]

dataset = dinv.datasets.TensorDataset(y=y_train)
train_dataloader = torch.utils.data.DataLoader(dataset)

We fine-tune using early stopping using a validation set, again without ground truth. We use a small patch of another set of measurements.

The model has 35618953 trainable parameters
/home/runner/work/deepinv/deepinv/deepinv/training/trainer.py:421: UserWarning: Update progress bar frequency of 1 may slow down training on GPU. Consider increasing this.
  warnings.warn(

  0%|          | 0/1 [00:00<?, ?it/s]
Train epoch 1/20:   0%|          | 0/1 [00:00<?, ?it/s]
Train epoch 1/20:   0%|          | 0/1 [00:01<?, ?it/s, R2RLoss=0.124, EILoss=0.0014, TotalLoss=0.125]
Train epoch 1/20: 100%|██████████| 1/1 [00:01<00:00,  1.63s/it, R2RLoss=0.124, EILoss=0.0014, TotalLoss=0.125]
Train epoch 1/20: 100%|██████████| 1/1 [00:01<00:00,  1.63s/it, R2RLoss=0.124, EILoss=0.0014, TotalLoss=0.125]

  0%|          | 0/1 [00:00<?, ?it/s]
Eval epoch 1/20:   0%|          | 0/1 [00:00<?, ?it/s]
Eval epoch 1/20:   0%|          | 0/1 [00:01<?, ?it/s, R2RLoss=3.81]
Eval epoch 1/20: 100%|██████████| 1/1 [00:01<00:00,  1.41s/it, R2RLoss=3.81]
Eval epoch 1/20: 100%|██████████| 1/1 [00:01<00:00,  1.41s/it, R2RLoss=3.81]
Best model saved at epoch 1

  0%|          | 0/1 [00:00<?, ?it/s]
Train epoch 2/20:   0%|          | 0/1 [00:00<?, ?it/s]
Train epoch 2/20:   0%|          | 0/1 [00:01<?, ?it/s, R2RLoss=0.118, EILoss=0.000364, TotalLoss=0.119]
Train epoch 2/20: 100%|██████████| 1/1 [00:01<00:00,  1.60s/it, R2RLoss=0.118, EILoss=0.000364, TotalLoss=0.119]
Train epoch 2/20: 100%|██████████| 1/1 [00:01<00:00,  1.60s/it, R2RLoss=0.118, EILoss=0.000364, TotalLoss=0.119]

  0%|          | 0/1 [00:00<?, ?it/s]
Train epoch 3/20:   0%|          | 0/1 [00:00<?, ?it/s]
Train epoch 3/20:   0%|          | 0/1 [00:01<?, ?it/s, R2RLoss=0.131, EILoss=0.000236, TotalLoss=0.131]
Train epoch 3/20: 100%|██████████| 1/1 [00:01<00:00,  1.59s/it, R2RLoss=0.131, EILoss=0.000236, TotalLoss=0.131]
Train epoch 3/20: 100%|██████████| 1/1 [00:01<00:00,  1.59s/it, R2RLoss=0.131, EILoss=0.000236, TotalLoss=0.131]

  0%|          | 0/1 [00:00<?, ?it/s]
Train epoch 4/20:   0%|          | 0/1 [00:00<?, ?it/s]
Train epoch 4/20:   0%|          | 0/1 [00:01<?, ?it/s, R2RLoss=0.124, EILoss=0.000263, TotalLoss=0.124]
Train epoch 4/20: 100%|██████████| 1/1 [00:01<00:00,  1.63s/it, R2RLoss=0.124, EILoss=0.000263, TotalLoss=0.124]
Train epoch 4/20: 100%|██████████| 1/1 [00:01<00:00,  1.63s/it, R2RLoss=0.124, EILoss=0.000263, TotalLoss=0.124]

  0%|          | 0/1 [00:00<?, ?it/s]
Train epoch 5/20:   0%|          | 0/1 [00:00<?, ?it/s]
Train epoch 5/20:   0%|          | 0/1 [00:01<?, ?it/s, R2RLoss=0.128, EILoss=0.000179, TotalLoss=0.128]
Train epoch 5/20: 100%|██████████| 1/1 [00:01<00:00,  1.58s/it, R2RLoss=0.128, EILoss=0.000179, TotalLoss=0.128]
Train epoch 5/20: 100%|██████████| 1/1 [00:01<00:00,  1.59s/it, R2RLoss=0.128, EILoss=0.000179, TotalLoss=0.128]

  0%|          | 0/1 [00:00<?, ?it/s]
Train epoch 6/20:   0%|          | 0/1 [00:00<?, ?it/s]
Train epoch 6/20:   0%|          | 0/1 [00:01<?, ?it/s, R2RLoss=0.13, EILoss=0.0002, TotalLoss=0.13]
Train epoch 6/20: 100%|██████████| 1/1 [00:01<00:00,  1.59s/it, R2RLoss=0.13, EILoss=0.0002, TotalLoss=0.13]
Train epoch 6/20: 100%|██████████| 1/1 [00:01<00:00,  1.59s/it, R2RLoss=0.13, EILoss=0.0002, TotalLoss=0.13]

  0%|          | 0/1 [00:00<?, ?it/s]
Eval epoch 6/20:   0%|          | 0/1 [00:00<?, ?it/s]
Eval epoch 6/20:   0%|          | 0/1 [00:01<?, ?it/s, R2RLoss=3.84]
Eval epoch 6/20: 100%|██████████| 1/1 [00:01<00:00,  1.43s/it, R2RLoss=3.84]
Eval epoch 6/20: 100%|██████████| 1/1 [00:01<00:00,  1.43s/it, R2RLoss=3.84]

  0%|          | 0/1 [00:00<?, ?it/s]
Train epoch 7/20:   0%|          | 0/1 [00:00<?, ?it/s]
Train epoch 7/20:   0%|          | 0/1 [00:01<?, ?it/s, R2RLoss=0.126, EILoss=0.000205, TotalLoss=0.126]
Train epoch 7/20: 100%|██████████| 1/1 [00:01<00:00,  1.58s/it, R2RLoss=0.126, EILoss=0.000205, TotalLoss=0.126]
Train epoch 7/20: 100%|██████████| 1/1 [00:01<00:00,  1.58s/it, R2RLoss=0.126, EILoss=0.000205, TotalLoss=0.126]

  0%|          | 0/1 [00:00<?, ?it/s]
Train epoch 8/20:   0%|          | 0/1 [00:00<?, ?it/s]
Train epoch 8/20:   0%|          | 0/1 [00:01<?, ?it/s, R2RLoss=0.128, EILoss=0.000197, TotalLoss=0.128]
Train epoch 8/20: 100%|██████████| 1/1 [00:01<00:00,  1.59s/it, R2RLoss=0.128, EILoss=0.000197, TotalLoss=0.128]
Train epoch 8/20: 100%|██████████| 1/1 [00:01<00:00,  1.59s/it, R2RLoss=0.128, EILoss=0.000197, TotalLoss=0.128]

  0%|          | 0/1 [00:00<?, ?it/s]
Train epoch 9/20:   0%|          | 0/1 [00:00<?, ?it/s]
Train epoch 9/20:   0%|          | 0/1 [00:01<?, ?it/s, R2RLoss=0.127, EILoss=0.000245, TotalLoss=0.127]
Train epoch 9/20: 100%|██████████| 1/1 [00:01<00:00,  1.58s/it, R2RLoss=0.127, EILoss=0.000245, TotalLoss=0.127]
Train epoch 9/20: 100%|██████████| 1/1 [00:01<00:00,  1.58s/it, R2RLoss=0.127, EILoss=0.000245, TotalLoss=0.127]

  0%|          | 0/1 [00:00<?, ?it/s]
Train epoch 10/20:   0%|          | 0/1 [00:00<?, ?it/s]
Train epoch 10/20:   0%|          | 0/1 [00:01<?, ?it/s, R2RLoss=0.123, EILoss=0.000253, TotalLoss=0.123]
Train epoch 10/20: 100%|██████████| 1/1 [00:01<00:00,  1.58s/it, R2RLoss=0.123, EILoss=0.000253, TotalLoss=0.123]
Train epoch 10/20: 100%|██████████| 1/1 [00:01<00:00,  1.58s/it, R2RLoss=0.123, EILoss=0.000253, TotalLoss=0.123]

  0%|          | 0/1 [00:00<?, ?it/s]
Train epoch 11/20:   0%|          | 0/1 [00:00<?, ?it/s]
Train epoch 11/20:   0%|          | 0/1 [00:01<?, ?it/s, R2RLoss=0.125, EILoss=0.000311, TotalLoss=0.125]
Train epoch 11/20: 100%|██████████| 1/1 [00:01<00:00,  1.61s/it, R2RLoss=0.125, EILoss=0.000311, TotalLoss=0.125]
Train epoch 11/20: 100%|██████████| 1/1 [00:01<00:00,  1.61s/it, R2RLoss=0.125, EILoss=0.000311, TotalLoss=0.125]

  0%|          | 0/1 [00:00<?, ?it/s]
Eval epoch 11/20:   0%|          | 0/1 [00:00<?, ?it/s]
Eval epoch 11/20:   0%|          | 0/1 [00:01<?, ?it/s, R2RLoss=3.87]
Eval epoch 11/20: 100%|██████████| 1/1 [00:01<00:00,  1.47s/it, R2RLoss=3.87]
Eval epoch 11/20: 100%|██████████| 1/1 [00:01<00:00,  1.47s/it, R2RLoss=3.87]

  0%|          | 0/1 [00:00<?, ?it/s]
Train epoch 12/20:   0%|          | 0/1 [00:00<?, ?it/s]
Train epoch 12/20:   0%|          | 0/1 [00:01<?, ?it/s, R2RLoss=0.129, EILoss=0.000248, TotalLoss=0.129]
Train epoch 12/20: 100%|██████████| 1/1 [00:01<00:00,  1.63s/it, R2RLoss=0.129, EILoss=0.000248, TotalLoss=0.129]
Train epoch 12/20: 100%|██████████| 1/1 [00:01<00:00,  1.63s/it, R2RLoss=0.129, EILoss=0.000248, TotalLoss=0.129]

  0%|          | 0/1 [00:00<?, ?it/s]
Train epoch 13/20:   0%|          | 0/1 [00:00<?, ?it/s]
Train epoch 13/20:   0%|          | 0/1 [00:01<?, ?it/s, R2RLoss=0.124, EILoss=0.000239, TotalLoss=0.125]
Train epoch 13/20: 100%|██████████| 1/1 [00:01<00:00,  1.62s/it, R2RLoss=0.124, EILoss=0.000239, TotalLoss=0.125]
Train epoch 13/20: 100%|██████████| 1/1 [00:01<00:00,  1.62s/it, R2RLoss=0.124, EILoss=0.000239, TotalLoss=0.125]

  0%|          | 0/1 [00:00<?, ?it/s]
Train epoch 14/20:   0%|          | 0/1 [00:00<?, ?it/s]
Train epoch 14/20:   0%|          | 0/1 [00:01<?, ?it/s, R2RLoss=0.129, EILoss=0.000236, TotalLoss=0.129]
Train epoch 14/20: 100%|██████████| 1/1 [00:01<00:00,  1.61s/it, R2RLoss=0.129, EILoss=0.000236, TotalLoss=0.129]
Train epoch 14/20: 100%|██████████| 1/1 [00:01<00:00,  1.61s/it, R2RLoss=0.129, EILoss=0.000236, TotalLoss=0.129]

  0%|          | 0/1 [00:00<?, ?it/s]
Train epoch 15/20:   0%|          | 0/1 [00:00<?, ?it/s]
Train epoch 15/20:   0%|          | 0/1 [00:01<?, ?it/s, R2RLoss=0.132, EILoss=0.000249, TotalLoss=0.132]
Train epoch 15/20: 100%|██████████| 1/1 [00:01<00:00,  1.59s/it, R2RLoss=0.132, EILoss=0.000249, TotalLoss=0.132]
Train epoch 15/20: 100%|██████████| 1/1 [00:01<00:00,  1.59s/it, R2RLoss=0.132, EILoss=0.000249, TotalLoss=0.132]

  0%|          | 0/1 [00:00<?, ?it/s]
Train epoch 16/20:   0%|          | 0/1 [00:00<?, ?it/s]
Train epoch 16/20:   0%|          | 0/1 [00:01<?, ?it/s, R2RLoss=0.13, EILoss=0.000229, TotalLoss=0.13]
Train epoch 16/20: 100%|██████████| 1/1 [00:01<00:00,  1.58s/it, R2RLoss=0.13, EILoss=0.000229, TotalLoss=0.13]
Train epoch 16/20: 100%|██████████| 1/1 [00:01<00:00,  1.58s/it, R2RLoss=0.13, EILoss=0.000229, TotalLoss=0.13]

  0%|          | 0/1 [00:00<?, ?it/s]
Eval epoch 16/20:   0%|          | 0/1 [00:00<?, ?it/s]
Eval epoch 16/20:   0%|          | 0/1 [00:01<?, ?it/s, R2RLoss=3.84]
Eval epoch 16/20: 100%|██████████| 1/1 [00:01<00:00,  1.42s/it, R2RLoss=3.84]
Eval epoch 16/20: 100%|██████████| 1/1 [00:01<00:00,  1.42s/it, R2RLoss=3.84]
Early stopping triggered as validation metrics have not improved in the last 3 validation steps, disable it with early_stop=False

We can now use the fine-tuned model to reconstruct the image from the measurement y.

with torch.no_grad():
    x_hat_ft = finetuned_model(y, physics)

# Show results
dinv.utils.plot(
    {
        "Original": x,
        f"Measurement": y,
        f"Zero-shot \nReconstruction": x_hat,
        f"Fine-tuned \nReconstruction": x_hat_ft,
    },
    subtitles=[
        "PSNR:",
        f"{psnr(y, x).item():.2f} dB",
        f"{psnr(x, x_hat).item():.2f} dB",
        f"{psnr(x, x_hat_ft).item():.2f} dB",
    ],
    figsize=(6, 4),
)
Original, Measurement, Zero-shot  Reconstruction, Fine-tuned  Reconstruction
References:

Total running time of the script: (0 minutes 54.946 seconds)

Gallery generated by Sphinx-Gallery