AdversarialTrainer

class deepinv.training.AdversarialTrainer(model: ~torch.nn.modules.module.Module, physics: ~deepinv.physics.forward.Physics | ~typing.List[~deepinv.physics.forward.Physics], optimizer: AdversarialOptimizer, train_dataloader: ~torch.utils.data.dataloader.DataLoader, epochs: int = 100, losses: ~deepinv.loss.loss.Loss | ~deepinv.loss.scheduler.BaseLossScheduler | ~typing.List[~deepinv.loss.loss.Loss] | ~typing.List[~deepinv.loss.scheduler.BaseLossScheduler] = SupLoss(   (metric): MSELoss() ), eval_dataloader: ~torch.utils.data.dataloader.DataLoader = None, scheduler: <module 'torch.optim.lr_scheduler' from '/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torch/optim/lr_scheduler.py'> = None, metrics: ~deepinv.loss.metric.metric.Metric | ~typing.List[~deepinv.loss.metric.metric.Metric] = PSNR(), online_measurements: bool = False, physics_generator: ~deepinv.physics.generator.base.PhysicsGenerator | ~typing.List[~deepinv.physics.generator.base.PhysicsGenerator] = None, grad_clip: float = None, ckp_interval: int = 1, device: str | ~torch.device = 'cpu', eval_interval: int = 1, save_path: str | ~pathlib.Path = '.', verbose: bool = True, show_progress_bar: bool = True, plot_images: bool = False, plot_convergence_metrics: bool = False, wandb_vis: bool = False, wandb_setup: dict = <factory>, plot_measurements: bool = True, check_grad: bool = False, ckpt_pretrained: str | None = None, freq_plot: int = 1, verbose_individual_losses: bool = True, display_losses_eval: bool = False, rescale_mode: str = 'clip', compare_no_learning: bool = False, no_learning_method: str = 'A_adjoint', loop_physics_generator: bool = False, losses_d: Union[Loss, List[Loss]] = None, D: Module = None, step_ratio_D: int = 1)[source]

Bases: Trainer

Trainer class for training a reconstruction network using adversarial learning.

It overrides the deepinv.Trainer class to provide the same functionality, whilst supporting training using adversarial losses. Note that the forward pass remains the same.

The usual reconstruction model corresponds to the generator model in an adversarial framework, which is trained using losses specified in the losses argument. Additionally, a discriminator model D is also jointly trained using the losses provided in losses_d. The adversarial losses themselves are defined in the deepinv.loss.adversarial module. Examples of discriminators are in deepinv.models.gan.

See Imaging inverse problems with adversarial networks for usage.


Examples:

A very basic example:

>>> from deepinv.training import AdversarialTrainer, AdversarialOptimizer
>>> from deepinv.loss.adversarial import SupAdversarialGeneratorLoss, SupAdversarialDiscriminatorLoss
>>> from deepinv.models import UNet, PatchGANDiscriminator
>>> from deepinv.physics import LinearPhysics
>>> from deepinv.datasets.utils import PlaceholderDataset
>>>
>>> generator = UNet(scales=2)
>>> discrimin = PatchGANDiscriminator(1, 2, 1)
>>>
>>> optimizer = AdversarialOptimizer(
...     torch.optim.Adam(generator.parameters()),
...     torch.optim.Adam(discrimin.parameters()),
... )
>>>
>>> trainer = AdversarialTrainer(
...     model = generator,
...     D = discrimin,
...     physics = LinearPhysics(),
...     train_dataloader = torch.utils.data.DataLoader(PlaceholderDataset()),
...     epochs = 1,
...     losses = SupAdversarialGeneratorLoss(),
...     losses_d = SupAdversarialDiscriminatorLoss(),
...     optimizer = optimizer,
...     verbose = False
... )
>>>
>>> generator = trainer.train()

Note that this forward pass also computes y_hat ahead of time to avoid having to compute it multiple times, but this is completely optional.

See deepinv.Trainer for additional parameters.

Parameters:
  • optimizer (AdversarialOptimizer) – optimizer encapsulating both generator and discriminator optimizers

  • losses_d (Loss, list) – losses to train the discriminator, e.g. adversarial losses

  • D (Module) – discriminator/critic/classification model, which must take in an image and return a scalar

  • step_ratio_D (int) – every iteration, train D this many times, allowing for imbalanced generator/discriminator training. Defaults to 1.

check_clip_grad_D()[source]

Check the discriminator’s gradient norm and perform gradient clipping if necessary.

Analogous to check_clip_grad for generator.

compute_loss(physics, x, y, train=True, epoch: int | None = None)[source]

Compute losses and perform backward passes for both generator and discriminator networks.

Parameters:
Returns:

(tuple) The network reconstruction x_net (for plotting and computing metrics) and the logs (for printing the training progress).

save_model(epoch, eval_psnr=None)[source]

Save discriminator model parameters alongside other models.

setup_train(**kwargs)[source]

After usual Trainer setup, setup losses for discriminator too.

Examples using AdversarialTrainer:

Imaging inverse problems with adversarial networks

Imaging inverse problems with adversarial networks