One-step Diffusion with Distribution Matching Distillation (DMD)
Hozy Summary
Concept) Distribution Based Distillation
- Settings)
- \(\mu_{\text{base}}\) : a pre-trained diffusion denoiser
- Desc.)
- Mean-prediction.
- Denoise \(x_T\sim\mathcal{N}(\mathbf{0,I})\) to \(x_0\sim p_{\text{real}}\)
- where \(T=1000\).
- Architecture : EDM, Stable Diffusion
- Desc.)
- \(G_\theta\) : “one-step” image generator
- Desc.)
- Outputs a fake image
- Two objectives
- Has the architecture of \(\mu_{\text{base}}\).
- Parameters \(\theta\) are initialized with the \(\mu_{\text{base}}\).
- i.e.) \(G_\theta(z) = \mu_{\text{base}}(z, T-1),\;\forall z\)
- Parameters \(\theta\) are initialized with the \(\mu_{\text{base}}\).
- Desc.)
- \(\mu_{\text{base}}\) : a pre-trained diffusion denoiser
- Loss)
- \(\mathcal{L} = \mathcal{D}_{\text{KL}} + \lambda_{\text{reg}} \mathcal{L}_{\text{reg}}\).
- where
- \(\mathcal{D}_{\text{KL}}\) is optimized by Distribution Matching Objective using \(\nabla_\theta \mathcal{D}_{\text{KL}}\) via \(\mathcal{L}_{\text{denoise}}^{\phi}\) minimization.
- \(\lambda_{\text{reg}} = 0.25\),
- \(\mathcal{L}_{\text{reg}}\) is the regression Loss
- where
- \(\mathcal{L} = \mathcal{D}_{\text{KL}} + \lambda_{\text{reg}} \mathcal{L}_{\text{reg}}\).
- Details)
- CFG used.
Tech.) Distribution Matching Loss
- Loss)
- \(\nabla_\theta \mathcal{D}_{\text{KL}} \simeq \mathbb{E}_{z,t,x,t_t} \left[ w_t \alpha_t \left( s_{\text{fake}}(x_t, t) - s_{\text{real}}(x_t, t) \right) \frac{\text{d}G}{\text{d}\theta} \right]\).
- where
- \(z\sim\mathcal{N}(\mathbf{0,I})\),
- \(x=G_\theta(z)\) : one-step sample
- \(t\sim\mathcal{U}(T_{\min}, T_{\max})\) : a time step
- Authors used \(T_{\min} = 0.02 T, T_{\max} = 0.98T\)
- \(x_t\sim q_t(x_t\mid x)\) : a perturbed training sample
- \(s_{\text{real}}(x) = \nabla_x \log p_{\text{real}}(x)\), \(s_{\text{fake}}(x) = \nabla_x \log p_{\text{fake}}(x)\) : scores from two distributions
- \(w_t\) : a weight schedule
- Authors used \(w_t = \displaystyle\frac{\sigma_t^2}{\alpha_t}\frac{CS}{\Vert\mu_{\text{base}}(x_t, t) - x\Vert_1}\)
- for the number of spatial locations \(S\) and the number of channels \(C\)
- Authors used \(w_t = \displaystyle\frac{\sigma_t^2}{\alpha_t}\frac{CS}{\Vert\mu_{\text{base}}(x_t, t) - x\Vert_1}\)
- where
- \(\nabla_\theta \mathcal{D}_{\text{KL}} \simeq \mathbb{E}_{z,t,x,t_t} \left[ w_t \alpha_t \left( s_{\text{fake}}(x_t, t) - s_{\text{real}}(x_t, t) \right) \frac{\text{d}G}{\text{d}\theta} \right]\).
- Derivation)
- We may start from approximating the fake distribution to the real one.
\(\begin{aligned} \mathcal{D}_{\text{KL}} = (p_{\text{fake}} \Vert p_{\text{real}}) &= \mathbb{E}_{p\sim p_{\text{fake}}} \left( \log\frac{p_{\text{fake}}}{p_{\text{real}}} \right) \\ &= \mathbb{E}_{z\sim\mathcal{N}(\mathbf{0, I}), x=G_\theta(z)} \left[ - \left( \log p_{\text{real}} - \log p_{\text{fake}} \right) \right] \\ \end{aligned}\). - This is intractable, but we only need the gradient w.r.t. \(\theta\) as
- \(\nabla_\theta \mathcal{D}_{\text{KL}} = \mathbb{E}_{z\sim\mathcal{N}(\mathbf{0, I}), x=G_\theta(z)} \left[ - \left( s_{\text{real}}(x) - s_{\text{fake}}(x) \right) \frac{\text{d}G}{\text{d}\theta} \right]\).
- for
- \(s_{\text{real}}(x) = \nabla_x \log p_{\text{real}}(x)\) and \(s_{\text{fake}}(x) = \nabla_x \log p_{\text{fake}}(x)\)
- cf.)
- \(s_{\text{real}}\) moves \(x\) towards the modes of \(p_{\text{real}}\)
- \(-s_{\text{fake}}\) spreads them apart
- cf.)
- \(s_{\text{real}}(x) = \nabla_x \log p_{\text{real}}(x)\) and \(s_{\text{fake}}(x) = \nabla_x \log p_{\text{fake}}(x)\)
- for
- \(\nabla_\theta \mathcal{D}_{\text{KL}} = \mathbb{E}_{z\sim\mathcal{N}(\mathbf{0, I}), x=G_\theta(z)} \left[ - \left( s_{\text{real}}(x) - s_{\text{fake}}(x) \right) \frac{\text{d}G}{\text{d}\theta} \right]\).
- Further perturb the data distribution with random Guassian noise
- Why?)
- To alleviate the vanishing gradient problem in most $x\in\mathcal{X}$
- i.e.)
- \(x_t\sim q(x_t\mid x)\) where \(q(x_t\mid x)\sim\mathcal{N}(\alpha_t x, \sigma_t^2\mathbf{I})\).
- Then, the scores can be denoted as
- \(s_{\text{real}}(x_t, t) = - \displaystyle\frac{x_t-\alpha_t \mu_{\text{base}}(x_t, t)}{\sigma_t^2}\).
- \(s_{\text{fake}}(x_t, t) = - \displaystyle\frac{x_t-\alpha_t \mu_{\text{fake}}^\phi(x_t, t)}{\sigma_t^2}\).
- where \(\mu_{\text{fake}}^\phi\)’s parameters are first initialized with the parameters from \(\mu_{\text{base}}\).
- Then, \(\phi\) are trained minimizing below objective
- \(\mathcal{L}_{\text{denoise}}^{\phi} = \left\Vert \mu_{\text{fake}}^\phi(x_t, t) - x_0 \right\Vert_2^2\).
- Why?)
- We may start from approximating the fake distribution to the real one.
Tech.) Regression Loss
- Goal)
- Problem)
- When \(t\approx 0\) (corrected with low level of noise), \(s_{\text{real}}(x_t, t)\) becomes unreliable as \(p_{\text{real}}\approx0\).
- The optimization is susceptible to mode collapse/dropping, where the fake distribution assigns higher overall density to a subset of the modes.
- i.e.) the score is invariant to scaling of the original probability density function.
- Sol.)
- Minimize the distance between the random Guassian noise \(z\) and the generated image by the pre-trained model \(y\).
- Problem)
- Loss)
- \(\mathcal{L}_{\text{reg}} = \mathbb{E}_{(z,y)\sim\mathcal{D}} \ell(G_\theta(z), y)\).
- where
- \(z\sim\mathcal{N}(\mathbf{0,I})\),
- \(y\sim\mu_{\text{base}}(z)\) : a sampled image
- Authors used…
- Heun solver from EDM
- on CIFAR-10 with 18 steps
- on Imagenet with 256 steps
- PNDM solver on LAION with 50 steps
- Heun solver from EDM
- Authors used…
- \(\ell(\cdot, \cdot)\) : a distance metric
- Authors used LPIPS following InstaFlow and Consistency Model.
- where
- \(\mathcal{L}_{\text{reg}} = \mathbb{E}_{(z,y)\sim\mathcal{D}} \ell(G_\theta(z), y)\).
Algorithm)
Enjoy Reading This Article?
Here are some more articles you might like to read next: