Discrete Diffusion Samplers and Bridges: Off-Policy Algorithms and Applications in Latent Spaces¶
Conference: ICML2026
arXiv: 2602.05961
Code: https://github.com/mmacosha/offpolicy-discrete-diffusion-samplers-and-bridges
Area: image_generation / diffusion sampling / discrete diffusion / amortized sampling / off-policy RL
Keywords: discrete diffusion samplers, off-policy RL, trajectory balance, Schrödinger bridge, VQ-VAE posterior sampling
TL;DR¶
This paper systematically migrates mature off-policy RL training techniques from continuous space diffusion sampling (replay buffer, importance weighting, MCMC exploration) to discrete diffusion samplers for the first time. It further extends these methods to data-to-energy discrete Schrödinger bridges, significantly mitigating mode collapse on multimodal distributions such as Ising/Potts and discretized GMMs. Finally, it applies the approach to data-free conditional image generation (posterior sampling) within the discrete latent space of VQ-VAEs.
Background & Motivation¶
Background: Sampling from unnormalized energies \(p(x) \propto e^{-E(x)}\) has long been dominated by MCMC, AIS, and SMC. In recent years, diffusion samplers have emerged as an amortized sampling approach: using a learned diffusion process to drive a simple prior \(p_0\) toward a target \(p_{\text{target}}\). Training objectives typically involve path-space divergences (trajectory balance, log-variance, etc.). In continuous spaces, the introduction of off-policy RL (replay buffer, importance weighting, MCMC exploration) has been proven to significantly enhance mode coverage (Sendera et al. 2024; Choi et al. 2026).
Limitations of Prior Work: Development of discrete versions (masked / uniform discrete diffusion sampler, e.g., MDNS, Zhu et al. 2025; Sanokowski et al. 2025a) has lagged significantly. These models almost exclusively utilize on-policy training (calculating losses directly from trajectories sampled by the current model), which is highly susceptible to mode collapse on strong multimodal targets like low-temperature Ising/Potts or high-dimensional GMMs, ultimately covering only a single mode.
Key Challenge: On-policy training only observes trajectories generated by the model itself; once the model biases toward a specific mode, it reinforces that bias. Breaking this feedback loop requires training data to "exceed the current policy coverage"—precisely the role of off-policy techniques in continuous diffusion samplers, which have not yet been cleanly introduced to the discrete setting. Furthermore, while GFlowNet research (Zhang et al. 2022a) has long performed off-policy training equivalent to discrete diffusion sampling, this line of work has remained disconnected from recent masked diffusion samplers.
Goal: (i) Cleanly migrate TB / LV objectives along with replay buffer, importance weighting, and MCMC exploration to discrete diffusion samplers under a unified second-moment loss framework; (ii) extend this methodology to data-to-energy discrete Schrödinger bridges (where one end provides only energy and the other only samples); (iii) establish a new application—data-free posterior sampling for generative models in discrete latent spaces (VQ-VAE).
Key Insight: The authors observe that the unified theory for continuous space (Berner et al. 2026) can be translated to the discrete case. Trajectory balance conveniently absorbs the unknown normalization constant into a learnable scalar \(c\), implying that the off-policy concepts from GFlowNet can be seamlessly applied to the mask/uniform kernels of discrete diffusion samplers. The technical bridge consists of replacing the path-space loss of continuous SDEs with the trajectory ratio loss of discrete Markov chains.
Core Idea: The only requirement is a unified second-moment loss \(\mathcal{L}_{\mathcal{P}} = \mathbb{E}_{\mathcal{P}}\big[(\log \tfrac{p_0 \otimes \overrightarrow{p}_\theta^{\otimes N}}{p_{\text{target}} \otimes \overleftarrow{p}^{\otimes N}} - c)^2\big]\). Off-policy learning is then realized through carefully designed trajectory distributions \(\mathcal{P}\) (buffer / importance weighting / MCMC refinement), porting the framework from continuous diffusion samplers to discrete diffusion samplers and Schrödinger bridges.
Method¶
Overall Architecture¶
Input: An energy function \(E: \mathcal{S} \to \mathbb{R}\), where \(\mathcal{S} = \{1, \dots, C\}^d\) is a discrete sequence space; the objective is to sample from \(p_{\text{target}}(x) \propto e^{-E(x)}\). The model consists of a forward (denoising) kernel \(\overrightarrow{p}_\theta(X_{n+1} \mid X_n)\) paired with a pre-selected backward (noising) kernel \(\overleftarrow{p}\) (masking or uniform discrete diffusion). Starting from \(p_0\), an \(N\)-step Markov chain arrives at \(X_N \in \mathcal{S}\), with the goal that the marginal distribution of \(X_N\) matches \(p_{\text{target}}\).
Training involves sampling trajectories \(X_{0:N}\) using a distribution \(\mathcal{P}\) and minimizing \(\mathcal{L}_{\mathcal{P}}\) (Eq. 3). When \(c\) is a learnable scalar, this corresponds to trajectory balance (TB); when \(c\) is the empirical batch mean, it is log-variance (LV). The choice of \(\mathcal{P}\) defines the on-policy versus off-policy boundary.
Bridge setup (§3): The fixed \(p_0\) is replaced by an arbitrary distribution, and \(\overleftarrow{p}_\varphi\) is also parameterized. IPF iterations (Eq. 6a/6b) are used to alternately fit \(\overrightarrow{\mathcal{P}}\) and \(\overleftarrow{\mathcal{P}}\). For the data-to-energy case, (6b) is trained using the LV variant described above.
Application (§4): In the discrete latent space \(z \in \{1, \dots, 8\}^{16}\) of a VQ-VAE, given a pretrained autoregressive prior \(p_{\text{latent}}(z)\), a deterministic decoder \(f\), and a categorical likelihood \(p(y \mid f(z))\), posterior sampling \(p(z \mid y) \propto p_{\text{latent}}(z) \cdot p(y \mid f(z))\) is treated as a new discrete energy sampling problem and trained with the same sampler.
Key Designs¶
-
Unified Second-Moment Loss + On/Off-Policy Decoupling:
- Function: Supports both TB and LV using the same objective \(\mathcal{L}_{\mathcal{P}}\) and allows for an arbitrary trajectory distribution \(\mathcal{P}\).
- Mechanism: The loss \(\mathcal{L}_{\mathcal{P}} = \mathbb{E}_{\mathcal{P}}[(\log \tfrac{p_0 \otimes \overrightarrow{p}_\theta^{\otimes N}}{p_{\text{target}} \otimes \overleftarrow{p}^{\otimes N}} - c)^2]\) maintains its unique minimum at \(p_0 \otimes \overrightarrow{p}_\theta^{\otimes N} = p_{\text{target}} \otimes \overleftarrow{p}^{\otimes N}\) as long as \(\mathcal{P}\) has full support. Thus, \(\mathcal{P}\) can be freely substituted for exploration without altering the optimal solution; when \(\mathcal{P} = p_0 \otimes \overrightarrow{p}_\theta^{\otimes N}\), the gradient aligns with reverse KL, corresponding to traditional on-policy training.
- Design Motivation: On-policy training suffers from high variance and can be locked into early modes. Formulating the loss as the "squared trajectory ratio" allows \(\mathcal{P}\) to function like a behavior policy in RL, which can be replaced by buffer samples or MCMC-refined trajectories to introduce exploration.
-
Importance-Weighted Replay Buffer:
- Function: Stores terminal states \(X_N\) from previous rollouts in a buffer and resamples training trajectories based on their "target importance."
- Mechanism: Each buffer element is stored with an importance weight \(w = e^{-E(X_N)} \prod_n \overleftarrow{p}(X_n \mid X_{n+1}) / [p_0(X_0) \prod_n \overrightarrow{p}_\theta(X_{n+1} \mid X_n)]\) (Eq. 4). During training, samples are drawn weighted by \(w\), and full trajectories are unrolled backward via \(\overleftarrow{p}^{\otimes N}\) to compute the loss.
- Design Motivation: A simple uniform buffer stabilizes rapidly changing policies; importance weighting further biases training towards samples with high target probability but low model probability—high-value regions that on-policy training fails to reach, which is crucial for mode coverage.
-
MCMC Exploration Refined Buffer:
- Function: Iteratively refines buffer samples using an MCMC kernel (e.g., Metropolis-Hastings, Swendsen-Wang) with \(p_{\text{target}}\) as the stationary distribution before training.
- Mechanism: MCMC only requires the energy function \(E\) and not the model, incurring negligible GPU overhead. For structured energies like Ising/Potts, proposals like Swendsen-Wang can be used for across-mode jumps. Algorithm 1 integrates "model rollout → buffer entry → MCMC refinement → weighted sampling" into a unified training loop.
- Design Motivation: Buffers can only store modes the model has already discovered, whereas MCMC can navigate to unseen modes guided by energy gradients. Their combination ensures training data originates near the true target and breaks mode collapse in low-temperature/strongly multimodal settings.
Loss & Training¶
The primary objective uses trajectory balance: \(\mathcal{L}_{\text{TB}} = \mathbb{E}_{\mathcal{P}}[(\log \tfrac{p_0 \overrightarrow{p}_\theta^{\otimes N}}{p_{\text{target}} \overleftarrow{p}^{\otimes N}} - \log Z_\phi)^2]\), where \(\log Z_\phi\) is a learnable constant absorbing the unknown normalization term. The log-variance version replaces \(c\) with the empirical batch mean to eliminate extra parameters. In bridge settings, IPF is used to alternately update forward and backward kernels, with the data-to-energy step using the LV variant. Variable time discretization is allowed during inference: multiple steps can be masked during training, while inference can be adjusted to "unmask one token per step" to balance memory and quality (§A.1). Target temperature annealing is enabled by default for fair comparison.
Key Experimental Results¶
Main Results¶
Ising / Potts Models (\(16 \times 16\) toroidal lattice, averaged over 5 runs; MDNS = SOTA on-policy discrete diffusion sampler):
| Setting | Method | ELBO ↑ | EUBO ↓ | Sinkhorn ↓ | Magnetisation err ↓ |
|---|---|---|---|---|---|
| Ising \(\beta=0.6\) | MDNS (WDCE, on-policy) | 310.18 | 341.82 | 48.71 | 0.41 |
| Ising \(\beta=0.6\) | LV (on-policy) | 309.77 | 422.53 | 116.96 | 0.97 (collapse) |
| Ising \(\beta=0.6\) | TB + Buffer | 310.42 | 310.56 | 3.59 | 0.04 |
| Ising \(\beta=0.6\) | TB + Buffer + MCMC | 310.43 | 310.55 | 3.47 | 0.02 |
| Ising \(\beta=1.2\) | MDNS / on-policy | 614.42 | >1100 | ~127 | 1.00 (severe collapse) |
| Ising \(\beta=1.2\) | TB + Buffer + MCMC | 615.03 | 615.14 | 0.02 | 0.03 |
| Potts \(q=3, \beta=1.2\) | MDNS | 620.23 | 680.52 | 99.95 | 0.58 |
| Potts \(q=3, \beta=1.2\) | TB + Buffer + MCMC | 620.73 | 621.30 | 12.37 | 0.03 |
Discretized Synthetic Density (Gray-coded 8-bit per dimension, averaged over 5 runs):
| Target | Method | ELBO ↑ | MMD ↓ | Sinkhorn ↓ |
|---|---|---|---|---|
| 40GMM (\(d=32\)) | MDNS | -16.66 | 0.17 | 349.31 |
| 40GMM (\(d=32\)) | TB (on-policy) | -2.47 | 0.40 | 2142.65 (collapse) |
| 40GMM (\(d=32\)) | TB + Buffer | -5.97 | 0.07 | 114.11 |
| 40GMM (\(d=32\)) | TB + Buffer + MCMC | -7.13 | 0.04 | 4.25 |
| ManyWell (\(d=80\)) | MDNS | 41.52 | 0.04 | 1.82 |
| ManyWell (\(d=80\)) | TB + Buffer + MCMC | 48.74 | 0.04 | 1.36 |
VQ-VAE Posterior Sampling (MNIST, 16-D latent, 8-word codebook, likelihood = odd/even/value): Both on-policy and off-policy LV correctly generate images of target categories. Off-policy shows superior diversity. This marks the first integration of discrete diffusion samplers into the latent space of pretrained generative models for data-free conditional generation.
Ablation Study¶
| Configuration | Key Findings | Explanation |
|---|---|---|
| on-policy (TB / LV) | Sinkhorn is 1-3 orders higher than off-policy on \(\beta=1.2\) Ising and 40GMM \(d=32\). | Severe mode collapse. |
| TB + Buffer | Significant improvement at most temperatures, but occasional collapse on Ising \(\beta=1.2\). | Buffer alone is unstable at extreme low temperatures. |
| TB + Buffer + MCMC | Stable across all temperatures; Sinkhorn comparable to or better than true MH MCMC. | MCMC exploration is critical in low-temperature regions. |
| LV vs TB | TB slightly outperforms LV at most temperatures, but LV is more parameter-efficient. | Consistent with continuous space findings. |
| MCMC Proposals | Swendsen-Wang is significantly faster on structured energies. | §D.1.3 |
| Schrödinger Bridge | Both work for simple bridges; on-policy collapses on difficult bridges (10GMM↔40GMM). | Fig. 4 |
Key Findings¶
- Among the three off-policy components, MCMC exploration is the only critical factor for stability in low-temperature/strongly multimodal settings—the buffer alone still occasionally fails on Ising \(\beta=1.2\). Once MCMC is integrated, results are stable and comparable to long-run MH MCMC.
- On-policy training consistently collapses to a single mode on all difficult tasks (magnetisation near 1, Sinkhorn 1–3 orders higher than off-policy), validating the necessity of migrating off-policy techniques to the discrete setting.
- TB generally performs slightly better than LV when an unknown \(\log Z\) can be modeled; however, LV is more natural for the data-to-energy step in Schrödinger bridges.
- VQ-VAE posterior sampling experiments demonstrate that discrete diffusion samplers can serve as "universal posterior plugins" for pretrained models without needing fine-tuning or backpropagation through the original model.
Highlights & Insights¶
- The paper clarifies the equivalence between "GFlowNet (Zhang et al. 2022a) == Discrete Diffusion Sampler (MDNS, Zhu et al. 2025)," making years of off-policy experience from the GFlowNet community directly applicable.
- Absorbing the unknown normalization constant \(Z\) into the learnable scalar \(c\) for TB is a clever engineering choice: the second-moment loss can reduce to KL when \(Z\) is known (bridge reference) or learn it directly when \(Z\) is unknown (sampling problem), avoiding bias correction issues in importance reweighting.
- The observation that "MCMC is almost free" is valuable: since MCMC only evaluates energy and not the model, it can be added to any amortized sampler for exploration during training. This is particularly useful when the energy is cheap and the model is expensive (e.g., protein folding, combinatorial optimization).
- The VQ-VAE posterior sampling framework treats controllable generation as a sampling problem in discrete latent spaces, which could eventually be applied to LLM token spaces as an alternative to RL fine-tuning that is entirely data-free.
Limitations & Future Work¶
- The experiments are conducted primarily on synthetic data (Ising / Potts / discretized GMM) and small-scale VQ-VAE/MNIST. Scaling to large discrete latent spaces (e.g., high-resolution VQ-GAN, large LM token spaces) and verifying MCMC exploration efficiency in high dimensions remains to be done.
- MCMC chains during training are not run until convergence, meaning MCMC acts more as a "beneficial perturbation" rather than "true posterior correction." Its effectiveness on more complex energy topologies is yet to be verified.
- The work relies strictly on trajectory balance / log-variance losses and lacks direct comparison with alternative amortized solutions based on SMC or adaptive importance sampling.
- The stability of data-to-energy IPF Schrödinger bridges in higher dimensions requires more systematic verification beyond the 16-bit Gray-coded experiments.
- Posterior sampling was tested only with categorical likelihoods; whether it holds for more complex conditions (OCR, ROI masks) remains an open question.
Related Work & Insights¶
- vs MDNS (Zhu et al. 2025): MDNS represents current on-policy discrete diffusion samplers using weighted denoising cross-entropy. This paper incorporates it into a unified second-moment framework and demonstrates that off-policy methods are systematically stronger in difficult multimodal settings.
- vs GFlowNet (Zhang et al. 2022a; Bengio et al. 2021/2023): GFlowNets used TB + off-policy for amortized sampling on discrete EBMs years ago, but the connection to masked discrete diffusion samplers was missing. This paper explicitly unifies the two.
- vs Continuous Diffusion Off-Policy Work (Sendera et al. 2024; Choi et al. 2026): While continuous settings use SDEs, this paper proves that the same buffer + importance weighting + MCMC triplet is equally effective for discrete Markov chains.
- vs Discrete Schrödinger Bridge (Kim et al. 2025a; Ksenofontov & Korotin 2025): Prior discrete bridge work required samples at both ends (data-to-data). This paper extends the continuous data-to-energy IPF (Tamogashev & Malkin, 2026) to the discrete case.
- vs Outsourced Diffusion Samplers (Venkatraman et al. 2025): Following data-free posterior sampling in continuous latent spaces (VAE/GAN), this paper extends the idea to VQ-VAE discrete latent spaces.
Rating¶
- Novelty: ⭐⭐⭐⭐ While individual components (TB/LV/buffer/MCMC) originate from existing work, the contribution lies in the systematic migration, unified framework, and new applications (bridges/VQ-VAE).
- Experimental Thoroughness: ⭐⭐⭐⭐ Covers six scenarios including Ising, Potts, GMMs, ManyWell, Schrödinger bridges, and VQ-VAE. Ablations are detailed, though large-scale generative model verification is missing.
- Writing Quality: ⭐⭐⭐⭐⭐ Clear clarification of the relationship between GFlowNet and discrete diffusion samplers. Formulas are clean and sections are well-connected.
- Value: ⭐⭐⭐⭐ Provides a ready-to-use toolbox for discrete amortized sampling. The VQ-VAE posterior sampling line is insightful for the controllable generation community.