import torch
import torch.nn as nn
import numpy as np
import deepinv.physics
from deepinv.loss.loss import Loss
def hutch_div(y, physics, f, mc_iter=1, rng=None):
r"""
Hutch divergence for A(f(x)).
:param torch.Tensor y: Measurements.
:param deepinv.physics.Physics physics: Forward operator associated with the measurements.
:param torch.nn.Module f: Reconstruction network.
:param int mc_iter: number of iterations. Default=1.
:param torch.Generator rng: Random number generator. Default is None.
:return: (float) hutch divergence.
"""
input = y.requires_grad_(True)
output = physics.A(f(input, physics))
out = 0
for i in range(mc_iter):
b = torch.empty_like(y).normal_(generator=rng)
x = torch.autograd.grad(output, input, b, retain_graph=True, create_graph=True)[
0
]
out += (b * x).reshape(y.size(0), -1).mean(1)
return out / mc_iter
def exact_div(y, physics, model):
r"""
Exact divergence for A(f(x)).
:param torch.Tensor y: Measurements.
:param deepinv.physics.Physics physics: Forward operator associated with the measurements.
:param torch.nn.Module model: Reconstruction network.
:param int mc_iter: number of iterations. Default=1.
:return: (float) exact divergence.
"""
input = y.requires_grad_(True)
output = physics.A(model(input, physics))
out = 0
_, c, h, w = input.shape
for i in range(c):
for j in range(h):
for k in range(w):
b = torch.zeros_like(input)
b[:, i, j, k] = 1
x = torch.autograd.grad(
output, input, b, retain_graph=True, create_graph=True
)[0]
out += (b * x).sum()
return out / (c * h * w)
def mc_div(y1, y, f, physics, tau, precond=lambda x: x, rng: torch.Generator = None):
r"""
Monte-Carlo estimation for the divergence of A(f(x)).
:param torch.Tensor y: Measurements.
:param deepinv.physics.Physics physics: Forward operator associated with the measurements.
:param torch.nn.Module f: Reconstruction network.
:param int mc_iter: number of iterations. Default=1.
:param float tau: Approximation constant for the Monte Carlo approximation of the divergence.
:param bool pinv: If ``True``, the pseudo-inverse of the forward operator is used. Default ``False``.
:param Callable precond: Preconditioner. Default is the identity.
:param torch.Generator rng: Random number generator. Default is None.
:return: (float) Ramani MC divergence.
"""
b = torch.empty_like(y).normal_(generator=rng)
y2 = physics.A(f(y + b * tau, physics))
return (precond(b) * precond(y2 - y1) / tau).reshape(y.size(0), -1).mean(1)
def unsure_gradient_step(loss, param, saved_grad, init_flag, step_size, momentum):
r"""
Gradient step for estimating the noise level in the UNSURE loss.
:param torch.Tensor loss: Loss value.
:param torch.Tensor param: Parameter to optimize.
:param torch.Tensor saved_grad: Saved gradient w.r.t. the parameter.
:param bool init_flag: Initialization flag (first gradient step).
:param float step_size: Step size.
:param float momentum: Momentum.
"""
grad = torch.autograd.grad(loss, param, retain_graph=True)[0]
if init_flag:
init_flag = False
saved_grad = grad
else:
saved_grad = momentum * saved_grad + (1.0 - momentum) * grad
return param + step_size * grad, saved_grad, init_flag
[docs]
class SureGaussianLoss(Loss):
r"""
SURE loss for Gaussian noise
The loss is designed for the following noise model:
.. math::
y \sim\mathcal{N}(u,\sigma^2 I) \quad \text{with}\quad u= A(x).
The loss is computed as
.. math::
\frac{1}{m}\|B(y - A\inverse{y})\|_2^2 -\sigma^2 +\frac{2\sigma^2}{m\tau}b^{\top} B^{\top} \left(A\inverse{y+\tau b_i} -
A\inverse{y}\right)
where :math:`R` is the trainable network, :math:`A` is the forward operator,
:math:`y` is the noisy measurement vector of size :math:`m`, :math:`A` is the forward operator,
:math:`B` is an optional linear mapping which should be approximately :math:`A^{\dagger}` (or any stable approximation),
:math:`b\sim\mathcal{N}(0,I)` and :math:`\tau\geq 0` is a hyperparameter controlling the
Monte Carlo approximation of the divergence.
This loss approximates the divergence of :math:`A\inverse{y}` (in the original SURE loss)
using the Monte Carlo approximation in
https://ieeexplore.ieee.org/abstract/document/4099398/
If the measurement data is truly Gaussian with standard deviation :math:`\sigma`,
this loss is an unbiased estimator of the mean squared loss :math:`\frac{1}{m}\|u-A\inverse{y}\|_2^2`
where :math:`z` is the noiseless measurement.
.. warning::
The loss can be sensitive to the choice of :math:`\tau`, which should be proportional to the size of :math:`y`.
The default value of 0.01 is adapted to :math:`y` vectors with entries in :math:`[0,1]`.
.. note::
If the noise level is unknown, the loss can be adapted to the UNSURE loss introduced in https://arxiv.org/abs/2409.01985,
which also learns the noise level.
:param float sigma: Standard deviation of the Gaussian noise.
:param float tau: Approximation constant for the Monte Carlo approximation of the divergence.
:param Callable, str B: Optional linear metric :math:`B`, which can be used to improve
the performance of the loss. If 'A_dagger', the pseudo-inverse of the forward operator is used.
Otherwise the metric should be a linear operator that approximates the pseudo-inverse of the forward operator
such as :func:`deepinv.physics.LinearPhysics.prox_l2` with large :math:`\gamma`. By default, the identity is used.
:param bool unsure: If ``True``, the loss is adapted to the UNSURE loss introduced in https://arxiv.org/abs/2409.01985
where the noise level :math:`\sigma` is also learned (the input value is used as initialization).
:param float step_size: Step size for the gradient ascent of the noise level if unsure is ``True``.
:param float momentum: Momentum for the gradient ascent of the noise level if unsure is ``True``.
:param torch.Generator rng: Optional random number generator. Default is None.
"""
def __init__(
self,
sigma,
tau=1e-2,
B=lambda x: x,
unsure=False,
step_size=1e-4,
momentum=0.9,
rng: torch.Generator = None,
):
super(SureGaussianLoss, self).__init__()
self.name = "SureGaussian"
self.sigma2 = sigma**2
self.tau = tau
self.metric = B
self.unsure = unsure
self.init_flag = False
self.step_size = step_size
self.momentum = momentum
self.grad_sigma = 0.0
self.rng = rng
if unsure:
self.sigma2 = torch.tensor(self.sigma2, requires_grad=True)
[docs]
def forward(self, y, x_net, physics, model, **kwargs):
r"""
Computes the SURE Loss.
:param torch.Tensor y: Measurements.
:param torch.Tensor x_net: reconstructed image :math:`\inverse{y}`.
:param deepinv.physics.Physics physics: Forward operator associated with the measurements.
:param torch.nn.Module model: Reconstruction network.
:return: torch.nn.Tensor loss of size (batch_size,)
"""
if self.metric == "A_dagger":
metric = lambda x: physics.A_dagger(x)
else:
metric = self.metric
y1 = physics.A(x_net)
div = (
2 * self.sigma2 * mc_div(y1, y, model, physics, self.tau, metric, self.rng)
)
mse = metric(y1 - y).pow(2).reshape(y.size(0), -1).mean(1)
loss_sure = mse + div - self.sigma2
if self.unsure: # update the estimate of the noise level
self.sigma2, self.grad_sigma, self.init_flag = unsure_gradient_step(
div.mean(),
self.sigma2,
self.grad_sigma,
self.init_flag,
self.step_size,
self.momentum,
)
return loss_sure
[docs]
class SurePoissonLoss(Loss):
r"""
SURE loss for Poisson noise
The loss is designed for the following noise model:
.. math::
y = \gamma z \quad \text{with}\quad z\sim \mathcal{P}(\frac{u}{\gamma}), \quad u=A(x).
The loss is computed as
.. math::
\frac{1}{m}\|y-A\inverse{y}\|_2^2-\frac{\gamma}{m} 1^{\top}y
+\frac{2\gamma}{m\tau}(b\odot y)^{\top} \left(A\inverse{y+\tau b}-A\inverse{y}\right)
where :math:`R` is the trainable network, :math:`y` is the noisy measurement vector of size :math:`m`,
:math:`b` is a Bernoulli random variable taking values of -1 and 1 each with a probability of 0.5,
:math:`\tau` is a small positive number, and :math:`\odot` is an elementwise multiplication.
See https://ieeexplore.ieee.org/abstract/document/6714502/ for details.
If the measurement data is truly Poisson
this loss is an unbiased estimator of the mean squared loss :math:`\frac{1}{m}\|u-A\inverse{y}\|_2^2`
where :math:`z` is the noiseless measurement.
.. warning::
The loss can be sensitive to the choice of :math:`\tau`, which should be proportional to the size of :math:`y`.
The default value of 0.01 is adapted to :math:`y` vectors with entries in :math:`[0,1]`.
:param float gain: Gain of the Poisson Noise.
:param float tau: Approximation constant for the Monte Carlo approximation of the divergence.
:param torch.Generator rng: Optional random number generator. Default is None.
"""
def __init__(self, gain, tau=1e-3, rng: torch.Generator = None):
super(SurePoissonLoss, self).__init__()
self.name = "SurePoisson"
self.gain = gain
self.tau = tau
self.rng = rng
[docs]
def forward(self, y, x_net, physics, model, **kwargs):
r"""
Computes the SURE loss.
:param torch.Tensor y: measurements.
:param torch.Tensor x_net: reconstructed image :math:`\inverse{y}`.
:param deepinv.physics.Physics physics: Forward operator associated with the measurements
:param torch.nn.Module model: Reconstruction network
:return: torch.nn.Tensor loss of size (batch_size,)
"""
# generate a random vector b
b = torch.empty_like(y).uniform_(generator=self.rng)
b = b > 0.5
b = (2 * b - 1) * 1.0 # binary [-1, 1]
y1 = physics.A(x_net)
y2 = physics.A(model(y + self.tau * b, physics))
loss_sure = (
(y1 - y).pow(2)
- self.gain * y
+ (2.0 / self.tau) * self.gain * (b * y * (y2 - y1))
)
loss_sure = loss_sure.reshape(y.size(0), -1).mean(1)
return loss_sure
[docs]
class SurePGLoss(Loss):
r"""
SURE loss for Poisson-Gaussian noise
The loss is designed for the following noise model:
.. math::
y = \gamma z + \epsilon
where :math:`u = A(x)`, :math:`z \sim \mathcal{P}\left(\frac{u}{\gamma}\right)`,
and :math:`\epsilon \sim \mathcal{N}(0, \sigma^2 I)`.
The loss is computed as
.. math::
& \frac{1}{m}\|y-A\inverse{y}\|_2^2-\frac{\gamma}{m} 1^{\top}y-\sigma^2
+\frac{2}{m\tau_1}(b\odot (\gamma y + \sigma^2 I))^{\top} \left(A\inverse{y+\tau b}-A\inverse{y} \right) \\\\
& +\frac{2\gamma \sigma^2}{m\tau_2^2}c^{\top} \left( A\inverse{y+\tau c} + A\inverse{y-\tau c} - 2A\inverse{y} \right)
where :math:`R` is the trainable network, :math:`y` is the noisy measurement vector,
:math:`b` is a Bernoulli random variable taking values of -1 and 1 each with a probability of 0.5,
:math:`\tau` is a small positive number, and :math:`\odot` is an elementwise multiplication.
If the measurement data is truly Poisson-Gaussian
this loss is an unbiased estimator of the mean squared loss :math:`\frac{1}{m}\|u-A\inverse{y}\|_2^2`
where :math:`z` is the noiseless measurement.
See https://ieeexplore.ieee.org/abstract/document/6714502/ for details.
.. warning::
The loss can be sensitive to the choice of :math:`\tau`, which should be proportional to the size of :math:`y`.
The default value of 0.01 is adapted to :math:`y` vectors with entries in :math:`[0,1]`.
.. note::
If the noise levels are unknown, the loss can be adapted to the UNSURE loss introduced in https://arxiv.org/abs/2409.01985,
which also learns the noise levels.
:param float sigma: Standard deviation of the Gaussian noise.
:param float gamma: Gain of the Poisson Noise.
:param float tau: Approximation constant for the Monte Carlo approximation of the divergence.
:param float tau2: Approximation constant for the second derivative.
:param bool second_derivative: If ``False``, the last term in the loss (approximating the second derivative) is removed
to speed up computations, at the cost of a possibly inexact loss. Default ``True``.
:param bool unsure: If ``True``, the loss is adapted to the UNSURE loss introduced in https://arxiv.org/abs/2409.01985
where :math:`\gamma` and :math:`\sigma^2` are also learned (their input value is used as initialization).
:param tuple[float] step_size: Step size for the gradient ascent of the noise levels if unsure is ``True``.
:param tuple[float] momentum: Momentum for the gradient ascent of the noise levels if unsure is ``True``.
:param torch.Generator rng: Optional random number generator. Default is None.
"""
def __init__(
self,
sigma,
gain,
tau1=1e-3,
tau2=1e-2,
second_derivative=False,
unsure=False,
step_size=(1e-4, 1e-4),
momentum=(0.9, 0.9),
rng=None,
):
super(SurePGLoss, self).__init__()
self.name = "SurePG"
# self.sure_loss_weight = sure_loss_weight
self.sigma2 = sigma**2
self.gain = gain
self.tau1 = tau1
self.tau2 = tau2
self.second_derivative = second_derivative
self.step_size = step_size
self.grad_sigma = 0.0
self.grad_gain = 0.0
self.momentum = momentum
self.init_flag_sigma = True
self.init_flag_gain = True
self.unsure = unsure
self.rng = rng
if unsure:
self.sigma2 = torch.tensor(self.sigma2, requires_grad=True)
self.gain = torch.tensor(self.gain, requires_grad=True)
[docs]
def forward(self, y, x_net, physics, model, **kwargs):
r"""
Computes the SURE loss.
:param torch.Tensor y: measurements.
:param torch.Tensor x_net: reconstructed image :math:`\inverse{y}`.
:param deepinv.physics.Physics physics: Forward operator associated with the measurements
:param torch.nn.Module f: Reconstruction network
:return: torch.nn.Tensor loss of size (batch_size,)
"""
b1 = torch.empty_like(y).uniform_(generator=self.rng)
b1 = b1 > 0.5
b1 = (2 * b1 - 1) * 1.0 # binary [-1, 1]
p = 0.7236 # .5 + .5*np.sqrt(1/5.)
b2 = torch.ones_like(b1) * np.sqrt(p / (1 - p))
b2[torch.empty_like(y).uniform_(generator=self.rng) < p] = -np.sqrt((1 - p) / p)
meas1 = physics.A(x_net)
meas2 = physics.A(model(y + self.tau1 * b1, physics))
loss_mc = (meas1 - y).pow(2).reshape(y.size(0), -1).mean(1)
loss_div1 = (
2
/ self.tau1
* ((b1 * (self.gain * y + self.sigma2)) * (meas2 - meas1))
.reshape(y.size(0), -1)
.mean(1)
)
offset = -self.gain * y.reshape(y.size(0), -1).mean(1) - self.sigma2
if self.unsure: # update the estimate of the noise levels
div = loss_div1.mean()
self.sigma2, self.grad_sigma, self.init_flag_sigma = unsure_gradient_step(
div,
self.sigma2,
self.grad_sigma,
self.init_flag_sigma,
self.step_size[0],
self.momentum[0],
)
self.gain, self.grad_gain, self.init_flag_gain = unsure_gradient_step(
div,
self.gain,
self.grad_gain,
self.init_flag_gain,
self.step_size[1],
self.momentum[1],
)
if self.second_derivative:
meas2p = physics.A(model(y + self.tau2 * b2, physics))
meas2n = physics.A(model(y - self.tau2 * b2, physics))
loss_div2 = (
-2
* self.sigma2
* self.gain
/ (self.tau2**2)
* (b2 * (meas2p + meas2n - 2 * meas1)).reshape(y.size(0), -1).mean(1)
)
else:
loss_div2 = torch.zeros_like(loss_div1)
loss_sure = loss_mc + loss_div1 + loss_div2 + offset
return loss_sure