Source code for deepinv.optim.epll

import torch.nn as nn
import torch
from deepinv.utils import patch_extractor
from deepinv.optim.utils import conjugate_gradient
from deepinv.models.utils import get_weights_url
from deepinv.optim.utils import GaussianMixtureModel


[docs] class EPLL(nn.Module): r""" Expected Patch Log Likelihood reconstruction method. Reconstruction method based on the minimization problem .. math:: \underset{x}{\arg\min} \; \|y-Ax\|^2 - \sum_i \log p(P_ix) where the first term is a standard :math:`\ell_2` data-fidelity, and the second term represents a patch prior via Gaussian mixture models, where :math:`P_i` is a patch operator that extracts the ith (overlapping) patch from the image. The reconstruction function is based on the approximated half-quadratic splitting method as in Zoran, D., and Weiss, Y. "From learning models of natural image patches to whole image restoration." (ICCV 2011). :param None, deepinv.optim.utils.GaussianMixtureModel GMM: Gaussian mixture defining the distribution on the patch space. ``None`` creates a GMM with n_components components of dimension accordingly to the arguments patch_size and channels. :param int n_components: number of components of the generated GMM if GMM is ``None``. :param str, None pretrained: Path to pretrained weights of the GMM with file ending ``.pt``. None for no pretrained weights, ``"download"`` for pretrained weights on the BSDS500 dataset, ``"GMM_lodopab_small"`` for the weights from the limited-angle CT example. See :ref:`pretrained-weights <pretrained-weights>` for more details. :param int patch_size: patch size. :param int channels: number of color channels (e.g. 1 for gray-valued images and 3 for RGB images) :param str device: defines device (``cpu`` or ``cuda``) """ def __init__( self, GMM=None, n_components=200, pretrained="download", patch_size=6, channels=1, device="cpu", ): super(EPLL, self).__init__() if GMM is None: self.GMM = GaussianMixtureModel( n_components, patch_size**2 * channels, device=device ) else: self.GMM = GMM self.patch_size = patch_size if pretrained: if pretrained[-3:] == ".pt": ckpt = torch.load(pretrained) else: if pretrained.startswith("GMM_lodopab_small"): assert patch_size == 3 assert channels == 1 file_name = pretrained + ".pt" elif ( (pretrained == "GMM_BSDS_gray" or pretrained == "download") and patch_size == 6 and channels == 1 ): file_name = "GMM_BSDS_gray2.pt" elif ( (pretrained == "GMM_BSDS_color" or pretrained == "download") and patch_size == 6 and channels == 3 ): file_name = "GMM_BSDS_color2.pt" else: raise ValueError( "No pretrained weights found for this configuration!" ) url = get_weights_url(model_name="EPLL", file_name=file_name) ckpt = torch.hub.load_state_dict_from_url( url, map_location=lambda storage, loc: storage, file_name=file_name ) self.load_state_dict(ckpt)
[docs] def forward(self, y, physics, sigma=None, x_init=None, betas=None, batch_size=-1): r""" Approximated half-quadratic splitting method for image reconstruction as proposed by Zoran and Weiss. :param torch.Tensor y: tensor of observations. Shape: batch size x ... :param torch.Tensor, None x_init: tensor of initializations. If ``None`` uses initializes with the adjoint of the forward operator. Shape: batch size x channels x height x width :param deepinv.physics.LinearPhysics physics: Forward linear operator. :param list[float] betas: parameters from the half-quadratic splitting. ``None`` uses the standard choice ``[1,4,8,16,32]/sigma_sq`` :param int batch_size: batching the patch estimations for large images. No effect on the output, but a small value reduces the memory consumption but might increase the computation time. -1 for considering all patches at once. """ x_init = physics.A_adjoint(y) if x_init is None else x_init if sigma is None: if hasattr(physics.noise_model, "sigma"): sigma = physics.noise_model.sigma else: raise ValueError( "Noise level sigma has to be provided if not present in the physics model." ) if betas is None: # default choice as suggested in Parameswaran et al. "Accelerating GMM-Based Patch Priors for Image Restoration: Three Ingredients for a 100× Speed-Up" betas = [beta / sigma**2 for beta in [1.0, 4.0, 8.0, 16.0, 32.0]] if y.shape[0] > 1: # vectorization over a batch of images not implemented.... return torch.cat( [ self.reconstruction( y[i : i + 1], x_init[i : i + 1], betas=betas, batch_size=batch_size, ) for i in range(y.shape[0]) ], 0, ) x = x_init Aty = physics.A_adjoint(y) for beta in betas: x = self._reconstruction_step(Aty, x, sigma**2, beta, physics, batch_size) return x
[docs] def negative_log_likelihood(self, x): r""" Takes patches and returns the negative log likelihood of the GMM for each patch. :param torch.Tensor x: tensor of patches of shape batch_size x number of patches per batch x patch_dimensions """ B, n_patches = x.shape[0:2] logpz = self.GMM(x.view(B * n_patches, -1)) return logpz.view(B, n_patches)
def _reconstruction_step(self, Aty, x, sigma_sq, beta, physics, batch_size): # precomputations for GMM with covariance regularization self.GMM.set_cov_reg(1.0 / beta) N, M = x.shape[2:4] total_patch_number = (N - self.patch_size + 1) * (M - self.patch_size + 1) if batch_size == -1 or batch_size > total_patch_number: batch_size = total_patch_number # compute sum P_i^T z and sum P_i^T P_i on the fly with batching x_tilde_flattened = torch.zeros_like(x).reshape(-1) patch_multiplicities = torch.zeros_like(x).reshape(-1) # batching loop over all patches in the image ind = 0 while ind < total_patch_number: # extract patches n_patches = min(batch_size, total_patch_number - ind) patch_inds = torch.LongTensor(range(ind, ind + n_patches)).to(x.device) patches, linear_inds = patch_extractor( x, n_patches, self.patch_size, position_inds_linear=patch_inds ) patches = patches.reshape(patches.shape[0] * patches.shape[1], -1) linear_inds = linear_inds.reshape(patches.shape[0], -1) # Gaussian selection k_star = self.GMM.classify(patches, cov_regularization=True) # Patch estimation estimation_matrices = torch.bmm( self.GMM.get_cov_inv_reg(), self.GMM.get_cov() ) estimation_matrices_k_star = estimation_matrices[k_star] patch_estimates = torch.bmm( estimation_matrices_k_star, patches[:, :, None] ).reshape(patches.shape[0], patches.shape[1]) # update on-the-fly parameters # the following two lines are the same like # patch_multiplicities[linear_inds] += 1.0 # x_tilde_flattened[linear_inds] += patch_estimates # where values of multiple indices are accumulated. patch_multiplicities.index_put_( (linear_inds,), torch.ones_like(patch_estimates), accumulate=True ) x_tilde_flattened.index_put_( (linear_inds,), patch_estimates, accumulate=True ) ind = ind + n_patches # compute x_tilde x_tilde_flattened /= patch_multiplicities # Image estimation by CG method rhs = Aty + beta * sigma_sq * x_tilde_flattened.view(x.shape) op = lambda im: physics.A_adjoint(physics.A(im)) + beta * sigma_sq * im hat_x = conjugate_gradient(op, rhs, max_iter=1e2, tol=1e-5) return hat_x