SRResNet#

class deepinv.models.SRResNet(num_blocks=16, im_c=3, feats=64, upscale=4, actv=nn.PReLU, norm='batch_norm', final_kernel_size=9, final_relu=False, pretrained=None, device=torch.device('cpu'))[source]#

Bases: Reconstructor

SRResNet super-resolution network.

Convolutional super-resolution architecture introduced in Ledig et al.[1] as the generator of SRGAN. The network applies a feature-extraction conv, a stack of residual blocks (Conv-Norm-Activation-Conv-Norm with an additive skip), a long skip connection from the feature-extraction output, and finally a sequence of torch.nn.PixelShuffle-based upsampling stages followed by a wide output convolution.

The total upsampling factor is upscale and must be a power of two; the network contains \(\log_2(\text{upscale})\) upsampling stages, each doubling the spatial resolution.

The model is registered as a Reconstructor: its forward takes a low-resolution measurement y and returns a high-resolution estimate. The physics argument is accepted for API compatibility but is not used by this network.

Note

The defaults correspond to the network configuration in Ledig et al.[1].

Note

Pretrained weights are available for the default RGB 4× configuration trained on DIV2K under L1 loss with DownsamplingMatlab (bicubic, factor 4). These weights require final_relu=True. Load with pretrained="download".

Parameters:
  • num_blocks (int) – number of residual blocks in the trunk. Default: 16

  • im_c (int) – number of image channels (used for both input and output). Default: 3

  • feats (int) – number of feature channels in the trunk. Default: 64

  • upscale (int) – upsampling factor. Must be a power of two. Default: 4

  • actv (type[torch.nn.Module]) – activation layer class, instantiated with no arguments. Default: torch.nn.ReLU.

  • norm (str) – normalization layer, can be one of (‘instance_norm’, ‘batch_norm’, ‘layer_norm’, None). Default ‘batch_norm’.

  • final_kernel_size (int) – kernel size of the final output convolution. Must be odd. Default: 9.

  • final_relu (bool) – enforce non-negativity of output by performing a relu after final conv. Default: False

  • pretrained (str, None) – load pretrained weights. If "download", weights are downloaded from an online repository. If a file path string, weights are loaded from that path. If None, weights are randomly initialised. The available pretrained weights require the default architecture with final_relu=True.

  • device (torch.device, str) – Device to put the model on. Default: ‘cpu’


References:

forward(y, physics=None, **kwargs)[source]#

Apply the super-resolution network to a low-resolution input.

Parameters:
Returns:

(torch.Tensor) high-resolution estimate, of shape (B, im_c, upscale * H, upscale * W).

Return type:

torch.Tensor

Examples using SRResNet:#

Super-resolution with SRResNet

Super-resolution with SRResNet