Skip to content

Singular Bayesian Neural Networks

Conference: ICML 2026
arXiv: 2602.00387
Code: None
Area: Bayesian Neural Networks / Variational Inference / Model Compression / Uncertainty Quantification
Keywords: Low-rank decomposition, singular posterior, PAC-Bayes, OOD detection, Mean-field Variational Inference

TL;DR

This paper parameterizes weight matrices directly as \(W=AB^\top\) instead of applying mean-field distributions to \(W\) itself, thereby inducing a low-rank posterior that is singular with respect to the Lebesgue measure. This reduces parameter counts from \(O(mn)\) to \(O(r(m+n))\) and tightens PAC-Bayes complexity from \(\sqrt{mn}\) to \(\sqrt{r(m+n)}\). Across MLP, LSTM, and Transformer architectures, it achieves OOD detection performance surpassing a 5-member Deep Ensemble while using \(33\times\) fewer parameters.

Background & Motivation

Background: Bayesian Neural Networks (BNNs) provide principled uncertainty quantification by maintaining weight distributions rather than point estimates, which is critical for high-stakes scenarios like healthcare and autonomous driving. The dominant approximation method is Mean-Field Variational Inference (MFVI), where each weight \(w_{ij}\) is modeled by an independent Gaussian \(\mathcal{N}(\mu_{ij}, \sigma_{ij}^2)\), requiring twice the parameters (mean + variance) of deterministic models.

Limitations of Prior Work: (1) Parameter Explosion—MFVI requires \(O(mn)\) variational parameters, restricting BNNs to small models; (2) Excessively Strong Independence Assumption—the fully factorized posterior erases structural correlations between weights, damaging expressivity; (3) Cinquin et al. (2021) identified fundamental pathologies in weight-space inference for Transformers (difficulty in prior setting and mapping weight-space to function-space); (4) Existing low-rank approaches have flaws: post-hoc low-rank perturbations (Rank-1 Mult.) rely on pre-trained backbones and lose end-to-end uncertainty, low-rank covariance approximations still parameterize full-rank \(W\) means, and Bayesian variants of LoRA are limited to fine-tuning.

Key Challenge: Modern neural networks empirically exhibit low intrinsic dimensionality (Aghajanyan et al., 2021; weight matrix singular values decay rapidly), yet the full-rank, independent parameterization of BNNs structurally ignores this fact, wasting parameters and losing correlations.

Goal: (1) Parameterize weight matrices directly as low-rank products to ensure the posterior naturally lies on a low-rank manifold; (2) Establish tight PAC-Bayes theoretical guarantees, reducing generalization complexity from \(\sqrt{mn}\) to \(\sqrt{r(m+n)}\); (3) Enable end-to-end training across MLP, LSTM, and Transformer architectures; (4) Learn uncertainty from scratch without relying on pre-trained backbones.

Key Insight: The authors observe that applying mean-field inference to the factors \(A, B\) rather than \(W\) induces a weight posterior \(q_W\) that is supported on the rank-\(r\) manifold \(\mathcal{R}_r\)—a manifold which has zero volume under the Lebesgue measure. In other words, this produces a posterior that is strictly singular with respect to the Lebesgue measure, rather than just "approximately low-rank." This geometric property serves as a strong inductive bias: all \(W_{ij}\) are coupled through shared factors \(A_{ik}\) and \(B_{jk}\), automatically generating structured correlations.

Core Idea: Place the Bayesian distribution on low-rank factors instead of weights. Treat "singularity" as a quantifiable inductive bias and use the Eckart-Young-Mirsky theorem to strictly characterize approximation errors via tail singular values \(\sum_{i>r} \sigma_i^2\).

Method

Overall Architecture

Each weight matrix \(W \in \mathbb{R}^{m \times n}\) is parameterized as \(W = AB^\top\), with \(A \in \mathbb{R}^{m \times r}\) and \(B \in \mathbb{R}^{n \times r}\). A scale-mixture Gaussian prior \(p_A(A) = \prod_j [\pi \mathcal{N}(0, \sigma_1^2) + (1-\pi)\mathcal{N}(0, \sigma_2^2)]\) is applied to the factors (heavy tails promote sparsity). Variational posteriors \(q_A, q_B\) are mean-field Gaussians, using the reparameterization trick \(A = \mu_A + \log(1+\exp(\rho_A)) \circ \epsilon_A\) for differentiability. The ELBO is decomposed into a data-fitting term \(\mathbb{E}_{q_A q_B}[\log p(\mathcal{D}|AB^\top)]\) and a regularization term \(\beta(\text{KL}(q_A \| p_A) + \text{KL}(q_B \| p_B))\). Implementation across architectures: MLPs factorize fully connected layers; Transformers factorize Q/K/V projections and FFNs, with embeddings using batch-sparsity to sample rows corresponding to active tokens; LSTMs factorize \(W_{ih}\) and \(W_{hh}\), sampling \(A, B\) once per batch and caching \(W\) across timesteps.

Key Designs

  1. Induced Singular Posterior and Geometric Inductive Bias:

    • Function: Places Bayesian uncertainty directly on the low-rank manifold, avoiding full-space diffusion seen in MFVI.
    • Mechanism: Perform variational inference on factors \((A, B)\); the distribution of weights \(W = AB^\top\) is obtained via a pushforward. Lemma 3.2 proves \(q_W(\mathcal{R}_r) = 1\) (supported on the set of rank-\(r\) matrices); Lemma 3.3 proves that for \(r < \min(m, n)\), the Lebesgue measure of \(\mathcal{R}_r\) is zero; Theorem 3.4 concludes \(q_W\) is singular with respect to the Lebesgue measure. This means \(q_W\) lacks a Lebesgue density—a fundamental geometric contrast to the "everywhere positive density" of MFVI.
    • Design Motivation: Wilson & Izmailov (2020) noted that Bayesian generalization depends on posterior support and inductive bias. While MFVI biases toward "independently tunable weights," Ours biases toward "weights coupled via shared factors," which aligns with the low-rank nature of modern networks and provides implicit regularization—updating \(W_{ij}\) requires modifying factors that affect entire rows or columns, preventing local memorization.
  2. Structured Weight Correlations (Lemma 3.5):

    • Function: Captures global correlations between weights within a low parameter budget, compensating for losses from the MFVI independence assumption.
    • Mechanism: Despite \(A\) and \(B\) being mean-field, elements of \(W\) are not independent\(\text{Cov}(W_{ij}, W_{i'j'}) = \sum_k \text{Cov}(A_{ik}B_{jk}, A_{i'k}B_{j'k})\). Any two weights sharing a latent factor \(k\) are correlated. The rank \(r\) controls the richness of the correlation structure: higher rank allows for complex block correlations while maintaining \(O(r(m+n))\) parameters. Figure 1 in the paper shows full-rank Bayes by Backprop (BBB) yields diagonal correlations, while low-rank results in block structures.
    • Design Motivation: Filters out high-frequency noise inconsistent with the dominant low-rank structure and captures "shared subspace" uncertainty propagation invisible to standard MFVI.
  3. Theoretical Guarantees: EYM Loss Decomposition + PAC-Bayes Tightening:

    • Function: Formally demonstrates that "low-rank \(\neq\) degenerate" and quantifies complexity gains.
    • Mechanism: Theorem 3.6 (EYM Loss Bound): Under \(L\)-Lipschitz loss, the gap between the optimal rank-\(r\) truncated SVD and the full-rank optimum is bounded by tail singular values \(|\mathbb{E}\ell(W^*x,y) - \mathbb{E}\ell(W^*_r x, y)| \le LR \sqrt{\sum_{i>r} \sigma_i^2(W^*)}\). Theorem 3.7 decomposes the error of learned \(W = AB^\top\) into learning error \(\|W - W^*_r\|_F\) + rank bias \(\sigma_{>r}\). Theorem 3.8 provides a PAC-Bayes complexity ratio of \(\sqrt{r(m+n)/mn} \ll 1\), which tightens significantly when \(r \ll \min(m, n)\). Theorem 3.9 provides a non-vacuous generalization bound based on low-rank Gaussian complexity.
    • Design Motivation: Provides theoretical guidance for choosing \(r\), which can be determined via singular value decay analysis or ablation, allowing for predicted loss upper bounds.

Loss & Training

All three ELBO terms are estimated via Monte Carlo (scale-mixture priors lack closed-form KL); Adam optimizer is used; \(\sigma = \log(1+\exp(\rho))\) ensures positivity; \(\beta\) serves as the KL temperature. The rank \(r_\ell\) is independently tunable for each layer. Predictions use Monte Carlo averaging over multiple weight samples.

Key Experimental Results

Main Results

Evaluated on MIMIC-III (ICU mortality, MLP), Beijing Air Quality (PM2.5 prediction, LSTM), and SST-2 (Sentiment, Transformer), comparing Deterministic / Deep Ensembles (5) / Full-Rank BBB / Low-Rank (Ours) / LR-SVD init / Rank-1 Mult.

Dataset (Architecture) Metric Ours Low-Rank Full-Rank BBB Deep Ens. (5) Parameters
MIMIC-III (MLP) AUC-OOD↑ 0.802 0.770 0.738 13.6k vs 44.8k / 112k
MIMIC-III (MLP) AUPR-In↑ 0.824 0.807 0.721
Beijing AQ (LSTM) PICP↑ 0.790 0.788 0.310 47k vs 132k / 330k
Beijing AQ (LSTM) AUROC-OOD↑ 0.710 0.492 0.730
SST-2 (Transformer) Acc↑ 0.806 0.752 0.825 1.5M vs 19.8M / 49.6M
SST-2 (Transformer) AUROC-OOD↑ 0.640 (2nd) 0.622 0.657
SST-2 Training Time min 8.2 23.1 64.7

Ablation Study

Configuration Key Metric Description
Low-Rank (random init, r=15) Best OOD AUC=0.802 Full Model
LR-SVD init OOD AUC=0.713 SVD initialization degrades (prematurely locks rank)
Rank-1 Mult. (post-hoc) OOD AUC=0.705 Confirms end-to-end low-rank > post-hoc perturbation
Full-Rank BBB OOD AUC=0.770 Confirms contribution of singular posterior
Rank sweep (Fig 3) \(r^* \approx 11\) critical PAC-Bayes bound becomes vacuous beyond critical value

Key Findings

  • OOD Detection vs. Likelihood Calibration Trade-off: Low-rank models outperform Deep Ensembles in OOD detection and uncertainty metrics (PICP/AUPR-Err), but fall slightly behind in in-distribution NLL/ECE. Structured correlations prioritize epistemic uncertainty, while ensembles prioritize likelihood calibration.
  • Modern architectures indeed display rapid singular value decay (especially in embeddings), providing strong empirical support for low-rank parameterization.
  • Full-Rank BBB performed worst on Transformers (0.752 acc), corroborating Cinquin et al.’s findings on pathologies in Transformer weight-space inference; low-rank constraints conversely stabilized training.
  • A single rank-\(r\) BNN can match the predictive performance of a 5-member Deep Ensemble with a \(33\times\) parameter reduction.

Highlights & Insights

  • "Singularity" is a feature, not a bug: Unlike traditional Bayesian methods that avoid singular posteriors, this work actively constructs them and quantifies their inductive bias—an elegant paradigm for geometrizing "prior beliefs."
  • EYM Theorem + Pushforward: Introducing classic matrix analysis tools to Bayesian deep learning complexity analysis provides clear loss upper bounds for selecting \(r\), which is highly practical.
  • Architecture-agnostic drop-in replacement: Low-rank variational layers can directly replace standard layers, making BNNs significantly easier to deploy in industry.
  • The approach of "placing Bayes on low-rank factors" can be transferred to LoRA fine-tuning, diffusion model weight uncertainty, and neural field parameters.

Limitations & Future Work

  • The rank \(r\) still requires manual selection or ablation; while singular value decay analysis helps, it requires a pre-trained backbone for SVD.
  • Deep Ensembles maintain an advantage in in-distribution likelihood (NLL=0.300 vs Ours 0.433 on MIMIC-III), showing structured correlations are not a "universal improvement."
  • Experimental scale remains relatively small (max is 4-layer BERT-mini), without validation on billion-scale models; the paper acknowledges this as foundational work.
  • Scale-mixture priors + Monte Carlo KL introduce additional sampling costs and sensitive hyperparameters (\(\pi, \sigma_1, \sigma_2\)).
  • Future directions: Combining with function-space methods like SNGP/Laplace, extending to SSM/Mamba architectures, and exploring "safe generation" via weight uncertainty in generative models.
  • vs Rank-1 Multiplicative (Dusenberry 2020a): They add rank-1 multiplicative noise to deterministic backbones post-hoc; Ours learns low-rank end-to-end from initialization, significantly winning on OOD.
  • vs Low-Rank Covariance (Tomczak 2020): They use low-rank + diagonal for covariance, but the weight mean is still full-rank; Ours applies low-rank directly to \(W\).
  • vs LoRA Bayesian (Yang 2024): LoRA requires fine-tuning a pre-trained backbone; Ours trains from scratch.
  • vs Deep Ensemble: Ensembles are "sampling-based poor man's Bayes" with \(5\times\) parameters; Ours is a single model with \(5\times\)\(33\times\) fewer parameters and better OOD detection, though slightly worse in-distribution likelihood.
  • vs SNGP / Linearized Laplace: They focus on function-space or last-layer uncertainty; Ours is end-to-end weight-space inference, making them complementary.
  • vs Watanabe’s Singular Learning Theory: The "singularity" here refers to geometric singularity of the induced posterior relative to the Lebesgue measure, distinct from Watanabe’s concept of asymptotic model singularity.

Rating

  • Novelty: ⭐⭐⭐⭐⭐ The geometric perspective of "singular posteriors" and the EYM loss bound framework are truly original BNN paradigms.
  • Experimental Thoroughness: ⭐⭐⭐⭐ Covers MLP/LSTM/Transformer across multiple OOD metrics; lacks large-scale LLM validation.
  • Writing Quality: ⭐⭐⭐⭐⭐ Theoretical derivations are rigorous and self-contained; Definition-Lemma-Theorem structure is very clear.
  • Value: ⭐⭐⭐⭐ Makes BNNs truly scalable to modern architectures with easy drop-in implementation; however, in-distribution calibration still lags behind Ensembles.