GeneratorLoss#

class deepinv.loss.adversarial.GeneratorLoss(weight_adv: float = 1.0, D: Module | None = None, device='cpu', **kwargs)[source]#

Bases: Loss

Base generator adversarial loss. Override the forward function to call adversarial_loss with quantities depending on your specific GAN model. For examples, see deepinv.loss.adversarial.SupAdversarialGeneratorLoss, deepinv.loss.adversarial.UnsupAdversarialGeneratorLoss

See Imaging inverse problems with adversarial networks for formulae.

Parameters:
  • weight_adv (float) – weight for adversarial loss, defaults to 1.0

  • D (torch.nn.Module) – discriminator network. If not specified, D must be provided in forward(), defaults to None.

  • device (str) – torch device, defaults to “cpu”

adversarial_loss(real: Tensor, fake: Tensor, D: Module | None = None) Tensor[source]#

Typical adversarial loss in GAN generators.

Parameters:
  • real (Tensor) – image labelled as real, typically one originating from training set

  • fake (Tensor) – image labelled as fake, typically a reconstructed image

  • D (nn.Module) – discriminator/critic/classifier model. If None, then D passed from __init__ used. Defaults to None.

Return Tensor:

generator adversarial loss

forward(*args, D: Module | None = None, **kwargs) Tensor[source]#

Computes the loss.

Parameters:
Returns:

(torch.Tensor) loss, the tensor size might be (1,) or (batch size,).

Examples using GeneratorLoss:#

Imaging inverse problems with adversarial networks

Imaging inverse problems with adversarial networks