Beyond Random: Automatic Inner-Loop Optimization in Dataset Distillation¶
Conference: NeurIPS 2025 arXiv: 2510.04838 Code: Available Area: Dataset Distillation / Efficient Training Keywords: Dataset Distillation, BPTT Truncation, Adaptive Truncation, Low-Rank Hessian Approximation, Patch-wise Semantic Preservation
TL;DR¶
This paper proposes AT-BPTT (Adaptive Truncation BPTT), which partitions DNN training into early/middle/late stages and adaptively adjusts truncation strategies and window sizes accordingly. The method achieves average accuracy gains of 3–17% on CIFAR-10/100/Tiny-ImageNet/ImageNet-1K, while delivering 3.9× speedup and 63% memory reduction.
Background & Motivation¶
Background: Dataset Distillation (DD) aims to compress a large dataset into a compact synthetic dataset such that models trained on the synthetic set approximate the performance of those trained on the full set. Mainstream methods include gradient matching (DC/DSA), trajectory matching (MTT), and distribution matching (DM).
Limitations of Prior Work: Trajectory matching methods (MTT/FTD/DATM) require unrolling \(T\) training steps via BPTT to optimize the synthetic data. Full unrolling incurs prohibitive computational and memory costs, so in practice random truncation (RaT-BPTT) is employed, unrolling only a randomly selected subset of \(S\) steps. However, random selection disregards the heterogeneous learning dynamics across different training stages.
Key Challenge: DNN training exhibits distinct phases—coarse-grained feature learning in the early stage, discriminative feature learning in the middle stage, and fine-grained adjustment in the late stage. Random truncation treats all stages uniformly and cannot align with this non-uniform learning dynamics.
Goal: Design a truncation strategy that adaptively aligns with the learning stages of DNN training, improving distillation quality while maintaining computational efficiency.
Key Insight: Gradient norms vary non-uniformly throughout training—large in the early stage (when basic features are acquired) and small in the late stage (during fine-tuning). This observation motivates adaptive allocation of truncation positions and window sizes.
Core Idea: Adaptively select BPTT truncation positions based on gradient norms (prioritizing high-gradient steps early and low-gradient steps late), dynamically adjust the unrolling window width, and employ low-rank Hessian approximation to reduce computational cost.
Method¶
Overall Architecture¶
The training process \([0, T]\) is divided into three stages, each applying a distinct truncation strategy to select unrolling positions. Window sizes are adaptively adjusted based on gradient variation between adjacent steps. Low-rank Hessian approximation replaces exact second-order derivatives, and patch-wise semantic preservation is applied for high-resolution images.
Key Designs¶
-
Three-Stage Adaptive Truncation Strategy:
- Function: Partitions training into early/middle/late stages, each using a different probability distribution to select truncation positions.
- Mechanism:
- Early stage: sampling proportional to gradient norm (\(P(t) \propto \exp(\|\nabla_\theta \mathcal{L}_t\| / \tau)\)) → prioritizes steps with large gradients (basic feature learning phase).
- Middle stage: uniform random sampling (standard RaT-BPTT) → covers discriminative feature learning.
- Late stage: sampling inversely proportional to gradient norm → prioritizes regions of gradient variation (fine-tuning phase).
- Design Motivation: Different training stages impose different requirements on the synthetic data—the early stage demands matching coarse-grained statistics, while the late stage demands matching fine-grained discriminative signals.
-
Adaptive Window Size:
- Function: Dynamically adjusts the unrolling window width based on gradient differences between adjacent time steps.
- Formula: \(W^*(t) = W - d + 2d \cdot \eta(t)\), where \(\eta(t)\) is proportional to \(\exp(|\|\nabla \mathcal{L}_t\| - \|\nabla \mathcal{L}_{t-1}\||/\tau)\).
- Design Motivation: Regions with sharp gradient changes require longer unrolling windows to capture cross-step dependencies.
-
Low-Rank Hessian Approximation:
- Function: Approximates the Hessian using randomized SVD combined with Hessian-vector products.
- Reduces complexity from \(O(p^2)\) to \(O(pk + k^3)\) (where \(k\) is the rank and \(p\) is the number of parameters).
- Design Motivation: Exact Hessian computation is the primary bottleneck in DD; low-rank approximation achieves substantial memory savings with negligible accuracy loss.
-
Patch-wise Semantic Preservation:
- Function: For high-resolution images, the synthetic image is divided into \(n \times n\) patches, each subject to both local and global prototype centroid matching.
- Design Motivation: Global matching loses local semantics at high resolutions; patch-level matching preserves spatial structure.
Loss & Training¶
Outer loop: loss is computed on the real validation set and backpropagated to the synthetic data. Inner loop: \(S\)-step BPTT unrolling on the synthetic data. \(\mathcal{L} = \mathcal{L}_{match} + \lambda \mathcal{L}_{semantic}\).
Key Experimental Results¶
Main Results¶
| Dataset | IPC | AT-BPTT | Prev. SOTA | Gain |
|---|---|---|---|---|
| CIFAR-10 | 10 | 72.4% | 69.4% | +3.0% |
| CIFAR-100 | 10 | 49.0% | 47.5% | +1.5% |
| Tiny-ImageNet | 10 | 32.7% | 24.4% | +8.3% |
| ImageNet-1K | 10 | 30.6% | 13.0% | +17.6% |
Ablation Study¶
| Configuration | Key Findings | Notes |
|---|---|---|
| Random vs. Adaptive Truncation | Adaptive consistently outperforms on all datasets | Validates core contribution |
| Fixed vs. Adaptive Window | Adaptive window +1.2% on CIFAR-100 | High-gradient-variation regions require longer unrolling |
| With/Without Low-Rank Hessian | Negligible accuracy change, −63% memory | Low-rank approximation is effective |
| With/Without Patch Semantic Preservation | Significant gain on high-resolution (ImageNet) | Local structure matters |
| Computational Efficiency | 3.9× speedup vs. RaT-BPTT | Primarily attributed to low-rank Hessian |
Key Findings¶
- The +17.6% gain on ImageNet-1K is remarkably large, indicating that large-scale datasets benefit substantially more from adaptive truncation.
- Low-rank Hessian approximation achieves 63% memory reduction and 3.9× speedup without sacrificing accuracy.
- Among the three stages, the early stage contributes most—suggesting that aligning with basic feature learning is the most critical factor in distillation.
Highlights & Insights¶
- The learning dynamics intuition is well-founded: DNN training is inherently stage-wise, and random truncation wastes the unrolling budget by ignoring this structure—a simple yet previously overlooked insight.
- Dramatic gains on ImageNet-1K: The +17.6% improvement reveals that truncation strategy has a far greater impact on large-scale, high-resolution data than on small datasets. Prior methods may have been "lucky" near the optimum on small datasets, while the inefficiency of random truncation is fully exposed at scale.
- Low-rank Hessian as a standalone practical contribution: Even setting aside adaptive truncation, low-rank Hessian approximation alone can substantially accelerate existing DD methods.
Limitations & Future Work¶
- The stage boundary ratios (defining "early/middle/late") require hyperparameter tuning.
- Sensitivity analysis of the gradient norm temperature \(\tau\) and window parameter \(d\) is insufficient.
- Validation is limited to the trajectory matching framework; applicability to gradient matching and distribution matching has not been explored.
- Scaling behavior under large IPC settings (IPC=50+) remains untested.
Related Work & Insights¶
- vs. MTT (Cazenavette et al., 2022): MTT employs fixed-length trajectory matching; this work improves the truncation strategy.
- vs. RaT-BPTT (Deng & Russakovsky, 2022): RaT-BPTT serves as the random truncation baseline; AT-BPTT consistently outperforms it via adaptive truncation.
- vs. FTD (Du et al., 2023): FTD improves the matching objective while AT-BPTT improves the unrolling strategy—the two approaches are orthogonal and can be combined.
Rating¶
- Novelty: ⭐⭐⭐⭐ — The adaptive truncation idea is simple yet overlooked; the substantial gains on ImageNet underscore its significance.
- Experimental Thoroughness: ⭐⭐⭐⭐⭐ — Four datasets, comprehensive ablations, and computational efficiency analysis.
- Writing Quality: ⭐⭐⭐⭐ — The three-stage framework is presented clearly.
- Value: ⭐⭐⭐⭐ — Directly useful to the dataset distillation community; the low-rank Hessian contribution is independently valuable.