
# Imaging inverse problems with adversarial networks

This example shows you how to train various networks using adversarial
training for deblurring problems. We demonstrate running training and
inference using a conditional GAN (i.e. DeblurGAN), CSGM, AmbientGAN and
UAIR implemented in the library, and how to simply train
your own GAN by using :meth:`deepinv.training.AdversarialTrainer`. These
examples can also be easily extended to train more complicated GANs such
as CycleGAN.

This example is based on the following papers:

-  Kupyn et al., [DeblurGAN: Blind Motion Deblurring Using Conditional Adversarial Networks](https://openaccess.thecvf.com/content_cvpr_2018/papers/Kupyn_DeblurGAN_Blind_Motion_CVPR_2018_paper.pdf)
-  Bora et al., [Compressed Sensing using Generative
   Models](https://arxiv.org/abs/1703.03208) (CSGM)
-  Bora et al., [AmbientGAN: Generative models from lossy
   measurements](https://openreview.net/forum?id=Hy7fDog0b)
-  Pajot et al., [Unsupervised Adversarial Image
   Reconstruction](https://openreview.net/forum?id=BJg4Z3RqF7)

Adversarial networks are characterised by the addition of an adversarial
loss $\mathcal{L}_\text{adv}$ to the standard reconstruction loss:

\begin{align}\mathcal{L}_\text{adv}(x,\hat x;D)=\mathbb{E}_{x\sim p_x}\left[q(D(x))\right]+\mathbb{E}_{\hat x\sim p_{\hat x}}\left[q(1-D(\hat x))\right]\end{align}

where $D(\cdot)$ is the discriminator model, $x$ is the
reference image, $\hat x$ is the estimated reconstruction,
$q(\cdot)$ is a quality function (e.g $q(x)=x$ for WGAN).
Training alternates between generator $G$ and discriminator
$D$ in a minimax game. When there are no ground truths (i.e.
unsupervised), this may be defined on the measurements $y$
instead.


In [None]:
import deepinv as dinv
from deepinv.loss import adversarial
from deepinv.physics.generator import MotionBlurGenerator
import torch
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import ImageFolder
from torchvision.transforms import Compose, ToTensor, CenterCrop, Resize
from torchvision.datasets.utils import download_and_extract_archive

device = dinv.utils.get_freer_gpu() if torch.cuda.is_available() else "cpu"

## Generate dataset
In this example we use the Urban100 dataset resized to 128x128. We apply random
motion blur physics using
:meth:`deepinv.physics.generator.MotionBlurGenerator`, and save the data
using :meth:`deepinv.datasets.generate_dataset`.




In [None]:
physics = dinv.physics.Blur(padding="circular", device=device)
blur_generator = MotionBlurGenerator((11, 11))

dataset = dinv.datasets.Urban100HR(
    root="Urban100",
    download=True,
    transform=Compose([ToTensor(), Resize(256), CenterCrop(128)]),
)

train_dataset, test_dataset = random_split(dataset, (0.8, 0.2))

# Generate data pairs x,y offline using a physics generator
dataset_path = dinv.datasets.generate_dataset(
    train_dataset=train_dataset,
    test_dataset=test_dataset,
    physics=physics,
    physics_generator=blur_generator,
    device=device,
    save_dir="Urban100",
    batch_size=1,
)

train_dataloader = DataLoader(
    dinv.datasets.HDF5Dataset(dataset_path, train=True), shuffle=True
)
test_dataloader = DataLoader(
    dinv.datasets.HDF5Dataset(dataset_path, train=False), shuffle=False
)

## Define models

We first define reconstruction network (i.e conditional generator) and
discriminator network to use for adversarial training. For demonstration
we use a simple U-Net as the reconstruction network and the
discriminator from [PatchGAN](https://arxiv.org/abs/1611.07004), but
these can be replaced with any architecture e.g transformers, unrolled
etc. Further discriminator models are in `adversarial models <adversarial-networks>`.




In [None]:
def get_models(model=None, D=None, lr_g=1e-4, lr_d=1e-4, device=device):
    if model is None:
        model = dinv.models.UNet(
            in_channels=3,
            out_channels=3,
            scales=2,
            circular_padding=True,
            batch_norm=False,
        ).to(device)

    if D is None:
        D = dinv.models.PatchGANDiscriminator(n_layers=2, batch_norm=False).to(device)

    optimizer = dinv.training.adversarial.AdversarialOptimizer(
        torch.optim.Adam(model.parameters(), lr=lr_g, weight_decay=1e-8),
        torch.optim.Adam(D.parameters(), lr=lr_d, weight_decay=1e-8),
    )
    scheduler = dinv.training.adversarial.AdversarialScheduler(
        torch.optim.lr_scheduler.StepLR(optimizer.G, step_size=5, gamma=0.9),
        torch.optim.lr_scheduler.StepLR(optimizer.D, step_size=5, gamma=0.9),
    )

    return model, D, optimizer, scheduler

## Conditional GAN training

Conditional GANs (Kupyn et al., [DeblurGAN: Blind Motion Deblurring Using Conditional Adversarial Networks](https://openaccess.thecvf.com/content_cvpr_2018/papers/Kupyn_DeblurGAN_Blind_Motion_CVPR_2018_paper.pdf))
are a type of GAN where the generator is conditioned on a label or input. In the context of imaging,
this can be used to generate images from a given measurement. In this example, we use a simple U-Net as the generator
and a PatchGAN discriminator. The forward pass of the generator is given by:

**Conditional GAN** forward pass:

\begin{align}\hat x = G(y)\end{align}

**Conditional GAN** loss:

\begin{align}\mathcal{L}=\mathcal{L}_\text{sup}(\hat x, x)+\mathcal{L}_\text{adv}(\hat x, x;D)\end{align}

where $\mathcal{L}_\text{sup}$ is a supervised loss such as
pixel-wise MSE or VGG Perceptual Loss.




In [None]:
G, D, optimizer, scheduler = get_models()

We next define pixel-wise and adversarial losses as defined above. We use the
MSE for the supervised pixel-wise metric for simplicity but this can be
easily replaced with a perceptual loss if desired.




In [None]:
loss_g = [
    dinv.loss.SupLoss(metric=torch.nn.MSELoss()),
    adversarial.SupAdversarialGeneratorLoss(device=device),
]
loss_d = adversarial.SupAdversarialDiscriminatorLoss(device=device)

We are now ready to train the networks using :meth:`deepinv.training.AdversarialTrainer`.
We load the pretrained models that were trained in the exact same way after 50 epochs,
and fine-tune the model for 1 epoch for a quick demo.
You can find the pretrained models on HuggingFace https://huggingface.co/deepinv/adversarial-demo.
To train from scratch, simply comment out the model loading code and increase the number of epochs.




In [None]:
ckpt = torch.hub.load_state_dict_from_url(
    dinv.models.utils.get_weights_url("adversarial-demo", "deblurgan_model.pth"),
    map_location=lambda s, _: s,
)

G.load_state_dict(ckpt["state_dict"])
D.load_state_dict(ckpt["state_dict_D"])
optimizer.load_state_dict(ckpt["optimizer"])

trainer = dinv.training.AdversarialTrainer(
    model=G,
    D=D,
    physics=physics,
    train_dataloader=train_dataloader,
    eval_dataloader=test_dataloader,
    epochs=1,
    losses=loss_g,
    losses_d=loss_d,
    optimizer=optimizer,
    scheduler=scheduler,
    verbose=True,
    show_progress_bar=False,
    save_path=None,
    device=device,
)

G = trainer.train()

Test the trained model and plot the results. We compare to the pseudo-inverse as a baseline.




In [None]:
trainer.plot_images = True
trainer.test(test_dataloader)

## UAIR training

Unsupervised Adversarial Image Reconstruction (UAIR) (Pajot et al.,
[Unsupervised Adversarial Image Reconstruction](https://openreview.net/forum?id=BJg4Z3RqF7))
is a method for solving inverse problems using generative models. In this
example, we use a simple U-Net as the generator and discriminator, and
train using the adversarial loss. The forward pass of the generator is defined as:

**UAIR** forward pass:

\begin{align}\hat x = G(y),\end{align}

**UAIR** loss:

\begin{align}\mathcal{L}=\mathcal{L}_\text{adv}(\hat y, y;D)+\lVert \forw{\inverse{\hat y}}- \hat y\rVert^2_2,\quad\hat y=\forw{\hat x}.\end{align}

We next load the models and construct losses as defined above.



In [None]:
G, D, optimizer, scheduler = get_models(
    lr_g=1e-4, lr_d=4e-4
)  # learning rates from original paper

loss_g = adversarial.UAIRGeneratorLoss(device=device)
loss_d = adversarial.UnsupAdversarialDiscriminatorLoss(device=device)

We are now ready to train the networks using :meth:`deepinv.training.AdversarialTrainer`.
Like above, we load a pretrained model trained in the exact same way for 50 epochs,
and fine-tune here for a quick demo with 1 epoch.




In [None]:
ckpt = torch.hub.load_state_dict_from_url(
    dinv.models.utils.get_weights_url("adversarial-demo", "uair_model.pth"),
    map_location=lambda s, _: s,
)

G.load_state_dict(ckpt["state_dict"])
D.load_state_dict(ckpt["state_dict_D"])
optimizer.load_state_dict(ckpt["optimizer"])

trainer = dinv.training.AdversarialTrainer(
    model=G,
    D=D,
    physics=physics,
    train_dataloader=train_dataloader,
    eval_dataloader=test_dataloader,
    epochs=1,
    losses=loss_g,
    losses_d=loss_d,
    optimizer=optimizer,
    scheduler=scheduler,
    verbose=True,
    show_progress_bar=False,
    save_path=None,
    device=device,
)
G = trainer.train()

Test the trained model and plot the results:




In [None]:
trainer.plot_images = True
trainer.test(test_dataloader)

## CSGM / AmbientGAN training

Compressed Sensing using Generative Models (CSGM) and AmbientGAN are two methods for solving inverse problems
using generative models. CSGM (Bora et al., [Compressed Sensing using Generative Models](https://arxiv.org/abs/1703.03208)) uses a generative model to solve the inverse problem by optimising the latent
space of the generator. AmbientGAN (Bora et al., [AmbientGAN: Generative models from lossy measurements](https://openreview.net/forum?id=Hy7fDog0b)) uses a generative model to solve the inverse problem by optimising the
measurements themselves. Both methods are trained using an adversarial loss; the main difference is that CSGM requires
a ground truth dataset (supervised loss), while AmbientGAN does not (unsupervised loss).

In this example, we use a DCGAN as the
generator and discriminator, and train using the adversarial loss. The forward pass of the generator is given by:

**CSGM** forward pass at train time:

\begin{align}\hat x = \inverse{z},\quad z\sim \mathcal{N}(\mathbf{0},\mathbf{I}_k)\end{align}

**CSGM**/**AmbientGAN** forward pass at eval time:

\begin{align}\hat x = \inverse{\hat z}\quad\text{s.t.}\quad\hat z=\operatorname*{argmin}_z \lVert \forw{\inverse{z}}-y\rVert _2^2\end{align}

**CSGM** loss:

\begin{align}\mathcal{L}=\mathcal{L}_\text{adv}(\hat x, x;D)\end{align}

**AmbientGAN** loss (where $\forw{\cdot}$ is the physics):

\begin{align}\mathcal{L}=\mathcal{L}_\text{adv}(\forw{\hat x}, y;D)\end{align}

We next load the models and construct losses as defined above.



In [None]:
G = dinv.models.CSGMGenerator(
    dinv.models.DCGANGenerator(output_size=128, nz=100, ngf=32), inf_tol=1e-2
).to(device)
D = dinv.models.DCGANDiscriminator(ndf=32).to(device)
_, _, optimizer, scheduler = get_models(
    model=G, D=D, lr_g=2e-4, lr_d=2e-4
)  # learning rates from original paper

# For AmbientGAN:
loss_g = adversarial.UnsupAdversarialGeneratorLoss(device=device)
loss_d = adversarial.UnsupAdversarialDiscriminatorLoss(device=device)

# For CSGM:
loss_g = adversarial.SupAdversarialGeneratorLoss(device=device)
loss_d = adversarial.SupAdversarialDiscriminatorLoss(device=device)

As before, we can now train our models. Since inference is very
slow for CSGM/AmbientGAN as it requires an optimisation, we only do one
evaluation at the end. Note the train PSNR is meaningless as this
generative model is trained on random latents.
Like above, we load a pretrained model trained in the exact same way for 50 epochs,
and fine-tune here for a quick demo with 1 epoch.




In [None]:
ckpt = torch.hub.load_state_dict_from_url(
    dinv.models.utils.get_weights_url("adversarial-demo", "csgm_model.pth"),
    map_location=lambda s, _: s,
)

G.load_state_dict(ckpt["state_dict"])
D.load_state_dict(ckpt["state_dict_D"])
optimizer.load_state_dict(ckpt["optimizer"])

trainer = dinv.training.AdversarialTrainer(
    model=G,
    D=D,
    physics=physics,
    train_dataloader=train_dataloader,
    epochs=1,
    losses=loss_g,
    losses_d=loss_d,
    optimizer=optimizer,
    scheduler=scheduler,
    verbose=True,
    show_progress_bar=False,
    save_path=None,
    device=device,
)
G = trainer.train()

Eventually, we run evaluation of the generative model by running test-time optimisation
using test measurements. Note that we do not get great results as CSGM /
AmbientGAN relies on large datasets of diverse samples, and we run the
optimisation to a relatively high tolerance for speed. Improve the results by
running the optimisation for longer.




In [None]:
trainer.test(test_dataloader)