Using multiple GPUs
Since all deepinv building blocks inherit from torch.nn.Module
, they are compatible with torch data parallel
modules, either via torch.nn.DataParallel
or torch.nn.parallel.DistributedDataParallel
.
For instance, one can simply write:
>>> import torch
>>> import deepinv as dinv
>>>
>>> backbone = dinv.models.DRUNet(pretrained=None, device=torch.device("cuda"))
>>> model = dinv.models.ArtifactRemoval(backbone)
>>> gpu_number = torch.cuda.device_count() # number of GPUs to use
>>> model = torch.nn.DataParallel(model, device_ids=list(range(gpu_number)))
which can seamlessly be combined with the default training loop deepinv.train()
.
Note however that it is recommended to use torch.nn.parallel.DistributedDataParallel
instead of the former
when training on multiple GPUs. Among other drawbacks of the previous approach, it is not possible to set attributes of
a model within the forward pass, which is required for some deepinv models. In this case, the training loop needs to be
modified. We point the reader to the PyTorch documentation
to extend their training codes to the multi-gpu case.