LycheeDecode: Accelerating Long-Context LLM Inference via Hybrid-Head Sparse Decoding¶
Conference: ICLR 2026 arXiv: 2602.04541 Code: Not released Area: LLM Efficiency Keywords: sparse attention, long-context inference acceleration, attention head specialization, HardKuma distribution, KV cache optimization
TL;DR¶
This paper proposes LycheeDecode, a fine-grained hybrid-head sparse decoding method that partitions attention heads into a small number of "retrieval heads" and a large number of "sparse heads," employing the HardKuma distribution for differentiable head-type identification. The method achieves a 2.7× speedup under 128K context while matching or surpassing full-attention baselines.
Background & Motivation¶
Long-context LLMs (supporting million-scale tokens) face severe bottlenecks during decoding: linearly growing KV caches lead to dramatically increased memory consumption and computational latency. Existing sparse attention methods fall into two categories:
- Eviction-based methods (StreamingLLM, SnapKV, H₂O): permanently discard tokens, causing irreversible information loss.
- Selection-based methods (SeerAttention, TidalDecode, RetrievalAttention): retain the full KV cache and dynamically select subsets.
Core Observation: Recent works (TidalDecode, OmniKV) find that critical tokens across adjacent layers are highly similar, and thus adopt a layer-level sharing strategy — all heads share the same set of critical tokens. However, this is overly coarse-grained:
- As shown in Figure 2, the top-k overlap rate varies dramatically across heads within the same layer (0% for head 14, 100% for head 24).
- A uniform layer-level sharing strategy forces all heads to perform identical functions, ignoring the functional diversity of heads.
Furthermore, methods such as DuoAttention learn continuous variables to classify head types, but must round to binary values at inference time, introducing train-inference inconsistency.
Method¶
Overall Architecture¶
LycheeDecode partitions attention heads into two categories:
- Retrieval Heads (\(\mathcal{H}_R\)): compute dense attention over the full sequence to dynamically identify the most important tokens.
- Sparse Heads (\(\mathcal{H}_S\)): reuse the critical token subset identified by retrieval heads for efficient sparse attention.
Key Designs¶
Critical Token Identification in Retrieval Heads:
Retrieval heads perform standard dense attention:
The top-k critical token indices are selected and propagated to the corresponding heads in the next layer:
Efficient Computation in Sparse Heads:
Sparse heads compute attention only over the inherited token subset \(\mathcal{S}_h^{(l)}\):
Sparse heads do not update the token set: \(\mathcal{S}_h^{(l+1)} = \mathcal{S}_h^{(l)}\).
Head Specialization via the HardKuma Distribution:
Head-type assignment is fundamentally a discrete optimization problem. LycheeDecode employs the Hard Kumaraswamy distribution as a differentiable surrogate for binary variables:
- Sampling: \(s = (1-u^{1/\beta})^{1/\alpha}\), where \(u \sim \mathcal{U}(0,1)\)
- Stretching: \(s' = s \cdot (q-p) + p\), where \(p < 0, q > 1\)
- Clipping: \(z = \min(1, \max(0, s'))\)
For each head \(h\) at layer \(l\), parameters \(\alpha_h^{(l)}, \beta_h^{(l)}\) are learned:
During training, the two attention maps are interpolated:
At inference, heads are deterministically assigned: heads with \(\mathbb{E}[z_h^{(l)}] > 0.5\) are designated retrieval heads, and the rest are sparse heads.
Loss & Training¶
Distillation Loss: aligns the logits of the hybrid-head student model with those of the full-attention teacher:
Sparsity Constraint (Lagrangian Relaxation):
\(\mathbb{E}[\|\mathbf{z}\|_0]\) admits a closed-form solution, and \(\lambda\) is adaptively updated via gradient ascent. Training requires only a few hours on a single A100 GPU for 3,000 steps.
Key Experimental Results¶
Main Results: Long-Context Understanding (LongBench)¶
| Method (Budget) | MFQA | NrtQA | TrQA | PRe | Avg |
|---|---|---|---|---|---|
| Llama-3-8B Full Attn | 30.76 | 5.52 | 86.56 | 77.00 | 32.33 |
| TidalDecode (4096) | 30.94 | 6.19 | 86.30 | 78.00 | 32.86 |
| LycheeDecode (4096) | 30.11 | 5.85 | 86.78 | 82.58 | 33.07 |
| Qwen3-8B Full Attn | 25.84 | 3.43 | 90.21 | 89.08 | 33.02 |
| SeerAttention-R (4096) | 24.85 | 3.30 | 90.19 | 93.17 | 33.38 |
| LycheeDecode (4096) | 24.90 | 3.32 | 90.34 | 93.25 | 33.48 |
LycheeDecode achieves the best average score across all settings, even surpassing the full-attention model.
Mathematical Reasoning Tasks¶
| Method | AIME24 | OlympiadBench | Avg |
|---|---|---|---|
| DeepSeek-R1-Qwen-7B Full Attn | 40.0 | 10.2 | 43.0 |
| TidalDecode | 16.7 | 7.0 | 30.2 |
| TidalDecode + Cache Correction | 26.7 | 8.6 | 35.0 |
| LycheeDecode | 43.3 | 10.9 | 44.2 |
| LycheeDecode + Cache Correction | 46.7 | 12.5 | 44.9 |
On reasoning tasks, LycheeDecode surpasses the full-attention baseline, presumably because head specialization filters irrelevant contextual noise.
Ablation Study¶
Comparison of Head Identification Methods:
| Method | Passkey Retrieval | HotpotQA |
|---|---|---|
| Direct Optimize | 32.06 | 31.02 |
| Hard Concrete | 32.13 | 30.25 |
| HardKuma (ours) | 33.07 | 31.11 |
Acceleration Results:
- 128K context, single batch: 2.7× end-to-end speedup over FlashAttention-2.
- 128K context, batch=8: peak kernel-level speedup reaches 7×.
- Compared to TidalDecode: 1.73× faster.
Key Findings¶
- Head-level strategy outperforms layer-level: LycheeDecode surpasses TidalDecode at all budgets, validating the superiority of fine-grained head-level sharing.
- Reasoning performance exceeds full attention: The head specialization mechanism filters noise and focuses on critical information.
- Ratio-based sparsity is the most robust: Under equivalent sparsity, the Ratio method (token budget scaling with sequence length) generally performs best.
- Cache Correction provides further gains: Applying dense attention correction every 32 tokens effectively mitigates error accumulation.
Highlights & Insights¶
- Treating attention heads as functionally specialized units, rather than as a uniform block, is a promising direction for long-context inference optimization.
- The judicious choice of the HardKuma distribution naturally produces near-binary differentiable samples, avoiding the train-inference inconsistency caused by rounding continuous variables.
- Lightweight training: requires only a single A100 GPU for a few hours, with no auxiliary gating networks.
- Efficient kernels implemented in TileLang: auto-tuning search is used to find the optimal configuration for each layer.
Limitations & Future Work¶
- Validation is currently limited to Llama3-8B and Qwen3-8B; generalizability to larger models (70B+) remains to be confirmed.
- On tasks with sparse supervision signals (e.g., HotpotQA with short answers), HardKuma shows a slight performance degradation.
- The number of retrieval heads (32) is a fixed hyperparameter; adaptively determining the retrieval head budget may be preferable.
- Cache Correction introduces additional computational overhead, requiring a trade-off between accuracy and efficiency.
- The current approach does not support prefill-stage acceleration; it targets the decode stage only.
Related Work & Insights¶
- Improvement over DuoAttention: DuoAttention determines each head's role independently without a coordination mechanism; in LycheeDecode, retrieval heads actively propagate critical tokens to sparse heads.
- Distinction from TidalDecode: TidalDecode shares at the layer level (2 full-attention layers × 8 KV heads = 16 retrieval heads), whereas LycheeDecode shares at the head level, enabling finer granularity.
- Implications for long-context inference: Head functional specialization may be key to improving inference quality — not merely speed — as filtering irrelevant context demonstrably benefits reasoning.
Rating¶
- Novelty: ⭐⭐⭐⭐ — The combination of head-level sparse decoding and HardKuma-based head identification is a novel design.
- Technical Depth: ⭐⭐⭐⭐⭐ — The theoretical derivation of HardKuma is complete, and the kernel implementation is meticulous.
- Experimental Thoroughness: ⭐⭐⭐⭐⭐ — Covers long-context understanding, mathematical reasoning, efficiency evaluation, and multi-dimensional ablations.
- Writing Quality: ⭐⭐⭐⭐ — Well-structured with clearly articulated motivation.
- Value: ⭐⭐⭐⭐⭐ — Low training cost, significant speedup, suitable for practical deployment.
- Overall Score: ⭐⭐⭐⭐⭐ (9/10)