TensorDataset#
- class deepinv.datasets.TensorDataset(*, x=None, y=None, params=None)[source]#
Bases:
ImageDatasetDataset wrapping data explicitly passed as tensors.
This dataset can be used to return ground truth
x, ground truth and measurements(x, y), or measurements only(y). All input tensors must be of shape(N, ...)and of sameNwhere N is the number of samples and … represents the data dimensions.Optionally,
paramsare returned too.- Parameters:
x (torch.Tensor, None) – optional input ground truth tensor
xy (torch.Tensor, None) – optional input measurement tensor
yparams (dict[str, torch.Tensor], None) – optional input physics parameters
paramsof format{"str": Tensor}
Examples:
Construct a dataset from a single measurement only:
>>> import torch >>> from deepinv.datasets import TensorDataset >>> y = torch.rand(1, 3, 8, 8) # B,C,H,W >>> dataset = TensorDataset(y=y) >>> x, y = dataset[0] >>> x nan >>> y.shape torch.Size([3, 8, 8])
Construct a dataset from a ground truth batch:
>>> x = torch.rand(4, 3, 8, 8) # 4 samples of 3-channel 8x8 images >>> dataset = TensorDataset(x=x) >>> dataset[0].shape torch.Size([3, 8, 8])