CSGMGenerator#

class deepinv.models.CSGMGenerator(backbone_generator: nn.Module = DCGANGenerator(), inf_max_iter=2500, inf_tol=1e-4, inf_lr=1e-2, inf_progress_bar=False)[source]#

Bases: Reconstructor

Adapts a generator model backbone (e.g DCGAN) for CSGM or AmbientGAN.

This approach was proposed in Compressed Sensing using Generative Models and AmbientGAN: Generative models from lossy measurements (Bora et al.).

At train time, the generator samples latent vector from Unif[-1, 1] and passes through backbone.

At test time, CSGM/AmbientGAN runs an optimisation to find the best latent vector that fits the input measurements y, then outputs the corresponding reconstruction.

This generator can be overridden for more advanced optimisation algorithms by overriding optimize_z.

See Imaging inverse problems with adversarial networks for how to use this for adversarial training.

Note

At train time, this generator discards the measurements y, but these measurements are used at test time. This means that train PSNR will be meaningless but test PSNR will be correct.

Parameters:
  • backbone_generator (nn.Module) – any neural network that maps a latent vector of length nz to an image, must have nz attribute. Defaults to DCGANGenerator()

  • inf_max_iter (int) – maximum iterations at inference-time optimisation, defaults to 2500

  • inf_tol (float) – tolerance of inference-time optimisation, defaults to 1e-2

  • inf_lr (float) – learning rate of inference-time optimisation, defaults to 1e-2

  • inf_progress_bar (bool) – whether to display progress bar for inference-time optimisation, defaults to False

forward(y: Tensor, physics: Physics, *args, **kwargs) Tensor[source]#

Forward pass of generator model.

At train time, the generator samples latent vector from Unif[-1, 1] and passes through backbone.

At test time, CSGM/AmbientGAN runs an optimisation to find the best latent vector that fits the input measurements y, then outputs the corresponding reconstruction.

Parameters:
  • y (Tensor) – measurement to reconstruct

  • physics (Physics) – forward model

optimize_z(z: Tensor, y: Tensor, physics: Physics) Tensor[source]#

Run inference-time optimisation of latent z that is consistent with input measurement y according to physics.

The optimisation is defined with simple stopping criteria. Override this function for more advanced optimisation.

Parameters:
  • z (Tensor) – initial latent variable guess

  • y (Tensor) – measurement with which to compare reconstructed image

  • physics (Physics) – forward model

Return Tensor:

optimized z

random_latent(device, requires_grad=True) Tensor[source]#

Generate a latent sample to feed into generative model.

The model must have an attribute nz which is the latent dimension.

Parameters:
  • device (torch.device) – torch device

  • requires_grad (bool) – whether to require gradient, defaults to True.

Examples using CSGMGenerator:#

Imaging inverse problems with adversarial networks

Imaging inverse problems with adversarial networks