CSGMGenerator#
- class deepinv.models.CSGMGenerator(backbone_generator: nn.Module = DCGANGenerator(), inf_max_iter=2500, inf_tol=1e-4, inf_lr=1e-2, inf_progress_bar=False)[source]#
Bases:
Reconstructor
Adapts a generator model backbone (e.g DCGAN) for CSGM or AmbientGAN.
This approach was proposed in Compressed Sensing using Generative Models and AmbientGAN: Generative models from lossy measurements (Bora et al.).
At train time, the generator samples latent vector from Unif[-1, 1] and passes through backbone.
At test time, CSGM/AmbientGAN runs an optimisation to find the best latent vector that fits the input measurements y, then outputs the corresponding reconstruction.
This generator can be overridden for more advanced optimisation algorithms by overriding
optimize_z
.See Imaging inverse problems with adversarial networks for how to use this for adversarial training.
Note
At train time, this generator discards the measurements
y
, but these measurements are used at test time. This means that train PSNR will be meaningless but test PSNR will be correct.- Parameters:
backbone_generator (nn.Module) – any neural network that maps a latent vector of length
nz
to an image, must havenz
attribute. Defaults to DCGANGenerator()inf_max_iter (int) – maximum iterations at inference-time optimisation, defaults to 2500
inf_tol (float) – tolerance of inference-time optimisation, defaults to 1e-2
inf_lr (float) – learning rate of inference-time optimisation, defaults to 1e-2
inf_progress_bar (bool) – whether to display progress bar for inference-time optimisation, defaults to False
- forward(y: Tensor, physics: Physics, *args, **kwargs) Tensor [source]#
Forward pass of generator model.
At train time, the generator samples latent vector from Unif[-1, 1] and passes through backbone.
At test time, CSGM/AmbientGAN runs an optimisation to find the best latent vector that fits the input measurements y, then outputs the corresponding reconstruction.
- Parameters:
y (Tensor) – measurement to reconstruct
physics (Physics) – forward model
- optimize_z(z: Tensor, y: Tensor, physics: Physics) Tensor [source]#
Run inference-time optimisation of latent z that is consistent with input measurement y according to physics.
The optimisation is defined with simple stopping criteria. Override this function for more advanced optimisation.
- Parameters:
z (Tensor) – initial latent variable guess
y (Tensor) – measurement with which to compare reconstructed image
physics (Physics) – forward model
- Return Tensor:
optimized z
- random_latent(device, requires_grad=True) Tensor [source]#
Generate a latent sample to feed into generative model.
The model must have an attribute nz which is the latent dimension.
- Parameters:
device (torch.device) – torch device
requires_grad (bool) – whether to require gradient, defaults to True.
Examples using CSGMGenerator
:#
Imaging inverse problems with adversarial networks