Using multiple GPUs

Since all deepinv building blocks inherit from torch.nn.Module , they are compatible with torch data parallel modules, either via torch.nn.DataParallel or torch.nn.parallel.DistributedDataParallel.

For instance, one can simply write:

>>> import torch
>>> import deepinv as dinv
>>>
>>> backbone = dinv.models.DRUNet(pretrained=None, device=torch.device("cuda"))
>>> model = dinv.models.ArtifactRemoval(backbone)
>>> gpu_number = torch.cuda.device_count()  # number of GPUs to use
>>> model = torch.nn.DataParallel(model, device_ids=list(range(gpu_number)))

which can seamlessly be combined with the default training loop deepinv.train().

Note however that it is recommended to use torch.nn.parallel.DistributedDataParallel instead of the former when training on multiple GPUs. Among other drawbacks of the previous approach, it is not possible to set attributes of a model within the forward pass, which is required for some deepinv models. In this case, the training loop needs to be modified. We point the reader to the PyTorch documentation to extend their training codes to the multi-gpu case.