(DM Reconst.) Ch.6 A Unified and Systemic Lens on Diffusion Models
Diffusion Model Conceptual Reconstruction following The Principles of Diffusion Models
Hozy Summary
Concept) Four Prediction Types
- Types)
- $\epsilon$-Prediction
- $\mathbf{x}$-Prediction
- Score-Prediction
- $\mathbf{v}$-Prediction
| Variational View | Score-Based View | Flow-Based View | ||
|---|---|---|---|---|
| Intractable original training objective | $$\mathcal{J}_{\text{KL}}(\phi) := \mathbb{E}_{p_t(\mathbf{x}_t)} \left[ \mathcal{D}_{\text{KL}}(p(\mathbf{x}_{t-\Delta t}\mid\mathbf{x}_t)) \;\Vert\; p_\phi(\mathbf{x}_{t-\Delta t}\mid\mathbf{x}_t) \right]$$ | $$\mathcal{J}_{\text{SM}}(\phi) := \mathbb{E}_{p_t(\mathbf{x}_t)} \left[ \Vert s_\phi(\mathbf{x}_t, t) - \nabla_{\mathbf{x}}\log p_t(\mathbf{x}_t) \Vert_2^2 \right]$$ | $$\mathcal{J}_{\text{FM}}(\phi) := \mathbb{E}_{p_t(\mathbf{x}_t)} \left[ \Vert \mathbf{v}_\phi(\mathbf{x}_t, t) - \mathbf{v}_t(\mathbf{x}_t) \Vert_2^2 \right]$$ | |
| Tractable objectives by conditioning on the data $$\mathbf{x}_0\sim p_{\text{data}}$$ | $$\mathcal{J}_{\text{KL}}(\phi) = \underbrace{\mathbb{E}_{\mathbf{x}_0} \mathbb{E}_{p_t(\mathbf{x}_t)} \left[ \mathcal{D}_{\text{KL}}(p(\mathbf{x}_{t-\Delta t}\mid\mathbf{x}_t, \mathbf{x}_0)) \;\Vert\; p_\phi(\mathbf{x}_{t-\Delta t}\mid\mathbf{x}_t) \right]}_{\mathcal{J}_{\text{KL}}(\phi) \text{ (Conditional KL)}} + C$$ | $$\mathcal{J}_{\text{SM}}(\phi) = \underbrace{\mathbb{E}_{\mathbf{x}_0} \mathbb{E}_{p_t(\mathbf{x}_t)} \left[ \Vert s_\phi(\mathbf{x}_t, t) - \nabla_{\mathbf{x}}\log p_t(\mathbf{x}_t \mid \mathbf{x}_0) \Vert_2^2 \right]}_{\mathcal{J}_{\text{DSM}}(\phi)} + C$$ | $$\mathcal{J}_{\text{FM}}(\phi) = \underbrace{\mathbb{E}_{\mathbf{x}_0} \mathbb{E}_{p_t(\mathbf{x}_t)} \left[ \Vert \mathbf{v}_\phi(\mathbf{x}_t, t) - \mathbf{v}_t(\mathbf{x}_t \mid \mathbf{x}_0) \Vert_2^2 \right]}_{\mathcal{J}_{\text{CFM}}(\phi)} + C$$ | |
| Common forward perturbation kernel | $$p_t(\mathbf{x}_t\mid\mathbf{x}_0) = \mathcal{N} \bigg( \mathbf{x}_t;\; \alpha_t\mathbf{x}_0, \sigma_t^2\mathbf{I} \bigg)$$ | |||
| Parameterization Target | $\epsilon$-Prediction $$\epsilon_\phi(\mathbf{x}_t, t) \approx \mathbb{E}[\epsilon\mid\mathbf{x}_t] = \epsilon^*(\mathbf{x}_t, t)$$ | $\mathbf{x}$-Prediction $$\mathbf{x}_\phi(\mathbf{x}_t, t) \approx \mathbb{E}[\mathbf{x}_0\mid\mathbf{x}_t] = \mathbf{x}^*(\mathbf{x}_t, t)$$ | Score-Prediction $$\begin{aligned} s_\phi(\mathbf{x}_t, t) &\approx \nabla_{\mathbf{x}_t}\log p_t(\mathbf{x}_t) \\ &= \mathbb{E}[\nabla_{\mathbf{x}_t}\log p_t(\mathbf{x}_t\mid\mathbf{x}_0)\mid\mathbf{x}_t] \\ &= s^*(\mathbf{x}_t, t) \end{aligned}$$ | $\mathbf{v}$-Prediction $$\mathbf{v}_\phi(\mathbf{x}_t, t) \approx \mathbb{E}\left[\frac{\text{d}\mathbf{x}_t}{\text{d}t}\mid\mathbf{x}_t\right] = \mathbf{v}^*(\mathbf{x}_t, t)$$ |
| Training objective | $$\mathcal{L}_{\text{noise}}(\phi) := \mathbb{E}_t\left[ \omega(t)\;\mathbb{E}_{\mathbf{x}_0, \epsilon}\Vert \epsilon_\phi(\mathbf{x}_t, t) - \epsilon \Vert_2^2 \right]$$ | $$\mathcal{L}_{\text{clean}}(\phi) := \mathbb{E}_t\left[ \omega(t)\;\mathbb{E}_{\mathbf{x}_0, \epsilon}\Vert \mathbf{x}_\phi(\mathbf{x}_t, t) - \mathbf{x}_0 \Vert_2^2 \right]$$ | $$\mathcal{L}_{\text{score}}(\phi) := \mathbb{E}_t\left[ \omega(t)\;\mathbb{E}_{\mathbf{x}_0, \epsilon}\Vert s_\phi(\mathbf{x}_t, t) - \nabla_{\mathbf{x}_t}\log p_t(\mathbf{x}_t\mid\mathbf{x}_0) \Vert_2^2 \right]$$ | $$\mathcal{L}_{\text{velocity}}(\phi) := \mathbb{E}_t\left[ \omega(t)\;\mathbb{E}_{\mathbf{x}_0, \epsilon}\Vert \mathbf{v}_\phi(\mathbf{x}_t, t) - \mathbf{v}_t(\mathbf{x}_t\mid\mathbf{x}_0,\epsilon) \Vert_2^2 \right]$$ |
| Conditional Parameterization Target $$A_t\mathbf{x}_0+B_t\epsilon$$ | $\epsilon$ | $\mathbf{x}_0$ | $$-\displaystyle\frac{1}{\sigma_t}\epsilon$$ | $$\alpha_t'\mathbf{x}_0 + \sigma_t'\epsilon$$ |
Prop.) Equivalence of the Four Parameterizations
Prop.) PF-ODE in each Parameterization
- cf.) Although the four parameterizations are equivalent in principle, they differ in practice.
- Why?)
- Each parametrization changes has its own
- the stiffness of the vector field
- the behavior of discretization error
- the ease of optimization
- Each parametrization changes has its own
- For the fast sampling with advanced ODE solvers, $\epsilon$- or $\mathbf{x}$-prediction is preferred.
- Why?) They align well with the solver inputs and reduce error accumulation.
- For the limited number of function evaluations, $\mathbf{x}$- or $\mathbf{v}$-prediction is preferred.
- Why?) They often yield smoother objectives and better step to step consistency.
- Why?)
Prop.) All Affine Flows are Equivalent
- Settings)
Prop.) Conversions between the Parameterizations
- Settings)
- \(\mathbf{x}_t = \alpha_t\mathbf{x}_0 + \sigma_t\epsilon\) s.t.
- \(\sigma_t\gt0\).
- \((\alpha_t'\sigma_t - \alpha_t\sigma_t')\ne0\).
- Oracle targets are given by
- \(\epsilon^*(\mathbf{x}_t, t) = \mathbb{E}[\epsilon\mid\mathbf{x}_t]\).
- \(\mathbf{x}_0^*(\mathbf{x}_t, t) = \mathbb{E}[\mathbf{x}_0\mid\mathbf{x}_t]\).
- \(\mathbf{v}^*(\mathbf{x}_t, t) = \mathbb{E}[\alpha_t'\mathbf{x}_0 + \sigma_t' \epsilon\mid\mathbf{x}_t]\).
- \(\mathbf{x}_t = \alpha_t\mathbf{x}_0 + \sigma_t\epsilon\) s.t.
- Conversions
- Score - Noise
- Parameterization
- \(s_\phi \equiv -\frac{1}{\sigma_t}\epsilon_\phi\).
- Loss
- \(\left\Vert s_\phi - \nabla_{\mathbf{x}_t}\log p_t(\mathbf{x}_t) \right\Vert_2^2 = \frac{1}{\sigma_t^2} \left\Vert \epsilon_{\phi} - \epsilon^* \right\Vert_2^2\).
- Parameterization
- Score - $\mathbf{x}$
- Parameterization
- \(s_\phi \equiv \frac{\alpha_t}{\sigma_t^2}\left(\mathbf{x}_\phi - \frac{\mathbf{x}_t}{\alpha_t}\right)\).
- Loss
- \(\left\Vert s_\phi - \nabla_{\mathbf{x}_t}\log p_t(\mathbf{x}_t) \right\Vert_2^2 = \frac{\alpha_t^2}{\sigma_t^4} \left\Vert \mathbf{x}_\phi - \mathbf{x}_0^* \right\Vert_2^2\).
- Parameterization
- Score - $\mathbf{v}$
- Parameterization
- \(s_\phi = \frac{\alpha_t}{\sigma_t(\alpha_t'\sigma_t - \alpha_t\sigma_t')}\mathbf{v}_\phi - \frac{\alpha_t'}{\sigma_t(\alpha_t'\sigma_t - \alpha_t\sigma_t')}\mathbf{x}_t\).
- Loss
- \(\left\Vert s_\phi - \nabla_{\mathbf{x}_t}\log p_t(\mathbf{x}_t) \right\Vert_2^2 = \left(\frac{\alpha_t}{\sigma_t(\alpha_t'\sigma_t - \alpha_t\sigma_t')}\right) \left\Vert \mathbf{v}_\phi - \mathbf{v}^*\right\Vert_2^2\).
- Parameterization
- Score - Noise
Prop.) Unified perspective connecting variational, SDE, and ODE formulations through the continuity equation, where all p_t(x) evolve under a shared dynamic.
Enjoy Reading This Article?
Here are some more articles you might like to read next: