UAIRGeneratorLoss#
- class deepinv.loss.adversarial.UAIRGeneratorLoss(weight_adv: float = 0.5, weight_mc: float = 1, metric: Module = MSELoss(), D: Module | None = None, device='cpu')[source]#
Bases:
GeneratorLoss
Reimplementation of UAIR generator’s adversarial loss.
Pajot et al., “Unsupervised Adversarial Image Reconstruction”.
The loss is defined as follows, to be minimised by the generator:
\(\mathcal{L}=\mathcal{L}_\text{adv}(\hat y, y;D)+\lVert \forw{\inverse{\hat y}}- \hat y\rVert^2_2,\quad\hat y=\forw{\hat x}\)
where the standard adversarial loss is
\(\mathcal{L}_\text{adv}(y,\hat y;D)=\mathbb{E}_{y\sim p_y}\left[q(D(y))\right]+\mathbb{E}_{\hat y\sim p_{\hat y}}\left[q(1-D(\hat y))\right]\)
See Imaging inverse problems with adversarial networks for examples of training generator and discriminator models.
Simple example (assuming a pretrained discriminator):
from deepinv.models import DCGANDiscriminator D = DCGANDiscriminator() # assume pretrained discriminator loss = UAIRGeneratorLoss(D=D) l = loss(y, y_hat, physics, model) l.backward()
- Parameters:
weight_adv (float) – weight for adversarial loss, defaults to 0.5 (from original paper)
weight_mc (float) – weight for measurement consistency, defaults to 1.0 (from original paper)
metric (nn.Module) – metric for measurement consistency, defaults to nn.MSELoss
D (torch.nn.Module) – discriminator network. If not specified, D must be provided in forward(), defaults to None.
device (str) – torch device, defaults to “cpu”
- forward(y: Tensor, y_hat: Tensor, physics: Physics, model: Module, D: Module | None = None, **kwargs)[source]#
Forward pass for UAIR generator’s adversarial loss.
- Parameters:
y (Tensor) – input measurement
y_hat (Tensor) – re-measured reconstruction
physics (Physics) – forward physics
model (nn.Module) – reconstruction network
D (nn.Module) – discriminator model. If None, then D passed from __init__ used. Defaults to None.
Examples using UAIRGeneratorLoss
:#
Imaging inverse problems with adversarial networks