Note
Go to the end to download the full example code.
Saving and loading models
Models can be saved and loaded in the same way as in PyTorch. In this example, we show how to define, load and save a model. For the purpose of the example, we choose an unfolded Chambolle Pock algorithm as the model. The architecture of the model and its training are described in the constrained unfolded demo.
import importlib.util
from pathlib import Path
import torch
import deepinv as dinv
from deepinv.optim.data_fidelity import IndicatorL2
from deepinv.optim.prior import PnP
from deepinv.unfolded import unfolded_builder
from deepinv.models.utils import get_weights_url
Setup paths for data loading and results.
Define a forward operator
We define a simple inpainting operator with 50% of missing pixels.
n_channels = 3
img_size = 32
device = dinv.utils.get_freer_gpu() if torch.cuda.is_available() else "cpu"
# Define the physics model
physics = dinv.physics.Inpainting(
(n_channels, img_size, img_size), mask=0.5, device=device
)
Define a model
For the purpose of this example, we define a rather complex model that consists an unfolded Chambolle-Pock algorithm.
# Select the data fidelity term
data_fidelity = IndicatorL2(radius=0.0)
# Set up the trainable denoising prior; here, the soft-threshold in a wavelet basis.
# If the prior is initialized with a list of length max_iter,
# then a distinct weight is trained for each CP iteration.
# For fixed trained model prior across iterations, initialize with a single model.
level = 3
max_iter = 20 # Number of unrolled iterations
prior = [
PnP(denoiser=dinv.models.WaveletDenoiser(wv="db8", level=level, device=device))
for i in range(max_iter)
]
# Unrolled optimization algorithm parameters
lamb = [
1.0
] * max_iter # initialization of the regularization parameter. A distinct lamb is trained for each iteration.
stepsize = [
1.0
] * max_iter # initialization of the stepsizes. A distinct stepsize is trained for each iteration.
sigma_denoiser = [0.01 * torch.ones(level, 3)] * max_iter
sigma = 1.0 # stepsize for Chambolle-Pock
params_algo = {
"stepsize": stepsize,
"g_param": sigma_denoiser,
"lambda": lamb,
"sigma": sigma,
"K": physics.A,
"K_adjoint": physics.A_adjoint,
}
trainable_params = [
"g_param",
"stepsize",
] # define which parameters from 'params_algo' are trainable
# Because the CP algorithm uses more than 2 variables, we need to define a custom initialization.
def custom_init_CP(y, physics):
x_init = physics.A_adjoint(y)
u_init = y
return {"est": (x_init, x_init, u_init)}
# Define the unfolded trainable model.
model = unfolded_builder(
"CP",
trainable_params=trainable_params,
params_algo=params_algo,
data_fidelity=data_fidelity,
max_iter=max_iter,
prior=prior,
g_first=False,
custom_init=custom_init_CP,
)
# Save the model
torch.save(model.state_dict(), CKPT_DIR / "inpainting/model_nontrained.pth")
# Set up the trainable denoising prior; here, the soft-threshold in a wavelet basis.
# If the prior is initialized with a list of length max_iter,
# then a distinct weight is trained for each PGD iteration.
# For fixed trained model prior across iterations, initialize with a single model.
prior_new = [
PnP(denoiser=dinv.models.WaveletDenoiser(wv="db8", level=level, device=device))
for i in range(max_iter)
]
# Unrolled optimization algorithm parameters
lamb = [
1.0
] * max_iter # initialization of the regularization parameter. A distinct lamb is trained for each iteration.
stepsize = [
1.0
] * max_iter # initialization of the stepsizes. A distinct stepsize is trained for each iteration.
sigma_denoiser = [0.01 * torch.ones(level, 3)] * max_iter
sigma = 1.0 # stepsize for Chambolle-Pock
params_algo_new = {
"stepsize": stepsize,
"g_param": sigma_denoiser,
"lambda": lamb,
"sigma": sigma,
"K": physics.A,
"K_adjoint": physics.A_adjoint,
}
model_new = unfolded_builder(
"CP",
trainable_params=trainable_params,
params_algo=params_algo_new,
data_fidelity=data_fidelity,
max_iter=max_iter,
prior=prior_new,
g_first=False,
custom_init=custom_init_CP,
)
print(
"Parameter model_new.params_algo.g_param[0] at init: \n",
model_new.params_algo.g_param[0],
)
# load a state_dict checkpoint
file_name = (
"demo_unfolded_CP_ptwt.pth"
if importlib.util.find_spec("ptwt")
else "demo_unfolded_CP.pth"
)
url = get_weights_url(model_name="demo", file_name=file_name)
ckpt_state_dict = torch.hub.load_state_dict_from_url(
url, map_location=lambda storage, loc: storage, file_name=file_name
)
# load a state_dict checkpoint
model_new.load_state_dict(ckpt_state_dict)
print(
"Parameter model_new.params_algo.g_param[0] after loading: \n",
model_new.params_algo.g_param[0],
)
Parameter model_new.params_algo.g_param[0] at init:
Parameter containing:
tensor([[0.0100, 0.0100, 0.0100],
[0.0100, 0.0100, 0.0100],
[0.0100, 0.0100, 0.0100]], requires_grad=True)
Downloading: "https://huggingface.co/deepinv/demo/resolve/main/demo_unfolded_CP_ptwt.pth?download=true" to /home/runner/.cache/torch/hub/checkpoints/demo_unfolded_CP_ptwt.pth
0%| | 0.00/20.3k [00:00<?, ?B/s]
100%|██████████| 20.3k/20.3k [00:00<00:00, 44.7MB/s]
Parameter model_new.params_algo.g_param[0] after loading:
Parameter containing:
tensor([[ 0.0667, 0.0671, 0.0872],
[ 0.0306, 0.0271, 0.0588],
[-0.0146, -0.0157, 0.0144]], requires_grad=True)
Total running time of the script: (0 minutes 0.218 seconds)