Trainer#
Training a reconstruction model can be done using the deepinv.Trainer
class, which can be easily customized
to fit your needs. A trainer can be used for both training deepinv.Trainer.train()
and testing deepinv.Trainer.test()
a model, and can be used to save and load models.
See Training a reconstruction network. for a simple example of how to use the trainer.
The class provides a flexible training loop that can be customized by the user. In particular, the user can
rewrite the deepinv.Trainer.compute_loss()
method to define their custom training step without having
to write all the training code from scratch:
class CustomTrainer(Trainer):
def compute_loss(self, physics, x, y, train=True, epoch: int = None):
logs = {}
self.optimizer.zero_grad() # Zero the gradients
# Evaluate reconstruction network
x_net = self.model_inference(y=y, physics=physics)
# Compute the losses
loss_total = 0
for k, l in enumerate(self.losses):
loss = l(x=x, x_net=x_net, y=y, physics=physics, model=self.model, epoch=epoch)
loss_total += loss.mean()
metric = self.logs_total_loss_train if train else self.logs_total_loss_eval
metric.update(loss_total.item())
logs[f"TotalLoss"] = metric.avg
if train:
loss_total.backward() # Backward the total loss
self.optimizer.step() # Optimizer step
return x_net, logs
If the user wants to change the way the metrics are computed, they can rewrite the
deepinv.Trainer.compute_metrics()
method.
The user can also change the way samples are generated by overriding
deepinv.Trainer.get_samples_online()
when measurements are simulated from a ground truth returned by the dataloader.deepinv.Trainer.get_samples_offline()
when both the ground truth and measurements are returned by the dataloader (and also optionally physics generator params).
For instance, in MRI, the dataloader often returns both the measurements and the mask associated with the measurements.
In this case, to update the deepinv.physics.Physics
parameters accordingly, a potential implementation would be:
class CustomTrainer(Trainer):
def get_samples_offline(self, iterators, g):
# Suppose your dataset returns per-sample masks, e.g. in MRI
x, y, mask = next(iterators[g])
# Suppose physics has class params such as DecomposablePhysics or MRI
physics = self.physics[g]
# Update physics parameters deterministically (i.e. not using a random generator)
physics.update(mask=mask.to(self.device))
return x.to(self.device), y.to(self.device), physics
Note
When using a dataset that has loads data as a 3-tuple, this is assumed to be (x, y, params)
where params
is assumed to be a dict of parameters, e.g. generated from deepinv.datasets.generate_dataset
.
Trainer will automatically load the parameters into the physics each iteration.
Warning
When using the trainer for unsupervised training, one should be careful that each measurement should be constant across epochs.
Generally it is preferred to do offline training by using online_measurements=False
and generating a dataset using deepinv.datasets.generate_dataset()
.
If you want to use online measurements, and your physics is random (i.e. you are either using a physics_generator
or a noise model),
you must use loop_random_online_physics=True
to reset the randomness every epoch, and a DataLoader
with shuffle=False
so the measurementsa
arrive in the same order every epoch.