Solving blind inverse problems / estimating physics parameters#

This demo shows you how to use deepinv.physics.Physics together with automatic differentiation to optimize your operator.

Consider the forward model

\[y = \noise{\forw{x, \theta}}\]

where \(N\) is the noise model, \(\forw{\cdot, \theta}\) is the forward operator, and the goal is to learn the parameter \(\theta\) (e.g., the filter in deepinv.physics.Blur).

In a typical blind inverse problem, given a measurement \(y\), we would like to recover both the underlying image \(x\) and the operator parameter \(\theta\), resulting in a highly ill-posed inverse problem.

In this example, we only focus on a much more simpler problem: given the measurement \(y\) and the ground truth \(x\), find the parameter \(\theta\). This can be reformulated as the following optimization problem:

\[\min_{\theta} \frac{1}{2} \|\forw{x, \theta} - y \|^2\]

This problem can be addressed by first-order optimization if we can compute the gradient of the above function with respect to \(\theta\). The dependence between the operator \(A\) and the parameter \(\theta\) can be complicated. DeepInverse provides a wide range of physics operators, implemented as differentiable classes. We can leverage the automatic differentiation engine provided in Pytorch to compute the gradient of the above loss function w.r.t. the physics parameters \(\theta\).

The purpose of this demo is to show how to use the physics classes in DeepInverse to estimate the physics parameters, together with the automatic differentiation. We show 3 different ways to do this: manually implementing the projected gradient descent algorithm, using a Pytorch optimizer and optimizing the physics as a usual neural network.

Import required packages

import deepinv as dinv
import torch
from tqdm import tqdm
import matplotlib.pyplot as plt

device = dinv.utils.get_freer_gpu() if torch.cuda.is_available() else "cpu"
dtype = torch.float32

Define the physics#

In this first example, we use the convolution operator, defined in the deepinv.physics.Blur class. We also generate a random convolution kernel of motion blur

x = dinv.utils.load_url_image(
    dinv.utils.demo.get_image_url("celeba_example.jpg"),
    img_size=256,
    resize_mode="resize",
).to(device)

y = physics(x, filter=true_kernel)

dinv.utils.plot([x, y, true_kernel], titles=["Sharp", "Blurry", "True kernel"])
Sharp, Blurry, True kernel

Define an optimization algorithm#

The convolution kernel lives in the simplex, ie the kernel must have positive entries summing to 1. We can use a simple optimization algorithm - Projected Gradient Descent - to enforce this constraint. The following function allows one to compute the orthogonal projection onto the simplex, by a sorting algorithm (Reference: Large-scale Multiclass Support Vector Machine Training via Euclidean Projection onto the Simplex – Mathieu Blondel, Akinori Fujino, and Naonori Ueda)

@torch.no_grad()
def projection_simplex_sort(v: torch.Tensor) -> torch.Tensor:
    r"""
    Projects a tensor onto the simplex using a sorting algorithm.
    """
    shape = v.shape
    B = shape[0]
    v = v.view(B, -1)
    n_features = v.size(1)
    u = torch.sort(v, descending=True, dim=-1).values
    cssv = torch.cumsum(u, dim=-1) - 1.0
    ind = torch.arange(n_features, device=v.device)[None, :].expand(B, -1) + 1.0
    cond = u - cssv / ind > 0
    rho = ind[cond][-1]
    theta = cssv[cond][-1] / rho
    w = torch.maximum(v - theta, torch.zeros_like(v))
    return w.reshape(shape)


# We also define a data fidelity term
data_fidelity = dinv.optim.L2()

Run the algorithm

Initialize a constant kernel

kernel_init = torch.zeros_like(true_kernel)
kernel_init[..., 5:-5, 5:-5] = 1.0
kernel_init = projection_simplex_sort(kernel_init)
n_iter = 1000
stepsize = 0.7

kernel_hat = kernel_init
losses = []
for i in tqdm(range(n_iter)):
    # compute the gradient
    with torch.enable_grad():
        kernel_hat.requires_grad_(True)
        physics.update(filter=kernel_hat)
        loss = data_fidelity(y=y, x=x, physics=physics) / y.numel()
        loss.backward()
    grad = kernel_hat.grad

    # gradient step and projection step
    with torch.no_grad():
        kernel_hat = kernel_hat - stepsize * grad
        kernel_hat = projection_simplex_sort(kernel_hat)

    losses.append(loss.item())

dinv.utils.plot(
    [true_kernel, kernel_init, kernel_hat],
    titles=["True kernel", "Init. kernel", "Estimated kernel"],
    suptitle="Result with Projected Gradient Descent",
)
Result with Projected Gradient Descent, True kernel, Init. kernel, Estimated kernel
  0%|          | 0/1000 [00:00<?, ?it/s]
  0%|          | 3/1000 [00:00<00:43, 22.83it/s]
  1%|          | 6/1000 [00:00<00:42, 23.60it/s]
  1%|          | 9/1000 [00:00<00:41, 24.03it/s]
  1%|          | 12/1000 [00:00<00:40, 24.39it/s]
  2%|▏         | 15/1000 [00:00<00:39, 24.64it/s]
  2%|▏         | 18/1000 [00:00<00:39, 24.80it/s]
  2%|▏         | 21/1000 [00:00<00:39, 24.56it/s]
  2%|▏         | 24/1000 [00:00<00:39, 24.83it/s]
  3%|▎         | 27/1000 [00:01<00:38, 24.99it/s]
  3%|▎         | 30/1000 [00:01<00:38, 25.05it/s]
  3%|▎         | 33/1000 [00:01<00:38, 25.12it/s]
  4%|▎         | 36/1000 [00:01<00:38, 25.17it/s]
  4%|▍         | 39/1000 [00:01<00:38, 25.19it/s]
  4%|▍         | 42/1000 [00:01<00:38, 25.19it/s]
  4%|▍         | 45/1000 [00:01<00:37, 25.20it/s]
  5%|▍         | 48/1000 [00:01<00:38, 24.79it/s]
  5%|▌         | 51/1000 [00:02<00:38, 24.97it/s]
  5%|▌         | 54/1000 [00:02<00:37, 25.05it/s]
  6%|▌         | 57/1000 [00:02<00:37, 25.13it/s]
  6%|▌         | 60/1000 [00:02<00:37, 25.18it/s]
  6%|▋         | 63/1000 [00:02<00:37, 25.25it/s]
  7%|▋         | 66/1000 [00:02<00:36, 25.28it/s]
  7%|▋         | 69/1000 [00:02<00:36, 25.25it/s]
  7%|▋         | 72/1000 [00:02<00:37, 24.98it/s]
  8%|▊         | 75/1000 [00:03<00:36, 25.07it/s]
  8%|▊         | 78/1000 [00:03<00:36, 25.11it/s]
  8%|▊         | 81/1000 [00:03<00:36, 25.17it/s]
  8%|▊         | 84/1000 [00:03<00:36, 25.22it/s]
  9%|▊         | 87/1000 [00:03<00:36, 25.26it/s]
  9%|▉         | 90/1000 [00:03<00:37, 24.18it/s]
  9%|▉         | 93/1000 [00:03<00:36, 24.52it/s]
 10%|▉         | 96/1000 [00:03<00:36, 24.62it/s]
 10%|▉         | 99/1000 [00:03<00:36, 24.45it/s]
 10%|█         | 102/1000 [00:04<00:36, 24.75it/s]
 10%|█         | 105/1000 [00:04<00:35, 24.94it/s]
 11%|█         | 108/1000 [00:04<00:35, 25.06it/s]
 11%|█         | 111/1000 [00:04<00:35, 25.14it/s]
 11%|█▏        | 114/1000 [00:04<00:35, 25.22it/s]
 12%|█▏        | 117/1000 [00:04<00:34, 25.25it/s]
 12%|█▏        | 120/1000 [00:04<00:34, 25.19it/s]
 12%|█▏        | 123/1000 [00:04<00:35, 24.88it/s]
 13%|█▎        | 126/1000 [00:05<00:34, 25.02it/s]
 13%|█▎        | 129/1000 [00:05<00:34, 25.12it/s]
 13%|█▎        | 132/1000 [00:05<00:34, 25.23it/s]
 14%|█▎        | 135/1000 [00:05<00:34, 25.28it/s]
 14%|█▍        | 138/1000 [00:05<00:34, 25.34it/s]
 14%|█▍        | 141/1000 [00:05<00:33, 25.40it/s]
 14%|█▍        | 144/1000 [00:05<00:33, 25.44it/s]
 15%|█▍        | 147/1000 [00:05<00:33, 25.35it/s]
 15%|█▌        | 150/1000 [00:06<00:34, 24.87it/s]
 15%|█▌        | 153/1000 [00:06<00:33, 24.94it/s]
 16%|█▌        | 156/1000 [00:06<00:33, 25.03it/s]
 16%|█▌        | 159/1000 [00:06<00:33, 25.07it/s]
 16%|█▌        | 162/1000 [00:06<00:33, 25.14it/s]
 16%|█▋        | 165/1000 [00:06<00:33, 25.20it/s]
 17%|█▋        | 168/1000 [00:06<00:33, 25.17it/s]
 17%|█▋        | 171/1000 [00:06<00:32, 25.19it/s]
 17%|█▋        | 174/1000 [00:06<00:33, 24.78it/s]
 18%|█▊        | 177/1000 [00:07<00:33, 24.86it/s]
 18%|█▊        | 180/1000 [00:07<00:32, 24.96it/s]
 18%|█▊        | 183/1000 [00:07<00:32, 25.05it/s]
 19%|█▊        | 186/1000 [00:07<00:32, 25.13it/s]
 19%|█▉        | 189/1000 [00:07<00:32, 25.25it/s]
 19%|█▉        | 192/1000 [00:07<00:31, 25.29it/s]
 20%|█▉        | 195/1000 [00:07<00:31, 25.28it/s]
 20%|█▉        | 198/1000 [00:07<00:31, 25.30it/s]
 20%|██        | 201/1000 [00:08<00:31, 25.05it/s]
 20%|██        | 204/1000 [00:08<00:31, 25.10it/s]
 21%|██        | 207/1000 [00:08<00:31, 25.19it/s]
 21%|██        | 210/1000 [00:08<00:31, 25.20it/s]
 21%|██▏       | 213/1000 [00:08<00:31, 25.24it/s]
 22%|██▏       | 216/1000 [00:08<00:30, 25.32it/s]
 22%|██▏       | 219/1000 [00:08<00:30, 25.21it/s]
 22%|██▏       | 222/1000 [00:08<00:30, 25.27it/s]
 22%|██▎       | 225/1000 [00:08<00:31, 24.91it/s]
 23%|██▎       | 228/1000 [00:09<00:30, 25.03it/s]
 23%|██▎       | 231/1000 [00:09<00:30, 25.15it/s]
 23%|██▎       | 234/1000 [00:09<00:30, 25.25it/s]
 24%|██▎       | 237/1000 [00:09<00:30, 25.34it/s]
 24%|██▍       | 240/1000 [00:09<00:29, 25.40it/s]
 24%|██▍       | 243/1000 [00:09<00:29, 25.46it/s]
 25%|██▍       | 246/1000 [00:09<00:29, 25.46it/s]
 25%|██▍       | 249/1000 [00:09<00:29, 25.44it/s]
 25%|██▌       | 252/1000 [00:10<00:29, 25.01it/s]
 26%|██▌       | 255/1000 [00:10<00:29, 25.13it/s]
 26%|██▌       | 258/1000 [00:10<00:29, 25.23it/s]
 26%|██▌       | 261/1000 [00:10<00:29, 25.29it/s]
 26%|██▋       | 264/1000 [00:10<00:29, 25.30it/s]
 27%|██▋       | 267/1000 [00:10<00:28, 25.34it/s]
 27%|██▋       | 270/1000 [00:10<00:28, 25.39it/s]
 27%|██▋       | 273/1000 [00:10<00:28, 25.41it/s]
 28%|██▊       | 276/1000 [00:11<00:28, 25.44it/s]
 28%|██▊       | 279/1000 [00:11<00:28, 24.99it/s]
 28%|██▊       | 282/1000 [00:11<00:28, 25.10it/s]
 28%|██▊       | 285/1000 [00:11<00:28, 25.17it/s]
 29%|██▉       | 288/1000 [00:11<00:28, 25.20it/s]
 29%|██▉       | 291/1000 [00:11<00:28, 25.23it/s]
 29%|██▉       | 294/1000 [00:11<00:27, 25.28it/s]
 30%|██▉       | 297/1000 [00:11<00:27, 25.28it/s]
 30%|███       | 300/1000 [00:11<00:27, 25.32it/s]
 30%|███       | 303/1000 [00:12<00:28, 24.64it/s]
 31%|███       | 306/1000 [00:12<00:27, 24.79it/s]
 31%|███       | 309/1000 [00:12<00:27, 24.91it/s]
 31%|███       | 312/1000 [00:12<00:27, 25.03it/s]
 32%|███▏      | 315/1000 [00:12<00:27, 25.16it/s]
 32%|███▏      | 318/1000 [00:12<00:26, 25.28it/s]
 32%|███▏      | 321/1000 [00:12<00:26, 25.35it/s]
 32%|███▏      | 324/1000 [00:12<00:26, 25.37it/s]
 33%|███▎      | 327/1000 [00:13<00:26, 25.25it/s]
 33%|███▎      | 330/1000 [00:13<00:28, 23.17it/s]
 33%|███▎      | 333/1000 [00:13<00:28, 23.71it/s]
 34%|███▎      | 336/1000 [00:13<00:27, 24.18it/s]
 34%|███▍      | 339/1000 [00:13<00:26, 24.54it/s]
 34%|███▍      | 342/1000 [00:13<00:26, 24.79it/s]
 34%|███▍      | 345/1000 [00:13<00:26, 24.96it/s]
 35%|███▍      | 348/1000 [00:13<00:26, 25.06it/s]
 35%|███▌      | 351/1000 [00:14<00:25, 25.13it/s]
 35%|███▌      | 354/1000 [00:14<00:26, 24.59it/s]
 36%|███▌      | 357/1000 [00:14<00:25, 24.80it/s]
 36%|███▌      | 360/1000 [00:14<00:25, 24.97it/s]
 36%|███▋      | 363/1000 [00:14<00:25, 25.08it/s]
 37%|███▋      | 366/1000 [00:14<00:25, 25.17it/s]
 37%|███▋      | 369/1000 [00:14<00:25, 25.02it/s]
 37%|███▋      | 372/1000 [00:14<00:25, 25.02it/s]
 38%|███▊      | 375/1000 [00:14<00:24, 25.01it/s]
 38%|███▊      | 378/1000 [00:15<00:25, 24.68it/s]
 38%|███▊      | 381/1000 [00:15<00:24, 24.84it/s]
 38%|███▊      | 384/1000 [00:15<00:24, 24.95it/s]
 39%|███▊      | 387/1000 [00:15<00:24, 25.08it/s]
 39%|███▉      | 390/1000 [00:15<00:24, 25.19it/s]
 39%|███▉      | 393/1000 [00:15<00:24, 25.24it/s]
 40%|███▉      | 396/1000 [00:15<00:23, 25.30it/s]
 40%|███▉      | 399/1000 [00:15<00:23, 25.36it/s]
 40%|████      | 402/1000 [00:16<00:23, 25.37it/s]
 40%|████      | 405/1000 [00:16<00:23, 24.87it/s]
 41%|████      | 408/1000 [00:16<00:23, 25.02it/s]
 41%|████      | 411/1000 [00:16<00:23, 25.08it/s]
 41%|████▏     | 414/1000 [00:16<00:23, 25.20it/s]
 42%|████▏     | 417/1000 [00:16<00:23, 25.30it/s]
 42%|████▏     | 420/1000 [00:16<00:22, 25.35it/s]
 42%|████▏     | 423/1000 [00:16<00:22, 25.41it/s]
 43%|████▎     | 426/1000 [00:17<00:22, 25.45it/s]
 43%|████▎     | 429/1000 [00:17<00:22, 25.15it/s]
 43%|████▎     | 432/1000 [00:17<00:22, 24.84it/s]
 44%|████▎     | 435/1000 [00:17<00:22, 25.01it/s]
 44%|████▍     | 438/1000 [00:17<00:22, 25.08it/s]
 44%|████▍     | 441/1000 [00:17<00:22, 25.16it/s]
 44%|████▍     | 444/1000 [00:17<00:22, 25.21it/s]
 45%|████▍     | 447/1000 [00:17<00:21, 25.23it/s]
 45%|████▌     | 450/1000 [00:17<00:21, 25.21it/s]
 45%|████▌     | 453/1000 [00:18<00:21, 25.21it/s]
 46%|████▌     | 456/1000 [00:18<00:21, 24.79it/s]
 46%|████▌     | 459/1000 [00:18<00:21, 24.93it/s]
 46%|████▌     | 462/1000 [00:18<00:21, 25.01it/s]
 46%|████▋     | 465/1000 [00:18<00:21, 25.12it/s]
 47%|████▋     | 468/1000 [00:18<00:21, 25.20it/s]
 47%|████▋     | 471/1000 [00:18<00:21, 25.02it/s]
 47%|████▋     | 474/1000 [00:18<00:21, 25.00it/s]
 48%|████▊     | 477/1000 [00:19<00:20, 25.07it/s]
 48%|████▊     | 480/1000 [00:19<00:20, 25.06it/s]
 48%|████▊     | 483/1000 [00:19<00:20, 24.75it/s]
 49%|████▊     | 486/1000 [00:19<00:20, 24.93it/s]
 49%|████▉     | 489/1000 [00:19<00:20, 25.03it/s]
 49%|████▉     | 492/1000 [00:19<00:20, 25.11it/s]
 50%|████▉     | 495/1000 [00:19<00:20, 25.16it/s]
 50%|████▉     | 498/1000 [00:19<00:19, 25.14it/s]
 50%|█████     | 501/1000 [00:19<00:19, 25.20it/s]
 50%|█████     | 504/1000 [00:20<00:19, 25.24it/s]
 51%|█████     | 507/1000 [00:20<00:19, 24.88it/s]
 51%|█████     | 510/1000 [00:20<00:19, 25.01it/s]
 51%|█████▏    | 513/1000 [00:20<00:19, 25.11it/s]
 52%|█████▏    | 516/1000 [00:20<00:19, 25.16it/s]
 52%|█████▏    | 519/1000 [00:20<00:19, 25.08it/s]
 52%|█████▏    | 522/1000 [00:20<00:19, 25.16it/s]
 52%|█████▎    | 525/1000 [00:20<00:18, 25.25it/s]
 53%|█████▎    | 528/1000 [00:21<00:18, 25.29it/s]
 53%|█████▎    | 531/1000 [00:21<00:18, 25.28it/s]
 53%|█████▎    | 534/1000 [00:21<00:18, 24.82it/s]
 54%|█████▎    | 537/1000 [00:21<00:18, 24.94it/s]
 54%|█████▍    | 540/1000 [00:21<00:18, 25.04it/s]
 54%|█████▍    | 543/1000 [00:21<00:18, 25.08it/s]
 55%|█████▍    | 546/1000 [00:21<00:18, 25.17it/s]
 55%|█████▍    | 549/1000 [00:21<00:17, 25.24it/s]
 55%|█████▌    | 552/1000 [00:22<00:17, 25.32it/s]
 56%|█████▌    | 555/1000 [00:22<00:17, 25.34it/s]
 56%|█████▌    | 558/1000 [00:22<00:17, 24.92it/s]
 56%|█████▌    | 561/1000 [00:22<00:17, 25.00it/s]
 56%|█████▋    | 564/1000 [00:22<00:17, 25.11it/s]
 57%|█████▋    | 567/1000 [00:22<00:17, 25.18it/s]
 57%|█████▋    | 570/1000 [00:22<00:17, 25.16it/s]
 57%|█████▋    | 573/1000 [00:22<00:17, 25.11it/s]
 58%|█████▊    | 576/1000 [00:22<00:16, 25.17it/s]
 58%|█████▊    | 579/1000 [00:23<00:16, 25.18it/s]
 58%|█████▊    | 582/1000 [00:23<00:16, 25.16it/s]
 58%|█████▊    | 585/1000 [00:23<00:16, 24.73it/s]
 59%|█████▉    | 588/1000 [00:23<00:16, 24.90it/s]
 59%|█████▉    | 591/1000 [00:23<00:16, 25.04it/s]
 59%|█████▉    | 594/1000 [00:23<00:16, 25.11it/s]
 60%|█████▉    | 597/1000 [00:23<00:16, 25.14it/s]
 60%|██████    | 600/1000 [00:23<00:15, 25.17it/s]
 60%|██████    | 603/1000 [00:24<00:15, 25.19it/s]
 61%|██████    | 606/1000 [00:24<00:15, 25.19it/s]
 61%|██████    | 609/1000 [00:24<00:15, 24.62it/s]
 61%|██████    | 612/1000 [00:24<00:15, 24.82it/s]
 62%|██████▏   | 615/1000 [00:24<00:15, 24.96it/s]
 62%|██████▏   | 618/1000 [00:24<00:15, 25.08it/s]
 62%|██████▏   | 621/1000 [00:24<00:15, 25.16it/s]
 62%|██████▏   | 624/1000 [00:24<00:14, 25.25it/s]
 63%|██████▎   | 627/1000 [00:25<00:14, 25.28it/s]
 63%|██████▎   | 630/1000 [00:25<00:14, 25.32it/s]
 63%|██████▎   | 633/1000 [00:25<00:14, 25.39it/s]
 64%|██████▎   | 636/1000 [00:25<00:14, 25.03it/s]
 64%|██████▍   | 639/1000 [00:25<00:14, 25.10it/s]
 64%|██████▍   | 642/1000 [00:25<00:14, 25.10it/s]
 64%|██████▍   | 645/1000 [00:25<00:14, 25.09it/s]
 65%|██████▍   | 648/1000 [00:25<00:14, 25.13it/s]
 65%|██████▌   | 651/1000 [00:25<00:13, 25.14it/s]
 65%|██████▌   | 654/1000 [00:26<00:13, 25.12it/s]
 66%|██████▌   | 657/1000 [00:26<00:13, 25.12it/s]
 66%|██████▌   | 660/1000 [00:26<00:13, 24.75it/s]
 66%|██████▋   | 663/1000 [00:26<00:13, 24.89it/s]
 67%|██████▋   | 666/1000 [00:26<00:13, 25.01it/s]
 67%|██████▋   | 669/1000 [00:26<00:13, 25.12it/s]
 67%|██████▋   | 672/1000 [00:26<00:13, 25.11it/s]
 68%|██████▊   | 675/1000 [00:26<00:12, 25.13it/s]
 68%|██████▊   | 678/1000 [00:27<00:12, 24.84it/s]
 68%|██████▊   | 681/1000 [00:27<00:12, 24.88it/s]
 68%|██████▊   | 684/1000 [00:27<00:12, 25.02it/s]
 69%|██████▊   | 687/1000 [00:27<00:12, 24.74it/s]
 69%|██████▉   | 690/1000 [00:27<00:12, 24.95it/s]
 69%|██████▉   | 693/1000 [00:27<00:12, 25.12it/s]
 70%|██████▉   | 696/1000 [00:27<00:12, 25.18it/s]
 70%|██████▉   | 699/1000 [00:27<00:11, 25.18it/s]
 70%|███████   | 702/1000 [00:28<00:11, 25.24it/s]
 70%|███████   | 705/1000 [00:28<00:11, 25.13it/s]
 71%|███████   | 708/1000 [00:28<00:11, 25.17it/s]
 71%|███████   | 711/1000 [00:28<00:11, 24.82it/s]
 71%|███████▏  | 714/1000 [00:28<00:11, 24.93it/s]
 72%|███████▏  | 717/1000 [00:28<00:11, 25.04it/s]
 72%|███████▏  | 720/1000 [00:28<00:11, 25.14it/s]
 72%|███████▏  | 723/1000 [00:28<00:10, 25.21it/s]
 73%|███████▎  | 726/1000 [00:28<00:10, 25.27it/s]
 73%|███████▎  | 729/1000 [00:29<00:10, 25.28it/s]
 73%|███████▎  | 732/1000 [00:29<00:10, 25.30it/s]
 74%|███████▎  | 735/1000 [00:29<00:10, 25.16it/s]
 74%|███████▍  | 738/1000 [00:29<00:10, 24.93it/s]
 74%|███████▍  | 741/1000 [00:29<00:10, 25.05it/s]
 74%|███████▍  | 744/1000 [00:29<00:10, 25.14it/s]
 75%|███████▍  | 747/1000 [00:29<00:10, 25.22it/s]
 75%|███████▌  | 750/1000 [00:29<00:09, 25.24it/s]
 75%|███████▌  | 753/1000 [00:30<00:09, 25.28it/s]
 76%|███████▌  | 756/1000 [00:30<00:09, 25.24it/s]
 76%|███████▌  | 759/1000 [00:30<00:09, 25.23it/s]
 76%|███████▌  | 762/1000 [00:30<00:09, 24.84it/s]
 76%|███████▋  | 765/1000 [00:30<00:09, 25.02it/s]
 77%|███████▋  | 768/1000 [00:30<00:09, 25.18it/s]
 77%|███████▋  | 771/1000 [00:30<00:09, 25.25it/s]
 77%|███████▋  | 774/1000 [00:30<00:08, 25.27it/s]
 78%|███████▊  | 777/1000 [00:30<00:08, 25.29it/s]
 78%|███████▊  | 780/1000 [00:31<00:08, 25.26it/s]
 78%|███████▊  | 783/1000 [00:31<00:08, 25.25it/s]
 79%|███████▊  | 786/1000 [00:31<00:08, 25.30it/s]
 79%|███████▉  | 789/1000 [00:31<00:08, 24.91it/s]
 79%|███████▉  | 792/1000 [00:31<00:08, 25.04it/s]
 80%|███████▉  | 795/1000 [00:31<00:08, 25.09it/s]
 80%|███████▉  | 798/1000 [00:31<00:08, 25.18it/s]
 80%|████████  | 801/1000 [00:31<00:07, 25.23it/s]
 80%|████████  | 804/1000 [00:32<00:07, 25.26it/s]
 81%|████████  | 807/1000 [00:32<00:07, 25.24it/s]
 81%|████████  | 810/1000 [00:32<00:07, 25.24it/s]
 81%|████████▏ | 813/1000 [00:32<00:07, 25.26it/s]
 82%|████████▏ | 816/1000 [00:32<00:07, 24.96it/s]
 82%|████████▏ | 819/1000 [00:32<00:07, 25.07it/s]
 82%|████████▏ | 822/1000 [00:32<00:07, 25.06it/s]
 82%|████████▎ | 825/1000 [00:32<00:06, 25.12it/s]
 83%|████████▎ | 828/1000 [00:33<00:06, 25.20it/s]
 83%|████████▎ | 831/1000 [00:33<00:06, 25.21it/s]
 83%|████████▎ | 834/1000 [00:33<00:06, 25.25it/s]
 84%|████████▎ | 837/1000 [00:33<00:06, 25.32it/s]
 84%|████████▍ | 840/1000 [00:33<00:06, 24.95it/s]
 84%|████████▍ | 843/1000 [00:33<00:06, 25.12it/s]
 85%|████████▍ | 846/1000 [00:33<00:06, 25.22it/s]
 85%|████████▍ | 849/1000 [00:33<00:05, 25.27it/s]
 85%|████████▌ | 852/1000 [00:33<00:05, 25.30it/s]
 86%|████████▌ | 855/1000 [00:34<00:05, 25.32it/s]
 86%|████████▌ | 858/1000 [00:34<00:05, 25.07it/s]
 86%|████████▌ | 861/1000 [00:34<00:05, 25.13it/s]
 86%|████████▋ | 864/1000 [00:34<00:05, 25.20it/s]
 87%|████████▋ | 867/1000 [00:34<00:05, 24.69it/s]
 87%|████████▋ | 870/1000 [00:34<00:05, 24.89it/s]
 87%|████████▋ | 873/1000 [00:34<00:05, 25.01it/s]
 88%|████████▊ | 876/1000 [00:34<00:04, 25.10it/s]
 88%|████████▊ | 879/1000 [00:35<00:04, 25.19it/s]
 88%|████████▊ | 882/1000 [00:35<00:04, 25.18it/s]
 88%|████████▊ | 885/1000 [00:35<00:04, 25.22it/s]
 89%|████████▉ | 888/1000 [00:35<00:04, 25.21it/s]
 89%|████████▉ | 891/1000 [00:35<00:04, 24.94it/s]
 89%|████████▉ | 894/1000 [00:35<00:04, 25.05it/s]
 90%|████████▉ | 897/1000 [00:35<00:04, 25.14it/s]
 90%|█████████ | 900/1000 [00:35<00:03, 25.20it/s]
 90%|█████████ | 903/1000 [00:36<00:03, 25.24it/s]
 91%|█████████ | 906/1000 [00:36<00:03, 25.22it/s]
 91%|█████████ | 909/1000 [00:36<00:03, 25.26it/s]
 91%|█████████ | 912/1000 [00:36<00:03, 25.25it/s]
 92%|█████████▏| 915/1000 [00:36<00:03, 25.28it/s]
 92%|█████████▏| 918/1000 [00:36<00:03, 24.84it/s]
 92%|█████████▏| 921/1000 [00:36<00:03, 25.00it/s]
 92%|█████████▏| 924/1000 [00:36<00:03, 25.08it/s]
 93%|█████████▎| 927/1000 [00:36<00:02, 25.19it/s]
 93%|█████████▎| 930/1000 [00:37<00:02, 25.20it/s]
 93%|█████████▎| 933/1000 [00:37<00:02, 25.25it/s]
 94%|█████████▎| 936/1000 [00:37<00:02, 25.24it/s]
 94%|█████████▍| 939/1000 [00:37<00:02, 25.26it/s]
 94%|█████████▍| 942/1000 [00:37<00:02, 24.87it/s]
 94%|█████████▍| 945/1000 [00:37<00:02, 24.93it/s]
 95%|█████████▍| 948/1000 [00:37<00:02, 25.03it/s]
 95%|█████████▌| 951/1000 [00:37<00:01, 25.13it/s]
 95%|█████████▌| 954/1000 [00:38<00:01, 25.24it/s]
 96%|█████████▌| 957/1000 [00:38<00:01, 25.10it/s]
 96%|█████████▌| 960/1000 [00:38<00:01, 25.18it/s]
 96%|█████████▋| 963/1000 [00:38<00:01, 25.21it/s]
 97%|█████████▋| 966/1000 [00:38<00:01, 25.26it/s]
 97%|█████████▋| 969/1000 [00:38<00:01, 25.06it/s]
 97%|█████████▋| 972/1000 [00:38<00:01, 25.12it/s]
 98%|█████████▊| 975/1000 [00:38<00:00, 25.19it/s]
 98%|█████████▊| 978/1000 [00:38<00:00, 25.24it/s]
 98%|█████████▊| 981/1000 [00:39<00:00, 25.27it/s]
 98%|█████████▊| 984/1000 [00:39<00:00, 25.28it/s]
 99%|█████████▊| 987/1000 [00:39<00:00, 25.28it/s]
 99%|█████████▉| 990/1000 [00:39<00:00, 25.31it/s]
 99%|█████████▉| 993/1000 [00:39<00:00, 24.87it/s]
100%|█████████▉| 996/1000 [00:39<00:00, 25.00it/s]
100%|█████████▉| 999/1000 [00:39<00:00, 25.09it/s]
100%|██████████| 1000/1000 [00:39<00:00, 25.09it/s]

We can plot the loss to make sure that it decreases

plt.figure()
plt.plot(range(n_iter), losses)
plt.title("Loss evolution")
plt.yscale("log")
plt.xlabel("Iteration")
plt.tight_layout()
plt.show()
Loss evolution

Combine with arbitrary optimizer#

Pytorch provides a wide range of optimizers for training neural networks. We can also pick one of those to optimizer our parameter

kernel_init = torch.zeros_like(true_kernel)
kernel_init[..., 5:-5, 5:-5] = 1.0
kernel_init = projection_simplex_sort(kernel_init)

kernel_hat = kernel_init.clone()
optimizer = torch.optim.Adam([kernel_hat], lr=0.1)

# We will alternate a gradient step and a projection step
losses = []
n_iter = 200
for i in tqdm(range(n_iter)):
    optimizer.zero_grad()
    # compute the gradient, this will directly change the gradient of `kernel_hat`
    with torch.enable_grad():
        kernel_hat.requires_grad_(True)
        physics.update(filter=kernel_hat)
        loss = data_fidelity(y=y, x=x, physics=physics) / y.numel()
        loss.backward()

    # a gradient step
    optimizer.step()
    # projection step, when doing additional steps, it's important to change only
    # the tensor data to avoid breaking the gradient computation
    kernel_hat.data = projection_simplex_sort(kernel_hat.data)
    # loss
    losses.append(loss.item())

dinv.utils.plot(
    [true_kernel, kernel_init, kernel_hat],
    titles=["True kernel", "Init. kernel", "Estimated kernel"],
    suptitle="Result with ADAM",
)
Result with ADAM, True kernel, Init. kernel, Estimated kernel
  0%|          | 0/200 [00:00<?, ?it/s]
  2%|▏         | 3/200 [00:00<00:08, 24.29it/s]
  3%|▎         | 6/200 [00:00<00:07, 24.51it/s]
  4%|▍         | 9/200 [00:00<00:07, 24.60it/s]
  6%|▌         | 12/200 [00:00<00:07, 24.75it/s]
  8%|▊         | 15/200 [00:00<00:07, 24.44it/s]
  9%|▉         | 18/200 [00:00<00:07, 24.65it/s]
 10%|█         | 21/200 [00:00<00:07, 24.83it/s]
 12%|█▏        | 24/200 [00:00<00:07, 25.00it/s]
 14%|█▎        | 27/200 [00:01<00:06, 25.10it/s]
 15%|█▌        | 30/200 [00:01<00:06, 25.13it/s]
 16%|█▋        | 33/200 [00:01<00:06, 25.08it/s]
 18%|█▊        | 36/200 [00:01<00:06, 25.09it/s]
 20%|█▉        | 39/200 [00:01<00:06, 24.76it/s]
 21%|██        | 42/200 [00:01<00:06, 24.85it/s]
 22%|██▎       | 45/200 [00:01<00:06, 24.94it/s]
 24%|██▍       | 48/200 [00:01<00:06, 25.02it/s]
 26%|██▌       | 51/200 [00:02<00:05, 25.08it/s]
 27%|██▋       | 54/200 [00:02<00:05, 25.10it/s]
 28%|██▊       | 57/200 [00:02<00:05, 25.11it/s]
 30%|███       | 60/200 [00:02<00:05, 25.12it/s]
 32%|███▏      | 63/200 [00:02<00:05, 25.14it/s]
 33%|███▎      | 66/200 [00:02<00:05, 24.74it/s]
 34%|███▍      | 69/200 [00:02<00:05, 24.91it/s]
 36%|███▌      | 72/200 [00:02<00:05, 25.01it/s]
 38%|███▊      | 75/200 [00:03<00:04, 25.07it/s]
 39%|███▉      | 78/200 [00:03<00:04, 25.04it/s]
 40%|████      | 81/200 [00:03<00:04, 25.08it/s]
 42%|████▏     | 84/200 [00:03<00:04, 25.12it/s]
 44%|████▎     | 87/200 [00:03<00:04, 25.19it/s]
 45%|████▌     | 90/200 [00:03<00:04, 24.84it/s]
 46%|████▋     | 93/200 [00:03<00:04, 24.97it/s]
 48%|████▊     | 96/200 [00:03<00:04, 25.01it/s]
 50%|████▉     | 99/200 [00:03<00:04, 25.08it/s]
 51%|█████     | 102/200 [00:04<00:03, 25.11it/s]
 52%|█████▎    | 105/200 [00:04<00:03, 25.16it/s]
 54%|█████▍    | 108/200 [00:04<00:03, 25.15it/s]
 56%|█████▌    | 111/200 [00:04<00:03, 25.18it/s]
 57%|█████▋    | 114/200 [00:04<00:03, 25.21it/s]
 58%|█████▊    | 117/200 [00:04<00:03, 24.70it/s]
 60%|██████    | 120/200 [00:04<00:03, 24.76it/s]
 62%|██████▏   | 123/200 [00:04<00:03, 24.76it/s]
 63%|██████▎   | 126/200 [00:05<00:02, 24.78it/s]
 64%|██████▍   | 129/200 [00:05<00:02, 24.81it/s]
 66%|██████▌   | 132/200 [00:05<00:02, 24.86it/s]
 68%|██████▊   | 135/200 [00:05<00:02, 24.89it/s]
 69%|██████▉   | 138/200 [00:05<00:02, 24.92it/s]
 70%|███████   | 141/200 [00:05<00:02, 24.03it/s]
 72%|███████▏  | 144/200 [00:05<00:02, 24.38it/s]
 74%|███████▎  | 147/200 [00:05<00:02, 24.67it/s]
 75%|███████▌  | 150/200 [00:06<00:02, 24.87it/s]
 76%|███████▋  | 153/200 [00:06<00:01, 24.94it/s]
 78%|███████▊  | 156/200 [00:06<00:01, 25.03it/s]
 80%|███████▉  | 159/200 [00:06<00:01, 25.10it/s]
 81%|████████  | 162/200 [00:06<00:01, 25.19it/s]
 82%|████████▎ | 165/200 [00:06<00:01, 25.23it/s]
 84%|████████▍ | 168/200 [00:06<00:01, 24.80it/s]
 86%|████████▌ | 171/200 [00:06<00:01, 24.92it/s]
 87%|████████▋ | 174/200 [00:06<00:01, 25.02it/s]
 88%|████████▊ | 177/200 [00:07<00:00, 25.00it/s]
 90%|█████████ | 180/200 [00:07<00:00, 25.06it/s]
 92%|█████████▏| 183/200 [00:07<00:00, 25.08it/s]
 93%|█████████▎| 186/200 [00:07<00:00, 25.12it/s]
 94%|█████████▍| 189/200 [00:07<00:00, 25.13it/s]
 96%|█████████▌| 192/200 [00:07<00:00, 24.79it/s]
 98%|█████████▊| 195/200 [00:07<00:00, 24.82it/s]
 99%|█████████▉| 198/200 [00:07<00:00, 24.88it/s]
100%|██████████| 200/200 [00:08<00:00, 24.94it/s]

We can plot the loss to make sure that it decreases

plt.figure()
plt.semilogy(range(n_iter), losses)
plt.title("Loss evolution")
plt.xlabel("Iteration")
plt.tight_layout()
plt.show()
Loss evolution

Optimizing the physics as a usual neural network#

Below we show another way to optimize the parameter of the physics, as we usually do for neural networks

kernel_init = torch.zeros_like(true_kernel)
kernel_init[..., 5:-5, 5:-5] = 1.0
kernel_init = projection_simplex_sort(kernel_init)

# The gradient is off by default, we need to enable the gradient of the parameter
physics = dinv.physics.Blur(
    filter=kernel_init.clone().requires_grad_(True), device=device
)

# Set up the optimizer by giving the parameter to an optimizer
# Try to change your favorite optimizer
optimizer = torch.optim.AdamW([physics.filter], lr=0.1)


# Try to change another loss function
# loss_fn = torch.nn.MSELoss()
loss_fn = torch.nn.L1Loss()

# We will alternate a gradient step and a projection step
losses = []
n_iter = 100
for i in tqdm(range(n_iter)):
    # update the gradient
    optimizer.zero_grad()
    y_hat = physics.A(x)
    loss = loss_fn(y_hat, y)
    loss.backward()

    # a gradient step
    optimizer.step()

    # projection step.
    # Note: when doing additional steps, it's important to change only
    # the tensor data to avoid breaking the gradient computation
    physics.filter.data = projection_simplex_sort(physics.filter.data)

    # loss
    losses.append(loss.item())

kernel_hat = physics.filter.data
dinv.utils.plot(
    [true_kernel, kernel_init, kernel_hat],
    titles=["True kernel", "Init. kernel", "Estimated kernel"],
    suptitle="Result with AdamW and L1 Loss",
)
Result with AdamW and L1 Loss, True kernel, Init. kernel, Estimated kernel
  0%|          | 0/100 [00:00<?, ?it/s]
  3%|▎         | 3/100 [00:00<00:04, 23.38it/s]
  6%|▌         | 6/100 [00:00<00:03, 24.08it/s]
  9%|▉         | 9/100 [00:00<00:03, 24.38it/s]
 12%|█▏        | 12/100 [00:00<00:03, 24.13it/s]
 15%|█▌        | 15/100 [00:00<00:03, 24.47it/s]
 18%|█▊        | 18/100 [00:00<00:03, 24.71it/s]
 21%|██        | 21/100 [00:00<00:03, 24.84it/s]
 24%|██▍       | 24/100 [00:00<00:03, 24.94it/s]
 27%|██▋       | 27/100 [00:01<00:02, 24.99it/s]
 30%|███       | 30/100 [00:01<00:02, 24.94it/s]
 33%|███▎      | 33/100 [00:01<00:02, 25.05it/s]
 36%|███▌      | 36/100 [00:01<00:02, 25.14it/s]
 39%|███▉      | 39/100 [00:01<00:02, 24.79it/s]
 42%|████▏     | 42/100 [00:01<00:02, 24.92it/s]
 45%|████▌     | 45/100 [00:01<00:02, 24.98it/s]
 48%|████▊     | 48/100 [00:01<00:02, 24.98it/s]
 51%|█████     | 51/100 [00:02<00:01, 25.05it/s]
 54%|█████▍    | 54/100 [00:02<00:01, 25.10it/s]
 57%|█████▋    | 57/100 [00:02<00:01, 25.10it/s]
 60%|██████    | 60/100 [00:02<00:01, 25.14it/s]
 63%|██████▎   | 63/100 [00:02<00:01, 24.62it/s]
 66%|██████▌   | 66/100 [00:02<00:01, 24.77it/s]
 69%|██████▉   | 69/100 [00:02<00:01, 24.89it/s]
 72%|███████▏  | 72/100 [00:02<00:01, 24.91it/s]
 75%|███████▌  | 75/100 [00:03<00:01, 24.95it/s]
 78%|███████▊  | 78/100 [00:03<00:00, 24.97it/s]
 81%|████████  | 81/100 [00:03<00:00, 25.00it/s]
 84%|████████▍ | 84/100 [00:03<00:00, 25.04it/s]
 87%|████████▋ | 87/100 [00:03<00:00, 25.03it/s]
 90%|█████████ | 90/100 [00:03<00:00, 24.61it/s]
 93%|█████████▎| 93/100 [00:03<00:00, 24.69it/s]
 96%|█████████▌| 96/100 [00:03<00:00, 24.77it/s]
 99%|█████████▉| 99/100 [00:03<00:00, 24.88it/s]
100%|██████████| 100/100 [00:04<00:00, 24.85it/s]

We can plot the loss to make sure that it decreases

plt.figure()
plt.semilogy(range(n_iter), losses)
plt.title("Loss evolution")
plt.xlabel("Iteration")
plt.tight_layout()
plt.show()
Loss evolution

Total running time of the script: (0 minutes 52.797 seconds)

Gallery generated by Sphinx-Gallery