(DM Reconst.) [WIP] Ch.10 Distillation-Based Methods for Fast Sampling
Diffusion Model Conceptual Reconstruction following The Principles of Diffusion Models
Hozy Summary
Concept) Distillation
- Problem)
- Diffusion model’s trade-off between quality and efficiency.
- DDPM’s \(\mathbf{x}\)-prediction perspective.
- We use \(\mathbf{x}_{\phi^\times}(\mathbf{x}_t, t)\approx\mathbb{E}[\mathbf{x}_0\mid\mathbf{x}_t]\) for one-step generation, where this denoiser averages over many plausible outcomes.
- Thus, few denoising steps leads overly smooth prediction and the blurry, low quality samples.
- ODE/SDE Trajectory Perspective)
- Reducing the NFE speeds up the generation but reduces the fidelity.
- Why?)
- Each solver step introduces an integration error of order \(\mathcal{O}(h^n)\)
- where
- \(h=\max_i\left\vert t_i-t_{i-1}\right\vert\) : the maximum step size
- \(n\) : the solver order.
- where
- Thus, \((\text{Fewer Steps})\Rightarrow h\uparrow \Rightarrow \mathcal{O}(h^n) \uparrow\)
- Each solver step introduces an integration error of order \(\mathcal{O}(h^n)\)
- Why?)
- Reducing the NFE speeds up the generation but reduces the fidelity.
- DDPM’s \(\mathbf{x}\)-prediction perspective.
- Diffusion model’s trade-off between quality and efficiency.
- Idea)
- Accelerate diffusion model sampling by teaching new generators to produce samples in only one or a few steps.
- How?)
- Teacher Sampler : Slow, pre-trained diffusion model
- Student model : Learn from the teacher sampler
- How?)
- Accelerate diffusion model sampling by teaching new generators to produce samples in only one or a few steps.
- Paradigms)
Concept) Distribution Level Distillation
- Model)
- Given
- \(p_{\text{prior}}, p_{\text{data}}\) : a prior noise distribution and the data distribution
- We want to train a one-step generator \(G_\theta(\mathbf{z})\)
- where
- \(\mathbf{z}\sim p_{\text{prior}}\) : a noisy latent
- by optimizing as
- \(\min_{\theta} \mathcal{D}\left(p_\theta(\hat{\mathbf{x}}), p_{\text{data}}(\hat{\mathbf{x}})\right)\).
- where
- \(\hat{\mathbf{x}} = G_\theta(\mathbf{z})\) : a sample from \(G_\theta(\mathbf{z})\)
- \(\mathcal{D}\) denotes a suitable divergence measurement such as KL.
- where
- \(\min_{\theta} \mathcal{D}\left(p_\theta(\hat{\mathbf{x}}), p_{\text{data}}(\hat{\mathbf{x}})\right)\).
- where
- In practice, we align the generator’s distribution with the empirical distrbution \(p_{\phi^\times}(\mathbf{x})\)
- i.e.) \(\min_{\theta} \mathcal{D}\left(p_\theta(\hat{\mathbf{x}}), p_{\phi^\times}(\hat{\mathbf{x}})\right)\)
- Given
- Approaches)
- Distributional Matching Distillation (DMD) (Yin et al. 2024)
- Variational Score Distillation (VSD)
- Score Identity Distillation (SiD)
Concept) Flow Map Level Distillation
- Model)
- Consider the PF-ODE of
- \(\displaystyle\frac{\text{d}\mathbf{x}(\tau)}{\text{d}\tau} = f(\tau)\mathbf{x}(\tau) - \frac{1}{2}g^2(\tau)\nabla_{\mathbf{x}} \log p_\tau(\mathbf{x}(\tau))=: \mathbf{v}^*(\mathbf{x}(\tau), \tau)\).
- Its solution map is denoted as
- \(\displaystyle \Psi_{s\rightarrow t}(\mathbf{x}_s) := \mathbf{x}_s + \int_s^t v^*(\mathbf{x}(\tau), \tau)\text{d}\tau\).
- where
- \(t\le s\).
- \(\mathbf{x}_T\sim p_{\text{prior}}, \mathbf{x}_0\sim p_{\text{data}}\).
- where
- \(\displaystyle \Psi_{s\rightarrow t}(\mathbf{x}_s) := \mathbf{x}_s + \int_s^t v^*(\mathbf{x}(\tau), \tau)\text{d}\tau\).
- We want to learn a map \(\Psi_{T\rightarrow0}(\mathbf{x}_T)\) that enables the one-step generation.
- However, the PF-ODE rarely admits a closed form. Thus, we approximate it numerically during training, and denote it as the general solver of
- \(\text{Solver}_{s\rightarrow t}(\mathbf{x}_s;\;\phi^\times)\) or simply \(\text{Solver}_{s\rightarrow t}(\mathbf{x}_s)\)
- i.e.) the numerical integration from \(s\) to \(t\) starting at \(\mathbf{x}_s\), with teacher parameters \(\phi^\times\)
- \(\text{Solver}_{s\rightarrow t}(\mathbf{x}_s;\;\phi^\times)\) or simply \(\text{Solver}_{s\rightarrow t}(\mathbf{x}_s)\)
- The loss can be denoted as
- \(\mathcal{L}_{\text{oracle}}(\theta) := \mathbb{E}_{s,t}\mathbb{E}_{\mathbf{x}_s\sim p_s}\left[ w(s, t) \; d(G_\theta(\mathbf{x}_s, s, t), \Psi_{s\rightarrow t}(\mathbf{x}_s)) \right]\).
- where
- \(w(s, t)\) : the weight on the time pairs \((s,t)\)
- \(d(\cdot,\;\cdot)\): : a distance metric
- where
- \(\mathcal{L}_{\text{oracle}}(\theta) := \mathbb{E}_{s,t}\mathbb{E}_{\mathbf{x}_s\sim p_s}\left[ w(s, t) \; d(G_\theta(\mathbf{x}_s, s, t), \Psi_{s\rightarrow t}(\mathbf{x}_s)) \right]\).
- Consider the PF-ODE of
- Approaches)
Concept) Knowledge Distillation
Luhman and Luhman, 2021
- Goal)
- Train a generator to imitate the output of a numerical solver evaluated along the full trajectory : \(0\rightarrow T\).
- i.e.) \(G_\phi(\mathbf{x}_T, T, 0) \approx\text{Solver}_{T\rightarrow0}(\mathbf{x}_T)\) for \(\mathbf{x}_T\sim p_{\text{prior}}\)
- Train a generator to imitate the output of a numerical solver evaluated along the full trajectory : \(0\rightarrow T\).
- Loss)
- \(\mathcal{L}_{\text{KD}} := \mathbb{E}_{\mathbf{x}_T\sim p_{\text{prior}}}\left\Vert G_\phi(\mathbf{x}_T, T, 0) - \text{Solver}_{T\rightarrow0}(\mathbf{x}_T) \right\Vert_2^2\).
- Advantage)
- Direct supervision from the pre-trained teacher \(\text{Solver}\)
- Drawback)
- Cannot leverage the strong supervision from the original training data.
- Computationally expensive.
- Why?) ODE integration on the Solver is required during the \(G_\phi\)’s training.
- The only option of the global mapping of \(T\rightarrow0\), does not provide controllability.
- Thus, controllable generation techniques cannot be applied.
Concept) Progressive Distillation
Salimans and Ho, 2021
- Goal)
- Train a time-conditional
Studentusing local supervision fromTeacherfragments.
- Train a time-conditional
- Idea)
- \(t_0=T\gt t_1\gt\cdots\gt t_N = 0\) : the fixed time grid
- \(\text{Teacher}_{t_k\rightarrow t_{k+1}}\) : Teacher model that provides time-stepping maps for \(k=0,\ldots,N-1\)
- Repeat
- \(\text{Student}_{t_k\rightarrow t_{k+2}} \approx \text{Teacher}_{t_k\rightarrow t_{k+1}} \circ \text{Teacher}_{t_{k+1}\rightarrow t_{k+2}}\).
- i.e.) Student learns the two-step skip map for \(k=0,2,4,\ldots, N-1\).
- \(\text{Teacher}\leftarrow\text{Student}\).
- \(\text{Student}_{t_k\rightarrow t_{k+2}} \approx \text{Teacher}_{t_k\rightarrow t_{k+1}} \circ \text{Teacher}_{t_{k+1}\rightarrow t_{k+2}}\).
Enjoy Reading This Article?
Here are some more articles you might like to read next: