Trainer#
Training a reconstruction model can be done using the deepinv.Trainer
class, which can be easily customized
to fit your needs. A trainer can be used for both training deepinv.Trainer.train()
and testing deepinv.Trainer.test()
a model, and can be used to save and load models.
See Training a reconstruction network. for a simple example of how to use the trainer.
The class provides a flexible training loop that can be customized by the user. In particular, the user can
rewrite the deepinv.Trainer.compute_loss()
method to define their custom training step without having
to write all the training code from scratch:
class CustomTrainer(Trainer):
def compute_loss(self, physics, x, y, train=True, epoch: int = None):
logs = {}
self.optimizer.zero_grad() # Zero the gradients
# Evaluate reconstruction network
x_net = self.model_inference(y=y, physics=physics)
# Compute the losses
loss_total = 0
for k, l in enumerate(self.losses):
loss = l(x=x, x_net=x_net, y=y, physics=physics, model=self.model, epoch=epoch)
loss_total += loss.mean()
metric = self.logs_total_loss_train if train else self.logs_total_loss_eval
metric.update(loss_total.item())
logs[f"TotalLoss"] = metric.avg
if train:
loss_total.backward() # Backward the total loss
self.optimizer.step() # Optimizer step
return x_net, logs
If the user wants to change the way the metrics are computed, they can rewrite the
deepinv.Trainer.compute_metrics()
method.
The user can also change the way samples are generated by overriding
deepinv.Trainer.get_samples_online()
when measurements are simulated from a ground truth returned by the dataloader.deepinv.Trainer.get_samples_offline()
when both the ground truth and measurements are returned by the dataloader (and also optionally physics generator params).
For instance, in MRI, the dataloader often returns both the measurements and the mask associated with the measurements.
In this case, to update the deepinv.physics.Physics()
parameters accordingly, a potential implementation would be:
class CustomTrainer(Trainer):
def get_samples_offline(self, iterators, g):
# Suppose your dataset returns per-sample masks, e.g. in MRI
x, y, mask = next(iterators[g])
# Suppose physics has class params such as DecomposablePhysics or MRI
physics = self.physics[g]
# Update physics parameters deterministically (i.e. not using a random generator)
physics.update_parameters(mask=mask.to(self.device))
return x.to(self.device), y.to(self.device), physics