Source code for deepinv.sampling.diffusion

import torch
import numpy as np
from tqdm import tqdm
from deepinv.models import Reconstructor

from typing import Dict, Any

import deepinv.physics
from deepinv.utils.plotting import plot
from deepinv.sampling import BaseSampling
from deepinv.sampling.sampling_iterators import DiffusionIterator


[docs] class DiffusionSampler(BaseSampling): r""" Turns a diffusion method into a Monte Carlo sampler Unlike diffusion methods, the resulting sampler computes the mean and variance of the distribution by running the diffusion multiple times. See the docs for :class:`deepinv.sampling.BaseSampling` for more information. It uses the helper class :class:`deepinv.sampling.DiffusionIterator`. :param torch.nn.Module diffusion: a diffusion model :param int max_iter: the number of samples to generate :param tuple clip: the clip range :param Callable g_statistic: the algorithm computes mean and variance of the g function, by default :math:`g(x) = x`. :param float thres_conv: the convergence threshold for the mean and variance :param bool verbose: whether to print the progress :param bool save_chain: whether to save the chain :param int thinning: the thinning factor :param float burnin_ratio: the burnin ratio """ def __init__( self, diffusion, max_iter=1e2, clip=(-1, 2), thres_conv=1e-1, g_statistic=lambda x: x, verbose=True, save_chain=False, ): # generate an iterator # set the params of the base class data_fidelity = None diffusion.verbose = False prior = diffusion iterator = DiffusionIterator(clip=clip) super().__init__( iterator, data_fidelity, prior, max_iter=max_iter, thinning=1, thresh_conv=thres_conv, history_size=save_chain, burnin_ratio=0.0, verbose=verbose, # thresh_conv=thres_conv, ) self.g_statistics = [lambda d: g_statistic(d["x"])]
[docs] def forward(self, y, physics, seed=None): r""" Runs the diffusion model to obtain the posterior mean and variance of the reconstruction of the measurements y. :param torch.Tensor y: Measurements :param deepinv.physics.Physics physics: Forward operator associated with the measurements :param float seed: Random seed for generating the samples :return: (tuple of torch.tensor) containing the posterior mean and variance. """ return self.sample(y, physics, seed=seed, g_statistics=self.g_statistics)
[docs] class DDRM(Reconstructor): r"""DDRM(self, denoiser, sigmas=np.linspace(1, 0, 100), eta=0.85, etab=1.0, verbose=False) Denoising Diffusion Restoration Models (DDRM). This class implements the denoising diffusion restoration model (DDRM) described in https://arxiv.org/abs/2201.11793. The DDRM is a sampling method that uses a denoiser to sample from the posterior distribution of the inverse problem. It requires that the physics operator has a singular value decomposition, i.e., it is :class:`deepinv.physics.DecomposablePhysics` class. :param torch.nn.Module denoiser: a denoiser model that can handle different noise levels. :param list[int] sigmas: a list of noise levels to use in the diffusion, they should be in decreasing order from 1 to 0. :param float eta: hyperparameter :param float etab: hyperparameter :param bool verbose: if True, print progress |sep| :Examples: Denoising diffusion restoration model using a pretrained DRUNet denoiser: >>> import deepinv as dinv >>> device = dinv.utils.get_freer_gpu(verbose=False) if torch.cuda.is_available() else 'cpu' >>> seed = torch.manual_seed(0) # Random seed for reproducibility >>> seed = torch.cuda.manual_seed(0) # Random seed for reproducibility on GPU >>> x = 0.5 * torch.ones(1, 3, 32, 32, device=device) # Define plain gray 32x32 image >>> physics = dinv.physics.Inpainting( ... mask=0.5, img_size=(3, 32, 32), ... noise_model=dinv.physics.GaussianNoise(0.1), ... device=device, ... ) >>> y = physics(x) # measurements >>> denoiser = dinv.models.DRUNet(pretrained="download").to(device) # doctest: +IGNORE_RESULT >>> model = dinv.sampling.DDRM(denoiser=denoiser, sigmas=np.linspace(1, 0, 10), verbose=True) # define the DDRM model >>> xhat = model(y, physics) # sample from the posterior distribution >>> (dinv.metric.PSNR()(xhat, x) > dinv.metric.PSNR()(y, x)).cpu() # Should be closer to the original tensor([True]) """ def __init__( self, denoiser, sigmas=np.linspace(1, 0, 100), eta=0.85, etab=1.0, verbose=False, eps=1e-6, ): super(DDRM, self).__init__() self.denoiser = denoiser self.sigmas = sigmas self.max_iter = len(sigmas) self.eta = eta self.verbose = verbose self.etab = etab self.eps = eps
[docs] def forward(self, y, physics: deepinv.physics.DecomposablePhysics, seed=None): r""" Runs the diffusion to obtain a random sample of the posterior distribution. :param torch.Tensor y: the measurements. :param deepinv.physics.DecomposablePhysics physics: the physics operator, which must have a singular value decomposition. :param int seed: the seed for the random number generator. """ with torch.no_grad(): if seed: np.random.seed(seed) torch.manual_seed(seed) if hasattr(physics.noise_model, "sigma"): sigma_noise = physics.noise_model.sigma else: sigma_noise = 0.01 if physics.__class__ == deepinv.physics.Denoising: mask = torch.ones_like( y ) # TODO: fix for economic SVD decompositions (eg. Decolorize) else: mask = torch.cat([physics.mask.abs()] * y.shape[0], dim=0) c = np.sqrt(1 - self.eta**2) y_bar = physics.U_adjoint(y) case = mask > sigma_noise y_bar[case] = y_bar[case] / (mask[case] + self.eps) nsr = torch.zeros_like(mask) nsr[case] = sigma_noise / (mask[case] + self.eps) # iteration 1 # compute init noise mean = torch.zeros_like(y_bar) std = torch.ones_like(y_bar) * self.sigmas[0] mean[case] = y_bar[case] std[case] = (self.sigmas[0] ** 2 - nsr[case].pow(2)).sqrt() x_bar = mean + std * torch.randn_like(y_bar) / np.sqrt(2.0) x_bar_prev = x_bar.clone() # denoise x = self.denoiser(physics.V(x_bar), self.sigmas[0]) for t in tqdm(range(1, self.max_iter), disable=(not self.verbose)): # add noise in transformed domain x_bar = physics.V_adjoint(x) case2 = torch.logical_and(case, (self.sigmas[t] < nsr)) case3 = torch.logical_and(case, (self.sigmas[t] >= nsr)) mean = ( x_bar + c * self.sigmas[t] * (x_bar_prev - x_bar) / self.sigmas[t - 1] ) mean[case2] = x_bar[case2] + c * self.sigmas[t] * ( y_bar[case2] - x_bar[case2] ) / (nsr[case2] + self.eps) mean[case3] = (1.0 - self.etab) * x_bar[case3] + self.etab * y_bar[ case3 ] std = torch.ones_like(x_bar) * self.eta * self.sigmas[t] std[case3] = ( (self.sigmas[t] ** 2 - (nsr[case3] * self.etab).pow(2)) .clamp(min=0) .sqrt() ) x_bar = mean + std * torch.randn_like(x_bar) / np.sqrt(2.0) x_bar_prev = x_bar.clone() # denoise x = self.denoiser(physics.V(x_bar), self.sigmas[t]) return x
[docs] class DiffPIR(Reconstructor): r""" Diffusion PnP Image Restoration (DiffPIR). This class implements the Diffusion PnP image restoration algorithm (DiffPIR) described in https://arxiv.org/abs/2305.08995. The DiffPIR algorithm is inspired on a half-quadratic splitting (HQS) plug-and-play algorithm, where the denoiser is a conditional diffusion denoiser, combined with a diffusion process. The algorithm writes as follows, for :math:`t` decreasing from :math:`T` to :math:`1`: .. math:: \begin{equation*} \begin{aligned} x_{0}^{t} &= D_{\theta}(x_t, \frac{\sqrt{1-\overline{\alpha}_t}}{\sqrt{\overline{\alpha}_t}}) \\ \widehat{x}_{0}^{t} &= \operatorname{prox}_{2 f(y, \cdot) /{\rho_t}}(x_{0}^{t}) \\ \widehat{\varepsilon} &= \left(x_t - \sqrt{\overline{\alpha}_t} \,\, \widehat{x}_{0}^t\right)/\sqrt{1-\overline{\alpha}_t} \\ \varepsilon_t &= \mathcal{N}(0, \mathbf{I}) \\ x_{t-1} &= \sqrt{\overline{\alpha}_t} \,\, \widehat{x}_{0}^t + \sqrt{1-\overline{\alpha}_t} \left(\sqrt{1-\zeta} \,\, \widehat{\varepsilon} + \sqrt{\zeta} \,\, \varepsilon_t\right), \end{aligned} \end{equation*} where :math:`D_\theta(\cdot,\sigma)` is a Gaussian denoiser network with noise level :math:`\sigma` and :math:`f(y, \cdot)` is the data fidelity term. .. note:: The algorithm might require careful tunning of the hyperparameters :math:`\lambda` and :math:`\zeta` to obtain optimal results. :param torch.nn.Module model: a conditional noise estimation model :param float sigma: the noise level of the data :param deepinv.optim.DataFidelity data_fidelity: the data fidelity operator :param int max_iter: the number of iterations to run the algorithm (default: 100) :param float zeta: hyperparameter :math:`\zeta` for the sampling step (must be between 0 and 1). Default: 1.0. :param float lambda_: hyperparameter :math:`\lambda` for the data fidelity step (:math:`\rho_t = \lambda \frac{\sigma_n^2}{\bar{\sigma}_t^2}` in the paper where the optimal value range between 3.0 and 25.0 depending on the problem). Default: ``7.0``. :param bool verbose: if ``True``, print progress :param str device: the device to use for the computations |sep| :Examples: Denoising diffusion restoration model using a pretrained DRUNet denoiser: >>> import deepinv as dinv >>> device = dinv.utils.get_freer_gpu(verbose=False) if torch.cuda.is_available() else 'cpu' >>> x = 0.5 * torch.ones(1, 3, 32, 32, device=device) # Define a plain gray 32x32 image >>> physics = dinv.physics.Inpainting( ... mask=0.5, img_size=(3, 32, 32), ... noise_model=dinv.physics.GaussianNoise(0.1), ... device=device ... ) >>> y = physics(x) # Measurements >>> denoiser = dinv.models.DRUNet(pretrained="download").to(device) >>> model = dinv.sampling.DiffPIR( ... model=denoiser, ... data_fidelity=dinv.optim.data_fidelity.L2(), ... device=device, ... ) # Define the DiffPIR model >>> xhat = model(y, physics) # Run the DiffPIR algorithm >>> (dinv.metric.PSNR()(xhat, x) > dinv.metric.PSNR()(y, x)).cpu() # Should be closer to the original tensor([True]) """ def __init__( self, model, data_fidelity, sigma=0.05, max_iter=100, zeta=0.1, lambda_=7.0, verbose=False, device="cpu", ): super().__init__() self.model = model self.lambda_ = lambda_ self.data_fidelity = data_fidelity self.max_iter = max_iter self.zeta = zeta self.verbose = verbose self.device = device self.beta_start, self.beta_end = 0.1 / 1000, 20 / 1000 self.num_train_timesteps = 1000 self.sigma = sigma ( self.sqrt_1m_alphas_cumprod, self.reduced_alpha_cumprod, self.sqrt_alphas_cumprod, self.sqrt_recip_alphas_cumprod, self.sqrt_recipm1_alphas_cumprod, self.betas, ) = self.get_alpha_beta() self.rhos, self.sigmas, self.seq = self.get_noise_schedule(sigma=sigma)
[docs] def get_alpha_beta(self): """ Get the alpha and beta sequences for the algorithm. This is necessary for mapping noise levels to timesteps. """ betas = torch.linspace( self.beta_start, self.beta_end, self.num_train_timesteps, dtype=torch.float32, device=self.device, ) alphas = 1.0 - betas alphas_cumprod = torch.cumprod(alphas, axis=0) # This is \overline{\alpha}_t # Useful sequences deriving from alphas_cumprod sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod) sqrt_1m_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod) reduced_alpha_cumprod = torch.div( sqrt_1m_alphas_cumprod, sqrt_alphas_cumprod ) # equivalent noise sigma on image sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod) sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod - 1) return ( sqrt_1m_alphas_cumprod, reduced_alpha_cumprod, sqrt_alphas_cumprod, sqrt_recip_alphas_cumprod, sqrt_recipm1_alphas_cumprod, betas, )
[docs] def get_noise_schedule(self, sigma): """ Get the noise schedule for the algorithm. """ lambda_ = self.lambda_ sigmas = [] sigma_ks = [] rhos = [] for i in range(self.num_train_timesteps): sigmas.append(self.reduced_alpha_cumprod[self.num_train_timesteps - 1 - i]) sigma_ks.append( (self.sqrt_1m_alphas_cumprod[i] / self.sqrt_alphas_cumprod[i]) ) rhos.append(lambda_ * (sigma**2) / (sigma_ks[i] ** 2)) rhos, sigmas = ( torch.tensor(rhos).to(self.device), torch.tensor(sigmas).to(self.device), ) seq = torch.sqrt( torch.linspace( 0.0, self.num_train_timesteps**2, self.max_iter, device=self.device ) ).type(torch.int32) seq[-1] = seq[-1] - 1 return rhos, sigmas, seq
[docs] def find_nearest(self, array, value): """ Find the argmin of the nearest value in an array. """ idx = torch.abs(array - value).argmin() return idx
[docs] def compute_alpha(self, betas, t): """ Compute the alpha sequence from the beta sequence. """ alphas = 1.0 - betas alphas_cumprod = torch.cumprod(alphas, axis=0) at = alphas_cumprod[t] return at
[docs] def get_alpha_prod( self, beta_start=0.1 / 1000, beta_end=20 / 1000, num_train_timesteps=1000 ): """ Get the alpha sequences; this is necessary for mapping noise levels to timesteps when performing pure denoising. """ betas = torch.linspace( beta_start, beta_end, num_train_timesteps, dtype=torch.float32, device=self.device, ) alphas = 1.0 - betas alphas_cumprod = torch.cumprod(alphas, axis=0) # This is \overline{\alpha}_t # Useful sequences deriving from alphas_cumprod sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod) sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / alphas_cumprod - 1) return ( sqrt_recip_alphas_cumprod, sqrt_recipm1_alphas_cumprod, )
[docs] def forward( self, y, physics: deepinv.physics.LinearPhysics, seed=None, x_init=None, ): r""" Runs the diffusion to obtain a random sample of the posterior distribution. :param torch.Tensor y: the measurements. :param deepinv.physics.LinearPhysics physics: the physics operator. :param float sigma: the noise level of the data. :param int seed: the seed for the random number generator. :param torch.Tensor x_init: the initial guess for the reconstruction. """ if seed: torch.manual_seed(seed) if hasattr(physics.noise_model, "sigma"): sigma = physics.noise_model.sigma # Then we overwrite the default values self.rhos, self.sigmas, self.seq = self.get_noise_schedule(sigma=sigma) # Initialization if x_init is None: # Necessary when x and y don't live in the same space x = 2 * physics.A_adjoint(y) - 1 else: x = 2 * x_init - 1 sqrt_recip_alphas_cumprod, sqrt_recipm1_alphas_cumprod = self.get_alpha_prod() with torch.no_grad(): for i in tqdm(range(len(self.seq)), disable=(not self.verbose)): # Current noise level curr_sigma = self.sigmas[self.seq[i]] # time step associated with the noise level sigmas[i] t_i = self.find_nearest(self.reduced_alpha_cumprod, curr_sigma) at = 1 / sqrt_recip_alphas_cumprod[t_i] ** 2 if ( i == 0 ): # Initialization (simpler than the original code, may be suboptimal) x = ( x + (curr_sigma**2 - 4.0 * self.sigma**2).sqrt() * torch.randn_like(x) ) / sqrt_recip_alphas_cumprod[-1] sigma_cur = curr_sigma # Denoising step x_aux = x / (2 * at.sqrt()) + 0.5 # renormalize in [0, 1] out = self.model(x_aux, sigma_cur / 2) denoised = 2 * out - 1 x0 = denoised.clamp(-1, 1) if not self.seq[i] == self.seq[-1]: # Data fidelity step x0_p = x0 / 2 + 0.5 x0_p = self.data_fidelity.prox( x0_p, y, physics, gamma=1.0 / (2 * self.rhos[t_i]) ) x0 = x0_p * 2 - 1 # Sampling step t_im1 = self.find_nearest( self.reduced_alpha_cumprod, self.sigmas[self.seq[i + 1]], ) # time step associated with the next noise level eps = ( x - self.sqrt_alphas_cumprod[t_i] * x0 ) / self.sqrt_1m_alphas_cumprod[ t_i ] # effective noise x = ( self.sqrt_alphas_cumprod[t_im1] * x0 + self.sqrt_1m_alphas_cumprod[t_im1] * (1 - self.zeta) ** 0.5 * eps + self.sqrt_1m_alphas_cumprod[t_im1] * self.zeta**0.5 * torch.randn_like(x) ) # sampling out = x / 2 + 0.5 # back to [0, 1] range return out
[docs] class DPS(Reconstructor): r""" Diffusion Posterior Sampling (DPS). This class implements the Diffusion Posterior Sampling algorithm (DPS) described in https://arxiv.org/abs/2209.14687. DPS is an approximation of a gradient-based posterior sampling algorithm, which has minimal assumptions on the forward model. The only restriction is that the measurement model has to be differentiable, which is generally the case. The algorithm writes as follows, for :math:`t` decreasing from :math:`T` to :math:`1`: .. math:: \begin{equation*} \begin{aligned} \widehat{\mathbf{x}}_{t} &= D_{\theta}(\mathbf{x}_t, \sqrt{1-\overline{\alpha}_t}/\sqrt{\overline{\alpha}_t}) \\ \mathbf{g}_t &= \nabla_{\mathbf{x}_t} \log p( \widehat{\mathbf{x}}_{t}(\mathbf{x}_t) | \mathbf{y} ) \\ \mathbf{\varepsilon}_t &= \mathcal{N}(0, \mathbf{I}) \\ \mathbf{x}_{t-1} &= a_t \,\, \mathbf{x}_t + b_t \, \, \widehat{\mathbf{x}}_t + \tilde{\sigma}_t \, \, \mathbf{\varepsilon}_t + \mathbf{g}_t, \end{aligned} \end{equation*} where :math:`\denoiser{\cdot}{\sigma}` is a denoising network for noise level :math:`\sigma`, :math:`\eta` is a hyperparameter, and the constants :math:`\tilde{\sigma}_t, a_t, b_t` are defined as .. math:: \begin{equation*} \begin{aligned} \tilde{\sigma}_t &= \eta \sqrt{ (1 - \frac{\overline{\alpha}_t}{\overline{\alpha}_{t-1}}) \frac{1 - \overline{\alpha}_{t-1}}{1 - \overline{\alpha}_t}} \\ a_t &= \sqrt{1 - \overline{\alpha}_{t-1} - \tilde{\sigma}_t^2}/\sqrt{1-\overline{\alpha}_t} \\ b_t &= \sqrt{\overline{\alpha}_{t-1}} - \sqrt{1 - \overline{\alpha}_{t-1} - \tilde{\sigma}_t^2} \frac{\sqrt{\overline{\alpha}_{t}}}{\sqrt{1 - \overline{\alpha}_{t}}}. \end{aligned} \end{equation*} :param torch.nn.Module model: a denoiser network that can handle different noise levels :param deepinv.optim.DataFidelity data_fidelity: the data fidelity operator :param int max_iter: the number of diffusion iterations to run the algorithm (default: 1000) :param float eta: DDIM hyperparameter which controls the stochasticity :param bool verbose: if True, print progress :param str device: the device to use for the computations """ def __init__( self, model, data_fidelity, max_iter=1000, eta=1.0, verbose=False, device="cpu", save_iterates=False, ): super(DPS, self).__init__() self.model = model self.model.requires_grad_(True) self.data_fidelity = data_fidelity self.max_iter = max_iter self.eta = eta self.verbose = verbose self.device = device self.beta_start, self.beta_end = 0.1 / 1000, 20 / 1000 self.num_train_timesteps = 1000 self.save_iterates = save_iterates self.betas, self.alpha_cumprod = self.compute_alpha_betas()
[docs] def compute_alpha_betas(self): r""" Get the beta and alpha sequences for the algorithm. This is necessary for mapping noise levels to timesteps. """ betas = torch.linspace( self.beta_start, self.beta_end, self.num_train_timesteps, dtype=torch.float32, device=self.device, ) alpha_cumprod = ( 1 - torch.cat([torch.zeros(1, device=self.device), betas], dim=0) ).cumprod(dim=0) return betas, alpha_cumprod
def get_alpha(self, alpha_cumprod, t): a = alpha_cumprod.index_select(0, t + 1).view(-1, 1, 1, 1) return a def forward(self, y, physics: deepinv.physics.Physics, seed=None, x_init=None): if seed: torch.manual_seed(seed) skip = self.num_train_timesteps // self.max_iter batch_size = y.shape[0] seq = range(0, self.num_train_timesteps, skip) seq_next = [-1] + list(seq[:-1]) time_pairs = list(zip(reversed(seq), reversed(seq_next))) # Initial sample from x_T x = torch.randn_like(y) if x_init is None else (2 * x_init - 1) if self.save_iterates: xs = [x] xt = x.to(self.device) for i, j in tqdm(time_pairs, disable=(not self.verbose)): t = torch.ones(batch_size, dtype=y.dtype, device=self.device) * i next_t = torch.ones(batch_size, dtype=y.dtype, device=self.device) * j at = self.get_alpha(self.alpha_cumprod, t.long()) at_next = self.get_alpha(self.alpha_cumprod, next_t.long()) with torch.enable_grad(): xt.requires_grad_(True) # 1. Denoising step aux_x = xt / (2 * at.sqrt()) + 0.5 # renormalize in [0, 1] sigma_cur = (1 - at).sqrt() / at.sqrt() # sigma_t x0_t = 2 * self.model(aux_x, sigma_cur / 2) - 1 x0_t = torch.clip(x0_t, -1.0, 1.0) # 2. Likelihood gradient approximation l2_loss = self.data_fidelity(x0_t, y, physics).sqrt().sum() norm_grad = torch.autograd.grad(outputs=l2_loss, inputs=xt)[0] norm_grad = norm_grad.detach() sigma_tilde = ( (1 - at / at_next) * (1 - at_next) / (1 - at) ).sqrt() * self.eta c2 = ((1 - at_next) - sigma_tilde**2).sqrt() # 3. Noise step epsilon = torch.randn_like(xt) # 4. DDPM(IM) step xt_next = ( (at_next.sqrt() - c2 * at.sqrt() / (1 - at).sqrt()) * x0_t + sigma_tilde * epsilon + c2 * xt / (1 - at).sqrt() - norm_grad ) if self.save_iterates: xs.append(xt_next.to("cpu")) xt = xt_next.clone() if self.save_iterates: return xs else: return xt