Reconstruct Anything Model (RAM) for solving inverse problems.#

This example shows how to use the RAM foundation model to solve inverse problems. The RAM model, described in the following paper, is a modified DRUNet architecture that is trained on a large number of inverse problems.

import torch
import deepinv as dinv
from deepinv.models import RAM

device = "cuda" if torch.cuda.is_available() else "cpu"

# Load the pretrained model
model = RAM(device=device)

# load image
x = dinv.utils.load_example("butterfly.png", img_size=(127, 129)).to(device)

# create forward operator
physics = dinv.physics.Inpainting(
    img_size=(3, 127, 129),
    mask=0.3,
    noise_model=dinv.physics.GaussianNoise(0.05),
    device=device,
)

# generate measurement
y = physics(x)

# run inference
with torch.no_grad():
    x_hat = model(y, physics=physics)

# compute PSNR
in_psnr = dinv.metric.PSNR()(x, y).item()
out_psnr = dinv.metric.PSNR()(x, x_hat).item()

# plot
dinv.utils.plot(
    [x, y, x_hat],
    [
        "Original",
        "Measurement\n PSNR = {:.2f}dB".format(in_psnr),
        "Reconstruction\n PSNR = {:.2f}dB".format(out_psnr),
    ],
    figsize=(8, 3),
)
Original, Measurement  PSNR = 6.03dB, Reconstruction  PSNR = 25.80dB
Downloading: "https://huggingface.co/mterris/ram/resolve/main/ram.pth.tar" to /home/runner/.cache/torch/hub/checkpoints/ram.pth.tar

  0%|          | 0.00/136M [00:00<?, ?B/s]
 17%|█▋        | 23.6M/136M [00:00<00:00, 247MB/s]
 37%|███▋      | 50.1M/136M [00:00<00:00, 265MB/s]
 56%|█████▌    | 76.1M/136M [00:00<00:00, 268MB/s]
 75%|███████▍  | 102M/136M [00:00<00:00, 255MB/s]
 95%|█████████▌| 130M/136M [00:00<00:00, 268MB/s]
100%|██████████| 136M/136M [00:00<00:00, 265MB/s]

This model was also trained on various denoising problems, in particular on Poisson-Gaussian denoising.

sigma, gain = 0.2, 0.5
physics = dinv.physics.Denoising(
    noise_model=dinv.physics.PoissonGaussianNoise(sigma=sigma, gain=gain),
)

# generate measurement
y = physics(x)

# run inference
with torch.no_grad():
    x_hat = model(y, physics=physics)
    # or alternatively, we can use the model without physics:
    # x_hat = model(y, sigma=sigma, gain=gain)

# compute PSNR
in_psnr = dinv.metric.PSNR()(x, y).item()
out_psnr = dinv.metric.PSNR()(x, x_hat).item()

# plot
dinv.utils.plot(
    [x, y, x_hat],
    [
        "Original",
        "Measurement\n PSNR = {:.2f}dB".format(in_psnr),
        "Reconstruction\n PSNR = {:.2f}dB".format(out_psnr),
    ],
    figsize=(8, 3),
)
Original, Measurement  PSNR = 5.20dB, Reconstruction  PSNR = 24.42dB

This model is not trained on all degradations, so it may not perform well on all inverse problems out-of-the-box. For instance, it is not trained on image demosaicing. Applying it to a demosaicing problem out-of-the-box will yield poor results, as shown in the following example:

# Define the Demosaicing physics
physics = dinv.physics.Demosaicing(
    img_size=x.shape[1:], noise_model=dinv.physics.PoissonNoise(0.1), device=device
)

# generate measurement
y = physics(x)

# run inference
with torch.no_grad():
    x_hat = model(y, physics=physics)

# compute PSNR
in_psnr = dinv.metric.PSNR()(x, y).item()
out_psnr = dinv.metric.PSNR()(x, x_hat).item()

# plot
dinv.utils.plot(
    [x, y, x_hat],
    [
        "Original",
        "Measurement\n PSNR = {:.2f}dB".format(in_psnr),
        "0 shot reconstruction\n PSNR = {:.2f}dB".format(out_psnr),
    ],
    figsize=(8, 3),
)
Original, Measurement  PSNR = 5.99dB, 0 shot reconstruction  PSNR = 21.78dB

To improve results, we can fine-tune the model on the specific degradation for the sample of interest. This can be done even in the absence of ground truth data, using unsupervised training. We showcase this in the following, where the model is fine-tuned on the measurement vector y itself. Here, since this example is run in a no-GPU environment, we will use a small patch of the image to speed up training, but in practice, we can use the full image.

First, we will create a dataset for unsupervised training that

class UnsupDataset(torch.utils.data.Dataset):
    r"""
    Dataset for unsupervised learning tasks.

    This dataset is used to return only the data without any labels.

    :param torch.Tensor data: Input data tensor of shape (N, ...), where N is the number of samples and ... represents the data dimensions.
    """

    def __init__(self, data):
        self.data = data

    def __len__(self):
        return self.data.size(0)

    def __getitem__(self, idx):
        return torch.nan, self.data[idx]


physics_train = dinv.physics.Demosaicing(
    img_size=(3, 64, 64),
    noise_model=dinv.physics.PoissonNoise(0.1, clip_positive=True),
    device=device,
)
x_train = x[..., :64, :64]  # take a small patch of the image
y_train = physics_train(x_train)

losses = [
    dinv.loss.R2RLoss(),
    dinv.loss.EILoss(dinv.transform.Shift(shift_max=0.4), weight=0.1),
]

dataset = UnsupDataset(y_train)

train_dataloader = torch.utils.data.DataLoader(dataset)

In order to check the performance of the fine-tuned model, we will use a validation set. We will use a small patch of another image. Note that this validation is also performed in an unsupervised manner, so we will not use the ground truth validation image.

The model has 35618953 trainable parameters

  0%|                                                                                                                           | 0/1 [00:00<?, ?it/s]
Train epoch 1/20:   0%|                                                                                                         | 0/1 [00:00<?, ?it/s]
Train epoch 1/20:   0%|                                                         | 0/1 [00:01<?, ?it/s, R2RLoss=0.134, EILoss=0.00132, TotalLoss=0.135]
Train epoch 1/20: 100%|█████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.60s/it, R2RLoss=0.134, EILoss=0.00132, TotalLoss=0.135]
Train epoch 1/20: 100%|█████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.60s/it, R2RLoss=0.134, EILoss=0.00132, TotalLoss=0.135]

  0%|                                                                                                                           | 0/1 [00:00<?, ?it/s]
Eval epoch 1/20:   0%|                                                                                                          | 0/1 [00:00<?, ?it/s]
Eval epoch 1/20:   0%|                                                                                            | 0/1 [00:01<?, ?it/s, R2RLoss=3.68]
Eval epoch 1/20: 100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.40s/it, R2RLoss=3.68]
Eval epoch 1/20: 100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.40s/it, R2RLoss=3.68]
Best model saved at epoch 1

  0%|                                                                                                                           | 0/1 [00:00<?, ?it/s]
Train epoch 2/20:   0%|                                                                                                         | 0/1 [00:00<?, ?it/s]
Train epoch 2/20:   0%|                                                        | 0/1 [00:01<?, ?it/s, R2RLoss=0.125, EILoss=0.000299, TotalLoss=0.125]
Train epoch 2/20: 100%|████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.57s/it, R2RLoss=0.125, EILoss=0.000299, TotalLoss=0.125]
Train epoch 2/20: 100%|████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.57s/it, R2RLoss=0.125, EILoss=0.000299, TotalLoss=0.125]

  0%|                                                                                                                           | 0/1 [00:00<?, ?it/s]
Train epoch 3/20:   0%|                                                                                                         | 0/1 [00:00<?, ?it/s]
Train epoch 3/20:   0%|                                                        | 0/1 [00:01<?, ?it/s, R2RLoss=0.131, EILoss=0.000259, TotalLoss=0.132]
Train epoch 3/20: 100%|████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.57s/it, R2RLoss=0.131, EILoss=0.000259, TotalLoss=0.132]
Train epoch 3/20: 100%|████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.57s/it, R2RLoss=0.131, EILoss=0.000259, TotalLoss=0.132]

  0%|                                                                                                                           | 0/1 [00:00<?, ?it/s]
Train epoch 4/20:   0%|                                                                                                         | 0/1 [00:00<?, ?it/s]
Train epoch 4/20:   0%|                                                        | 0/1 [00:01<?, ?it/s, R2RLoss=0.129, EILoss=0.000192, TotalLoss=0.129]
Train epoch 4/20: 100%|████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.56s/it, R2RLoss=0.129, EILoss=0.000192, TotalLoss=0.129]
Train epoch 4/20: 100%|████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.56s/it, R2RLoss=0.129, EILoss=0.000192, TotalLoss=0.129]

  0%|                                                                                                                           | 0/1 [00:00<?, ?it/s]
Train epoch 5/20:   0%|                                                                                                         | 0/1 [00:00<?, ?it/s]
Train epoch 5/20:   0%|                                                        | 0/1 [00:01<?, ?it/s, R2RLoss=0.128, EILoss=0.000207, TotalLoss=0.128]
Train epoch 5/20: 100%|████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.57s/it, R2RLoss=0.128, EILoss=0.000207, TotalLoss=0.128]
Train epoch 5/20: 100%|████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.57s/it, R2RLoss=0.128, EILoss=0.000207, TotalLoss=0.128]

  0%|                                                                                                                           | 0/1 [00:00<?, ?it/s]
Train epoch 6/20:   0%|                                                                                                         | 0/1 [00:00<?, ?it/s]
Train epoch 6/20:   0%|                                                        | 0/1 [00:01<?, ?it/s, R2RLoss=0.135, EILoss=0.000213, TotalLoss=0.136]
Train epoch 6/20: 100%|████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.56s/it, R2RLoss=0.135, EILoss=0.000213, TotalLoss=0.136]
Train epoch 6/20: 100%|████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.56s/it, R2RLoss=0.135, EILoss=0.000213, TotalLoss=0.136]

  0%|                                                                                                                           | 0/1 [00:00<?, ?it/s]
Eval epoch 6/20:   0%|                                                                                                          | 0/1 [00:00<?, ?it/s]
Eval epoch 6/20:   0%|                                                                                            | 0/1 [00:01<?, ?it/s, R2RLoss=3.71]
Eval epoch 6/20: 100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.40s/it, R2RLoss=3.71]
Eval epoch 6/20: 100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.40s/it, R2RLoss=3.71]

  0%|                                                                                                                           | 0/1 [00:00<?, ?it/s]
Train epoch 7/20:   0%|                                                                                                         | 0/1 [00:00<?, ?it/s]
Train epoch 7/20:   0%|                                                        | 0/1 [00:01<?, ?it/s, R2RLoss=0.128, EILoss=0.000234, TotalLoss=0.128]
Train epoch 7/20: 100%|████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.55s/it, R2RLoss=0.128, EILoss=0.000234, TotalLoss=0.128]
Train epoch 7/20: 100%|████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.55s/it, R2RLoss=0.128, EILoss=0.000234, TotalLoss=0.128]

  0%|                                                                                                                           | 0/1 [00:00<?, ?it/s]
Train epoch 8/20:   0%|                                                                                                         | 0/1 [00:00<?, ?it/s]
Train epoch 8/20:   0%|                                                        | 0/1 [00:01<?, ?it/s, R2RLoss=0.129, EILoss=0.000261, TotalLoss=0.129]
Train epoch 8/20: 100%|████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.55s/it, R2RLoss=0.129, EILoss=0.000261, TotalLoss=0.129]
Train epoch 8/20: 100%|████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.55s/it, R2RLoss=0.129, EILoss=0.000261, TotalLoss=0.129]

  0%|                                                                                                                           | 0/1 [00:00<?, ?it/s]
Train epoch 9/20:   0%|                                                                                                         | 0/1 [00:00<?, ?it/s]
Train epoch 9/20:   0%|                                                        | 0/1 [00:01<?, ?it/s, R2RLoss=0.131, EILoss=0.000208, TotalLoss=0.131]
Train epoch 9/20: 100%|████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.55s/it, R2RLoss=0.131, EILoss=0.000208, TotalLoss=0.131]
Train epoch 9/20: 100%|████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.55s/it, R2RLoss=0.131, EILoss=0.000208, TotalLoss=0.131]

  0%|                                                                                                                           | 0/1 [00:00<?, ?it/s]
Train epoch 10/20:   0%|                                                                                                        | 0/1 [00:00<?, ?it/s]
Train epoch 10/20:   0%|                                                       | 0/1 [00:01<?, ?it/s, R2RLoss=0.131, EILoss=0.000222, TotalLoss=0.131]
Train epoch 10/20: 100%|███████████████████████████████████████████████| 1/1 [00:01<00:00,  1.67s/it, R2RLoss=0.131, EILoss=0.000222, TotalLoss=0.131]
Train epoch 10/20: 100%|███████████████████████████████████████████████| 1/1 [00:01<00:00,  1.67s/it, R2RLoss=0.131, EILoss=0.000222, TotalLoss=0.131]

  0%|                                                                                                                           | 0/1 [00:00<?, ?it/s]
Train epoch 11/20:   0%|                                                                                                        | 0/1 [00:00<?, ?it/s]
Train epoch 11/20:   0%|                                                       | 0/1 [00:01<?, ?it/s, R2RLoss=0.124, EILoss=0.000252, TotalLoss=0.125]
Train epoch 11/20: 100%|███████████████████████████████████████████████| 1/1 [00:01<00:00,  1.56s/it, R2RLoss=0.124, EILoss=0.000252, TotalLoss=0.125]
Train epoch 11/20: 100%|███████████████████████████████████████████████| 1/1 [00:01<00:00,  1.56s/it, R2RLoss=0.124, EILoss=0.000252, TotalLoss=0.125]

  0%|                                                                                                                           | 0/1 [00:00<?, ?it/s]
Eval epoch 11/20:   0%|                                                                                                         | 0/1 [00:00<?, ?it/s]
Eval epoch 11/20:   0%|                                                                                           | 0/1 [00:01<?, ?it/s, R2RLoss=3.76]
Eval epoch 11/20: 100%|███████████████████████████████████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.40s/it, R2RLoss=3.76]
Eval epoch 11/20: 100%|███████████████████████████████████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.40s/it, R2RLoss=3.76]

  0%|                                                                                                                           | 0/1 [00:00<?, ?it/s]
Train epoch 12/20:   0%|                                                                                                        | 0/1 [00:00<?, ?it/s]
Train epoch 12/20:   0%|                                                       | 0/1 [00:01<?, ?it/s, R2RLoss=0.132, EILoss=0.000215, TotalLoss=0.132]
Train epoch 12/20: 100%|███████████████████████████████████████████████| 1/1 [00:01<00:00,  1.55s/it, R2RLoss=0.132, EILoss=0.000215, TotalLoss=0.132]
Train epoch 12/20: 100%|███████████████████████████████████████████████| 1/1 [00:01<00:00,  1.55s/it, R2RLoss=0.132, EILoss=0.000215, TotalLoss=0.132]

  0%|                                                                                                                           | 0/1 [00:00<?, ?it/s]
Train epoch 13/20:   0%|                                                                                                        | 0/1 [00:00<?, ?it/s]
Train epoch 13/20:   0%|                                                        | 0/1 [00:01<?, ?it/s, R2RLoss=0.132, EILoss=0.00022, TotalLoss=0.132]
Train epoch 13/20: 100%|████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.55s/it, R2RLoss=0.132, EILoss=0.00022, TotalLoss=0.132]
Train epoch 13/20: 100%|████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.55s/it, R2RLoss=0.132, EILoss=0.00022, TotalLoss=0.132]

  0%|                                                                                                                           | 0/1 [00:00<?, ?it/s]
Train epoch 14/20:   0%|                                                                                                        | 0/1 [00:00<?, ?it/s]
Train epoch 14/20:   0%|                                                       | 0/1 [00:01<?, ?it/s, R2RLoss=0.138, EILoss=0.000265, TotalLoss=0.138]
Train epoch 14/20: 100%|███████████████████████████████████████████████| 1/1 [00:01<00:00,  1.55s/it, R2RLoss=0.138, EILoss=0.000265, TotalLoss=0.138]
Train epoch 14/20: 100%|███████████████████████████████████████████████| 1/1 [00:01<00:00,  1.55s/it, R2RLoss=0.138, EILoss=0.000265, TotalLoss=0.138]

  0%|                                                                                                                           | 0/1 [00:00<?, ?it/s]
Train epoch 15/20:   0%|                                                                                                        | 0/1 [00:00<?, ?it/s]
Train epoch 15/20:   0%|                                                       | 0/1 [00:01<?, ?it/s, R2RLoss=0.127, EILoss=0.000266, TotalLoss=0.127]
Train epoch 15/20: 100%|███████████████████████████████████████████████| 1/1 [00:01<00:00,  1.55s/it, R2RLoss=0.127, EILoss=0.000266, TotalLoss=0.127]
Train epoch 15/20: 100%|███████████████████████████████████████████████| 1/1 [00:01<00:00,  1.55s/it, R2RLoss=0.127, EILoss=0.000266, TotalLoss=0.127]

  0%|                                                                                                                           | 0/1 [00:00<?, ?it/s]
Train epoch 16/20:   0%|                                                                                                        | 0/1 [00:00<?, ?it/s]
Train epoch 16/20:   0%|                                                       | 0/1 [00:01<?, ?it/s, R2RLoss=0.128, EILoss=0.000205, TotalLoss=0.128]
Train epoch 16/20: 100%|███████████████████████████████████████████████| 1/1 [00:01<00:00,  1.55s/it, R2RLoss=0.128, EILoss=0.000205, TotalLoss=0.128]
Train epoch 16/20: 100%|███████████████████████████████████████████████| 1/1 [00:01<00:00,  1.56s/it, R2RLoss=0.128, EILoss=0.000205, TotalLoss=0.128]

  0%|                                                                                                                           | 0/1 [00:00<?, ?it/s]
Eval epoch 16/20:   0%|                                                                                                         | 0/1 [00:00<?, ?it/s]
Eval epoch 16/20:   0%|                                                                                           | 0/1 [00:01<?, ?it/s, R2RLoss=3.76]
Eval epoch 16/20: 100%|███████████████████████████████████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.42s/it, R2RLoss=3.76]
Eval epoch 16/20: 100%|███████████████████████████████████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.42s/it, R2RLoss=3.76]
Early stopping triggered as validation metrics have not improved in the last 3 validation steps, disable it with early_stop=False

We can now use the fine-tuned model to reconstruct the image from the measurement vector y.

with torch.no_grad():
    x_hat = finetuned_model(y, physics=physics)

# compute PSNR
in_psnr = dinv.metric.PSNR()(x, y).item()
out_psnr = dinv.metric.PSNR()(x, x_hat).item()

# plot
dinv.utils.plot(
    [x, y, x_hat],
    [
        "Original",
        "Measurement\n PSNR = {:.2f}dB".format(in_psnr),
        "Finetuned reconstruction\n PSNR = {:.2f}dB".format(out_psnr),
    ],
    figsize=(8, 3),
)
Original, Measurement  PSNR = 5.99dB, Finetuned reconstruction  PSNR = 24.60dB

Total running time of the script: (0 minutes 39.778 seconds)

Gallery generated by Sphinx-Gallery