.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/basics/demo_hf_dataset.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_basics_demo_hf_dataset.py: Using huggingface dataset ==================================================================================================== | This simple example shows how to load and prepare properly a huggingface dataset. | Context: having a quick access to several datasets available under the huggingface format. | Available datasets: https://huggingface.co/datasets?search=deepinv | Here we use `drunet_dataset `_. .. GENERATED FROM PYTHON SOURCE LINES 13-16 Load libraries ---------------------------------------------------------------------------------------- .. GENERATED FROM PYTHON SOURCE LINES 16-24 .. code-block:: Python from datasets import load_dataset as load_dataset_hf from torch.utils.data import IterableDataset, DataLoader from torchvision import transforms import deepinv as dinv .. GENERATED FROM PYTHON SOURCE LINES 25-31 Stream data from Internet ---------------------------------------------------------------------------------------- Stream data from huggingface servers: only a limited number of samples is loaded on memory at all time, which avoid having to save the dataset on disk and avoid overloading the memory capacity. .. GENERATED FROM PYTHON SOURCE LINES 31-39 .. code-block:: Python # source : https://huggingface.co/datasets/deepinv/drunet_dataset # type : datasets.iterable_dataset.IterableDataset raw_hf_train_dataset = load_dataset_hf( "deepinv/drunet_dataset", split="train", streaming=True ) print("Number of data files used to store raw data: ", raw_hf_train_dataset.n_shards) .. rst-class:: sphx-glr-script-out .. code-block:: none Number of data files used to store raw data: 1 .. GENERATED FROM PYTHON SOURCE LINES 40-46 Shuffle data with buffer shuffling ---------------------------------------------------------------------------------------- | In streaming mode, we can only read sequentially the data sample in a certain order thus we are not able to do exact shuffling. | An alternative way is the buffer shuffling which load a fixed number of samples in memory and let us pick randomly one sample among this fixed number of samples. .. GENERATED FROM PYTHON SOURCE LINES 46-51 .. code-block:: Python # https://huggingface.co/docs/datasets/about_mapstyle_vs_iterable raw_hf_train_dataset = raw_hf_train_dataset.shuffle(seed=42, buffer_size=100) .. GENERATED FROM PYTHON SOURCE LINES 52-57 Apply transformation on dataset ---------------------------------------------------------------------------------------- We define transformation with ``torchvision.transforms`` module, but it can be any other function. .. GENERATED FROM PYTHON SOURCE LINES 57-90 .. code-block:: Python # Function that should be applied to a PIL Image img_transforms = transforms.Compose( [ transforms.Resize((224, 224)), # Resize all images to 224x224 transforms.ToTensor(), ] ) # Class that apply `transform` on data samples of a datasets.iterable_dataset.IterableDataset class HFDataset(IterableDataset): r""" Creates an iteratble dataset from a Hugging Face dataset to enable streaming. """ def __init__(self, hf_dataset, transforms=None, key="png"): self.hf_dataset = hf_dataset self.transform = transforms self.key = key def __iter__(self): for sample in self.hf_dataset: if self.transform: out = self.transform(sample[self.key]) else: out = sample[self.key] yield out hf_train_dataset = HFDataset(raw_hf_train_dataset, transforms=img_transforms) .. GENERATED FROM PYTHON SOURCE LINES 91-97 Create a dataloader -------------------------------------------------------------------- | With ``datasets.iterable_dataset.IterableDataset``, data samples are stored in 1 file or in a few files. | In case of few files, we can use ``num_workers`` argument to load data samples in parallel. .. GENERATED FROM PYTHON SOURCE LINES 97-111 .. code-block:: Python if raw_hf_train_dataset.n_shards > 1: # num_workers <= raw_hf_train_dataset.n_shards (number of data files) # num_workers <= number of available cpu cores num_workers = ... train_dataloader = DataLoader( hf_train_dataset, batch_size=2, num_workers=num_workers ) else: train_dataloader = DataLoader(hf_train_dataset, batch_size=2) # display a batch batch = next(iter(train_dataloader)) dinv.utils.plot([batch[0], batch[1]]) .. image-sg:: /auto_examples/basics/images/sphx_glr_demo_hf_dataset_001.png :alt: demo hf dataset :srcset: /auto_examples/basics/images/sphx_glr_demo_hf_dataset_001.png :class: sphx-glr-single-img .. rst-class:: sphx-glr-timing **Total running time of the script:** (0 minutes 3.781 seconds) .. _sphx_glr_download_auto_examples_basics_demo_hf_dataset.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: demo_hf_dataset.ipynb ` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: demo_hf_dataset.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: demo_hf_dataset.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_