HATA: Trainable and Hardware-Efficient Hash-Aware Top-k Attention for Scalable Large Model Inference¶
Conference: ACL 2025
arXiv: 2506.02572
Code: Yes (https://github.com/gpzlx1/HATA)
Area: Others
Keywords: Top-k Attention, Learning-to-Hash, KVCache, LLM Inference, Hardware-Efficient
TL;DR¶
HATA proposes a method that integrates learning-to-hash technology into the top-k attention mechanism. By mapping queries and keys to binary hash codes to retrieve relative qk score rankings (rather than absolute score estimations), it achieves up to a 7.2x speedup compared to full attention while maintaining model accuracy.
Background & Motivation¶
Even with KVCache acceleration, the attention module remains a critical bottleneck in LLM inference. When processing 32K token sequences, the attention module of Llama2-7B consumes over 70% of the execution time. This bottleneck stems from two aspects:
Computational Complexity: Attention computation scale quadratically with sequence length.
Memory Bandwidth: Every decoding step requires loading the entirety of the cached Key and Value vectors.
Top-k attention exploits the sparsity of attention distributions, keeping only the \(k\) most relevant tokens to compute attention. However, existing methods suffer from an efficiency-accuracy trade-off:
- Low-rank methods (Loki, InfiniGen): Dot products are computed on a projected dimension subset, but channel extraction is expensive, and maintaining accuracy requires a sufficient number of dimensions.
- Block-based methods (Quest, InfLLM): Divide KV into blocks and estimate the upper bound of block-level scores, but the granularity is too coarse—essential tokens are scattered across different blocks, and selecting whole blocks unnecessarily loads irrelevant KV pairs.
Core Insight: Existing methods assume that precise estimation of the absolute qk scores is necessary, but in reality, only the relative ordering is required. Determining whether "\(s_{qk_i} > s_{qk_j}\)" is significantly lower in cost than precisely calculating the qk scores. This shifts the challenge from "numerical regression" to "ordinal comparison."
Method¶
Overall Architecture¶
HATA operates in three phases: 1. Offline Training: Learns hash functions from real data, ensuring similar qk pairs remain close in Hamming space. 2. Prefill Phase: Computes normal attention while encoding Keys into binary hash codes and caching them. 3. Decode Phase: Encodes the new query into a hash code \(\rightarrow\) fast-tracks top-k Key selection via Hamming distance \(\rightarrow\) computes exact attention only for the selected KV pairs.
Key Designs¶
- Learning-to-Hash Modeling:
The hash function is defined as \(h(x) = 2 \cdot \text{Sigmoid}(\sigma \cdot x W_H) - 1\), where \(W_H\) is a trainable hash weight matrix.
The optimization objective comprises three terms: - Similarity Preservation: \(\epsilon \sum_j \sum_i s_{j,i} \|h(q_j) - h(k_{j,i})\|^2\) — similar qk pairs are assigned neighboring hash codes. - Bit Balance: \(\eta \sum_j \|\sum_i h(k_{j,i})\|^2\) — assuring each bit in the hash codes has a rough balance of \(+1\) and \(-1\). - Decorrelation: \(\lambda \|W_H^T W_H - I_r\|\) — encouraging different hash bits to encode diverse information.
The Sigmoid function is utilized to approximate the non-differentiable sign function, where \(\sigma\) controls the smoothness of the approximation. One \(W_H\) is trained independently for each attention head.
-
Training Data Construction:
- Q and K are collected from each attention head during the prefill phase.
- Sample \(q_j\) and calculate its qk scores with \([k_1,...,k_j]\).
- The top 10% of qk pairs are labeled as positive samples (linearly decaying labels \(s \in [1,20]\)), while the remaining 90% are negative samples (\(s = -1\)).
- Thousands to millions of training pairs can be generated per sequence, sourced from dozens of sequences to enhance diversity.
-
Decode Phase Algorithm:
- Perform
HashEncode(matrix multiplication + Sign + BitPack) on the new query and key. - Append the key hash codes to the hash code cache.
- Calculate the Hamming distance between the query hash code and all cached key hash codes:
bitcount(bitwise_xor(Q_H, K_H_cache)). - In Grouped-Query Attention (GQA) scenarios, aggregate scores across multiple queries sharing the same KVCache.
- Select the top-k, gather the corresponding KV pairs, and execute sparse attention.
- Perform
Hardware-Efficient Optimizations¶
- Kernel Fusion: Fuses linear projection \(\rightarrow\) Sign \(\rightarrow\) BitPack \(\rightarrow\) cache update into a single CUDA kernel, avoiding frequent CPU-GPU synchronization.
- High-Performance Hamming Score Operator: Achieves bit counting based on GPU
XOR+popc/popcllinstructions, optimizing bandwidth through coalesced memory access. - Gather + FlashAttention Fusion: Integrates the Gather operation into the FlashAttention kernel, reducing redundant data transfer between HBM and SRAM.
Implementation Code: 1,470 lines of C++/CUDA + 940 lines of Python, integrated into existing inference frameworks as a plug-in.
Key Experimental Results¶
Main Results: LongBench-e Accuracy (512 token budget)¶
| Task | Dense | Loki | Quest | MagicPIG | H2O | SnapKV | HATA |
|---|---|---|---|---|---|---|---|
| Llama-2-7B AVG | 34.47 | 32.78 | 32.64 | 34.09 | 9.57 | 24.96 | 34.60 |
| Llama-3.1-8B AVG | 54.10 | 53.23 | 52.19 | 47.61 | 49.89 | 51.00 | 53.94 |
HATA closely matches the performance of Dense (full attention) on both models and even outperforms Dense on certain tasks (e.g., HQA 15.65 vs. 15.30).
NIAH (Needle in a Haystack) Test¶
| Task | Dense | Loki | Quest | HATA |
|---|---|---|---|---|
| NS1 (Llama-2) | 93.75 | 25.00 | 100.0 | 100.0 |
| NS2 (Llama-2) | 100.0 | 2.08 | 95.83 | 98.96 |
| NS3 (Llama-2) | 91.67 | 0.00 | 52.08 | 83.33 |
| NS1 (Llama-3.1) | 100.0 | 98.96 | 100.0 | 98.96 |
| NS3 (Llama-3.1) | 100.0 | 96.88 | 47.92 | 100.0 |
HATA maintains performance close to or exceeding Dense across multi-level difficulties in the NIAH test, while Loki suffers from a performance collapse on Llama-2.
Inference Speed (token/s, Llama-3.1-8B, A800)¶
| Sequence Length | Dense | Loki | Quest | HATA |
|---|---|---|---|---|
| 32K | Baseline | ~1.2× | ~1.5× | ~2× |
| 64K | Baseline | ~1.5× | ~2.5× | ~4× |
| 128K | Baseline | ~2× | ~4× | ~7.2× |
The longer the sequence, the greater the speedup achieved by HATA—reaching a 7.2x speedup at 128K.
Key Findings¶
- Optimal Accuracy-Efficiency Balance: Among all top-k methods, HATA resides closest to full attention in terms of accuracy while producing larger speedups.
- Ordinal Comparison is Sufficient: Experiments confirm that only the relative ordering of qk scores is needed (rather than absolute value estimation) to achieve high-quality top-k selection.
- Negligible Prefill Overhead: The extra overhead of hash encoding is less than 1% of the total computation (since \(r_{\text{bit}} \ll\) sequence length).
- Impact of Hash Code Dimensions: Among \(r\)-bit hash codes, too small an \(r\) degrades accuracy, while too large an \(r\) reduces efficiency; 128 bits is typically a good sweet spot.
- GQA Compatibility: By aggregating scores from multiple queries sharing the same KVCache, HATA naturally supports GQA architectures.
Highlights & Insights¶
- Profound Reformulation of the Problem: Reconceptualizing top-k attention from "precise score estimation" to "ordinal comparison" is the most significant contribution of this work. Absolute scores are indeed unnecessary for top-k selection, and this insight unlocks an entirely new optimization space.
- Perfect Synergy with Hash Technology: Learning-to-hash naturally addresses the ordinal comparison problem (where Hamming distance preserves order), and binary operations (
XOR+popcount) are extremely efficient on GPUs. - High Engineering Completeness: This work features not only algorithm design but also comprehensive system-level optimizations, including kernel fusion, highly optimized Hamming operators, and FlashAttention integration.
- Plug-and-Play Design: Users can easily adopt it by replacing the standard attention mechanism with HATA attention, lowering the barrier to deployment.
Limitations & Future Work¶
- Requirement of Offline Weight Training: Each layer and head of every model requires training its own \(W_H\). Although training overhead is low, it adds steps to the deployment pipeline.
- Fixed Token Budget Assumption: Current experiments rely on a fixed budget of 512 tokens; an adaptive budget could yield better results.
- Acceleration Only in Decode Phase: The prefill phase still uses full attention, leaving the latency of long-context prefill unoptimized.
- Limited Model Coverage: Evaluation is mainly conducted on Llama-2 and Llama-3.1, leaving MoE models or larger-scale models unvalidated.
Related Work & Insights¶
- SparQ's dimension subsetting approach tackles the same problem from a different angle (numerical precision vs. ranking precision), while HATA's ordinal comparison paradigm is more fundamental.
- Block-based methods like Quest and InfLLM complement token-level methods. A two-level hierarchy could be considered (coarse block-level filtering followed by fine-grained hash-based selection).
- Comparison with Hash-RAG (2505.16133): Both represent applications of learning-to-hash in NLP. However, HATA focuses on intra-attention qk matching, while Hash-RAG targets external knowledge base retrieval.
- Insights for KVCache compression approaches (e.g., H2O, SnapKV): HATA eliminates the risk of dropping crucial information by keeping all KV pairs and selecting them more efficiently, rather than permanently discarding them.
Rating¶
- Novelty: ⭐⭐⭐⭐⭐ — The core insight that "only ordering is needed, not absolute scores" is highly profound. The application of learning-to-hash in LLM attention is pioneering.
- Experimental Thoroughness: ⭐⭐⭐⭐⭐ — Features comprehensive evaluations across 13 tasks in LongBench-e, multi-level difficulty in NIAH, two distinct models, comparisons against multiple baselines, detailed speedups, and thorough ablation analyses.
- Writing Quality: ⭐⭐⭐⭐ — The motivation is clearly articulated, the algorithm pseudocode is comprehensive, and Figures 1/2/3 present high information density. However, equations in the learning-to-hash background section are somewhat dense.
- Value: ⭐⭐⭐⭐⭐ — Delivers up to a 7.2x speedup without accuracy degradation, offering direct and significant practical value for deploying long-context LLMs. The open-source code operates out-of-the-box in a plug-and-play manner.