adjoint_function
- class deepinv.physics.adjoint_function(A, input_size, device='cpu', dtype=torch.float32)[source]
Bases:
Provides the adjoint function of a linear operator \(A\), i.e., \(A^{\top}\).
The generated function can be simply called as
A_adjoint(y)
, for example:>>> import torch >>> from deepinv.physics.forward import adjoint_function >>> A = lambda x: torch.roll(x, shifts=(1,1), dims=(2,3)) # shift image by one pixel >>> x = torch.randn((4, 1, 5, 5)) >>> y = A(x) >>> A_adjoint = adjoint_function(A, (4, 1, 5, 5)) >>> torch.allclose(A_adjoint(y), x) # we have A^T(A(x)) = x True
- Parameters:
- Returns:
(Callable) function that computes the adjoint of \(A\).