MosaicDiff: Training-free Structural Pruning for Diffusion Model Acceleration Reflecting Pretraining Dynamics¶
Conference: ICCV 2025 arXiv: 2510.11962 Code: https://github.com/bwguo105/MosaicDiff Area: Diffusion Models / Model Acceleration Keywords: Structural Pruning, Training-free Acceleration, Pretraining Dynamics, SNR-aware, Second-order Pruning
TL;DR¶
This paper proposes MosaicDiff, a training-free structural pruning method for diffusion models that dynamically partitions the inference trajectory into three stages based on pretraining learning-speed dynamics and applies stage-specific subnetworks with varying sparsity, achieving significant acceleration on DiT and SDXL without sacrificing generation quality.
Background & Motivation¶
Background: Diffusion models offer strong generative capabilities but incur substantial computational cost. The community has primarily addressed inference acceleration through reduced sampling steps (DDIM, DPM-Solver), knowledge distillation, structural pruning, quantization, and feature caching. However, existing methods universally overlook the inherent variation in learning speed during diffusion model pretraining.
Core Observation: Diffusion model pretraining exhibits a "slow–fast–slow" three-phase learning pattern—early stages involve slow learning (dominated by high-noise inputs), the middle stage shows a sharp increase in learning speed (rapid capture of coarse-grained features), and the late stage slows again (refinement of fine-grained details). This critical insight has been entirely ignored by prior work.
Limitations of Prior Work: - Diff-Pruning applies a single pruned model across all timesteps, ignoring step-level importance differences. - EcoDiff employs learnable masks but similarly lacks step-adaptive behavior. - Existing methods suffer severe performance degradation at high sparsity ratios.
Method¶
Overall Architecture¶
MosaicDiff follows a "Divide → Prune → Conquer" three-stage pipeline:
- Divide: The inference trajectory is partitioned into three stages via quantitative analysis of pretraining dynamics.
- Prune: Second-order structural pruning is applied to each stage using SNR-aware calibration data.
- Conquer: Subnetworks with different sparsity levels are assembled to complete the final sampling.
Key Design 1: Stage Partitioning and Importance Scoring¶
Learning dynamics are characterized by monitoring the MSE between the intermediate latent representation \(\hat{x}_t\) and the final output \(\hat{x}_0\):
The authors derive closed-form expressions for the expected MSE and its gradient (Theorem 1):
The final importance score integrates SNR information:
Sampling steps are automatically partitioned into three stages via a threshold \(threshold = M \cdot \max_t(score(t))\).
Key Design 2: SNR-aware Calibration Dataset¶
Stage-specific calibration data are constructed by encoding images from standard datasets (e.g., ImageNet-1K) into latent representations, randomly sampling \(t\) within the corresponding stage's timestep range, and adding noise according to \(x_t = \sqrt{\bar{\alpha}_t}x_0 + \sqrt{1-\bar{\alpha}_t}\epsilon\), ensuring the calibration data accurately reflects the SNR characteristics of each stage. For models employing CFG, unconditional (null-label) calibration samples are additionally provided.
Key Design 3: Second-order Structural Pruning¶
Given a linear layer weight \(\mathbf{W} \in \mathbb{R}^{m \times n}\) and calibration input \(\mathbf{X} \in \mathbb{R}^{b \times n}\), the objective is:
Column-level saliency scores are obtained by computing the Hessian matrix \(\mathbf{H} = \mathbf{X}^\top\mathbf{X}\) and extending the OBS formulation:
After pruning, a compensatory weight update \(\delta = -\mathbf{W}_{:,\mathbf{M}} \cdot (\mathbf{H}_{\mathbf{M},\mathbf{M}}^{-1})^{-1} \cdot \mathbf{H}_{\mathbf{M},:}^{-1}\) is applied to further reduce reconstruction error.
Sparsity Allocation Strategy¶
Per-stage sparsity is inversely proportional to the average importance score: \(s_i \propto 1 - \overline{score}_i\). The fast-learning middle stage retains more parameters, while the slow-learning early and late stages permit more aggressive pruning.
Key Experimental Results¶
Main Results¶
| Method | Steps | MACs(T) | Speedup | IS↑ | FID↓ | Precision↑ |
|---|---|---|---|---|---|---|
| Vanilla DiT-XL/2 | 50 | 5.72 | 1.00× | 238.6 | 2.26 | 80.16 |
| Diff-Pruning-0.3 | 50 | 4.10 | 1.29× | 4.68 | 180.76 | 7.24 |
| Learning-to-Cache | 50 | 4.36 | 1.27× | 244.1 | 2.27 | 80.94 |
| MosaicDiff-0.33 | 50 | 3.92 | 1.32× | 267.8 | 2.24 | 82.01 |
| Vanilla DiT-XL/2 | 20 | 2.29 | 1.00× | 223.5 | 3.48 | 78.76 |
| MosaicDiff-0.30 | 20 | 1.64 | 1.28× | 266.7 | 3.20 | 81.13 |
Pruning Comparison on SDXL¶
| Method | Sparsity | FID↓ | CLIP↑ | SSIM↑ |
|---|---|---|---|---|
| Diff-Pruning | 10% | 108.96 | 0.22 | 0.31 |
| EcoDiff | 10% | 33.75 | 0.31 | 0.53 |
| MosaicDiff | 10% | 23.18 | 0.32 | 0.67 |
| Diff-Pruning | 20% | 404.87 | 0.05 | 0.26 |
| MosaicDiff | 20% | 23.79 | 0.32 | 0.64 |
Key Findings¶
- MosaicDiff achieves FID 2.24 under 50-step DDIM with 33% sparsity, outperforming the uncompressed baseline (2.26) while reducing MACs by 31%.
- The advantage is more pronounced at high sparsity: Diff-Pruning at 30% sparsity yields FID 180.76, whereas MosaicDiff attains 2.24.
- The method generalizes effectively to SDXL: at 10% sparsity, FID improves to 23.18 (vs. EcoDiff's 33.75).
- Closed-form MSE/gradient curves closely match empirical observations, validating the theoretical analysis.
Highlights & Insights¶
- First alignment of pretraining learning dynamics with post-training acceleration: The perspective is novel and theoretically grounded; the closed-form solution eliminates any dependency on actual pretraining runs.
- Training-free and fine-tuning-free: Second-order pruning leveraging Hessian information requires no retraining whatsoever.
- Strong generality: Applicable to both Transformer (DiT) and U-Net (SDXL) architectures.
- SNR-aware calibration ensures precision in stage-specific pruning.
Limitations & Future Work¶
- The thresholds \(M\) and \(\lambda\) require hyperparameter tuning (though the authors provide recommended values).
- Whether the three-stage partition is optimal for all noise schedules remains unclear.
- Hessian computation incurs non-trivial calibration data and computational overhead.
Related Work & Insights¶
- Sampling Acceleration: DDIM, DPM-Solver, Consistency Models
- Structural Pruning: Diff-Pruning, EcoDiff
- Caching Methods: DeepCache, Learning-to-Cache
- Training-time Compression: DiP-GO, Knowledge Distillation
Rating¶
- Novelty: ⭐⭐⭐⭐ — The pretraining dynamics perspective is original.
- Technical Depth: ⭐⭐⭐⭐ — Closed-form theoretical analysis is rigorous.
- Experimental Thoroughness: ⭐⭐⭐⭐ — Comprehensive coverage across DiT and SDXL.
- Practical Value: ⭐⭐⭐⭐⭐ — Training-free and plug-and-play.