Adversarial Networks#

There are two types of adversarial networks for imaging: conditional and unconditional. See Imaging inverse problems with adversarial networks for examples. Adversarial training can be done using the deepinv.training.AdversarialTrainer class, which is a subclass of deepinv.Trainer.

Conditional GAN#

Conditional generative adversarial networks (cGANs) aim to learn a reconstruction network \(\hat{x}=R(y,A,z)\), which maps the measurements \(y\) to the signal \(x\), possibly conditioned on a random variable \(z\) and the forward operator \(A\), which strikes a good trade-off between distortion \(\|x-\hat{x}\|^2\) and perception (how close are the distribution of reconstructed and clean images \(p_{\hat{x}}\) and \(p_x\)).

They are trained by adding an adversarial loss \(\mathcal{L}_\text{adv}\) to the standard reconstruction loss:

\[\mathcal{L}_\text{total}=\mathcal{L}_\text{rec}+\lambda\mathcal{L}_\text{adv}\]

where \(\lambda\) is a hyperparameter that balances the two losses. The reconstruction loss is often a mean squared error (MSE) \(\mathcal{L}_\text{rec}(x,\hat{x})=\|x-\hat{x}\|^2\) (or a self-supervised alternative), while the adversarial loss is

\[\mathcal{L}_\text{adv}(x,\hat x;D)=\mathbb{E}_{x\sim p_x}\left[q(D(x))\right]+\mathbb{E}_{\hat x\sim p_{\hat x}}\left[q(1-D(\hat x))\right]\]

where \(D(\cdot)\) is the discriminator model, \(x\) is the reference image, \(\hat{x}\) is the estimated reconstruction, \(q(\cdot)\) is a quality function (e.g \(q(x)=x\) for WGAN). Training alternates between generator \(G\) and discriminator \(D\) in a minimax game. When there are no ground truths (i.e. self-supervised), this may be defined on the measurements \(y\) instead. See the list of available adversarial losses in Adversarial Learning.

The reconstruction network (i.e. the “generator”) \(R\) can be any architecture that maps the measurements \(y\) to the signal \(x\), including artifact removal or unfolded networks.

The discriminator network \(D\) can be implemented with one of the following architectures:

Table 13 Discriminator Networks#

Discriminator

Description

DCGANDiscriminator

Deep Convolution GAN discriminator model

ESRGANDiscriminator

Enhanced Super-Resolution GAN discriminator model

PatchGANDiscriminator

PatchGAN discriminator model

Unconditional GAN#

Unconditional generative adversarial networks train a generator network \(\hat{x}=G(z)\) to map a simple distribution \(p_z\) (e.g., Gaussian) to the signal distribution \(p_x\). The generator is trained with an adversarial loss:

\[\mathcal{L}_\text{total}=\mathcal{L}_\text{adv}(\hat x, x;D)\]

See the list of available adversarial losses in Adversarial Learning, including CSGM and AmbientGAN training.

Once the generator is trained, we can solve inverse problems by looking for a latent \(z\) that matches the observed measurements \(\forw{R(z)}\approx y\):

\[\hat x = \inverse{\hat z}\quad\text{s.t.}\quad\hat z=\operatorname*{argmin}_z \lVert \forw{\inverse{z}}-y\rVert _2^2\]

We can adapt any latent generator model to train an unconditional GAN and perform conditional inference:

Table 14 Unconditional GANs#

Generator

Description

DCGANGenerator

DCGAN unconditional generator model

CSGMGenerator

Adapts an unconditional generator model for CSGM or AmbientGAN training.

Deep Image Prior#

The deep image prior uses an untrained convolutional decoder network as \(R\) applied to a random input \(z\). The choice of the architecture of \(R\) is crucial for the success of the method: we provide the deepinv.models.ConvDecoder architecture, which is based on a convolutional decoder network, and has shown good inductive bias for image reconstruction tasks.