Distillation of Discrete Diffusion through Dimensional Correlations (Di4C)¶
Conference: ICML 2025
arXiv: 2410.08709
Code: sony/di4c
Area: Diffusion Models / Discrete Generative Models
Keywords: Discrete Diffusion Models, Knowledge Distillation, Dimensional Correlation, Mixture Models, Few-step Sampling
TL;DR¶
This paper proposes the Di4C method, which captures correlations between dimensions through a "mixture" model. Combined with a consistency loss function, it distills multi-step discrete diffusion models into few-step models, demonstrating effectiveness across both image and language tasks.
Background & Motivation¶
Discrete diffusion models face unique challenges in reducing sampling steps that do not exist in continuous models:
Dimensional Independence Limitation: Traditional discrete diffusion models use a "product model", where each dimension models the sampling distribution independently, i.e., \(p_{s|t}(x_s|x_t) = \prod_{d=1}^D p_{s|t}^d(x_s^d|x_t)\). Although this ensures scalability (reducing the output space size from \(O(D^{|S|})\) to \(O(D|S|)\)), it neglects the dependencies between dimensions.
Fundamental Obstacle to Few-step Sampling: In extreme cases (such as single-step denoising in masked diffusion), product models are completely incapable of approximating complex joint distributions. Theorem 1 of this paper proves that the total variation (TV) distance of an \(N\)-step product model has an \(\Omega(1/N)\) lower bound, indicating that modeling dimensional correlation is essential to reducing the number of steps.
Limitations of Training Losses: The loss functions of continuous-time score-based discrete diffusion only require marginal distributions. Consequently, even if a model possesses the capacity to represent dimensional correlations, it cannot learn them from these losses.
Key Insight: The composition of multi-step product models can implicitly capture dimensional correlation, even though each individual step is dimension-independent.
Method¶
Overall Architecture¶
The core idea of Di4C is to distill a multi-step teacher (a product model) into a few-step student (a mixture model): $\(p_{0|t_n}^\theta \approx p_{0|t_1}^\psi \circ \cdots \circ p_{t_{n-1}|t_n}^\psi\)$
Key Designs¶
1. Mixture Model (Section 3.2)¶
To capture dimensional correlations, a mixture model is proposed: $\(p_{s|t}^\theta(x_s|x_t) = \mathbb{E}_\lambda[p_{s|t}^\theta(x_s|x_t;\lambda)]\)$ $\(\text{其中 } p_{s|t}^\theta(x_s|x_t;\lambda) = \prod_{d=1}^D p_{s|t}^{\theta,d}(x_s^d|x_t;\lambda)\)$
- Each component conditioned on \(\lambda\) remains a product model (dimension-independent), but taking the expectation over \(\lambda\) introduces dimensional correlation.
- Proposition 1 proves the universal approximation property of the mixture model: any discrete distribution can be represented as a mixture of product distributions.
- It introduces almost no additional computational overhead during inference, requiring only sampling and injecting \(\lambda\).
2. Distillation Loss Functions (Section 3.3)¶
Distillation Loss—approximating the teacher at a small timestep \(\delta\): $\(\mathcal{L}_{\text{distil}}(\theta;\psi,r_\delta,\delta) = \mathbb{E}_{x_\delta \sim r_\delta}[D_{\text{KL}}(p_{0|\delta}^\psi(\cdot|x_\delta) \| p_{0|\delta}^\theta(\cdot|x_\delta))]\)$
Consistency Loss—learning the dimensional correlations manifested by the composition of the teacher: $\(\mathcal{L}_{\text{consis}}(\theta;\psi,r_t,s,u,t) = \mathbb{E}_{x_t \sim r_t}[D_{\text{KL}}(p_{s|u}^\theta \circ p_{u|t}^\psi(\cdot|x_t) \| p_{s|t}^\theta(\cdot|x_t))]\)$
- The key to consistency loss lies in the joint use of both the student denoiser and the teacher denoiser.
- It is approximated computationally using Monte Carlo or control variates.
Loss & Training¶
- Distillation loss is used at small \(\delta\) (where a single step of the teacher suffices), while dimensional correlation is primarily introduced via consistency loss.
- The reference distribution \(r_t\) can be either \(q_t\) generated from data or the distribution obtained from multi-step teacher denoising.
- In practice, \(\lambda\) can be implemented by feeding an additional random noise vector into the network.
Key Experimental Results¶
Main Results¶
1. CIFAR-10 Pixel-level Discrete Gaussian Diffusion - Uses Campbell et al. (2022) as the teacher. - Significantly improves the teacher's FID metrics under few-step sampling (2-4 steps).
2. ImageNet Class-Conditional Generation (Masked Diffusion) - Teacher: Besnier & Chen (2023) - Achieves a 2x speedup while maintaining generation quality comparable to the teacher.
3. Language Modeling (OpenWebText, Masked Diffusion LM) - Further distills on top of an already distilled model (Deschenaux & Gulcehre, 2025). - Further improves performance by capturing dimensional correlations, without significantly compromising sampling diversity.
Ablation Study¶
- Impact of the number of mixture components in the mixture model on quality.
- Selection of distillation steps and consistency loss timesteps.
- Performance comparison between consistency loss and pure distillation loss.
Key Findings¶
| Theoretical Result | Content |
|---|---|
| Theorem 1 Upper Bound | TV distance of the \(N\)-step product model = \(O(1/N)\) |
| Theorem 1 Lower Bound | Even in a simple case where $ |
| Theorem 2 | The Di4C loss upper-bounds the distance between teacher and student output distributions |
| Lemma 1 | TV distance of a single-step product model = \(O(\epsilon^2)\) |
Highlights & Insights¶
- Elegant Combination of Theory and Practice: Theorem 1 rigorously proves that product models require \(\Omega(1/\epsilon)\) steps, establishing the theoretical necessity of modeling dimensional correlation.
- Elegance and Simplicity of the Mixture Model: By indexing a family of product models via a simple random variable \(\lambda\), it achieves a universal representation of dimensional correlation with minimal inference overhead.
- Cross-Modal Generality: The method is effective for both images (pixel-level and masked diffusion) and language (masked diffusion LM).
- Connection to Continuous Consistency Models: The design of consistency loss is conceptually similar to consistency models in the continuous domain, but operates on the composition of conditional probabilities.
- Stepwise Approximation Error of \(O(\epsilon^2)\): Leveraging the dimension-factorized nature of the forward process, the paper proves the good approximation performance of product models over small time intervals.
Limitations & Future Work¶
- Computation of Consistency Loss: Exact computation requires summing over high-dimensional discrete spaces, and Monte Carlo approximation introduces variance.
- Mixture Model Capacity: The distribution and dimension choices of \(\lambda\) require empirical tuning, and the theoretically required number of mixtures can be quite large.
- Dependency on Teacher Quality: The upper bound of the distillation performance is constrained by the multi-step sampling quality of the teacher.
- Continuous-Time Setting: The theoretical analysis is conducted in a continuous-time framework, but practical training employs discrete timesteps.
- Scaling to Larger Scales: The model scales used in the ImageNet experiments are relatively limited, and the effectiveness on large-scale text generation remains to be further validated.
Related Work & Insights¶
- Continuous-Domain Distillation: Consistency distillation of Salimans & Ho (2022), Kim et al. (2024).
- Foundations of Discrete Diffusion: Austin et al. (2021) D3PM, Campbell et al. (2022) continuous-time discrete diffusion.
- Concurrent Work: Park et al. (2025), Liu et al. (2024), Xu et al. (2025) also highlight the issue of product models neglecting dimensional correlation.
- Insights: The mixture model concept could inspire other scenarios requiring efficient modeling of high-dimensional discrete joint distributions.
Rating¶
- Novelty: ⭐⭐⭐⭐⭐ — First to systematically address few-step distillation in discrete diffusion, introducing innovations in both theory and methodology.
- Experimental Thoroughness: ⭐⭐⭐⭐ — Cross-modal validation (image and language) across multiple discrete diffusion variants.
- Writing Quality: ⭐⭐⭐⭐ — Clear theoretical derivations, with Figure 1 intuitively showcasing the core problem.
- Value: ⭐⭐⭐⭐⭐ — Holds significant theoretical and practical importance for accelerating discrete diffusion models.