SRResNet#
- class deepinv.models.SRResNet(num_blocks=16, im_c=3, feats=64, upscale=4, actv=nn.PReLU, norm='batch_norm', final_kernel_size=9, final_relu=False, pretrained=None, device=torch.device('cpu'))[source]#
Bases:
ReconstructorSRResNet super-resolution network.
Convolutional super-resolution architecture introduced in Ledig et al.[1] as the generator of SRGAN. The network applies a feature-extraction conv, a stack of residual blocks (Conv-Norm-Activation-Conv-Norm with an additive skip), a long skip connection from the feature-extraction output, and finally a sequence of
torch.nn.PixelShuffle-based upsampling stages followed by a wide output convolution.The total upsampling factor is
upscaleand must be a power of two; the network contains \(\log_2(\text{upscale})\) upsampling stages, each doubling the spatial resolution.The model is registered as a
Reconstructor: itsforwardtakes a low-resolution measurementyand returns a high-resolution estimate. Thephysicsargument is accepted for API compatibility but is not used by this network.Note
The defaults correspond to the network configuration in Ledig et al.[1].
Note
Pretrained weights are available for the default RGB 4× configuration trained on DIV2K under L1 loss with
DownsamplingMatlab(bicubic, factor 4). These weights requirefinal_relu=True. Load withpretrained="download".- Parameters:
num_blocks (int) – number of residual blocks in the trunk. Default: 16
im_c (int) – number of image channels (used for both input and output). Default: 3
feats (int) – number of feature channels in the trunk. Default: 64
upscale (int) – upsampling factor. Must be a power of two. Default: 4
actv (type[torch.nn.Module]) – activation layer class, instantiated with no arguments. Default:
torch.nn.ReLU.norm (str) – normalization layer, can be one of (‘instance_norm’, ‘batch_norm’, ‘layer_norm’, None). Default ‘batch_norm’.
final_kernel_size (int) – kernel size of the final output convolution. Must be odd. Default: 9.
final_relu (bool) – enforce non-negativity of output by performing a relu after final conv. Default: False
pretrained (str, None) – load pretrained weights. If
"download", weights are downloaded from an online repository. If a file path string, weights are loaded from that path. IfNone, weights are randomly initialised. The available pretrained weights require the default architecture withfinal_relu=True.device (torch.device, str) – Device to put the model on. Default: ‘cpu’
- References:
- forward(y, physics=None, **kwargs)[source]#
Apply the super-resolution network to a low-resolution input.
- Parameters:
y (torch.Tensor) – low-resolution input image, of shape
(B, im_c, H, W).physics (deepinv.physics.Physics) – forward operator (not used).
- Returns:
(
torch.Tensor) high-resolution estimate, of shape(B, im_c, upscale * H, upscale * W).- Return type: