Skip to content

Improving Discrete Diffusion Unmasking Policies Beyond Explicit Reference Policies (UPO)

Conference: ICLR 2026 arXiv: 2510.05725 Code: GitHub Area: Discrete Diffusion Models / Language Modeling Keywords: Masked Diffusion Models, Unmasking Policy, reinforcement-learning, KL-regularized MDP, GRPO

TL;DR

This paper proposes Unmasking Policy Optimization (UPO), which formalizes the denoising process of Masked Diffusion Models (MDMs) as a KL-regularized Markov Decision Process and trains a lightweight unmasking policy model via reinforcement learning to replace heuristic schedulers such as max-confidence. Both theoretical analysis and empirical results demonstrate that the learned policy generates samples closer to the true data distribution.

Background & Motivation

Masked Diffusion Models (MDMs) achieve discrete-space generation through iterative unmasking, and have demonstrated competitive performance against autoregressive models in language modeling (e.g., LLaDA, Dream-7B). During inference, the choice of which position to unmask first is critical to generation quality.

Theoretical Motivation: Kim et al. (2025) prove that no polynomial-time algorithm can exactly recover the data distribution over all masked sentences—certain "hard sub-problems" exist. Empirically, however, the max-confidence policy is shown to bypass these difficult instances.

Limitations of Prior Work: Large-scale MDMs (LLaDA, Dream-7B) rely on rule-based schedulers (max-confidence, max-margin, entropy), which are purely heuristic and lack theoretical optimality guarantees. Pass@N experiments reveal that both random and Top-5 policies can surpass max-confidence single-path accuracy when \(N\) trajectories are sampled, indicating the existence of better unmasking paths than max-confidence.

Core Problem: How to learn an unmasking policy that outperforms heuristics while maintaining training stability and theoretical guarantees?

Method

Overall Architecture

UPO formalizes the MDM denoising process as a finite-horizon Markov Decision Process (MDP). The base MDM \(\pi_\theta\) is frozen and left unmodified; only a lightweight unmasking policy model \(g_\phi\) is trained to determine the unmasking order. The policy model operates on intermediate features of the MDM and contains approximately 134M parameters, compared to 8B for the base MDM.

Key Designs

  1. MDP Formulation: The state is the current partially unmasked sequence \(\mathbf{x}_n\), and the action space is the index set of all masked positions \(\mathcal{A}_{\mathbf{x}_n}\). The policy \(g_\phi(a^i | \mathbf{x}_n)\) is parameterized via softmax to select the position to unmask, with environment dynamics provided by the frozen MDM:
\[p_{g_\phi}(\mathbf{x}_{n-1} | \mathbf{x}_n) = g_\phi(a_n | \mathbf{x}_n) \cdot \pi_\theta(\mathbf{x}_{n-1} | \mathbf{x}_n, a_n)\]

The policy model architecture consists of one Transformer layer followed by a 3-layer MLP, reusing the MDM's feature extraction for computational efficiency.

  1. KL-Regularized GRPO Objective: A strong reference policy \(g_{\mathrm{ref}}\) (e.g., max-confidence) is introduced, and an output-level GRPO loss with KL divergence regularization is optimized:
\[\max_\phi \mathbb{E}\left[\frac{p_{g_\phi}(\mathbf{x}_0|\mathbf{q})}{p_{g_{\phi_{\mathrm{old}}}}(\mathbf{x}_0|\mathbf{q})} A(\mathbf{q}, \mathbf{x}_0) - \beta D_{\mathrm{KL}}(p_{g_\phi} \| p_{g_{\mathrm{ref}}})\right]\]

The regularization term keeps \(g_\phi\) close to \(g_{\mathrm{ref}}\), functioning as a trust region to prevent instability.

  1. Theoretical Guarantees (Theorem 1 & 2):

  2. Convergence: Under iterative optimization, the expected reward of the policy converges to a fixed point \(r_{g^*} > r_{g_{\mathrm{ref}}}\) that strictly exceeds that of the reference policy.

  3. Distributional Proximity: The terminal distribution generated by the optimized policy has a strictly smaller KL divergence from the true data distribution than the reference policy, i.e., \(D_{\mathrm{KL}}(p_{\mathrm{data}} \| p_{g_{\phi^*}}) < D_{\mathrm{KL}}(p_{\mathrm{data}} \| p_{g_{\mathrm{ref}}})\).

Loss & Training

Since \(p_{g_\phi}(\mathbf{x}_0|\mathbf{q})\) is intractable to compute directly (requiring marginalization over all trajectories), Proposition 1 establishes that the gradient of a token-level surrogate loss approximates that of the output-level loss. The final actionable UPO loss is a token-level GRPO objective with clipping:

\[\mathcal{L}_{\mathrm{UPO}} = \frac{1}{G}\sum_g \left(\frac{1}{L}\sum_n \min\left(\frac{g_\phi(a_n^{(g)}|\mathbf{x}_n^{(g)})}{g_{\phi_{\mathrm{old}}}(a_n^{(g)}|\mathbf{x}_n^{(g)})} A_g, \mathrm{clip}(\cdot, 1-\epsilon, 1+\epsilon) A_g\right) - \beta D(p_{g_\phi} \| p_{g_{\mathrm{ref}}})\right)\]

Three reference policy variants are provided: max-confidence (CE regularization), softmax-confidence (KL regularization), and Top-K (KL regularization). The Top-K variant supports random initialization without pretraining.

Key Experimental Results

Main Results

Dataset Metric UPO Max-Confidence Random Gain (vs. conf.)
Sudoku Accuracy 0.817 0.705 0.616 +11.2%
Zebra Accuracy 0.362 0.337 0.339 +2.5%
GSM8K Accuracy 0.703 0.684 0.612 +1.9%
Math500 Accuracy 0.284 0.272 0.196 +1.2%

Ablation Study

Configuration GSM8K Acc Note
diffu-GRPO + Random 0.638 Baseline MDM post-training
diffu-GRPO + Max-Confidence 0.751 Heuristic scheduler
diffu-GRPO + UPO 0.764 UPO applied on top of post-trained MDM, +1.3%
KL regularization (Top-K, GSM8K) 0.703 With KL divergence term
No KL regularization (GSM8K) ~0.68 Degraded performance; early path collapse
Random init, no regularization (≈DColT) ~0.67 2–3% below reference-policy variants

Key Findings

  • Unmasking order is critically important for structured tasks such as Sudoku, where UPO achieves the largest gain (+20.1% vs. random).
  • The regularization term prevents early path collapse and maintains higher variance in group rewards.
  • UPO and diffu-GRPO are complementary—the former optimizes the scheduling policy, while the latter fine-tunes the MDM itself.
  • The optimal reference policy varies by task: max-confidence for Sudoku, Top-K for GSM8K.

Highlights & Insights

  • UPO decouples unmasking policy learning from MDM training; the policy model contains only 134M parameters (1.7% of the 8B MDM), resulting in very low training cost.
  • A complete pipeline from theory to practice is provided: theoretical guarantees (Theorems 1 & 2) → surrogate loss (Proposition 1) → actionable training objective.
  • Pass@N experiments intuitively demonstrate the suboptimality of heuristic policies, providing strong motivation for learned policies.
  • The trust-region design via KL regularization simultaneously ensures training stability and allows the policy to surpass the reference.

Limitations & Future Work

  • Gains on mathematical reasoning tasks such as GSM8K and Math500 are relatively modest, possibly because ordering signals in long sequences are less salient than in Sudoku.
  • The policy model relies on intermediate features from the MDM, creating architectural coupling.
  • The current formulation unmasks one position per step; extending to parallel multi-position unmasking is an important future direction.
  • Generalization experiments are limited—training and test sets are drawn from the same task distribution.
  • Compared to DColT (Huang et al., 2025): UPO introduces an explicit reference policy and KL regularization, yielding greater stability and improved performance.
  • diffu-GRPO applies RL to fine-tune the MDM itself, whereas UPO trains the scheduling policy; the two approaches are complementary.
  • This paradigm is generalizable to sampling schedule learning for continuous diffusion models (e.g., step-size selection for ODE solvers).

Rating

  • Novelty: ⭐⭐⭐⭐⭐ — First work to model MDM unmasking as a KL-regularized MDP with complete theoretical guarantees.
  • Experimental Thoroughness: ⭐⭐⭐⭐ — Four benchmarks covering logical and mathematical reasoning with thorough ablations; open-ended text generation evaluation is absent.
  • Writing Quality: ⭐⭐⭐⭐⭐ — Rigorous theoretical derivations, clear experimental design, and compelling motivation.
  • Value: ⭐⭐⭐⭐ — Introduces a new paradigm for discrete diffusion model inference, with significant practical improvements on structured tasks.