DRUNet
- class deepinv.models.DRUNet(in_channels=3, out_channels=3, nc=[64, 128, 256, 512], nb=4, act_mode='R', downsample_mode='strideconv', upsample_mode='convtranspose', pretrained='download', device=None)[source]
Bases:
Module
DRUNet denoiser network.
The network architecture is based on the paper Plug-and-Play Image Restoration with Deep Denoiser Prior, and has a U-Net like structure, with convolutional blocks in the encoder and decoder parts.
The network takes into account the noise level of the input image, which is encoded as an additional input channel.
A pretrained network for (in_channels=out_channels=1 or in_channels=out_channels=3) can be downloaded via setting
pretrained='download'
.- Parameters:
in_channels (int) – number of channels of the input.
out_channels (int) – number of channels of the output.
nc (list) – number of convolutional layers.
nb (int) – number of convolutional blocks per layer.
nf (int) – number of channels per convolutional layer.
act_mode (str) – activation mode, “R” for ReLU, “L” for LeakyReLU “E” for ELU and “s” for Softplus.
downsample_mode (str) – Downsampling mode, “avgpool” for average pooling, “maxpool” for max pooling, and “strideconv” for convolution with stride 2.
upsample_mode (str) – Upsampling mode, “convtranspose” for convolution transpose, “pixelsuffle” for pixel shuffling, and “upconv” for nearest neighbour upsampling with additional convolution.
pretrained (str, None) – use a pretrained network. If
pretrained=None
, the weights will be initialized at random using Pytorch’s default initialization. Ifpretrained='download'
, the weights will be downloaded from an online repository (only available for the default architecture with 3 or 1 input/output channels). Finally,pretrained
can also be set as a path to the user’s own pretrained weights. See pretrained-weights for more details.train (bool) – training or testing mode.
device (str) – gpu or cpu.
- forward(x, sigma)[source]
Run the denoiser on image with noise level \(\sigma\).
- Parameters:
x (torch.Tensor) – noisy image
sigma (float, torch.Tensor) – noise level. If
sigma
is a float, it is used for all images in the batch. Ifsigma
is a tensor, it must be of shape(batch_size,)
.
Examples using DRUNet
:
Random phase retrieval and reconstruction methods.
DPIR method for PnP image deblurring.
Image reconstruction with a diffusion model