Using Multiple GPUs#

Since all deepinv building blocks inherit from torch.nn.Module , they are compatible with torch data parallel modules, either via torch.nn.DataParallel or torch.nn.parallel.DistributedDataParallel.

For instance, one can simply write:

import torch
import deepinv as dinv

backbone = dinv.models.DRUNet(pretrained=None, device=torch.device("cuda"))
model = dinv.models.ArtifactRemoval(backbone)
gpu_number = torch.cuda.device_count()  # number of GPUs to use
model = torch.nn.DataParallel(model, device_ids=list(range(gpu_number)))

which can seamlessly be combined with the default Trainer deepinv.Trainer.

Note however that it is recommended to use torch.nn.parallel.DistributedDataParallel instead of the former when training on multiple GPUs. Among other drawbacks of the previous approach, it is not possible to set attributes of a model within the forward pass, which is required for some deepinv models. In this case, the training loop needs to be modified. We point the reader to the PyTorch documentation to extend their training codes to the multi-gpu case.