BregmanL2
- class deepinv.optim.bregman.BregmanL2[source]
Bases:
Bregman
Module for the L2 norm as Bregman potential \(\phi(x) = \frac{1}{2} \|x\|_2^2\). The corresponding Bregman divergence is the squared Euclidean distance \(D(x,y) = \frac{1}{2} \|x-y\|_2^2\).
- conjugate(x)[source]
Computes the convex conjugate potential \(\phi^*(y) = \frac{1}{2} \|y\|_2^2\).
- Parameters:
x (torch.Tensor) – Variable \(x\) at which the conjugate is computed.
- Returns:
(torch.tensor) conjugate potential \(\phi^*(y)\).
- div(x, y, *args, **kwargs)[source]
Computes the Bregman divergence with potential \(\phi\). Here falls back to the L2 distance.
- Parameters:
x (torch.Tensor) – Variable \(x\) at which the divergence is computed.
y (torch.Tensor) – Variable \(y\) at which the divergence is computed.
- Returns:
(torch.tensor) divergence \(\phi(x) - \phi(y) - \langle \nabla \phi(y), x-y\).
- fn(x)[source]
Computes the L2 norm potential \(\phi(x) = \frac{1}{2} \|x\|_2^2\).
- Parameters:
x (torch.Tensor) – Variable \(x\) at which the potential is computed.
- Returns:
(torch.tensor) potential \(h(x)\).
- grad(x, *args, **kwargs)[source]
Calculates the gradient of the L2 norm \(\nabla \phi(x) = x\).
- Parameters:
x (torch.Tensor) – Variable \(x\) at which the gradient is computed.
- Returns:
(torch.tensor) gradient \(\nabla_x \phi\), computed in \(x\).
- grad_conj(x, *args, **kwargs)[source]
Calculates the gradient of the conjugate of the L2 norm \(\nabla \phi^*(x) = x\).
- Parameters:
x (torch.Tensor) – Variable \(x\) at which the gradient is computed.
- Returns:
(torch.tensor) gradient \(\nabla_x \phi^*\), computed in \(x\).