ToMA: Token Merge with Attention for Diffusion Models¶
Conference: ICML2025
arXiv: 2509.10918
Code: github.com/WenboLuu/ToMA
Area: Image Generation
Keywords: token merging, diffusion model acceleration, submodular optimization, GPU-aligned efficiency, training-free
TL;DR¶
This paper proposes ToMA, which reformulates token merging as a submodular optimization problem and implements merge/unmerge via attention-like linear transformations. This makes it compatible with GPU-optimized schemes like FlashAttention, achieving actual end-to-end speedups of 24% and 23% on SDXL and Flux, respectively, with negligible image quality degradation (DINO \(\Delta < 0.07\)).
Background & Motivation¶
Background¶
Diffusion Models have made breakthrough progress in the field of high-fidelity image generation. However, their core Transformer architecture faces an \(O(N^2)\) quadratic complexity bottleneck due to the self-attention mechanism. As the number of tokens grows, inference latency is scaled up exponentially during the multi-step denoising process.
Limitations of Prior Work¶
Existing acceleration methods can be divided into two main categories:
Attention Optimization: Methods such as FlashAttention and xFormers optimize attention computation through hardware-aware memory access patterns, which are already approaching hardware efficiency limits.
Token Reduction: Plug-and-play methods like ToMeSD and ToFu reduce FLOPs by merging redundant tokens. However, they suffer from a key limitation—their merge/unmerge operations rely on GPU-unfriendly primitives (such as sorting, scatter-writes, etc.). The introduced overhead counteracts the theoretical speedup when used in conjunction with highly efficient attention implementations like FlashAttention.
Key Insight¶
When attention computation is optimized close to hardware limits by FlashAttention, the bottleneck shifts from attention computation to the token merging operation itself. At this stage, the merging overhead of prior methods (e.g., ToMeSD) becomes dominant, making actual acceleration unachievable. This constitutes a gap between theoretical FLOPs reduction and actual wall-clock speedup.
Method¶
Overall Architecture¶
ToMA (Token Merge with Attention) is a training-free plug-and-play framework that consists of three core stages:
- Facility Location Algorithm (Destination Token Selection): Selects a representative subset \(D \subset N\) from all \(N\) tokens via submodular optimization to maximize representation diversity.
- Attention-based Merge: Constructs a low-rank attention matrix to transform \(N \rightarrow D\) through linear operations, executing Self-Attention, Cross-Attention, and MLP within the reduced space.
- Inverse Unmerge: Recovers full-resolution features by Transforming \(D \rightarrow N\) via a pseudo-inverse operation.
The entire workflow is efficiently executed through localized processing (local windows) and parallel batch optimization.
Key Designs¶
Design 1: Submodular Optimization-driven Token Selection¶
ToMA models token merging as a Facility Location Problem, which is a classic submodular maximization problem. The objective function satisfies the property of diminishing returns, and a greedy algorithm provides a guaranteed \((1 - 1/e)\) approximation ratio. This ensures that the selected set of destination tokens covers all token information in an approximately optimal manner. Compared to the heuristic matching of ToMeSD, this design provides theoretical guarantees.
A submodular function is defined on the subsets of a ground set \(V\), satisfying: for any \(A \subseteq B \subseteq V\) and element \(v \in V \setminus B\), we have \(f(v|A) \ge f(v|B)\), where \(f(v|A) = f(A \cup \{v\}) - f(A)\) is the marginal gain.
Design 2: Attention-like Linear Transformation for Merge/Unmerge¶
- Merge: Constructs an attention weight matrix \(M \in \mathbb{R}^{D \times N}\) to aggregate tokens via matrix multiplication, which is equivalent to the linear transformation of the attention mechanism.
- Unmerge: Recovers the original token dimension using the pseudo-inverse \(M^+\) of \(M\).
This design makes full use of batched matrix multiplication (batched matmul) on the GPU, avoiding GPU-unfriendly operations like sorting and scatter-writes, and is fully compatible with FlashAttention.
Design 3: Exploiting DMs' Inherent Properties to Reduce Overhead¶
ToMA leverages two inherent properties of diffusion models to further reduce computational overhead:
| Property | Description | Acceleration Strategy |
|---|---|---|
| Latent Space Locality | Tokens in the latent space exhibit spatial local correlation | Parallelize merge operations within non-overlapping local windows (e.g., \(8 \times 8\) patches) to reduce optimization scale |
| Sequential Redundancy | Merge patterns are highly similar across adjacent denoising steps and consecutive Transformer layers | Reuse merge patterns across time steps and layers to amortize optimization overhead |
These two strategies significantly decrease the invocation frequency of the Facility Location algorithm, which, combined with local window partitioning, keeps the token scale of each optimization controllable.
Core Differences from Prior Work¶
| Dimension | ToMeSD | ToFu | ToMA |
|---|---|---|---|
| Token Selection | Heuristic greedy matching | Switch merge/prune based on linearity test | Submodular optimization (theoretical guarantee) |
| Merge Operation | Unweighted average + sorting | Same as ToMeSD | Attention-like matrix multiplication |
| Unmerge Operation | Copy destination embedding | Same as above | Pseudo-inverse linear transformation |
| GPU Friendliness | Low (sorting, scatter-writes) | Low | High (batched matmul) |
| Compatibility with FlashAttention | Overhead cancels speedup benefit | Similar | Fully compatible, achieving actual acceleration |
| Theoretical Guarantee | None | None | Submodular approximation ratio \((1-1/e)\) |
Key Experimental Results¶
Main Results¶
| Model | Method | Speedup | DINO \(\Delta\) | Note |
|---|---|---|---|---|
| SDXL-base | Original | 1.00\(\times\) | — | baseline |
| SDXL-base | +FlashAttention2 | — | — | Attention optimized |
| SDXL-base | +ToMeSD (on FA2) | \(\approx\)1.00\(\times\) | — | Overhead cancels speedup, no actual gain |
| SDXL-base | +ToMA (ratio=0.5, on FA2) | 1.24\(\times\) | <0.07 | 24% actual end-to-end speedup |
| Flux.1-dev | +ToMA | 1.23\(\times\) | <0.07 | 23% actual speedup |
Cross-GPU Architecture Validation¶
| GPU Architecture | Supported | Note |
|---|---|---|
| NVIDIA RTX 6000 | ✅ | Main experimental platform |
| NVIDIA V100 | ✅ | Evaluates cross-generational compatibility |
| NVIDIA RTX 8000 | ✅ | Evaluates cross-generational compatibility |
ToMA achieves SOTA speedup on a variety of GPU architectures, demonstrating that its GPU-aligned design possesses excellent hardware generalization.
Key Findings¶
- ToMeSD Failure Case: When paired with FlashAttention2, the sorting and scatter-write overhead of ToMeSD dominates the computation time, resulting in actual speeds that can be even lower than the baseline FA2 without token reduction.
- Quality Retention: At a merge ratio of 0.5 (i.e., reducing 50% of tokens), ToMA changes the DINO score by \(< 0.07\), yielding almost imperceptible degradation in image quality.
- Actual vs. Theoretical Acceleration: A key finding emphasized in this work is that theoretical FLOPs reduction does not equate to actual wall-clock speedup. The coordination between system-level optimization and algorithm design is of paramount importance.
Highlights & Insights¶
- Profound Problem Insight: The authors accurately identify the overlooked key issue of "theoretical FLOPs reduction \(\ne\) actual speedup", pointing out that once attention is optimized by FlashAttention, the GPU-unfriendliness of merging operations becomes the new bottleneck.
- Algorithm-System Co-design: This work tightly couples the theoretical guarantees of submodular optimization with GPU execution paradigms (batched matmul, local window parallelism), achieving a complete loop from theory to practice.
- Training-free, Plug-and-play: It requires no retraining and can be directly embedded into existing diffusion model inference pipelines, demonstrating high practicality.
- Solid Theoretical Foundation: Based on the \((1 - 1/e)\) approximation guarantee of submodular optimization, the method boasts a more solid theoretical foundation compared to heuristic methods.
- Dual Redundancy Exploitation: The design simultaneously leverages both latent space locality and sequential redundancy—two inherent properties of diffusion models—to maximize the reduction of amortized merge overhead.
Limitations & Future Work¶
- Limited Cached Content: The full-text cache of the paper only covers up to the Preliminaries section. Detailed ablation studies and broader quantitative comparisons were unavailable, which might omit some experimental details.
- Fixed Merge Ratio: The ratio of \(0.5\) shown in the paper is static; whether it supports adaptive dynamic scaling is not fully discussed.
- Single Evaluation Metric: It primarily relies on the DINO score to evaluate image quality, lacking more comprehensive generation quality evaluations such as FID or CLIP scores (which could be in the full paper but were missing from the cache).
- Local Window Partitioning: It remains to be explored whether the fixed \(8 \times 8\) window size is optimal, or if adaptive window size adjustment is needed under different resolutions.
- Video Diffusion Models: The method is currently validated only on image diffusion models. Its extensibility to video diffusion models (e.g., Sora-like architectures) is worth exploring.
- Combination with Quantization/Distillation: While the paper mentions that ToMA is compatible with orthogonal approaches, the combined effects with compression techniques like quantization and knowledge distillation are not deeply explored.
Related Work & Insights¶
Efficient Vision Transformers¶
- Compact Architectures: Architectures like Swin Transformer and PVT reduce complexity through structural design, but require retraining.
- Pruning Strategies: Methods like X-Pruner accelerate inference by removing redundant structures, but require post-training fine-tuning.
- Knowledge Distillation: Methods like DeiT transfer knowledge from large models to smaller ones.
- Post-Training Quantization: Lowers weight/activation precision to reduce computation.
Token Reduction Methods¶
- Learned Methods: DynamicViT (generating pruning masks via MLP) and A-ViT (halting probability) require extra training.
- Heuristic Methods: ATS (reliant on class tokens, not applicable to generative tasks), ToDo (only downsampling KV, with limited effect), ToMeSD (GPU-unfriendly greedy matching), and ToFu (dynamically switching between merge and prune).
Insights¶
The core insight of ToMA is that the design of acceleration methods must align with the underlying hardware execution model. If theoretical FLOPs reduction relies on GPU-unfriendly operations, it can actually backfire in engineering practice. This approach holds significant reference value for token reduction design in other domains, such as LLM inference acceleration and video model acceleration.
Rating¶
- Novelty: ⭐⭐⭐⭐ — Reformulates token merging as a submodular optimization and implements it with attention-like transformations, presenting a fresh perspective.
- Experimental Thoroughness: ⭐⭐⭐⭐ — Validated across multiple models and GPU architectures, though complete ablation data is missing in the cache.
- Writing Quality: ⭐⭐⭐⭐ — Motivation is clearly articulated, presenting a cohesive theory-system-experiment trinity.
- Value: ⭐⭐⭐⭐⭐ — Resolves the practical engineering pain point of incompatibility between token reduction and highly efficient attention, offering high practical value.