MetaAug: Meta-Data Augmentation for Post-Training Quantization¶
Conference: ECCV 2024
arXiv: 2407.14726
Code: Yes
Area: Model Compression
Keywords: Post-Training Quantization, Meta-Learning, Bi-Level Optimization, Data Augmentation, Overfitting Mitigation
TL;DR¶
This paper proposes MetaAug, a meta-learning-based post-training quantization (PTQ) method. It employs a learnable transformation network to augment calibration data and concurrently optimizes both the transformation network and the quantized model within a bi-level optimization framework, thereby effectively mitigating the overfitting of PTQ on small calibration sets.
Background & Motivation¶
The Core Dilemma of PTQ: Overfitting¶
When deploying deep neural networks on resource-constrained devices, quantization is a critical technique to reduce computational and storage overhead. Quantization methods are broadly categorized into: - Quantization-Aware Training (QAT): Requires retraining with large-scale training data, yielding high accuracy but often limited by data access constraints in practice. - Post-Training Quantization (PTQ): Quantizes pre-trained models using only a small set of calibration data (e.g., 1024 images), which is much more practical.
However, the core issue of PTQ is that the calibration dataset is too small, making the quantized model highly susceptible to overfitting.
Limitations of Prior Work¶
Existing works have attempted to mitigate the overfitting of PTQ: - QDrop: Randomly drops quantized activations. - PD-Quant: Uses the BN statistics of the full-precision model to correct activation distributions. - Activation Regularization: Minimizes the discrepancy of intermediate features between the full-precision and quantized models.
However, these methods share the same fundamental flaw: they all rely on the original calibration data to train the quantized model, without a validation set to monitor overfitting during quantization. Using the same data for training and evaluation makes overfitting almost inevitable.
Mechanism¶
If the calibration data could be "split into two"—using the transformed calibration data as the training set and the original calibration data as the validation set—overfitting could be detected and prevented during quantization. The key challenge lies in making the data generated by the transformation network retain the semantic information of the original data while being sufficiently different from it to avoid degenerating into an identity mapping.
Method¶
Overall Architecture¶
The core of MetaAug is a bi-level optimization framework: - Inner optimization: Trains the quantized model \(\theta_Q\) using the transformed data \(T(x_i)\). - Outer optimization: Validates the quantized model using the original calibration data \(x_i\), and updates the transformation network \(T\) according to the validation loss.
This forms a meta-learning paradigm: the transformation network learns how to modify the data such that the quantized model trained on the transformed data generalizes better to the original data.
Key Designs¶
-
Meta-learning bi-level optimization:
- Function: Jointly optimizes the transformation network \(T\) and the quantized model \(\theta_Q\).
- Mechanism: $\(T^* = \arg\min_T \frac{1}{N}\sum_{i=1}^N \mathcal{L}_{\text{val}}(\hat{\theta}_Q, x_i^v)\)$ $\(\text{s.t. } \hat{\theta}_Q = \arg\min_{\theta_Q} \frac{1}{N}\sum_{i=1}^N \mathcal{L}_Q(\theta_Q, T(x_i))\)$ The inner loop trains the quantized model on the transformed data, while the outer loop validates and updates the transformation network on the original data.
- Design Motivation: It mimics the meta-learning concept of MAML—a good data transformation should allow a model trained on it to generalize well to the original data. End-to-end optimization is achieved via second-order gradients, utilized with the help of Facebook's
higherlibrary to compute higher-order gradients.
-
Transformation network \(T\) (UNet-based):
- Function: Transforms original calibration images into augmented images that preserve semantics but differ in appearance.
- Mechanism: Employs a UNet as the image-to-image transformation network, leveraging its encoder-decoder structure and skip connections to retain fine-grained feature information of the input.
- Design Motivation: The residual connections in UNet naturally help preserve original image details, while the encoder-decoder structure offers sufficient flexibility for transformation.
-
Distribution Preservation Loss:
- Function: Ensures that the transformed data maintains the same distribution structure as the original data in the feature space.
- Mechanism: Based on Probability Knowledge Transfer (PKT), it estimates the conditional probability density between any two points in the feature space, enabling the transformed images to share the same probability distribution with the original images: $\(\mathcal{L}_{DP}(T, S) = \frac{1}{N}\sum_{i=1}^N \text{KL}[\mathcal{P}_i \| \mathcal{P}_i^{(g)}]\)$ where \(\mathcal{P}_{i|j} = \frac{K(f_{\theta_{FP}}(x_i), f_{\theta_{FP}}(x_j))}{\sum_{k \neq j} K(f_{\theta_{FP}}(x_k), f_{\theta_{FP}}(x_j))}\) and \(K\) is the cosine similarity kernel.
- Design Motivation: MSE and KL divergence only consider point-to-point feature distances, ignoring the global structural relationships among samples. The distribution preservation loss captures the global distribution information of the dataset. Jam-packed ablation studies confirm that it outperforms MSE (+0.45%) and KL (+0.20%).
Loss & Training¶
Quantization Loss (Inner Loop): Employs block-wise quantization to minimize the MSE of the output of the \(l\)-th block between the full-precision model and the quantized model:
Validation Loss: Uses KL divergence to measure the consistency between the outputs of the quantized model and the full-precision model:
Margin Loss (to prevent identity collapse):
This ensures that the pixel difference between the transformed image and the original image is no less than the threshold \(\epsilon\).
Total Loss: \(\mathcal{L}_T = \lambda_1 \mathcal{L}_{\text{val}} + \lambda_2 \mathcal{L}_{\text{margin}} + \lambda_3 \mathcal{L}_{DP}\), with \(\lambda_1=5\), \(\lambda_2=0.5\), \(\lambda_3=3\times 10^4\).
Training Process: For each block, the transformation network and the quantized model are updated alternately first (for 500 iterations), and then the block is quantized using both the original and transformed data (for \(2\times 10^4\) iterations).
Key Experimental Results¶
Main Results: ImageNet Top-1 Accuracy¶
| Method | Bit Width (W/A) | ResNet-18 | ResNet-50 | MobileNetV2 |
|---|---|---|---|---|
| Full Precision | 32/32 | 71.01 | 76.63 | 72.20 |
| QDrop | 4/4 | 69.10 | 75.03 | 67.89 |
| PD-Quant | 4/4 | 69.23 | 75.16 | 68.19 |
| Genie-M | 4/4 | 69.35 | 75.21 | 68.65 |
| MetaAug (Ours) | 4/4 | 69.48 | 75.29 | 68.76 |
| QDrop | 2/2 | 51.14 | 54.74 | 8.46 |
| PD-Quant | 2/2 | 53.14 | 57.16 | 13.76 |
| Genie-M | 2/2 | 53.71 | 56.71 | 16.25 |
| Bit-Shrinking* | 2/2 | 57.33 | 59.03 | 18.23 |
| MetaAug* (Ours) | 2/2 | 57.89 | 60.50 | 19.61 |
The improvement is most pronounced at ultra-low bit widths (W2A2): ResNet-50 outperforms Bit-Shrinking by 1.47%, and MobileNetV2 by 1.38%.
Ablation Study¶
| Configuration | Loss Combination | ResNet-18 W2A2 |
|---|---|---|
| Genie-M Baseline | — | 53.71 |
| (a) | \(\mathcal{L}_{\text{val}}\) | 53.45 |
| (b) | \(\mathcal{L}_{\text{val}} + \mathcal{L}_{\text{MSE}}\) | 53.64 |
| (c) | \(\mathcal{L}_{\text{val}} + \mathcal{L}_{\text{KL}}\) | 53.89 |
| (d) | \(\mathcal{L}_{\text{val}} + \mathcal{L}_{\text{DP}}\) | 54.09 |
| (e) | \(\mathcal{L}_{\text{val}} + \mathcal{L}_{\text{DP}} + \mathcal{L}_{\text{margin}}\) | 54.22 |
The distribution preservation loss \(\mathcal{L}_{DP}\) consistently outperforms MSE and KL; introducing the margin loss yields an additional 0.13% improvement.
Overfitting Analysis¶
| Method | W/A | Test Set Acc | Calibration Set Acc | Train-Test Gap |
|---|---|---|---|---|
| QDrop | 2/2 | 51.14 | 77.53 | 26.39 |
| PD-Quant | 2/2 | 53.14 | 83.30 | 30.16 |
| Genie-M | 2/2 | 53.77 | 80.18 | 27.01 |
| MetaAug | 2/2 | 54.22 | 77.64 | 23.42 |
MetaAug not only achieves the highest test accuracy but also has the smallest train-test gap (23.42 vs 30.16 for PD-Quant), confirming its effectiveness in mitigating overfitting.
Key Findings¶
- Comparison with Traditional Augmentation: MetaAug (54.22%) outperforms Random Flip (53.93%), Cutmix (54.15%), and Mixup (54.05%), and can be stacked with them (MetaAug+Cutmix reaches 54.63%).
- Benefits are More Pronounced at Low Bit Widths: The improvement under 4/4 is limited, but under 2/2 is substantial (+0.5%~1.5%), because overfitting is more severe at lower bit widths, making the generalization gains from meta-learning more valuable.
- Visualization of the transformation network verifies that the generated images change in appearance while preserving semantic structure.
Highlights & Insights¶
- Viewing PTQ through a validation lens: For the first time, the concept of train/validation split is introduced to PTQ, addressing overfitting from a data optimization perspective rather than model regularization.
- Exquisite anti-degeneration design: The margin loss prevents the transformation network from collapsing into an identity mapping, and the distribution preservation loss prevents semantic information loss. Both are indispensable.
- Complementary to traditional augmentation: MetaAug can be combined with Mixup, Cutmix, etc., to yield larger gains, indicating that the learned transformation is fundamentally different from random augmentation.
- Quantification of actual overfitting: Quantifies the degree of overfitting through the train-test accuracy gap, providing a clear evaluation benchmark for PTQ overfitting research.
Limitations & Future Work¶
- The transformation network only performs photometric transformations and does not introduce geometric transformations (the authors suggest integrating Spatial Transformer).
- Using UNet as the transformation network introduces additional training overhead (500 iterations/block).
- Only verified on ImageNet classification, and not yet extended to detection, segmentation, etc.
- Hyperparameters (\(\lambda_1, \lambda_2, \lambda_3, \epsilon\)) require different configurations for different architectures (ResNet-50 uses \(\epsilon=0.3\), while others use \(\epsilon=0.1\)).
Related Work & Insights¶
- Connection to MAML: The bi-level optimization framework originates from MAML, but the objective shifts from "learning a good initialization" to "learning a good data transformation".
- Difference from MetaMix/MetaQuantNet: These works use meta-learning to optimize the quantization strategy itself, whereas MetaAug is the first to use meta-learning from a data perspective to resolve PTQ overfitting.
- Inspirations: The proposed meta-learning framework for PTQ can be extended to other model compression scenarios that require few-shot calibration (e.g., pruning, distillation).
Rating¶
- Novelty: ⭐⭐⭐⭐ First to solve PTQ overfitting from a data perspective using meta-learning bi-level optimization, providing a unique perspective.
- Experimental Thoroughness: ⭐⭐⭐⭐ Comprehensive ablation and comparison experiments across multiple architectures and bit widths, with intuitive overfitting analysis.
- Writing Quality: ⭐⭐⭐⭐ Clear formula derivations and well-explained motivations.
- Value: ⭐⭐⭐⭐ Offers a new paradigm for PTQ research, orthogonal to and combinable with existing methods.