Skip to content

AdaNAT: Exploring Adaptive Policy for Token-Based Image Generation

Conference: ECCV 2024
arXiv: 2409.00342
Code: https://github.com/LeapLabTHU/AdaNAT
Area: Image Generation
Keywords: Non-autoregressive Transformer, Reinforcement Learning, Adaptive Generation Policy, Adversarial Reward, Token Generation

TL;DR

This work proposes AdaNAT, which models the generation policy configuration of Non-Autoregressive Transformers (NAT) as an MDP. Utilizing a lightweight policy network combined with PPO reinforcement learning and an adversarial reward model, AdaNAT automatically customizes generation policies (re-masking ratio, sampling temperature, CFG weights, etc.) for each sample. It achieves an FID of 2.86 on ImageNet-256 using only 8 steps, yielding an approximate 40% relative improvement over hand-crafted policies.

Background & Motivation

Background

Background: Token-based image generation represents a crucial paradigm operating in parallel with diffusion models. Non-Autoregressive Transformers (NAT, e.g., MaskGIT, Muse, MAGE) generate images with few steps through a parallel decoding mechanism, demonstrating potential in balancing both efficiency and quality. However, the generation process of NAT requires configurations of complex policies—specifically, setting the re-masking ratio \(m^{(t)}\), sampling temperature \(\tau_1^{(t)}\), re-masking temperature \(\tau_2^{(t)}\), and classifier-free guidance (CFG) scale \(w^{(t)}\) at each step, totaling \(4 \times T\) hyperparameters. Existing methods entirely rely on hand-crafted scheduling functions (e.g., cosine schedule), which are not only labor-intensive but also sub-optimal. These manual schedules are globally shared, "one-size-fits-all" policies that cannot adapt to the varying complexities of different samples.

Limitations of Prior Work

Limitations of Prior Work: Goal: How to automatically and adaptively configure the optimal policy for each individual sample generated by NAT? Three specific challenges arise: (1) The joint search space of \(4T\) hyperparameters is extremely large, making manual tuning impractical; (2) Samples with different complexities should ideally utilize distinct generation policies (simple patterns converge early with fewer adjustments, while complex structures demand more delicate steps), yet existing methods apply the same policy to all samples; (3) The discrete token generation process is non-differentiable, preventing direct end-to-end optimization of the policy network.

Method

Overall Architecture

AdaNAT trains a lightweight policy network on top of a pre-trained, frozen NAT model. The input is the generation state of the current step (time step \(t\) and current token sequence \(\mathbf{v}^{(t)}\)), and the output is the current step's policy configuration (\(m^{(t)}, \tau_1^{(t)}, \tau_2^{(t)}, w^{(t)}\)). The entire pipeline does not modify the parameters of the underlying NAT model, learning only "how to utilize it better."

Key Designs

  1. MDP Modeling: The \(T\)-step generation process of NAT is modeled as a Markov Decision Process (MDP). The state \(s_t = (t, \mathbf{v}^{(t)})\) is the current token sequence, the action \(a_t\) consists of the four policy parameters, the state transition is determined by the frozen NAT model (executing one parallel decoding + re-masking step given the action), and the reward is provided only at the final step (evaluating the quality of the final generated image). This formulation elegantly bypasses the non-differentiable nature of discrete tokens.

  2. Lightweight Policy Network: The policy network reuses the existing output features \(f_\theta(\mathbf{v}^{(t)})\) of the NAT model as input (requiring no extra encoders) and consists only of a deep convolutional layer, a pointwise convolutional layer, and an MLP, with AdaLN to inject timestamp information. The overall inference overhead accounts for only 0.03% of the total NAT inference cost, which is practically negligible. The policy is outputted as a Gaussian distribution \(\pi_\phi(a_t|s_t) = \mathcal{N}(\eta_\phi(s_t), \sigma I)\), enabling stochastic exploration during training and using the mean during inference.

  3. Adversarial Reward Model: This is the core contribution of the paper. The authors systematically compare three reward designs:

    • FID Reward: Because FID is a statistical metric, it cannot provide sample-level feedback. Empirically, the policy network failed to train or produced visually poor images (low FID but blurry/distorted images), indicating that the FID metric can be "hacked."
    • Pre-trained Reward Model (ImageReward): Although it provides sample-level feedback, the policy network tends to generate images with collapsed styles, leading to a severe lack of diversity as it "overfits" the static reward.
    • Adversarial Reward (Ours): A GAN-like discriminator is introduced as the reward model \(r_\psi\), forming a minimax game with the policy network. The policy network maximizes the reward, while the reward model simultaneously updates to distinguish real from generated images. Since the reward is dynamically updated, the policy network cannot easily overfit, achieving a balance between fidelity and diversity.

Loss & Training

  • The policy network is optimized using the PPO algorithm, adopting a clipped surrogate objective and a value function loss.
  • The adversarial reward model is trained using standard GAN discriminator loss (binary cross-entropy for classifying real and fake images).
  • The two models are updated alternately, with 5 gradient updates per round each to stabilize the minimax game.
  • The overall training converges within 1000 iterations with a batch size of 4096. The exploration parameter \(\sigma\) decays from 0.6 to 0.3 after 500 rounds.
  • Key: Throughout the entire process, the NAT model parameters remain completely frozen, requiring no gradient backpropagation through the NAT.

Key Experimental Results

Dataset Model Steps TFLOPs FID-50K Comparison with SOTA
ImageNet-256 AdaNAT-S 4 0.2 4.54 MaskGIT(8 steps): 6.18
ImageNet-256 AdaNAT-S 8 0.3 3.71 MaskGIT-RS(8 steps): 4.02
ImageNet-256 AdaNAT-L 4 0.5 3.63 U-ViT-H†(4 steps): 8.45
ImageNet-256 AdaNAT-L 8 0.9 2.86 DiT-XL†(8 steps): 5.18
ImageNet-512 AdaNAT-L 8 1.2 3.66 ADM-G: 7.72
MS-COCO AdaNAT-S 8 0.3 5.75 U-ViT†(8 steps): 6.37
CC3M AdaNAT-Muse 8 2.8 6.83 Muse(8 steps): 7.67

Ablation Study

  • Contribution of Learnability: Hand-crafted policy \(\rightarrow\) Learnable (non-adaptive) policy: FID drops from 7.65 to 5.40 (-30%); adding adaptability further reduces to 4.54 (-16%). The total relative improvement is approximately 40%.
  • Adversarial Reward vs. Alternatives: The FID reward causes adaptive policy training to collapse (FID 55.4); the pre-trained reward model suffers from poor image diversity; the adversarial reward obtains the best trade-off between quality and diversity.
  • Policy Network Overhead: Accounts for only 0.03% of the total inference cost, which is completely negligible.

Highlights & Insights

  • "Optimize the User, Not the Generator" Paradigm: Instead of modifying the pre-trained NAT model, this approach learns how to better utilize it. This represents an efficient post-processing optimization paradigm that can be generalized to any generative models requiring complex runtime policy configurations.
  • Combination of RL and Adversarial Reward: Addressing policy overfitting via dynamic reward signals offers valuable insights for all RL-based generative model optimization (e.g., RLHF).
  • Visualization of Adaptive Policies: The paper clearly illustrates the differentiated behavior of the policy network across samples of varying complexities—early stopping and fine-tuning for simple images versus continuous dramatic adjustments for complex images—validating the rationality of the adaptive strategy.
  • Extremely Lightweight: The policy network reuses NAT features, incurring a negligible extra overhead of 0.03%, rendering it a "plug-and-play" enhancement for "free."

Limitations & Future Work

  • The authors note that scalability has not yet been validated on ultra-large datasets (such as LAION-5B) and very large models (>1B parameters).
  • The training stability of the adversarial reward model might vary across different datasets/model scales, and a deeper analysis is lacking.
  • The validation is restricted to class-conditional and text-to-image generation, without exploring other generative tasks (e.g., image editing, video generation).
  • The adversarial reward model adopts StyleGAN-T's discriminator architecture, leaving the impact of more diverse discriminator designs unexplored.
  • There is a lack of comprehensive and fair comparison with the contemporary AutoNAT (CVPR 2024)—the core ideas are similar, but the reward designs differ.
  • vs MaskGIT/MAGE: These works utilize hand-crafted scheduling functions like the cosine schedule. AdaNAT shows that these designs are severely sub-optimal (with a 40% relative drop in FID) and unable to adapt at the sample level. AdaNAT can directly enhance these models as a post-processing tool.
  • vs AutoNAT (CVPR 2024): While both optimize NAT policies, AutoNAT uses FID as the optimization target. AdaNAT reveals that the FID target leads to poor visual quality (despite low FID scores) and proposes adversarial rewards as a superior alternative. Quantitatively, AdaNAT-FID (2.56) < AutoNAT (2.68), but AdaNAT-Adv (2.86) is superior in visual quality and diversity.
  • vs RL Optimization for Diffusion Models (DPOK, DDPO, etc.): These methods directly fine-tune diffusion models to align with human preferences. In contrast, AdaNAT does not modify the generator but only optimizes the policy configuration—these two paradigms are orthogonal and applicable to different generative models.
  • Comparison Insights with EVATok: EVATok adopts a two-step paradigm of "offline optimal estimation followed by training a router to mime" to allocate adaptive lengths for video tokens, whereas AdaNAT employs RL to learn the optimal policy end-to-end. Both paradigms have pros and cons: EVATok is more stable but requires an offline search, whereas AdaNAT is more flexible but relies heavily on reward design.
  • Unified Framework of RL-optimized Generation Policies: The core framework of AdaNAT (frozen generator + RL policy network + adversarial reward) could be extended to other generative models that require complex inference-time configuration adjustments, such as adaptive steps, CFG, or sampler selection in diffusion models.

Rating

  • Novelty: ⭐⭐⭐⭐ Modeling the NAT policy configuration as an MDP and using RL to optimize it is not entirely brand new, but the insights on adversarial reward design and the systematic comparison are highly valuable.
  • Experimental Thoroughness: ⭐⭐⭐⭐ Four datasets cover class-conditional and text-to-image generation with detailed ablations and analyses, although large-scale validation is lacking.
  • Writing Quality: ⭐⭐⭐⭐⭐ The problem motivation is clear, and the progressive analysis of the three reward designs reads like an engaging story.
  • Value: ⭐⭐⭐⭐ The post-processing optimization paradigm of "modifying the policy without changing the model" and the adversarial reward design both exhibit good generalizability/transferability.