Imaging inverse problems with adversarial networks

This example shows you how to train various networks using adversarial training for deblurring problems. We demonstrate running training and inference using a conditional GAN (i.e. DeblurGAN), CSGM, AmbientGAN and UAIR implemented in the library, and how to simply train your own GAN by using deepinv.training.AdversarialTrainer(). These examples can also be easily extended to train more complicated GANs such as CycleGAN.

This example is based on the following papers:

Adversarial networks are characterised by the addition of an adversarial loss \(\mathcal{L}_\text{adv}\) to the standard reconstruction loss:

\[\mathcal{L}_\text{adv}(x,\hat x;D)=\mathbb{E}_{x\sim p_x}\left[q(D(x))\right]+\mathbb{E}_{\hat x\sim p_{\hat x}}\left[q(1-D(\hat x))\right]\]

where \(D(\cdot)\) is the discriminator model, \(x\) is the reference image, \(\hat x\) is the estimated reconstruction, \(q(\cdot)\) is a quality function (e.g \(q(x)=x\) for WGAN). Training alternates between generator \(G\) and discriminator \(D\) in a minimax game. When there are no ground truths (i.e. unsupervised), this may be defined on the measurements \(y\) instead.

import deepinv as dinv
from deepinv.loss import adversarial
from deepinv.physics.generator import MotionBlurGenerator
import torch
from torch.utils.data import DataLoader, random_split
from torchvision.datasets import ImageFolder
from torchvision.transforms import Compose, ToTensor, CenterCrop, Resize
from torchvision.datasets.utils import download_and_extract_archive

device = dinv.utils.get_freer_gpu() if torch.cuda.is_available() else "cpu"

Generate dataset

In this example we use the Urban100 dataset resized to 128x128. We apply random motion blur physics using deepinv.physics.generator.MotionBlurGenerator(), and save the data using deepinv.datasets.generate_dataset().

physics = dinv.physics.Blur(padding="circular", device=device)
blur_generator = MotionBlurGenerator((11, 11))

dataset = dinv.datasets.Urban100HR(
    root="Urban100",
    download=True,
    transform=Compose([ToTensor(), Resize(256), CenterCrop(128)]),
)

train_dataset, test_dataset = random_split(dataset, (0.8, 0.2))

# Generate data pairs x,y offline using a physics generator
dataset_path = dinv.datasets.generate_dataset(
    train_dataset=train_dataset,
    test_dataset=test_dataset,
    physics=physics,
    physics_generator=blur_generator,
    device=device,
    save_dir="Urban100",
    batch_size=1,
)

train_dataloader = DataLoader(
    dinv.datasets.HDF5Dataset(dataset_path, train=True), shuffle=True
)
test_dataloader = DataLoader(
    dinv.datasets.HDF5Dataset(dataset_path, train=False), shuffle=False
)
  0%|          | 0/135388067 [00:00<?, ?it/s]
 15%|█▍        | 19.1M/129M [00:00<00:00, 200MB/s]
 32%|███▏      | 41.8M/129M [00:00<00:00, 222MB/s]
 50%|█████     | 64.8M/129M [00:00<00:00, 231MB/s]
 68%|██████▊   | 87.9M/129M [00:00<00:00, 235MB/s]
 87%|████████▋ | 113M/129M [00:00<00:00, 245MB/s]
100%|██████████| 129M/129M [00:00<00:00, 238MB/s]

Extracting:   0%|          | 0/101 [00:00<?, ?it/s]
Extracting:  16%|█▌        | 16/101 [00:00<00:00, 146.46it/s]
Extracting:  32%|███▏      | 32/101 [00:00<00:00, 153.21it/s]
Extracting:  51%|█████▏    | 52/101 [00:00<00:00, 172.13it/s]
Extracting:  69%|██████▉   | 70/101 [00:00<00:00, 160.76it/s]
Extracting:  86%|████████▌ | 87/101 [00:00<00:00, 152.51it/s]
Extracting: 100%|██████████| 101/101 [00:00<00:00, 154.97it/s]
Dataset has been successfully downloaded.
Dataset has been saved at Urban100/dinv_dataset0.h5

Define models

We first define reconstruction network (i.e conditional generator) and discriminator network to use for adversarial training. For demonstration we use a simple U-Net as the reconstruction network and the discriminator from PatchGAN, but these can be replaced with any architecture e.g transformers, unrolled etc. Further discriminator models are in adversarial models.

def get_models(model=None, D=None, lr_g=1e-4, lr_d=1e-4, device=device):
    if model is None:
        model = dinv.models.UNet(
            in_channels=3,
            out_channels=3,
            scales=2,
            circular_padding=True,
            batch_norm=False,
        ).to(device)

    if D is None:
        D = dinv.models.PatchGANDiscriminator(n_layers=2, batch_norm=False).to(device)

    optimizer = dinv.training.adversarial.AdversarialOptimizer(
        torch.optim.Adam(model.parameters(), lr=lr_g, weight_decay=1e-8),
        torch.optim.Adam(D.parameters(), lr=lr_d, weight_decay=1e-8),
    )
    scheduler = dinv.training.adversarial.AdversarialScheduler(
        torch.optim.lr_scheduler.StepLR(optimizer.G, step_size=5, gamma=0.9),
        torch.optim.lr_scheduler.StepLR(optimizer.D, step_size=5, gamma=0.9),
    )

    return model, D, optimizer, scheduler

Conditional GAN training

Conditional GANs (Kupyn et al., DeblurGAN: Blind Motion Deblurring Using Conditional Adversarial Networks) are a type of GAN where the generator is conditioned on a label or input. In the context of imaging, this can be used to generate images from a given measurement. In this example, we use a simple U-Net as the generator and a PatchGAN discriminator. The forward pass of the generator is given by:

Conditional GAN forward pass:

\[\hat x = G(y)\]

Conditional GAN loss:

\[\mathcal{L}=\mathcal{L}_\text{sup}(\hat x, x)+\mathcal{L}_\text{adv}(\hat x, x;D)\]

where \(\mathcal{L}_\text{sup}\) is a supervised loss such as pixel-wise MSE or VGG Perceptual Loss.

G, D, optimizer, scheduler = get_models()

We next define pixel-wise and adversarial losses as defined above. We use the MSE for the supervised pixel-wise metric for simplicity but this can be easily replaced with a perceptual loss if desired.

We are now ready to train the networks using deepinv.training.AdversarialTrainer(). We load the pretrained models that were trained in the exact same way after 50 epochs, and fine-tune the model for 1 epoch for a quick demo. You can find the pretrained models on HuggingFace https://huggingface.co/deepinv/adversarial-demo. To train from scratch, simply comment out the model loading code and increase the number of epochs.

ckpt = torch.hub.load_state_dict_from_url(
    dinv.models.utils.get_weights_url("adversarial-demo", "deblurgan_model.pth"),
    map_location=lambda s, _: s,
)

G.load_state_dict(ckpt["state_dict"])
D.load_state_dict(ckpt["state_dict_D"])
optimizer.load_state_dict(ckpt["optimizer"])

trainer = dinv.training.AdversarialTrainer(
    model=G,
    D=D,
    physics=physics,
    train_dataloader=train_dataloader,
    eval_dataloader=test_dataloader,
    epochs=1,
    losses=loss_g,
    losses_d=loss_d,
    optimizer=optimizer,
    scheduler=scheduler,
    verbose=True,
    show_progress_bar=False,
    save_path=None,
    device=device,
)

G = trainer.train()
Downloading: "https://huggingface.co/deepinv/adversarial-demo/resolve/main/deblurgan_model.pth?download=true" to /home/runner/.cache/torch/hub/checkpoints/deblurgan_model.pth

  0%|          | 0.00/12.7M [00:00<?, ?B/s]
  9%|▉         | 1.12M/12.7M [00:00<00:01, 11.1MB/s]
 18%|█▊        | 2.25M/12.7M [00:00<00:00, 11.4MB/s]
 27%|██▋       | 3.38M/12.7M [00:00<00:00, 10.4MB/s]
 35%|███▌      | 4.50M/12.7M [00:00<00:00, 10.8MB/s]
 44%|████▍     | 5.62M/12.7M [00:00<00:00, 10.3MB/s]
 52%|█████▏    | 6.62M/12.7M [00:00<00:00, 10.4MB/s]
 61%|██████    | 7.75M/12.7M [00:00<00:00, 10.8MB/s]
 70%|██████▉   | 8.88M/12.7M [00:00<00:00, 10.3MB/s]
 78%|███████▊  | 9.88M/12.7M [00:00<00:00, 10.4MB/s]
 87%|████████▋ | 11.0M/12.7M [00:01<00:00, 10.7MB/s]
 95%|█████████▌| 12.1M/12.7M [00:01<00:00, 10.3MB/s]
100%|██████████| 12.7M/12.7M [00:01<00:00, 10.6MB/s]
The model has 444867 trainable parameters
Train epoch 0: SupLoss=0.004, SupAdversarialGeneratorLoss=0.003, TotalLoss=0.006, PSNR=25.826
Eval epoch 0: PSNR=25.339

Test the trained model and plot the results. We compare to the pseudo-inverse as a baseline.

trainer.plot_images = True
trainer.test(test_dataloader)
Ground truth, Measurement, No learning, Reconstruction
Eval epoch 0: PSNR=25.339, PSNR no learning=22.129
Test results:
PSNR no learning: 22.129 +- 2.703
PSNR: 25.339 +- 3.741

{'PSNR no learning': np.float64(22.128802490234374), 'PSNR no learning_std': np.float64(2.703303237720839), 'PSNR': np.float64(25.33935546875), 'PSNR_std': np.float64(3.7408121787192714)}

UAIR training

Unsupervised Adversarial Image Reconstruction (UAIR) (Pajot et al., Unsupervised Adversarial Image Reconstruction) is a method for solving inverse problems using generative models. In this example, we use a simple U-Net as the generator and discriminator, and train using the adversarial loss. The forward pass of the generator is defined as:

UAIR forward pass:

\[\hat x = G(y),\]

UAIR loss:

\[\mathcal{L}=\mathcal{L}_\text{adv}(\hat y, y;D)+\lVert \forw{\inverse{\hat y}}- \hat y\rVert^2_2,\quad\hat y=\forw{\hat x}.\]

We next load the models and construct losses as defined above.

G, D, optimizer, scheduler = get_models(
    lr_g=1e-4, lr_d=4e-4
)  # learning rates from original paper

loss_g = adversarial.UAIRGeneratorLoss(device=device)
loss_d = adversarial.UnsupAdversarialDiscriminatorLoss(device=device)

We are now ready to train the networks using deepinv.training.AdversarialTrainer(). Like above, we load a pretrained model trained in the exact same way for 50 epochs, and fine-tune here for a quick demo with 1 epoch.

ckpt = torch.hub.load_state_dict_from_url(
    dinv.models.utils.get_weights_url("adversarial-demo", "uair_model.pth"),
    map_location=lambda s, _: s,
)

G.load_state_dict(ckpt["state_dict"])
D.load_state_dict(ckpt["state_dict_D"])
optimizer.load_state_dict(ckpt["optimizer"])

trainer = dinv.training.AdversarialTrainer(
    model=G,
    D=D,
    physics=physics,
    train_dataloader=train_dataloader,
    eval_dataloader=test_dataloader,
    epochs=1,
    losses=loss_g,
    losses_d=loss_d,
    optimizer=optimizer,
    scheduler=scheduler,
    verbose=True,
    show_progress_bar=False,
    save_path=None,
    device=device,
)
G = trainer.train()
Downloading: "https://huggingface.co/deepinv/adversarial-demo/resolve/main/uair_model.pth?download=true" to /home/runner/.cache/torch/hub/checkpoints/uair_model.pth

  0%|          | 0.00/12.7M [00:00<?, ?B/s]
  9%|▉         | 1.12M/12.7M [00:00<00:01, 11.2MB/s]
 18%|█▊        | 2.25M/12.7M [00:00<00:00, 11.4MB/s]
 27%|██▋       | 3.38M/12.7M [00:00<00:00, 10.5MB/s]
 35%|███▌      | 4.50M/12.7M [00:00<00:00, 10.9MB/s]
 44%|████▍     | 5.62M/12.7M [00:00<00:00, 10.3MB/s]
 52%|█████▏    | 6.62M/12.7M [00:00<00:00, 10.3MB/s]
 61%|██████    | 7.75M/12.7M [00:00<00:00, 10.8MB/s]
 70%|██████▉   | 8.88M/12.7M [00:00<00:00, 10.3MB/s]
 78%|███████▊  | 9.88M/12.7M [00:00<00:00, 10.4MB/s]
 86%|████████▌ | 10.9M/12.7M [00:01<00:00, 10.3MB/s]
 93%|█████████▎| 11.9M/12.7M [00:01<00:00, 10.3MB/s]
100%|██████████| 12.7M/12.7M [00:01<00:00, 10.6MB/s]
The model has 444867 trainable parameters
Train epoch 0: TotalLoss=0.143, PSNR=24.828
Eval epoch 0: PSNR=24.388

Test the trained model and plot the results:

trainer.plot_images = True
trainer.test(test_dataloader)
Ground truth, Measurement, No learning, Reconstruction
Eval epoch 0: PSNR=24.388, PSNR no learning=22.129
Test results:
PSNR no learning: 22.129 +- 2.703
PSNR: 24.388 +- 3.427

{'PSNR no learning': np.float64(22.128802490234374), 'PSNR no learning_std': np.float64(2.703303237720839), 'PSNR': np.float64(24.388037109375), 'PSNR_std': np.float64(3.4268960671244293)}

CSGM / AmbientGAN training

Compressed Sensing using Generative Models (CSGM) and AmbientGAN are two methods for solving inverse problems using generative models. CSGM (Bora et al., Compressed Sensing using Generative Models) uses a generative model to solve the inverse problem by optimising the latent space of the generator. AmbientGAN (Bora et al., AmbientGAN: Generative models from lossy measurements) uses a generative model to solve the inverse problem by optimising the measurements themselves. Both methods are trained using an adversarial loss; the main difference is that CSGM requires a ground truth dataset (supervised loss), while AmbientGAN does not (unsupervised loss).

In this example, we use a DCGAN as the generator and discriminator, and train using the adversarial loss. The forward pass of the generator is given by:

CSGM forward pass at train time:

\[\hat x = \inverse{z},\quad z\sim \mathcal{N}(\mathbf{0},\mathbf{I}_k)\]

CSGM/AmbientGAN forward pass at eval time:

\[\hat x = \inverse{\hat z}\quad\text{s.t.}\quad\hat z=\operatorname*{argmin}_z \lVert \forw{\inverse{z}}-y\rVert _2^2\]

CSGM loss:

\[\mathcal{L}=\mathcal{L}_\text{adv}(\hat x, x;D)\]

AmbientGAN loss (where \(\forw{\cdot}\) is the physics):

\[\mathcal{L}=\mathcal{L}_\text{adv}(\forw{\hat x}, y;D)\]

We next load the models and construct losses as defined above.

G = dinv.models.CSGMGenerator(
    dinv.models.DCGANGenerator(output_size=128, nz=100, ngf=32), inf_tol=1e-2
).to(device)
D = dinv.models.DCGANDiscriminator(ndf=32).to(device)
_, _, optimizer, scheduler = get_models(
    model=G, D=D, lr_g=2e-4, lr_d=2e-4
)  # learning rates from original paper

# For AmbientGAN:
loss_g = adversarial.UnsupAdversarialGeneratorLoss(device=device)
loss_d = adversarial.UnsupAdversarialDiscriminatorLoss(device=device)

# For CSGM:
loss_g = adversarial.SupAdversarialGeneratorLoss(device=device)
loss_d = adversarial.SupAdversarialDiscriminatorLoss(device=device)

As before, we can now train our models. Since inference is very slow for CSGM/AmbientGAN as it requires an optimisation, we only do one evaluation at the end. Note the train PSNR is meaningless as this generative model is trained on random latents. Like above, we load a pretrained model trained in the exact same way for 50 epochs, and fine-tune here for a quick demo with 1 epoch.

ckpt = torch.hub.load_state_dict_from_url(
    dinv.models.utils.get_weights_url("adversarial-demo", "csgm_model.pth"),
    map_location=lambda s, _: s,
)

G.load_state_dict(ckpt["state_dict"])
D.load_state_dict(ckpt["state_dict_D"])
optimizer.load_state_dict(ckpt["optimizer"])

trainer = dinv.training.AdversarialTrainer(
    model=G,
    D=D,
    physics=physics,
    train_dataloader=train_dataloader,
    epochs=1,
    losses=loss_g,
    losses_d=loss_d,
    optimizer=optimizer,
    scheduler=scheduler,
    verbose=True,
    show_progress_bar=False,
    save_path=None,
    device=device,
)
G = trainer.train()
Downloading: "https://huggingface.co/deepinv/adversarial-demo/resolve/main/csgm_model.pth?download=true" to /home/runner/.cache/torch/hub/checkpoints/csgm_model.pth

  0%|          | 0.00/49.3M [00:00<?, ?B/s]
  2%|▏         | 1.12M/49.3M [00:00<00:04, 11.0MB/s]
  5%|▍         | 2.25M/49.3M [00:00<00:04, 11.1MB/s]
  7%|▋         | 3.38M/49.3M [00:00<00:04, 10.3MB/s]
  9%|▉         | 4.38M/49.3M [00:00<00:04, 10.3MB/s]
 11%|█         | 5.38M/49.3M [00:00<00:04, 10.3MB/s]
 13%|█▎        | 6.38M/49.3M [00:00<00:04, 10.4MB/s]
 15%|█▍        | 7.38M/49.3M [00:00<00:04, 10.3MB/s]
 17%|█▋        | 8.38M/49.3M [00:00<00:04, 10.3MB/s]
 19%|█▉        | 9.38M/49.3M [00:00<00:04, 10.3MB/s]
 21%|██        | 10.4M/49.3M [00:01<00:03, 10.3MB/s]
 23%|██▎       | 11.4M/49.3M [00:01<00:03, 10.3MB/s]
 25%|██▌       | 12.4M/49.3M [00:01<00:03, 10.3MB/s]
 27%|██▋       | 13.4M/49.3M [00:01<00:03, 10.4MB/s]
 29%|██▉       | 14.4M/49.3M [00:01<00:03, 10.4MB/s]
 31%|███       | 15.4M/49.3M [00:01<00:03, 10.4MB/s]
 33%|███▎      | 16.4M/49.3M [00:01<00:03, 10.4MB/s]
 36%|███▌      | 17.5M/49.3M [00:01<00:03, 10.8MB/s]
 38%|███▊      | 18.6M/49.3M [00:01<00:03, 10.3MB/s]
 40%|███▉      | 19.6M/49.3M [00:01<00:03, 10.3MB/s]
 42%|████▏     | 20.6M/49.3M [00:02<00:02, 10.4MB/s]
 44%|████▍     | 21.6M/49.3M [00:02<00:02, 10.3MB/s]
 46%|████▌     | 22.6M/49.3M [00:02<00:02, 10.3MB/s]
 48%|████▊     | 23.9M/49.3M [00:02<00:02, 10.4MB/s]
 50%|█████     | 24.9M/49.3M [00:02<00:02, 10.4MB/s]
 52%|█████▏    | 25.9M/49.3M [00:02<00:02, 10.4MB/s]
 55%|█████▍    | 26.9M/49.3M [00:02<00:02, 10.3MB/s]
 57%|█████▋    | 27.9M/49.3M [00:02<00:02, 10.3MB/s]
 59%|█████▊    | 28.9M/49.3M [00:02<00:02, 10.3MB/s]
 61%|██████    | 29.9M/49.3M [00:03<00:01, 10.4MB/s]
 63%|██████▎   | 30.9M/49.3M [00:03<00:01, 10.4MB/s]
 65%|██████▍   | 31.9M/49.3M [00:03<00:01, 10.4MB/s]
 67%|██████▋   | 32.9M/49.3M [00:03<00:01, 10.4MB/s]
 69%|██████▊   | 33.9M/49.3M [00:03<00:01, 10.4MB/s]
 71%|███████   | 34.9M/49.3M [00:03<00:01, 10.4MB/s]
 73%|███████▎  | 35.9M/49.3M [00:03<00:01, 10.4MB/s]
 75%|███████▍  | 36.9M/49.3M [00:03<00:01, 10.4MB/s]
 77%|███████▋  | 37.9M/49.3M [00:03<00:01, 10.4MB/s]
 79%|███████▉  | 38.9M/49.3M [00:03<00:01, 10.4MB/s]
 81%|████████  | 40.0M/49.3M [00:04<00:00, 10.8MB/s]
 83%|████████▎ | 41.1M/49.3M [00:04<00:00, 10.3MB/s]
 85%|████████▌ | 42.1M/49.3M [00:04<00:00, 10.3MB/s]
 88%|████████▊ | 43.2M/49.3M [00:04<00:00, 10.7MB/s]
 90%|█████████ | 44.4M/49.3M [00:04<00:00, 10.3MB/s]
 92%|█████████▏| 45.4M/49.3M [00:04<00:00, 10.3MB/s]
 94%|█████████▍| 46.5M/49.3M [00:04<00:00, 10.7MB/s]
 97%|█████████▋| 47.6M/49.3M [00:04<00:00, 10.3MB/s]
 99%|█████████▊| 48.6M/49.3M [00:04<00:00, 10.3MB/s]
100%|██████████| 49.3M/49.3M [00:04<00:00, 10.4MB/s]
The model has 3608000 trainable parameters
Train epoch 0: TotalLoss=0.007, PSNR=9.163

Eventually, we run evaluation of the generative model by running test-time optimisation using test measurements. Note that we do not get great results as CSGM / AmbientGAN relies on large datasets of diverse samples, and we run the optimisation to a relatively high tolerance for speed. Improve the results by running the optimisation for longer.

trainer.test(test_dataloader)
Eval epoch 0: PSNR=9.528, PSNR no learning=22.129
Test results:
PSNR no learning: 22.129 +- 2.703
PSNR: 9.528 +- 1.301

{'PSNR no learning': np.float64(22.128802490234374), 'PSNR no learning_std': np.float64(2.703303237720839), 'PSNR': np.float64(9.527952575683594), 'PSNR_std': np.float64(1.301279076545696)}

Total running time of the script: (1 minutes 29.218 seconds)

Gallery generated by Sphinx-Gallery