JacobianSpectralNorm#
- class deepinv.loss.JacobianSpectralNorm(max_iter=10, tol=1e-3, eval_mode=False, verbose=False, reduction='max', reduced_batchsize=None)[source]#
Bases:
Loss
Computes the spectral norm of the Jacobian.
Given a function \(f:\mathbb{R}^n\to\mathbb{R}^n\), this module computes the spectral norm of the Jacobian of \(f\) in \(x\), i.e.
\[\|\frac{df}{du}(x)\|_2.\]This spectral norm is computed with a power method leveraging jacobian vector products, as proposed in https://arxiv.org/abs/2012.13247v2.
Note
This implementation assumes that the input \(x\) is batched with shape
(B, ...)
, where B is the batch size.- Parameters:
max_iter (int) – maximum numer of iteration of the power method.
tol (float) – tolerance for the convergence of the power method.
eval_mode (bool) – set to
False
if one does not want to backpropagate through the spectral norm (default), set toTrue
otherwise.verbose (bool) – whether to print computation details or not.
reduction (str) – reduction in batch dimension. One of [“mean”, “sum”, “max”], operation to be performed after all spectral norms have been computed. If
None
, a vector of lengthbatch_size
will be returned. Defaults to “max”.reduced_batchsize (int) – if not
None
, the batch size will be reduced to this value for the computation of the spectral norm. Can be useful to reduce memory usage and computation time when the batch size is large.
- Examples:
>>> import torch >>> from deepinv.loss.regularisers import JacobianSpectralNorm >>> _ = torch.manual_seed(0) >>> >>> reg_l2 = JacobianSpectralNorm(max_iter=100, tol=1e-5, eval_mode=False, verbose=True) >>> A = torch.diag(torch.Tensor(range(1, 51))).unsqueeze(0) # creates a diagonal matrix with largest eigenvalue = 50 >>> x = torch.randn((1, A.shape[1])).unsqueeze(0).requires_grad_() >>> out = x @ A >>> regval = reg_l2(out, x) >>> print(regval) # returns approx 50 tensor(49.9999)
- forward(y, x, **kwargs)[source]#
Computes the spectral norm of the Jacobian of \(f\) in \(x\).
Warning
The input \(x\) must have requires_grad=True before evaluating \(f\).
- Parameters:
y (torch.Tensor) – output of the function \(f\) at \(x\), of dimension
(B, ...)
x (torch.Tensor) – input of the function \(f\), of dimension
(B, ...)
If x has multiple dimensions, it’s assumed the first one corresponds to the batch dimension.