.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/adversarial-learning/demo_gan_imaging.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note New to DeepInverse? Get started with the basics with the :ref:`5 minute quickstart tutorial `.. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_adversarial-learning_demo_gan_imaging.py: Imaging inverse problems with adversarial networks ================================================== This example shows you how to train various networks using adversarial training for deblurring problems. We demonstrate running training and inference using a conditional GAN (i.e. DeblurGAN), CSGM, AmbientGAN and UAIR implemented in the library, and how to simply train your own GAN by using :class:`deepinv.training.AdversarialTrainer`. These examples can also be easily extended to train more complicated GANs such as CycleGAN. This example is based on the papers DeblurGAN :footcite:p:`kupyn2018deblurgan`, Compressed Sensing using Generative Models (CSGM) :footcite:p:`bora2017compressed`, AmbiantGAN :footcite:p:`bora2018ambientgan`, and Unsupervised Adversarial Image Reconstruction (UAIR) :footcite:p:`pajot2019unsupervised`. Adversarial networks are characterized by the addition of an adversarial loss :math:`\mathcal{L}_\text{adv}` to the standard reconstruction loss: .. math:: \mathcal{L}_\text{adv}(x,\hat x;D)=\mathbb{E}_{x\sim p_x}\left[q(D(x))\right]+\mathbb{E}_{\hat x\sim p_{\hat x}}\left[q(1-D(\hat x))\right] where :math:`D(\cdot)` is the discriminator model, :math:`x` is the reference image, :math:`\hat x` is the estimated reconstruction, :math:`q(\cdot)` is a quality function (e.g :math:`q(x)=x` for WGAN). Training alternates between generator :math:`G` and discriminator :math:`D` in a minimax game. When there are no ground truths (i.e. unsupervised), this may be defined on the measurements :math:`y` instead. .. GENERATED FROM PYTHON SOURCE LINES 31-50 .. code-block:: Python from pathlib import Path import torch from torch.utils.data import DataLoader, random_split from torchvision.transforms import Compose, ToTensor, CenterCrop, Resize import deepinv as dinv from deepinv.loss import adversarial from deepinv.utils import get_data_home from deepinv.physics.generator import MotionBlurGenerator device = dinv.utils.get_freer_gpu() if torch.cuda.is_available() else "cpu" BASE_DIR = Path(".") DATA_DIR = BASE_DIR / "measurments" ORGINAL_DATA_DIR = get_data_home() / "Urban100" .. rst-class:: sphx-glr-script-out .. code-block:: none Selected GPU 0 with 1765.25 MiB free memory .. GENERATED FROM PYTHON SOURCE LINES 51-58 Generate dataset ~~~~~~~~~~~~~~~~ In this example we use the Urban100 dataset resized to 128x128. We apply random motion blur physics using :class:`deepinv.physics.generator.MotionBlurGenerator`, and save the data using :func:`deepinv.datasets.generate_dataset`. .. GENERATED FROM PYTHON SOURCE LINES 58-95 .. code-block:: Python physics = dinv.physics.Blur(padding="circular", device=device) blur_generator = MotionBlurGenerator((11, 11), device=device) dataset = dinv.datasets.Urban100HR( root=ORGINAL_DATA_DIR, download=True, transform=Compose([ToTensor(), Resize(256), CenterCrop(128)]), ) train_dataset, test_dataset = random_split(dataset, (0.8, 0.2)) # Generate data pairs x,y offline using a physics generator dataset_path = dinv.datasets.generate_dataset( train_dataset=train_dataset, test_dataset=test_dataset, physics=physics, physics_generator=blur_generator, device=device, save_dir=DATA_DIR, batch_size=1, ) train_dataloader = DataLoader( dinv.datasets.HDF5Dataset( dataset_path, train=True, load_physics_generator_params=True ), shuffle=True, ) test_dataloader = DataLoader( dinv.datasets.HDF5Dataset( dataset_path, train=False, load_physics_generator_params=True ), shuffle=False, ) .. rst-class:: sphx-glr-script-out .. code-block:: none 0%| | 0/135388067 [00:00`. .. GENERATED FROM PYTHON SOURCE LINES 106-133 .. code-block:: Python def get_models(model=None, D=None, lr_g=1e-4, lr_d=1e-4, device=device): if model is None: model = dinv.models.UNet( in_channels=3, out_channels=3, scales=2, circular_padding=True, batch_norm=False, ).to(device) if D is None: D = dinv.models.PatchGANDiscriminator(n_layers=2, batch_norm=False).to(device) optimizer = dinv.training.adversarial.AdversarialOptimizer( torch.optim.Adam(model.parameters(), lr=lr_g, weight_decay=1e-8), torch.optim.Adam(D.parameters(), lr=lr_d, weight_decay=1e-8), ) scheduler = dinv.training.adversarial.AdversarialScheduler( torch.optim.lr_scheduler.StepLR(optimizer.G, step_size=5, gamma=0.9), torch.optim.lr_scheduler.StepLR(optimizer.D, step_size=5, gamma=0.9), ) return model, D, optimizer, scheduler .. GENERATED FROM PYTHON SOURCE LINES 134-153 Conditional GAN training ~~~~~~~~~~~~~~~~~~~~~~~~ Conditional GANs :footcite:p:`kupyn2018deblurgan` are a type of GAN where the generator is conditioned on a label or input. In the context of imaging, this can be used to generate images from a given measurement. In this example, we use a simple U-Net as the generator and a PatchGAN discriminator. The forward pass of the generator is given by: **Conditional GAN** forward pass: .. math:: \hat x = G(y) **Conditional GAN** loss: .. math:: \mathcal{L}=\mathcal{L}_\text{sup}(\hat x, x)+\mathcal{L}_\text{adv}(\hat x, x;D) where :math:`\mathcal{L}_\text{sup}` is a supervised loss such as pixel-wise MSE or VGG Perceptual Loss. .. GENERATED FROM PYTHON SOURCE LINES 153-157 .. code-block:: Python G, D, optimizer, scheduler = get_models() .. GENERATED FROM PYTHON SOURCE LINES 158-162 We next define pixel-wise and adversarial losses as defined above. We use the MSE for the supervised pixel-wise metric for simplicity but this can be easily replaced with a perceptual loss if desired. .. GENERATED FROM PYTHON SOURCE LINES 162-170 .. code-block:: Python loss_g = [ dinv.loss.SupLoss(metric=torch.nn.MSELoss()), adversarial.SupAdversarialGeneratorLoss(device=device), ] loss_d = adversarial.SupAdversarialDiscriminatorLoss(device=device) .. GENERATED FROM PYTHON SOURCE LINES 171-177 We are now ready to train the networks using :class:`deepinv.training.AdversarialTrainer`. We load the pretrained models that were trained in the exact same way after 50 epochs, and fine-tune the model for 1 epoch for a quick demo. You can find the pretrained models on HuggingFace https://huggingface.co/deepinv/adversarial-demo. To train from scratch, simply comment out the model loading code and increase the number of epochs. .. GENERATED FROM PYTHON SOURCE LINES 177-206 .. code-block:: Python ckpt = torch.hub.load_state_dict_from_url( dinv.models.utils.get_weights_url("adversarial-demo", "deblurgan_model.pth"), map_location=lambda s, _: s, ) G.load_state_dict(ckpt["state_dict"]) D.load_state_dict(ckpt["state_dict_D"]) optimizer.load_state_dict(ckpt["optimizer"]) trainer = dinv.training.AdversarialTrainer( model=G, D=D, physics=physics, train_dataloader=train_dataloader, eval_dataloader=test_dataloader, epochs=1, losses=loss_g, losses_d=loss_d, optimizer=optimizer, scheduler=scheduler, verbose=True, show_progress_bar=False, save_path=None, device=device, ) G = trainer.train() .. rst-class:: sphx-glr-script-out .. code-block:: none /local/jtachell/deepinv/deepinv/deepinv/training/adversarial.py:165: UserWarning: non_blocking_transfers=True but DataLoader.pin_memory=False; set pin_memory=True to overlap host-device copies with compute. super().setup_train(**kwargs) The model has 444867 trainable parameters /local/jtachell/deepinv/deepinv/deepinv/training/adversarial.py:168: UserWarning: optimizer_step_multi_dataset parameter of Trainer is should be set to `False` when using adversarial trainer. Automatically setting it to `False`. warnings.warn( Train epoch 0: SupLoss=0.005, SupAdversarialGeneratorLoss=0.003, TotalLoss=0.008, PSNR=24.614 Eval epoch 0: PSNR=25.977 Best model saved at epoch 1 .. GENERATED FROM PYTHON SOURCE LINES 207-209 Test the trained model and plot the results. We compare to the pseudo-inverse as a baseline. .. GENERATED FROM PYTHON SOURCE LINES 209-214 .. code-block:: Python trainer.plot_images = True trainer.test(test_dataloader) .. image-sg:: /auto_examples/adversarial-learning/images/sphx_glr_demo_gan_imaging_001.png :alt: Ground truth, Measurement, No learning, Reconstruction :srcset: /auto_examples/adversarial-learning/images/sphx_glr_demo_gan_imaging_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none Eval epoch 0: PSNR=25.977, PSNR no learning=24.797 Test results: PSNR no learning: 24.797 +- 5.002 PSNR: 25.977 +- 3.971 {'PSNR no learning': 24.79730405807495, 'PSNR no learning_std': 5.0017257362374234, 'PSNR': 25.976995182037353, 'PSNR_std': 3.9707933192112295} .. GENERATED FROM PYTHON SOURCE LINES 215-232 UAIR training ~~~~~~~~~~~~~ Unsupervised Adversarial Image Reconstruction (UAIR) :footcite:p:`pajot2019unsupervised` is a method for solving inverse problems using generative models. In this example, we use a simple U-Net as the generator and discriminator, and train using the adversarial loss. The forward pass of the generator is defined as: **UAIR** forward pass: .. math:: \hat x = G(y), **UAIR** loss: .. math:: \mathcal{L}=\mathcal{L}_\text{adv}(\hat y, y;D)+\lVert \forw{\inverse{\hat y}}- \hat y\rVert^2_2,\quad\hat y=\forw{\hat x}. We next load the models and construct losses as defined above. .. GENERATED FROM PYTHON SOURCE LINES 232-241 .. code-block:: Python G, D, optimizer, scheduler = get_models( lr_g=1e-4, lr_d=4e-4 ) # learning rates from original paper loss_g = adversarial.UAIRGeneratorLoss(device=device) loss_d = adversarial.UnsupAdversarialDiscriminatorLoss(device=device) .. GENERATED FROM PYTHON SOURCE LINES 242-246 We are now ready to train the networks using :class:`deepinv.training.AdversarialTrainer`. Like above, we load a pretrained model trained in the exact same way for 50 epochs, and fine-tune here for a quick demo with 1 epoch. .. GENERATED FROM PYTHON SOURCE LINES 246-274 .. code-block:: Python ckpt = torch.hub.load_state_dict_from_url( dinv.models.utils.get_weights_url("adversarial-demo", "uair_model.pth"), map_location=lambda s, _: s, ) G.load_state_dict(ckpt["state_dict"]) D.load_state_dict(ckpt["state_dict_D"]) optimizer.load_state_dict(ckpt["optimizer"]) trainer = dinv.training.AdversarialTrainer( model=G, D=D, physics=physics, train_dataloader=train_dataloader, eval_dataloader=test_dataloader, epochs=1, losses=loss_g, losses_d=loss_d, optimizer=optimizer, scheduler=scheduler, verbose=True, show_progress_bar=False, save_path=None, device=device, ) G = trainer.train() .. rst-class:: sphx-glr-script-out .. code-block:: none The model has 444867 trainable parameters Train epoch 0: TotalLoss=0.137, PSNR=24.009 Eval epoch 0: PSNR=25.274 Best model saved at epoch 1 .. GENERATED FROM PYTHON SOURCE LINES 275-277 Test the trained model and plot the results: .. GENERATED FROM PYTHON SOURCE LINES 277-282 .. code-block:: Python trainer.plot_images = True trainer.test(test_dataloader) .. image-sg:: /auto_examples/adversarial-learning/images/sphx_glr_demo_gan_imaging_002.png :alt: Ground truth, Measurement, No learning, Reconstruction :srcset: /auto_examples/adversarial-learning/images/sphx_glr_demo_gan_imaging_002.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-script-out .. code-block:: none Eval epoch 0: PSNR=25.274, PSNR no learning=24.797 Test results: PSNR no learning: 24.797 +- 5.002 PSNR: 25.274 +- 3.547 {'PSNR no learning': 24.79730405807495, 'PSNR no learning_std': 5.0017257362374234, 'PSNR': 25.273712825775146, 'PSNR_std': 3.5472013802552587} .. GENERATED FROM PYTHON SOURCE LINES 283-312 CSGM / AmbientGAN training ~~~~~~~~~~~~~~~~~~~~~~~~~~ Compressed Sensing using Generative Models (CSGM) :footcite:p:`bora2017compressed` and AmbientGAN :footcite:p:`bora2018ambientgan` are two methods for solving inverse problems using generative models. CSGM uses a generative model to solve the inverse problem by optimising the latent space of the generator. AmbientGAN uses a generative model to solve the inverse problem by optimising the measurements themselves. Both methods are trained using an adversarial loss; the main difference is that CSGM requires a ground truth dataset (supervised loss), while AmbientGAN does not (unsupervised loss). In this example, we use a DCGAN as the generator and discriminator, and train using the adversarial loss. The forward pass of the generator is given by: **CSGM** forward pass at train time: .. math:: \hat x = \inverse{z},\quad z\sim \mathcal{N}(\mathbf{0},\mathbf{I}_k) **CSGM**/**AmbientGAN** forward pass at eval time: .. math:: \hat x = \inverse{\hat z}\quad\text{s.t.}\quad\hat z=\operatorname*{argmin}_z \lVert \forw{\inverse{z}}-y\rVert _2^2 **CSGM** loss: .. math:: \mathcal{L}=\mathcal{L}_\text{adv}(\hat x, x;D) **AmbientGAN** loss (where :math:`\forw{\cdot}` is the physics): .. math:: \mathcal{L}=\mathcal{L}_\text{adv}(\forw{\hat x}, y;D) We next load the models and construct losses as defined above. .. GENERATED FROM PYTHON SOURCE LINES 312-330 .. code-block:: Python G = dinv.models.CSGMGenerator( dinv.models.DCGANGenerator(output_size=128, nz=100, ngf=32), inf_tol=1e-2 ).to(device) D = dinv.models.DCGANDiscriminator(ndf=32).to(device) _, _, optimizer, scheduler = get_models( model=G, D=D, lr_g=2e-4, lr_d=2e-4 ) # learning rates from original paper # For AmbientGAN: loss_g = adversarial.UnsupAdversarialGeneratorLoss(device=device) loss_d = adversarial.UnsupAdversarialDiscriminatorLoss(device=device) # For CSGM: loss_g = adversarial.SupAdversarialGeneratorLoss(device=device) loss_d = adversarial.SupAdversarialDiscriminatorLoss(device=device) .. GENERATED FROM PYTHON SOURCE LINES 331-338 As before, we can now train our models. Since inference is very slow for CSGM/AmbientGAN as it requires an optimisation, we only do one evaluation at the end. Note the train PSNR is meaningless as this generative model is trained on random latents. Like above, we load a pretrained model trained in the exact same way for 50 epochs, and fine-tune here for a quick demo with 1 epoch. .. GENERATED FROM PYTHON SOURCE LINES 338-366 .. code-block:: Python ckpt = torch.hub.load_state_dict_from_url( dinv.models.utils.get_weights_url("adversarial-demo", "csgm_model.pth"), map_location=lambda s, _: s, ) G.load_state_dict(ckpt["state_dict"]) D.load_state_dict(ckpt["state_dict_D"]) optimizer.load_state_dict(ckpt["optimizer"]) trainer = dinv.training.AdversarialTrainer( model=G, D=D, physics=physics, train_dataloader=train_dataloader, epochs=1, losses=loss_g, losses_d=loss_d, optimizer=optimizer, scheduler=scheduler, verbose=True, show_progress_bar=False, save_path=None, device=device, ) G = trainer.train() .. rst-class:: sphx-glr-script-out .. code-block:: none The model has 3608000 trainable parameters Train epoch 0: TotalLoss=0.007, PSNR=9.108 .. GENERATED FROM PYTHON SOURCE LINES 367-373 Eventually, we run evaluation of the generative model by running test-time optimisation using test measurements. Note that we do not get great results as CSGM / AmbientGAN relies on large datasets of diverse samples, and we run the optimisation to a relatively high tolerance for speed. Improve the results by running the optimisation for longer. .. GENERATED FROM PYTHON SOURCE LINES 373-376 .. code-block:: Python trainer.test(test_dataloader) .. rst-class:: sphx-glr-script-out .. code-block:: none Eval epoch 0: PSNR=10.313, PSNR no learning=24.797 Test results: PSNR no learning: 24.797 +- 5.002 PSNR: 10.313 +- 0.854 {'PSNR no learning': 24.79730405807495, 'PSNR no learning_std': 5.0017257362374234, 'PSNR': 10.31317310333252, 'PSNR_std': 0.8543547880178533} .. GENERATED FROM PYTHON SOURCE LINES 377-380 :References: .. footbibliography:: .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 17.729 seconds) .. _sphx_glr_download_auto_examples_adversarial-learning_demo_gan_imaging.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: demo_gan_imaging.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: demo_gan_imaging.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: demo_gan_imaging.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_