Source code for deepinv.sampling.langevin

import torch.nn as nn
import torch
import numpy as np
import time as time

import deepinv.optim
from tqdm import tqdm
from deepinv.optim.utils import check_conv
from deepinv.sampling.utils import Welford, projbox, refl_projbox


[docs] class MonteCarlo(nn.Module): r""" Base class for Monte Carlo sampling. This class can be used to create new Monte Carlo samplers, by only defining their kernel inside a torch.nn.Module: :: # define custom sampling kernel (possibly a Markov kernel which depends on the previous sample). class MyKernel(torch.torch.nn.Module): def __init__(self, iterator_params): super().__init__() self.iterator_params = iterator_params def forward(self, x, y, physics, likelihood, prior): # run one sampling kernel iteration new_x = f(x, y, physics, likelihood, prior, self.iterator_params) return new_x class MySampler(MonteCarlo): def __init__(self, prior, data_fidelity, iterator_params, max_iter=1e3, burnin_ratio=.1, clip=(-1,2), verbose=True): # generate an iterator iterator = MyKernel(step_size=step_size, alpha=alpha) # set the params of the base class super().__init__(iterator, prior, data_fidelity, max_iter=max_iter, burnin_ratio=burnin_ratio, clip=clip, verbose=verbose) # create the sampler sampler = MySampler(prior, data_fidelity, iterator_params) # compute posterior mean and variance of reconstruction of measurement y mean, var = sampler(y, physics) This class computes the mean and variance of the chain using Welford's algorithm, which avoids storing the whole Monte Carlo samples. :param deepinv.optim.ScorePrior prior: negative log-prior based on a trained or model-based denoiser. :param deepinv.optim.DataFidelity data_fidelity: negative log-likelihood function linked with the noise distribution in the acquisition physics. :param int max_iter: number of Monte Carlo iterations. :param int thinning: thins the Monte Carlo samples by an integer :math:`\geq 1` (i.e., keeping one out of ``thinning`` samples to compute posterior statistics). :param float burnin_ratio: percentage of iterations used for burn-in period, should be set between 0 and 1. The burn-in samples are discarded constant with a numerical algorithm. :param tuple clip: Tuple containing the box-constraints :math:`[a,b]`. If ``None``, the algorithm will not project the samples. :param float crit_conv: Threshold for verifying the convergence of the mean and variance estimates. :param function_handle g_statistic: The sampler will compute the posterior mean and variance of the function g_statistic. By default, it is the identity function (lambda x: x), and thus the sampler computes the posterior mean and variance. :param bool verbose: prints progress of the algorithm. """ def __init__( self, iterator: torch.nn.Module, prior: deepinv.optim.ScorePrior, data_fidelity: deepinv.optim.DataFidelity, max_iter=1e3, burnin_ratio=0.2, thinning=10, clip=(-1.0, 2.0), thresh_conv=1e-3, crit_conv="residual", save_chain=False, g_statistic=lambda x: x, verbose=False, ): super(MonteCarlo, self).__init__() self.iterator = iterator self.prior = prior self.likelihood = data_fidelity self.C_set = clip self.thinning = thinning self.max_iter = int(max_iter) self.thresh_conv = thresh_conv self.crit_conv = crit_conv self.burnin_iter = int(burnin_ratio * max_iter) self.verbose = verbose self.mean_convergence = False self.var_convergence = False self.g_function = g_statistic self.save_chain = save_chain self.chain = []
[docs] def forward(self, y, physics, seed=None, x_init=None): r""" Runs an Monte Carlo chain to obtain the posterior mean and variance of the reconstruction of the measurements y. :param torch.Tensor y: Measurements :param deepinv.physics.Physics physics: Forward operator associated with the measurements :param float seed: Random seed for generating the Monte Carlo samples :return: (tuple of torch.tensor) containing the posterior mean and variance. """ with torch.no_grad(): if seed is not None: np.random.seed(seed) torch.manual_seed(seed) # Algorithm parameters if self.C_set: C_lower_lim = self.C_set[0] C_upper_lim = self.C_set[1] # Initialization if x_init is None: x = physics.A_adjoint(y) else: x = x_init # Monte Carlo loop start_time = time.time() statistics = Welford(self.g_function(x)) self.mean_convergence = False self.var_convergence = False for it in tqdm(range(self.max_iter), disable=(not self.verbose)): x = self.iterator( x, y, physics, likelihood=self.likelihood, prior=self.prior ) if self.C_set: x = projbox(x, C_lower_lim, C_upper_lim) if it >= self.burnin_iter and (it % self.thinning) == 0: if it >= (self.max_iter - self.thinning): mean_prev = statistics.mean().clone() var_prev = statistics.var().clone() statistics.update(self.g_function(x)) if self.save_chain: self.chain.append(x.clone()) if self.verbose: if torch.cuda.is_available(): torch.cuda.synchronize() end_time = time.time() elapsed = end_time - start_time print( f"Monte Carlo sampling finished! elapsed time={elapsed:.2f} seconds" ) if ( check_conv( {"est": (mean_prev,)}, {"est": (statistics.mean(),)}, it, self.crit_conv, self.thresh_conv, self.verbose, ) and it > 1 ): self.mean_convergence = True if ( check_conv( {"est": (var_prev,)}, {"est": (statistics.var(),)}, it, self.crit_conv, self.thresh_conv, self.verbose, ) and it > 1 ): self.var_convergence = True return statistics.mean(), statistics.var()
[docs] def get_chain(self): r""" Returns the thinned Monte Carlo samples (after burn-in iterations). Requires ``save_chain=True``. """ return self.chain
[docs] def reset(self): r""" Resets the Markov chain. """ self.chain = [] self.mean_convergence = False self.var_convergence = False
[docs] def mean_has_converged(self): r""" Returns a boolean indicating if the posterior mean verifies the convergence criteria. """ return self.mean_convergence
[docs] def var_has_converged(self): r""" Returns a boolean indicating if the posterior variance verifies the convergence criteria. """ return self.var_convergence
class ULAIterator(nn.Module): r""" Single iteration of the Unadjusted Langevin Algorithm. :param float step_size: step size :math:`\eta>0` of the algorithm. :param float alpha: regularization parameter :math:`\alpha`. :param float sigma: noise level used in the plug-and-play prior denoiser. """ def __init__(self, step_size, alpha, sigma): super().__init__() self.step_size = step_size self.alpha = alpha self.noise_std = np.sqrt(2 * step_size) self.sigma = sigma def forward(self, x, y, physics, likelihood, prior): noise = torch.randn_like(x) * self.noise_std lhood = -likelihood.grad(x, y, physics) lprior = -prior.grad(x, self.sigma) * self.alpha return x + self.step_size * (lhood + lprior) + noise
[docs] class ULA(MonteCarlo): r""" Projected Plug-and-Play Unadjusted Langevin Algorithm. The algorithm runs the following markov chain iteration (Algorithm 2 from https://arxiv.org/abs/2103.04715): .. math:: x_{k+1} = \Pi_{[a,b]} \left(x_{k} + \eta \nabla \log p(y|A,x_k) + \eta \alpha \nabla \log p(x_{k}) + \sqrt{2\eta}z_{k+1} \right). where :math:`x_{k}` is the :math:`k` th sample of the Markov chain, :math:`\log p(y|x)` is the log-likelihood function, :math:`\log p(x)` is the log-prior, :math:`\eta>0` is the step size, :math:`\alpha>0` controls the amount of regularization, :math:`\Pi_{[a,b]}(x)` projects the entries of :math:`x` to the interval :math:`[a,b]` and :math:`z\sim \mathcal{N}(0,I)` is a standard Gaussian vector. - Projected PnP-ULA assumes that the denoiser is :math:`L`-Lipschitz differentiable - For convergence, ULA required step_size smaller than :math:`\frac{1}{L+\|A\|_2^2}` :param deepinv.optim.ScorePrior, torch.nn.Module prior: negative log-prior based on a trained or model-based denoiser. :param deepinv.optim.DataFidelity, torch.nn.Module data_fidelity: negative log-likelihood function linked with the noise distribution in the acquisition physics. :param float step_size: step size :math:`\eta>0` of the algorithm. Tip: use :meth:`deepinv.physics.Physics.compute_norm()` to compute the Lipschitz constant of the forward operator. :param float sigma: noise level used in the plug-and-play prior denoiser. A larger value of sigma will result in a more regularized reconstruction. :param float alpha: regularization parameter :math:`\alpha` :param int max_iter: number of Monte Carlo iterations. :param int thinning: Thins the Markov Chain by an integer :math:`\geq 1` (i.e., keeping one out of ``thinning`` samples to compute posterior statistics). :param float burnin_ratio: percentage of iterations used for burn-in period, should be set between 0 and 1. The burn-in samples are discarded constant with a numerical algorithm. :param tuple clip: Tuple containing the box-constraints :math:`[a,b]`. If ``None``, the algorithm will not project the samples. :param float crit_conv: Threshold for verifying the convergence of the mean and variance estimates. :param function_handle g_statistic: The sampler will compute the posterior mean and variance of the function g_statistic. By default, it is the identity function (lambda x: x), and thus the sampler computes the posterior mean and variance. :param bool verbose: prints progress of the algorithm. """ def __init__( self, prior, data_fidelity, step_size=1.0, sigma=0.05, alpha=1.0, max_iter=1e3, thinning=5, burnin_ratio=0.2, clip=(-1.0, 2.0), thresh_conv=1e-3, save_chain=False, g_statistic=lambda x: x, verbose=False, ): iterator = ULAIterator(step_size=step_size, alpha=alpha, sigma=sigma) super().__init__( iterator, prior, data_fidelity, max_iter=max_iter, thresh_conv=thresh_conv, g_statistic=g_statistic, burnin_ratio=burnin_ratio, clip=clip, thinning=thinning, save_chain=save_chain, verbose=verbose, )
class SKRockIterator(nn.Module): def __init__(self, step_size, alpha, inner_iter, eta, sigma): super().__init__() self.step_size = step_size self.alpha = alpha self.eta = eta self.inner_iter = inner_iter self.noise_std = np.sqrt(2 * step_size) self.sigma = sigma def forward(self, x, y, physics, likelihood, prior): posterior = lambda u: likelihood.grad(u, y, physics) + self.alpha * prior.grad( u, self.sigma ) # First kind Chebyshev function T_s = lambda s, u: np.cosh(s * np.arccosh(u)) # First derivative Chebyshev polynomial first kind T_prime_s = lambda s, u: s * np.sinh(s * np.arccosh(u)) / np.sqrt(u**2 - 1) w0 = 1 + self.eta / (self.inner_iter**2) # parameter \omega_0 w1 = T_s(self.inner_iter, w0) / T_prime_s( self.inner_iter, w0 ) # parameter \omega_1 mu1 = w1 / w0 # parameter \mu_1 nu1 = self.inner_iter * w1 / 2 # parameter \nu_1 kappa1 = self.inner_iter * (w1 / w0) # parameter \kappa_1 # sampling the variable x noise = np.sqrt(2 * self.step_size) * torch.randn_like(x) # diffusion term # first internal iteration (s=1) xts_2 = x.clone() xts = ( x.clone() - mu1 * self.step_size * posterior(x + nu1 * noise) + kappa1 * noise ) for js in range( 2, self.inner_iter + 1 ): # s=2,...,self.inner_iter SK-ROCK internal iterations xts_1 = xts.clone() mu = 2 * w1 * T_s(js - 1, w0) / T_s(js, w0) # parameter \mu_js nu = 2 * w0 * T_s(js - 1, w0) / T_s(js, w0) # parameter \nu_js kappa = 1 - nu # parameter \kappa_js xts = -mu * self.step_size * posterior(xts) + nu * xts + kappa * xts_2 xts_2 = xts_1 return xts # new sample produced by the SK-ROCK algorithm
[docs] class SKRock(MonteCarlo): r""" Plug-and-Play SKROCK algorithm. Obtains samples of the posterior distribution using an orthogonal Runge-Kutta-Chebyshev stochastic approximation to accelerate the standard Unadjusted Langevin Algorithm. The algorithm was introduced in "Accelerating proximal Markov chain Monte Carlo by using an explicit stabilised method" by L. Vargas, M. Pereyra and K. Zygalakis (https://arxiv.org/abs/1908.08845) - SKROCK assumes that the denoiser is :math:`L`-Lipschitz differentiable - For convergence, SKROCK required step_size smaller than :math:`\frac{1}{L+\|A\|_2^2}` :param deepinv.optim.ScorePrior, torch.nn.Module prior: negative log-prior based on a trained or model-based denoiser. :param deepinv.optim.DataFidelity, torch.nn.Module data_fidelity: negative log-likelihood function linked with the noise distribution in the acquisition physics. :param float step_size: Step size of the algorithm. Tip: use physics.lipschitz to compute the Lipschitz :param float eta: :math:`\eta` SKROCK damping parameter. :param float alpha: regularization parameter :math:`\alpha`. :param int inner_iter: Number of inner SKROCK iterations. :param int max_iter: Number of outer iterations. :param int thinning: Thins the Markov Chain by an integer :math:`\geq 1` (i.e., keeping one out of ``thinning`` samples to compute posterior statistics). :param float burnin_ratio: percentage of iterations used for burn-in period. The burn-in samples are discarded constant with a numerical algorithm. :param tuple clip: Tuple containing the box-constraints :math:`[a,b]`. If ``None``, the algorithm will not project the samples. :param bool verbose: prints progress of the algorithm. :param float sigma: noise level used in the plug-and-play prior denoiser. A larger value of sigma will result in a more regularized reconstruction. :param function_handle g_statistic: The sampler will compute the posterior mean and variance of the function g_statistic. By default, it is the identity function (lambda x: x), and thus the sampler computes the posterior mean and variance. """ def __init__( self, prior: deepinv.optim.ScorePrior, data_fidelity, step_size=1.0, inner_iter=10, eta=0.05, alpha=1.0, max_iter=1e3, burnin_ratio=0.2, thinning=10, clip=(-1.0, 2.0), thresh_conv=1e-3, save_chain=False, g_statistic=lambda x: x, verbose=False, sigma=0.05, ): iterator = SKRockIterator( step_size=step_size, alpha=alpha, inner_iter=inner_iter, eta=eta, sigma=sigma, ) super().__init__( iterator, prior, data_fidelity, max_iter=max_iter, thresh_conv=thresh_conv, thinning=thinning, burnin_ratio=burnin_ratio, clip=clip, g_statistic=g_statistic, save_chain=save_chain, verbose=verbose, )
# if __name__ == "__main__": # import deepinv as dinv # import torchvision # from deepinv.optim.data_fidelity import L2 # # x = torchvision.io.read_image("../../datasets/celeba/img_align_celeba/085307.jpg") # x = x.unsqueeze(0).float().to(dinv.device) / 255 # # physics = dinv.physics.CompressedSensing(m=50000, fast=True, img_shape=(3, 218, 178), device=dinv.device) # # physics = dinv.physics.Denoising() # physics = dinv.physics.Inpainting( # mask=0.95, tensor_size=(3, 218, 178), device=dinv.device # ) # # physics = dinv.physics.BlurFFT(filter=dinv.physics.blur.gaussian_blur(sigma=(2,2)), img_size=x.shape[1:], device=dinv.device) # # sigma = 0.1 # physics.noise_model = dinv.physics.GaussianNoise(sigma) # # y = physics(x) # # likelihood = L2(sigma=sigma) # # # model_spec = {'name': 'median_filter', 'args': {'kernel_size': 3}} # model_spec = { # "name": "dncnn", # "args": { # "device": dinv.device, # "in_channels": 3, # "out_channels": 3, # "pretrained": "download_lipschitz", # }, # } # # model_spec = {'name': 'waveletprior', 'args': {'wv': 'db8', 'level': 4, 'device': dinv.device}} # # prior = ScorePrior(model_spec=model_spec, sigma_normalize=True) # # sigma_den = 2 / 255 # f = ULA( # prior, # likelihood, # max_iter=5000, # sigma=sigma_den, # burnin_ratio=0.3, # verbose=True, # alpha=0.3, # step_size=0.5 * 1 / (1 / (sigma**2) + 1 / (sigma_den**2)), # clip=(-1, 2), # save_chain=True, # ) # # f = SKRock(prior, likelihood, max_iter=1000, burnin_ratio=.3, verbose=True, # # alpha=.9, step_size=.1*(sigma**2), clip=(-1, 2)) # # xmean, xvar = f(y, physics) # # print(str(f.mean_has_converged())) # print(str(f.var_has_converged())) # # chain = f.get_chain() # distance = np.zeros((len(chain))) # for k, xhat in enumerate(chain): # dist = (xhat - xmean).pow(2).mean() # distance[k] = dist # distance = np.sort(distance) # thres = distance[int(len(distance) * 0.95)] # # err = (x - xmean).pow(2).mean() # print(f"Confidence region: {thres:.2e}, error: {err:.2e}") # # xstdn = xvar.sqrt() # xstdn_plot = xstdn.sum(dim=1).unsqueeze(1) # # error = (xmean - x).abs() # per pixel average abs. error # error_plot = error.sum(dim=1).unsqueeze(1) # # print(f"Correct std: {(xstdn*3>error).sum()/np.prod(xstdn.shape)*100:.1f}%") # # dinv.utils.plot( # [physics.A_adjoint(y), x, xmean, xstdn_plot, error_plot], # titles=["meas.", "ground-truth", "mean", "norm. std", "abs. error"], # )