Bregman

class deepinv.optim.bregman.Bregman(phi=None)[source]

Bases: Potential

Module for the Bregman framework with convex Bregman potential \(\phi\). Comes with methods to compute the potential, its gradient, its conjugate, its gradient and its Bregman divergence.

Parameters:

h (callable) – Potential function \(\phi(x)\) to be used in the Bregman framework.

MD_step(x, grad, *args, gamma=1.0, **kwargs)[source]

Performs a Mirror Descent step \(x = \nabla \phi^*(\nabla \phi(x) - \gamma \nabla f(x))\).

Parameters:
  • x (torch.Tensor) – Variable \(x\) at which the step is performed.

  • grad (torch.Tensor) – Gradient of the minimized function at \(x\).

  • gamma (float) – Step size.

Returns:

(torch.tensor) updated variable \(x\).

div(x, y, *args, **kwargs)[source]

Computes the Bregman divergence \(D_\phi(x,y)\) with Bregman potential \(\phi\).

Parameters:
  • x (torch.Tensor) – Left variable \(x\) at which the divergence is computed.

  • y (torch.Tensor) – Right variable \(y\) at which the divergence is computed.

Returns:

(torch.tensor) divergence \(h(x) - h(y) - \langle \nabla h(y), x-y \rangle\).