DiscriminatorMetric#
- class deepinv.loss.adversarial.DiscriminatorMetric(metric=nn.MSELoss(), real_label=1.0, fake_label=0.0, no_grad=False, device='cpu')[source]#
Bases:
object
Generic GAN discriminator metric building block.
Compares discriminator output with labels depending on if the image should be real or not.
The loss function is composed following LSGAN from Mao et al.[1].
This can be overriden to provide any flavour of discriminator metric, e.g. NSGAN, WGAN, LSGAN etc.
See Lucic et al.[2] for a comparison.
- Parameters:
metric (torch.nn.Module) – loss with which to compare outputs, defaults to
torch.nn.MSELoss
real_label (float) – value for ideal real image, defaults to 1.
fake_label (float) – value for ideal fake image, defaults to 0.
no_grad (bool) – whether to no_grad the metric computation, defaults to
False
device (str) – torch device, defaults to
"cpu"
- References: