Note
New to DeepInverse? Get started with the basics with the 5 minute quickstart tutorial.
Using state-of-the-art diffusion models from HuggingFace Diffusers with DeepInverse#
This demo shows you how to use our wrapper
deepinv.models.DiffusersDenoiserWrapper to turn any SOTA models from the HuggingFace Hub to an image denoiser. It also can be used to perform unconditional image generation or for posterior sampling.
See more about the diffusers pipeline and our posterior sampling user guide.
Note
This example requires the diffusers and transformers package. You can install it via pip install diffusers transformers.
import torch
import deepinv as dinv
from deepinv.models.wrapper import DiffusersDenoiserWrapper
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float32
figsize = 2.5
from deepinv.sampling import (
PosteriorDiffusion,
EulerSolver,
VarianceExplodingDiffusion,
VariancePreservingDiffusion,
)
from deepinv.optim import ZeroFidelity
Let us first load a pretrained diffusion model from the HuggingFace Hub. Here, we use the google/ddpm-ema-celebahq-256 model.
This model is trained on 256x256 of CelebA dataset using the DDPM scheduler.
# We can wrap any diffusers model as a DeepInv denoiser using one line of code:
denoiser = DiffusersDenoiserWrapper(
mode_id="google/ddpm-ema-celebahq-256", device=device
)
# Load an example image
x = dinv.utils.load_example(
"celeba_example2.jpg",
img_size=256,
resize_mode="resize",
).to(device)
# Add noise and test the denoiser
sigma = 0.1
x_noisy = x + sigma * torch.randn_like(x)
with torch.no_grad():
x_denoised = denoiser(x_noisy, sigma=sigma)
dinv.utils.plot(
[x, x_noisy, x_denoised],
figsize=(figsize * 3, figsize),
titles=["Original image", "Noisy image", "Denoised image"],
)

Fetching 4 files: 0%| | 0/4 [00:00<?, ?it/s]
Fetching 4 files: 25%|██▌ | 1/4 [00:00<00:00, 5.16it/s]
Fetching 4 files: 50%|█████ | 2/4 [00:03<00:04, 2.20s/it]
Fetching 4 files: 100%|██████████| 4/4 [00:03<00:00, 1.05it/s]
Loading pipeline components...: 0%| | 0/2 [00:00<?, ?it/s]An error occurred while trying to fetch /home/runner/.cache/huggingface/hub/models--google--ddpm-ema-celebahq-256/snapshots/4cb6117472e6e4f45c5afe606b101858c27c3802: Error no file named diffusion_pytorch_model.safetensors found in directory /home/runner/.cache/huggingface/hub/models--google--ddpm-ema-celebahq-256/snapshots/4cb6117472e6e4f45c5afe606b101858c27c3802.
Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead.
Loading pipeline components...: 100%|██████████| 2/2 [00:00<00:00, 23.11it/s]
It is also possible to use the wrapped model for unconditional image generation. The model was trained with DDPM scheduler, however we can use it with any SDE provided in DeepInv. Here, we use the Variance Exploding SDE with Euler solver for sampling.
num_steps = 125
rng = torch.Generator(device)
timesteps = torch.linspace(1, 0.001, num_steps)
solver = EulerSolver(timesteps=timesteps, rng=rng)
sigma_min = 0.001
sigma_max = 80
sde = VarianceExplodingDiffusion(
sigma_max=sigma_max,
sigma_min=sigma_min,
alpha=0.5,
device=device,
dtype=dtype,
)
model = PosteriorDiffusion(
data_fidelity=ZeroFidelity(),
sde=sde,
denoiser=denoiser,
solver=solver,
dtype=dtype,
device=device,
verbose=True,
)
sample, trajectory = model(
y=None,
physics=None,
x_init=(1, 3, 256, 256),
seed=42,
get_trajectory=True,
)
dinv.utils.plot(
sample,
titles="Unconditional generation",
figsize=(figsize, figsize),
)

0%| | 0/124 [00:00<?, ?it/s]
1%| | 1/124 [00:02<04:12, 2.05s/it]
2%|▏ | 2/124 [00:04<04:10, 2.05s/it]
2%|▏ | 3/124 [00:06<04:08, 2.05s/it]
3%|▎ | 4/124 [00:08<04:05, 2.05s/it]
4%|▍ | 5/124 [00:10<04:03, 2.05s/it]
5%|▍ | 6/124 [00:12<04:00, 2.04s/it]
6%|▌ | 7/124 [00:14<03:58, 2.04s/it]
6%|▋ | 8/124 [00:16<03:56, 2.04s/it]
7%|▋ | 9/124 [00:18<03:54, 2.04s/it]
8%|▊ | 10/124 [00:20<03:52, 2.04s/it]
9%|▉ | 11/124 [00:22<03:49, 2.03s/it]
10%|▉ | 12/124 [00:24<03:47, 2.03s/it]
10%|█ | 13/124 [00:26<03:44, 2.02s/it]
11%|█▏ | 14/124 [00:28<03:42, 2.02s/it]
12%|█▏ | 15/124 [00:30<03:40, 2.02s/it]
13%|█▎ | 16/124 [00:32<03:38, 2.02s/it]
14%|█▎ | 17/124 [00:34<03:36, 2.02s/it]
15%|█▍ | 18/124 [00:36<03:34, 2.02s/it]
15%|█▌ | 19/124 [00:38<03:32, 2.03s/it]
16%|█▌ | 20/124 [00:40<03:31, 2.03s/it]
17%|█▋ | 21/124 [00:42<03:29, 2.04s/it]
18%|█▊ | 22/124 [00:44<03:28, 2.04s/it]
19%|█▊ | 23/124 [00:46<03:26, 2.04s/it]
19%|█▉ | 24/124 [00:48<03:24, 2.05s/it]
20%|██ | 25/124 [00:50<03:23, 2.05s/it]
21%|██ | 26/124 [00:52<03:21, 2.05s/it]
22%|██▏ | 27/124 [00:55<03:19, 2.06s/it]
23%|██▎ | 28/124 [00:57<03:17, 2.06s/it]
23%|██▎ | 29/124 [00:59<03:14, 2.05s/it]
24%|██▍ | 30/124 [01:01<03:12, 2.05s/it]
25%|██▌ | 31/124 [01:03<03:10, 2.05s/it]
26%|██▌ | 32/124 [01:05<03:08, 2.05s/it]
27%|██▋ | 33/124 [01:07<03:05, 2.04s/it]
27%|██▋ | 34/124 [01:09<03:03, 2.04s/it]
28%|██▊ | 35/124 [01:11<03:01, 2.04s/it]
29%|██▉ | 36/124 [01:13<02:59, 2.04s/it]
30%|██▉ | 37/124 [01:15<02:58, 2.05s/it]
31%|███ | 38/124 [01:17<02:56, 2.05s/it]
31%|███▏ | 39/124 [01:19<02:53, 2.04s/it]
32%|███▏ | 40/124 [01:21<02:51, 2.04s/it]
33%|███▎ | 41/124 [01:23<02:48, 2.03s/it]
34%|███▍ | 42/124 [01:25<02:46, 2.03s/it]
35%|███▍ | 43/124 [01:27<02:44, 2.03s/it]
35%|███▌ | 44/124 [01:29<02:42, 2.03s/it]
36%|███▋ | 45/124 [01:31<02:39, 2.02s/it]
37%|███▋ | 46/124 [01:33<02:37, 2.02s/it]
38%|███▊ | 47/124 [01:35<02:35, 2.02s/it]
39%|███▊ | 48/124 [01:37<02:33, 2.02s/it]
40%|███▉ | 49/124 [01:39<02:32, 2.03s/it]
40%|████ | 50/124 [01:41<02:30, 2.03s/it]
41%|████ | 51/124 [01:43<02:28, 2.04s/it]
42%|████▏ | 52/124 [01:45<02:26, 2.04s/it]
43%|████▎ | 53/124 [01:48<02:24, 2.04s/it]
44%|████▎ | 54/124 [01:50<02:23, 2.04s/it]
44%|████▍ | 55/124 [01:52<02:21, 2.04s/it]
45%|████▌ | 56/124 [01:54<02:18, 2.04s/it]
46%|████▌ | 57/124 [01:56<02:17, 2.04s/it]
47%|████▋ | 58/124 [01:58<02:14, 2.04s/it]
48%|████▊ | 59/124 [02:00<02:12, 2.04s/it]
48%|████▊ | 60/124 [02:02<02:10, 2.04s/it]
49%|████▉ | 61/124 [02:04<02:08, 2.04s/it]
50%|█████ | 62/124 [02:06<02:06, 2.04s/it]
51%|█████ | 63/124 [02:08<02:04, 2.04s/it]
52%|█████▏ | 64/124 [02:10<02:02, 2.04s/it]
52%|█████▏ | 65/124 [02:12<02:00, 2.04s/it]
53%|█████▎ | 66/124 [02:14<01:58, 2.04s/it]
54%|█████▍ | 67/124 [02:16<01:56, 2.04s/it]
55%|█████▍ | 68/124 [02:18<01:54, 2.04s/it]
56%|█████▌ | 69/124 [02:20<01:52, 2.04s/it]
56%|█████▋ | 70/124 [02:22<01:50, 2.05s/it]
57%|█████▋ | 71/124 [02:24<01:48, 2.04s/it]
58%|█████▊ | 72/124 [02:26<01:45, 2.03s/it]
59%|█████▉ | 73/124 [02:28<01:43, 2.03s/it]
60%|█████▉ | 74/124 [02:30<01:41, 2.03s/it]
60%|██████ | 75/124 [02:32<01:39, 2.02s/it]
61%|██████▏ | 76/124 [02:34<01:37, 2.02s/it]
62%|██████▏ | 77/124 [02:36<01:35, 2.02s/it]
63%|██████▎ | 78/124 [02:38<01:33, 2.03s/it]
64%|██████▎ | 79/124 [02:40<01:31, 2.03s/it]
65%|██████▍ | 80/124 [02:42<01:29, 2.03s/it]
65%|██████▌ | 81/124 [02:45<01:27, 2.04s/it]
66%|██████▌ | 82/124 [02:47<01:25, 2.04s/it]
67%|██████▋ | 83/124 [02:49<01:23, 2.04s/it]
68%|██████▊ | 84/124 [02:51<01:21, 2.04s/it]
69%|██████▊ | 85/124 [02:53<01:19, 2.04s/it]
69%|██████▉ | 86/124 [02:55<01:17, 2.03s/it]
70%|███████ | 87/124 [02:57<01:15, 2.03s/it]
71%|███████ | 88/124 [02:59<01:12, 2.03s/it]
72%|███████▏ | 89/124 [03:01<01:10, 2.02s/it]
73%|███████▎ | 90/124 [03:03<01:08, 2.02s/it]
73%|███████▎ | 91/124 [03:05<01:06, 2.02s/it]
74%|███████▍ | 92/124 [03:07<01:04, 2.02s/it]
75%|███████▌ | 93/124 [03:09<01:02, 2.03s/it]
76%|███████▌ | 94/124 [03:11<01:00, 2.03s/it]
77%|███████▋ | 95/124 [03:13<00:59, 2.04s/it]
77%|███████▋ | 96/124 [03:15<00:57, 2.04s/it]
78%|███████▊ | 97/124 [03:17<00:55, 2.05s/it]
79%|███████▉ | 98/124 [03:19<00:53, 2.04s/it]
80%|███████▉ | 99/124 [03:21<00:51, 2.04s/it]
81%|████████ | 100/124 [03:23<00:48, 2.04s/it]
81%|████████▏ | 101/124 [03:25<00:46, 2.04s/it]
82%|████████▏ | 102/124 [03:27<00:45, 2.05s/it]
83%|████████▎ | 103/124 [03:29<00:43, 2.05s/it]
84%|████████▍ | 104/124 [03:31<00:40, 2.04s/it]
85%|████████▍ | 105/124 [03:33<00:38, 2.04s/it]
85%|████████▌ | 106/124 [03:35<00:36, 2.04s/it]
86%|████████▋ | 107/124 [03:37<00:34, 2.03s/it]
87%|████████▋ | 108/124 [03:40<00:32, 2.04s/it]
88%|████████▊ | 109/124 [03:42<00:30, 2.03s/it]
89%|████████▊ | 110/124 [03:44<00:28, 2.04s/it]
90%|████████▉ | 111/124 [03:46<00:26, 2.04s/it]
90%|█████████ | 112/124 [03:48<00:24, 2.03s/it]
91%|█████████ | 113/124 [03:50<00:22, 2.04s/it]
92%|█████████▏| 114/124 [03:52<00:20, 2.04s/it]
93%|█████████▎| 115/124 [03:54<00:18, 2.04s/it]
94%|█████████▎| 116/124 [03:56<00:16, 2.04s/it]
94%|█████████▍| 117/124 [03:58<00:14, 2.04s/it]
95%|█████████▌| 118/124 [04:00<00:12, 2.04s/it]
96%|█████████▌| 119/124 [04:02<00:10, 2.05s/it]
97%|█████████▋| 120/124 [04:04<00:08, 2.05s/it]
98%|█████████▊| 121/124 [04:06<00:06, 2.04s/it]
98%|█████████▊| 122/124 [04:08<00:04, 2.04s/it]
99%|█████████▉| 123/124 [04:10<00:02, 2.04s/it]
100%|██████████| 124/124 [04:12<00:00, 2.04s/it]
100%|██████████| 124/124 [04:12<00:00, 2.04s/it]
Similar to other denoisers in DeepInv, the wrapped diffusers model can be used for posterior sampling. Below we use VP-SDE for posterior sampling in an inpainting problem.
# Initialize the physics and the VP-SDE
mask = torch.ones_like(x)
mask[..., 70:150, 120:180] = 0
physics = dinv.physics.Inpainting(
mask=mask,
img_size=x.shape[1:],
device=device,
noise_model=dinv.physics.GaussianNoise(0.05),
)
y = physics(x)
sde = VariancePreservingDiffusion(device=device, dtype=dtype)
from deepinv.sampling import DPSDataFidelity
model = PosteriorDiffusion(
data_fidelity=DPSDataFidelity(denoiser=denoiser, weight=0.3),
denoiser=denoiser,
sde=sde,
solver=solver,
dtype=dtype,
device=device,
verbose=True,
)
posterior_sample = model(
y=y,
physics=physics,
x_init=(1, 3, 256, 256),
seed=15,
)
dinv.utils.plot(
[x, y, posterior_sample],
titles=["Original image", "Measurement", "Posterior sample"],
figsize=(figsize * 3, figsize),
)

0%| | 0/124 [00:00<?, ?it/s]
1%| | 1/124 [00:06<13:14, 6.46s/it]
2%|▏ | 2/124 [00:12<12:58, 6.38s/it]
2%|▏ | 3/124 [00:19<12:48, 6.35s/it]
3%|▎ | 4/124 [00:25<12:43, 6.36s/it]
4%|▍ | 5/124 [00:31<12:39, 6.38s/it]
5%|▍ | 6/124 [00:38<12:36, 6.41s/it]
6%|▌ | 7/124 [00:44<12:30, 6.41s/it]
6%|▋ | 8/124 [00:51<12:25, 6.42s/it]
7%|▋ | 9/124 [00:57<12:20, 6.44s/it]
8%|▊ | 10/124 [01:04<12:12, 6.43s/it]
9%|▉ | 11/124 [01:10<12:04, 6.42s/it]
10%|▉ | 12/124 [01:16<11:58, 6.41s/it]
10%|█ | 13/124 [01:23<11:50, 6.40s/it]
11%|█▏ | 14/124 [01:29<11:42, 6.39s/it]
12%|█▏ | 15/124 [01:36<11:37, 6.40s/it]
13%|█▎ | 16/124 [01:42<11:34, 6.43s/it]
14%|█▎ | 17/124 [01:49<11:29, 6.44s/it]
15%|█▍ | 18/124 [01:55<11:23, 6.45s/it]
15%|█▌ | 19/124 [02:01<11:17, 6.46s/it]
16%|█▌ | 20/124 [02:08<11:10, 6.45s/it]
17%|█▋ | 21/124 [02:14<11:02, 6.43s/it]
18%|█▊ | 22/124 [02:21<10:54, 6.42s/it]
19%|█▊ | 23/124 [02:27<10:44, 6.38s/it]
19%|█▉ | 24/124 [02:33<10:35, 6.36s/it]
20%|██ | 25/124 [02:40<10:30, 6.37s/it]
21%|██ | 26/124 [02:46<10:24, 6.38s/it]
22%|██▏ | 27/124 [02:52<10:18, 6.38s/it]
23%|██▎ | 28/124 [02:59<10:12, 6.38s/it]
23%|██▎ | 29/124 [03:05<10:05, 6.38s/it]
24%|██▍ | 30/124 [03:11<09:57, 6.36s/it]
25%|██▌ | 31/124 [03:18<09:50, 6.35s/it]
26%|██▌ | 32/124 [03:24<09:44, 6.35s/it]
27%|██▋ | 33/124 [03:31<09:39, 6.36s/it]
27%|██▋ | 34/124 [03:37<09:35, 6.39s/it]
28%|██▊ | 35/124 [03:43<09:28, 6.38s/it]
29%|██▉ | 36/124 [03:50<09:23, 6.40s/it]
30%|██▉ | 37/124 [03:56<09:17, 6.40s/it]
31%|███ | 38/124 [04:03<09:10, 6.41s/it]
31%|███▏ | 39/124 [04:09<09:04, 6.40s/it]
32%|███▏ | 40/124 [04:15<08:58, 6.41s/it]
33%|███▎ | 41/124 [04:22<08:51, 6.40s/it]
34%|███▍ | 42/124 [04:28<08:45, 6.41s/it]
35%|███▍ | 43/124 [04:35<08:38, 6.41s/it]
35%|███▌ | 44/124 [04:41<08:33, 6.42s/it]
36%|███▋ | 45/124 [04:48<08:27, 6.43s/it]
37%|███▋ | 46/124 [04:54<08:21, 6.42s/it]
38%|███▊ | 47/124 [05:00<08:14, 6.42s/it]
39%|███▊ | 48/124 [05:07<08:08, 6.43s/it]
40%|███▉ | 49/124 [05:13<08:01, 6.43s/it]
40%|████ | 50/124 [05:20<07:55, 6.43s/it]
41%|████ | 51/124 [05:26<07:49, 6.44s/it]
42%|████▏ | 52/124 [05:33<07:43, 6.44s/it]
43%|████▎ | 53/124 [05:39<07:37, 6.44s/it]
44%|████▎ | 54/124 [05:45<07:30, 6.43s/it]
44%|████▍ | 55/124 [05:52<07:23, 6.43s/it]
45%|████▌ | 56/124 [05:58<07:17, 6.44s/it]
46%|████▌ | 57/124 [06:05<07:11, 6.45s/it]
47%|████▋ | 58/124 [06:11<07:04, 6.43s/it]
48%|████▊ | 59/124 [06:18<06:57, 6.43s/it]
48%|████▊ | 60/124 [06:24<06:51, 6.42s/it]
49%|████▉ | 61/124 [06:31<06:45, 6.44s/it]
50%|█████ | 62/124 [06:37<06:39, 6.45s/it]
51%|█████ | 63/124 [06:44<06:34, 6.48s/it]
52%|█████▏ | 64/124 [06:50<06:27, 6.47s/it]
52%|█████▏ | 65/124 [06:56<06:21, 6.46s/it]
53%|█████▎ | 66/124 [07:03<06:14, 6.46s/it]
54%|█████▍ | 67/124 [07:09<06:07, 6.46s/it]
55%|█████▍ | 68/124 [07:16<06:01, 6.45s/it]
56%|█████▌ | 69/124 [07:22<05:54, 6.45s/it]
56%|█████▋ | 70/124 [07:29<05:48, 6.46s/it]
57%|█████▋ | 71/124 [07:35<05:42, 6.46s/it]
58%|█████▊ | 72/124 [07:42<05:35, 6.46s/it]
59%|█████▉ | 73/124 [07:48<05:29, 6.45s/it]
60%|█████▉ | 74/124 [07:55<05:22, 6.45s/it]
60%|██████ | 75/124 [08:01<05:16, 6.45s/it]
61%|██████▏ | 76/124 [08:07<05:09, 6.45s/it]
62%|██████▏ | 77/124 [08:14<05:02, 6.44s/it]
63%|██████▎ | 78/124 [08:20<04:55, 6.43s/it]
64%|██████▎ | 79/124 [08:27<04:49, 6.43s/it]
65%|██████▍ | 80/124 [08:33<04:43, 6.45s/it]
65%|██████▌ | 81/124 [08:40<04:37, 6.46s/it]
66%|██████▌ | 82/124 [08:46<04:31, 6.47s/it]
67%|██████▋ | 83/124 [08:53<04:25, 6.47s/it]
68%|██████▊ | 84/124 [08:59<04:18, 6.46s/it]
69%|██████▊ | 85/124 [09:06<04:11, 6.46s/it]
69%|██████▉ | 86/124 [09:12<04:05, 6.46s/it]
70%|███████ | 87/124 [09:18<03:58, 6.45s/it]
71%|███████ | 88/124 [09:25<03:51, 6.44s/it]
72%|███████▏ | 89/124 [09:31<03:45, 6.44s/it]
73%|███████▎ | 90/124 [09:38<03:39, 6.45s/it]
73%|███████▎ | 91/124 [09:44<03:32, 6.45s/it]
74%|███████▍ | 92/124 [09:51<03:26, 6.45s/it]
75%|███████▌ | 93/124 [09:57<03:19, 6.45s/it]
76%|███████▌ | 94/124 [10:04<03:13, 6.47s/it]
77%|███████▋ | 95/124 [10:10<03:07, 6.48s/it]
77%|███████▋ | 96/124 [10:17<03:01, 6.48s/it]
78%|███████▊ | 97/124 [10:23<02:54, 6.48s/it]
79%|███████▉ | 98/124 [10:29<02:48, 6.47s/it]
80%|███████▉ | 99/124 [10:36<02:41, 6.47s/it]
81%|████████ | 100/124 [10:42<02:35, 6.48s/it]
81%|████████▏ | 101/124 [10:49<02:28, 6.47s/it]
82%|████████▏ | 102/124 [10:55<02:22, 6.47s/it]
83%|████████▎ | 103/124 [11:02<02:15, 6.47s/it]
84%|████████▍ | 104/124 [11:08<02:09, 6.47s/it]
85%|████████▍ | 105/124 [11:15<02:02, 6.46s/it]
85%|████████▌ | 106/124 [11:21<01:56, 6.46s/it]
86%|████████▋ | 107/124 [11:28<01:49, 6.46s/it]
87%|████████▋ | 108/124 [11:34<01:43, 6.46s/it]
88%|████████▊ | 109/124 [11:41<01:36, 6.46s/it]
89%|████████▊ | 110/124 [11:47<01:30, 6.45s/it]
90%|████████▉ | 111/124 [11:53<01:23, 6.44s/it]
90%|█████████ | 112/124 [12:00<01:17, 6.44s/it]
91%|█████████ | 113/124 [12:06<01:10, 6.44s/it]
92%|█████████▏| 114/124 [12:13<01:04, 6.43s/it]
93%|█████████▎| 115/124 [12:19<00:57, 6.43s/it]
94%|█████████▎| 116/124 [12:26<00:51, 6.42s/it]
94%|█████████▍| 117/124 [12:32<00:44, 6.43s/it]
95%|█████████▌| 118/124 [12:38<00:38, 6.43s/it]
96%|█████████▌| 119/124 [12:45<00:32, 6.43s/it]
97%|█████████▋| 120/124 [12:51<00:25, 6.43s/it]
98%|█████████▊| 121/124 [12:58<00:19, 6.43s/it]
98%|█████████▊| 122/124 [13:04<00:12, 6.43s/it]
99%|█████████▉| 123/124 [13:11<00:06, 6.43s/it]
100%|██████████| 124/124 [13:17<00:00, 6.42s/it]
100%|██████████| 124/124 [13:17<00:00, 6.43s/it]
Total running time of the script: (17 minutes 39.601 seconds)