Flow Matching

Flow Matching 推导.

  • 任务: 还原一个数据分布 q(x)q(x) , 其中 xRdx \in \mathbb{R}^d. 我们有 q(x)q(x) 的 samples

从一个简单分布一步步变换到 q(x)q(x).

pt(x)p_t(x) 其中 t[0,1]t \in [0,1], 满足 p0(x)=N(0,I)p_0(x) = \mathcal{N}(0, I) 是一个易于 sample 的分布, p1(x)q(x)p_1(x) \approx q(x).

Prerequisite·

多元微积分坐标变换:

Ag(x)dx=f(A)g(f1(y))detf1ydy\int_Ag(x)dx=\int_{f(A)}g(f^{-1}(y))\left|\det\frac{\partial f^{-1}}{\partial y}\right|dy

概率分布随机变量的变换:

PY(y)=PX(f1(y))detf1yP_Y(y)=P_X(f^{-1}(y))\cdot\left|\det\frac{\partial f^{-1}}{\partial y}\right|

proof:

Y=f(X)Y=f(X)

对任意区域 AA, 我们有

Pr(XA)=Pr(Yf(A))APX(x)dx=f(A)PY(y)dyf(A)PX(f1(y))detf1ydy=f(A)PY(y)dyPX(f1(y))detf1y=PY(y)\begin{align*} Pr(X \in A) &= Pr(Y \in f(A)) \\ \int_A P_X(x) dx &= \int_{f(A)} P_Y(y) dy \\ \int_{f(A)} P_X(f^{-1}(y)) \cdot \left | \det \frac {\partial f^{-1}} {\partial y} \right| dy &= \int_{f(A)} P_Y(y) dy \\ P_X(f^{-1}(y))\cdot\left|\det\frac{\partial f^{-1}}{\partial y}\right| &=P_Y(y) \end{align*}

变换·

用 vector field 产生 pt(x)p_t(x): 考虑向量场 ϕ:[0,1]×RdRd\phi: [0,1]\times \mathbb{R}^d \rightarrow \mathbb{R}^d.

在时刻 tt, ϕt\phi_tp0(x)p_0(x) 映射到 pt(x)p_t(x) (通过上面的概率分布随机变量的变换). 因此我们有

[ϕt]p0(x)=p0(ϕt1(x))det[ϕt1x(x)](1)[\phi_t]_*p_0(x)=p_0(\phi_t^{-1}(x))\det\left[\frac{\partial\phi_t^{-1}}{\partial x}(x)\right] \tag{1}

我们考虑一个"连续"的变换, 相对应的就有速度场 u:[0,1]×RdRdu: [0, 1] \times \mathbb{R}^d \rightarrow \mathbb{R}^d. 速度场 uu 通过下面的公式可以产生向量场 ϕ\phi:

tϕt(x)=ut(ϕt(x))(2)\frac \partial {\partial t} \phi_t(x) = u_t(\phi_t(x)) \tag{2}

并且初始化 ϕ0(x)=x\phi_0(x) =x.

速度场 uupt(x)p_t(x) 之间的变换公式称为连续性方程, 感觉上就是 概率的变换量等于流量:

tpt(x)+(pt(x)ut(x))=0(3)\frac{\partial}{\partial t}p_t(x)+\nabla\cdot(p_t(x)u_t(x))=0 \tag 3

(1), (2), (3) 完成了概率场 ptp_t, 向量场 ϕt\phi_t 和速度场 utu_t 的相互转换.

Flow Matching·

我们想要训练速度场, 用神经网络 vtθv^{\theta}_t 去近似 utu_t. 直接的 loss 形式:

LFM(θ)=Et,pt(x)[vtθ(x)ut(x)2]\mathcal{L}_{\mathrm{FM}}(\theta)=\mathbb{E}_{t,p_t(\mathbf{x})}\left[\|\mathbf{v}_t^\theta(\mathbf{x})-\mathbf{u}_t(\mathbf{x})\|^2\right]

但是我们没法得到 ground truth ut(x)u_t(x), 我们只有 pt(x)=q(x)p_t(x) = q(x) 的一些样本.

我们想办法构造满足要求的 utu_t.

Conditional Probability Paths·

不知道为什么想到用条件速度场. 针对每一个样本 x1x_1 from q(x)q(x), 构造一组 ut(xx1)u_t(x |x_1), 并且以某种方式把所有的条件速度场 组合起来 得到 utu_t. 但是组合起来的方式无法计算, 所以用神经网络 vtθv_t^\theta 去训练做近似.

首先考察 ut(xx1)u_t(x | x_1)ut(x)u_t(x) 之间的关系. 假设 ut(xx1)u_t(x|x_1) 可以产生 pt(xx1)p_t(x|x_1).

我们希望 pt(xx1)p_t(x| x_1) 满足 p0(xx1)=N(0,I)p_0(x|x_1) = \mathcal{N}(0, I) 从一个简单的分布出发, 并且 p1(xx1)=N(x1,σminI)p_1(x|x_1) = \mathcal{N}(x_1, \sigma_{min} I), 是一个 centered by x1x_1 的方差非常小的分布, 几乎就是 x=x1x=x_1.

此时自然地

pt(x)=Ex1pt(xx1)=x1pt(xx1)q(x1)dx1p_t(x) = \mathbb{E}_{x_1} p_t(x | x_1)= \int_{x_1} p_t(x | x_1) q(x_1)dx_1

t=1t=1pt(x)q(x)p_t(x) \rightarrow q(x).

那么我们可以得到 ut(x)u_t(x) by (3)

(pt(x)ut(x))=tpt(x)=tx1pt(xx1)q(x1)dx1=x1tpt(xx1)q(x1)dx1=x1(pt(xx1)ut(xx1))q(x1)dx1=x1pt(xx1)ut(xx1)q(x1)dx1\begin{align*} \nabla \cdot (p_t(x) u_t(x)) &= - \frac {\partial} {\partial t} p_t(x) \\ &= -\frac {\partial} {\partial t}\int_{x_1} p_t(x | x_1) q(x_1) dx_1 \\ &= \int_{x_1} -\frac {\partial} {\partial t}p_t(x | x_1) q(x_1) dx_1 \\ &= \int_{x_1} \nabla \cdot (p_t(x | x_1) u_t(x | x_1)) q(x_1)dx_1 \\ &= \nabla \cdot \int_{x_1} p_t(x | x_1) u_t(x | x_1) q(x_1)dx_1 \end{align*}

于是

ut(x)=ut(xx1)pt(xx1)q(x1)pt(x)dx1.(4)u_t(x)=\int u_t(x|x_1)\frac{p_t(x|x_1)q(x_1)}{p_t(x)}dx_1. \tag 4

Loss·

我们试图把 loss 根据条件速度场重写.

仿照 score matching, 我们把 loss 右边展开并且只考虑与 θ\theta 相关的项

Et,pt(x)[vtθ(x)ut(x)2]=Et,pt(x)[vtθ(x)2+ut(x)22<vtθ(x),ut(x)>]=Et,pt(x)[vtθ(x)2]2Et,pt(x)[<vtθ(x),ut(x)>]Ept(x)[<vtθ(x),ut(x)>]=pt(x)<vtθ(x),ut(x)>dx=pt(x)<vtθ(x),ut(xx1)pt(xx1)q(x1)pt(x)dx1>dx=<vtθ(x),ut(xx1)>pt(xx1)q(x1)dxdx1=Eq(x1)pt(xx1)[<vtθ(x),ut(xx1)>]\begin{align*} \mathbb{E}_{t, p_t(x)} [\| v_t^\theta (x) - u_t(x)\| ^2 ] &= \mathbb{E}_{t, p_t(x)} \left [ \|v_t^\theta(x)\| ^ 2 + \|u_t(x)\|^2 - 2 \left <v_t^\theta(x), u_t(x)\right>\right] \\ &= \mathbb{E}_{t, p_t(x)}\left[\|v_t^\theta(x)\|^2\right] - 2 \mathbb{E}_{t, p_t(x)}\left[\left <v_t^\theta(x), u_t(x)\right>\right] \\ \mathbb{E}_{p_t(x)}\left[\left <v_t^\theta(x), u_t(x)\right>\right] &= \int p_t(x) \left <v_t^\theta(x), u_t(x)\right> dx \\ &= \int p_t(x) \left < v_t^\theta (x), \int u_t(x | x_1) \frac{p_t(x|x_1)q(x_1)}{p_t(x)}dx_1 \right > dx \\ &= \int \int \left < v_t^\theta(x), u_t(x|x_1)\right> p_t(x | x_1) q(x_1)dx dx_1 \\ &= \mathbb{E}_{q(x_1)p_t(x | x_1)} \left[\left<v_t^\theta (x), u_t(x | x_1)\right>\right] \end{align*}

所以原始的损失函数可以改写成

LCFM(θ)=Et,q(x1),pt(xx1)vtθ(x)ut(xx1)2.\mathcal{L}_{\mathrm{CFM}}(\theta)=\mathbb{E}_{t,q(x_1),p_t(x|x_1)}\|v_t^\theta(x)-u_t(x|x_1)\|^2.

设计 ut(xx1)u_t(x|x_1)·

Recap: 我们希望 pt(xx1)p_t(x| x_1) 满足 p0(xx1)=N(0,I)p_0(x|x_1) = \mathcal{N}(0, I) 从一个简单的分布出发, 并且 p1(xx1)=N(x1,σminI)p_1(x|x_1) = \mathcal{N}(x_1, \sigma_{min} I), 是一个 centered by x1x_1 的方差非常小的分布, 几乎就是 x=x1x=x_1.

考虑这个向量场

ψt(x)=σt(x1)x+μt(x1).\psi_t(x) = \sigma_t (x_1) x + \mu_t(x_1).

他能够把一个 N(0,I)\mathcal{N}(0,I) 变换到 N(μt(x1),σt(x1)I)\mathcal{N}(\mu_t(x_1), \sigma_t(x_1) I).

根据 (1), 有

[ψt(x)]p0(xx1)=p0(ψt1(x))det[ψt1x(x)]=p0(xμt(x1)σt(x1))1σt(x1)=N(xμt(x1)σt(x1);0,I)1σt(x1)=N(x;μt(x1),σt2(x1)I)=pt(xx1)\begin{align*} [\psi_t(x)]_* p_0(x | x_1) &= p_0(\psi_t^{-1}(x)) \det \left[\frac{\partial\psi_t^{-1}}{\partial x}(x)\right] \\ &= p_0\left (\frac {x - \mu_t(x_1)} {\sigma_t(x_1)}\right) \cdot \left |\frac 1 {\sigma_t(x_1)}\right| \\ &= \mathcal{N}\left ( \frac {x - \mu_t(x_1)} {\sigma_t(x_1)}; 0, I\right) \cdot \left |\frac 1 {\sigma_t(x_1)}\right| \\ &= \mathcal{N} \left(x; \mu_t(x_1), \sigma_t^2(x_1)I\right) \\ &= p_t(x|x_1) \end{align*}

我们只要令 μ1(x1)=x1\mu_1(x_1) = x_1σt(x1)\sigma_t(x_1) 充分小即可.

根据向量场我们可以得到对应的速度场 by (2):

ddtψt(x)=ut(ψt(x)x1)\frac{d}{dt}\psi_t(x)=u_t(\psi_t(x)|x_1)

解得

ut(xx1)=σt(x1)σt(x1)(xμt(x1))+μt(x1).u_t(x|x_1)=\frac{\sigma_t^{\prime}(x_1)}{\sigma_t(x_1)}\left(x-\mu_t(x_1)\right)+\mu_t^{\prime}(x_1).

最后的损失函数形式如下:

Lcra(θ)=Et,q(x1),p(x0)=N(0,I)vtθ(ψt(x0))ddtψt(x0)2.\mathcal{L}_{\mathrm{cra}}(\theta)=\mathbb{E}_{t,q(x_1),p(x_0)=\mathcal{N}(0,I)}\left\|v_t^\theta(\psi_t(x_0))-\frac{d}{dt}\psi_t(x_0)\right\|^2.

Analysis·

我们训练 vθv^\theta 需要做的就是从标准正态分布 sample x0x_0, 从样本拿 x1x_1, 然后计算 loss. μt\mu_tσt\sigma_t 都是我们可以自己定义的, 只要满足 μ0=0,μ1=x1,σ0=1,σ11\mu_0 = 0, \mu_1 = x_1, \sigma_0 = 1, \sigma_1 \ll 1.

[TODO] relationship w/ diffusion·