KLASS: KL-Guided Fast Inference in Masked Diffusion Models¶
Conference: NeurIPS 2025 arXiv: 2511.05664 Code: GitHub Area: Image Generation / Discrete Diffusion Model Sampling Keywords: Masked Diffusion Models, KL Divergence, Accelerated Sampling, Token Stability, Parallel Decoding
TL;DR¶
This paper proposes KLASS (KL-Adaptive Stability Sampling), a training-free sampling method that leverages token-level KL divergence and confidence scores to identify stable tokens for parallel decoding, achieving up to 2.78× speedup on masked diffusion models without sacrificing—and in many cases improving—generation quality.
Background & Motivation¶
Background: Masked diffusion models (MDMs) have demonstrated competitive performance across language generation, image generation, and molecular generation tasks, with large-scale models such as LLaDA and DREAM having acquired reasoning capabilities.
Limitations of Prior Work: The sampling process of MDMs relies on iterative unmasking, typically employing fixed Top-k or random sampling strategies that unmask only a small number of tokens per step, resulting in slow inference.
Key Challenge: A fundamental trade-off exists between accelerated sampling (unmasking more tokens per step) and generation quality—prematurely unmasking unstable tokens leads to degraded accuracy.
Goal: To enable safe parallel unmasking of multiple tokens to accelerate generation, without requiring additional training or external planners.
Key Insight: The model's own internal signals—KL divergence and confidence scores—are used to determine whether a token is "stable" and thus safe to unmask early.
Core Idea: Tokens exhibiting low KL divergence and high confidence are considered "stable" and can be unmasked in parallel; incorrectly predicted tokens cannot maintain dynamic stability throughout the diffusion process.
Method¶
Overall Architecture¶
At each step of the reverse sampling process in a standard masked diffusion model, KLASS computes two metrics for all masked tokens—a confidence score and a KL score—and then adaptively selects tokens satisfying the stability criteria for batch unmasking.
Key Designs¶
-
Confidence Score: Defined as the maximum probability value in the model's predicted distribution: \(\text{conf}_t^i = \max_v p_t^i(v)\) High confidence indicates greater certainty in the model's estimate for that token.
-
KL Score: Defined as the KL divergence between predicted distributions at adjacent time steps: \(d_t^i = D_{\text{KL}}(p_t^i \| p_{t+1}^i)\) A low KL score indicates that the model's estimate for the token is temporally consistent and stable. Empirical findings show that correctly predicted tokens consistently exhibit lower KL scores.
-
Stable Token Selection: Given history length \(n\), KL threshold \(\epsilon_{\text{KL}}\), and confidence threshold \(\tau\), the stable token set is defined as: \(S_t = \{i \mid \forall k \in \{1,...,n\}, D_{\text{KL}}(p_{t+k-1}^i \| p_{t+k}^i) < \epsilon_{\text{KL}} \wedge \text{conf}_t^i > \tau\}\)
-
Unmasking Rule: If stable tokens exist (\(S_t \neq \emptyset\)), all are unmasked simultaneously; otherwise, the method falls back to Top-\(u\) confidence-based unmasking.
Theoretical Support¶
Proposition 5.3: For a well-trained model, if a token's current prediction is incorrect, its predicted distribution will necessarily undergo significant change as context is progressively revealed (the average KL divergence is lower-bounded). This implies that incorrectly predicted tokens cannot maintain dynamic stability, and KLASS therefore avoids errors by deferring the unmasking of unstable tokens.
Computational Overhead¶
KL computation is a lightweight post-processing operation requiring no additional forward passes. Experiments show memory overhead of less than 1.57% and latency overhead of less than 0.21%.
Key Experimental Results¶
Main Results: Reasoning Tasks¶
| Method | Model | MATH Acc↑ | Steps↓ | GSM8K Acc↑ | Steps↓ | HumanEval Acc↑ | Steps↓ |
|---|---|---|---|---|---|---|---|
| Top-1 | LLaDA | 31.4 | 256 | 75.13 | 256 | 39.63 | 256 |
| Confidence | LLaDA | 31.6 | 96.5 | 75.21 | 74.4 | 37.80 | 54.4 |
| KLASS | LLaDA | 33.8 | 128.6 | 76.50 | 98.6 | 40.85 | 92.0 |
| Top-1 | DREAM | 37.97 | 256 | 79.55 | 256 | 58.53 | 256 |
| Confidence | DREAM | 41.80 | 95.1 | 73.67 | 74.8 | 50.00 | 52.5 |
| KLASS | DREAM | 43.20 | 149.7 | 79.43 | 155.7 | 59.35 | 74.9 |
KLASS outperforms standard Top-1 decoding on nearly all tasks while reducing the number of steps by 40–70%, achieving a wall-clock speedup of up to 2.78×.
Text Generation¶
| Method | MAUVE↑ | LLaMA2 PPL↓ | LLaMA3 PPL↓ | GPT-2 PPL↓ |
|---|---|---|---|---|
| MDLM | 0.115 | 30.88 | 54.15 | 51.78 |
| KLASS | 0.179 | 26.94 | 49.19 | 45.50 |
Image Generation (MMaDA)¶
| Method | Steps | FID↓ | IS↑ |
|---|---|---|---|
| Confidence | 16 | 34.48 | 75.72 |
| KLASS | 16 | 30.48 | 93.07 |
Ablation Study¶
| Unmasking Strategy | MATH Acc↑ | Steps↓ |
|---|---|---|
| Single (conf) | 31.2 | 256 |
| Single (KL) | 29.0 | 256 |
| Parallel (KLASS) | 33.8 | 128.6 |
Parallel unmasking of stable tokens outperforms single-token selection strategies while requiring fewer steps.
Key Findings¶
- The combination of KL score and confidence is essential—neither criterion alone is sufficient.
- Correctly predicted tokens consistently exhibit lower KL scores (empirically verified in Figure 1b).
- KLASS proves effective across text, image, and molecular modalities, demonstrating its generality.
Highlights & Insights¶
- Training-free universal sampler: KLASS does not modify model parameters and relies purely on statistical signals derived from the model's own inference process.
- Theory-practice alignment: Proposition 5.3 provides a theoretical justification for why KL divergence can discriminate between correct and incorrect tokens.
- Cross-modal generality: A single method applies to language, image, and molecular generation—a rare property among diffusion model samplers.
- Simultaneous acceleration and quality improvement: KLASS not only speeds up inference but also improves accuracy on multiple benchmarks, challenging the conventional speed–quality trade-off.
Limitations & Future Work¶
- Hyperparameters (KL threshold, confidence threshold) require tuning across different models and tasks; although the authors claim low sensitivity, grid search is still necessary.
- Validation is currently limited to masked diffusion models and has not been extended to other discrete diffusion formulations (e.g., uniform noise schedules).
- The choice of history length \(n\) affects performance, with longer histories incurring greater caching overhead.
- Performance on larger-scale models remains to be verified.
Related Work & Insights¶
- Compared to concurrent work such as Fast-dLLM, Dimple, and EB-Sampler, KLASS distinguishes itself by incorporating KL divergence as a measure of "dynamic stability" in addition to confidence scores.
- The fine-grained, token-level monitoring of inference quality can inspire sampling strategy design in other discrete generative models.
- The KL-based "prior verification" mechanism has potential for extension to speculative decoding in autoregressive models.
Rating¶
- Novelty: ⭐⭐⭐⭐ The dual-criterion approach of combining KL divergence and confidence for stable token selection is novel, though the core idea is relatively straightforward.
- Experimental Thoroughness: ⭐⭐⭐⭐⭐ Covers four modalities—text reasoning, text generation, image generation, and molecular generation—with comprehensive ablations.
- Writing Quality: ⭐⭐⭐⭐⭐ Well-structured with concise and compelling theoretical sections and intuitive figures.
- Value: ⭐⭐⭐⭐ A plug-and-play acceleration method with high practical utility, though applicability is currently limited to the masked diffusion model ecosystem.