Skip to content

Distribution Transformers: Fast Approximate Bayesian Inference With On-The-Fly Prior Adaptation

Conference: ICML 2026
arXiv: 2502.02463
Code: https://github.com/GWhittle110/distribution-transformers
Area: Scientific Computing / Bayesian Inference / Amortized Inference with Transformers
Keywords: Amortized Bayesian Inference, Prior Adaptation, Gaussian Mixture Models (GMM), Sequential Filtering, Transformer

TL;DR

The Distribution Transformer (DT) explicitly tokenizes the "prior distribution" into a set of Gaussian Mixture Model (GMM) components and injects "observations" via cross-attention into a decoder. It learns an end-to-end mapping of "prior + data \(\to\) posterior." While maintaining the same distribution family as the prior (GMM \(\to\) GMM) to support sequential filtering, DT compresses inference time from minutes to milliseconds and allows arbitrary prior modification at test time without retraining.

Background & Motivation

Background: Amortized Bayesian Inference (ABI) pre-trains the expensive process of "solving the posterior for every new dataset." During offline training, a model \(z \mapsto q(x|z)\) is learned, while online application requires only a single forward pass. Transformer-based representatives like PFN, TabPFN, and ACE achieve single-forward posterior estimation in small-sample scenarios with performance comparable to SVI/MCMC.

Limitations of Prior Work: (1) These ABI models "hard-code" the prior during training—changing the prior requires retraining or regenerating data. (2) Even for few methods supporting "prior flexibility," the output distribution family (e.g., the Riemann bucket distribution in PFN) is inconsistent with the prior family, meaning the output posterior cannot be fed back as the next-round prior, making sequential filtering (e.g., Kalman or Particle Filter scenarios) impossible. (3) Classical sequential methods (EKF, PF) are flexible but either rely on strong Gaussian assumptions or have computational costs that explode with the number of particles, and they do not support cross-task amortization.

Key Challenge: Amortization, prior flexibility, and conjugacy (prior and posterior belonging to the same family) must be satisfied simultaneously for sequential Bayesian filtering—previous works inevitably sacrifice at least one.

Goal: (i) Achieve single-forward posterior estimation (Amortization); (ii) Allow arbitrary prior swapping at test time without retraining (Prior Amortization); (iii) Ensure the prior and posterior belong to the same GMM family for recursive filtering; (iv) Match or exceed PFN/TabPFN/ACE on static benchmarks while catching up to Particle Filters on sequential tasks with \(10^1\) to \(10^3\) speedup.

Key Insight: Utilize a "universal approximator distribution family" and operate within this family using a Transformer. The authors select Gaussian Mixture Models (GMM)—any compactly supported smooth density can be approximated to arbitrary precision by a \(k\)-component GMM. The GMM parameters \(\{(w_i,\boldsymbol{\mu}_i,\boldsymbol{\Sigma}_i)\}_{i=1}^{k}\) naturally form an "unordered sequence of tokens," perfectly matching the permutation invariance assumption of Transformers.

Core Idea: Rewrite Bayesian inference as a GMM-sequence to GMM-sequence mapping implemented by a Transformer decoder. Both the prior and observations are embedded as tokens, and the output returns to the GMM family—this homomorphism is the key to sequential filtering.

Method

Overall Architecture

The framework strings together four modules: Prior Embedding \(\to\) Observation Embedding \(\to\) Transformer Decoder \(\to\) GMM De-embedding. Given prior parameters \(\phi\), a learnable embedding network maps them to an unordered sequence of \(k\) tokens (the GMM representation in latent space). Given observations \(z\) (datasets, sensor readings, or query points), customized encoder networks produce another set of tokens. The Transformer decoder (without positional encodings to preserve permutation equivariance) performs self-attention among prior tokens and global cross-attention with observation tokens to output the posterior token sequence. Finally, a component-wise learnable de-embedding maps each token to \((\text{logit}, \boldsymbol{\mu}_i, \boldsymbol{\Sigma}_i)\). A cross-token softmax yields weights \(w_i\), assembling the GMM posterior \(q_\theta(x|z,\phi) = \sum_i w_i \mathcal{N}(x;\boldsymbol{\mu}_i,\boldsymbol{\Sigma}_i)\). Optionally, a sample-space transform \(f(\cdot)\) is introduced to modify support (e.g., log-warping for Inverse-Gamma priors with positive support).

Key Designs

  1. GMM-as-token Representation + Transformer Decoder:

    • Function: Represents "distributions" as token sequences digestible by Transformers and ensures input/output consistency within the GMM family.
    • Mechanism: Prior parameters \(\phi\) are embedded into \(k\) tokens; observations are embedded according to their source and concatenated as the context sequence. The Transformer decoder uses cross-attention between prior and context tokens. No positional encoding is used to match GMM component permutation invariance.
    • Design Motivation: GMMs match the set-to-set nature of Transformers. Homomorphism (matching input/output families) means the posterior at \(t\) can serve as the prior at \(t+1\), which is the algebraic prerequisite for sequential filtering—a feat PFN's Riemann bucket distribution cannot achieve.
  2. Meta-prior + KL-based Dual Training Objective:

    • Function: Enables the model to see a family of priors during training to support test-time prior swapping.
    • Mechanism: A "distribution over priors"—the meta-prior \(p(\phi)\)—is introduced. The joint distribution is \(p(\phi,x,z) = p(\phi)p(x|\phi)p(z|x)\). During training, \(\phi_i \sim p(\phi)\) is sampled per batch, followed by \(x_i \sim p(x|\phi_i)\) and \(z_i \sim p(z|x_i)\). The main loss is \(\ell_\theta = \mathbb{E}_{p(\phi,x,z)}[-\log q_\theta(f(x)|z,\phi)]\). Prop 3.1 proves this is equivalent to \(\mathbb{E}_{p(\phi,z)}[\mathrm{KL}(p(\cdot|z,\phi) \,\|\, q_\theta(\cdot|z,\phi))]\) plus a constant.
    • Design Motivation: Treating the prior as a random variable in the joint distribution amortizes the mapping over the larger space \(\Phi \times \mathcal{Z} \to \mathcal{Q}\). The KL form ensures the model directly approximates the true posterior.
  3. Prior Consistency Regularization (prior loss) for Latent Conjugacy:

    • Function: Applies de-embedding to "prior tokens" to get a GMM approximation of the prior \(q_\theta(x|\phi)\), forcing the prior and posterior to share the same latent representation.
    • Mechanism: Defined as \(\ell_\theta^{\mathrm{prior}} = \mathbb{E}_{p(\phi,x)}[-\log q_\theta(x|\phi)]\), with the combined loss \(\ell_\theta' = \ell_\theta^{\mathrm{prior}} + \ell_\theta\). Prior tokens are decoded before the Transformer, while posterior tokens are decoded after. Both must approximate their respective GMM distributions using the same de-embedding.
    • Design Motivation: Prior/posterior conjugacy requires them to be decoded via the same de-embedding. Without this, prior and posterior tokens might drift into different latent regions, causing sequential recursion to fail numerically.

Key Experimental Results

Main Results

Experiment 4.1: Analytic conjugate comparison with Inverse-Gamma prior and Normal variance likelihood, using narrow and wide meta-priors across 1,000 unseen problems.

Method KL (Narrow Meta-prior) KL (Wide Meta-prior) Inference Time per 1k (s)
SVI 0.0425 ± 0.0003 0.0558 ± 0.0016 148
PFN-15 0.517 ± 1.009* 331.5 ± 646.6* 0.003
PFN-5000 0.0038 ± 0.0789 0.2935 ± 0.0237 0.003
TabPFNv2 0.0112 ± 0.0013 0.1513 ± 0.0168 1.52
ACE-5 0.0094 ± 0.0000 0.0048 ± 0.0014 0.037
DT-2 0.0044 ± 0.0001 0.0058 ± 0.0002 0.014
DT-5 0.0004 ± 0.0000 0.0003 ± 0.0000 0.016

DT-5 achieves KL divergence nearly an order of magnitude lower than PFN-5000 for narrow meta-priors and 3 orders of magnitude lower for wide meta-priors. Inference is \(\approx 10^4\) times faster than SVI.

Experiment 4.2.1 (5D GP Predictive Posterior + Hyper-posterior): DT outperforms PFN/TabPFNv2/ACE on both PPD NLL (0.81) and hyper-posterior NLL (0.31), with the fastest time (9.5 s).

Experiment 4.3.1 (4D State-Space Bayesian Sensor Fusion):

Method Expected NLL Per-step Time (100 seq batch) (s)
EKF 95.9 ± 4.40 0.010
Particle Filter -0.244 ± 0.047 0.818
DT-4 -0.197 ± 0.040 0.017

DT matches the "quasi-ground truth" PF while being 50× faster. EKF fails completely due to linearization assumptions.

Ablation Study

Dimension / Method Observation Implication
GMM components \(k = 2\) vs \(5\) KL drops from 0.0044 to 0.0004 Component count acts as an "expressivity knob" decoupled from model parameters.
Riemann Output (PFN) vs GMM (DT/ACE) Riemann KL spikes to 331 under wide meta-prior Limited expressivity of bucketed distributions is a bottleneck for PFN.
With vs Without prior loss Performance gain is minor, but sequential recursion requires it Latent space conjugacy is the algebraic prerequisite for sequential capabilities.
Sequential PFN (concatenating obs) Inference time grows linearly or \(\mathcal{O}(T^2)\) DT's constant-time recurrence is a major engineering advantage.
Exp 4.3.2 (10D Stochastic Volatility) PF requires \(10^3\) more compute to match DT DT excels in high-dimensional sparse information scenarios.

Key Findings

  • Homomorphism is the key to sequential capability: GMM \(\to\) GMM mapping means post-inference posteriors can serve immediately as priors, decoupling single-step time from sequence length \(T\).
  • GMM expressivity has a high ceiling: Compared to Riemann bucketed distributions, a 5-component GMM is visually indistinguishable from the true posterior in conjugate experiments.
  • Wider meta-priors highlight prior flexibility: PFN-5000 degrades severely under wide meta-priors, while DT remains robust.
  • Prior loss is functional, not just performance-driven: Removing it has little effect on static KL but breaks the latent conjugacy, causing sequential filtering to fail.

Highlights & Insights

  • "Distribution as Input" is an undervalued design degree of freedom: Unlike previous ABI methods that treat priors as constants or hyperparameters, DT tokenizes prior parameters \(\phi\). This allows "prior swapping" to be generalized to Bayesian Optimization, ABC, or sensor fusion.
  • Architectural Symmetry \(\leftrightarrow\) Probabilistic Symmetry: The use of a Transformer without positional encodings matches GMM component disorder, and cross-attention matches observation conditional independence. This isomorphism between probabilistic and neural structures is a powerful paradigm for scientific ML.
  • From learning posteriors to learning operators: DT learns the operator "prior + data \(\to\) posterior." It moves amortization to a higher level of abstraction (cross-task + cross-prior-family).
  • Stackable Real-time Bayesian Filtering: DT enables non-Gaussian, non-linear SSMs at millisecond throughput, matching PF accuracy. This is highly valuable for autonomous sensing and industrial control.

Limitations & Future Work

  • Training cost scales with prior space dimension: Covering a wider \(\Phi\) requires significantly more offline samples and time.
  • Meta-prior reliance: Performance may decay if the test prior falls far outside the meta-prior range.
  • High-dimensional GMM bottlenecks: Self-attention is quadratic in component count, and full-covariance decoding is quadratic in latent dimension. High dimensions require sparse or low-rank covariances.
  • Error accumulation in long sequences: Approximation errors may drift over very long sequences, though they remain controlled in medium-depth tests.
  • vs PFN / TabPFN / TabPFNv2: PFN uses fixed priors and Riemann outputs. DT tokenizes the prior, uses GMM, and supports sequential filtering—a qualitative advancement.
  • vs ACE: ACE supports prior flexibility and GMM output. DT's primary distinction is its flexible embedding design and explicit homomorphism via prior loss, enabling sequential applications.
  • vs Classical Kalman / Particle Filters: EKF fails on non-linear observations; PF suffers from the curse of dimensionality. DT addresses both via non-linear expressivity and amortized constant throughput.
  • vs Variational Inference / Neural Processes: VI requires per-problem optimization. NPs typically amortize predictive distributions in data space rather than latent posterior space. DT amortizes, outputs latent posteriors, and allows prior swapping.

Rating

  • Novelty: ⭐⭐⭐⭐⭐ Tokenizing the distribution and enforcing prior-posterior homomorphism for sequential filtering are significant qualitative breakthroughs in ABI.
  • Experimental Thoroughness: ⭐⭐⭐⭐ Broad coverage (conjugate priors, GP, sensor fusion). Lacks a real-world robotics or autonomous driving end-to-end demo.
  • Writing Quality: ⭐⭐⭐⭐ Clear chain from motivation to theory, though the "functional" role of the prior loss could be further elaborated for intuition.
  • Value: ⭐⭐⭐⭐⭐ Achieving millisecond throughput, arbitrary prior swapping, and sequential cascading simultaneously is a major step for real-time Bayesian applications.