GeneratorLoss#

class deepinv.loss.adversarial.GeneratorLoss(weight_adv=1.0, D=None, device='cpu', **kwargs)[source]#

Bases: Loss

Base generator adversarial loss.

Override the forward function to call adversarial_loss with quantities depending on your specific GAN model. For examples, see SupAdversarialGeneratorLoss and UnsupAdversarialGeneratorLoss

See Imaging inverse problems with adversarial networks for formulae.

Parameters:
  • weight_adv (float) – weight for adversarial loss, defaults to 1.

  • D (torch.nn.Module) – discriminator network. If not specified, D must be provided in forward(), defaults to None.

  • device (str) – torch device, defaults to "cpu"

adversarial_loss(real, fake, D=None)[source]#

Typical adversarial loss in GAN generators.

Parameters:
  • real (torch.Tensor) – image labelled as real, typically one originating from training set

  • fake (torch.Tensor) – image labelled as fake, typically a reconstructed image

  • D (torch.nn.Module) – discriminator/critic/classifier model. If None, then D passed from __init__ used. Defaults to None.

Returns:

generator adversarial loss

Return type:

Tensor

forward(*args, D=None, **kwargs)[source]#

Computes the loss.

Parameters:
Returns:

(torch.Tensor) loss, the tensor size might be (1,) or (batch size,).

Return type:

Tensor

Examples using GeneratorLoss:#

Imaging inverse problems with adversarial networks

Imaging inverse problems with adversarial networks