(DM Reconst.) Ch.2 Variational Perspective - From VAEs to DDPM
Diffusion Model Conceptual Reconstruction following The Principles of Diffusion Models
Hozy Summary
- VAE (Naive Autoencoder $\rightarrow$ VAE $\rightarrow$ Gaussian VAE) cannot resolve the blurriness issue due to the averaging effect.
- HVAE enables more expressive power with the progressive structure (deep, layered hierachical network).
- Still, HVAE causes the posterior collapse problem if the layers get too deep.
- i.e. Loss of control in generation.
- Still, HVAE causes the posterior collapse problem if the layers get too deep.
- DDPM
- resembles…
- but
- has the fixed encoder (not the target of learning)
- has no per-level KL terms
- instead well conditioned denoising subproblems from large to small noise
- yields stable optimization
- has high sample quality while preserving a coarse-to-fine hierarcy over time (noise)
- yet
- has slow sampling speed for accuracy
2.1 Variational Autoencoder (VAE)
Model) Naive Autoencoder
- Structure)
- Original Data $\rightarrow$ Encoder $\rightarrow$ Latent $\rightarrow$ Decoder $\rightarrow$ Reconstructed Data
- Training)
- Minimzie the reconstruction error between original input and reconsturction.
- Drawback)
- Unstructured latent space
- i.e.) Latent codes produce meaningless outputs.
- Sol.) VAE
- Unstructured latent space
Model) Variational Autoencoder
Kingma et al. 2013
- Structure)
- Decoder (Generator)
- Settings)
- $\mathbf{x}$ : observed variable
- $\mathbf{z}\sim p_{\text{prior}}$ : latent variable
- where $p_{\text{prior}} := \mathcal{N}(\mathbf{0,I})$
- Goal)
- Map $\mathbf{z}$ back to data space.
- Sample $\mathbf{z}\sim p_{\text{prior}}$ and decode into $\mathbf{x}\sim p_\phi(\mathbf{x\mid z})$.
- Def.)
- $p_\phi(\mathbf{x\mid z})$ : a decoder distribution
- Marginal likelihood)
- $p_\phi(\mathbf{x}) = \displaystyle\int p_\phi(\mathbf{x\mid z}) p(\mathbf{z}) \text{d}\mathbf{z}$
- Problem)
- Integral over $\mathbf{z}$ is intractable!
- Sol.) Variational step
- Problem)
- $p_\phi(\mathbf{x}) = \displaystyle\int p_\phi(\mathbf{x\mid z}) p(\mathbf{z}) \text{d}\mathbf{z}$
- Settings)
- Encoder (Inference Network)
- Goal)
- Given an observation $\mathbf{x}$, map to the latent code $\mathbf{z}$.
- i.e.) $p_\phi(\mathbf{z\mid x})$
- Given an observation $\mathbf{x}$, map to the latent code $\mathbf{z}$.
- Problem)
- $p_\phi(\mathbf{z\mid x})$ is intractable.
- Why?)
- $p_\phi(\mathbf{z\mid x}) = \displaystyle\frac{p_\phi(\mathbf{x\mid z}) p(\mathbf{z})}{p_\phi(\mathbf{x})}\quad(\because\text{Bayes’ Rule})$
- $p_\phi(\mathbf{x})$ was intractable.
- Why?)
- Sol.) Variational step.
- $p_\phi(\mathbf{z\mid x})$ is intractable.
- Goal)
- Decoder (Generator)
- Tech.)
- Variational Step
- Goal.)
- Find a encoder $q_\theta(\mathbf{z\mid x})$ s.t. $q_\theta(\mathbf{z\mid x}) \approx p_\phi(\mathbf{z\mid x})$
- How?)
- Parameterize with a neural network.
- Optimization)
- Minimize the ELBO loss.
- Goal.)
- Variational Step
Concept) Evidence Lower Bound (ELBO)
- Props.)
- Reconstruction Term
- Encourages accurate recovery of $\mathbf{x}$ from the latent $\mathbf{z}$.
- With Gaussian VAE guarantees the reconstruction loss of an autoencoder.
- (RISK) Optimizing this alone has risk of memorizing the training data!
- Latent Regularization (KL)
- Encourages $q_\theta(\mathbf{z\mid x})$ (the encoder distribution) to stay close to $p_{\text{prior}}$ (the simple prior).
- Shapes the latent space into a smooth and continuous structure.
- Reconstruction Term
Model) Gaussian VAE
- Goal)
- Employ Gaussian distributions for both the encoder and decoder.
- Def.)
- Encoder : $q_\theta(\mathbf{z\mid x}) := \mathcal{N}\left(\mathbf{z};\;\mu_\theta(\mathbf{x}), \text{diag}\left(\boldsymbol{\sigma}^2_\theta(\mathbf{x})\right)\right)$
- where
- $\mu_\theta(\mathbf{x}) : \mathbb{R}^D\rightarrow\mathbb{R}^d$
- \(\boldsymbol{\sigma}^2_\theta : \mathbb{R}^D\rightarrow\mathbb{R}^d_+\).
- where
- Decoder : $p_\phi(\mathbf{z\mid x}) := \mathcal{N}\left(\mathbf{x};\;\mu_\phi(\mathbf{z}), \sigma^2\mathbf{I} \right)$
- where
- $\mu_\phi(\mathbf{x}) : \mathbb{R}^d\rightarrow\mathbb{R}^D$
- Fixed variance with $\sigma\gt0$
- where
- Encoder : $q_\theta(\mathbf{z\mid x}) := \mathcal{N}\left(\mathbf{z};\;\mu_\theta(\mathbf{x}), \text{diag}\left(\boldsymbol{\sigma}^2_\theta(\mathbf{x})\right)\right)$
- Optimization) ELBO loss
- Drawback)
- Blurry Generation
- Why?)
- Averaging effect over conflicting modes
- Why?)
- Blurry Generation
Desc of the averaging effect
Model) Hierarchical VAE (HVAE)
Vahdat et al., 2020
- Structure)
- Decoding
- Forms chain of conditional priors
- Top-down Hierarchy of multiple layers of latent variables.
- Top-down factorization of the joint distributions
- \(\displaystyle p_\phi(\mathbf{x, z}_{1:L}) = p_\phi(\mathbf{x, z}_{1}) \prod_{i=2}^L p_\phi(\mathbf{z}_{i-1}\mid\mathbf{z}_{i}) p(\mathbf{z}_{L})\).
- Top-down factorization of the joint distributions
- Marginal data distribution
- \(p_{\text{HVAE}}(\mathbf{x}) := \displaystyle\int p_\phi(\mathbf{x, z}_{1:L}) \text{d}\mathbf{z}_{1:L}\).
- Encoding)
- Bottom-up Markov Factorization
- \(q_\theta(\mathbf{z}_{1:L}\mid\mathbf{x}) = q_\theta(\mathbf{z}_{1}\mid\mathbf{x}) \prod_{i=2}^L q_\theta(\mathbf{z}_i\mid\mathbf{z}_{i-2})\).
- Bottom-up Markov Factorization
- Decoding
- Optimization)
- ELBO loss
Derivation
- Desc.)
- Distributed information penalty across levels and localizes learning signals through the adjacent KL terms.
- Advantages)
- Captures data features at multiple levels of abstraction.
- Boosts expressive power
- Mirrors the compositional nauure of real-world data
- Utilizing the deep, stacked layers to build data
- Deep networks!
- Limits)
- Single Gaussian variational family cannot match the multimodal $p_\phi(\mathbf{z\mid x})$.
- If the decoder is too expressive, the model may suffer from posterior collapse.
- i.e.) When decoder can model the data with out using $\mathbf{z}$, we lose control of the generation.
Desc. using the mutual information derived from the ELBO loss
2.2 Denoising Diffusion Probabilistic Models (DDPM)
Sohl-Dickstein et al., 2015; Ho et al., 2020
Concept) Forward Process (Fixed Encoder)
- Goal)
- Gradually corrupt data by injecting Guassian noise over multiple steps via a transition kernel of
\(\begin{aligned} p(\mathbf{x}_{i}\mid \mathbf{x}_{i-1}) &:= \mathcal{N}\left( \mathbf{x}_i;\; \sqrt{1-\beta_i^2} \mathbf{x}_{i-1},\; \beta_i^2\mathbf{I} \right) \\ &\triangleq \mathcal{N}\left( \mathbf{x}_i;\; \alpha_i \mathbf{x}_{i-1},\; (1-\alpha_i^2)\mathbf{I} \right) & \left( \alpha_i \stackrel{\text{put}}{=} 1-\beta_i^2 \right) \\ \end{aligned}\).- where
- \(\mathbf{x}_0\sim p_{\text{data}}\) : a sample drawn from the real data distribution
- \(\left\{ \beta_i \right\}_{i=1}^L\) : a pre-determined monotoncially increasing noise schedule for \(\beta_i\in(0,1), \forall i\)
- which is equivalent to updating $\mathbf{x}_i$ as
- \(\mathbf{x}_i = \alpha_i \mathbf{x}_{i-1} + \beta_i \epsilon_i\).
- for \(\epsilon_i\sim\mathcal{N}(\mathbf{0, I})\)
- \(\mathbf{x}_i = \alpha_i \mathbf{x}_{i-1} + \beta_i \epsilon_i\).
- where
- Gradually corrupt data by injecting Guassian noise over multiple steps via a transition kernel of
- Desc.)
- Closed-Form Expression for the distribution of noisy samples at step $i$
- \(p(\mathbf{x}_{i}\mid \mathbf{x}_{0}) = \mathcal{N}\left( \mathbf{x}_i;\; \bar{\alpha}_i\mathbf{x}_0, (1-\bar{\alpha}_i^2)\mathbf{I} \right)\).
- where \(\bar{\alpha} := \displaystyle\prod_{k=1}^i\sqrt{1-\beta_k^2} = \prod_{k=1}^k \alpha_k\)
- \(p(\mathbf{x}_{i}\mid \mathbf{x}_{0}) = \mathcal{N}\left( \mathbf{x}_i;\; \bar{\alpha}_i\mathbf{x}_0, (1-\bar{\alpha}_i^2)\mathbf{I} \right)\).
- Sampled $\mathbf{x}_i$
- \(\bar{\alpha}_i\mathbf{x}_0 + \sqrt{1-\bar{\alpha}_i^2}\epsilon\).
- for \(\epsilon_i\sim\mathcal{N}(\mathbf{0, I})\)
- \(\bar{\alpha}_i\mathbf{x}_0 + \sqrt{1-\bar{\alpha}_i^2}\epsilon\).
- Convergence of the forward process
- If \(\left\{ \beta_i \right\}_{i=1}^L\) is an increasing sequence
- \(p_L(\mathbf{x}_L\mid\mathbf{x}_0)\rightarrow\mathcal{N}(\mathbf{0,1})\) as $L\rightarrow\infty$
- Justifying $p_{\text{prior}} := \mathcal{N}(\mathbf{0,I})$
- If \(\left\{ \beta_i \right\}_{i=1}^L\) is an increasing sequence
- Closed-Form Expression for the distribution of noisy samples at step $i$
Concept) Reverse Process (Learnable Decoder)
- Goal)
- Reverse the noise corruption through a parameterized distribution \(p_\phi(\mathbf{x}_{i-1}\mid \mathbf{x}_{i})\)
- Method)
- To get the tractable marginal, we condition the intermediate noisy data on the clean data sample.
- Refer to the DDPM optimization problem below.
- Then we may get the closed form reverse conditional transition kernel of
- To get the tractable marginal, we condition the intermediate noisy data on the clean data sample.
- Parameterization)
- \(p_\phi(\mathbf{x}_{i-1}\mid\mathbf{x}_i) := \mathcal{N}\left( \mathbf{x}_{i-1};\; \boldsymbol{\mu}_{\phi}(\mathbf{x}_i, i), \sigma^2(i)\mathbf{I} \right)\).
- where
- \(\boldsymbol{\mu}_{\phi}(\cdot, i) : \mathbb{R}^D\rightarrow\mathbb{R}^D\) : a learnable mean function
- $\sigma^2(i)$ : a fixed variance function given from the reverse conditional transition kernel above.
- where
- \(p_\phi(\mathbf{x}_{i-1}\mid\mathbf{x}_i) := \mathcal{N}\left( \mathbf{x}_{i-1};\; \boldsymbol{\mu}_{\phi}(\mathbf{x}_i, i), \sigma^2(i)\mathbf{I} \right)\).
Tech) Optimization of DDPM
- Goal)
- Minimize the expected KL divergence of
- \(\mathbf{E}_{p_i(\mathbf{x}_i)}\left[ \mathcal{D}_{\text{KL}}\left( p(\mathbf{x}_{i-1}\mid \mathbf{x}_{i}) \Vert p_\phi(\mathbf{x}_{i-1}\mid \mathbf{x}_{i}) \right) \right]\).
- Minimize the expected KL divergence of
- Problem)
- \(p(\mathbf{x}_{i-1}\mid \mathbf{x}_{i})\) is intractable.
- Why?)
- \(p_i(\mathbf{x}_i), p_{i-1}(\mathbf{x}_{i-1})\) are intractable where \(p(\mathbf{x}_{i-1}\mid \mathbf{x}_{i}) = p(\mathbf{x}_{i}\mid \mathbf{x}_{i-1}) \displaystyle\frac{p_{i-1}(\mathbf{x}_{i-1})}{p_i(\mathbf{x}_i)}\)
- We don’t know $p_{\text{data}}$ so as \(p_i(\mathbf{x}_i)=\displaystyle\int p_i(\mathbf{x}_i\mid\mathbf{x}_0) p_{\text{data}}(\mathbf{x}_0)\text{d}\mathbf{x}_0\)
- Why?)
- \(p(\mathbf{x}_{i-1}\mid \mathbf{x}_{i})\) is intractable.
- Sol.)
- Condition the reverse transition on a clean data sample $\mathbf{x}$
- i.e.) \(p(\mathbf{x}_{i-1}\mid \mathbf{x}_{i}, \mathbf{x}) = p(\mathbf{x}_{i}\mid \mathbf{x}_{i-1}) \displaystyle\frac{p_{i-1}(\mathbf{x}_{i-1} \mid \mathbf{x})}{p_i(\mathbf{x}_i \mid \mathbf{x})}\).
- How does this work?
- Markov forward process assumption enables \(p(\mathbf{x}_{i}\mid \mathbf{x}_{i-1}, \mathbf{x}) = p(\mathbf{x}_{i}\mid \mathbf{x}_{i-1})\)
- Gaussian assumption on all involved distributions.
- Condition the reverse transition on a clean data sample $\mathbf{x}$
Tech.) DDPM’s Loss Function
- Def.)
- \(\mathcal{L}_{\text{DDPM}}(\phi) := \displaystyle\sum_{i=1}^L\frac{1}{2\sigma^2(i)}\mathbb{E}_{\mathbf{x}_0}\mathbb{E}_{p(\mathbf{x}_i\mid\mathbf{x}_0)}\left[ \left\Vert \boldsymbol{\mu}_\phi(\mathbf{x}_i, i) - \boldsymbol{\mu}(\mathbf{x}_i, \mathbf{x}_0, i) \right\Vert_2^2 \right]\).
Derivation of the Loss
- Parameterizations of \(\mu_\phi(\mathbf{x}_i, i)\)
- $\epsilon$-prediction
- $\mathbf{x}$-prediction
- cf.) Both are mathematically equivalent.
- While the true forward process uses ground-truth $\mathbf{x}_0$ and real noise $\epsilon$, our predictions must obey this exact same structural identity:
\(\begin{aligned} \mathbf{x}_i = \bar{\alpha}_i\mathbf{x}_{\phi}(\mathbf{x}_i, i) + \sqrt{1-\bar{\alpha}_i^2}\epsilon_\phi(\mathbf{x}_i, i) &\Leftrightarrow \mathbf{x}_{\phi}(\mathbf{x}_i, i) = \frac{1}{\bar{\alpha}_i} \left( \mathbf{x}_i - \sqrt{1-\bar{\alpha}_i^2}\epsilon_\phi(\mathbf{x}_i, i) \right) \\ &\Leftrightarrow \epsilon_\phi(\mathbf{x}_i, i) = \frac{1}{\sqrt{1-\bar{\alpha}_i^2}} \left( \mathbf{x}_i - \bar{\alpha}_i\mathbf{x}_{\phi}(\mathbf{x}_i, i) \right) \\ \end{aligned}\)
- While the true forward process uses ground-truth $\mathbf{x}_0$ and real noise $\epsilon$, our predictions must obey this exact same structural identity:
- cf.) Both are mathematically equivalent.
- ELBO
Tech.) ε-prediction
- Goal)
- Train a neural network to estimate the noise added.
- Def.)
- \(\mathcal{L}_{\text{simple}}(\phi) := \mathbb{E}_{i}\mathbb{E}_{\mathbf{x}\sim p_{\text{data}}}\mathbb{E}_{\epsilon\sim\mathcal{N}(\mathbf{0,I})} \left[ \Vert \epsilon_\phi(\mathbf{x}_i, i) - \epsilon \Vert_2^2 \right]\).
- where \(\mathbf{x}_i = \bar{\alpha}_i\mathbf{x}_0 + \sqrt{1-\bar{\alpha}_i^2} \epsilon\) with \(\mathbf{x}_0\sim p_{\text{data}}\).
- \(\mathcal{L}_{\text{simple}}(\phi) := \mathbb{E}_{i}\mathbb{E}_{\mathbf{x}\sim p_{\text{data}}}\mathbb{E}_{\epsilon\sim\mathcal{N}(\mathbf{0,I})} \left[ \Vert \epsilon_\phi(\mathbf{x}_i, i) - \epsilon \Vert_2^2 \right]\).
Derivation of the Loss
- Prop.)
- \(\mathcal{L}_{\text{DDPM}}\) and \(\mathcal{L}_{\text{simple}}\) share the same optimal solution of
- \(\epsilon^*(\mathbf{x}_i, i) = \mathbb{E}[\epsilon\mid\mathbf{x}_i]\) for \(\mathbf{x}_i\sim p_i\)
- i.e.) Noise estimated by the ε-prediction coincides with the conditional expectation of the true noise, even though $\mathbf{x}_i$ does not uniquely determine the original clean sample.
- \(\epsilon^*(\mathbf{x}_i, i) = \mathbb{E}[\epsilon\mid\mathbf{x}_i]\) for \(\mathbf{x}_i\sim p_i\)
- \(\mathcal{L}_{\text{DDPM}}\) and \(\mathcal{L}_{\text{simple}}\) share the same optimal solution of
Tech.) x-prediction
- Goal)
- Train a neural network to estimate the clean image from a given noisy input \(\mathbf{x}_i\sim p_i(\mathbf{x}_i)\).
- Def.)
- \(\mathcal{L}_{\text{x-pred}}(\phi) := \mathbb{E}_{i}\mathbb{E}_{\mathbf{x}_0 \sim p_{\text{data}}}\mathbb{E}_{\epsilon\sim\mathcal{N}(\mathbf{0,I})} \left[ \omega_i \Vert \mathbf{x}_\phi(\mathbf{x}_i, i) - \mathbf{x}_0 \Vert_2^2 \right]\).
- where
- \(\mathbf{x}_0\sim p_{\text{data}}\).
- \(\omega_i\) : some weighting function
- where
- \(\mathcal{L}_{\text{x-pred}}(\phi) := \mathbb{E}_{i}\mathbb{E}_{\mathbf{x}_0 \sim p_{\text{data}}}\mathbb{E}_{\epsilon\sim\mathcal{N}(\mathbf{0,I})} \left[ \omega_i \Vert \mathbf{x}_\phi(\mathbf{x}_i, i) - \mathbf{x}_0 \Vert_2^2 \right]\).
Derivation of the Loss
Tech.) DDPM’s Sampling
- Settings)
- \(\epsilon_{\phi^\times}(\mathbf{x}_i, i)\) : a trained (frozen) noise
- assuming the $\epsilon$-prediction
- \(\epsilon_{\phi^\times}(\mathbf{x}_i, i)\) : a trained (frozen) noise
- Procedure)
- For $i=L,\ldots,1$,
- \(\mathbf{x}_{i-1} \leftarrow \displaystyle \underbrace{\frac{1}{\alpha_i} \left(\mathbf{x}_{i} - \frac{1-\alpha_i^2}{\sqrt{1-\bar{\alpha}_i^2}}\epsilon_{\phi^\times}(\mathbf{x}_{i}, i)\right)}_{\mu_{\phi^\times}(\mathbf{x}_{i}, i)} + \sigma(i)\epsilon_i\quad\) for \(\epsilon_i\sim\mathcal{N}(\mathbf{0,I})\)
- cf.) Equivalent to sampling \(\mathbf{x}_{i}\sim p_{\phi^\times}(\mathbf{x}_i\mid\mathbf{x_i})\).
- For $i=L,\ldots,1$,
- Interpretation)
- Plugging in the paramterziation identity \(\epsilon_{\phi^\times}(\mathbf{x}_i, i) = \displaystyle\frac{\mathbf{x}_i - \bar{\alpha}_i\mathbf{x}_{\phi^\times}(\mathbf{x}_i, i)}{\sqrt{1-\bar{\alpha}_i^2}}\), the sampling at step $i$ can be rewritten as
- \(\mathbf{x}_{i-1} \leftarrow \displaystyle \underbrace{c_1\mathbf{x}_{i} +c_2 \mathbf{x}_{\phi^\times}(\mathbf{x}_i, i)}_{\text{interpolation}} + \sigma(i)\epsilon_i\quad\).
- i.e.) Each step (\(\mathbf{x}_{i-1}\)) is centered around the predicted clean sample (\(\mathbf{x}_{\phi^\times}(\mathbf{x}_i, i)\)), with added Gaussian noise scaled by \(\sigma(i)\).
- Decomposed into two steps.
- Estimate the clean data \(\mathbf{x}_{\phi^\times}(\mathbf{x}_i, i)\) from the noisy input from the previous step \(\mathbf{x}_i\).
- Sample a less noisy latent \(\mathbf{x}_{i-1}\) using \(\mathbf{x}_{\phi^\times}(\mathbf{x}_i, i)\).
- \(\mathbf{x}_{i-1} \leftarrow \displaystyle \underbrace{c_1\mathbf{x}_{i} +c_2 \mathbf{x}_{\phi^\times}(\mathbf{x}_i, i)}_{\text{interpolation}} + \sigma(i)\epsilon_i\quad\).
- Plugging in the paramterziation identity \(\epsilon_{\phi^\times}(\mathbf{x}_i, i) = \displaystyle\frac{\mathbf{x}_i - \bar{\alpha}_i\mathbf{x}_{\phi^\times}(\mathbf{x}_i, i)}{\sqrt{1-\bar{\alpha}_i^2}}\), the sampling at step $i$ can be rewritten as
- Props.)
- Early sample stpes set the global structure, and later steps add fine detail.
- Even if \(\mathbf{x}_{\phi^\times}(\mathbf{x}_i, i)\) is trained as the optimal denoiser, it can only predict the average clean samplle given \(\mathbf{x}_i\).
- Thus, this leads to the blurry samples particularily at high noise level (\(i\approx L\gg1\))
- cf.) Gaussian VAE
- Thus, this leads to the blurry samples particularily at high noise level (\(i\approx L\gg1\))
- However, as the sampling proceeds to the low noise level (\(i\rightarrow1\)), it pregressively refines an estimate of the clean signal.
- Even if \(\mathbf{x}_{\phi^\times}(\mathbf{x}_i, i)\) is trained as the optimal denoiser, it can only predict the average clean samplle given \(\mathbf{x}_i\).
- Slow Sampling
- Why?)
- \(p_{\phi}(\mathbf{x}_{i−1}\mid\mathbf{x}_{i})\) is typically modeled as a Gaussian to approximate \(p(\mathbf{x}_{i−1}\mid\mathbf{x}_{i})\), limiting its expressiveness.
- For small forward noise scales $\beta_i$, the true reverse distribution is approximately Gaussian, enabling accurate approximation.
- Conversely, large $\beta_i$ induce multimodality or strong non-Gaussianity that a single Gaussian cannot capture.
- Thus, to maintain the accuracy, DDPM employs many small $\beta_i$ steps.
- Why?)
- Early sample stpes set the global structure, and later steps add fine detail.
Enjoy Reading This Article?
Here are some more articles you might like to read next: