ENSURELoss#
- class deepinv.loss.mri.ENSURELoss(sigma, physics_generator, tau=None, rng=None)[source]#
Bases:
SureGaussianLoss
ENSURE loss for image reconstruction in Gaussian noise.
The loss function is a special case of
deepinv.loss.SureGaussianLoss
for MRI/inpainting with varying masks, and is designed for the following noise model:\[y \sim\mathcal{N}(u,\sigma^2 I) \quad \text{with}\quad u= A_i(x).\]where \(A_i\sim\mathcal{A}\) is assumed to be drawn from a set of measurement operators. The loss is computed as
\[\frac{1}{m}\|\Beta(A^{\dagger}y - \inverse{y})\|_2^2 +\frac{2\sigma^2}{m\tau}b^{\top} \left(\inverse{A^{\top}y+\tau b_i} - \inverse{A^{\top}y}\right)\]where \(R\) is the trainable network (which takes \(A^\top y\) as input), \(A\) is the forward operator, \(y\) is the noisy measurement vector of size \(m\), \(b\sim\mathcal{N}(0,I)\), \(\tau\geq 0\) is a hyperparameter controlling the Monte Carlo approximation of the divergence, and \(\Beta=W^{-1}P\) where \(P\) is the projection operator onto the range space of \(\A^\top\) and \(W\) is a weighting determined by the set of measurement operators where \(W^2=\mathbb{E}\left[P\right]\).
The ENSURE loss was proposed in Aggarwal et al.[1] for MRI.
Warning
This loss was originally proposed only to be used with
artifact removal models
which can be written in the form \(\inverse{\cdot}=r(A^\top\cdot)\). If an artifact removal model is not used, then we evaluate the network directly instead.We currently only provide an implementation for
single-coil MRI
andinpainting
, whereA^top=A^dagger
such that \(P=A^{\top}A\) and then \(W\) is a weighted average over sampling masks.- Parameters:
sigma (float) – Standard deviation of the Gaussian noise.
physics_generator (deepinv.physics.generator.PhysicsGenerator) – random physics generator used to compute the weighting \(W\).
tau (float) – Approximation constant for the Monte Carlo approximation of the divergence. Defaults to \(0.1\sigma\).
rng (torch.Generator) – Optional random number generator. Default is None.
- References:
- div(x_net, y, f, physics)[source]#
Monte-Carlo estimation for the divergence of f(x).
- Parameters:
x_net (torch.Tensor) – Reconstructions.
y (torch.Tensor) – Measurements.
physics (deepinv.physics.Physics) – Forward operator associated with the measurements.
f (torch.nn.Module) – Reconstruction network.