Trainer#

Training a reconstruction model can be done using the deepinv.Trainer class, which can be easily customized to fit your needs. A trainer can be used for both training deepinv.Trainer.train() and testing deepinv.Trainer.test() a model, and can be used to save and load models.

See Training a reconstruction network. for a simple example of how to use the trainer.

The class provides a flexible training loop that can be customized by the user. In particular, the user can rewrite the deepinv.Trainer.compute_loss() method to define their custom training step without having to write all the training code from scratch:

class CustomTrainer(Trainer):
    def compute_loss(self, physics, x, y, train=True, epoch: int = None):
        logs = {}

        self.optimizer.zero_grad() # Zero the gradients

        # Evaluate reconstruction network
        x_net = self.model_inference(y=y, physics=physics)

        # Compute the losses
        loss_total = 0
        for k, l in enumerate(self.losses):
            loss = l(x=x, x_net=x_net, y=y, physics=physics, model=self.model, epoch=epoch)
            loss_total += loss.mean()

        metric = self.logs_total_loss_train if train else self.logs_total_loss_eval
        metric.update(loss_total.item())
        logs[f"TotalLoss"] = metric.avg

        if train:
            loss_total.backward()  # Backward the total loss
            self.optimizer.step() # Optimizer step

        return x_net, logs

If the user wants to change the way the metrics are computed, they can rewrite the deepinv.Trainer.compute_metrics() method.

The user can also change the way samples are generated by overriding

For instance, in MRI, the dataloader often returns both the measurements and the mask associated with the measurements. In this case, to update the deepinv.physics.Physics parameters accordingly, a potential implementation would be:

class CustomTrainer(Trainer):
    def get_samples_offline(self, iterators, g):
        # Suppose your dataset returns per-sample masks, e.g. in MRI
        x, y, mask = next(iterators[g])

        # Suppose physics has class params such as DecomposablePhysics or MRI
        physics = self.physics[g]

        # Update physics parameters deterministically (i.e. not using a random generator)
        physics.update(mask=mask.to(self.device))

        return x.to(self.device), y.to(self.device), physics

Note

When using a dataset that has loads data as a 3-tuple, this is assumed to be (x, y, params) where params is assumed to be a dict of parameters, e.g. generated from deepinv.datasets.generate_dataset. Trainer will automatically load the parameters into the physics each iteration.

Warning

When using the trainer for unsupervised training, one should be careful that each measurement should be constant across epochs. Generally it is preferred to do offline training by using online_measurements=False and generating a dataset using deepinv.datasets.generate_dataset().

If you want to use online measurements, and your physics is random (i.e. you are either using a physics_generator or a noise model), you must use loop_random_online_physics=True to reset the randomness every epoch, and a DataLoader with shuffle=False so the measurementsa arrive in the same order every epoch.