ScoreModelWrapper#
- class deepinv.models.ScoreModelWrapper(score_model=None, prediction_type='epsilon', clip_output=True, sigma_t=None, scale_t=None, sigma_inverse=None, variance_preserving=False, variance_exploding=False, T=1.0, takes_integer_time=False, n_timesteps=1000, _was_trained_on_minus_one_one=True, device='cpu')[source]#
Bases:
DenoiserWraps a score model as a DeepInv Denoiser.
Given a noisy sample \(x_t = s_t(x_0 + \sigma_t \varepsilon)\), where \(\varepsilon \sim \mathcal{N}(0, I)\), depending on the
prediction_type, the inputscore_modelis trained to predict, either:the noise \(\varepsilon\) (
prediction_type = 'epsilon') as typically the case for DDPM models, orthe denoised sample \(x_0\) (
prediction_type = 'sample') orthe
v-prediction\(s_t (\varepsilon - \sigma_t \cdot x_0)\) as proposed by [1] (prediction_type = 'v_prediction')
- Parameters:
score_model (torch.nn.Module | Callable) – score model to be wrapped.
prediction_type (str) – type of prediction made by the score model.
clip_output (bool) – whether to clip the output to the model range. Default is
True.sigma_t (Callable | torch.Tensor) – continuous function or tensor (of shape
[N]withNthe number of time steps) defining the noise schedule \(\sigma_t\).scale_t (Callable | torch.Tensor) – function or tensor (of shape
[N]withNthe number of time steps) defining the scaling schedule \(s_t\).sigma_inverse (Callable) – analytic inverse of the
sigma_t. If not provided, a numeric inversion is used.variance_preserving (bool) – whether the schedule is variance-preserving. If
True,scale_tis computed from thesigma_t.variance_exploding (bool) – whether the schedule is variance-exploding. If
True,scale_tis set to1.T (float) – maximum time value for continuous schedules. Default is
1.0.takes_integer_time (bool) – whether the model takes integer time steps (in
[0, n_timesteps-1]) as input. Default isFalse.n_timesteps (int) – number of time steps for discrete schedules. Default is
1000._was_trained_on_minus_one_one (bool) – whether the model was trained on images in
[-1, 1]range (True) or[0, 1]range (False). Default isTrue.str – device to load the model on. Default is
'cpu'.
- References:
- forward(x, sigma=None, input_in_minus_one_one=False, *args, **kwargs)[source]#
Applies denoiser \(\denoiser{x}{\sigma}\). If
input_in_minus_one_oneisFalse(default value), the inputxis expected to be in[0, 1]range (up to random noise) and the output is also in[0, 1]range. Otherwise, both input and output are expected in[-1, 1]range.- Parameters:
x (torch.Tensor) – noisy input, of shape
[B, C, H, W].sigma (torch.Tensor, float) – noise level. Can be a
floator atorch.Tensorof shape[B]. If a singlefloatis provided, the same noise level is used for all samples in the batch. Otherwise, batch-wise noise levels are used.input_in_minus_one_one (bool) – whether the input
xis in[-1, 1]range. Default isFalse.args – additional positional arguments to be passed to the model.
kwarg – additional keyword arguments to be passed to the model. For example, a
promptfor text-conditioned orclass_labelfor class-conditioned models.
- Returns:
(
torch.Tensor) the denoised output.- Return type:
- get_schedule_value(schedule, t, target_size=None)[source]#
Get the value of a schedule (function or tensor) at given time steps.
- Parameters:
schedule (Callable | torch.Tensor) – schedule function or tensor.
t (torch.Tensor) – time steps, of shape
[B]or[].target_size (torch.Size) – target size to broadcast the output to. Default is
None.
- Returns:
(
torch.Tensor) schedule values at time stepst, of shape that is broadcastable totarget_sizeiftarget_sizeis provided.- Return type:
- score(x, t=None, *args, **kwargs)[source]#
Computes the score function \(\nabla_x \log p_t(x)\).
- Parameters:
x (torch.Tensor) – input tensor of shape
[B, C, H, W].t (torch.Tensor | float) – single timestep or tensor of shape
[B]or[].args – additional positional arguments of the model.
kwargs – additional keyword arguments of the model.
- Returns:
(
torch.Tensor) the score function of shape[B, C, H, W].- Return type:
- time_from_sigma(sigma)[source]#
Computes the time step
t in [0,T]corresponding to a given noise levelsigma.If an analytic inverse of the
sigma_tis provided, it is used. Otherwise, a numeric inversion is performed (nearest neighbor for discrete schedules, binary search for continuous schedules).- Parameters:
sigma (torch.Tensor | float) – noise level(s), either a scalar or a tensor of shape
[B].
Examples using ScoreModelWrapper:#
Using state-of-the-art diffusion models from HuggingFace Diffusers with DeepInverse