(DM Reconst.) Ch.2 Variational Perspective - From VAEs to DDPM

Diffusion Model Conceptual Reconstruction following The Principles of Diffusion Models



Hozy Summary

  • VAE (Naive Autoencoder $\rightarrow$ VAE $\rightarrow$ Gaussian VAE) cannot resolve the blurriness issue due to the averaging effect.
  • HVAE enables more expressive power with the progressive structure (deep, layered hierachical network).
    • Still, HVAE causes the posterior collapse problem if the layers get too deep.
      • i.e. Loss of control in generation.
  • DDPM
    • resembles…
      • the encoder-decoder structure of the VAEs, with forwardreveres structure
      • the deep layered structure of HVAE employing the timestep $t$
    • but
      • has the fixed encoder (not the target of learning)
      • has no per-level KL terms
        • instead well conditioned denoising subproblems from large to small noise
      • yields stable optimization
      • has high sample quality while preserving a coarse-to-fine hierarcy over time (noise)
    • yet
      • has slow sampling speed for accuracy



2.1 Variational Autoencoder (VAE)


Model) Naive Autoencoder

  • Structure)
    • Original Data $\rightarrow$ Encoder $\rightarrow$ Latent $\rightarrow$ Decoder $\rightarrow$ Reconstructed Data
  • Training)
    • Minimzie the reconstruction error between original input and reconsturction.
  • Drawback)
    • Unstructured latent space
      • i.e.) Latent codes produce meaningless outputs.
      • Sol.) VAE


Model) Variational Autoencoder

Kingma et al. 2013

  • Structure)
    • Decoder (Generator)
      • Settings)
        • $\mathbf{x}$ : observed variable
        • $\mathbf{z}\sim p_{\text{prior}}$ : latent variable
          • where $p_{\text{prior}} := \mathcal{N}(\mathbf{0,I})$
      • Goal)
        • Map $\mathbf{z}$ back to data space.
        • Sample $\mathbf{z}\sim p_{\text{prior}}$ and decode into $\mathbf{x}\sim p_\phi(\mathbf{x\mid z})$.
      • Def.)
        • $p_\phi(\mathbf{x\mid z})$ : a decoder distribution
      • Marginal likelihood)
        • $p_\phi(\mathbf{x}) = \displaystyle\int p_\phi(\mathbf{x\mid z}) p(\mathbf{z}) \text{d}\mathbf{z}$
          • Problem)
            • Integral over $\mathbf{z}$ is intractable!
            • Sol.) Variational step
    • Encoder (Inference Network)
      • Goal)
        • Given an observation $\mathbf{x}$, map to the latent code $\mathbf{z}$.
          • i.e.) $p_\phi(\mathbf{z\mid x})$
      • Problem)
        • $p_\phi(\mathbf{z\mid x})$ is intractable.
          • Why?)
            • $p_\phi(\mathbf{z\mid x}) = \displaystyle\frac{p_\phi(\mathbf{x\mid z}) p(\mathbf{z})}{p_\phi(\mathbf{x})}\quad(\because\text{Bayes’ Rule})$
            • $p_\phi(\mathbf{x})$ was intractable.
        • Sol.) Variational step.
  • Tech.)
    • Variational Step
      • Goal.)
        • Find a encoder $q_\theta(\mathbf{z\mid x})$ s.t. $q_\theta(\mathbf{z\mid x}) \approx p_\phi(\mathbf{z\mid x})$
        • How?)
          • Parameterize with a neural network.
      • Optimization)


Concept) Evidence Lower Bound (ELBO)

  • Props.)
    • Reconstruction Term
      • Encourages accurate recovery of $\mathbf{x}$ from the latent $\mathbf{z}$.
      • With Gaussian VAE guarantees the reconstruction loss of an autoencoder.
      • (RISK) Optimizing this alone has risk of memorizing the training data!
    • Latent Regularization (KL)
      • Encourages $q_\theta(\mathbf{z\mid x})$ (the encoder distribution) to stay close to $p_{\text{prior}}$ (the simple prior).
      • Shapes the latent space into a smooth and continuous structure.



Model) Gaussian VAE

  • Goal)
    • Employ Gaussian distributions for both the encoder and decoder.
  • Def.)
    • Encoder : $q_\theta(\mathbf{z\mid x}) := \mathcal{N}\left(\mathbf{z};\;\mu_\theta(\mathbf{x}), \text{diag}\left(\boldsymbol{\sigma}^2_\theta(\mathbf{x})\right)\right)$
      • where
        • $\mu_\theta(\mathbf{x}) : \mathbb{R}^D\rightarrow\mathbb{R}^d$
        • \(\boldsymbol{\sigma}^2_\theta : \mathbb{R}^D\rightarrow\mathbb{R}^d_+\).
    • Decoder : $p_\phi(\mathbf{z\mid x}) := \mathcal{N}\left(\mathbf{x};\;\mu_\phi(\mathbf{z}), \sigma^2\mathbf{I} \right)$
      • where
        • $\mu_\phi(\mathbf{x}) : \mathbb{R}^d\rightarrow\mathbb{R}^D$
        • Fixed variance with $\sigma\gt0$
  • Optimization) ELBO loss
  • Drawback)
    • Blurry Generation
      • Why?)
        • Averaging effect over conflicting modes
Desc of the averaging effect



Model) Hierarchical VAE (HVAE)

Vahdat et al., 2020

  • Structure)
    • Decoding
      • Forms chain of conditional priors
      • Top-down Hierarchy of multiple layers of latent variables.
        • Top-down factorization of the joint distributions
          • \(\displaystyle p_\phi(\mathbf{x, z}_{1:L}) = p_\phi(\mathbf{x, z}_{1}) \prod_{i=2}^L p_\phi(\mathbf{z}_{i-1}\mid\mathbf{z}_{i}) p(\mathbf{z}_{L})\).
      • Marginal data distribution
        • \(p_{\text{HVAE}}(\mathbf{x}) := \displaystyle\int p_\phi(\mathbf{x, z}_{1:L}) \text{d}\mathbf{z}_{1:L}\).
    • Encoding)
      • Bottom-up Markov Factorization
        • \(q_\theta(\mathbf{z}_{1:L}\mid\mathbf{x}) = q_\theta(\mathbf{z}_{1}\mid\mathbf{x}) \prod_{i=2}^L q_\theta(\mathbf{z}_i\mid\mathbf{z}_{i-2})\).
  • Optimization)
    • ELBO loss
Derivation

  • Desc.)
    • Distributed information penalty across levels and localizes learning signals through the adjacent KL terms.


  • Advantages)
    • Captures data features at multiple levels of abstraction.
    • Boosts expressive power
    • Mirrors the compositional nauure of real-world data
    • Utilizing the deep, stacked layers to build data
      • Deep networks!
  • Limits)
    • Single Gaussian variational family cannot match the multimodal $p_\phi(\mathbf{z\mid x})$.
    • If the decoder is too expressive, the model may suffer from posterior collapse.
      • i.e.) When decoder can model the data with out using $\mathbf{z}$, we lose control of the generation.
Desc. using the mutual information derived from the ELBO loss




2.2 Denoising Diffusion Probabilistic Models (DDPM)

Sohl-Dickstein et al., 2015; Ho et al., 2020


Concept) Forward Process (Fixed Encoder)

  • Goal)
    • Gradually corrupt data by injecting Guassian noise over multiple steps via a transition kernel of
      \(\begin{aligned} p(\mathbf{x}_{i}\mid \mathbf{x}_{i-1}) &:= \mathcal{N}\left( \mathbf{x}_i;\; \sqrt{1-\beta_i^2} \mathbf{x}_{i-1},\; \beta_i^2\mathbf{I} \right) \\ &\triangleq \mathcal{N}\left( \mathbf{x}_i;\; \alpha_i \mathbf{x}_{i-1},\; (1-\alpha_i^2)\mathbf{I} \right) & \left( \alpha_i \stackrel{\text{put}}{=} 1-\beta_i^2 \right) \\ \end{aligned}\).
      • where
        • \(\mathbf{x}_0\sim p_{\text{data}}\) : a sample drawn from the real data distribution
        • \(\left\{ \beta_i \right\}_{i=1}^L\) : a pre-determined monotoncially increasing noise schedule for \(\beta_i\in(0,1), \forall i\)
      • which is equivalent to updating $\mathbf{x}_i$ as
        • \(\mathbf{x}_i = \alpha_i \mathbf{x}_{i-1} + \beta_i \epsilon_i\).
          • for \(\epsilon_i\sim\mathcal{N}(\mathbf{0, I})\)
  • Desc.)
    • Closed-Form Expression for the distribution of noisy samples at step $i$
      • \(p(\mathbf{x}_{i}\mid \mathbf{x}_{0}) = \mathcal{N}\left( \mathbf{x}_i;\; \bar{\alpha}_i\mathbf{x}_0, (1-\bar{\alpha}_i^2)\mathbf{I} \right)\).
        • where \(\bar{\alpha} := \displaystyle\prod_{k=1}^i\sqrt{1-\beta_k^2} = \prod_{k=1}^k \alpha_k\)
    • Sampled $\mathbf{x}_i$
      • \(\bar{\alpha}_i\mathbf{x}_0 + \sqrt{1-\bar{\alpha}_i^2}\epsilon\).
        • for \(\epsilon_i\sim\mathcal{N}(\mathbf{0, I})\)
    • Convergence of the forward process
      • If \(\left\{ \beta_i \right\}_{i=1}^L\) is an increasing sequence
        • \(p_L(\mathbf{x}_L\mid\mathbf{x}_0)\rightarrow\mathcal{N}(\mathbf{0,1})\) as $L\rightarrow\infty$
      • Justifying $p_{\text{prior}} := \mathcal{N}(\mathbf{0,I})$


Concept) Reverse Process (Learnable Decoder)

  • Goal)
    • Reverse the noise corruption through a parameterized distribution \(p_\phi(\mathbf{x}_{i-1}\mid \mathbf{x}_{i})\)
  • Method)
    • To get the tractable marginal, we condition the intermediate noisy data on the clean data sample.
    • Then we may get the closed form reverse conditional transition kernel of
  • Parameterization)
    • \(p_\phi(\mathbf{x}_{i-1}\mid\mathbf{x}_i) := \mathcal{N}\left( \mathbf{x}_{i-1};\; \boldsymbol{\mu}_{\phi}(\mathbf{x}_i, i), \sigma^2(i)\mathbf{I} \right)\).
      • where
        • \(\boldsymbol{\mu}_{\phi}(\cdot, i) : \mathbb{R}^D\rightarrow\mathbb{R}^D\) : a learnable mean function
        • $\sigma^2(i)$ : a fixed variance function given from the reverse conditional transition kernel above.


Tech) Optimization of DDPM

  • Goal)
    • Minimize the expected KL divergence of
      • \(\mathbf{E}_{p_i(\mathbf{x}_i)}\left[ \mathcal{D}_{\text{KL}}\left( p(\mathbf{x}_{i-1}\mid \mathbf{x}_{i}) \Vert p_\phi(\mathbf{x}_{i-1}\mid \mathbf{x}_{i}) \right) \right]\).
  • Problem)
    • \(p(\mathbf{x}_{i-1}\mid \mathbf{x}_{i})\) is intractable.
      • Why?)
        • \(p_i(\mathbf{x}_i), p_{i-1}(\mathbf{x}_{i-1})\) are intractable where \(p(\mathbf{x}_{i-1}\mid \mathbf{x}_{i}) = p(\mathbf{x}_{i}\mid \mathbf{x}_{i-1}) \displaystyle\frac{p_{i-1}(\mathbf{x}_{i-1})}{p_i(\mathbf{x}_i)}\)
        • We don’t know $p_{\text{data}}$ so as \(p_i(\mathbf{x}_i)=\displaystyle\int p_i(\mathbf{x}_i\mid\mathbf{x}_0) p_{\text{data}}(\mathbf{x}_0)\text{d}\mathbf{x}_0\)
  • Sol.)
    • Condition the reverse transition on a clean data sample $\mathbf{x}$
      • i.e.) \(p(\mathbf{x}_{i-1}\mid \mathbf{x}_{i}, \mathbf{x}) = p(\mathbf{x}_{i}\mid \mathbf{x}_{i-1}) \displaystyle\frac{p_{i-1}(\mathbf{x}_{i-1} \mid \mathbf{x})}{p_i(\mathbf{x}_i \mid \mathbf{x})}\).
      • How does this work?
        1. Markov forward process assumption enables \(p(\mathbf{x}_{i}\mid \mathbf{x}_{i-1}, \mathbf{x}) = p(\mathbf{x}_{i}\mid \mathbf{x}_{i-1})\)
        2. Gaussian assumption on all involved distributions.


Tech.) DDPM’s Loss Function

  • Def.)
    • \(\mathcal{L}_{\text{DDPM}}(\phi) := \displaystyle\sum_{i=1}^L\frac{1}{2\sigma^2(i)}\mathbb{E}_{\mathbf{x}_0}\mathbb{E}_{p(\mathbf{x}_i\mid\mathbf{x}_0)}\left[ \left\Vert \boldsymbol{\mu}_\phi(\mathbf{x}_i, i) - \boldsymbol{\mu}(\mathbf{x}_i, \mathbf{x}_0, i) \right\Vert_2^2 \right]\).
Derivation of the Loss


  • Parameterizations of \(\mu_\phi(\mathbf{x}_i, i)\)
    • $\epsilon$-prediction
    • $\mathbf{x}$-prediction
      • cf.) Both are mathematically equivalent.
        • While the true forward process uses ground-truth $\mathbf{x}_0$ and real noise $\epsilon$, our predictions must obey this exact same structural identity:
          \(\begin{aligned} \mathbf{x}_i = \bar{\alpha}_i\mathbf{x}_{\phi}(\mathbf{x}_i, i) + \sqrt{1-\bar{\alpha}_i^2}\epsilon_\phi(\mathbf{x}_i, i) &\Leftrightarrow \mathbf{x}_{\phi}(\mathbf{x}_i, i) = \frac{1}{\bar{\alpha}_i} \left( \mathbf{x}_i - \sqrt{1-\bar{\alpha}_i^2}\epsilon_\phi(\mathbf{x}_i, i) \right) \\ &\Leftrightarrow \epsilon_\phi(\mathbf{x}_i, i) = \frac{1}{\sqrt{1-\bar{\alpha}_i^2}} \left( \mathbf{x}_i - \bar{\alpha}_i\mathbf{x}_{\phi}(\mathbf{x}_i, i) \right) \\ \end{aligned}\)
  • ELBO


Tech.) ε-prediction

  • Goal)
    • Train a neural network to estimate the noise added.
  • Def.)
    • \(\mathcal{L}_{\text{simple}}(\phi) := \mathbb{E}_{i}\mathbb{E}_{\mathbf{x}\sim p_{\text{data}}}\mathbb{E}_{\epsilon\sim\mathcal{N}(\mathbf{0,I})} \left[ \Vert \epsilon_\phi(\mathbf{x}_i, i) - \epsilon \Vert_2^2 \right]\).
      • where \(\mathbf{x}_i = \bar{\alpha}_i\mathbf{x}_0 + \sqrt{1-\bar{\alpha}_i^2} \epsilon\) with \(\mathbf{x}_0\sim p_{\text{data}}\).
Derivation of the Loss


  • Prop.)
    • \(\mathcal{L}_{\text{DDPM}}\) and \(\mathcal{L}_{\text{simple}}\) share the same optimal solution of
      • \(\epsilon^*(\mathbf{x}_i, i) = \mathbb{E}[\epsilon\mid\mathbf{x}_i]\) for \(\mathbf{x}_i\sim p_i\)
        • i.e.) Noise estimated by the ε-prediction coincides with the conditional expectation of the true noise, even though $\mathbf{x}_i$ does not uniquely determine the original clean sample.


Tech.) x-prediction

  • Goal)
    • Train a neural network to estimate the clean image from a given noisy input \(\mathbf{x}_i\sim p_i(\mathbf{x}_i)\).
  • Def.)
    • \(\mathcal{L}_{\text{x-pred}}(\phi) := \mathbb{E}_{i}\mathbb{E}_{\mathbf{x}_0 \sim p_{\text{data}}}\mathbb{E}_{\epsilon\sim\mathcal{N}(\mathbf{0,I})} \left[ \omega_i \Vert \mathbf{x}_\phi(\mathbf{x}_i, i) - \mathbf{x}_0 \Vert_2^2 \right]\).
      • where
        • \(\mathbf{x}_0\sim p_{\text{data}}\).
        • \(\omega_i\) : some weighting function
Derivation of the Loss



Tech.) DDPM’s Sampling

  • Settings)
    • \(\epsilon_{\phi^\times}(\mathbf{x}_i, i)\) : a trained (frozen) noise
  • Procedure)
    • For $i=L,\ldots,1$,
      • \(\mathbf{x}_{i-1} \leftarrow \displaystyle \underbrace{\frac{1}{\alpha_i} \left(\mathbf{x}_{i} - \frac{1-\alpha_i^2}{\sqrt{1-\bar{\alpha}_i^2}}\epsilon_{\phi^\times}(\mathbf{x}_{i}, i)\right)}_{\mu_{\phi^\times}(\mathbf{x}_{i}, i)} + \sigma(i)\epsilon_i\quad\) for \(\epsilon_i\sim\mathcal{N}(\mathbf{0,I})\)
    • cf.) Equivalent to sampling \(\mathbf{x}_{i}\sim p_{\phi^\times}(\mathbf{x}_i\mid\mathbf{x_i})\).
  • Interpretation)
    • Plugging in the paramterziation identity \(\epsilon_{\phi^\times}(\mathbf{x}_i, i) = \displaystyle\frac{\mathbf{x}_i - \bar{\alpha}_i\mathbf{x}_{\phi^\times}(\mathbf{x}_i, i)}{\sqrt{1-\bar{\alpha}_i^2}}\), the sampling at step $i$ can be rewritten as
      • \(\mathbf{x}_{i-1} \leftarrow \displaystyle \underbrace{c_1\mathbf{x}_{i} +c_2 \mathbf{x}_{\phi^\times}(\mathbf{x}_i, i)}_{\text{interpolation}} + \sigma(i)\epsilon_i\quad\).
        • i.e.) Each step (\(\mathbf{x}_{i-1}\)) is centered around the predicted clean sample (\(\mathbf{x}_{\phi^\times}(\mathbf{x}_i, i)\)), with added Gaussian noise scaled by \(\sigma(i)\).
        • Decomposed into two steps.
          1. Estimate the clean data \(\mathbf{x}_{\phi^\times}(\mathbf{x}_i, i)\) from the noisy input from the previous step \(\mathbf{x}_i\).
          2. Sample a less noisy latent \(\mathbf{x}_{i-1}\) using \(\mathbf{x}_{\phi^\times}(\mathbf{x}_i, i)\).
  • Props.)
    • Early sample stpes set the global structure, and later steps add fine detail.
      • Even if \(\mathbf{x}_{\phi^\times}(\mathbf{x}_i, i)\) is trained as the optimal denoiser, it can only predict the average clean samplle given \(\mathbf{x}_i\).
        • Thus, this leads to the blurry samples particularily at high noise level (\(i\approx L\gg1\))
      • However, as the sampling proceeds to the low noise level (\(i\rightarrow1\)), it pregressively refines an estimate of the clean signal.
    • Slow Sampling
      • Why?)
        • \(p_{\phi}(\mathbf{x}_{i−1}\mid\mathbf{x}_{i})\) is typically modeled as a Gaussian to approximate \(p(\mathbf{x}_{i−1}\mid\mathbf{x}_{i})\), limiting its expressiveness.
        • For small forward noise scales $\beta_i$, the true reverse distribution is approximately Gaussian, enabling accurate approximation.
        • Conversely, large $\beta_i$ induce multimodality or strong non-Gaussianity that a single Gaussian cannot capture.
        • Thus, to maintain the accuracy, DDPM employs many small $\beta_i$ steps.



Enjoy Reading This Article?

Here are some more articles you might like to read next:

  • (DM Reconst.) Ch.3 Score-Based Perspective - From EBMs to NCSN
  • Denoising Diffusion Probabilistic Models (DDPM)
  • Variational Autoencoder Bayes (VAE)
  • Score-Based Generative Modeling through Stochastic Differential Equation
  • Flow Straight and Fast - Learning to Generate and Transfer Data with Rectified Flow (Rectified Flow)