Improving Generalization with Flat Hilbert Bayesian Inference¶
Conference: ICML2025
arXiv: 2410.04196
Code: To be confirmed
Area: LLM Evaluation
Keywords: Bayesian Inference, Sharpness-Aware Minimization, RKHS, SVGD, LoRA, Generalization
TL;DR¶
Proposes Flat Hilbert Bayesian Inference (FHBI), which generalizes the sharpness-aware minimization (SAM) concept of flatness from finite-dimensional Euclidean space to infinite-dimensional Reproducing Kernel Hilbert Space (RKHS), and integrates it with particle-based Bayesian inference, outperforming nine baselines with an average Top-1 accuracy of 73.7% on the VTAB-1K benchmark.
Background & Motivation¶
Bayesian inference models uncertainty through the posterior distribution. However, existing methods (such as SVGD) only approximate the empirical posterior \(p(\theta|\mathcal{S})\), which is prone to overfitting the training set. On the other hand, SAM improves generalization by finding flat minima through minimizing loss sharpness, but is restricted to single-model optimization in finite-dimensional Euclidean space.
Core Motivation: Can the flatness minimization concept of SAM be integrated with particle Bayesian methods to approximate the population posterior \(p(\theta|\mathcal{D})\) instead of the empirical posterior in the function space (RKHS), thereby simultaneously achieving:
- Uncertainty quantification capabilities of Bayesian methods
- Generalization improvements brought by flatness minimization
- Interactive diversity among particles
Method¶
Population Posterior vs. Empirical Posterior¶
Define the empirical posterior and population posterior:
Proposition 4.1 proves that the population posterior is the solution to the following optimization problem:
That is, the particle ensemble sampled from \(\mathbb{Q}^*\) optimally minimizes the population loss, preventing overfitting.
Function Space Generalization Bound (Theorem 4.2)¶
Generalizes existing Euclidean space generalization bounds to RKHS \(\mathcal{H}^d\):
The key contribution lies in handling the infinite-dimensional property of RKHS (such as RBF kernels), which prevents the direct application of dimension-dependent existing results.
Bayesian Inference Generalization Bound (Theorem 4.3)¶
Connects function space sharpness to Bayesian inference—the KL divergence of the population posterior can be upper-bounded by the worst-case KL divergence of the empirical posterior:
FHBI Algorithm¶
Based on the above theory, FHBI adopts a two-step iterative update:
Step 1 - Adversarial Perturbation (Ascent Step): Find the worst perturbation in RKHS along the direction of the function gradient
Step 2 - Function Descent Step: Compute the gradient at the perturbed position and update
In practice, \(m\) particles \(\{\theta_i\}_{i=1}^m\) are maintained. The update of each particle involves information interaction among all particles:
- Sharpness Minimization: Each particle searches for flat regions (similar to SAM)
- Angle Repulsion Force: Directs gradient diversity of particles (minimizing \(\nabla_{\theta_j}\mathcal{L}(\theta_j) \cdot \nabla_{\theta_k}\mathcal{L}(\theta_k)\))
- Spatial Repulsion Force: Kernel gradient term \(\nabla_\theta k(\theta, \theta_j)\) prevents particle collapse
Unified Perspective: FHBI is a generalization of SAM and SVGD—it degenerates to SVGD when \(\rho=0\), and to SAM when \(m=1\).
Key Experimental Results¶
Evaluated on the VTAB-1K benchmark (19 datasets, spanning Natural, Specialized, and Structured categories) using ViT-B/16 + LoRA fine-tuning:
| Method | Natural (7) | Specialized (4) | Structured (8) | Average |
|---|---|---|---|---|
| AdamW | 79.1 | 84.3 | 59.0 | 72.0 |
| SAM | 80.1 | 83.2 | 56.0 | 70.5 |
| DeepEns | 79.3 | 83.9 | 42.8 | 67.0 |
| BayesTune | 80.5 | 84.9 | 59.3 | 72.2 |
| SVGD | 79.8 | 84.6 | 56.3 | 70.9 |
| SADA-JEM | 80.3 | 84.7 | 58.6 | 72.1 |
| FHBI | 82.4 | 86.9 | 61.6 | 73.7 |
- FHBI achieves the best performance across all three domains (Natural, Specialized, and Structured)
- Achieves an average Top-1 accuracy of 73.7%, outperforming the best baseline BayesTune by 1.5 percentage points
- Shows a more pronounced advantage on the highly challenging Structured datasets
Expected Calibration Error (ECE): FHBI also achieves the lowest ECE across multiple datasets, indicating better model confidence calibration.
Highlights & Insights¶
- Solid Theoretical Contribution: Generalizes the generalization bound of SAM from finite dimensions to infinite-dimensional RKHS for the first time, which is non-trivial as it requires addressing infinite-dimensionality.
- Elegant Unified Perspective: FHBI unifies SAM (single-particle flatness) and SVGD (multi-particle posterior approximation), revealing their intrinsic relationship.
- Triple Diversity Mechanism: Sharpness + Angle Repulsion + Spatial Repulsion, a triple-force mechanism that promotes particle diversity.
- Comprehensive Experiments: 19 datasets × 9 baselines, validating effectiveness on both Top-1 accuracy and ECE metrics.
- Natural Integration with LoRA: Multi-particle inference is run only on lightweight LoRA parameters, keeping computational overhead manageable.
Limitations & Future Work¶
- Computational Cost: \(m\) particles entail \(m\) times forward/backward passes and kernel matrix computation, leading to non-negligible overhead when scaling to large-scale models.
- Evaluated Only on ViT-B/16 + LoRA: Not verified on larger models (such as ViT-L or LLMs) or other PEFT methods.
- Sensitivity to Kernel Choice: The paper employs the RBF kernel without fully exploring the impact of other kernel functions.
- Theory-to-Practice Gap: Lemma 4.4 relies on the approximation of sufficiently small \(\|f\|\), and the degree to which this condition is met in practice is not thoroughly discussed.
- Focus Only on Image Classification: Generalization has not been validated on NLP, object detection, or other tasks.
Related Work & Insights¶
- SAM (Foret et al., 2021): Euclidean space special case of FHBI (\(m=1\))
- SVGD (Liu & Wang, 2016): Perturbation-free special case of FHBI (\(\rho=0\))
- BayesTune (Kim et al., 2023): Bayesian fine-tuning baseline
- SA-BNN / SADA-JEM: Sharpness-aware Bayesian methods
- Insight: Systematically elevating optimization techniques (sharpness awareness) to the function space level is an important direction for Bayesian deep learning.
Rating¶
- Novelty: ⭐⭐⭐⭐ (RKHS generalization bound + unified SAM/SVGD perspective)
- Experimental Thoroughness: ⭐⭐⭐⭐ (19 datasets × 9 baselines, multiple metrics)
- Writing Quality: ⭐⭐⭐⭐ (Clear theoretical derivations and intuitive diagrams)
- Value: ⭐⭐⭐⭐ (Meaningful fusion of Bayesian inference and generalization theory)