(DM Reconst.) Ch.4 Score SDE Famework

Diffusion Model Conceptual Reconstruction following The Principles of Diffusion Models



Hozy Summary

  • Training Score SDE
    • Forward-Time SDE : \(\text{d}\mathbf{x}(t) = \mathbf{f}(\mathbf{x}(t),)\text{d}t + g(t)\text{d}\mathbf{w}(t)\)
    • Loss : \(\mathcal{L}_{\text{DSM}}(\phi;\omega(\cdot)) := \displaystyle\frac{1}{2}\mathbb{E}_{t}\mathbb{E}_{\mathbf{x}_0}\mathbb{E}_{p_t(\mathbf{x}_{t}\mid\mathbf{x}_{0})}\left[ \omega(t) \Vert s_\phi(\mathbf{x}_{t}, t) - \nabla_{\mathbf{x}_{t}}\log p_t(\mathbf{x}_{t}\mid\mathbf{x}_{0}) \Vert_2^2 \right]\).
  • Sampling Score SDE
    • Reverse-Time SDE : \(\text{d}\mathbf{x}_{\phi^\times}^{\text{SDE}}(t) = \Big[ \mathbf{f}(\mathbf{x}_{\phi^\times}^{\text{SDE}}(t), t) - g^2(t) \underbrace{s_{\phi^\times}(\mathbf{x}_{\phi^\times}^{\text{SDE}}(t), t)}_{\text{plugged in}} \Big]\text{d}t + g(t)\text{d}\bar{\mathbf{w}}(t)\)
      • Sampling Scheme : \(\mathbf{x}_{t-\Delta t} \;\leftarrow\; \mathbf{x}_{t} - \left[ \mathbf{f}(\mathbf{x}_{t}, t) - g^2(t)s_{\phi^\times}(\mathbf{x}, t) \right]\Delta t + g(t)\sqrt{\Delta t} \cdot \boldsymbol{\epsilon}\)
    • PF-ODE : \(\displaystyle\frac{\text{d}}{\text{d}t} \mathbf{x}_{\phi^\times}^{\text{ODE}}(t) = \mathbf{f}(\mathbf{x}_{\phi^\times}^{\text{ODE}}(t), t) - g^2(t) \underbrace{s_{\phi^\times}(\mathbf{x}_{\phi^\times}^{\text{ODE}}(t), t)}_{\text{plugged in}}\)
      • Sampling Scheme : \(\mathbf{x}_{t-\Delta t} \;\leftarrow\; \mathbf{x}_{t} - \left[ \mathbf{f}(\mathbf{x}_{t}, t) - g^2(t)s_{\phi^\times}(\mathbf{x}, t) \right]\Delta t\)
  • Instantiations of SDE



4.1 Score SDE: Its Principles

Concept) Forward-Time SDE : Data to Noise

  • Def.)
    • \(\text{d}\mathbf{x}(t) = \mathbf{f}(\mathbf{x}(t),t)\text{d}t + g(t)\text{d}\mathbf{w}(t)\).
      • where
        • $t\in[0,T]$ : a continuous timestep
        • $\mathbf{f}(\cdot, t) : \mathbb{R}^D\rightarrow\mathbb{R}^D$ : the drift
        • $g(t)\in\mathbb{R}$ : the scalar diffusion coefficient
  • Prop.)
    • We may induce time-dependent densities of
      • \(p_t(\mathbf{x}_t\mid\mathbf{x}_0)\) : Perturbation Kernel
      • \(p_t(\mathbf{x}_t)\) : Marginal Density


Analysis) Deriving Forward SDE from NCSN & DDPM's Common Form
NCSN DDPM
Forward Noise Injection Schemes $$\mathbf{x}_{\sigma_i} = \mathbf{x} + \sigma_i\epsilon_i$$ $$\mathbf{x}_i = \sqrt{1-\beta_i^2}\mathbf{x}_{i-1} + \beta_i\epsilon_i$$
Discrete Sequential Update $$t\rightarrow t+\Delta t$$ $$\begin{aligned} \mathbf{x}_{t+\Delta t} &= \mathbf{x}_t + \sqrt{\sigma_{t+\Delta t}^2 - \sigma_t^2}\epsilon_t \\ &\approx \mathbf{x}_t + \underbrace{\sqrt{\frac{\text{d}\sigma_t^2}{\text{d}t}}}_{\text{diffusion}}\epsilon_t\sqrt{\Delta t} \end{aligned}$$ $$\begin{aligned} \mathbf{x}_{t+\Delta t} &= \sqrt{1-\beta_t}\mathbf{x}_t + \sqrt{\beta_{t}}\epsilon_t \\ &\approx \mathbf{x}_t \underbrace{- \frac{1}{2}\beta_t\mathbf{x}_t}{\text{drift}}\Delta t + \underbrace{\sqrt{\beta_t}}_{\text{diffusion}}\epsilon_t\sqrt{\Delta t} \end{aligned}$$
Common Form $$\text{d}\mathbf{x}(t) = \underbrace{\mathbf{f}(\mathbf{x}(t),t)\text{d}t}_{\text{drift}} + \underbrace{g(t)}_{\text{diffusion}}\text{d}\mathbf{w}(t)$$
Corresponding Gaussian transition is $$p(\mathbf{x}_{t+\Delta t}\mid\mathbf{x}_t) := \mathcal{N}\left( \mathbf{x}_{t+\Delta t};\; \mathbf{x}_t + \mathbf{f}(\mathbf{x}_t, t)\Delta t, g^2(t)\Delta t\mathbf{I} \right)$$ $$\text{where }\begin{cases} \mathbf{x}(t+\Delta t) - \mathbf{x}(t) & \approx \text{d}\mathbf{x}(t) \\ \Delta t &\approx \text{d}t \\ \sqrt{\Delta t} \epsilon_t \sim \mathcal{N}(\mathbf{0},, \Delta t\mathbf{I}) &\approx \text{d}\mathbf{w}(t) \end{cases}$$
$$\text{where d}\mathbf{w}(t)\sim\mathcal{N}(\mathbf{0}, \text{d}t\mathbf{I}) \text{ is a standard Wiener process}$$



Concept) Reverse-Time Stochastic Process for Generation

  • Def.)
    • \(\text{d}\bar{\mathbf{x}}(t) = \left[ \mathbf{f}(\bar{\mathbf{x}}(t), t) - g^2(t)\nabla_\mathbf{x}\log p_t(\bar{\mathbf{x}}(t)) \right]\text{d}t + g(t)\text{d}\bar{\mathbf{w}}(t)\).
      • where
        • \(\bar{\mathbf{x}}(T)\sim p_{\text{prior}} \approx p_T\).
        • \(\bar{\mathbf{w}}(t)\) denotes a standard Wiener process in reverse time s.t.
          • \(\bar{\mathbf{w}}(t) := \mathbf{w}(T-t) - \mathbf{w}(T)\).
  • Prop.)
    • Balance between the drift and the score-driven drift
      • drift : \(g(t)\text{d}\bar{\mathbf{w}}(t)\)
        • This term
          • provides controlled stochasticity for exploration.
          • dominates at the early stage : $t\approx T$
      • score=-driven drift : \(g^2(t)\nabla_\mathbf{x}\log p_t(\bar{\mathbf{x}}(t))\)
        • This term
          • guides trajectories toward regions of higher density
          • dominates at the later stage : $t\approx 0$



Concept) Probability Flow ODE (PF-ODE)

  • Goal)
    • A deterministic process (ODE) that evolves samples with the same marginal distributions as the forward SDE.
  • Def.)
    • \(\displaystyle\frac{\text{d}}{\text{d}t}\tilde{\mathbf{x}}(t) = \mathbf{f}(\tilde{\mathbf{x}}(t), t) - \frac{1}{2} g^2(t) \nabla_\mathbf{x} \log p_t(\tilde{\mathbf{x}}(t))\).
  • Sample result)
    • \(\tilde{\mathbf{x}}(T) + \displaystyle\int_T^0 \left[ \mathbf{f}(\tilde{\mathbf{x}(\tau)}, \tau) - \frac{1}{2}g^2(\tau)\nabla_\mathbf{x}\log p_\tau(\tilde{\mathbf{x}}(\tau)) \right]\text{d}\tau\quad\) where $\tilde{\mathbf{x}}(T)\sim p_{\text{prior}}$
      • Intractable
      • relies on numerical solvers
        • e.g.) Euler Method
  • Advantages over Reverse-Time SDE
    • Bidirectionalilty of Integration
      • i.e.) Can be integrated in either direction : $\int_0^T$ or $\int_T^0$
    • Wide range of well-established numerical solvers for ODE.
  • Prop.)
    • Derived by choosing a drift of an ODE s.t. its evolution preserves the same marginal densities as the forward SDE


Prop.)

  • Fokker-Plank Equation ensures that Reverse SDE and PF-ODE yield identical marginal distribution \(p_t(\mathbf{x}_t)\) given the Forward SDE.


E.g.) Gaussian Closed Form Example

  • Settings)
    • $p_{x_t}\sim\mathcal{N}(m_t, s_t^2)$
Reverse-Time SDE Derivation

Recall that the Forward SDE was given by $$\text{d}x(t) = f(t)x(t)\text{d}t + g(t)\text{d}w_t$$ If we take one small Euler step of size $\Delta t\gt0$, we have $$\begin{aligned} x_{t+\Delta t} = ax_t + r\epsilon \quad\text{where } \begin{cases} a := 1+f(t)+\Delta t \\ r := g(t)\sqrt{\Delta t} \\ \epsilon\sim\mathcal{N}(0,1) \end{cases} \end{aligned}$$ Equivalently, the forward one-step transition kernel is Gaussian of $$x_{t+\Delta t}\vert x_t\sim\mathcal{N}(a x_t, r^2)$$ Also, the current marginal at $t$ is also Gaussian because $p_{\text{data}}$ is Gaussian, i.e. $$x_t\sim\mathcal{N}(m_t, s_t^2)\quad\exists m_t, s_t\in\mathbb{R}$$ By Bayes' rule, we may denote the conditional density as $$\begin{aligned} p(x_t\mid x_{t+\Delta t}) &\varpropto p(x_{t+\Delta t}\mid x_t) \; p_t(x_t) \\ &\varpropto \exp\left(-\frac{(x_{t+\Delta t} - ax_t)^2}{2r^2}\right) \exp\left(-\frac{(x_t-m_t)^2}{2s_t^2}\right) \end{aligned}$$ Taking the log on both sides, we may simplify into $$\begin{aligned} \log p(x_t\mid x_{t+\Delta t}) &= -\frac{(x_{t+\Delta t} - ax_t)^2}{2r^2} -\frac{(x_t-m_t)^2}{2s_t^2} + C \\ &=-\left(\frac{a^2}{2r^2}+\frac{1}{2s_t^2}\right)x_t^2 -\left(\frac{-2ax_{t+\Delta t}}{2r^2} + \frac{-2m_t}{2s_t^2}\right) x_t + C' \\ &=-\frac{A}{2}x_t^2 + B x_t + C'' & \left(\text{Put } A = \frac{a^2}{r^2}+\frac{1}{s_t^2},\; B=\frac{ax_{t+\Delta t}}{r^2} + \frac{m_t}{s_t^2}\right) \\ &= -\frac{A}{2}\left(x_t-\frac{B}{A}\right)^2 - \frac{B^2}{A} + C'' \\ \end{aligned}$$ Considering that for $x\sim\mathcal{N}(\mu,\sigma^2),\;\log x = -\frac{(x-\mu)^2}{2\sigma^2}+C$, we may get $$\text{Var}(x_t\mid x_{t+\Delta t}) = \frac{1}{A} = \frac{1}{\frac{a^2}{r^2}+\frac{1}{s_t^2}},\quad \mathbb{E}[x_t\mid x_{t+\Delta t}] = \frac{B}{A} = \frac{\frac{ax_{t+\Delta t}}{r^2} + \frac{m_t}{s_t^2}}{\frac{a^2}{r^2}+\frac{1}{s_t^2}}$$

PF-ODE Derivation

Refer to p102-104
Idea : Denote $\Delta t$ as a smooth map as $$x_{t+\Delta t} = \Phi_{t,\Delta t}(x_t) = x_t + \Delta t v_t(x_t) + \mathcal{O}(\Delta t^2)$$ Goal is to see what form $v_t$ must take so that output is Gaussian given an Gaussian input.



4.2 Training and Sampling of Score SDE

Tech.) Training Score SDE

  • Loss : DSM Loss
    • \(\mathcal{L}_{\text{DSM}}(\phi;\omega(\cdot)) := \displaystyle\frac{1}{2}\mathbb{E}_{t}\mathbb{E}_{\mathbf{x}_0}\mathbb{E}_{p_t(\mathbf{x}_{t}\mid\mathbf{x}_{0})}\left[ \omega(t) \Vert s_\phi(\mathbf{x}_{t}, t) - \nabla_{\mathbf{x}_{t}}\log p_t(\mathbf{x}_{t}\mid\mathbf{x}_{0}) \Vert_2^2 \right]\).
      • where \(\mathbf{x}_{0}\sim p_{\text{data}}\)
        • cf.) This is the continuous counterpart of the NCSN Loss of
          • \(\mathcal{L}_{\text{NCSN}}(\phi) := \displaystyle\sum_{i=1}^L \lambda(\sigma_i) \mathcal{L}_{\text{DSM}}(\phi;\sigma_i)\).
        • Prop.)
          • The minimizer $s^*$ satisfies
            \(\begin{aligned} s^*(\mathbf{x}_{t}, t) &= \mathbb{E}_{\mathbf{x}_{0}\sim p(\mathbf{x}_{0}\mid\mathbf{x}_{t})} \left[ \nabla_{\mathbf{x}_{t}}\log p_t(\mathbf{x}_{t}\mid\mathbf{x}_{0}) \right] \\ &= \nabla_{\mathbf{x}_{t}}\log p_t(\mathbf{x}_{t}) \quad\quad \text{for almost every } \mathbf{x}_{t}\sim p_t \text{ and } t\in[0,T] \end{aligned}\)

Tech.) Sampling with Reverse-Time SDE

  • Setting)
    • \(s_{\phi^\times} := s_{\phi^\times}(\mathbf{x}, t) \approx \nabla_{\mathbf{x}} \log p_t(\mathbf{x})\).
      • Trained and frozen parameterized score
  • Sampling)
    • Plugging in the $s_{\phi^\times}$ into the Reverse-Time SDE, we may get
      • \(\text{d}\mathbf{x}_{\phi^\times}^{\text{SDE}}(t) = \Big[ \mathbf{f}(\mathbf{x}_{\phi^\times}^{\text{SDE}}(t), t) - g^2(t) \underbrace{s_{\phi^\times}(\mathbf{x}_{\phi^\times}^{\text{SDE}}(t), t)}_{\text{plugged in}} \Big]\text{d}t + g(t)\text{d}\bar{\mathbf{w}}(t)\).
    • Draw an initial value $\mathbf{x}T$ from $p{\text{prior}}$.
    • From $t=T$ to $t=0$, solve the above equation using a numerical solver.
      • e.g.) Euler-Maruyama Method
        • \(\mathbf{x}_{t-\Delta t} \;\leftarrow\; \mathbf{x}_{t} - \left[ \mathbf{f}(\mathbf{x}_{t}, t) - g^2(t)s_{\phi^\times}(\mathbf{x}, t) \right]\Delta t + g(t)\sqrt{\Delta t} \cdot \boldsymbol{\epsilon}\).
          • where
            • \(\boldsymbol{\epsilon}\sim\mathcal{N}(\mathbf{0, I})\),
            • $\Delta t\gt0$ is the step size.
  • Prop.)
    • DDPM’s sampling scheme is one specification with some choice of $\mathbf{f}$ and g.


Tech.) Sampling with PF-ODE

  • Setting)
    • \(s_{\phi^\times} := s_{\phi^\times}(\mathbf{x}, t) \approx \nabla_{\mathbf{x}} \log p_t(\mathbf{x})\).
      • Trained and frozen parameterized score
  • Sampling)
    • Plugging in the $s_{\phi^\times}$ into the PF-ODE, we may get
      • \(\displaystyle\frac{\text{d}}{\text{d}t} \mathbf{x}_{\phi^\times}^{\text{ODE}}(t) = \mathbf{f}(\mathbf{x}_{\phi^\times}^{\text{ODE}}(t), t) - g^2(t) \underbrace{s_{\phi^\times}(\mathbf{x}_{\phi^\times}^{\text{ODE}}(t), t)}_{\text{plugged in}}\).
    • Draw an initial value $\mathbf{x}T$ from $p{\text{prior}}$.
    • From $t=T$ to $t=0$, solve the above equation
      • This is equivalent to approximating the below integral:
        • \(\underbrace{\mathbf{x}_{\phi^\times}^{\text{ODE}}(0)}_{\text{final sample}} \;=\; \mathbf{x}(T) + \displaystyle\int_T^0 \left[ \mathbf{f} \left(\mathbf{x}_{\phi^\times}^{\text{ODE}}(\tau), \tau \right) - \frac{1}{2}g^2(\tau)s_{\phi^\times}\left( \mathbf{x}_{\phi^\times}^{\text{ODE}}(\tau) \right) \right]\text{d}\tau\).
      • Or, if we use the Euler method, we may update in discrete timestep $\Delta t$ as
        • \(\mathbf{x}_{t-\Delta t} \;\leftarrow\; \mathbf{x}_{t} - \left[ \mathbf{f}(\mathbf{x}_{t}, t) - g^2(t)s_{\phi^\times}(\mathbf{x}, t) \right]\Delta t\).
          • cf.) Completely deterministic!
  • Application)
  • Prop.)
    • Exact log-likelihood computation via PF-ODE
Desc.



4.3 Instantiations of SDEs

Concept) Variance Explosion SDE (VE SDE)

  • Def.)
    • \(\text{d}\mathbf{x}(t) = \displaystyle\sqrt{\frac{\text{d}\sigma^2(t)}{\text{d}t}}\text{d}\mathbf{w}(t)\).
      • i.e.)
        • Drift Term : $\mathbf{f} = 0$
        • Diffusion Term : \(g(t) = \frac{\text{d}\sigma^2(t)}{\text{d}t}\)
  • Props.)
    • Perturbation Kernel : \(p_t(\mathbf{x}_{t}\mid\mathbf{x}_{0}) = \mathcal{N}\left(\mathbf{x}_{t};\; \mathbf{x}_{0}, \left( \sigma^2(t) - \sigma^2(0) \right)\mathbf{I} \right)\)
    • Prior Distribution : \(p_{\text{data}} := \mathcal{N}(\mathbf{0}, \sigma^2(T)\mathbf{I})\)
      • where $\sigma(t)$ is an increasing function for $t\in[0,T]$ and that $\sigma^2(T)\gg\sigma^2(0)$
  • e.g.)
    • NCSN with \(\displaystyle\sigma(t) := \sigma_\min\left(\frac{\sigma_\max}{\sigma_\min}\right)^t,\quad \text{for } t\in(0,1]\)
      • i.e.) Discretized version of VE SDE


Concept) Variance Preserving SDE (VP SDE)

  • Def.)
    • \(\text{d}\mathbf{x}(t) = \displaystyle -\frac{1}{2}\beta(t)\mathbf{x}(t)\text{d}(t) + \sqrt{\beta(t)}\text{d}\mathbf{w}(t)\).
      • i.e.)
        • Drift Term : $\mathbf{f}(\mathbf{x}(t), t) = -\frac{1}{2}\beta(t)\mathbf{x}(t)$
        • Diffusion Term : \(g(t) = \sqrt{\beta(t)}\)
  • Props.)
    • Perturbation Kernel : \(p_t(\mathbf{x}_{t}\mid\mathbf{x}_{0}) = \mathcal{N}\left(\mathbf{x}_{t};\; \mathbf{x}_0 \exp\left(-\frac{1}{2}\int_0^t \beta(\tau)\text{d}\tau \right), \mathbf{I} - \mathbf{I}\exp(-\int_0^t \beta(\tau)\text{d}\tau) \right)\)
    • Prior Distribution : \(p_{\text{data}} := \mathcal{N}(\mathbf{0, I})\)
  • e.g.)
    • DDPM with noise schedule of \(\beta(t) := \beta_\min + t(\beta_\max - \beta_\min),\quad \forall t\in[0,1]\)
      • i.e.) Discretized version of VP SDE



Enjoy Reading This Article?

Here are some more articles you might like to read next:

  • Score-Based Generative Modeling through Stochastic Differential Equation
  • (DM Reconst.) Ch.3 Score-Based Perspective - From EBMs to NCSN
  • Denoising Diffusion Probabilistic Models (DDPM)
  • (DM Reconst.) Ch.6 A Unified and Systemic Lens on Diffusion Models
  • (DM Reconst.) Ch.5 Flow-Based Perspective - From NFs to Flow Matching