DiscriminatorMetric#
- class deepinv.loss.adversarial.DiscriminatorMetric(metric: Module = MSELoss(), real_label: float = 1.0, fake_label: float = 0.0, no_grad: bool = False, device='cpu')[source]#
Bases:
object
Generic GAN discriminator metric building block.
Compares discriminator output with labels depending on if the image should be real or not.
The loss function is composed following LSGAN: Least Squares Generative Adversarial Networks
This can be overriden to provide any flavour of discriminator metric, e.g. NSGAN, WGAN, LSGAN etc.
See Are GANs Created Equal? for a comparison.
- Parameters:
metric (torch.nn.Module) – loss with which to compare outputs, defaults to
torch.nn.MSELoss()
real_label (float) – value for ideal real image, defaults to 1.
fake_label (float) – value for ideal fake image, defaults to 0.
no_grad (bool) – whether to no_grad the metric computation, defaults to
False
device (str) – torch device, defaults to
"cpu"