Bregman
- class deepinv.optim.bregman.Bregman(phi=None)[source]
Bases:
Potential
Module for the Bregman framework with convex Bregman potential \(\phi\). Comes with methods to compute the potential, its gradient, its conjugate, its gradient and its Bregman divergence.
- Parameters:
h (callable) – Potential function \(\phi(x)\) to be used in the Bregman framework.
- MD_step(x, grad, *args, gamma=1.0, **kwargs)[source]
Performs a Mirror Descent step \(x = \nabla \phi^*(\nabla \phi(x) - \gamma \nabla f(x))\).
- Parameters:
x (torch.Tensor) – Variable \(x\) at which the step is performed.
grad (torch.Tensor) – Gradient of the minimized function at \(x\).
gamma (float) – Step size.
- Returns:
(torch.tensor) updated variable \(x\).
- div(x, y, *args, **kwargs)[source]
Computes the Bregman divergence \(D_\phi(x,y)\) with Bregman potential \(\phi\).
- Parameters:
x (torch.Tensor) – Left variable \(x\) at which the divergence is computed.
y (torch.Tensor) – Right variable \(y\) at which the divergence is computed.
- Returns:
(torch.tensor) divergence \(h(x) - h(y) - \langle \nabla h(y), x-y \rangle\).