Common Video-Gen Distillation Recipes + Pseudocode (I): ReFlow, Progressive, Consistency, rCM常见Video Gen Model Distillation方法及其伪代码 I: ReFlow, Progressive, Consistency, rCM
A quick cheat-sheet of four widely used distillation recipes for diffusion/flow-style video generation models, with minimal PyTorch-like pseudocode.核心直觉: 利用教师模型将纯噪声和真实数据一一对应,学生模型通过学习两点之间的直线速度 来拉直生成轨迹。
1. ReFlow (Rectification / Trajectory Straightening)
Core intuition: Use the teacher to pair pure noise with data. The student learns the constant velocity of the straight line segment between the two endpoints, which “straightens” the generation trajectory.
def train_reflow(teacher_model, student_model, dataloader):
for x_real in dataloader:
# 1) Sample pure noise
z_0 = torch.randn_like(x_real)
# 2) Teacher generates the paired endpoint (usually the expensive step)
with torch.no_grad():
z_1_fake = ode_solve(teacher_model, z_0, steps=50)
# 3) Sample a random time t on the straight line
t = torch.rand(1)
# 4) Linear interpolation point + target constant velocity
z_t = (1 - t) * z_0 + t * z_1_fake
target_velocity = z_1_fake - z_0
# 5) Student predicts the constant velocity
pred_velocity = student_model(z_t, t)
loss = F.mse_loss(pred_velocity, target_velocity)
loss.backward()
optimizer.step()
2. Progressive Distillation
Core intuition: Make one student step match two teacher steps (i.e., step size doubles). Iterate to halve steps, e.g. 64 → 32 → 16 → 8.
def train_progressive_distillation(teacher_model, student_model, dataloader, current_steps):
dt = 1.0 / current_steps
for x_real in dataloader:
# 1) Sample discrete time t and its noisy state z_t
t = sample_discrete_time(current_steps)
z_t = add_noise(x_real, t)
# 2) Teacher takes two small steps
with torch.no_grad():
z_next_1 = teacher_model.euler_step(z_t, t, dt)
z_next_2 = teacher_model.euler_step(z_next_1, t - dt, dt)
# 3) Student takes one big step
z_student = student_model.euler_step(z_t, t, 2 * dt)
# 4) Match the endpoints
loss = F.mse_loss(z_student, z_next_2)
loss.backward()
optimizer.step()
3. Consistency Distillation
Core intuition: Break the fixed-step constraint. Require that for any two adjacent points on the trajectory, the predicted x0 (the denoised “origin”) is identical. Use an EMA shadow model to stabilize training.
def train_consistency_distillation(teacher_model, student_model, ema_student_model, dataloader):
for x_real in dataloader:
# 1) Sample adjacent timesteps
t_n, t_n_minus_1 = sample_adjacent_timesteps()
z_tn = add_noise(x_real, t_n)
# 2) Teacher connects the two nearby times with a tiny step
with torch.no_grad():
z_tn_minus_1_hat = teacher_model.euler_step(z_tn, t_n, t_n - t_n_minus_1)
# 3) Student predicts x0 at current time
pred_x0_current = student_model.predict_x0(z_tn, t_n)
# 4) EMA model predicts x0 at the next time (stable target)
with torch.no_grad():
target_x0 = ema_student_model.predict_x0(z_tn_minus_1_hat, t_n_minus_1)
# 5) Enforce consistency
loss = F.mse_loss(pred_x0_current, target_x0)
loss.backward()
optimizer.step()
update_ema(ema_student_model, student_model)
4. rCM (Score-Regularized Continuous-Time CM)
Core intuition: Combine continuous-time consistency (via a JVP-based tangent) with distribution matching / adversarial regularization (DMD-style score) as a long-skip regularizer to reduce detail collapse.
def train_rCM(teacher, student, fake_score, dataloader, config):
# student contains online_model and ema_model
for i, (x_0, prompt) in enumerate(dataloader):
epsilon = torch.randn_like(x_0)
[cite_start]t = sample_time_from_Pg() # [cite: 976]
[cite_start]x_t = torch.cos(t) * x_0 + torch.sin(t) * epsilon # TrigFlow schedule [cite: 601, 976]
# ==========================================
# Generator Step (optimize student)
# ==========================================
# [cite_start]1. sCM loss (forward consistency) [cite: 615, 979, 980]
# [cite_start]JVP is the key of continuous-time consistency: dF_{theta-}/dt [cite: 978]
v_teacher = teacher(x_t, t, prompt)
[cite_start]jvp_term = compute_jvp(student.ema_model, x_t, t, v_teacher) # [cite: 978]
[cite_start]g = construct_tangent_signal(student.ema_model(x_t, t), v_teacher, jvp_term, t) # [cite: 979]
[cite_start]loss_sCM = mse_loss(student.online_model(x_t, t) - student.ema_model(x_t, t), g / (norm(g) + c)) # [cite: 980]
# [cite_start]2. DMD loss (reverse-divergence regularization) [cite: 671]
loss_DMD = 0
if i > config.tangent_warmup_iterations:
# [cite_start]Few-step sampling from student to produce fake data [cite: 984]
[cite_start]x_0_theta = simulate_backward_trajectory(student.online_model, steps=random(1, N_max)) # [cite: 984]
[cite_start]t_D = sample_time_from_Pd() # [cite: 985]
[cite_start]x_tD = add_noise_trigflow(x_0_theta, t_D) # [cite: 985]
# [cite_start]Use fake score net to evaluate quality [cite: 1013]
[cite_start]weighting = mean(abs(x_0_theta - teacher(x_tD, t_D))) # [cite: 1013]
[cite_start]loss_DMD = mse_loss(x_0_theta, stop_gradient(x_0_theta - (fake_score(x_tD, t_D) - teacher(x_tD, t_D)) / weighting)) # [cite: 1013]
# [cite_start]Jointly optimize student [cite: 1013]
[cite_start]total_loss = loss_sCM + config.lambda_weight * loss_DMD # lambda default ~0.01 [cite: 675]
total_loss.backward()
optimizer_student.step()
update_power_ema(student.ema_model, student.online_model)
# ==========================================
# Critic Step (optimize fake score net)
# ==========================================
[cite_start]if i % config.update_freq == 0 and i > config.tangent_warmup_iterations: # [cite: 987, 1002]
# [cite_start]Target: match Flow Matching velocity [cite: 1007, 1009]
[cite_start]t_D = sample_time_from_Pd() # [cite: 1007]
[cite_start]x_t_fake = torch.cos(t_D) * x_0_theta.detach() + torch.sin(t_D) * epsilon # [cite: 1007]
[cite_start]v_target = torch.cos(t_D) * epsilon - torch.sin(t_D) * x_0_theta.detach() # [cite: 1007]
[cite_start]loss_fake_score = F.mse_loss(fake_score(x_t_fake, t_D), v_target) # [cite: 1009]
loss_fake_score.backward()
optimizer_fake_score.step()
1. ReFlow (整流 / 轨迹拉直)
核心直觉: 利用教师模型将纯噪声和真实数据一一对应,学生模型通过学习两点之间的直线速度 来拉直生成轨迹。
def train_reflow(teacher_model, student_model, dataloader):
for x_real in dataloader:
# 1. 采样纯噪声
z_0 = torch.randn_like(x_real)
# 2. 教师模型生成配对数据 (最耗时的一步)
with torch.no_grad():
z_1_fake = ode_solve(teacher_model, z_0, steps=50)
# 3. 在直线轨迹上随机采样时间 t
t = torch.rand(1)
# 4. 构造直线插值点和目标恒定速度
z_t = (1 - t) * z_0 + t * z_1_fake
target_velocity = z_1_fake - z_0
# 5. 学生模型只需预测这个恒定速度
pred_velocity = student_model(z_t, t)
loss = F.mse_loss(pred_velocity, target_velocity)
loss.backward()
optimizer.step()
2. Progressive Distillation (渐进式蒸馏)
核心直觉: 学生走 步(跨度 )等于教师走 步(跨度 )。通过循环将 64 步减半至 32、16、8 步。
def train_progressive_distillation(teacher_model, student_model, dataloader, current_steps):
dt = 1.0 / current_steps
for x_real in dataloader:
# 1. 采样当前时间点 t 和对应的加噪状态 z_t
t = sample_discrete_time(current_steps)
z_t = add_noise(x_real, t)
# 2. 教师模型走两小步
with torch.no_grad():
z_next_1 = teacher_model.euler_step(z_t, t, dt)
z_next_2 = teacher_model.euler_step(z_next_1, t - dt, dt)
# 3. 学生模型走一大步
z_student = student_model.euler_step(z_t, t, 2 * dt)
# 4. 对齐结果
loss = F.mse_loss(z_student, z_next_2)
loss.backward()
optimizer.step()
3. Consistency Distillation (一致性蒸馏)
核心直觉: 打破步数限制,要求模型在轨迹上任意相邻的两点 预测出的原点 都必须完全一致。使用 EMA 影子模型防止训练崩坏。
def train_consistency_distillation(teacher_model, student_model, ema_student_model, dataloader):
for x_real in dataloader:
# 1. 采样相邻时间点
t_n, t_n_minus_1 = sample_adjacent_timesteps()
z_tn = add_noise(x_real, t_n)
# 2. 教师模型在两个时间点之间建立微小连接
with torch.no_grad():
z_tn_minus_1_hat = teacher_model.euler_step(z_tn, t_n, t_n - t_n_minus_1)
# 3. 学生模型预测当前点的 x0
pred_x0_current = student_model.predict_x0(z_tn, t_n)
# 4. EMA模型预测下一步的 x0 (作为稳定目标)
with torch.no_grad():
target_x0 = ema_student_model.predict_x0(z_tn_minus_1_hat, t_n_minus_1)
# 5. 强制一致性
loss = F.mse_loss(pred_x0_current, target_x0)
loss.backward()
optimizer.step()
update_ema(ema_student_model, student_model)
4. rCM (Score-Regularized Continuous-Time CM)
核心直觉: 将连续时间的一致性(基于 JVP 切线)与对抗/分布匹配(DMD 评分)结合 。前者保证动作连贯与模式覆盖,后者作为长跳正则化(Long-skip regularizer)修复细节崩坏 。
def train_rCM(teacher, student, fake_score, dataloader, config):
# student 包含 online_model 和 ema_model
for i, (x_0, prompt) in enumerate(dataloader):
epsilon = torch.randn_like(x_0)
[cite_start]t = sample_time_from_Pg() # [cite: 976]
[cite_start]x_t = torch.cos(t) * x_0 + torch.sin(t) * epsilon # TrigFlow 调度 [cite: 601, 976]
# ==========================================
# Generator Step (优化学生模型)
# ==========================================
# [cite_start]1. 计算 sCM 损失 (前向一致性) [cite: 615, 979, 980]
# [cite_start]JVP 是连续时间一致性的核心,计算 dF_{\theta-}/dt [cite: 978]
v_teacher = teacher(x_t, t, prompt)
[cite_start]jvp_term = compute_jvp(student.ema_model, x_t, t, v_teacher) # [cite: 978]
[cite_start]g = construct_tangent_signal(student.ema_model(x_t, t), v_teacher, jvp_term, t) # [cite: 979]
[cite_start]loss_sCM = mse_loss(student.online_model(x_t, t) - student.ema_model(x_t, t), g / (norm(g) + c)) # [cite: 980]
# [cite_start]2. 计算 DMD 损失 (反向散度正则化) [cite: 671]
loss_DMD = 0
if i > config.tangent_warmup_iterations:
# [cite_start]学生模型进行少步采样生成伪数据 [cite: 984]
[cite_start]x_0_theta = simulate_backward_trajectory(student.online_model, steps=random(1, N_max)) # [cite: 984]
[cite_start]t_D = sample_time_from_Pd() # [cite: 985]
[cite_start]x_tD = add_noise_trigflow(x_0_theta, t_D) # [cite: 985]
# [cite_start]使用 Fake Score 网络评估生成质量 [cite: 1013]
[cite_start]weighting = mean(abs(x_0_theta - teacher(x_tD, t_D))) # [cite: 1013]
[cite_start]loss_DMD = mse_loss(x_0_theta, stop_gradient(x_0_theta - (fake_score(x_tD, t_D) - teacher(x_tD, t_D)) / weighting)) # [cite: 1013]
# [cite_start]联合优化学生模型 [cite: 1013]
[cite_start]total_loss = loss_sCM + config.lambda_weight * loss_DMD # lambda 默认推荐 0.01 [cite: 675]
total_loss.backward()
optimizer_student.step()
update_power_ema(student.ema_model, student.online_model)
# ==========================================
# Critic Step (优化 Fake Score 网络)
# ==========================================
[cite_start]if i % config.update_freq == 0 and i > config.tangent_warmup_iterations: # [cite: 987, 1002]
# [cite_start]目标是匹配 Flow Matching 速度 [cite: 1007, 1009]
[cite_start]t_D = sample_time_from_Pd() # [cite: 1007]
[cite_start]x_t_fake = torch.cos(t_D) * x_0_theta.detach() + torch.sin(t_D) * epsilon # [cite: 1007]
[cite_start]v_target = torch.cos(t_D) * epsilon - torch.sin(t_D) * x_0_theta.detach() # [cite: 1007]
[cite_start]loss_fake_score = F.mse_loss(fake_score(x_t_fake, t_D), v_target) # [cite: 1009]
loss_fake_score.backward()
optimizer_fake_score.step()