One-step Diffusion with Distribution Matching Distillation (DMD)

Yin et al. 2024



Hozy Summary




Concept) Distribution Based Distillation

  • Settings)
    • \(\mu_{\text{base}}\) : a pre-trained diffusion denoiser
      • Desc.)
        • Mean-prediction.
        • Denoise \(x_T\sim\mathcal{N}(\mathbf{0,I})\) to \(x_0\sim p_{\text{real}}\)
          • where \(T=1000\).
        • Architecture : EDM, Stable Diffusion
    • \(G_\theta\) : “one-step” image generator
      • Desc.)
        • Outputs a fake image
        • Two objectives
        • Has the architecture of \(\mu_{\text{base}}\).
          • Parameters \(\theta\) are initialized with the \(\mu_{\text{base}}\).
            • i.e.) \(G_\theta(z) = \mu_{\text{base}}(z, T-1),\;\forall z\)
  • Loss)
    • \(\mathcal{L} = \mathcal{D}_{\text{KL}} + \lambda_{\text{reg}} \mathcal{L}_{\text{reg}}\).
      • where
        • \(\mathcal{D}_{\text{KL}}\) is optimized by Distribution Matching Objective using \(\nabla_\theta \mathcal{D}_{\text{KL}}\) via \(\mathcal{L}_{\text{denoise}}^{\phi}\) minimization.
        • \(\lambda_{\text{reg}} = 0.25\),
        • \(\mathcal{L}_{\text{reg}}\) is the regression Loss
  • Details)
    • CFG used.


Tech.) Distribution Matching Loss

  • Loss)
    • \(\nabla_\theta \mathcal{D}_{\text{KL}} \simeq \mathbb{E}_{z,t,x,t_t} \left[ w_t \alpha_t \left( s_{\text{fake}}(x_t, t) - s_{\text{real}}(x_t, t) \right) \frac{\text{d}G}{\text{d}\theta} \right]\).
      • where
        • \(z\sim\mathcal{N}(\mathbf{0,I})\),
        • \(x=G_\theta(z)\) : one-step sample
        • \(t\sim\mathcal{U}(T_{\min}, T_{\max})\) : a time step
          • Authors used \(T_{\min} = 0.02 T, T_{\max} = 0.98T\)
        • \(x_t\sim q_t(x_t\mid x)\) : a perturbed training sample
        • \(s_{\text{real}}(x) = \nabla_x \log p_{\text{real}}(x)\), \(s_{\text{fake}}(x) = \nabla_x \log p_{\text{fake}}(x)\) : scores from two distributions
        • \(w_t\) : a weight schedule
          • Authors used \(w_t = \displaystyle\frac{\sigma_t^2}{\alpha_t}\frac{CS}{\Vert\mu_{\text{base}}(x_t, t) - x\Vert_1}\)
            • for the number of spatial locations \(S\) and the number of channels \(C\)
  • Derivation)
    • We may start from approximating the fake distribution to the real one.
      \(\begin{aligned} \mathcal{D}_{\text{KL}} = (p_{\text{fake}} \Vert p_{\text{real}}) &= \mathbb{E}_{p\sim p_{\text{fake}}} \left( \log\frac{p_{\text{fake}}}{p_{\text{real}}} \right) \\ &= \mathbb{E}_{z\sim\mathcal{N}(\mathbf{0, I}), x=G_\theta(z)} \left[ - \left( \log p_{\text{real}} - \log p_{\text{fake}} \right) \right] \\ \end{aligned}\).
    • This is intractable, but we only need the gradient w.r.t. \(\theta\) as
      • \(\nabla_\theta \mathcal{D}_{\text{KL}} = \mathbb{E}_{z\sim\mathcal{N}(\mathbf{0, I}), x=G_\theta(z)} \left[ - \left( s_{\text{real}}(x) - s_{\text{fake}}(x) \right) \frac{\text{d}G}{\text{d}\theta} \right]\).
        • for
          • \(s_{\text{real}}(x) = \nabla_x \log p_{\text{real}}(x)\) and \(s_{\text{fake}}(x) = \nabla_x \log p_{\text{fake}}(x)\)
            • cf.)
              • \(s_{\text{real}}\) moves \(x\) towards the modes of \(p_{\text{real}}\)
              • \(-s_{\text{fake}}\) spreads them apart
    • Further perturb the data distribution with random Guassian noise
      • Why?)
        • To alleviate the vanishing gradient problem in most $x\in\mathcal{X}$
      • i.e.)
        • \(x_t\sim q(x_t\mid x)\) where \(q(x_t\mid x)\sim\mathcal{N}(\alpha_t x, \sigma_t^2\mathbf{I})\).
        • Then, the scores can be denoted as
          • \(s_{\text{real}}(x_t, t) = - \displaystyle\frac{x_t-\alpha_t \mu_{\text{base}}(x_t, t)}{\sigma_t^2}\).
          • \(s_{\text{fake}}(x_t, t) = - \displaystyle\frac{x_t-\alpha_t \mu_{\text{fake}}^\phi(x_t, t)}{\sigma_t^2}\).
            • where \(\mu_{\text{fake}}^\phi\)’s parameters are first initialized with the parameters from \(\mu_{\text{base}}\).
            • Then, \(\phi\) are trained minimizing below objective
              • \(\mathcal{L}_{\text{denoise}}^{\phi} = \left\Vert \mu_{\text{fake}}^\phi(x_t, t) - x_0 \right\Vert_2^2\).


Tech.) Regression Loss

  • Goal)
    • Problem)
      • When \(t\approx 0\) (corrected with low level of noise), \(s_{\text{real}}(x_t, t)\) becomes unreliable as \(p_{\text{real}}\approx0\).
      • The optimization is susceptible to mode collapse/dropping, where the fake distribution assigns higher overall density to a subset of the modes.
        • i.e.) the score is invariant to scaling of the original probability density function.
    • Sol.)
      • Minimize the distance between the random Guassian noise \(z\) and the generated image by the pre-trained model \(y\).
  • Loss)
    • \(\mathcal{L}_{\text{reg}} = \mathbb{E}_{(z,y)\sim\mathcal{D}} \ell(G_\theta(z), y)\).
      • where
        • \(z\sim\mathcal{N}(\mathbf{0,I})\),
        • \(y\sim\mu_{\text{base}}(z)\) : a sampled image
          • Authors used…
            • Heun solver from EDM
              • on CIFAR-10 with 18 steps
              • on Imagenet with 256 steps
            • PNDM solver on LAION with 50 steps
        • \(\ell(\cdot, \cdot)\) : a distance metric
          • Authors used LPIPS following InstaFlow and Consistency Model.



Algorithm)




Enjoy Reading This Article?

Here are some more articles you might like to read next:

  • (DM Reconst.) Ch.3 Score-Based Perspective - From EBMs to NCSN
  • Guiding a Diffusion Model with a Bad Version of Itself (Autoguidance)
  • Scaling Rectified Flow Transformers for High-Resolution Image Synthesis (Stable Diffusion 3)
  • Denoising Diffusion Probabilistic Models (DDPM)
  • Score-Based Generative Modeling through Stochastic Differential Equation