DiscriminatorLoss#
- class deepinv.loss.adversarial.DiscriminatorLoss(weight_adv=1.0, D=None, device='cpu', **kwargs)[source]#
Bases:
LossBase discriminator adversarial loss.
Override the forward function to call
adversarial_losswith quantities depending on your specific GAN model.For examples, see
deepinv.loss.adversarial.SupAdversarialDiscriminatorLoss,deepinv.loss.adversarial.UnsupAdversarialDiscriminatorLoss.See Imaging inverse problems with adversarial networks for formula.
- Parameters:
weight_adv (float) – weight for adversarial loss, defaults to 1.
D (torch.nn.Module) – discriminator network. If not specified,
Dmust be provided inforward, defaults toNone.device (str) – torch device, defaults to
"cpu".
- adversarial_loss(real, fake, D=None)[source]#
Typical adversarial loss in GAN discriminators.
- Parameters:
real (torch.Tensor) – image labelled as real, typically one originating from training set
fake (torch.Tensor) – image labelled as fake, typically a reconstructed image
D (torch.nn.Module) – discriminator/critic/classifier model. If None, then D passed from __init__ used. Defaults to None.
- Returns:
(
torch.Tensor) discriminator adversarial loss
Examples using DiscriminatorLoss:#
Imaging inverse problems with adversarial networks