DiscriminatorMetric#

class deepinv.loss.adversarial.DiscriminatorMetric(metric=nn.MSELoss(), real_label=1.0, fake_label=0.0, no_grad=False, device='cpu')[source]#

Bases: object

Generic GAN discriminator metric building block.

Compares discriminator output with labels depending on if the image should be real or not.

The loss function is composed following LSGAN from Mao et al.[1].

This can be overriden to provide any flavour of discriminator metric, e.g. NSGAN, WGAN, LSGAN etc.

See Lucic et al.[2] for a comparison.

Parameters:
  • metric (torch.nn.Module) – loss with which to compare outputs, defaults to torch.nn.MSELoss

  • real_label (float) – value for ideal real image, defaults to 1.

  • fake_label (float) – value for ideal fake image, defaults to 0.

  • no_grad (bool) – whether to no_grad the metric computation, defaults to False

  • device (str) – torch device, defaults to "cpu"


References: