.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/sampling/demo_flow_matching.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note New to DeepInverse? Get started with the basics with the :ref:`5 minute quickstart tutorial `.. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_sampling_demo_flow_matching.py: Flow-Matching for posterior sampling and unconditional generation ================================================================== This demo shows you how to perform unconditional image generation and posterior sampling using Flow Matching (FM). Flow matching consists in building a continuous transportation between a reference distribution :math:`p_1` which is easy to sample from (e.g., a Gaussian distribution) and the data distribution :math:`p_0`. Sampling is done by solving the following ordinary differential equation (ODE) defined by a time-dependent velocity field :math:`v_\theta(x,t)`: .. math:: \frac{dx_t}{dt} = v_\theta(x_t,t), \quad x_0 \sim p_0 \quad t \in [0,1] The velocity field :math:`v_\theta(x,t)` is typically trained to approximate the conditional expectation: .. math:: v_\theta(x_t,t) \approx \mathbb{E}_{x_0 \sim p_0, x_1 \sim p_1}\Big[ \frac{d}{dt} x_t | x_t = a(t) x_0 + b(t) x_1 \Big] where :math:`a(t)` and :math:`b(t)` are interpolation coefficients such that :math:`x_t` interpolates between :math:`x_0` and :math:`x_1`. When the reference distribution :math:`p_0` is the standard Gaussian, the velocity field can be expressed as a function of a Gaussian denoiser :math:`D(x, \sigma)` as follows: .. math:: v_\theta(x_t,t) = - \frac{b'(t)}{b(t)} x_t + \frac{1}{2}\frac{a(t) b'(t) - a'(t) b(t)}{a(t) b(t)} \left(D\left(\frac{x_t}{a(t)}, \frac{b(t)}{a(t)} \right) - x_t\right) The most common choice of time schedulers is the linear schedule :math:`a(t) = 1 - t` and :math:`b(t) = t`. In this demo, we will show how to : - Perform unconditional generation using, instead of a trained denoiser, the closed-form MMSE denoiser .. math:: D(x, \sigma) = \mathbb{E}_{x_0 \sim p_{data}, \epsilon \sim \mathcal{N}(0, I)} \Big[ x_0 | x = x_0 + \sigma \epsilon \Big] Given a dataset of clean images, it can be computed by evaluating the distance between the input image and all the points of the dataset (see :class:`deepinv.models.MMSE`). - Perform posterior sampling using Flow-Matching combined with a DPS data fidelity term (see :ref:`sphx_glr_auto_examples_sampling_demo_diffusion_sde.py` for more details) - Explore different choices of time schedulers :math:`a(t)` and :math:`b(t)`. .. GENERATED FROM PYTHON SOURCE LINES 42-54 .. code-block:: Python import torch import deepinv as dinv from deepinv.sampling import ( PosteriorDiffusion, DPSDataFidelity, EulerSolver, FlowMatching, ) import numpy as np from torchvision import datasets, transforms from deepinv.models import MMSE .. GENERATED FROM PYTHON SOURCE LINES 55-60 ----------------------------- We start by working with the closed-form MMSE denoser. It is calculated by computing the distance between the input image and all the points of the dataset. This can be quite long to compute for large images and large datasets. In this toy example, we use the validation set of MNIST. When using this closed-form MMSE denoiser, the sampling is guaranteed to output an image of the dataset. .. GENERATED FROM PYTHON SOURCE LINES 60-80 .. code-block:: Python device = dinv.utils.get_device() dtype = torch.float32 figsize = 2.5 # We use the closed-form MMSE denoiser defined using as atoms the testset of MNIST. # The deepinv MMSE denoiser takes as input a dataloader. dataset = datasets.MNIST( root=".", train=False, download=True, transform=transforms.ToTensor() ) dataloader = torch.utils.data.DataLoader(dataset, batch_size=1000, shuffle=False) n_max = ( 1000 # limit the number of images to speed up the computation of the MMSE denoiser ) tensors = torch.cat([data[0] for data in iter(dataloader)], dim=0) # (N,1,28,28) tensors = tensors[:n_max].to(device) denoiser = MMSE(dataloader=tensors, device=device, dtype=dtype) .. rst-class:: sphx-glr-script-out .. code-block:: none Selected GPU 0 with 4989.25 MiB free memory 0%| | 0.00/9.91M [00:00` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: demo_flow_matching.py ` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: demo_flow_matching.zip ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_