BaseSDE#
- class deepinv.sampling.BaseSDE(drift, diffusion, solver=None, dtype=torch.float32, device=torch.device('cpu'), *args, **kwargs)[source]#
Bases:
Module
Base class for Stochastic Differential Equation (SDE):min_num_steps
\[d x_{t} = f(x_t, t) dt + g(t) d w_{t}\]where \(f\) is the drift term, \(g\) is the diffusion coefficient and \(w\) is the standard Brownian motion. It defines the common interface for drift and diffusion functions.
- Parameters:
drift (Callable) – a time-dependent drift function \(f(x, t)\)
diffusion (Callable) – a time-dependent diffusion function \(g(t)\)
solver (deepinv.sampling.BaseSDESolver) – the solver for solving the SDE.
dtype (torch.dtype) – the data type of the computations.
device (torch.device) – the device for the computations.
- discretize(x, t, *args, **kwargs)[source]#
Discretize the SDE at the given time step.
- Parameters:
x (torch.Tensor) – current state.
t (float) – discretized time step.
*args – additional arguments for the drift.
**kwargs – additional keyword arguments for the drift.
- Return Tuple[Tensor, Tensor]:
discretized drift and diffusion.
- Return type:
- sample(x_init=None, seed=None, get_trajectory=False, *args, **kwargs)[source]#
Solve the SDE with the given timesteps.
- Parameters:
x_init (torch.Tensor) – initial value.
seed (int) – the seed for the pseudo-random number generator used in the solver.
get_trajectory (bool) – whether to return the full trajectory of the SDE or only the last sample, optional. Default to False
*args – additional arguments for the solver.
**kwargs – additional keyword arguments for the solver.
:return : the generated sample (
torch.Tensor
of shape(B, C, H, W)
) ifget_trajectory
isFalse
. Otherwise, returns (torch.Tensor
,torch.Tensor
) of shape(B, C, H, W)
and(N, B, C, H, W)
whereN
is the number of steps.
Examples using BaseSDE
:#

Posterior Sampling for Inverse Problems with Stochastic Differential Equations modeling.