GaussianMixtureModel

class deepinv.optim.utils.GaussianMixtureModel(n_components, dimension, device='cpu')[source]

Bases: Module

Gaussian mixture model including parameter estimation.

Implements a Gaussian Mixture Model, its negative log likelihood function and an EM algorithm for parameter estimation.

Parameters:
  • n_components (int) – number of components of the GMM

  • dimension (int) – data dimension

  • device (str) – gpu or cpu.

classify(x, cov_regularization=False)[source]

returns the index of the most likely component

Parameters:
  • x (torch.Tensor) – input data of shape batch_dimension x dimension

  • cov_regularization (bool) – whether using regularized covariance matrices

component_log_likelihoods(x, cov_regularization=False)[source]

returns a tensor containing the log likelihood values of x for each component

Parameters:
  • x (torch.Tensor) – input data of shape batch_dimension x dimension

  • cov_regularization (bool) – whether using regularized covariance matrices

fit(dataloader, max_iters=100, stopping_criterion=None, data_init=True, cov_regularization=1e-05, verbose=False)[source]

Batched Expectation Maximization algorithm for parameter estimation.

Parameters:
  • dataloader (torch.utils.data.DataLoader) – containing the data

  • max_iters (int) – maximum number of iterations

  • stopping_criterion (float) – stop when objective decrease is smaller than this number. None for performing exactly max_iters iterations

  • data_init (bool) – True for initialize mu by the first data points, False for using current values as initialization

  • verbose (bool) – Output progress information in the console

forward(x)[source]

evaluate negative log likelihood function

Parameters:

x (torch.Tensor) – input data of shape batch_dimension x dimension

get_cov()[source]

get method for covariances

get_cov_inv_reg()[source]

get method for covariances

get_weights()[source]

get method for weights

load_state_dict(*args, **kwargs)[source]

Override load_state_dict to maintain internal parameters.

set_cov(cov)[source]

Sets the covariance parameters to cov and maintains their log-determinants and inverses

Parameters:

cov (torch.Tensor) – new covariance matrices in a n_components x dimension x dimension tensor

set_cov_reg(reg)[source]

Sets covariance regularization parameter for evaluating Needed for EPLL.

Parameters:

reg (float) – covariance regularization parameter

set_weights(weights)[source]

sets weight parameter while ensuring non-negativity and summation to one

Parameters:

weights (torch.Tensor) – non-zero weight tensor of size n_components with non-negative entries