GaussianMixtureModel#
- class deepinv.optim.utils.GaussianMixtureModel(n_components, dimension, device='cpu')[source]#
Bases:
Module
Gaussian mixture model including parameter estimation.
Implements a Gaussian Mixture Model, its negative log likelihood function and an EM algorithm for parameter estimation.
- Parameters:
- classify(x, cov_regularization=False)[source]#
returns the index of the most likely component
- Parameters:
x (torch.Tensor) – input data of shape batch_dimension x dimension
cov_regularization (bool) – whether using regularized covariance matrices
- component_log_likelihoods(x, cov_regularization=False)[source]#
returns a tensor containing the log likelihood values of x for each component
- Parameters:
x (torch.Tensor) – input data of shape batch_dimension x dimension
cov_regularization (bool) – whether using regularized covariance matrices
- fit(dataloader, max_iters=100, stopping_criterion=None, data_init=True, cov_regularization=1e-05, verbose=False)[source]#
Batched Expectation Maximization algorithm for parameter estimation.
- Parameters:
dataloader (torch.utils.data.DataLoader) – containing the data
max_iters (int) – maximum number of iterations
stopping_criterion (float) – stop when objective decrease is smaller than this number. None for performing exactly max_iters iterations
data_init (bool) – True for initialize mu by the first data points, False for using current values as initialization
verbose (bool) – Output progress information in the console
- forward(x)[source]#
evaluate negative log likelihood function
- Parameters:
x (torch.Tensor) – input data of shape batch_dimension x dimension
- set_cov(cov)[source]#
Sets the covariance parameters to cov and maintains their log-determinants and inverses
- Parameters:
cov (torch.Tensor) – new covariance matrices in a n_components x dimension x dimension tensor
- set_cov_reg(reg)[source]#
Sets covariance regularization parameter for evaluating Needed for EPLL.
- Parameters:
reg (float) – covariance regularization parameter
- set_weights(weights)[source]#
sets weight parameter while ensuring non-negativity and summation to one
- Parameters:
weights (torch.Tensor) – non-zero weight tensor of size n_components with non-negative entries