<- Back to all posts

Common Video-Gen Distillation Recipes + Pseudocode (I): ReFlow, Progressive, Consistency, rCM

2026-02-22

A quick cheat-sheet of four widely used distillation recipes for diffusion/flow-style video generation models, with minimal PyTorch-like pseudocode.

1. ReFlow (Rectification / Trajectory Straightening)

Core intuition: Use the teacher to pair pure noise with data. The student learns the constant velocity of the straight line segment between the two endpoints, which “straightens” the generation trajectory.

def train_reflow(teacher_model, student_model, dataloader):
    for x_real in dataloader:
        # 1) Sample pure noise
        z_0 = torch.randn_like(x_real)

        # 2) Teacher generates the paired endpoint (usually the expensive step)
        with torch.no_grad():
            z_1_fake = ode_solve(teacher_model, z_0, steps=50)

        # 3) Sample a random time t on the straight line
        t = torch.rand(1)

        # 4) Linear interpolation point + target constant velocity
        z_t = (1 - t) * z_0 + t * z_1_fake
        target_velocity = z_1_fake - z_0

        # 5) Student predicts the constant velocity
        pred_velocity = student_model(z_t, t)
        loss = F.mse_loss(pred_velocity, target_velocity)

        loss.backward()
        optimizer.step()

2. Progressive Distillation

Core intuition: Make one student step match two teacher steps (i.e., step size doubles). Iterate to halve steps, e.g. 64 → 32 → 16 → 8.

def train_progressive_distillation(teacher_model, student_model, dataloader, current_steps):
    dt = 1.0 / current_steps

    for x_real in dataloader:
        # 1) Sample discrete time t and its noisy state z_t
        t = sample_discrete_time(current_steps)
        z_t = add_noise(x_real, t)

        # 2) Teacher takes two small steps
        with torch.no_grad():
            z_next_1 = teacher_model.euler_step(z_t, t, dt)
            z_next_2 = teacher_model.euler_step(z_next_1, t - dt, dt)

        # 3) Student takes one big step
        z_student = student_model.euler_step(z_t, t, 2 * dt)

        # 4) Match the endpoints
        loss = F.mse_loss(z_student, z_next_2)

        loss.backward()
        optimizer.step()

3. Consistency Distillation

Core intuition: Break the fixed-step constraint. Require that for any two adjacent points on the trajectory, the predicted x0 (the denoised “origin”) is identical. Use an EMA shadow model to stabilize training.

def train_consistency_distillation(teacher_model, student_model, ema_student_model, dataloader):
    for x_real in dataloader:
        # 1) Sample adjacent timesteps
        t_n, t_n_minus_1 = sample_adjacent_timesteps()
        z_tn = add_noise(x_real, t_n)

        # 2) Teacher connects the two nearby times with a tiny step
        with torch.no_grad():
            z_tn_minus_1_hat = teacher_model.euler_step(z_tn, t_n, t_n - t_n_minus_1)

        # 3) Student predicts x0 at current time
        pred_x0_current = student_model.predict_x0(z_tn, t_n)

        # 4) EMA model predicts x0 at the next time (stable target)
        with torch.no_grad():
            target_x0 = ema_student_model.predict_x0(z_tn_minus_1_hat, t_n_minus_1)

        # 5) Enforce consistency
        loss = F.mse_loss(pred_x0_current, target_x0)

        loss.backward()
        optimizer.step()
        update_ema(ema_student_model, student_model)

4. rCM (Score-Regularized Continuous-Time CM)

Core intuition: Combine continuous-time consistency (via a JVP-based tangent) with distribution matching / adversarial regularization (DMD-style score) as a long-skip regularizer to reduce detail collapse.

def train_rCM(teacher, student, fake_score, dataloader, config):
    # student contains online_model and ema_model
    for i, (x_0, prompt) in enumerate(dataloader):
        epsilon = torch.randn_like(x_0)
        [cite_start]t = sample_time_from_Pg() # [cite: 976]
        [cite_start]x_t = torch.cos(t) * x_0 + torch.sin(t) * epsilon # TrigFlow schedule [cite: 601, 976]

        # ==========================================
        # Generator Step (optimize student)
        # ==========================================
        # [cite_start]1. sCM loss (forward consistency) [cite: 615, 979, 980]
        # [cite_start]JVP is the key of continuous-time consistency: dF_{theta-}/dt [cite: 978]
        v_teacher = teacher(x_t, t, prompt)
        [cite_start]jvp_term = compute_jvp(student.ema_model, x_t, t, v_teacher) # [cite: 978]

        [cite_start]g = construct_tangent_signal(student.ema_model(x_t, t), v_teacher, jvp_term, t) # [cite: 979]
        [cite_start]loss_sCM = mse_loss(student.online_model(x_t, t) - student.ema_model(x_t, t), g / (norm(g) + c)) # [cite: 980]

        # [cite_start]2. DMD loss (reverse-divergence regularization) [cite: 671]
        loss_DMD = 0
        if i > config.tangent_warmup_iterations:
            # [cite_start]Few-step sampling from student to produce fake data [cite: 984]
            [cite_start]x_0_theta = simulate_backward_trajectory(student.online_model, steps=random(1, N_max)) # [cite: 984]
            [cite_start]t_D = sample_time_from_Pd() # [cite: 985]
            [cite_start]x_tD = add_noise_trigflow(x_0_theta, t_D) # [cite: 985]

            # [cite_start]Use fake score net to evaluate quality [cite: 1013]
            [cite_start]weighting = mean(abs(x_0_theta - teacher(x_tD, t_D))) # [cite: 1013]
            [cite_start]loss_DMD = mse_loss(x_0_theta, stop_gradient(x_0_theta - (fake_score(x_tD, t_D) - teacher(x_tD, t_D)) / weighting)) # [cite: 1013]

        # [cite_start]Jointly optimize student [cite: 1013]
        [cite_start]total_loss = loss_sCM + config.lambda_weight * loss_DMD # lambda default ~0.01 [cite: 675]
        total_loss.backward()
        optimizer_student.step()
        update_power_ema(student.ema_model, student.online_model)

        # ==========================================
        # Critic Step (optimize fake score net)
        # ==========================================
        [cite_start]if i % config.update_freq == 0 and i > config.tangent_warmup_iterations: # [cite: 987, 1002]
            # [cite_start]Target: match Flow Matching velocity [cite: 1007, 1009]
            [cite_start]t_D = sample_time_from_Pd() # [cite: 1007]
            [cite_start]x_t_fake = torch.cos(t_D) * x_0_theta.detach() + torch.sin(t_D) * epsilon # [cite: 1007]
            [cite_start]v_target = torch.cos(t_D) * epsilon - torch.sin(t_D) * x_0_theta.detach() # [cite: 1007]

            [cite_start]loss_fake_score = F.mse_loss(fake_score(x_t_fake, t_D), v_target) # [cite: 1009]
            loss_fake_score.backward()
            optimizer_fake_score.step()