TensorDataset#
- class deepinv.datasets.TensorDataset(*, x=None, y=None, params=None)[source]#
Bases:
ImageDataset
Dataset wrapping data explicitly passed as tensors.
This dataset can be used to return ground truth
x
, ground truth and measurements(x, y)
, or measurements only(y)
. All input tensors must be of shape(N, ...)
and of sameN
where N is the number of samples and … represents the data dimensions.Optionally,
params
are returned too.- Parameters:
x (torch.Tensor, None) – optional input ground truth tensor
x
y (torch.Tensor, None) – optional input measurement tensor
y
params (dict[str, torch.Tensor], None) – optional input physics parameters
params
of format{"str": Tensor}
Examples:
Construct a dataset from a single measurement only:
>>> import torch >>> from deepinv.datasets import TensorDataset >>> y = torch.rand(1, 3, 8, 8) # B,C,H,W >>> dataset = TensorDataset(y=y) >>> x, y = dataset[0] >>> x nan >>> y.shape torch.Size([3, 8, 8])
Construct a dataset from a ground truth batch:
>>> x = torch.rand(4, 3, 8, 8) # 4 samples of 3-channel 8x8 images >>> dataset = TensorDataset(x=x) >>> dataset[0].shape torch.Size([3, 8, 8])