PromptIR#

class deepinv.models.PromptIR(in_channels=3, out_channels=3, dim=48, num_blocks=(4, 6, 6, 8), num_refinement_blocks=4, heads=(1, 2, 4, 8), ffn_expansion_factor=2.66, bias=False, LayerNorm_type='WithBias', decoder=True, device=None, pretrained=None)[source]#

Bases: Reconstructor, Denoiser

PromptIR restoration model.

PromptIR is a blind restoration model that was proposed in Potlapalli et al.[1].

The authors’ pretrained weights for in_channels=out_channels=3 can be downloaded via setting pretrained='download'.

Parameters:
  • in_channels (int) – number of channels of the input.

  • out_channels (int) – number of channels of the output.

  • dim (int) – base dimension of the model.

  • num_blocks (tuple) – number of transformer blocks at each level of the encoder/decoder

  • num_refinement_blocks (int) – number of transformer blocks in the refinement module.

  • heads (tuple) – number of attention heads at each level of the encoder/decoder.

  • ffn_expansion_factor (float) – expansion factor of the feed-forward networks.

  • bias (bool) – whether to use bias in the convolutional layers.

  • LayerNorm_type (str) – type of layer normalization to use (‘BiasFree’ or ‘WithBias’).

  • decoder (bool) – whether to use the decoder with prompt generation blocks.

  • device (torch.device, str) – device to load the model on.

  • pretrained (str) – path to the pretrained weights or ‘download’ to download the authors’ weights.


References:

load_pretrained(checkpoint_path)[source]#

Load pretrained weights.

Parameters:

checkpoint_path (str) – path to the checkpoint or ‘download’ to download the authors’ weights.