ICNN
- class deepinv.models.ICNN(in_channels=3, dim_hidden=256, beta_softplus=100, alpha=0.0, pos_weights=False, rectifier_fn=ReLU(), device='cpu')[source]
Bases:
Module
Input Convex Neural Network.
Mostly based on the implementation from the paper What’s in a Prior? Learned Proximal Networks for Inverse Problems, and from the implementation from the OOT libreary.
- Parameters:
in_channels (int) – Number of input channels.
dim_hidden (int) – Number of hidden units.
beta_softplus (float) – Beta parameter for the softplus activation function.
alpha (float) – Strongly convex parameter.
pos_weights (bool) – Whether to force positive weights in the forward pass.
rectifier_fn (torch.nn.Module) – Activation function to use to force postive weight.
device (str) – Device to use for the model.
- forward(x)[source]
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.