Skip to content

Autoregressive Distillation of Diffusion Transformers

Conference: CVPR 2025 (Oral)
arXiv: 2504.11295
Code: None
Area: Diffusion Models
Keywords: Diffusion model distillation, autoregressive distillation, exposure bias, ODE trajectory history, few-step generation

TL;DR

Proposes Autoregressive Distillation (ARD), which utilizes historical information of ODE trajectories instead of only the current denoised sample as input to predict future steps. By modifying the teacher transformer architecture with token-wise time embeddings and block-wise causal attention masks, it achieves an FID of 1.84 in 4 steps on ImageNet-256 with only a 1.1% increase in extra FLOPs.

Background & Motivation

Background: Diffusion Transformers (DiTs) have shown powerful capabilities in high-fidelity image generation and high-resolution scaling, but their iterative sampling process is extremely computationally expensive (typically requiring tens to hundreds of steps). Distillation methods compress the solution of the probability flow ODE into a few-step student model, serving as a mainstream approach to accelerate sampling.

Limitations of Prior Work: Existing distillation methods (such as Progressive Distillation and Consistency Models) use the most recently denoised sample as the input for the student. This introduces the exposure bias problem—during training, the student observes "clean" trajectories generated by the teacher, whereas during inference, it must generate sequentially based on its own previous (potentially erroneous) outputs, leading to accumulated errors.

Key Challenge: Relying solely on the current single-step denoising result as input has two fundamental limitations: (1) It is susceptible to accumulated errors, causing exposure bias; (2) It discards coarse-grained historical information within the ODE trajectory, which is highly valuable for predicting future steps.

Goal: Design a distillation method that utilizes the history of the ODE trajectory, simultaneously mitigating exposure bias and providing a richer source of coarse-grained information.

Key Insight: Reframe distillation from an autoregressive perspective—the ODE trajectory is an ordered sequence where predictions at each step can "look back" at past history. This is analogous to leveraging context history in language models to reduce generation errors.

Core Idea: Inject the ODE trajectory history into the teacher transformer, distinguishing inputs from different timesteps via token-wise time embeddings and ensuring the correct direction of information flow using block-wise causal attention masks, thereby letting historical steps provide additional coarse-grained guidance for the current prediction.

Method

Overall Architecture

Based on standard distillation workflows, ARD expands the input of the student model from a single denoised sample to a sequence containing historical ODE trajectories. Given a pre-trained diffusion transformer as the teacher, ARD modifies its architecture to accept multiple historical timesteps as inputs, training the student to utilize this historical information to make more accurate predictions during few-step sampling. The input consists of the noisy image + historical trajectory points, and the output is the denoised prediction.

Key Designs

  1. Token-wise Time Embedding:

    • Function: Allows the model to distinguish input tokens originating from different timesteps.
    • Mechanism: In standard DiT, all tokens share a global time embedding. ARD adds independent time embedding tokens for each historical timestep in the trajectory. Specifically, tokens of the current step \(x_t\) and historical steps \(x_{t-1}, x_{t-2}, \dots\) are embedded with their corresponding time information respectively, and then fed together into the transformer. This allows the model to know not only the "current timestep" but also "which historical timestep each input token belongs to".
    • Design Motivation: Without distinguishing between timesteps, the model cannot differentiate the semantic variations between current and historical inputs. Token-wise time embeddings enable the model to automatically learn how to process information from different timesteps distinctively—extracting details from recent steps and structures from earlier steps.
  2. Block-wise Causal Attention Mask:

    • Function: Controls the direction of historical information flow and prevents information leakage.
    • Mechanism: In the transformer attention computation, a causal mask is introduced to ensure that tokens of the current timestep can only attend to themselves and tokens of earlier timesteps, preventing them from "peeking" into the future. The masking operates at a block level—all tokens within the same timestep can attend to each other (standard self-attention), but a strict causal order is maintained across different timesteps. This is similar to causal attention in language models, but the granularity of operation is the token block instead of individual tokens.
    • Design Motivation: The ODE trajectory possesses a natural temporal order—earlier steps contain coarse-grained structures, while later steps contain fine-grained details. Causal attention ensures that coarse-grained information is only used to guide refinement, avoiding training-inference inconsistency caused by reverse information leakage.
  3. Lower-Layer History Injection:

    • Function: Injects historical information only in the lower layers of the transformer, while higher layers process only the current step.
    • Mechanism: Historical trajectories are not utilized across all transformer layers. ARD reveals that injecting historical tokens only in the bottom several layers while retaining only the current-step tokens in the top layers sustains performance gains while maintaining efficiency. The lower layers extract coarse-grained structural information from history to integrate into the current representation, while the upper layers focus on fine-grained generation for the current step.
    • Design Motivation: (1) Injecting historical tokens into all layers significantly increases computational cost (as attention complexity scales quadratically with the number of tokens), whereas lower-layer injection increases FLOPs by only 1.1%; (2) Intuitively, coarse-grained global structure should be processed in early layers, while fine-grained generation occurs in subsequent layers—consistent with the paradigm in CNNs where lower layers capture general features and upper layers capture task-specific features.

Loss & Training

ARD adopts the standard distillation loss—the distance between the student's prediction and the teacher's output at the corresponding ODE step. During training, the teacher's ODE trajectory is utilized as historical inputs (to reduce exposure bias), whereas during inference, the student's own historically predicted trajectory is used. Key training details: the student is initialized from the pre-trained DiT, and the newly added time embeddings and attention mask parameters are trained from scratch.

Key Experimental Results

Main Results (ImageNet-256 class-conditional generation)

Method Steps FID ↓ FID Degradation (vs Teacher)
Teacher (DDPM, 250 steps) 250 ~1.5 -
Progressive Distillation 4 ~3.0 ~1.5
Consistency Distillation 4 ~2.5 ~1.0
ARD 4 1.84 ~0.3
Baseline distillation methods (avg) 4 - ~1.5
ARD vs. Baselines 4 - 5× less FID degradation

Ablation Study

Configuration FID ↓ Extra FLOPs Description
Baseline (no history) ~2.5 0% Standard distillation
Full-layer history injection ~1.82 ~8% Best performance but heavy overhead
Lower-layer injection (ARD) 1.84 1.1% Best performance-efficiency trade-off
Without causal mask ~2.1 1.1% Information leakage leads to degradation
Without token-wise time embedding ~2.2 1.1% Unable to distinguish historical steps

Key Findings

  • 5x Reduction in FID Degradation: Compared to baseline distillation methods, ARD reduces the FID degradation (the gap relative to the teacher) by approximately 5 times under 4-step sampling, demonstrating that historical information effectively mitigates exposure bias.
  • Minimal Extra Overhead: The lower-layer injection strategy increases FLOPs by only 1.1%, representing an almost "free" performance boost.
  • Equally Effective on T2I: On the 1024-resolution text-to-image task, ARD outperforms publicly available distilled models in prompt adherence score with minimal FID degradation.
  • Causal Mask and Time Embedding Are Both Indispensable: Removing either component significantly degrades performance (FID increases by 0.3+), indicating that proper control over the information flow is key to the success of ARD.
  • Lower-layer Injection ≈ Full-layer Injection: Demonstrates only a 0.02 FID difference with around 7% saved FLOPs, indicating that the value of historical information is primarily captured during the lower-layer feature extraction phase.

Highlights & Insights

  • Redefining Distillation from an Autoregressive Perspective: Viewing the ODE trajectory as a sequence and processing it autoregressively is a highly clever perspective shift. Utilizing the teacher's trajectory during training (resembling teacher forcing) and the student's own history during inference (resembling autoregressive decoding) perfectly aligns with the sequence generation paradigm. This concept can be transferred to any generation model based on iterative refinement.
  • Highly Efficient Lower-layer Injection Design: The insight that coarse-grained information only needs to be injected into lower layers is thoroughly practical. This is applicable not only to diffusion distillation but also to any transformer architecture requiring the fusion of multi-resolution conditional information.
  • A New Approach to Exposure Bias: Traditional methods addressing exposure bias (e.g., scheduled sampling, noise injection) often require modifying training objectives or data distributions. ARD fundamentally reduces dependence on single-step prediction accuracy by introducing historical inputs, representing a more elegant solution.

Limitations & Future Work

  • Modification of Teacher Architecture Required: ARD requires altering the transformer structure (injecting time embeddings and attention masks) and cannot be directly applied plug-and-play to arbitrary pre-trained models.
  • Limited History Length: Currently, only 1-2 historical steps are primarily used. Longer histories might yield further improvements but increase complexity.
  • Validated Only on Class-Conditional and T2I: The effectiveness in other diffusion model applications like video and 3D generation remains unexplored.
  • Future Directions: Exploring adaptive history length selection; integrating ARD with other distillation paradigms such as consistency models; extending to video and 3D generation tasks.
  • vs. Progressive Distillation: PD halves the sampling steps progressively but relies solely on the current denoised sample as input, leading to significant FID degradation. ARD reduces this degradation five-fold by introducing historical information, which indicates that insufficient information is the core bottleneck of PD.
  • vs. Consistency Models: CM achieves single-step generation via consistency constraints, but the gap to the teacher remains substantial. ARD retains the flexibility of multi-step sampling and achieves superior quality with 4 steps.
  • vs. InstaFlow/Rectified Flow Distillation: These methods perform few-step distillation without exploiting trajectory history. The autoregressive concept of ARD can serve as a plug-and-play enhancement module stacked on top of these methodologies.

Rating

  • Novelty: ⭐⭐⭐⭐⭐ Reintroducing the autoregressive framework into diffusion distillation is a completely fresh perspective, making its CVPR Oral well deserved.
  • Experimental Thoroughness: ⭐⭐⭐⭐ Validated on both ImageNet-256 and T2I with thorough ablation, but lacks comparison with the latest consistency model variants.
  • Writing Quality: ⭐⭐⭐⭐ Clear motivation, and the method description is concise and elegant.
  • Value: ⭐⭐⭐⭐⭐ Achieving a 5x reduction in FID degradation for only 1.1% extra FLOPs is highly practical; the autoregressive distillation framework offers extensive extensibility.