Skip to content

Large Scale Diffusion Distillation via Score-Regularized Continuous-Time Consistency

Conference: ICLR 2026 arXiv: 2510.08431 Code: Project Page Area: Image Generation Keywords: Continuous-time consistency models, score distillation, large-scale distillation, JVP, few-step generation

TL;DR

This paper proposes rCM (score-regularized continuous-time consistency model), which for the first time scales continuous-time consistency distillation to 14B-parameter text-to-image/video models. By combining forward KL divergence (consistency) with reverse KL divergence (score distillation), rCM matches DMD2 in quality while preserving diversity, achieving 15–50× inference speedup.

Background & Motivation

  • Background: sCM (continuous-time consistency models) is theoretically elegant, but its applicability to large-scale text-to-image/video models remains unclear — JVP computation is incompatible with FlashAttention and distributed training.
  • Limitations of Prior Work: sCM suffers from quality degradation in fine-grained generation (error accumulation and quality diffusion caused by the mode-covering nature of forward KL divergence). Score/adversarial distillation methods (e.g., DMD2) achieve superior quality but are prone to mode collapse and lack diversity.
  • Key Challenge: Forward KL divergence (consistency models) and reverse KL divergence (score distillation) have complementary characteristics.
  • Goal: Combine the two divergence types to achieve both high quality and high diversity at large scale.

Method

Overall Architecture

rCM = sCM (forward KL consistency distillation) + DMD (reverse KL score distillation) + infrastructure optimization. Training alternates between optimizing the student model (rCM loss) and the fake score network (flow matching loss).

Key Designs

  1. FlashAttention-2 JVP Kernel: A Triton kernel is developed to integrate JVP into the forward pass of FlashAttention-2, supporting both self-attention and cross-attention, and compatible with FSDP and Context Parallelism, enabling sCM training to scale to models with 10B+ parameters.

  2. Score Regularization: The DMD loss is introduced as a long-jump regularizer to complement sCM. The final objective is: $\(\mathcal{L}_{\text{rCM}}(\theta) = \mathcal{L}_{\text{sCM}}(\theta) + \lambda \mathcal{L}_{\text{DMD}}(\theta)\)$ with \(\lambda=0.01\) being universal across models and tasks. sCM provides mode-covering (high diversity), while DMD provides mode-seeking (high quality).

  3. Stable Time-Derivative Computation: To address JVP training instability in large models, two strategies are proposed:

  4. Semi-continuous time: JVP is used for spatial components, while finite differences approximate the temporal component (\(\Delta t = 10^{-4}\)).
  5. High-precision time: FP32 precision is enforced on time-embedding layers.

  6. Rollout Strategy: The student can perform arbitrary-step sampling; the number of steps \(N \in [1, N_{\max}]\) is randomly selected. DMD loss is backpropagated only through the final step, with random timesteps ensuring coverage of the full time range.

Loss & Training

  • sCM loss (tangent normalization): \(\mathcal{L}_{\text{sCM}} = \mathbb{E}\left[\left\|\mathbf{F}_\theta - \mathbf{F}_{\theta^-} - \frac{\mathbf{g}}{\|\mathbf{g}\|_2^2 + c}\right\|_2^2\right]\)
  • DMD loss: guides the student using the discrepancy between fake score and teacher score.
  • The fake score network is trained on student-generated data using a flow matching loss, with alternating optimization.

Key Experimental Results

Main Results (GenEval T2I)

Model Params NFE Overall Counting Position
FLUX.1-dev 12B 50 0.66 0.74 0.22
Cosmos-Predict2 14B (teacher) 14B 70 0.84 0.79 0.64
Cosmos-Predict2 + DMD2 2B 4 0.80 0.70 0.57
Cosmos-Predict2 + rCM 2B 4 0.81 0.73 0.58
Cosmos-Predict2 + rCM 14B 4 0.83 0.80 0.59
Cosmos-Predict2 + rCM 14B 1 0.82 0.84 0.49

VBench Video Experiments

Model Params NFE Total Score Throughput (FPS)
Wan2.1 14B (teacher) 14B 100 83.58 0.18
Wan2.1 + DMD2 1.3B 4 84.56 14.6
Wan2.1 + rCM 1.3B 4 84.43 14.6
Wan2.1 + rCM 14B 2 85.05 8.3

Key Findings

  • rCM matches or exceeds DMD2 in quality while significantly outperforming DMD2 in diversity (Figure 1 shows DMD2 outputs converging in object position and pose).
  • The 14B rCM achieves a GenEval score of 0.83 in 4 steps, approaching the teacher's 0.84 in 70 steps.
  • In video generation, rCM achieves near-teacher VBench scores in just 2 steps.
  • \(\lambda=0.01\) achieves the best balance between quality and diversity.
  • Pure sCM exhibits notable quality defects in fine-grained scenarios such as text rendering, which rCM successfully resolves.

Highlights & Insights

  • This is the first work to scale JVP-based continuous-time consistency distillation to 14B parameters and 5-second video generation.
  • The paper offers a unified framework for understanding distillation methods through the complementarity of forward and reverse KL divergences.
  • No GAN tuning or extensive hyperparameter search is required; \(\lambda=0.01\) generalizes across tasks.
  • The diversity advantage of rCM is particularly valuable for scenarios requiring diverse outputs, such as interactive world models.

Limitations & Future Work

  • An additional fake score network introduces extra memory overhead.
  • JVP computation remains slower than standard forward passes, resulting in high training costs.
  • Single-step video generation still shows noticeable quality degradation (VBench drops from 85.05 to 83.02).
  • Extension to autoregressive video diffusion is mentioned only as future work.
  • sCM and MeanFlow provide the theoretical foundation.
  • DMD/DMD2 provide practical solutions for reverse KL distillation.
  • The philosophy of jointly combining forward and reverse KL divergences in DDO and DDRL underpins the design of rCM.
  • rCM provides a practical acceleration solution for deploying large-scale visual generation models.

Technical Details

  • TrigFlow noise schedule: \(\alpha_t = \cos(t), \sigma_t = \sin(t)\), interconvertible with rectified flow via SNR matching.
  • The fake score network is trained on student-generated data with a flow matching loss via alternating optimization.
  • Selective Activation Checkpointing (SAC) is employed to reduce memory consumption.
  • The teacher uses CFG, which is jointly distilled into the student.
  • Full-parameter fine-tuning is used (no LoRA), emphasizing the stability of rCM.
  • Experiments cover Cosmos-Predict2 (0.6B/2B/14B T2I) and Wan2.1 (1.3B/14B T2V).
  • Wan2.1 14B with 2-step rCM achieves 8.3 FPS versus the teacher's 0.18 FPS (~46× speedup).

Rating

  • Novelty: ⭐⭐⭐⭐ The theoretical insight of combining forward and reverse KL divergences is valuable, though individual components are known.
  • Experimental Thoroughness: ⭐⭐⭐⭐⭐ Validation at an unprecedented scale (14B parameters, T2I+T2V, multi-step ablations).
  • Writing Quality: ⭐⭐⭐⭐⭐ Theoretical analysis is clear and engineering details are thorough.
  • Value: ⭐⭐⭐⭐⭐ Addresses the core challenge of accelerating large-scale diffusion models with strong practical utility.