BaseSDE#

class deepinv.sampling.BaseSDE(drift, diffusion, solver=None, dtype=torch.float32, device=torch.device('cpu'), *args, **kwargs)[source]#

Bases: Module

Base class for Stochastic Differential Equation (SDE):min_num_steps

\[d x_{t} = f(x_t, t) dt + g(t) d w_{t}\]

where \(f\) is the drift term, \(g\) is the diffusion coefficient and \(w\) is the standard Brownian motion. It defines the common interface for drift and diffusion functions.

Parameters:
  • drift (Callable) – a time-dependent drift function \(f(x, t)\)

  • diffusion (Callable) – a time-dependent diffusion function \(g(t)\)

  • solver (deepinv.sampling.BaseSDESolver) – the solver for solving the SDE.

  • dtype (torch.dtype) – the data type of the computations.

  • device (torch.device) – the device for the computations.

discretize(x, t, *args, **kwargs)[source]#

Discretize the SDE at the given time step.

Parameters:
  • x (torch.Tensor) – current state.

  • t (float) – discretized time step.

  • *args – additional arguments for the drift.

  • **kwargs – additional keyword arguments for the drift.

Return Tuple[Tensor, Tensor]:

discretized drift and diffusion.

Return type:

Tuple[Tensor, Tensor]

sample(x_init=None, seed=None, get_trajectory=False, *args, **kwargs)[source]#

Solve the SDE with the given timesteps.

Parameters:
  • x_init (torch.Tensor) – initial value.

  • seed (int) – the seed for the pseudo-random number generator used in the solver.

  • get_trajectory (bool) – whether to return the full trajectory of the SDE or only the last sample, optional. Default to False

  • *args – additional arguments for the solver.

  • **kwargs – additional keyword arguments for the solver.

:return : the generated sample (torch.Tensor of shape (B, C, H, W)) if get_trajectory is False. Otherwise, returns (torch.Tensor, torch.Tensor) of shape (B, C, H, W) and (N, B, C, H, W) where N is the number of steps.

sample_init(shape, rng=None)[source]#

Sample from the end-time distribution of the forward diffusion.

Parameters:

shape (List | Tuple | Size) – The shape of the the sample, of the form (B, C, H, W).

Examples using BaseSDE:#

Posterior Sampling for Inverse Problems with Stochastic Differential Equations modeling.

Posterior Sampling for Inverse Problems with Stochastic Differential Equations modeling.