Improving Discrete Diffusion Unmasking Policies Beyond Explicit Reference Policies (UPO)¶
Conference: ICLR 2026 arXiv: 2510.05725 Code: GitHub Area: Discrete Diffusion Models / Language Modeling Keywords: Masked Diffusion Models, Unmasking Policy, reinforcement-learning, KL-regularized MDP, GRPO
TL;DR¶
This paper proposes Unmasking Policy Optimization (UPO), which formalizes the denoising process of Masked Diffusion Models (MDMs) as a KL-regularized Markov Decision Process and trains a lightweight unmasking policy model via reinforcement learning to replace heuristic schedulers such as max-confidence. Both theoretical analysis and empirical results demonstrate that the learned policy generates samples closer to the true data distribution.
Background & Motivation¶
Masked Diffusion Models (MDMs) achieve discrete-space generation through iterative unmasking, and have demonstrated competitive performance against autoregressive models in language modeling (e.g., LLaDA, Dream-7B). During inference, the choice of which position to unmask first is critical to generation quality.
Theoretical Motivation: Kim et al. (2025) prove that no polynomial-time algorithm can exactly recover the data distribution over all masked sentences—certain "hard sub-problems" exist. Empirically, however, the max-confidence policy is shown to bypass these difficult instances.
Limitations of Prior Work: Large-scale MDMs (LLaDA, Dream-7B) rely on rule-based schedulers (max-confidence, max-margin, entropy), which are purely heuristic and lack theoretical optimality guarantees. Pass@N experiments reveal that both random and Top-5 policies can surpass max-confidence single-path accuracy when \(N\) trajectories are sampled, indicating the existence of better unmasking paths than max-confidence.
Core Problem: How to learn an unmasking policy that outperforms heuristics while maintaining training stability and theoretical guarantees?
Method¶
Overall Architecture¶
UPO formalizes the MDM denoising process as a finite-horizon Markov Decision Process (MDP). The base MDM \(\pi_\theta\) is frozen and left unmodified; only a lightweight unmasking policy model \(g_\phi\) is trained to determine the unmasking order. The policy model operates on intermediate features of the MDM and contains approximately 134M parameters, compared to 8B for the base MDM.
Key Designs¶
- MDP Formulation: The state is the current partially unmasked sequence \(\mathbf{x}_n\), and the action space is the index set of all masked positions \(\mathcal{A}_{\mathbf{x}_n}\). The policy \(g_\phi(a^i | \mathbf{x}_n)\) is parameterized via softmax to select the position to unmask, with environment dynamics provided by the frozen MDM:
The policy model architecture consists of one Transformer layer followed by a 3-layer MLP, reusing the MDM's feature extraction for computational efficiency.
- KL-Regularized GRPO Objective: A strong reference policy \(g_{\mathrm{ref}}\) (e.g., max-confidence) is introduced, and an output-level GRPO loss with KL divergence regularization is optimized:
The regularization term keeps \(g_\phi\) close to \(g_{\mathrm{ref}}\), functioning as a trust region to prevent instability.
-
Theoretical Guarantees (Theorem 1 & 2):
-
Convergence: Under iterative optimization, the expected reward of the policy converges to a fixed point \(r_{g^*} > r_{g_{\mathrm{ref}}}\) that strictly exceeds that of the reference policy.
- Distributional Proximity: The terminal distribution generated by the optimized policy has a strictly smaller KL divergence from the true data distribution than the reference policy, i.e., \(D_{\mathrm{KL}}(p_{\mathrm{data}} \| p_{g_{\phi^*}}) < D_{\mathrm{KL}}(p_{\mathrm{data}} \| p_{g_{\mathrm{ref}}})\).
Loss & Training¶
Since \(p_{g_\phi}(\mathbf{x}_0|\mathbf{q})\) is intractable to compute directly (requiring marginalization over all trajectories), Proposition 1 establishes that the gradient of a token-level surrogate loss approximates that of the output-level loss. The final actionable UPO loss is a token-level GRPO objective with clipping:
Three reference policy variants are provided: max-confidence (CE regularization), softmax-confidence (KL regularization), and Top-K (KL regularization). The Top-K variant supports random initialization without pretraining.
Key Experimental Results¶
Main Results¶
| Dataset | Metric | UPO | Max-Confidence | Random | Gain (vs. conf.) |
|---|---|---|---|---|---|
| Sudoku | Accuracy | 0.817 | 0.705 | 0.616 | +11.2% |
| Zebra | Accuracy | 0.362 | 0.337 | 0.339 | +2.5% |
| GSM8K | Accuracy | 0.703 | 0.684 | 0.612 | +1.9% |
| Math500 | Accuracy | 0.284 | 0.272 | 0.196 | +1.2% |
Ablation Study¶
| Configuration | GSM8K Acc | Note |
|---|---|---|
| diffu-GRPO + Random | 0.638 | Baseline MDM post-training |
| diffu-GRPO + Max-Confidence | 0.751 | Heuristic scheduler |
| diffu-GRPO + UPO | 0.764 | UPO applied on top of post-trained MDM, +1.3% |
| KL regularization (Top-K, GSM8K) | 0.703 | With KL divergence term |
| No KL regularization (GSM8K) | ~0.68 | Degraded performance; early path collapse |
| Random init, no regularization (≈DColT) | ~0.67 | 2–3% below reference-policy variants |
Key Findings¶
- Unmasking order is critically important for structured tasks such as Sudoku, where UPO achieves the largest gain (+20.1% vs. random).
- The regularization term prevents early path collapse and maintains higher variance in group rewards.
- UPO and diffu-GRPO are complementary—the former optimizes the scheduling policy, while the latter fine-tunes the MDM itself.
- The optimal reference policy varies by task: max-confidence for Sudoku, Top-K for GSM8K.
Highlights & Insights¶
- UPO decouples unmasking policy learning from MDM training; the policy model contains only 134M parameters (1.7% of the 8B MDM), resulting in very low training cost.
- A complete pipeline from theory to practice is provided: theoretical guarantees (Theorems 1 & 2) → surrogate loss (Proposition 1) → actionable training objective.
- Pass@N experiments intuitively demonstrate the suboptimality of heuristic policies, providing strong motivation for learned policies.
- The trust-region design via KL regularization simultaneously ensures training stability and allows the policy to surpass the reference.
Limitations & Future Work¶
- Gains on mathematical reasoning tasks such as GSM8K and Math500 are relatively modest, possibly because ordering signals in long sequences are less salient than in Sudoku.
- The policy model relies on intermediate features from the MDM, creating architectural coupling.
- The current formulation unmasks one position per step; extending to parallel multi-position unmasking is an important future direction.
- Generalization experiments are limited—training and test sets are drawn from the same task distribution.
Related Work & Insights¶
- Compared to DColT (Huang et al., 2025): UPO introduces an explicit reference policy and KL regularization, yielding greater stability and improved performance.
- diffu-GRPO applies RL to fine-tune the MDM itself, whereas UPO trains the scheduling policy; the two approaches are complementary.
- This paradigm is generalizable to sampling schedule learning for continuous diffusion models (e.g., step-size selection for ODE solvers).
Rating¶
- Novelty: ⭐⭐⭐⭐⭐ — First work to model MDM unmasking as a KL-regularized MDP with complete theoretical guarantees.
- Experimental Thoroughness: ⭐⭐⭐⭐ — Four benchmarks covering logical and mathematical reasoning with thorough ablations; open-ended text generation evaluation is absent.
- Writing Quality: ⭐⭐⭐⭐⭐ — Rigorous theoretical derivations, clear experimental design, and compelling motivation.
- Value: ⭐⭐⭐⭐ — Introduces a new paradigm for discrete diffusion model inference, with significant practical improvements on structured tasks.