WaveletDenoiser#
- class deepinv.models.WaveletDenoiser(level=3, wv='db8', device='cpu', non_linearity='soft', mode='zero', wvdim=2, is_complex=False)[source]#
Bases:
DenoiserOrthogonal Wavelet denoising with the \(\ell_1\) norm.
This denoiser is defined as the solution to the optimization problem:
\[\underset{x}{\arg\min} \; \|x-y\|^2 + \gamma \|\Psi x\|_n\]where \(\Psi\) is an orthonormal wavelet transform, \(\lambda>0\) is a hyperparameter, and where \(\|\cdot\|_n\) is either the \(\ell_1\) norm (
non_linearity="soft") or the \(\ell_0\) norm (non_linearity="hard"). A variant of the \(\ell_0\) norm is also available (non_linearity="topk"), where the thresholding is done by keeping the \(k\) largest coefficients in each wavelet subband and setting the others to zero.The solution is available in closed-form, thus the denoiser is cheap to compute.
Warning
This model requires Pytorch Wavelets (
ptwt) to be installed. It can be installed withpip install ptwt.- Parameters:
level (int) β decomposition level of the wavelet transform
wv (str) β mother wavelet (follows the PyWavelets convention) (default: βdb8β)
non_linearity (str) β
"soft","hard"or"topk"thresholding (default:"soft"). If"topk", only the top-k wavelet coefficients are kept.mode (str) β padding mode for the wavelet transform (default: βzeroβ).
wvdim (int) β dimension of the wavelet transform (either 2 or 3) (default: 2).
is_complex (bool) β whether the input is complex-valued (default: False).
device (str) β cpu or gpu
Note
This class requires the
ptwtpackage to be installed. Install withpip install ptwt.- flatten_coeffs(dec)[source]#
Flattens the wavelet coefficients and returns them in a single torch vector of shape (n_coeffs,).
- forward(x, ths=0.1, **kwargs)[source]#
Run the model on a noisy image.
- Parameters:
x (torch.Tensor) β noisy image. Assumes a tensor of shape (B, C, H, W) (2D data) or (B, C, D, H, W) (3D data).
ths (int, float, torch.Tensor) β thresholding parameter \(\gamma\). If
thsis a tensor, it should be of shape(B,)(same coefficent for all levels),(B, n_levels-1)(one coefficient per level), or(B, n_levels-1, 3)(one coefficient per subband and per level).Bshould be the same as the batch size of the input or1. Ifnon_linearityequals"soft"or"hard",thsserves as a (soft or hard) thresholding parameter for the wavelet coefficients. Ifnon_linearityequals"topk",thscan indicate the number of wavelet coefficients that are kept (ifint) or the proportion of coefficients that are kept (iffloat).
- hard_threshold_topk(x, ths=0.1)[source]#
Hard thresholding of the wavelet coefficients by keeping only the top-k coefficients and setting the others to 0.
- Parameters:
x (torch.Tensor) β wavelet coefficients.
ths (float, int) β top k coefficients to keep. If
float, it is interpreted as a proportion of the total number of coefficients. Ifint, it is interpreted as the number of coefficients to keep.
- prox_l0(x, ths=0.1)[source]#
Hard thresholding of the wavelet coefficients.
- Parameters:
x (torch.Tensor) β wavelet coefficients of shape (B, C, H, W) or (B, C, D, H, W).
ths (float, torch.Tensor) β threshold of shape (B,) or scalar. If scalar, same threshold is used for all elements in batch.
- prox_l1(x, ths=0.1)[source]#
Soft thresholding of the wavelet coefficients.
- Parameters:
x (torch.Tensor) β wavelet coefficients.
ths (float, torch.Tensor) β threshold.
- static psi(x, wavelet='db2', level=2, dimension=2, mode='zero')[source]#
Returns a flattened list containing the wavelet coefficients.
- Parameters:
x (torch.Tensor) β input image.
wavelet (str) β mother wavelet.
level (int) β decomposition level.
dimension (int) β dimension of the wavelet transform (either 2 or 3).
- reshape_ths(ths, level)[source]#
- Reshape the thresholding parameter in the appropriate format, i.e. either:
a list of 3 elements, or
a tensor of 3 elements.
Since the approximation coefficients are not thresholded, we do not need to provide a thresholding parameter, ths has shape (n_levels-1, 3).
Examples using WaveletDenoiser:#
Unfolded Chambolle-Pock for constrained image inpainting