DRUNet#

class deepinv.models.DRUNet(in_channels=3, out_channels=3, nc=[64, 128, 256, 512], nb=4, act_mode='R', downsample_mode='strideconv', upsample_mode='convtranspose', pretrained='download', device=None)[source]#

Bases: Denoiser

DRUNet denoiser network.

The network architecture is based on the paper Plug-and-Play Image Restoration with Deep Denoiser Prior, and has a U-Net like structure, with convolutional blocks in the encoder and decoder parts.

The network takes into account the noise level of the input image, which is encoded as an additional input channel.

A pretrained network for (in_channels=out_channels=1 or in_channels=out_channels=3) can be downloaded via setting pretrained='download'.

Parameters:
  • in_channels (int) – number of channels of the input.

  • out_channels (int) – number of channels of the output.

  • nc (list) – number of convolutional layers.

  • nb (int) – number of convolutional blocks per layer.

  • nf (int) – number of channels per convolutional layer.

  • act_mode (str) – activation mode, “R” for ReLU, “L” for LeakyReLU “E” for ELU and “s” for Softplus.

  • downsample_mode (str) – Downsampling mode, “avgpool” for average pooling, “maxpool” for max pooling, and “strideconv” for convolution with stride 2.

  • upsample_mode (str) – Upsampling mode, “convtranspose” for convolution transpose, “pixelsuffle” for pixel shuffling, and “upconv” for nearest neighbour upsampling with additional convolution.

  • pretrained (str, None) – use a pretrained network. If pretrained=None, the weights will be initialized at random using Pytorch’s default initialization. If pretrained='download', the weights will be downloaded from an online repository (only available for the default architecture with 3 or 1 input/output channels). Finally, pretrained can also be set as a path to the user’s own pretrained weights. See pretrained-weights for more details.

  • device (str) – gpu or cpu.

forward(x, sigma)[source]#

Run the denoiser on image with noise level \(\sigma\).

Parameters:
  • x (torch.Tensor) – noisy image

  • sigma (float, torch.Tensor) – noise level. If sigma is a float, it is used for all images in the batch. If sigma is a tensor, it must be of shape (batch_size,).

Examples using DRUNet:#

Random phase retrieval and reconstruction methods.

Random phase retrieval and reconstruction methods.

DPIR method for PnP image deblurring.

DPIR method for PnP image deblurring.

Image reconstruction with a diffusion model

Image reconstruction with a diffusion model