Preventing Catastrophic Overfitting in Fast Adversarial Training: A Bi-level Optimization Perspective¶
Conference: ECCV2024
arXiv: 2407.12443
Code: HandingWangXDGroup/FGSM-PCO
Area: AI Safety
Keywords: fast adversarial training, catastrophic overfitting, bilevel optimization, FGSM, adversarial examples
TL;DR¶
Analyzes the causes of catastrophic overfitting in fast adversarial training from a bi-level optimization perspective, and proposes the FGSM-PCO method. By adaptively fusing historical and current adversarial examples along with custom regularization loss, it effectively prevents and corrects the collapse of inner optimization.
Background & Motivation¶
Adversarial Training (AT) is an effective means of defending against adversarial examples, which can be modeled as a bi-level optimization problem: the inner level maximizes perturbation to generate adversarial examples, and the outer level minimizes the model loss on adversarial examples. Standard PGD-AT uses multi-step attacks to solve the inner loop problem, which is computationally expensive. Fast Adversarial Training (FAT) replaces PGD with single-step FGSM, greatly reducing training overhead, but faces a severe catastrophic overfitting issue—the model accuracy under FGSM attack skyrockets, while the accuracy under multi-step PGD attack plummets to 0%.
Existing FAT methods (such as FGSM-RS with random initialization, FGSM-GA with gradient alignment regularization, FGSM-MEP with momentum-perturbed initialization) can delay the occurrence of overfitting, but still cannot completely avoid catastrophic overfitting on complex tasks (e.g., Tiny-ImageNet) or large-parameter models (e.g., WideResNet34-10). More importantly, once overfitting occurs, these methods lack a correction mechanism to restore effective training.
Core Problem¶
- Root cause of catastrophic overfitting: The coupling of FGSM's large single-step attack with the alternating optimization mechanism easily leads to inner optimization collapse—the generated adversarial examples become ineffective against the current model, thereby invalidating the entire bi-level optimization.
- Limitations of Prior Work: Existing FAT methods can only delay the occurrence of overfitting and cannot bring the training back on track after overfitting has already occurred.
- Goal: Design a FAT framework that can both prevent catastrophic overfitting and automatically correct it when an overfitting trend emerges.
Method¶
Overall Architecture: FGSM-PCO¶
The core idea of FGSM-PCO (Preventing Catastrophic Overfitting) is: instead of directly using the adversarial examples generated by current FGSM for training, it adaptively fuses historical adversarial examples with current adversarial examples before using them for training.
1. Adversarial Example Generation and Fusion¶
Given the previous-epoch adversarial examples \(\boldsymbol{x}_{t-1}^*\), the processing flow of the current stage is:
- Calculate gradient direction: \(\mathbf{g}_t = \text{sign}(\nabla_{\mathbf{x}} \mathcal{L}(f_\theta(\boldsymbol{x}_{t-1}^*), \mathbf{y}))\)
- Generate amplified adversarial examples: \(\boldsymbol{x}_{am}^* = \boldsymbol{x}_{t-1}^* + \gamma \epsilon \mathbf{g}_t\), where \(\gamma\) is the amplification factor (default \(\gamma=2\)) to compensate for the perturbation decay caused by fusion.
- Adaptive fusion: \(\boldsymbol{x}_{train} = \lambda_t \boldsymbol{x}_{t-1}^* + (1-\lambda_t) \boldsymbol{x}_{am}^*\)
2. Adaptive Fusion Ratio¶
The fusion factor \(\lambda_t\) is determined by the model's classification confidence on the current adversarial example:
where \(k\) is the index of the true label. The key intuition:
- During normal training: Adversarial examples are effective, model confidence on the true class is low \(\to\) \(\lambda_t\) is small \(\to\) more current adversarial examples are used.
- When an overfitting trend emerges: Adversarial examples become ineffective, model confidence on the true class is high \(\to\) \(\lambda_t\) is large \(\to\) more historical adversarial examples are retained to avoid reliance on the ineffective current adversarial examples.
This mechanism ensures that when overfitting occurs, the training samples automatically lean toward historical valid samples, thereby correcting the training direction.
3. Custom Regularization Loss¶
To complement the fusion framework, the PCO loss function is proposed:
- First term: Cross-entropy loss on fused distribution samples to guarantee the model's discriminative ability on adversarial examples.
- Second term regularization: Requires the predictions of fused samples to maintain consistency with the predictions of the adversarial examples from the previous and current stages, preventing inner optimization collapse. \(\beta=10\) is the default weight.
4. Correction Capability¶
Unlike other FAT methods, FGSM-PCO possesses the ability to correct already occurred overfitting. Experiments show that when switching to FGSM-PCO after FGSM-AT overfits at epoch 16 or FGSM-MEP overfits at epoch 50, the model can recover effective training in both cases.
Key Experimental Results¶
CIFAR-10 + ResNet18¶
| Method | Clean Acc | PGD10 | PGD50 | AA | Training Time |
|---|---|---|---|---|---|
| PGD-AT (best) | 82.57 | 53.19 | 52.21 | 48.77 | 199 min |
| TRADES (best) | 82.03 | 54.06 | 53.16 | 49.47 | 241 min |
| FGSM-MEP (best) | 81.72 | 55.13 | 54.29 | 48.23 | 57 min |
| FGSM-PCO (best) | 82.05 | 56.32 | 55.67 | 48.04 | 60 min |
- PGD10 accuracy is 3.1% higher than PGD-AT and 1.2% higher than FGSM-MEP.
- The results of the last epoch checkpoint are consistent with the best checkpoint, proving no overfitting.
CIFAR-100 + WideResNet34-10¶
| Method | Clean Acc | PGD10 | Training Time |
|---|---|---|---|
| PGD-AT | 62.45 | 32.36 | 1397 min |
| FGSM-MEP | 43.42 | 23.77 | 407 min |
| FGSM-PCO | 65.80 | 29.80 | 421 min |
- In 10 independent repeated experiments, FGSM-PCO experienced overfitting 0/10 times, while both FGSM-AT/FGSM-RS were 10/10, and FGSM-MEP was 6/10.
Tiny-ImageNet + PreActResNet18¶
| Method | Clean Acc | PGD10 | PGD50 |
|---|---|---|---|
| PGD-AT (best) | 33.99 | 15.35 | 15.16 |
| FGSM-MEP (best) | 31.70 | 16.81 | 16.69 |
| FGSM-PCO (best) | 34.96 | 18.17 | 17.99 |
Ablation Study (CIFAR-10 + ResNet18)¶
- Fusion only (no adaptation, no regularization): PGD10 = 39.91%, overfitting occurs
- Fusion + Regularization loss: PGD10 = 54.27%, significant improvement
- Fusion + Adaptation: PGD10 = 50.67%
- All components: PGD10 = 56.12%, all three components are indispensable
Highlights & Insights¶
- From a bi-level optimization perspective, clearly explains the essence of catastrophic overfitting—a chain reaction caused by the collapse of inner-level optimization.
- Exquisitely designed adaptive fusion mechanism: Utilizes the model's own classification confidence as a signal, requiring no additional hyperparameter tuning.
- The first FAT method with correction capability: It can not only prevent overfitting but also recover training after overfitting has occurred.
- Complete avoidance of overfitting in 10/10 runs under the widely recognized difficult setting of WideResNet34-10 + CIFAR-100.
- Takes only 3 minutes more training time than FGSM-MEP, but saves 1/3 of GPU memory.
Limitations & Future Work¶
- Training overhead is still higher than the simplest FAT: Requires storing the previous-epoch adversarial examples and performing an additional forward pass, making it about 50% slower than FGSM-RS.
- Amplification factor \(\gamma\) is fixed at 2: Dynamic adjustment strategies have not been explored; different datasets/models might require different configurations.
- Does not surpass TRADES on AA (AutoAttack) metrics: Robustness under the strongest attacks still has a gap (48.04% vs 49.47%).
- Only validated under \(l_\infty\) norm constraint: The applicability to other norm constraints like \(l_2\) has not been discussed.
- Limited dataset scale: The largest validated dataset is Tiny-ImageNet (64×64); it has not been tested on full-scale ImageNet.
Related Work & Insights¶
| Method | Core Strategy | Can Prevent Overfitting? | Can Correct Overfitting? | Extra Overhead |
|---|---|---|---|---|
| FGSM-RS | Random initialization + Large step size | Partially | No | None |
| FGSM-GA | Gradient alignment regularization | Partially | No | Medium |
| FGSM-MEP | Momentum perturbation initialization | Mostly | No | High GPU Memory |
| FGSM-PCO | Adaptive fusion + Regularization | Completely | Yes | Low GPU Memory |
Key difference from FGSM-MEP: MEP initializes perturbations by accumulating gradient momentum, reducing the risk of inner optimization failure but without the ability to correct it; PCO directly involves historical samples in training through fusion, automatically increasing the proportion of historical samples to correct the direction when an overfitting trend emerges.
Inspirations & Connections¶
- Adaptive fusion concept can be extended to other training scenarios prone to collapse (e.g., GAN training, policy collapse in reinforcement learning).
- Utilizing the model's own confidence as a monitoring signal for training status is a lightweight and universal diagnostic mechanism.
- The idea of reusing historical samples shares similarities with experience replay; future work can explore introducing richer historical information in adversarial training.
- The regularization term requires "consistency of predictions before and after fusion", which is connected to consistency regularization in knowledge distillation.
Rating¶
- Novelty: ⭐⭐⭐⭐ — The analysis from the bi-level optimization perspective is profound, and the adaptive fusion mechanism is simple and effective.
- Experimental Thoroughness: ⭐⭐⭐⭐ — Three datasets and three models, including ablation study, sensitivity analysis, and correction capability validation.
- Writing Quality: ⭐⭐⭐⭐ — Clear motivation of the problem, complete derivation of the method.
- Value: ⭐⭐⭐⭐ — A practical improvement in the FAT field, resolving the long-standing problem of catastrophic overfitting.