DiscriminatorMetric#

class deepinv.loss.adversarial.DiscriminatorMetric(metric: Module = MSELoss(), real_label: float = 1.0, fake_label: float = 0.0, no_grad: bool = False, device='cpu')[source]#

Bases: object

Generic GAN discriminator metric building block.

Compares discriminator output with labels depending on if the image should be real or not.

The loss function is composed following LSGAN: Least Squares Generative Adversarial Networks

This can be overriden to provide any flavour of discriminator metric, e.g. NSGAN, WGAN, LSGAN etc.

See Are GANs Created Equal? for a comparison.

Parameters:
  • metric (torch.nn.Module) – loss with which to compare outputs, defaults to torch.nn.MSELoss()

  • real_label (float) – value for ideal real image, defaults to 1.

  • fake_label (float) – value for ideal fake image, defaults to 0.

  • no_grad (bool) – whether to no_grad the metric computation, defaults to False

  • device (str) – torch device, defaults to "cpu"