RandomPatchSampler#
- class deepinv.datasets.RandomPatchSampler(x_dir=None, y_dir=None, patch_size=32, file_format='.npy', ch_axis=None, dtype=torch.float32, loader=None)[source]#
Bases:
ImageDataset
Dataset for nD images that samples one random patch per image.
This dataset builds from one or two directories of nD images (must be of format
.npy
,.nii(.gz)`, or ``.b2nd
, ifloader
is not specified). On each epoch, it returns a randomly sampled patch of fixed size from each volume.Warning
This loader uses torch’s random functionality. To ensure reproducibility, set the DataLoader’s
generator
with a fixed seed.Supported use cases: - Single-directory: provide only the ground-truth folder
x_dir
or measurement foldery_dir
(returns patches from that directory). - Paired-directory: provide bothx_dir
andy_dir
(returns matched patches from both).Channel handling: - If
ch_axis=None
: a singleton channel dimension is added at axis 0. - Ifch_axis=0
: images are assumed channel-first. - Ifch_axis=-1
: images are assumed channel-last and transposed to channel-first. - Patches are never extracted along the channel axis (patch size for that axis is ignored).Patch size handling: - Accepts either an integer (applied to all spatial dims) or a tuple. - If
patch_size
is tuple, andpatch_size[i] == 1
, this is equivalent to slicing across axis i (singleton at axis i will be squeezed). This can be used to e.g. extract 2D slices from a 3D volume - If tuple length is one less than the image ndim, the channel axis is auto-filled withNone
.Randomness & reproducibility: - Patch coordinates are drawn with Python’s
random
module. - To ensure deterministic behavior across workers, set the DataLoader’sworker_init_fn
orgenerator
according to the PyTorch reproducibility guidelines.Notes - All images must have the same dimensionality. - When both directories are provided, only files present in both are used. - Shapes of each file are checked for consistency (spatial not smaller than
patch_size
+ channels remain consistent across files).