AdversarialTrainer
- class deepinv.training.AdversarialTrainer(model: ~torch.nn.modules.module.Module, physics: ~deepinv.physics.forward.Physics | ~typing.List[~deepinv.physics.forward.Physics], optimizer: AdversarialOptimizer, train_dataloader: ~torch.utils.data.dataloader.DataLoader, epochs: int = 100, losses: ~deepinv.loss.loss.Loss | ~deepinv.loss.scheduler.BaseLossScheduler | ~typing.List[~deepinv.loss.loss.Loss] | ~typing.List[~deepinv.loss.scheduler.BaseLossScheduler] = SupLoss( (metric): MSELoss() ), eval_dataloader: ~torch.utils.data.dataloader.DataLoader = None, scheduler: <module 'torch.optim.lr_scheduler' from '/opt/hostedtoolcache/Python/3.9.20/x64/lib/python3.9/site-packages/torch/optim/lr_scheduler.py'> = None, metrics: ~deepinv.loss.metric.metric.Metric | ~typing.List[~deepinv.loss.metric.metric.Metric] = PSNR(), online_measurements: bool = False, physics_generator: ~deepinv.physics.generator.base.PhysicsGenerator | ~typing.List[~deepinv.physics.generator.base.PhysicsGenerator] = None, grad_clip: float = None, ckp_interval: int = 1, device: str | ~torch.device = 'cpu', eval_interval: int = 1, save_path: str | ~pathlib.Path = '.', verbose: bool = True, show_progress_bar: bool = True, plot_images: bool = False, plot_convergence_metrics: bool = False, wandb_vis: bool = False, wandb_setup: dict = <factory>, plot_measurements: bool = True, check_grad: bool = False, ckpt_pretrained: str | None = None, freq_plot: int = 1, verbose_individual_losses: bool = True, display_losses_eval: bool = False, rescale_mode: str = 'clip', compare_no_learning: bool = False, no_learning_method: str = 'A_adjoint', loop_physics_generator: bool = False, losses_d: Union[Loss, List[Loss]] = None, D: Module = None, step_ratio_D: int = 1)[source]
Bases:
Trainer
Trainer class for training a reconstruction network using adversarial learning.
It overrides the
deepinv.Trainer
class to provide the same functionality, whilst supporting training using adversarial losses. Note that the forward pass remains the same.The usual reconstruction model corresponds to the generator model in an adversarial framework, which is trained using losses specified in the
losses
argument. Additionally, a discriminator modelD
is also jointly trained using the losses provided inlosses_d
. The adversarial losses themselves are defined in thedeepinv.loss.adversarial
module. Examples of discriminators are indeepinv.models.gan
.See Imaging inverse problems with adversarial networks for usage.
- Examples:
A very basic example:
>>> from deepinv.training import AdversarialTrainer, AdversarialOptimizer >>> from deepinv.loss.adversarial import SupAdversarialGeneratorLoss, SupAdversarialDiscriminatorLoss >>> from deepinv.models import UNet, PatchGANDiscriminator >>> from deepinv.physics import LinearPhysics >>> from deepinv.datasets.utils import PlaceholderDataset >>> >>> generator = UNet(scales=2) >>> discrimin = PatchGANDiscriminator(1, 2, 1) >>> >>> optimizer = AdversarialOptimizer( ... torch.optim.Adam(generator.parameters()), ... torch.optim.Adam(discrimin.parameters()), ... ) >>> >>> trainer = AdversarialTrainer( ... model = generator, ... D = discrimin, ... physics = LinearPhysics(), ... train_dataloader = torch.utils.data.DataLoader(PlaceholderDataset()), ... epochs = 1, ... losses = SupAdversarialGeneratorLoss(), ... losses_d = SupAdversarialDiscriminatorLoss(), ... optimizer = optimizer, ... verbose = False ... ) >>> >>> generator = trainer.train()
Note that this forward pass also computes y_hat ahead of time to avoid having to compute it multiple times, but this is completely optional.
See
deepinv.Trainer
for additional parameters.- Parameters:
optimizer (AdversarialOptimizer) – optimizer encapsulating both generator and discriminator optimizers
losses_d (Loss, list) – losses to train the discriminator, e.g. adversarial losses
D (Module) – discriminator/critic/classification model, which must take in an image and return a scalar
step_ratio_D (int) – every iteration, train D this many times, allowing for imbalanced generator/discriminator training. Defaults to 1.
- check_clip_grad_D()[source]
Check the discriminator’s gradient norm and perform gradient clipping if necessary.
Analogous to
check_clip_grad
for generator.
- compute_loss(physics, x, y, train=True, epoch: int | None = None)[source]
Compute losses and perform backward passes for both generator and discriminator networks.
- Parameters:
physics (deepinv.physics.Physics) – Current physics operator.
x (torch.Tensor) – Ground truth.
y (torch.Tensor) – Measurement.
train (bool) – If
True
, the model is trained, otherwise it is evaluated.epoch (int) – current epoch.
- Returns:
(tuple) The network reconstruction x_net (for plotting and computing metrics) and the logs (for printing the training progress).
Examples using AdversarialTrainer
:
Imaging inverse problems with adversarial networks