Flow Matching 推导.
- 任务: 还原一个数据分布 q(x) , 其中 x∈Rd. 我们有 q(x) 的 samples
从一个简单分布一步步变换到 q(x).
pt(x) 其中 t∈[0,1], 满足 p0(x)=N(0,I) 是一个易于 sample 的分布, p1(x)≈q(x).
Prerequisite·
多元微积分坐标变换:
∫Ag(x)dx=∫f(A)g(f−1(y))det∂y∂f−1dy
概率分布随机变量的变换:
PY(y)=PX(f−1(y))⋅det∂y∂f−1
proof:
Y=f(X)
对任意区域 A, 我们有
Pr(X∈A)∫APX(x)dx∫f(A)PX(f−1(y))⋅det∂y∂f−1dyPX(f−1(y))⋅det∂y∂f−1=Pr(Y∈f(A))=∫f(A)PY(y)dy=∫f(A)PY(y)dy=PY(y)
用 vector field 产生 pt(x): 考虑向量场 ϕ:[0,1]×Rd→Rd.
在时刻 t, ϕt 把 p0(x) 映射到 pt(x) (通过上面的概率分布随机变量的变换). 因此我们有
[ϕt]∗p0(x)=p0(ϕt−1(x))det[∂x∂ϕt−1(x)](1)
我们考虑一个"连续"的变换, 相对应的就有速度场 u:[0,1]×Rd→Rd. 速度场 u 通过下面的公式可以产生向量场 ϕ:
∂t∂ϕt(x)=ut(ϕt(x))(2)
并且初始化 ϕ0(x)=x.
速度场 u 和 pt(x) 之间的变换公式称为连续性方程, 感觉上就是 概率的变换量等于流量:
∂t∂pt(x)+∇⋅(pt(x)ut(x))=0(3)
(1), (2), (3) 完成了概率场 pt, 向量场 ϕt 和速度场 ut 的相互转换.
Flow Matching·
我们想要训练速度场, 用神经网络 vtθ 去近似 ut. 直接的 loss 形式:
LFM(θ)=Et,pt(x)[∥vtθ(x)−ut(x)∥2]
但是我们没法得到 ground truth ut(x), 我们只有 pt(x)=q(x) 的一些样本.
我们想办法构造满足要求的 ut.
Conditional Probability Paths·
不知道为什么想到用条件速度场.
针对每一个样本 x1 from q(x), 构造一组 ut(x∣x1), 并且以某种方式把所有的条件速度场 组合起来 得到 ut.
但是组合起来的方式无法计算, 所以用神经网络 vtθ 去训练做近似.
首先考察 ut(x∣x1) 和 ut(x) 之间的关系. 假设 ut(x∣x1) 可以产生 pt(x∣x1).
我们希望 pt(x∣x1) 满足 p0(x∣x1)=N(0,I) 从一个简单的分布出发, 并且 p1(x∣x1)=N(x1,σminI), 是一个 centered by x1 的方差非常小的分布, 几乎就是 x=x1.
此时自然地
pt(x)=Ex1pt(x∣x1)=∫x1pt(x∣x1)q(x1)dx1
在 t=1 时 pt(x)→q(x).
那么我们可以得到 ut(x) by (3)
∇⋅(pt(x)ut(x))=−∂t∂pt(x)=−∂t∂∫x1pt(x∣x1)q(x1)dx1=∫x1−∂t∂pt(x∣x1)q(x1)dx1=∫x1∇⋅(pt(x∣x1)ut(x∣x1))q(x1)dx1=∇⋅∫x1pt(x∣x1)ut(x∣x1)q(x1)dx1
于是
ut(x)=∫ut(x∣x1)pt(x)pt(x∣x1)q(x1)dx1.(4)
Loss·
我们试图把 loss 根据条件速度场重写.
仿照 score matching, 我们把 loss 右边展开并且只考虑与 θ 相关的项
Et,pt(x)[∥vtθ(x)−ut(x)∥2]Ept(x)[⟨vtθ(x),ut(x)⟩]=Et,pt(x)[∥vtθ(x)∥2+∥ut(x)∥2−2⟨vtθ(x),ut(x)⟩]=Et,pt(x)[∥vtθ(x)∥2]−2Et,pt(x)[⟨vtθ(x),ut(x)⟩]=∫pt(x)⟨vtθ(x),ut(x)⟩dx=∫pt(x)⟨vtθ(x),∫ut(x∣x1)pt(x)pt(x∣x1)q(x1)dx1⟩dx=∫∫⟨vtθ(x),ut(x∣x1)⟩pt(x∣x1)q(x1)dxdx1=Eq(x1)pt(x∣x1)[⟨vtθ(x),ut(x∣x1)⟩]
所以原始的损失函数可以改写成
LCFM(θ)=Et,q(x1),pt(x∣x1)∥vtθ(x)−ut(x∣x1)∥2.
设计 ut(x∣x1)·
Recap: 我们希望 pt(x∣x1) 满足 p0(x∣x1)=N(0,I) 从一个简单的分布出发, 并且 p1(x∣x1)=N(x1,σminI), 是一个 centered by x1 的方差非常小的分布, 几乎就是 x=x1.
考虑这个向量场
ψt(x)=σt(x1)x+μt(x1).
他能够把一个 N(0,I) 变换到 N(μt(x1),σt(x1)I).
根据 (1), 有
[ψt(x)]∗p0(x∣x1)=p0(ψt−1(x))det[∂x∂ψt−1(x)]=p0(σt(x1)x−μt(x1))⋅σt(x1)1=N(σt(x1)x−μt(x1);0,I)⋅σt(x1)1=N(x;μt(x1),σt2(x1)I)=pt(x∣x1)
我们只要令 μ1(x1)=x1 且 σt(x1) 充分小即可.
根据向量场我们可以得到对应的速度场 by (2):
dtdψt(x)=ut(ψt(x)∣x1)
解得
ut(x∣x1)=σt(x1)σt′(x1)(x−μt(x1))+μt′(x1).
最后的损失函数形式如下:
Lcra(θ)=Et,q(x1),p(x0)=N(0,I)vtθ(ψt(x0))−dtdψt(x0)2.
Analysis·
我们训练 vθ 需要做的就是从标准正态分布 sample x0, 从样本拿 x1, 然后计算 loss. μt 和 σt 都是我们可以自己定义的, 只要满足 μ0=0,μ1=x1,σ0=1,σ1≪1.
[TODO] relationship w/ diffusion·