Note
Go to the end to download the full example code.
Solving blind inverse problems / estimating physics parameters#
This demo shows you how to use
deepinv.physics.Physics
together with automatic differentiation to optimize your operator.
Consider the forward model
where \(N\) is the noise model, \(\forw{\cdot, \theta}\) is the forward operator, and the goal is to learn the parameter \(\theta\) (e.g., the filter in deepinv.physics.Blur
).
In a typical blind inverse problem, given a measurement \(y\), we would like to recover both the underlying image \(x\) and the operator parameter \(\theta\), resulting in a highly ill-posed inverse problem.
In this example, we only focus on a much more simpler problem: given the measurement \(y\) and the ground truth \(x\), find the parameter \(\theta\). This can be reformulated as the following optimization problem:
This problem can be addressed by first-order optimization if we can compute the gradient of the above function with respect to \(\theta\). The dependence between the operator \(A\) and the parameter \(\theta\) can be complicated. DeepInverse provides a wide range of physics operators, implemented as differentiable classes. We can leverage the automatic differentiation engine provided in Pytorch to compute the gradient of the above loss function w.r.t. the physics parameters \(\theta\).
The purpose of this demo is to show how to use the physics classes in DeepInverse to estimate the physics parameters, together with the automatic differentiation. We show 3 different ways to do this: manually implementing the projected gradient descent algorithm, using a Pytorch optimizer and optimizing the physics as a usual neural network.
Import required packages
import deepinv as dinv
import torch
from tqdm import tqdm
import matplotlib.pyplot as plt
device = dinv.utils.get_freer_gpu() if torch.cuda.is_available() else "cpu"
dtype = torch.float32
Define the physics#
In this first example, we use the convolution operator, defined in the deepinv.physics.Blur
class.
We also generate a random convolution kernel of motion blur
generator = dinv.physics.generator.MotionBlurGenerator(
psf_size=(25, 25), rng=torch.Generator(device), device=device
)
true_kernel = generator.step(1, seed=123)["filter"]
physics = dinv.physics.Blur(noise_model=dinv.physics.GaussianNoise(0.02), device=device)
x = dinv.utils.load_url_image(
dinv.utils.demo.get_image_url("celeba_example.jpg"),
img_size=256,
resize_mode="resize",
).to(device)
y = physics(x, filter=true_kernel)
dinv.utils.plot([x, y, true_kernel], titles=["Sharp", "Blurry", "True kernel"])

Define an optimization algorithm#
The convolution kernel lives in the simplex, ie the kernel must have positive entries summing to 1. We can use a simple optimization algorithm - Projected Gradient Descent - to enforce this constraint. The following function allows one to compute the orthogonal projection onto the simplex, by a sorting algorithm (Reference: Large-scale Multiclass Support Vector Machine Training via Euclidean Projection onto the Simplex – Mathieu Blondel, Akinori Fujino, and Naonori Ueda)
@torch.no_grad()
def projection_simplex_sort(v: torch.Tensor) -> torch.Tensor:
r"""
Projects a tensor onto the simplex using a sorting algorithm.
"""
shape = v.shape
B = shape[0]
v = v.view(B, -1)
n_features = v.size(1)
u = torch.sort(v, descending=True, dim=-1).values
cssv = torch.cumsum(u, dim=-1) - 1.0
ind = torch.arange(n_features, device=v.device)[None, :].expand(B, -1) + 1.0
cond = u - cssv / ind > 0
rho = ind[cond][-1]
theta = cssv[cond][-1] / rho
w = torch.maximum(v - theta, torch.zeros_like(v))
return w.reshape(shape)
# We also define a data fidelity term
data_fidelity = dinv.optim.L2()
Run the algorithm
Initialize a constant kernel
kernel_init = torch.zeros_like(true_kernel)
kernel_init[..., 5:-5, 5:-5] = 1.0
kernel_init = projection_simplex_sort(kernel_init)
n_iter = 1000
stepsize = 0.7
kernel_hat = kernel_init
losses = []
for i in tqdm(range(n_iter)):
# compute the gradient
with torch.enable_grad():
kernel_hat.requires_grad_(True)
physics.update(filter=kernel_hat)
loss = data_fidelity(y=y, x=x, physics=physics) / y.numel()
loss.backward()
grad = kernel_hat.grad
# gradient step and projection step
with torch.no_grad():
kernel_hat = kernel_hat - stepsize * grad
kernel_hat = projection_simplex_sort(kernel_hat)
losses.append(loss.item())
dinv.utils.plot(
[true_kernel, kernel_init, kernel_hat],
titles=["True kernel", "Init. kernel", "Estimated kernel"],
suptitle="Result with Projected Gradient Descent",
)

0%| | 0/1000 [00:00<?, ?it/s]
0%| | 3/1000 [00:00<00:43, 22.83it/s]
1%| | 6/1000 [00:00<00:42, 23.60it/s]
1%| | 9/1000 [00:00<00:41, 24.03it/s]
1%| | 12/1000 [00:00<00:40, 24.39it/s]
2%|▏ | 15/1000 [00:00<00:39, 24.64it/s]
2%|▏ | 18/1000 [00:00<00:39, 24.80it/s]
2%|▏ | 21/1000 [00:00<00:39, 24.56it/s]
2%|▏ | 24/1000 [00:00<00:39, 24.83it/s]
3%|▎ | 27/1000 [00:01<00:38, 24.99it/s]
3%|▎ | 30/1000 [00:01<00:38, 25.05it/s]
3%|▎ | 33/1000 [00:01<00:38, 25.12it/s]
4%|▎ | 36/1000 [00:01<00:38, 25.17it/s]
4%|▍ | 39/1000 [00:01<00:38, 25.19it/s]
4%|▍ | 42/1000 [00:01<00:38, 25.19it/s]
4%|▍ | 45/1000 [00:01<00:37, 25.20it/s]
5%|▍ | 48/1000 [00:01<00:38, 24.79it/s]
5%|▌ | 51/1000 [00:02<00:38, 24.97it/s]
5%|▌ | 54/1000 [00:02<00:37, 25.05it/s]
6%|▌ | 57/1000 [00:02<00:37, 25.13it/s]
6%|▌ | 60/1000 [00:02<00:37, 25.18it/s]
6%|▋ | 63/1000 [00:02<00:37, 25.25it/s]
7%|▋ | 66/1000 [00:02<00:36, 25.28it/s]
7%|▋ | 69/1000 [00:02<00:36, 25.25it/s]
7%|▋ | 72/1000 [00:02<00:37, 24.98it/s]
8%|▊ | 75/1000 [00:03<00:36, 25.07it/s]
8%|▊ | 78/1000 [00:03<00:36, 25.11it/s]
8%|▊ | 81/1000 [00:03<00:36, 25.17it/s]
8%|▊ | 84/1000 [00:03<00:36, 25.22it/s]
9%|▊ | 87/1000 [00:03<00:36, 25.26it/s]
9%|▉ | 90/1000 [00:03<00:37, 24.18it/s]
9%|▉ | 93/1000 [00:03<00:36, 24.52it/s]
10%|▉ | 96/1000 [00:03<00:36, 24.62it/s]
10%|▉ | 99/1000 [00:03<00:36, 24.45it/s]
10%|█ | 102/1000 [00:04<00:36, 24.75it/s]
10%|█ | 105/1000 [00:04<00:35, 24.94it/s]
11%|█ | 108/1000 [00:04<00:35, 25.06it/s]
11%|█ | 111/1000 [00:04<00:35, 25.14it/s]
11%|█▏ | 114/1000 [00:04<00:35, 25.22it/s]
12%|█▏ | 117/1000 [00:04<00:34, 25.25it/s]
12%|█▏ | 120/1000 [00:04<00:34, 25.19it/s]
12%|█▏ | 123/1000 [00:04<00:35, 24.88it/s]
13%|█▎ | 126/1000 [00:05<00:34, 25.02it/s]
13%|█▎ | 129/1000 [00:05<00:34, 25.12it/s]
13%|█▎ | 132/1000 [00:05<00:34, 25.23it/s]
14%|█▎ | 135/1000 [00:05<00:34, 25.28it/s]
14%|█▍ | 138/1000 [00:05<00:34, 25.34it/s]
14%|█▍ | 141/1000 [00:05<00:33, 25.40it/s]
14%|█▍ | 144/1000 [00:05<00:33, 25.44it/s]
15%|█▍ | 147/1000 [00:05<00:33, 25.35it/s]
15%|█▌ | 150/1000 [00:06<00:34, 24.87it/s]
15%|█▌ | 153/1000 [00:06<00:33, 24.94it/s]
16%|█▌ | 156/1000 [00:06<00:33, 25.03it/s]
16%|█▌ | 159/1000 [00:06<00:33, 25.07it/s]
16%|█▌ | 162/1000 [00:06<00:33, 25.14it/s]
16%|█▋ | 165/1000 [00:06<00:33, 25.20it/s]
17%|█▋ | 168/1000 [00:06<00:33, 25.17it/s]
17%|█▋ | 171/1000 [00:06<00:32, 25.19it/s]
17%|█▋ | 174/1000 [00:06<00:33, 24.78it/s]
18%|█▊ | 177/1000 [00:07<00:33, 24.86it/s]
18%|█▊ | 180/1000 [00:07<00:32, 24.96it/s]
18%|█▊ | 183/1000 [00:07<00:32, 25.05it/s]
19%|█▊ | 186/1000 [00:07<00:32, 25.13it/s]
19%|█▉ | 189/1000 [00:07<00:32, 25.25it/s]
19%|█▉ | 192/1000 [00:07<00:31, 25.29it/s]
20%|█▉ | 195/1000 [00:07<00:31, 25.28it/s]
20%|█▉ | 198/1000 [00:07<00:31, 25.30it/s]
20%|██ | 201/1000 [00:08<00:31, 25.05it/s]
20%|██ | 204/1000 [00:08<00:31, 25.10it/s]
21%|██ | 207/1000 [00:08<00:31, 25.19it/s]
21%|██ | 210/1000 [00:08<00:31, 25.20it/s]
21%|██▏ | 213/1000 [00:08<00:31, 25.24it/s]
22%|██▏ | 216/1000 [00:08<00:30, 25.32it/s]
22%|██▏ | 219/1000 [00:08<00:30, 25.21it/s]
22%|██▏ | 222/1000 [00:08<00:30, 25.27it/s]
22%|██▎ | 225/1000 [00:08<00:31, 24.91it/s]
23%|██▎ | 228/1000 [00:09<00:30, 25.03it/s]
23%|██▎ | 231/1000 [00:09<00:30, 25.15it/s]
23%|██▎ | 234/1000 [00:09<00:30, 25.25it/s]
24%|██▎ | 237/1000 [00:09<00:30, 25.34it/s]
24%|██▍ | 240/1000 [00:09<00:29, 25.40it/s]
24%|██▍ | 243/1000 [00:09<00:29, 25.46it/s]
25%|██▍ | 246/1000 [00:09<00:29, 25.46it/s]
25%|██▍ | 249/1000 [00:09<00:29, 25.44it/s]
25%|██▌ | 252/1000 [00:10<00:29, 25.01it/s]
26%|██▌ | 255/1000 [00:10<00:29, 25.13it/s]
26%|██▌ | 258/1000 [00:10<00:29, 25.23it/s]
26%|██▌ | 261/1000 [00:10<00:29, 25.29it/s]
26%|██▋ | 264/1000 [00:10<00:29, 25.30it/s]
27%|██▋ | 267/1000 [00:10<00:28, 25.34it/s]
27%|██▋ | 270/1000 [00:10<00:28, 25.39it/s]
27%|██▋ | 273/1000 [00:10<00:28, 25.41it/s]
28%|██▊ | 276/1000 [00:11<00:28, 25.44it/s]
28%|██▊ | 279/1000 [00:11<00:28, 24.99it/s]
28%|██▊ | 282/1000 [00:11<00:28, 25.10it/s]
28%|██▊ | 285/1000 [00:11<00:28, 25.17it/s]
29%|██▉ | 288/1000 [00:11<00:28, 25.20it/s]
29%|██▉ | 291/1000 [00:11<00:28, 25.23it/s]
29%|██▉ | 294/1000 [00:11<00:27, 25.28it/s]
30%|██▉ | 297/1000 [00:11<00:27, 25.28it/s]
30%|███ | 300/1000 [00:11<00:27, 25.32it/s]
30%|███ | 303/1000 [00:12<00:28, 24.64it/s]
31%|███ | 306/1000 [00:12<00:27, 24.79it/s]
31%|███ | 309/1000 [00:12<00:27, 24.91it/s]
31%|███ | 312/1000 [00:12<00:27, 25.03it/s]
32%|███▏ | 315/1000 [00:12<00:27, 25.16it/s]
32%|███▏ | 318/1000 [00:12<00:26, 25.28it/s]
32%|███▏ | 321/1000 [00:12<00:26, 25.35it/s]
32%|███▏ | 324/1000 [00:12<00:26, 25.37it/s]
33%|███▎ | 327/1000 [00:13<00:26, 25.25it/s]
33%|███▎ | 330/1000 [00:13<00:28, 23.17it/s]
33%|███▎ | 333/1000 [00:13<00:28, 23.71it/s]
34%|███▎ | 336/1000 [00:13<00:27, 24.18it/s]
34%|███▍ | 339/1000 [00:13<00:26, 24.54it/s]
34%|███▍ | 342/1000 [00:13<00:26, 24.79it/s]
34%|███▍ | 345/1000 [00:13<00:26, 24.96it/s]
35%|███▍ | 348/1000 [00:13<00:26, 25.06it/s]
35%|███▌ | 351/1000 [00:14<00:25, 25.13it/s]
35%|███▌ | 354/1000 [00:14<00:26, 24.59it/s]
36%|███▌ | 357/1000 [00:14<00:25, 24.80it/s]
36%|███▌ | 360/1000 [00:14<00:25, 24.97it/s]
36%|███▋ | 363/1000 [00:14<00:25, 25.08it/s]
37%|███▋ | 366/1000 [00:14<00:25, 25.17it/s]
37%|███▋ | 369/1000 [00:14<00:25, 25.02it/s]
37%|███▋ | 372/1000 [00:14<00:25, 25.02it/s]
38%|███▊ | 375/1000 [00:14<00:24, 25.01it/s]
38%|███▊ | 378/1000 [00:15<00:25, 24.68it/s]
38%|███▊ | 381/1000 [00:15<00:24, 24.84it/s]
38%|███▊ | 384/1000 [00:15<00:24, 24.95it/s]
39%|███▊ | 387/1000 [00:15<00:24, 25.08it/s]
39%|███▉ | 390/1000 [00:15<00:24, 25.19it/s]
39%|███▉ | 393/1000 [00:15<00:24, 25.24it/s]
40%|███▉ | 396/1000 [00:15<00:23, 25.30it/s]
40%|███▉ | 399/1000 [00:15<00:23, 25.36it/s]
40%|████ | 402/1000 [00:16<00:23, 25.37it/s]
40%|████ | 405/1000 [00:16<00:23, 24.87it/s]
41%|████ | 408/1000 [00:16<00:23, 25.02it/s]
41%|████ | 411/1000 [00:16<00:23, 25.08it/s]
41%|████▏ | 414/1000 [00:16<00:23, 25.20it/s]
42%|████▏ | 417/1000 [00:16<00:23, 25.30it/s]
42%|████▏ | 420/1000 [00:16<00:22, 25.35it/s]
42%|████▏ | 423/1000 [00:16<00:22, 25.41it/s]
43%|████▎ | 426/1000 [00:17<00:22, 25.45it/s]
43%|████▎ | 429/1000 [00:17<00:22, 25.15it/s]
43%|████▎ | 432/1000 [00:17<00:22, 24.84it/s]
44%|████▎ | 435/1000 [00:17<00:22, 25.01it/s]
44%|████▍ | 438/1000 [00:17<00:22, 25.08it/s]
44%|████▍ | 441/1000 [00:17<00:22, 25.16it/s]
44%|████▍ | 444/1000 [00:17<00:22, 25.21it/s]
45%|████▍ | 447/1000 [00:17<00:21, 25.23it/s]
45%|████▌ | 450/1000 [00:17<00:21, 25.21it/s]
45%|████▌ | 453/1000 [00:18<00:21, 25.21it/s]
46%|████▌ | 456/1000 [00:18<00:21, 24.79it/s]
46%|████▌ | 459/1000 [00:18<00:21, 24.93it/s]
46%|████▌ | 462/1000 [00:18<00:21, 25.01it/s]
46%|████▋ | 465/1000 [00:18<00:21, 25.12it/s]
47%|████▋ | 468/1000 [00:18<00:21, 25.20it/s]
47%|████▋ | 471/1000 [00:18<00:21, 25.02it/s]
47%|████▋ | 474/1000 [00:18<00:21, 25.00it/s]
48%|████▊ | 477/1000 [00:19<00:20, 25.07it/s]
48%|████▊ | 480/1000 [00:19<00:20, 25.06it/s]
48%|████▊ | 483/1000 [00:19<00:20, 24.75it/s]
49%|████▊ | 486/1000 [00:19<00:20, 24.93it/s]
49%|████▉ | 489/1000 [00:19<00:20, 25.03it/s]
49%|████▉ | 492/1000 [00:19<00:20, 25.11it/s]
50%|████▉ | 495/1000 [00:19<00:20, 25.16it/s]
50%|████▉ | 498/1000 [00:19<00:19, 25.14it/s]
50%|█████ | 501/1000 [00:19<00:19, 25.20it/s]
50%|█████ | 504/1000 [00:20<00:19, 25.24it/s]
51%|█████ | 507/1000 [00:20<00:19, 24.88it/s]
51%|█████ | 510/1000 [00:20<00:19, 25.01it/s]
51%|█████▏ | 513/1000 [00:20<00:19, 25.11it/s]
52%|█████▏ | 516/1000 [00:20<00:19, 25.16it/s]
52%|█████▏ | 519/1000 [00:20<00:19, 25.08it/s]
52%|█████▏ | 522/1000 [00:20<00:19, 25.16it/s]
52%|█████▎ | 525/1000 [00:20<00:18, 25.25it/s]
53%|█████▎ | 528/1000 [00:21<00:18, 25.29it/s]
53%|█████▎ | 531/1000 [00:21<00:18, 25.28it/s]
53%|█████▎ | 534/1000 [00:21<00:18, 24.82it/s]
54%|█████▎ | 537/1000 [00:21<00:18, 24.94it/s]
54%|█████▍ | 540/1000 [00:21<00:18, 25.04it/s]
54%|█████▍ | 543/1000 [00:21<00:18, 25.08it/s]
55%|█████▍ | 546/1000 [00:21<00:18, 25.17it/s]
55%|█████▍ | 549/1000 [00:21<00:17, 25.24it/s]
55%|█████▌ | 552/1000 [00:22<00:17, 25.32it/s]
56%|█████▌ | 555/1000 [00:22<00:17, 25.34it/s]
56%|█████▌ | 558/1000 [00:22<00:17, 24.92it/s]
56%|█████▌ | 561/1000 [00:22<00:17, 25.00it/s]
56%|█████▋ | 564/1000 [00:22<00:17, 25.11it/s]
57%|█████▋ | 567/1000 [00:22<00:17, 25.18it/s]
57%|█████▋ | 570/1000 [00:22<00:17, 25.16it/s]
57%|█████▋ | 573/1000 [00:22<00:17, 25.11it/s]
58%|█████▊ | 576/1000 [00:22<00:16, 25.17it/s]
58%|█████▊ | 579/1000 [00:23<00:16, 25.18it/s]
58%|█████▊ | 582/1000 [00:23<00:16, 25.16it/s]
58%|█████▊ | 585/1000 [00:23<00:16, 24.73it/s]
59%|█████▉ | 588/1000 [00:23<00:16, 24.90it/s]
59%|█████▉ | 591/1000 [00:23<00:16, 25.04it/s]
59%|█████▉ | 594/1000 [00:23<00:16, 25.11it/s]
60%|█████▉ | 597/1000 [00:23<00:16, 25.14it/s]
60%|██████ | 600/1000 [00:23<00:15, 25.17it/s]
60%|██████ | 603/1000 [00:24<00:15, 25.19it/s]
61%|██████ | 606/1000 [00:24<00:15, 25.19it/s]
61%|██████ | 609/1000 [00:24<00:15, 24.62it/s]
61%|██████ | 612/1000 [00:24<00:15, 24.82it/s]
62%|██████▏ | 615/1000 [00:24<00:15, 24.96it/s]
62%|██████▏ | 618/1000 [00:24<00:15, 25.08it/s]
62%|██████▏ | 621/1000 [00:24<00:15, 25.16it/s]
62%|██████▏ | 624/1000 [00:24<00:14, 25.25it/s]
63%|██████▎ | 627/1000 [00:25<00:14, 25.28it/s]
63%|██████▎ | 630/1000 [00:25<00:14, 25.32it/s]
63%|██████▎ | 633/1000 [00:25<00:14, 25.39it/s]
64%|██████▎ | 636/1000 [00:25<00:14, 25.03it/s]
64%|██████▍ | 639/1000 [00:25<00:14, 25.10it/s]
64%|██████▍ | 642/1000 [00:25<00:14, 25.10it/s]
64%|██████▍ | 645/1000 [00:25<00:14, 25.09it/s]
65%|██████▍ | 648/1000 [00:25<00:14, 25.13it/s]
65%|██████▌ | 651/1000 [00:25<00:13, 25.14it/s]
65%|██████▌ | 654/1000 [00:26<00:13, 25.12it/s]
66%|██████▌ | 657/1000 [00:26<00:13, 25.12it/s]
66%|██████▌ | 660/1000 [00:26<00:13, 24.75it/s]
66%|██████▋ | 663/1000 [00:26<00:13, 24.89it/s]
67%|██████▋ | 666/1000 [00:26<00:13, 25.01it/s]
67%|██████▋ | 669/1000 [00:26<00:13, 25.12it/s]
67%|██████▋ | 672/1000 [00:26<00:13, 25.11it/s]
68%|██████▊ | 675/1000 [00:26<00:12, 25.13it/s]
68%|██████▊ | 678/1000 [00:27<00:12, 24.84it/s]
68%|██████▊ | 681/1000 [00:27<00:12, 24.88it/s]
68%|██████▊ | 684/1000 [00:27<00:12, 25.02it/s]
69%|██████▊ | 687/1000 [00:27<00:12, 24.74it/s]
69%|██████▉ | 690/1000 [00:27<00:12, 24.95it/s]
69%|██████▉ | 693/1000 [00:27<00:12, 25.12it/s]
70%|██████▉ | 696/1000 [00:27<00:12, 25.18it/s]
70%|██████▉ | 699/1000 [00:27<00:11, 25.18it/s]
70%|███████ | 702/1000 [00:28<00:11, 25.24it/s]
70%|███████ | 705/1000 [00:28<00:11, 25.13it/s]
71%|███████ | 708/1000 [00:28<00:11, 25.17it/s]
71%|███████ | 711/1000 [00:28<00:11, 24.82it/s]
71%|███████▏ | 714/1000 [00:28<00:11, 24.93it/s]
72%|███████▏ | 717/1000 [00:28<00:11, 25.04it/s]
72%|███████▏ | 720/1000 [00:28<00:11, 25.14it/s]
72%|███████▏ | 723/1000 [00:28<00:10, 25.21it/s]
73%|███████▎ | 726/1000 [00:28<00:10, 25.27it/s]
73%|███████▎ | 729/1000 [00:29<00:10, 25.28it/s]
73%|███████▎ | 732/1000 [00:29<00:10, 25.30it/s]
74%|███████▎ | 735/1000 [00:29<00:10, 25.16it/s]
74%|███████▍ | 738/1000 [00:29<00:10, 24.93it/s]
74%|███████▍ | 741/1000 [00:29<00:10, 25.05it/s]
74%|███████▍ | 744/1000 [00:29<00:10, 25.14it/s]
75%|███████▍ | 747/1000 [00:29<00:10, 25.22it/s]
75%|███████▌ | 750/1000 [00:29<00:09, 25.24it/s]
75%|███████▌ | 753/1000 [00:30<00:09, 25.28it/s]
76%|███████▌ | 756/1000 [00:30<00:09, 25.24it/s]
76%|███████▌ | 759/1000 [00:30<00:09, 25.23it/s]
76%|███████▌ | 762/1000 [00:30<00:09, 24.84it/s]
76%|███████▋ | 765/1000 [00:30<00:09, 25.02it/s]
77%|███████▋ | 768/1000 [00:30<00:09, 25.18it/s]
77%|███████▋ | 771/1000 [00:30<00:09, 25.25it/s]
77%|███████▋ | 774/1000 [00:30<00:08, 25.27it/s]
78%|███████▊ | 777/1000 [00:30<00:08, 25.29it/s]
78%|███████▊ | 780/1000 [00:31<00:08, 25.26it/s]
78%|███████▊ | 783/1000 [00:31<00:08, 25.25it/s]
79%|███████▊ | 786/1000 [00:31<00:08, 25.30it/s]
79%|███████▉ | 789/1000 [00:31<00:08, 24.91it/s]
79%|███████▉ | 792/1000 [00:31<00:08, 25.04it/s]
80%|███████▉ | 795/1000 [00:31<00:08, 25.09it/s]
80%|███████▉ | 798/1000 [00:31<00:08, 25.18it/s]
80%|████████ | 801/1000 [00:31<00:07, 25.23it/s]
80%|████████ | 804/1000 [00:32<00:07, 25.26it/s]
81%|████████ | 807/1000 [00:32<00:07, 25.24it/s]
81%|████████ | 810/1000 [00:32<00:07, 25.24it/s]
81%|████████▏ | 813/1000 [00:32<00:07, 25.26it/s]
82%|████████▏ | 816/1000 [00:32<00:07, 24.96it/s]
82%|████████▏ | 819/1000 [00:32<00:07, 25.07it/s]
82%|████████▏ | 822/1000 [00:32<00:07, 25.06it/s]
82%|████████▎ | 825/1000 [00:32<00:06, 25.12it/s]
83%|████████▎ | 828/1000 [00:33<00:06, 25.20it/s]
83%|████████▎ | 831/1000 [00:33<00:06, 25.21it/s]
83%|████████▎ | 834/1000 [00:33<00:06, 25.25it/s]
84%|████████▎ | 837/1000 [00:33<00:06, 25.32it/s]
84%|████████▍ | 840/1000 [00:33<00:06, 24.95it/s]
84%|████████▍ | 843/1000 [00:33<00:06, 25.12it/s]
85%|████████▍ | 846/1000 [00:33<00:06, 25.22it/s]
85%|████████▍ | 849/1000 [00:33<00:05, 25.27it/s]
85%|████████▌ | 852/1000 [00:33<00:05, 25.30it/s]
86%|████████▌ | 855/1000 [00:34<00:05, 25.32it/s]
86%|████████▌ | 858/1000 [00:34<00:05, 25.07it/s]
86%|████████▌ | 861/1000 [00:34<00:05, 25.13it/s]
86%|████████▋ | 864/1000 [00:34<00:05, 25.20it/s]
87%|████████▋ | 867/1000 [00:34<00:05, 24.69it/s]
87%|████████▋ | 870/1000 [00:34<00:05, 24.89it/s]
87%|████████▋ | 873/1000 [00:34<00:05, 25.01it/s]
88%|████████▊ | 876/1000 [00:34<00:04, 25.10it/s]
88%|████████▊ | 879/1000 [00:35<00:04, 25.19it/s]
88%|████████▊ | 882/1000 [00:35<00:04, 25.18it/s]
88%|████████▊ | 885/1000 [00:35<00:04, 25.22it/s]
89%|████████▉ | 888/1000 [00:35<00:04, 25.21it/s]
89%|████████▉ | 891/1000 [00:35<00:04, 24.94it/s]
89%|████████▉ | 894/1000 [00:35<00:04, 25.05it/s]
90%|████████▉ | 897/1000 [00:35<00:04, 25.14it/s]
90%|█████████ | 900/1000 [00:35<00:03, 25.20it/s]
90%|█████████ | 903/1000 [00:36<00:03, 25.24it/s]
91%|█████████ | 906/1000 [00:36<00:03, 25.22it/s]
91%|█████████ | 909/1000 [00:36<00:03, 25.26it/s]
91%|█████████ | 912/1000 [00:36<00:03, 25.25it/s]
92%|█████████▏| 915/1000 [00:36<00:03, 25.28it/s]
92%|█████████▏| 918/1000 [00:36<00:03, 24.84it/s]
92%|█████████▏| 921/1000 [00:36<00:03, 25.00it/s]
92%|█████████▏| 924/1000 [00:36<00:03, 25.08it/s]
93%|█████████▎| 927/1000 [00:36<00:02, 25.19it/s]
93%|█████████▎| 930/1000 [00:37<00:02, 25.20it/s]
93%|█████████▎| 933/1000 [00:37<00:02, 25.25it/s]
94%|█████████▎| 936/1000 [00:37<00:02, 25.24it/s]
94%|█████████▍| 939/1000 [00:37<00:02, 25.26it/s]
94%|█████████▍| 942/1000 [00:37<00:02, 24.87it/s]
94%|█████████▍| 945/1000 [00:37<00:02, 24.93it/s]
95%|█████████▍| 948/1000 [00:37<00:02, 25.03it/s]
95%|█████████▌| 951/1000 [00:37<00:01, 25.13it/s]
95%|█████████▌| 954/1000 [00:38<00:01, 25.24it/s]
96%|█████████▌| 957/1000 [00:38<00:01, 25.10it/s]
96%|█████████▌| 960/1000 [00:38<00:01, 25.18it/s]
96%|█████████▋| 963/1000 [00:38<00:01, 25.21it/s]
97%|█████████▋| 966/1000 [00:38<00:01, 25.26it/s]
97%|█████████▋| 969/1000 [00:38<00:01, 25.06it/s]
97%|█████████▋| 972/1000 [00:38<00:01, 25.12it/s]
98%|█████████▊| 975/1000 [00:38<00:00, 25.19it/s]
98%|█████████▊| 978/1000 [00:38<00:00, 25.24it/s]
98%|█████████▊| 981/1000 [00:39<00:00, 25.27it/s]
98%|█████████▊| 984/1000 [00:39<00:00, 25.28it/s]
99%|█████████▊| 987/1000 [00:39<00:00, 25.28it/s]
99%|█████████▉| 990/1000 [00:39<00:00, 25.31it/s]
99%|█████████▉| 993/1000 [00:39<00:00, 24.87it/s]
100%|█████████▉| 996/1000 [00:39<00:00, 25.00it/s]
100%|█████████▉| 999/1000 [00:39<00:00, 25.09it/s]
100%|██████████| 1000/1000 [00:39<00:00, 25.09it/s]
We can plot the loss to make sure that it decreases

Combine with arbitrary optimizer#
Pytorch provides a wide range of optimizers for training neural networks. We can also pick one of those to optimizer our parameter
kernel_init = torch.zeros_like(true_kernel)
kernel_init[..., 5:-5, 5:-5] = 1.0
kernel_init = projection_simplex_sort(kernel_init)
kernel_hat = kernel_init.clone()
optimizer = torch.optim.Adam([kernel_hat], lr=0.1)
# We will alternate a gradient step and a projection step
losses = []
n_iter = 200
for i in tqdm(range(n_iter)):
optimizer.zero_grad()
# compute the gradient, this will directly change the gradient of `kernel_hat`
with torch.enable_grad():
kernel_hat.requires_grad_(True)
physics.update(filter=kernel_hat)
loss = data_fidelity(y=y, x=x, physics=physics) / y.numel()
loss.backward()
# a gradient step
optimizer.step()
# projection step, when doing additional steps, it's important to change only
# the tensor data to avoid breaking the gradient computation
kernel_hat.data = projection_simplex_sort(kernel_hat.data)
# loss
losses.append(loss.item())
dinv.utils.plot(
[true_kernel, kernel_init, kernel_hat],
titles=["True kernel", "Init. kernel", "Estimated kernel"],
suptitle="Result with ADAM",
)

0%| | 0/200 [00:00<?, ?it/s]
2%|▏ | 3/200 [00:00<00:08, 24.29it/s]
3%|▎ | 6/200 [00:00<00:07, 24.51it/s]
4%|▍ | 9/200 [00:00<00:07, 24.60it/s]
6%|▌ | 12/200 [00:00<00:07, 24.75it/s]
8%|▊ | 15/200 [00:00<00:07, 24.44it/s]
9%|▉ | 18/200 [00:00<00:07, 24.65it/s]
10%|█ | 21/200 [00:00<00:07, 24.83it/s]
12%|█▏ | 24/200 [00:00<00:07, 25.00it/s]
14%|█▎ | 27/200 [00:01<00:06, 25.10it/s]
15%|█▌ | 30/200 [00:01<00:06, 25.13it/s]
16%|█▋ | 33/200 [00:01<00:06, 25.08it/s]
18%|█▊ | 36/200 [00:01<00:06, 25.09it/s]
20%|█▉ | 39/200 [00:01<00:06, 24.76it/s]
21%|██ | 42/200 [00:01<00:06, 24.85it/s]
22%|██▎ | 45/200 [00:01<00:06, 24.94it/s]
24%|██▍ | 48/200 [00:01<00:06, 25.02it/s]
26%|██▌ | 51/200 [00:02<00:05, 25.08it/s]
27%|██▋ | 54/200 [00:02<00:05, 25.10it/s]
28%|██▊ | 57/200 [00:02<00:05, 25.11it/s]
30%|███ | 60/200 [00:02<00:05, 25.12it/s]
32%|███▏ | 63/200 [00:02<00:05, 25.14it/s]
33%|███▎ | 66/200 [00:02<00:05, 24.74it/s]
34%|███▍ | 69/200 [00:02<00:05, 24.91it/s]
36%|███▌ | 72/200 [00:02<00:05, 25.01it/s]
38%|███▊ | 75/200 [00:03<00:04, 25.07it/s]
39%|███▉ | 78/200 [00:03<00:04, 25.04it/s]
40%|████ | 81/200 [00:03<00:04, 25.08it/s]
42%|████▏ | 84/200 [00:03<00:04, 25.12it/s]
44%|████▎ | 87/200 [00:03<00:04, 25.19it/s]
45%|████▌ | 90/200 [00:03<00:04, 24.84it/s]
46%|████▋ | 93/200 [00:03<00:04, 24.97it/s]
48%|████▊ | 96/200 [00:03<00:04, 25.01it/s]
50%|████▉ | 99/200 [00:03<00:04, 25.08it/s]
51%|█████ | 102/200 [00:04<00:03, 25.11it/s]
52%|█████▎ | 105/200 [00:04<00:03, 25.16it/s]
54%|█████▍ | 108/200 [00:04<00:03, 25.15it/s]
56%|█████▌ | 111/200 [00:04<00:03, 25.18it/s]
57%|█████▋ | 114/200 [00:04<00:03, 25.21it/s]
58%|█████▊ | 117/200 [00:04<00:03, 24.70it/s]
60%|██████ | 120/200 [00:04<00:03, 24.76it/s]
62%|██████▏ | 123/200 [00:04<00:03, 24.76it/s]
63%|██████▎ | 126/200 [00:05<00:02, 24.78it/s]
64%|██████▍ | 129/200 [00:05<00:02, 24.81it/s]
66%|██████▌ | 132/200 [00:05<00:02, 24.86it/s]
68%|██████▊ | 135/200 [00:05<00:02, 24.89it/s]
69%|██████▉ | 138/200 [00:05<00:02, 24.92it/s]
70%|███████ | 141/200 [00:05<00:02, 24.03it/s]
72%|███████▏ | 144/200 [00:05<00:02, 24.38it/s]
74%|███████▎ | 147/200 [00:05<00:02, 24.67it/s]
75%|███████▌ | 150/200 [00:06<00:02, 24.87it/s]
76%|███████▋ | 153/200 [00:06<00:01, 24.94it/s]
78%|███████▊ | 156/200 [00:06<00:01, 25.03it/s]
80%|███████▉ | 159/200 [00:06<00:01, 25.10it/s]
81%|████████ | 162/200 [00:06<00:01, 25.19it/s]
82%|████████▎ | 165/200 [00:06<00:01, 25.23it/s]
84%|████████▍ | 168/200 [00:06<00:01, 24.80it/s]
86%|████████▌ | 171/200 [00:06<00:01, 24.92it/s]
87%|████████▋ | 174/200 [00:06<00:01, 25.02it/s]
88%|████████▊ | 177/200 [00:07<00:00, 25.00it/s]
90%|█████████ | 180/200 [00:07<00:00, 25.06it/s]
92%|█████████▏| 183/200 [00:07<00:00, 25.08it/s]
93%|█████████▎| 186/200 [00:07<00:00, 25.12it/s]
94%|█████████▍| 189/200 [00:07<00:00, 25.13it/s]
96%|█████████▌| 192/200 [00:07<00:00, 24.79it/s]
98%|█████████▊| 195/200 [00:07<00:00, 24.82it/s]
99%|█████████▉| 198/200 [00:07<00:00, 24.88it/s]
100%|██████████| 200/200 [00:08<00:00, 24.94it/s]
We can plot the loss to make sure that it decreases

Optimizing the physics as a usual neural network#
Below we show another way to optimize the parameter of the physics, as we usually do for neural networks
kernel_init = torch.zeros_like(true_kernel)
kernel_init[..., 5:-5, 5:-5] = 1.0
kernel_init = projection_simplex_sort(kernel_init)
# The gradient is off by default, we need to enable the gradient of the parameter
physics = dinv.physics.Blur(
filter=kernel_init.clone().requires_grad_(True), device=device
)
# Set up the optimizer by giving the parameter to an optimizer
# Try to change your favorite optimizer
optimizer = torch.optim.AdamW([physics.filter], lr=0.1)
# Try to change another loss function
# loss_fn = torch.nn.MSELoss()
loss_fn = torch.nn.L1Loss()
# We will alternate a gradient step and a projection step
losses = []
n_iter = 100
for i in tqdm(range(n_iter)):
# update the gradient
optimizer.zero_grad()
y_hat = physics.A(x)
loss = loss_fn(y_hat, y)
loss.backward()
# a gradient step
optimizer.step()
# projection step.
# Note: when doing additional steps, it's important to change only
# the tensor data to avoid breaking the gradient computation
physics.filter.data = projection_simplex_sort(physics.filter.data)
# loss
losses.append(loss.item())
kernel_hat = physics.filter.data
dinv.utils.plot(
[true_kernel, kernel_init, kernel_hat],
titles=["True kernel", "Init. kernel", "Estimated kernel"],
suptitle="Result with AdamW and L1 Loss",
)

0%| | 0/100 [00:00<?, ?it/s]
3%|▎ | 3/100 [00:00<00:04, 23.38it/s]
6%|▌ | 6/100 [00:00<00:03, 24.08it/s]
9%|▉ | 9/100 [00:00<00:03, 24.38it/s]
12%|█▏ | 12/100 [00:00<00:03, 24.13it/s]
15%|█▌ | 15/100 [00:00<00:03, 24.47it/s]
18%|█▊ | 18/100 [00:00<00:03, 24.71it/s]
21%|██ | 21/100 [00:00<00:03, 24.84it/s]
24%|██▍ | 24/100 [00:00<00:03, 24.94it/s]
27%|██▋ | 27/100 [00:01<00:02, 24.99it/s]
30%|███ | 30/100 [00:01<00:02, 24.94it/s]
33%|███▎ | 33/100 [00:01<00:02, 25.05it/s]
36%|███▌ | 36/100 [00:01<00:02, 25.14it/s]
39%|███▉ | 39/100 [00:01<00:02, 24.79it/s]
42%|████▏ | 42/100 [00:01<00:02, 24.92it/s]
45%|████▌ | 45/100 [00:01<00:02, 24.98it/s]
48%|████▊ | 48/100 [00:01<00:02, 24.98it/s]
51%|█████ | 51/100 [00:02<00:01, 25.05it/s]
54%|█████▍ | 54/100 [00:02<00:01, 25.10it/s]
57%|█████▋ | 57/100 [00:02<00:01, 25.10it/s]
60%|██████ | 60/100 [00:02<00:01, 25.14it/s]
63%|██████▎ | 63/100 [00:02<00:01, 24.62it/s]
66%|██████▌ | 66/100 [00:02<00:01, 24.77it/s]
69%|██████▉ | 69/100 [00:02<00:01, 24.89it/s]
72%|███████▏ | 72/100 [00:02<00:01, 24.91it/s]
75%|███████▌ | 75/100 [00:03<00:01, 24.95it/s]
78%|███████▊ | 78/100 [00:03<00:00, 24.97it/s]
81%|████████ | 81/100 [00:03<00:00, 25.00it/s]
84%|████████▍ | 84/100 [00:03<00:00, 25.04it/s]
87%|████████▋ | 87/100 [00:03<00:00, 25.03it/s]
90%|█████████ | 90/100 [00:03<00:00, 24.61it/s]
93%|█████████▎| 93/100 [00:03<00:00, 24.69it/s]
96%|█████████▌| 96/100 [00:03<00:00, 24.77it/s]
99%|█████████▉| 99/100 [00:03<00:00, 24.88it/s]
100%|██████████| 100/100 [00:04<00:00, 24.85it/s]
We can plot the loss to make sure that it decreases

Total running time of the script: (0 minutes 52.797 seconds)