DiscriminatorLoss
- class deepinv.loss.adversarial.DiscriminatorLoss(weight_adv: float = 1.0, D: Module | None = None, device='cpu', **kwargs)[source]
Bases:
Loss
Base discriminator adversarial loss. Override the forward function to call adversarial_loss with quantities depending on your specific GAN model. For examples, see
deepinv.loss.adversarial.SupAdversarialDiscriminatorLoss
,deepinv.loss.adversarial.UnsupAdversarialDiscriminatorLoss
.See Imaging inverse problems with adversarial networks for formulae.
- Parameters:
weight_adv (float) – weight for adversarial loss, defaults to 1.0
D (torch.nn.Module) – discriminator network. If not specified, D must be provided in forward(), defaults to None.
device (str) – torch device, defaults to “cpu”
- adversarial_loss(real: Tensor, fake: Tensor, D: Module | None = None)[source]
Typical adversarial loss in GAN discriminators.
- Parameters:
real (Tensor) – image labelled as real, typically one originating from training set
fake (Tensor) – image labelled as fake, typically a reconstructed image
D (nn.Module) – discriminator/critic/classifier model. If None, then D passed from __init__ used. Defaults to None.
- Return Tensor:
discriminator adversarial loss
- forward(*args, D: Module | None = None, **kwargs) Tensor [source]
Computes the loss.
- Parameters:
x_net (torch.Tensor) – Reconstructed image \(\inverse{y}\).
x (torch.Tensor) – Reference image.
y (torch.Tensor) – Measurement.
physics (deepinv.physics.Physics) – Forward operator associated with the measurements.
model (torch.nn.Module) – Reconstruction function.
- Returns:
(torch.Tensor) loss, the tensor size might be (1,) or (batch size,).
Examples using DiscriminatorLoss
:
Imaging inverse problems with adversarial networks