WaveletPrior#

class deepinv.optim.WaveletPrior(level=3, wv='db8', p=1, device='cpu', wvdim=2, clamp_min=None, clamp_max=None, *args, **kwargs)[source]#

Bases: Prior

Wavelet prior \(\reg{x} = \|\Psi x\|_{p}\).

\(\Psi\) is an orthonormal wavelet transform, and \(\|\cdot\|_{p}\) is the \(p\)-norm, with \(p=0\), \(p=1\), or \(p=\infty\).

If clamping parameters are provided, the prior writes as \(\reg{x} = \|\Psi x\|_{p} + \iota_{[c_{\text{min}}, c_{\text{max}}]}(x)\), where \(\iota_{[c_{\text{min}}, c_{\text{max}}]}(x)\) is the indicator function of the interval \([c_{\text{min}}, c_{\text{max}}]\).

Note

Following common practice in signal processing, only detail coefficients are regularized, and the approximation coefficients are left untouched.

Warning

For 3D data, the computational complexity of the wavelet transform cubically with the size of the support. For large 3D data, it is recommended to use wavelets with small support (e.g. db1 to db4).

Parameters:
  • level (int) – level of the wavelet transform. Default is 3.

  • wv (str) – wavelet name to choose among those available in pywt. Default is “db8”.

  • p (float) – \(p\)-norm of the prior. Default is 1.

  • device (str) – device on which the wavelet transform is computed. Default is “cpu”.

  • wvdim (int) – dimension of the wavelet transform, can be either 2 or 3. Default is 2.

  • clamp_min (float) – minimum value for the clamping. Default is None.

  • clamp_max (float) – maximum value for the clamping. Default is None.

fn(x, *args, reduce=True, **kwargs)[source]#

Computes the regularizer

\[\begin{equation} {\regname}_{i,j}(x) = \|(\Psi x)_{i,j}\|_{p} \end{equation}\]

where \(\Psi\) is an orthonormal wavelet transform, \(i\) and \(j\) are the indices of the wavelet sub-bands, and \(\|\cdot\|_{p}\) is the \(p\)-norm, with \(p=0\), \(p=1\), or \(p=\infty\). As mentioned in the class description, only detail coefficients are regularized, and the approximation coefficients are left untouched.

If reduce is set to True, the regularizer is summed over all detail coefficients, yielding

\[\regname(x) = \|\Psi x\|_{p}.\]

If reduce is set to False, the regularizer is returned as a list of the norms of the detail coefficients.

Parameters:
  • x (torch.Tensor) – Variable \(x\) at which the prior is computed.

  • reduce (bool) – if True, the prior is summed over all detail coefficients. Default is True.

Returns:

(torch.Tensor) prior \(g(x)\).

prox(x, *args, gamma=1.0, **kwargs)[source]#

Compute the proximity operator of the wavelet prior with the denoiser WaveletDenoiser. Only detail coefficients are thresholded.

Parameters:
  • x (torch.Tensor) – Variable \(x\) at which the proximity operator is computed.

  • gamma (float) – stepsize of the proximity operator.

Returns:

(torch.Tensor) proximity operator at \(x\).

psi(x, wavelet='db2', level=2, dimension=2)[source]#

Applies the (flattening) wavelet decomposition of x.

Examples using WaveletPrior:#

Image inpainting with wavelet prior

Image inpainting with wavelet prior

Learned Iterative Soft-Thresholding Algorithm (LISTA) for compressed sensing

Learned Iterative Soft-Thresholding Algorithm (LISTA) for compressed sensing

Radio interferometric imaging with deepinverse

Radio interferometric imaging with deepinverse