HeunSolver#
- class deepinv.sampling.HeunSolver(timesteps, rng=None)[source]#
Bases:
BaseSDESolver
Heun solver for SDEs.
This solver uses the second-order Heun method to numerically integrate SDEs, defined as:
\[\begin{split}\tilde{x}_{t+dt} &= x_t + f(x_t,t)dt + g(t) W_{dt} \\ x_{t+dt} &= x_t + \frac{1}{2}[f(x_t,t) + f(\tilde{x}_{t+dt},t+dt)]dt + \frac{1}{2}[g(t) + g(t+dt)] W_{dt}\end{split}\]where \(W_t\) is a Gaussian random variable with mean 0 and variance dt.
- Parameters:
timesteps (torch.Tensor) – The time steps at which to evaluate the solution.
rng (torch.Generator) – A random number generator for reproducibility.