(DM Reconst.) Ch.6 A Unified and Systemic Lens on Diffusion Models

Diffusion Model Conceptual Reconstruction following The Principles of Diffusion Models



Hozy Summary




Concept) Four Prediction Types

  • Types)
    • $\epsilon$-Prediction
    • $\mathbf{x}$-Prediction
    • Score-Prediction
    • $\mathbf{v}$-Prediction


Variational View Score-Based View Flow-Based View
Intractable original training objective $$\mathcal{J}_{\text{KL}}(\phi) := \mathbb{E}_{p_t(\mathbf{x}_t)} \left[ \mathcal{D}_{\text{KL}}(p(\mathbf{x}_{t-\Delta t}\mid\mathbf{x}_t)) \;\Vert\; p_\phi(\mathbf{x}_{t-\Delta t}\mid\mathbf{x}_t) \right]$$ $$\mathcal{J}_{\text{SM}}(\phi) := \mathbb{E}_{p_t(\mathbf{x}_t)} \left[ \Vert s_\phi(\mathbf{x}_t, t) - \nabla_{\mathbf{x}}\log p_t(\mathbf{x}_t) \Vert_2^2 \right]$$ $$\mathcal{J}_{\text{FM}}(\phi) := \mathbb{E}_{p_t(\mathbf{x}_t)} \left[ \Vert \mathbf{v}_\phi(\mathbf{x}_t, t) - \mathbf{v}_t(\mathbf{x}_t) \Vert_2^2 \right]$$
Tractable objectives by conditioning on the data $$\mathbf{x}_0\sim p_{\text{data}}$$ $$\mathcal{J}_{\text{KL}}(\phi) = \underbrace{\mathbb{E}_{\mathbf{x}_0} \mathbb{E}_{p_t(\mathbf{x}_t)} \left[ \mathcal{D}_{\text{KL}}(p(\mathbf{x}_{t-\Delta t}\mid\mathbf{x}_t, \mathbf{x}_0)) \;\Vert\; p_\phi(\mathbf{x}_{t-\Delta t}\mid\mathbf{x}_t) \right]}_{\mathcal{J}_{\text{KL}}(\phi) \text{ (Conditional KL)}} + C$$ $$\mathcal{J}_{\text{SM}}(\phi) = \underbrace{\mathbb{E}_{\mathbf{x}_0} \mathbb{E}_{p_t(\mathbf{x}_t)} \left[ \Vert s_\phi(\mathbf{x}_t, t) - \nabla_{\mathbf{x}}\log p_t(\mathbf{x}_t \mid \mathbf{x}_0) \Vert_2^2 \right]}_{\mathcal{J}_{\text{DSM}}(\phi)} + C$$ $$\mathcal{J}_{\text{FM}}(\phi) = \underbrace{\mathbb{E}_{\mathbf{x}_0} \mathbb{E}_{p_t(\mathbf{x}_t)} \left[ \Vert \mathbf{v}_\phi(\mathbf{x}_t, t) - \mathbf{v}_t(\mathbf{x}_t \mid \mathbf{x}_0) \Vert_2^2 \right]}_{\mathcal{J}_{\text{CFM}}(\phi)} + C$$
Common forward perturbation kernel $$p_t(\mathbf{x}_t\mid\mathbf{x}_0) = \mathcal{N} \bigg( \mathbf{x}_t;\; \alpha_t\mathbf{x}_0, \sigma_t^2\mathbf{I} \bigg)$$
Parameterization Target $\epsilon$-Prediction $$\epsilon_\phi(\mathbf{x}_t, t) \approx \mathbb{E}[\epsilon\mid\mathbf{x}_t] = \epsilon^*(\mathbf{x}_t, t)$$ $\mathbf{x}$-Prediction $$\mathbf{x}_\phi(\mathbf{x}_t, t) \approx \mathbb{E}[\mathbf{x}_0\mid\mathbf{x}_t] = \mathbf{x}^*(\mathbf{x}_t, t)$$ Score-Prediction $$\begin{aligned} s_\phi(\mathbf{x}_t, t) &\approx \nabla_{\mathbf{x}_t}\log p_t(\mathbf{x}_t) \\ &= \mathbb{E}[\nabla_{\mathbf{x}_t}\log p_t(\mathbf{x}_t\mid\mathbf{x}_0)\mid\mathbf{x}_t] \\ &= s^*(\mathbf{x}_t, t) \end{aligned}$$ $\mathbf{v}$-Prediction $$\mathbf{v}_\phi(\mathbf{x}_t, t) \approx \mathbb{E}\left[\frac{\text{d}\mathbf{x}_t}{\text{d}t}\mid\mathbf{x}_t\right] = \mathbf{v}^*(\mathbf{x}_t, t)$$
Training objective $$\mathcal{L}_{\text{noise}}(\phi) := \mathbb{E}_t\left[ \omega(t)\;\mathbb{E}_{\mathbf{x}_0, \epsilon}\Vert \epsilon_\phi(\mathbf{x}_t, t) - \epsilon \Vert_2^2 \right]$$ $$\mathcal{L}_{\text{clean}}(\phi) := \mathbb{E}_t\left[ \omega(t)\;\mathbb{E}_{\mathbf{x}_0, \epsilon}\Vert \mathbf{x}_\phi(\mathbf{x}_t, t) - \mathbf{x}_0 \Vert_2^2 \right]$$ $$\mathcal{L}_{\text{score}}(\phi) := \mathbb{E}_t\left[ \omega(t)\;\mathbb{E}_{\mathbf{x}_0, \epsilon}\Vert s_\phi(\mathbf{x}_t, t) - \nabla_{\mathbf{x}_t}\log p_t(\mathbf{x}_t\mid\mathbf{x}_0) \Vert_2^2 \right]$$ $$\mathcal{L}_{\text{velocity}}(\phi) := \mathbb{E}_t\left[ \omega(t)\;\mathbb{E}_{\mathbf{x}_0, \epsilon}\Vert \mathbf{v}_\phi(\mathbf{x}_t, t) - \mathbf{v}_t(\mathbf{x}_t\mid\mathbf{x}_0,\epsilon) \Vert_2^2 \right]$$
Conditional Parameterization Target $$A_t\mathbf{x}_0+B_t\epsilon$$ $\epsilon$ $\mathbf{x}_0$ $$-\displaystyle\frac{1}{\sigma_t}\epsilon$$ $$\alpha_t'\mathbf{x}_0 + \sigma_t'\epsilon$$



Prop.) Equivalence of the Four Parameterizations



Prop.) PF-ODE in each Parameterization

  • cf.) Although the four parameterizations are equivalent in principle, they differ in practice.
    • Why?)
      • Each parametrization changes has its own
        • the stiffness of the vector field
        • the behavior of discretization error
        • the ease of optimization
    • For the fast sampling with advanced ODE solvers, $\epsilon$- or $\mathbf{x}$-prediction is preferred.
      • Why?) They align well with the solver inputs and reduce error accumulation.
    • For the limited number of function evaluations, $\mathbf{x}$- or $\mathbf{v}$-prediction is preferred.
      • Why?) They often yield smoother objectives and better step to step consistency.



Prop.) All Affine Flows are Equivalent

  • Settings)
    • \(\mathbf{x}_t^{\text{FM}} = (1-t)\mathbf{x}_0 + t\epsilon = \mathbf{x}_0 + t\underbrace{(\epsilon - \mathbf{x}_0)}_{\mathbf{v}}\).
      • i.e.) A canonical interpolation used in CFM and RF.



Prop.) Conversions between the Parameterizations

  • Settings)
    • \(\mathbf{x}_t = \alpha_t\mathbf{x}_0 + \sigma_t\epsilon\) s.t.
      • \(\sigma_t\gt0\).
      • \((\alpha_t'\sigma_t - \alpha_t\sigma_t')\ne0\).
    • Oracle targets are given by
      • \(\epsilon^*(\mathbf{x}_t, t) = \mathbb{E}[\epsilon\mid\mathbf{x}_t]\).
      • \(\mathbf{x}_0^*(\mathbf{x}_t, t) = \mathbb{E}[\mathbf{x}_0\mid\mathbf{x}_t]\).
      • \(\mathbf{v}^*(\mathbf{x}_t, t) = \mathbb{E}[\alpha_t'\mathbf{x}_0 + \sigma_t' \epsilon\mid\mathbf{x}_t]\).
  • Conversions
    • Score - Noise
      • Parameterization
        • \(s_\phi \equiv -\frac{1}{\sigma_t}\epsilon_\phi\).
      • Loss
        • \(\left\Vert s_\phi - \nabla_{\mathbf{x}_t}\log p_t(\mathbf{x}_t) \right\Vert_2^2 = \frac{1}{\sigma_t^2} \left\Vert \epsilon_{\phi} - \epsilon^* \right\Vert_2^2\).
    • Score - $\mathbf{x}$
      • Parameterization
        • \(s_\phi \equiv \frac{\alpha_t}{\sigma_t^2}\left(\mathbf{x}_\phi - \frac{\mathbf{x}_t}{\alpha_t}\right)\).
      • Loss
        • \(\left\Vert s_\phi - \nabla_{\mathbf{x}_t}\log p_t(\mathbf{x}_t) \right\Vert_2^2 = \frac{\alpha_t^2}{\sigma_t^4} \left\Vert \mathbf{x}_\phi - \mathbf{x}_0^* \right\Vert_2^2\).
    • Score - $\mathbf{v}$
      • Parameterization
        • \(s_\phi = \frac{\alpha_t}{\sigma_t(\alpha_t'\sigma_t - \alpha_t\sigma_t')}\mathbf{v}_\phi - \frac{\alpha_t'}{\sigma_t(\alpha_t'\sigma_t - \alpha_t\sigma_t')}\mathbf{x}_t\).
      • Loss
        • \(\left\Vert s_\phi - \nabla_{\mathbf{x}_t}\log p_t(\mathbf{x}_t) \right\Vert_2^2 = \left(\frac{\alpha_t}{\sigma_t(\alpha_t'\sigma_t - \alpha_t\sigma_t')}\right) \left\Vert \mathbf{v}_\phi - \mathbf{v}^*\right\Vert_2^2\).



Prop.) Unified perspective connecting variational, SDE, and ODE formulations through the continuity equation, where all p_t(x) evolve under a shared dynamic.




Enjoy Reading This Article?

Here are some more articles you might like to read next:

  • (DM Reconst.) Ch.4 Score SDE Famework
  • (DM Reconst.) Ch.3 Score-Based Perspective - From EBMs to NCSN
  • (DM Reconst.) Ch.2 Variational Perspective - From VAEs to DDPM
  • (DM Reconst.) Ch.5 Flow-Based Perspective - From NFs to Flow Matching
  • Denoising Diffusion Probabilistic Models (DDPM)