Note
New to DeepInverse? Get started with the basics with the 5 minute quickstart tutorial.
Inference and fine-tune a foundation model#
This example shows how to perform inference on and fine-tune the Reconstruct Anything Model (RAM) foundation model [1] to solve inverse problems.
The Reconstruct Anything Model
is a model that has been trained to work on a large
variety of linear image reconstruction tasks and datasets (deblurring, inpainting, denoising, tomography, MRI, etc.)
and is robust to a wide variety of imaging domains.
Tip
Want to use your own dataset? See Bring your own dataset
Want to use your own physics? See Bring your own physics
import deepinv as dinv
import torch
device = dinv.utils.get_freer_gpu() if torch.cuda.is_available() else "cpu"
model = dinv.models.RAM(device=device, pretrained=True)
1. Zero-shot inference#
First, let’s evaluate the zero-shot inference performance of the foundation model.
Accelerated medical imaging#
Here, we demonstrated reconstructing brain MRI from an accelerated noisy MRI scan from FastMRI:
x = dinv.utils.load_example("demo_mini_subset_fastmri_brain_0.pt", device=device)
# Define physics
physics = dinv.physics.MRI(noise_model=dinv.physics.GaussianNoise(0.05), device=device)
physics_generator = dinv.physics.generator.GaussianMaskGenerator(
(320, 320), device=device
)
# Generate measurement
y = physics(x, **physics_generator.step())
# Perform inference
with torch.no_grad():
x_hat = model(y, physics)
x_lin = physics.A_adjoint(y)
psnr = dinv.metric.PSNR()
dinv.utils.plot(
{
"Ground truth": x,
f"Linear inverse\n PSNR {psnr(x_lin, x).item():.2f}dB": x_lin,
f"Pretrained RAM\n PSNR {psnr(x_hat, x).item():.2f}dB": x_hat,
}
)

Computational photography#
Joint random motion deblurring and denoising, using a cropped image from color BSD:
x = dinv.utils.load_example("CBSD_0010.png", img_size=(200, 200), device=device)
physics = dinv.physics.BlurFFT(
img_size=x.shape[1:],
noise_model=dinv.physics.GaussianNoise(sigma=0.05),
device=device,
)
# fmt: off
physics_generator = (
dinv.physics.generator.MotionBlurGenerator((31, 31), l=2.0, sigma=2.4, device=device) +
dinv.physics.generator.SigmaGenerator(sigma_min=0.001, sigma_max=0.2, device=device)
)
# fmt: on
y = physics(x, **physics_generator.step())
with torch.no_grad():
x_hat = model(y, physics)
x_lin = physics.A_adjoint(y)
dinv.utils.plot(
{
"Ground truth": x,
f"Linear inverse\n PSNR {psnr(x_lin, x).item():.2f}dB": x_lin,
f"Pretrained RAM\n PSNR {psnr(x_hat, x).item():.2f}dB": x_hat,
}
)

Tomography#
Computed Tomography with limited angles using data from the The Cancer Imaging Archive of lungs:
x = dinv.utils.load_example("CT100_256x256_0.pt", device=device)
physics = dinv.physics.Tomography(
img_width=256,
angles=10,
normalize=True,
device=device,
)
y = physics(x)
with torch.no_grad():
x_hat = model(y, physics)
x_lin = physics.A_dagger(y)
dinv.utils.plot(
{
"Ground truth": x,
f"FBP pseudo-inverse\n PSNR {psnr(x_lin, x).item():.2f}dB": x_lin,
f"Pretrained RAM\n PSNR {psnr(x_hat, x).item():.2f}dB": x_hat,
}
)

Remote sensing#
Satellite denoising with Poisson-Gaussian noise using urban data from the WorldView-3 satellite over Jacksonville:
x = dinv.utils.load_example("JAX_018_011_RGB.tif", device=device)[..., :300, :300]
physics = dinv.physics.Denoising(
noise_model=dinv.physics.PoissonGaussianNoise(sigma=0.1, gain=0.1)
)
y = physics(x)
with torch.no_grad():
x_hat = model(y, physics)
# Alternatively, use the model without physics:
# x_hat = model(y, sigma=0.1, gain=0.1)
x_lin = physics.A_adjoint(y)
dinv.utils.plot(
{
"Ground truth": x,
f"Linear inverse\n PSNR {psnr(x_lin, x).item():.2f}dB": x_lin,
f"Pretrained RAM\n PSNR {psnr(x_hat, x).item():.2f}dB": x_hat,
}
)

2. Fine-tuning#
As with all models, there may be a drop in performance when used zero-shot on problems or data outside those seen during training.
For instance, RAM is not trained on image demosaicing:
x = dinv.utils.load_example("butterfly.png", img_size=(127, 129), device=device)
physics = dinv.physics.Demosaicing(
img_size=x.shape[1:], noise_model=dinv.physics.PoissonNoise(0.1), device=device
)
# Generate measurement
y = physics(x)
# Run inference
with torch.no_grad():
x_hat = model(y, physics)
# Show results
dinv.utils.plot(
{
"Original": x,
f"Measurement\n PSNR {psnr(y, x).item():.2f}dB": y,
f"Reconstruction\n PSNR {psnr(x_hat, x).item():.2f}dB": x_hat,
},
)

To improve results, we can fine-tune the model on our problem and data, even in the absence of ground truth data, using a self-supervised loss, and even on a single image only.
Here, since this example is run in a no-GPU environment, we will use a small patch of the image to speed up training, but in practice, we can use the full image.
Note
You can also fine-tune on larger datasets if you want, by replacing the dataset.
# Take small patch
x_train = x[..., :64, :64]
physics_train = dinv.physics.Demosaicing(
img_size=x_train.shape[1:],
noise_model=dinv.physics.PoissonNoise(0.1, clip_positive=True),
device=device,
)
y_train = physics_train(x_train)
# Define training loss
losses = [
dinv.loss.R2RLoss(),
dinv.loss.EILoss(dinv.transform.Shift(shift_max=0.4), weight=0.1),
]
dataset = dinv.datasets.TensorDataset(y=y_train)
train_dataloader = torch.utils.data.DataLoader(dataset)
We fine-tune using early stopping using a validation set, again without ground truth. We use a small patch of another set of measurements.
eval_dataloader = torch.utils.data.DataLoader(
dinv.datasets.TensorDataset(
y=physics_train(
dinv.utils.load_example("leaves.png", device=device)[..., :64, :64]
)
)
)
max_epochs = 20
trainer = dinv.Trainer(
model=model,
physics=physics_train,
eval_interval=5,
ckp_interval=max_epochs - 1,
metrics=losses[0],
early_stop=True,
device=device,
losses=losses,
epochs=max_epochs,
optimizer=torch.optim.Adam(model.parameters(), lr=5e-5),
train_dataloader=train_dataloader,
eval_dataloader=eval_dataloader,
)
finetuned_model = trainer.train()
The model has 35618953 trainable parameters
0%| | 0/1 [00:00<?, ?it/s]
Train epoch 1/20: 0%| | 0/1 [00:00<?, ?it/s]
Train epoch 1/20: 0%| | 0/1 [00:01<?, ?it/s, R2RLoss=0.127, EILoss=0.00117, TotalLoss=0.128]
Train epoch 1/20: 100%|█████████████████████████████████████████████████| 1/1 [00:01<00:00, 1.63s/it, R2RLoss=0.127, EILoss=0.00117, TotalLoss=0.128]
Train epoch 1/20: 100%|█████████████████████████████████████████████████| 1/1 [00:01<00:00, 1.63s/it, R2RLoss=0.127, EILoss=0.00117, TotalLoss=0.128]
0%| | 0/1 [00:00<?, ?it/s]
Eval epoch 1/20: 0%| | 0/1 [00:00<?, ?it/s]
Eval epoch 1/20: 0%| | 0/1 [00:01<?, ?it/s, R2RLoss=3.67]
Eval epoch 1/20: 100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:01<00:00, 1.42s/it, R2RLoss=3.67]
Eval epoch 1/20: 100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:01<00:00, 1.42s/it, R2RLoss=3.67]
Best model saved at epoch 1
0%| | 0/1 [00:00<?, ?it/s]
Train epoch 2/20: 0%| | 0/1 [00:00<?, ?it/s]
Train epoch 2/20: 0%| | 0/1 [00:01<?, ?it/s, R2RLoss=0.122, EILoss=0.000238, TotalLoss=0.123]
Train epoch 2/20: 100%|████████████████████████████████████████████████| 1/1 [00:01<00:00, 1.56s/it, R2RLoss=0.122, EILoss=0.000238, TotalLoss=0.123]
Train epoch 2/20: 100%|████████████████████████████████████████████████| 1/1 [00:01<00:00, 1.56s/it, R2RLoss=0.122, EILoss=0.000238, TotalLoss=0.123]
0%| | 0/1 [00:00<?, ?it/s]
Train epoch 3/20: 0%| | 0/1 [00:00<?, ?it/s]
Train epoch 3/20: 0%| | 0/1 [00:01<?, ?it/s, R2RLoss=0.128, EILoss=0.000214, TotalLoss=0.128]
Train epoch 3/20: 100%|████████████████████████████████████████████████| 1/1 [00:01<00:00, 1.56s/it, R2RLoss=0.128, EILoss=0.000214, TotalLoss=0.128]
Train epoch 3/20: 100%|████████████████████████████████████████████████| 1/1 [00:01<00:00, 1.56s/it, R2RLoss=0.128, EILoss=0.000214, TotalLoss=0.128]
0%| | 0/1 [00:00<?, ?it/s]
Train epoch 4/20: 0%| | 0/1 [00:00<?, ?it/s]
Train epoch 4/20: 0%| | 0/1 [00:01<?, ?it/s, R2RLoss=0.129, EILoss=0.000209, TotalLoss=0.129]
Train epoch 4/20: 100%|████████████████████████████████████████████████| 1/1 [00:01<00:00, 1.57s/it, R2RLoss=0.129, EILoss=0.000209, TotalLoss=0.129]
Train epoch 4/20: 100%|████████████████████████████████████████████████| 1/1 [00:01<00:00, 1.57s/it, R2RLoss=0.129, EILoss=0.000209, TotalLoss=0.129]
0%| | 0/1 [00:00<?, ?it/s]
Train epoch 5/20: 0%| | 0/1 [00:00<?, ?it/s]
Train epoch 5/20: 0%| | 0/1 [00:01<?, ?it/s, R2RLoss=0.126, EILoss=0.000215, TotalLoss=0.126]
Train epoch 5/20: 100%|████████████████████████████████████████████████| 1/1 [00:01<00:00, 1.57s/it, R2RLoss=0.126, EILoss=0.000215, TotalLoss=0.126]
Train epoch 5/20: 100%|████████████████████████████████████████████████| 1/1 [00:01<00:00, 1.57s/it, R2RLoss=0.126, EILoss=0.000215, TotalLoss=0.126]
0%| | 0/1 [00:00<?, ?it/s]
Train epoch 6/20: 0%| | 0/1 [00:00<?, ?it/s]
Train epoch 6/20: 0%| | 0/1 [00:01<?, ?it/s, R2RLoss=0.121, EILoss=0.000266, TotalLoss=0.122]
Train epoch 6/20: 100%|████████████████████████████████████████████████| 1/1 [00:01<00:00, 1.56s/it, R2RLoss=0.121, EILoss=0.000266, TotalLoss=0.122]
Train epoch 6/20: 100%|████████████████████████████████████████████████| 1/1 [00:01<00:00, 1.56s/it, R2RLoss=0.121, EILoss=0.000266, TotalLoss=0.122]
0%| | 0/1 [00:00<?, ?it/s]
Eval epoch 6/20: 0%| | 0/1 [00:00<?, ?it/s]
Eval epoch 6/20: 0%| | 0/1 [00:01<?, ?it/s, R2RLoss=3.71]
Eval epoch 6/20: 100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:01<00:00, 1.43s/it, R2RLoss=3.71]
Eval epoch 6/20: 100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:01<00:00, 1.43s/it, R2RLoss=3.71]
0%| | 0/1 [00:00<?, ?it/s]
Train epoch 7/20: 0%| | 0/1 [00:00<?, ?it/s]
Train epoch 7/20: 0%| | 0/1 [00:01<?, ?it/s, R2RLoss=0.13, EILoss=0.000209, TotalLoss=0.13]
Train epoch 7/20: 100%|██████████████████████████████████████████████████| 1/1 [00:01<00:00, 1.56s/it, R2RLoss=0.13, EILoss=0.000209, TotalLoss=0.13]
Train epoch 7/20: 100%|██████████████████████████████████████████████████| 1/1 [00:01<00:00, 1.56s/it, R2RLoss=0.13, EILoss=0.000209, TotalLoss=0.13]
0%| | 0/1 [00:00<?, ?it/s]
Train epoch 8/20: 0%| | 0/1 [00:00<?, ?it/s]
Train epoch 8/20: 0%| | 0/1 [00:01<?, ?it/s, R2RLoss=0.121, EILoss=0.000281, TotalLoss=0.122]
Train epoch 8/20: 100%|████████████████████████████████████████████████| 1/1 [00:01<00:00, 1.56s/it, R2RLoss=0.121, EILoss=0.000281, TotalLoss=0.122]
Train epoch 8/20: 100%|████████████████████████████████████████████████| 1/1 [00:01<00:00, 1.56s/it, R2RLoss=0.121, EILoss=0.000281, TotalLoss=0.122]
0%| | 0/1 [00:00<?, ?it/s]
Train epoch 9/20: 0%| | 0/1 [00:00<?, ?it/s]
Train epoch 9/20: 0%| | 0/1 [00:01<?, ?it/s, R2RLoss=0.126, EILoss=0.000263, TotalLoss=0.126]
Train epoch 9/20: 100%|████████████████████████████████████████████████| 1/1 [00:01<00:00, 1.56s/it, R2RLoss=0.126, EILoss=0.000263, TotalLoss=0.126]
Train epoch 9/20: 100%|████████████████████████████████████████████████| 1/1 [00:01<00:00, 1.56s/it, R2RLoss=0.126, EILoss=0.000263, TotalLoss=0.126]
0%| | 0/1 [00:00<?, ?it/s]
Train epoch 10/20: 0%| | 0/1 [00:00<?, ?it/s]
Train epoch 10/20: 0%| | 0/1 [00:01<?, ?it/s, R2RLoss=0.122, EILoss=0.000253, TotalLoss=0.122]
Train epoch 10/20: 100%|███████████████████████████████████████████████| 1/1 [00:01<00:00, 1.56s/it, R2RLoss=0.122, EILoss=0.000253, TotalLoss=0.122]
Train epoch 10/20: 100%|███████████████████████████████████████████████| 1/1 [00:01<00:00, 1.56s/it, R2RLoss=0.122, EILoss=0.000253, TotalLoss=0.122]
0%| | 0/1 [00:00<?, ?it/s]
Train epoch 11/20: 0%| | 0/1 [00:00<?, ?it/s]
Train epoch 11/20: 0%| | 0/1 [00:01<?, ?it/s, R2RLoss=0.129, EILoss=0.000225, TotalLoss=0.129]
Train epoch 11/20: 100%|███████████████████████████████████████████████| 1/1 [00:01<00:00, 1.56s/it, R2RLoss=0.129, EILoss=0.000225, TotalLoss=0.129]
Train epoch 11/20: 100%|███████████████████████████████████████████████| 1/1 [00:01<00:00, 1.56s/it, R2RLoss=0.129, EILoss=0.000225, TotalLoss=0.129]
0%| | 0/1 [00:00<?, ?it/s]
Eval epoch 11/20: 0%| | 0/1 [00:00<?, ?it/s]
Eval epoch 11/20: 0%| | 0/1 [00:01<?, ?it/s, R2RLoss=3.69]
Eval epoch 11/20: 100%|███████████████████████████████████████████████████████████████████████████████████| 1/1 [00:01<00:00, 1.43s/it, R2RLoss=3.69]
Eval epoch 11/20: 100%|███████████████████████████████████████████████████████████████████████████████████| 1/1 [00:01<00:00, 1.43s/it, R2RLoss=3.69]
0%| | 0/1 [00:00<?, ?it/s]
Train epoch 12/20: 0%| | 0/1 [00:00<?, ?it/s]
Train epoch 12/20: 0%| | 0/1 [00:01<?, ?it/s, R2RLoss=0.128, EILoss=0.000226, TotalLoss=0.128]
Train epoch 12/20: 100%|███████████████████████████████████████████████| 1/1 [00:01<00:00, 1.56s/it, R2RLoss=0.128, EILoss=0.000226, TotalLoss=0.128]
Train epoch 12/20: 100%|███████████████████████████████████████████████| 1/1 [00:01<00:00, 1.56s/it, R2RLoss=0.128, EILoss=0.000226, TotalLoss=0.128]
0%| | 0/1 [00:00<?, ?it/s]
Train epoch 13/20: 0%| | 0/1 [00:00<?, ?it/s]
Train epoch 13/20: 0%| | 0/1 [00:01<?, ?it/s, R2RLoss=0.121, EILoss=0.000183, TotalLoss=0.121]
Train epoch 13/20: 100%|███████████████████████████████████████████████| 1/1 [00:01<00:00, 1.55s/it, R2RLoss=0.121, EILoss=0.000183, TotalLoss=0.121]
Train epoch 13/20: 100%|███████████████████████████████████████████████| 1/1 [00:01<00:00, 1.55s/it, R2RLoss=0.121, EILoss=0.000183, TotalLoss=0.121]
0%| | 0/1 [00:00<?, ?it/s]
Train epoch 14/20: 0%| | 0/1 [00:00<?, ?it/s]
Train epoch 14/20: 0%| | 0/1 [00:01<?, ?it/s, R2RLoss=0.125, EILoss=0.000218, TotalLoss=0.125]
Train epoch 14/20: 100%|███████████████████████████████████████████████| 1/1 [00:01<00:00, 1.56s/it, R2RLoss=0.125, EILoss=0.000218, TotalLoss=0.125]
Train epoch 14/20: 100%|███████████████████████████████████████████████| 1/1 [00:01<00:00, 1.56s/it, R2RLoss=0.125, EILoss=0.000218, TotalLoss=0.125]
0%| | 0/1 [00:00<?, ?it/s]
Train epoch 15/20: 0%| | 0/1 [00:00<?, ?it/s]
Train epoch 15/20: 0%| | 0/1 [00:01<?, ?it/s, R2RLoss=0.126, EILoss=0.000207, TotalLoss=0.126]
Train epoch 15/20: 100%|███████████████████████████████████████████████| 1/1 [00:01<00:00, 1.57s/it, R2RLoss=0.126, EILoss=0.000207, TotalLoss=0.126]
Train epoch 15/20: 100%|███████████████████████████████████████████████| 1/1 [00:01<00:00, 1.57s/it, R2RLoss=0.126, EILoss=0.000207, TotalLoss=0.126]
0%| | 0/1 [00:00<?, ?it/s]
Train epoch 16/20: 0%| | 0/1 [00:00<?, ?it/s]
Train epoch 16/20: 0%| | 0/1 [00:01<?, ?it/s, R2RLoss=0.126, EILoss=0.000236, TotalLoss=0.127]
Train epoch 16/20: 100%|███████████████████████████████████████████████| 1/1 [00:01<00:00, 1.56s/it, R2RLoss=0.126, EILoss=0.000236, TotalLoss=0.127]
Train epoch 16/20: 100%|███████████████████████████████████████████████| 1/1 [00:01<00:00, 1.56s/it, R2RLoss=0.126, EILoss=0.000236, TotalLoss=0.127]
0%| | 0/1 [00:00<?, ?it/s]
Eval epoch 16/20: 0%| | 0/1 [00:00<?, ?it/s]
Eval epoch 16/20: 0%| | 0/1 [00:01<?, ?it/s, R2RLoss=3.72]
Eval epoch 16/20: 100%|███████████████████████████████████████████████████████████████████████████████████| 1/1 [00:01<00:00, 1.42s/it, R2RLoss=3.72]
Eval epoch 16/20: 100%|███████████████████████████████████████████████████████████████████████████████████| 1/1 [00:01<00:00, 1.42s/it, R2RLoss=3.72]
Early stopping triggered as validation metrics have not improved in the last 3 validation steps, disable it with early_stop=False
We can now use the fine-tuned model to reconstruct the image from the measurement y
.
with torch.no_grad():
x_hat_ft = finetuned_model(y, physics)
# Show results
dinv.utils.plot(
{
"Original": x,
f"Measurement\n PSNR {psnr(y, x).item():.2f}dB": y,
f"Zero-shot reconstruction\n PSNR {psnr(x_hat, x).item():.2f}dB": x_hat,
f"Fine-tuned reconstruction\n PSNR {psnr(x_hat_ft, x).item():.2f}dB": x_hat_ft,
},
)

- References:
Total running time of the script: (0 minutes 54.152 seconds)