(DM Reconst.) Ch.4 Score SDE Famework
Diffusion Model Conceptual Reconstruction following The Principles of Diffusion Models
Hozy Summary
- Training Score SDE
- Forward-Time SDE : \(\text{d}\mathbf{x}(t) = \mathbf{f}(\mathbf{x}(t),)\text{d}t + g(t)\text{d}\mathbf{w}(t)\)
- Loss : \(\mathcal{L}_{\text{DSM}}(\phi;\omega(\cdot)) := \displaystyle\frac{1}{2}\mathbb{E}_{t}\mathbb{E}_{\mathbf{x}_0}\mathbb{E}_{p_t(\mathbf{x}_{t}\mid\mathbf{x}_{0})}\left[ \omega(t) \Vert s_\phi(\mathbf{x}_{t}, t) - \nabla_{\mathbf{x}_{t}}\log p_t(\mathbf{x}_{t}\mid\mathbf{x}_{0}) \Vert_2^2 \right]\).
- Sampling Score SDE
- Reverse-Time SDE : \(\text{d}\mathbf{x}_{\phi^\times}^{\text{SDE}}(t) = \Big[ \mathbf{f}(\mathbf{x}_{\phi^\times}^{\text{SDE}}(t), t) - g^2(t) \underbrace{s_{\phi^\times}(\mathbf{x}_{\phi^\times}^{\text{SDE}}(t), t)}_{\text{plugged in}} \Big]\text{d}t + g(t)\text{d}\bar{\mathbf{w}}(t)\)
- Sampling Scheme : \(\mathbf{x}_{t-\Delta t} \;\leftarrow\; \mathbf{x}_{t} - \left[ \mathbf{f}(\mathbf{x}_{t}, t) - g^2(t)s_{\phi^\times}(\mathbf{x}, t) \right]\Delta t + g(t)\sqrt{\Delta t} \cdot \boldsymbol{\epsilon}\)
- PF-ODE : \(\displaystyle\frac{\text{d}}{\text{d}t} \mathbf{x}_{\phi^\times}^{\text{ODE}}(t) = \mathbf{f}(\mathbf{x}_{\phi^\times}^{\text{ODE}}(t), t) - g^2(t) \underbrace{s_{\phi^\times}(\mathbf{x}_{\phi^\times}^{\text{ODE}}(t), t)}_{\text{plugged in}}\)
- Sampling Scheme : \(\mathbf{x}_{t-\Delta t} \;\leftarrow\; \mathbf{x}_{t} - \left[ \mathbf{f}(\mathbf{x}_{t}, t) - g^2(t)s_{\phi^\times}(\mathbf{x}, t) \right]\Delta t\)
- Reverse-Time SDE : \(\text{d}\mathbf{x}_{\phi^\times}^{\text{SDE}}(t) = \Big[ \mathbf{f}(\mathbf{x}_{\phi^\times}^{\text{SDE}}(t), t) - g^2(t) \underbrace{s_{\phi^\times}(\mathbf{x}_{\phi^\times}^{\text{SDE}}(t), t)}_{\text{plugged in}} \Big]\text{d}t + g(t)\text{d}\bar{\mathbf{w}}(t)\)
- Instantiations of SDE
4.1 Score SDE: Its Principles
Concept) Forward-Time SDE : Data to Noise
- Def.)
- \(\text{d}\mathbf{x}(t) = \mathbf{f}(\mathbf{x}(t),t)\text{d}t + g(t)\text{d}\mathbf{w}(t)\).
- where
- $t\in[0,T]$ : a continuous timestep
- $\mathbf{f}(\cdot, t) : \mathbb{R}^D\rightarrow\mathbb{R}^D$ : the drift
- $g(t)\in\mathbb{R}$ : the scalar diffusion coefficient
- where
- \(\text{d}\mathbf{x}(t) = \mathbf{f}(\mathbf{x}(t),t)\text{d}t + g(t)\text{d}\mathbf{w}(t)\).
- Prop.)
- We may induce time-dependent densities of
- \(p_t(\mathbf{x}_t\mid\mathbf{x}_0)\) : Perturbation Kernel
- \(p_t(\mathbf{x}_t)\) : Marginal Density
- We may induce time-dependent densities of
Analysis) Deriving Forward SDE from NCSN & DDPM's Common Form
| NCSN | DDPM | |
|---|---|---|
| Forward Noise Injection Schemes | $$\mathbf{x}_{\sigma_i} = \mathbf{x} + \sigma_i\epsilon_i$$ | $$\mathbf{x}_i = \sqrt{1-\beta_i^2}\mathbf{x}_{i-1} + \beta_i\epsilon_i$$ |
| Discrete Sequential Update $$t\rightarrow t+\Delta t$$ | $$\begin{aligned} \mathbf{x}_{t+\Delta t} &= \mathbf{x}_t + \sqrt{\sigma_{t+\Delta t}^2 - \sigma_t^2}\epsilon_t \\ &\approx \mathbf{x}_t + \underbrace{\sqrt{\frac{\text{d}\sigma_t^2}{\text{d}t}}}_{\text{diffusion}}\epsilon_t\sqrt{\Delta t} \end{aligned}$$ | $$\begin{aligned} \mathbf{x}_{t+\Delta t} &= \sqrt{1-\beta_t}\mathbf{x}_t + \sqrt{\beta_{t}}\epsilon_t \\ &\approx \mathbf{x}_t \underbrace{- \frac{1}{2}\beta_t\mathbf{x}_t}{\text{drift}}\Delta t + \underbrace{\sqrt{\beta_t}}_{\text{diffusion}}\epsilon_t\sqrt{\Delta t} \end{aligned}$$ |
| Common Form | $$\text{d}\mathbf{x}(t) = \underbrace{\mathbf{f}(\mathbf{x}(t),t)\text{d}t}_{\text{drift}} + \underbrace{g(t)}_{\text{diffusion}}\text{d}\mathbf{w}(t)$$ Corresponding Gaussian transition is $$p(\mathbf{x}_{t+\Delta t}\mid\mathbf{x}_t) := \mathcal{N}\left( \mathbf{x}_{t+\Delta t};\; \mathbf{x}_t + \mathbf{f}(\mathbf{x}_t, t)\Delta t, g^2(t)\Delta t\mathbf{I} \right)$$ $$\text{where }\begin{cases} \mathbf{x}(t+\Delta t) - \mathbf{x}(t) & \approx \text{d}\mathbf{x}(t) \\ \Delta t &\approx \text{d}t \\ \sqrt{\Delta t} \epsilon_t \sim \mathcal{N}(\mathbf{0},, \Delta t\mathbf{I}) &\approx \text{d}\mathbf{w}(t) \end{cases}$$ $$\text{where d}\mathbf{w}(t)\sim\mathcal{N}(\mathbf{0}, \text{d}t\mathbf{I}) \text{ is a standard Wiener process}$$ | |
Concept) Reverse-Time Stochastic Process for Generation
- Def.)
- \(\text{d}\bar{\mathbf{x}}(t) = \left[ \mathbf{f}(\bar{\mathbf{x}}(t), t) - g^2(t)\nabla_\mathbf{x}\log p_t(\bar{\mathbf{x}}(t)) \right]\text{d}t + g(t)\text{d}\bar{\mathbf{w}}(t)\).
- where
- \(\bar{\mathbf{x}}(T)\sim p_{\text{prior}} \approx p_T\).
- \(\bar{\mathbf{w}}(t)\) denotes a standard Wiener process in reverse time s.t.
- \(\bar{\mathbf{w}}(t) := \mathbf{w}(T-t) - \mathbf{w}(T)\).
- where
- \(\text{d}\bar{\mathbf{x}}(t) = \left[ \mathbf{f}(\bar{\mathbf{x}}(t), t) - g^2(t)\nabla_\mathbf{x}\log p_t(\bar{\mathbf{x}}(t)) \right]\text{d}t + g(t)\text{d}\bar{\mathbf{w}}(t)\).
- Prop.)
- Balance between the drift and the score-driven drift
- drift : \(g(t)\text{d}\bar{\mathbf{w}}(t)\)
- This term
- provides controlled stochasticity for exploration.
- dominates at the early stage : $t\approx T$
- This term
- score=-driven drift : \(g^2(t)\nabla_\mathbf{x}\log p_t(\bar{\mathbf{x}}(t))\)
- This term
- guides trajectories toward regions of higher density
- dominates at the later stage : $t\approx 0$
- This term
- drift : \(g(t)\text{d}\bar{\mathbf{w}}(t)\)
- Balance between the drift and the score-driven drift
Concept) Probability Flow ODE (PF-ODE)
- Goal)
- A deterministic process (ODE) that evolves samples with the same marginal distributions as the forward SDE.
- Def.)
- \(\displaystyle\frac{\text{d}}{\text{d}t}\tilde{\mathbf{x}}(t) = \mathbf{f}(\tilde{\mathbf{x}}(t), t) - \frac{1}{2} g^2(t) \nabla_\mathbf{x} \log p_t(\tilde{\mathbf{x}}(t))\).
- Sample result)
- \(\tilde{\mathbf{x}}(T) + \displaystyle\int_T^0 \left[ \mathbf{f}(\tilde{\mathbf{x}(\tau)}, \tau) - \frac{1}{2}g^2(\tau)\nabla_\mathbf{x}\log p_\tau(\tilde{\mathbf{x}}(\tau)) \right]\text{d}\tau\quad\) where $\tilde{\mathbf{x}}(T)\sim p_{\text{prior}}$
- Intractable
- relies on numerical solvers
- e.g.) Euler Method
- \(\tilde{\mathbf{x}}(T) + \displaystyle\int_T^0 \left[ \mathbf{f}(\tilde{\mathbf{x}(\tau)}, \tau) - \frac{1}{2}g^2(\tau)\nabla_\mathbf{x}\log p_\tau(\tilde{\mathbf{x}}(\tau)) \right]\text{d}\tau\quad\) where $\tilde{\mathbf{x}}(T)\sim p_{\text{prior}}$
- Advantages over Reverse-Time SDE
- Bidirectionalilty of Integration
- i.e.) Can be integrated in either direction : $\int_0^T$ or $\int_T^0$
- Wide range of well-established numerical solvers for ODE.
- Bidirectionalilty of Integration
- Prop.)
- Derived by choosing a drift of an ODE s.t. its evolution preserves the same marginal densities as the forward SDE
Prop.)
- Fokker-Plank Equation ensures that Reverse SDE and PF-ODE yield identical marginal distribution \(p_t(\mathbf{x}_t)\) given the Forward SDE.
E.g.) Gaussian Closed Form Example
- Settings)
- $p_{x_t}\sim\mathcal{N}(m_t, s_t^2)$
Reverse-Time SDE Derivation
Recall that the Forward SDE was given by $$\text{d}x(t) = f(t)x(t)\text{d}t + g(t)\text{d}w_t$$ If we take one small Euler step of size $\Delta t\gt0$, we have $$\begin{aligned} x_{t+\Delta t} = ax_t + r\epsilon \quad\text{where } \begin{cases} a := 1+f(t)+\Delta t \\ r := g(t)\sqrt{\Delta t} \\ \epsilon\sim\mathcal{N}(0,1) \end{cases} \end{aligned}$$ Equivalently, the forward one-step transition kernel is Gaussian of $$x_{t+\Delta t}\vert x_t\sim\mathcal{N}(a x_t, r^2)$$ Also, the current marginal at $t$ is also Gaussian because $p_{\text{data}}$ is Gaussian, i.e. $$x_t\sim\mathcal{N}(m_t, s_t^2)\quad\exists m_t, s_t\in\mathbb{R}$$ By Bayes' rule, we may denote the conditional density as $$\begin{aligned} p(x_t\mid x_{t+\Delta t}) &\varpropto p(x_{t+\Delta t}\mid x_t) \; p_t(x_t) \\ &\varpropto \exp\left(-\frac{(x_{t+\Delta t} - ax_t)^2}{2r^2}\right) \exp\left(-\frac{(x_t-m_t)^2}{2s_t^2}\right) \end{aligned}$$ Taking the log on both sides, we may simplify into $$\begin{aligned} \log p(x_t\mid x_{t+\Delta t}) &= -\frac{(x_{t+\Delta t} - ax_t)^2}{2r^2} -\frac{(x_t-m_t)^2}{2s_t^2} + C \\ &=-\left(\frac{a^2}{2r^2}+\frac{1}{2s_t^2}\right)x_t^2 -\left(\frac{-2ax_{t+\Delta t}}{2r^2} + \frac{-2m_t}{2s_t^2}\right) x_t + C' \\ &=-\frac{A}{2}x_t^2 + B x_t + C'' & \left(\text{Put } A = \frac{a^2}{r^2}+\frac{1}{s_t^2},\; B=\frac{ax_{t+\Delta t}}{r^2} + \frac{m_t}{s_t^2}\right) \\ &= -\frac{A}{2}\left(x_t-\frac{B}{A}\right)^2 - \frac{B^2}{A} + C'' \\ \end{aligned}$$ Considering that for $x\sim\mathcal{N}(\mu,\sigma^2),\;\log x = -\frac{(x-\mu)^2}{2\sigma^2}+C$, we may get $$\text{Var}(x_t\mid x_{t+\Delta t}) = \frac{1}{A} = \frac{1}{\frac{a^2}{r^2}+\frac{1}{s_t^2}},\quad \mathbb{E}[x_t\mid x_{t+\Delta t}] = \frac{B}{A} = \frac{\frac{ax_{t+\Delta t}}{r^2} + \frac{m_t}{s_t^2}}{\frac{a^2}{r^2}+\frac{1}{s_t^2}}$$
PF-ODE Derivation
Refer to p102-104
Idea : Denote $\Delta t$ as a smooth map as $$x_{t+\Delta t} = \Phi_{t,\Delta t}(x_t) = x_t + \Delta t v_t(x_t) + \mathcal{O}(\Delta t^2)$$ Goal is to see what form $v_t$ must take so that output is Gaussian given an Gaussian input.
4.2 Training and Sampling of Score SDE
Tech.) Training Score SDE
- Loss : DSM Loss
- \(\mathcal{L}_{\text{DSM}}(\phi;\omega(\cdot)) := \displaystyle\frac{1}{2}\mathbb{E}_{t}\mathbb{E}_{\mathbf{x}_0}\mathbb{E}_{p_t(\mathbf{x}_{t}\mid\mathbf{x}_{0})}\left[ \omega(t) \Vert s_\phi(\mathbf{x}_{t}, t) - \nabla_{\mathbf{x}_{t}}\log p_t(\mathbf{x}_{t}\mid\mathbf{x}_{0}) \Vert_2^2 \right]\).
- where \(\mathbf{x}_{0}\sim p_{\text{data}}\)
- cf.) This is the continuous counterpart of the NCSN Loss of
- \(\mathcal{L}_{\text{NCSN}}(\phi) := \displaystyle\sum_{i=1}^L \lambda(\sigma_i) \mathcal{L}_{\text{DSM}}(\phi;\sigma_i)\).
- Prop.)
- The minimizer $s^*$ satisfies
\(\begin{aligned} s^*(\mathbf{x}_{t}, t) &= \mathbb{E}_{\mathbf{x}_{0}\sim p(\mathbf{x}_{0}\mid\mathbf{x}_{t})} \left[ \nabla_{\mathbf{x}_{t}}\log p_t(\mathbf{x}_{t}\mid\mathbf{x}_{0}) \right] \\ &= \nabla_{\mathbf{x}_{t}}\log p_t(\mathbf{x}_{t}) \quad\quad \text{for almost every } \mathbf{x}_{t}\sim p_t \text{ and } t\in[0,T] \end{aligned}\)
- The minimizer $s^*$ satisfies
- cf.) This is the continuous counterpart of the NCSN Loss of
- where \(\mathbf{x}_{0}\sim p_{\text{data}}\)
- \(\mathcal{L}_{\text{DSM}}(\phi;\omega(\cdot)) := \displaystyle\frac{1}{2}\mathbb{E}_{t}\mathbb{E}_{\mathbf{x}_0}\mathbb{E}_{p_t(\mathbf{x}_{t}\mid\mathbf{x}_{0})}\left[ \omega(t) \Vert s_\phi(\mathbf{x}_{t}, t) - \nabla_{\mathbf{x}_{t}}\log p_t(\mathbf{x}_{t}\mid\mathbf{x}_{0}) \Vert_2^2 \right]\).
Tech.) Sampling with Reverse-Time SDE
- Setting)
- \(s_{\phi^\times} := s_{\phi^\times}(\mathbf{x}, t) \approx \nabla_{\mathbf{x}} \log p_t(\mathbf{x})\).
- Trained and frozen parameterized score
- \(s_{\phi^\times} := s_{\phi^\times}(\mathbf{x}, t) \approx \nabla_{\mathbf{x}} \log p_t(\mathbf{x})\).
- Sampling)
- Plugging in the $s_{\phi^\times}$ into the Reverse-Time SDE, we may get
- \(\text{d}\mathbf{x}_{\phi^\times}^{\text{SDE}}(t) = \Big[ \mathbf{f}(\mathbf{x}_{\phi^\times}^{\text{SDE}}(t), t) - g^2(t) \underbrace{s_{\phi^\times}(\mathbf{x}_{\phi^\times}^{\text{SDE}}(t), t)}_{\text{plugged in}} \Big]\text{d}t + g(t)\text{d}\bar{\mathbf{w}}(t)\).
- Draw an initial value $\mathbf{x}T$ from $p{\text{prior}}$.
- From $t=T$ to $t=0$, solve the above equation using a numerical solver.
- e.g.) Euler-Maruyama Method
- \(\mathbf{x}_{t-\Delta t} \;\leftarrow\; \mathbf{x}_{t} - \left[ \mathbf{f}(\mathbf{x}_{t}, t) - g^2(t)s_{\phi^\times}(\mathbf{x}, t) \right]\Delta t + g(t)\sqrt{\Delta t} \cdot \boldsymbol{\epsilon}\).
- where
- \(\boldsymbol{\epsilon}\sim\mathcal{N}(\mathbf{0, I})\),
- $\Delta t\gt0$ is the step size.
- where
- \(\mathbf{x}_{t-\Delta t} \;\leftarrow\; \mathbf{x}_{t} - \left[ \mathbf{f}(\mathbf{x}_{t}, t) - g^2(t)s_{\phi^\times}(\mathbf{x}, t) \right]\Delta t + g(t)\sqrt{\Delta t} \cdot \boldsymbol{\epsilon}\).
- e.g.) Euler-Maruyama Method
- Plugging in the $s_{\phi^\times}$ into the Reverse-Time SDE, we may get
- Prop.)
- DDPM’s sampling scheme is one specification with some choice of $\mathbf{f}$ and g.
Tech.) Sampling with PF-ODE
- Setting)
- \(s_{\phi^\times} := s_{\phi^\times}(\mathbf{x}, t) \approx \nabla_{\mathbf{x}} \log p_t(\mathbf{x})\).
- Trained and frozen parameterized score
- \(s_{\phi^\times} := s_{\phi^\times}(\mathbf{x}, t) \approx \nabla_{\mathbf{x}} \log p_t(\mathbf{x})\).
- Sampling)
- Plugging in the $s_{\phi^\times}$ into the PF-ODE, we may get
- \(\displaystyle\frac{\text{d}}{\text{d}t} \mathbf{x}_{\phi^\times}^{\text{ODE}}(t) = \mathbf{f}(\mathbf{x}_{\phi^\times}^{\text{ODE}}(t), t) - g^2(t) \underbrace{s_{\phi^\times}(\mathbf{x}_{\phi^\times}^{\text{ODE}}(t), t)}_{\text{plugged in}}\).
- Draw an initial value $\mathbf{x}T$ from $p{\text{prior}}$.
- From $t=T$ to $t=0$, solve the above equation
- This is equivalent to approximating the below integral:
- \(\underbrace{\mathbf{x}_{\phi^\times}^{\text{ODE}}(0)}_{\text{final sample}} \;=\; \mathbf{x}(T) + \displaystyle\int_T^0 \left[ \mathbf{f} \left(\mathbf{x}_{\phi^\times}^{\text{ODE}}(\tau), \tau \right) - \frac{1}{2}g^2(\tau)s_{\phi^\times}\left( \mathbf{x}_{\phi^\times}^{\text{ODE}}(\tau) \right) \right]\text{d}\tau\).
- Or, if we use the Euler method, we may update in discrete timestep $\Delta t$ as
- \(\mathbf{x}_{t-\Delta t} \;\leftarrow\; \mathbf{x}_{t} - \left[ \mathbf{f}(\mathbf{x}_{t}, t) - g^2(t)s_{\phi^\times}(\mathbf{x}, t) \right]\Delta t\).
- cf.) Completely deterministic!
- \(\mathbf{x}_{t-\Delta t} \;\leftarrow\; \mathbf{x}_{t} - \left[ \mathbf{f}(\mathbf{x}_{t}, t) - g^2(t)s_{\phi^\times}(\mathbf{x}, t) \right]\Delta t\).
- This is equivalent to approximating the below integral:
- Plugging in the $s_{\phi^\times}$ into the PF-ODE, we may get
- Application)
- Controllable Generation
- e.g.) Mokady et al., 2023; Su et al., 2022
- Idea)
- If we solve the PF-ODE in forward direction, ODE flow maps data to its latent representation.
- Controllable Generation
- Prop.)
- Exact log-likelihood computation via PF-ODE
Desc.
4.3 Instantiations of SDEs
Concept) Variance Explosion SDE (VE SDE)
- Def.)
- \(\text{d}\mathbf{x}(t) = \displaystyle\sqrt{\frac{\text{d}\sigma^2(t)}{\text{d}t}}\text{d}\mathbf{w}(t)\).
- i.e.)
- Drift Term : $\mathbf{f} = 0$
- Diffusion Term : \(g(t) = \frac{\text{d}\sigma^2(t)}{\text{d}t}\)
- i.e.)
- \(\text{d}\mathbf{x}(t) = \displaystyle\sqrt{\frac{\text{d}\sigma^2(t)}{\text{d}t}}\text{d}\mathbf{w}(t)\).
- Props.)
- Perturbation Kernel : \(p_t(\mathbf{x}_{t}\mid\mathbf{x}_{0}) = \mathcal{N}\left(\mathbf{x}_{t};\; \mathbf{x}_{0}, \left( \sigma^2(t) - \sigma^2(0) \right)\mathbf{I} \right)\)
- Prior Distribution : \(p_{\text{data}} := \mathcal{N}(\mathbf{0}, \sigma^2(T)\mathbf{I})\)
- where $\sigma(t)$ is an increasing function for $t\in[0,T]$ and that $\sigma^2(T)\gg\sigma^2(0)$
- e.g.)
- NCSN with \(\displaystyle\sigma(t) := \sigma_\min\left(\frac{\sigma_\max}{\sigma_\min}\right)^t,\quad \text{for } t\in(0,1]\)
- i.e.) Discretized version of VE SDE
- NCSN with \(\displaystyle\sigma(t) := \sigma_\min\left(\frac{\sigma_\max}{\sigma_\min}\right)^t,\quad \text{for } t\in(0,1]\)
Concept) Variance Preserving SDE (VP SDE)
- Def.)
- \(\text{d}\mathbf{x}(t) = \displaystyle -\frac{1}{2}\beta(t)\mathbf{x}(t)\text{d}(t) + \sqrt{\beta(t)}\text{d}\mathbf{w}(t)\).
- i.e.)
- Drift Term : $\mathbf{f}(\mathbf{x}(t), t) = -\frac{1}{2}\beta(t)\mathbf{x}(t)$
- Diffusion Term : \(g(t) = \sqrt{\beta(t)}\)
- i.e.)
- \(\text{d}\mathbf{x}(t) = \displaystyle -\frac{1}{2}\beta(t)\mathbf{x}(t)\text{d}(t) + \sqrt{\beta(t)}\text{d}\mathbf{w}(t)\).
- Props.)
- Perturbation Kernel : \(p_t(\mathbf{x}_{t}\mid\mathbf{x}_{0}) = \mathcal{N}\left(\mathbf{x}_{t};\; \mathbf{x}_0 \exp\left(-\frac{1}{2}\int_0^t \beta(\tau)\text{d}\tau \right), \mathbf{I} - \mathbf{I}\exp(-\int_0^t \beta(\tau)\text{d}\tau) \right)\)
- Prior Distribution : \(p_{\text{data}} := \mathcal{N}(\mathbf{0, I})\)
- e.g.)
- DDPM with noise schedule of \(\beta(t) := \beta_\min + t(\beta_\max - \beta_\min),\quad \forall t\in[0,1]\)
- i.e.) Discretized version of VP SDE
- DDPM with noise schedule of \(\beta(t) := \beta_\min + t(\beta_\max - \beta_\min),\quad \forall t\in[0,1]\)
Enjoy Reading This Article?
Here are some more articles you might like to read next: