PosteriorDiffusion#
- class deepinv.sampling.PosteriorDiffusion(data_fidelity=None, denoiser=None, sde=None, solver=None, dtype=torch.float64, device=torch.device('cpu'), verbose=False, minus_one_one=True, *args, **kwargs)[source]#
- Bases: - Reconstructor- Posterior distribution sampling for inverse problems using diffusion models by Reverse-time Stochastic Differential Equation (SDE). - Consider the acquisition model: \[y = \noise{\forw{x}}.\]- This class defines the reverse-time SDE for the posterior distribution \(p(x|y)\) given the data \(y\): \[d\, x_t = \left( f(x_t, t) - \frac{1 + \alpha}{2} g(t)^2 \nabla_{x_t} \log p_t(x_t | y) \right) d\,t + g(t) \sqrt{\alpha} d\, w_{t}\]- where \(f\) is the drift term, \(g\) is the diffusion coefficient and \(w\) is the standard Brownian motion. The drift term and the diffusion coefficient are defined by the underlying (unconditional) forward-time SDE - sde. The (conditional) score function \(\nabla_{x_t} \log p_t(x_t | y)\) can be decomposed using the Bayes’ rule:\[\nabla_{x_t} \log p_t(x_t | y) = \nabla_{x_t} \log p_t(x_t) + \nabla_{x_t} \log p_t(y | x_t).\]- The first term is the score function of the unconditional SDE, which is typically approximated by a MMSE denoiser using the well-known Tweedie’s formula, while the second term is approximated by the (noisy) data-fidelity term. We implement various data-fidelity terms in - deepinv.sampling.NoisyDataFidelity.- Parameters:
- data_fidelity (deepinv.sampling.NoisyDataFidelity) – the noisy data-fidelity term, used to approximate the score \(\nabla_{x_t} \log p_t(y \vert x_t)\). Default to - deepinv.optim.ZeroFidelity, which corresponds to the zero data-fidelity term and the sampling process boils down to the unconditional SDE sampling.
- denoiser (deepinv.models.Denoiser) – a denoiser used to provide an approximation of the (unconditional) score at time \(t\) \(\nabla \log p_t\). 
- sde (deepinv.sampling.DiffusionSDE) – the forward-time SDE, which defines the drift and diffusion terms of the reverse-time SDE. 
- solver (deepinv.sampling.BaseSDESolver) – the solver for the SDE. If not specified, the solver from the - sdewill be used.
- dtype (torch.dtype) – the data type of the sampling solver, except for the - denoiserwhich will use- torch.float32. We recommend using- torch.float64for better stability and less numerical error when solving the SDE in discrete time, since most computation cost is from evaluating the- denoiser, which will be always computed in- torch.float32.
- device (torch.device) – the device for the computations. 
- verbose (bool) – whether to display a progress bar during the sampling process, optional. Default to - False.
- minus_one_one (bool) – If - True, wrap the denoiser so that SDE states- xin [-1, 1] are converted to [0, 1] before denoising and mapped back afterward. Set- Truefor denoisers trained on [0, 1] (all denoisers in- deepinv.models.Denoiser); set- Falseonly if the denoiser natively expects [-1, 1]. This affects only the denoiser interface and usually improves quality when matched to the denoiser’s training range. Default:- True.
 
 - forward(y, physics, x_init=None, seed=None, timesteps=None, get_trajectory=False, *args, **kwargs)[source]#
- Sample the posterior distribution \(p(x|y)\) given the data measurement \(y\). - Parameters:
- y (torch.Tensor) – the data measurement. 
- physics (deepinv.physics.Physics) – the forward operator. 
- x_init (torch.Tensor, tuple) – the initial value for the sampling, can be a - torch.Tensoror a tuple- (B, C, H, W), indicating the shape of the initial point, matching the shape of- physicsand- y. In this case, the initial value is taken randomly following the end-point distribution of the- sde.
- seed (int) – the random seed. 
- timesteps (torch.Tensor) – the time steps for the solver. If - None, the default time steps in the solver will be used.
- get_trajectory (bool) – whether to return the full trajectory of the SDE or only the last sample, optional. Default to - False.
- *args – the additional arguments for the solver. 
- **kwargs – the additional keyword arguments for the solver. 
 
- Returns:
- the generated sample ( - torch.Tensorof shape- (B, C, H, W)) if- get_trajectoryis- False. Otherwise, returns a tuple (- torch.Tensor,- torch.Tensor) of shape- (B, C, H, W)and- (N, B, C, H, W)where- Nis the number of steps.
 
 - score(y, physics, x, t, *args, **kwargs)[source]#
- Approximating the conditional score \(\nabla_{x_t} \log p_t(x_t \vert y)\). - Parameters:
- y (torch.Tensor) – the data measurement. 
- physics (deepinv.physics.Physics) – the forward operator. 
- x (torch.Tensor) – the current state. 
- t (torch.Tensor, float) – the current time step. 
- *args – additional arguments for the score function of the unconditional SDE. 
- **kwargs – additional keyword arguments for the score function of the unconditional SDE. 
 
- Returns:
- the score function \(\nabla_{x_t} \log p_t(x_t \vert y)\). 
- Return type:
 
 
Examples using PosteriorDiffusion:#
 
Building your diffusion posterior sampling method using SDEs
 
    