import torch
import torch.nn as nn
from torch import Tensor
import warnings
from typing import Optional, Union, Any
from numpy import ndarray
from tqdm import tqdm
[docs]
class SDEOutput(dict):
r"""
A container for storing the output of an SDE solver, that behaves like a `dict` but allows access with the attribute syntax.
Attributes:
:attr torch.Tensor sample: the final samples of the sampling process, of shape ``(B, C, H, W)``.
:attr torch.Tensor trajectory: the trajectory of the sampling process, of shape ``(num_steps, B, C, H, W)`` if ``full_trajectory`` is ``True``, otherwise of shape ``(B, C, H, W)``.
:attr torch.Tensor timesteps: the time steps at which the samples were taken, of shape ``(num_steps,)``.
:attr int nfe: the number of function evaluations performed during the integration.
"""
def __init__(self, sample: Tensor, trajectory: Tensor, timesteps: Tensor, nfe: int):
sol = {
"sample": sample,
"trajectory": trajectory,
"timesteps": timesteps,
"nfe": nfe,
}
super().__init__(sol)
def __getattr__(self, name: str) -> Any:
try:
return self[name]
except KeyError:
raise AttributeError(name)
def __setattr__(self, name: str, value: Any) -> None:
self[name] = value
def __delattr__(self, name: str) -> None:
del self[name]
[docs]
class BaseSDESolver(nn.Module):
r"""
Base class for solving Stochastic Differential Equations (SDEs) from :class:`deepinv.sampling.BaseSDE` of the form:
.. math::
d x_{t} = f(x_t, t) dt + g(t) d w_{t}
where :math:`f` is the drift term, :math:`g` is the diffusion coefficient, and :math:`w_t` is a standard Brownian process.
Currently only supported for fixed time steps for numerical integration.
:param torch.Tensor, numpy.ndarray, list timesteps: time steps at which the SDE will be discretized.e.
:param torch.Generator rng: a random number generator for reproducibility, optional.
:param bool verbose: whether to display a progress bar during the sampling process, optional. Default to False.
"""
def __init__(
self,
timesteps: Union[Tensor, ndarray],
rng: Optional[torch.Generator] = None,
):
super().__init__()
if isinstance(timesteps, ndarray):
self.timesteps = torch.from_numpy(timesteps.copy())
elif isinstance(timesteps, Tensor):
self.timesteps = timesteps
self.rng = rng
if rng is not None:
self.initial_random_state = rng.get_state()
self.timesteps = self.timesteps.to(rng.device)
[docs]
def step(self, sde, t0: float, t1: float, x0: Tensor, *args, **kwargs) -> Tensor:
r"""
Perform a single step with step size from time `t0` to time `t1`, with current state `x0`.
:param deepinv.sampling.BaseSDE sde: the SDE to solve.
:param float or torch.Tensor t0: Time at the start of the step, of size (,).
:param float or torch.Tensor t1: Time at the end of the step, of size (,).
:param torch.Tensor x0: Current state of the system, of size (batch_size, d).
:return: Updated state of the system after the step.
:rtype: torch.Tensor
"""
raise NotImplementedError
[docs]
@torch.no_grad()
def sample(
self,
sde,
x_init: Tensor,
seed: int = None,
*args,
timesteps: Union[Tensor, ndarray] = None,
get_trajectory: bool = False,
verbose: bool = False,
**kwargs,
) -> SDEOutput:
r"""
Solve the Stochastic Differential Equation (SDE) with given time steps.
This function iteratively applies the SDE solver step for each time interval
defined by the provided timesteps.
:param deepinv.sampling.BaseSDE sde: the SDE to solve.
:param torch.Tensor x_init: The initial state of the system.
:param int seed: The seed for the random number generator, if `rng` is provided.
:param torch.Tensor, numpy.ndarray, list timesteps: A sequence of time points at which to solve the SDE. If None, default timesteps will be used.
:param bool get_trajectory: whether to return the full trajectory of the SDE or only the last sample, optional. Default to False.
:param bool verbose: whether to display a progress bar during the sampling process, optional. Default to False.
:param \*args: Variable length argument list to be passed to the step function.
:param \*\*kwargs: Arbitrary keyword arguments to be passed to the step function.
:return: SDEOutput
"""
self.rng_manual_seed(seed)
x = x_init
nfe = 0
trajectory = [x_init.clone()] if get_trajectory else []
if timesteps is None:
timesteps = self.timesteps.to(sde.device, sde.dtype)
else:
if isinstance(timesteps, ndarray):
timesteps = torch.from_numpy(timesteps.copy())
timesteps = timesteps.to(sde.device, sde.dtype)
for t_cur, t_next in tqdm(
zip(timesteps[:-1], timesteps[1:], strict=True),
total=len(timesteps) - 1,
disable=not verbose,
):
x, cur_nfe = self.step(sde, t_cur, t_next, x, *args, **kwargs)
nfe += cur_nfe
if get_trajectory:
trajectory.append(x.clone())
if get_trajectory:
trajectory = torch.stack(trajectory, dim=0)
else:
trajectory = x
output = SDEOutput(
sample=x, trajectory=trajectory, timesteps=timesteps, nfe=nfe
)
return output
[docs]
def rng_manual_seed(self, seed: int = None):
r"""
Sets the seed for the random number generator.
:param int seed: the seed to set for the random number generator. If not provided, the current state of the random number generator is used.
Note: it will be ignored if the random number generator is not initialized.
"""
if seed is not None:
if self.rng is not None:
self.rng = self.rng.manual_seed(seed)
else:
warnings.warn(
"Cannot set seed for random number generator because it is not initialized. The `seed` parameter is ignored."
)
[docs]
def reset_rng(self):
r"""
Reset the random number generator to its initial state.
"""
self.rng.set_state(self.initial_random_state)
[docs]
def randn_like(self, input: torch.Tensor, seed: int = None):
r"""
Equivalent to :func:`torch.randn_like` but supports a pseudorandom number generator argument.
:param torch.Tensor input: The input tensor whose size will be used.
:param int seed: The seed for the random number generator, if `rng` is provided.
:return: A tensor of the same size as input filled with random numbers from a normal distribution.
:rtype: torch.Tensor
This method uses the `rng` attribute of the class, which is a pseudo-random number generator
for reproducibility. If a seed is provided, it will be used to set the state of `rng` before
generating the random numbers.
.. note::
The `rng` attribute must be initialized for this method to work properly.
"""
self.rng_manual_seed(seed)
return torch.empty_like(input).normal_(generator=self.rng)
[docs]
class EulerSolver(BaseSDESolver):
r"""
Euler-Maruyama solver for SDEs.
This solver uses the Euler-Maruyama method to numerically integrate SDEs. It is a first-order method that
approximates the solution using the following update rule:
.. math::
x_{t+dt} = x_t + f(x_t,t)dt + g(t) W_{dt}
where :math:`W_t` is a Gaussian random variable with mean 0 and variance dt.
:param torch.Tensor timesteps: The time steps at which to evaluate the solution.
:param torch.Generator rng: A random number generator for reproducibility.
"""
def __init__(self, timesteps, rng: torch.Generator = None):
super().__init__(timesteps, rng=rng)
def step(self, sde, t0, t1, x0: Tensor, *args, **kwargs):
dt = abs(t1 - t0)
dW = self.randn_like(x0) * dt**0.5
drift, diffusion = sde.discretize(x0, t0, *args, **kwargs)
return x0 + drift * dt + diffusion * dW, 1
[docs]
class HeunSolver(BaseSDESolver):
r"""
Heun solver for SDEs.
This solver uses the second-order Heun method to numerically integrate SDEs, defined as:
.. math::
\tilde{x}_{t+dt} &= x_t + f(x_t,t)dt + g(t) W_{dt} \\
x_{t+dt} &= x_t + \frac{1}{2}[f(x_t,t) + f(\tilde{x}_{t+dt},t+dt)]dt + \frac{1}{2}[g(t) + g(t+dt)] W_{dt}
where :math:`W_t` is a Gaussian random variable with mean 0 and variance dt.
:param torch.Tensor timesteps: The time steps at which to evaluate the solution.
:param torch.Generator rng: A random number generator for reproducibility.
"""
def __init__(
self,
timesteps,
rng: torch.Generator = None,
):
super().__init__(timesteps, rng=rng)
def step(self, sde, t0, t1, x0: Tensor, *args, **kwargs):
dt = abs(t1 - t0)
dW = self.randn_like(x0) * dt**0.5
drift_0, diffusion_0 = sde.discretize(x0, t0, *args, **kwargs)
x_euler = x0 + drift_0 * dt + diffusion_0 * dW
drift_1, diffusion_1 = sde.discretize(x_euler, t1, *args, **kwargs)
return (
x0
+ 0.5 * (drift_0 + drift_1) * dt
+ 0.5 * (diffusion_0 + diffusion_1) * dW,
2,
)