.. _trainer: Trainer ======= Training a reconstruction model can be done using the :class:`deepinv.Trainer` class, which can be easily customized to fit your needs. A trainer can be used for both training :func:`deepinv.Trainer.train` and testing :func:`deepinv.Trainer.test` a model, and can be used to save and load models. See :ref:`sphx_glr_auto_examples_basics_demo_train_inpainting.py` for a simple example of how to use the trainer. The class provides a flexible training loop that can be customized by the user. In particular, the user can rewrite the :meth:`deepinv.Trainer.compute_loss` method to define their custom training step without having to write all the training code from scratch: :: class CustomTrainer(Trainer): def compute_loss(self, physics, x, y, train=True, epoch: int = None): logs = {} self.optimizer.zero_grad() # Zero the gradients # Evaluate reconstruction network x_net = self.model_inference(y=y, physics=physics) # Compute the losses loss_total = 0 for k, l in enumerate(self.losses): loss = l(x=x, x_net=x_net, y=y, physics=physics, model=self.model, epoch=epoch) loss_total += loss.mean() metric = self.logs_total_loss_train if train else self.logs_total_loss_eval metric.update(loss_total.item()) logs[f"TotalLoss"] = metric.avg if train: loss_total.backward() # Backward the total loss self.optimizer.step() # Optimizer step return x_net, logs If the user wants to change the way the metrics are computed, they can rewrite the :meth:`deepinv.Trainer.compute_metrics` method. The user can also change the way samples are generated by overriding - :meth:`deepinv.Trainer.get_samples_online` when measurements are simulated from a ground truth returned by the dataloader. - :meth:`deepinv.Trainer.get_samples_offline` when both the ground truth and measurements are returned by the dataloader (and also optionally physics generator params). For instance, in MRI, the dataloader often returns both the measurements and the mask associated with the measurements. In this case, to update the :meth:`deepinv.physics.Physics` parameters accordingly, a potential implementation would be: :: class CustomTrainer(Trainer): def get_samples_offline(self, iterators, g): # Suppose your dataset returns per-sample masks, e.g. in MRI x, y, mask = next(iterators[g]) # Suppose physics has class params such as DecomposablePhysics or MRI physics = self.physics[g] # Update physics parameters deterministically (i.e. not using a random generator) physics.update_parameters(mask=mask.to(self.device)) return x.to(self.device), y.to(self.device), physics