Trainer#

Training a reconstruction model can be done using the deepinv.Trainer class, which can be easily customized to fit your needs. A trainer can be used for both training deepinv.Trainer.train() and testing deepinv.Trainer.test() a model, and can be used to save and load models.

See Training a reconstruction network. for a simple example of how to use the trainer.

The class provides a flexible training loop that can be customized by the user. In particular, the user can rewrite the deepinv.Trainer.compute_loss() method to define their custom training step without having to write all the training code from scratch:

class CustomTrainer(Trainer):
    def compute_loss(self, physics, x, y, train=True, epoch: int = None):
        logs = {}

        self.optimizer.zero_grad() # Zero the gradients

        # Evaluate reconstruction network
        x_net = self.model_inference(y=y, physics=physics)

        # Compute the losses
        loss_total = 0
        for k, l in enumerate(self.losses):
            loss = l(x=x, x_net=x_net, y=y, physics=physics, model=self.model, epoch=epoch)
            loss_total += loss.mean()

        metric = self.logs_total_loss_train if train else self.logs_total_loss_eval
        metric.update(loss_total.item())
        logs[f"TotalLoss"] = metric.avg

        if train:
            loss_total.backward()  # Backward the total loss
            self.optimizer.step() # Optimizer step

        return x_net, logs

If the user wants to change the way the metrics are computed, they can rewrite the deepinv.Trainer.compute_metrics() method.

The user can also change the way samples are generated by overriding

For instance, in MRI, the dataloader often returns both the measurements and the mask associated with the measurements. In this case, to update the deepinv.physics.Physics() parameters accordingly, a potential implementation would be:

class CustomTrainer(Trainer):
    def get_samples_offline(self, iterators, g):
        # Suppose your dataset returns per-sample masks, e.g. in MRI
        x, y, mask = next(iterators[g])

        # Suppose physics has class params such as DecomposablePhysics or MRI
        physics = self.physics[g]

        # Update physics parameters deterministically (i.e. not using a random generator)
        physics.update_parameters(mask=mask.to(self.device))

        return x.to(self.device), y.to(self.device), physics