Multi-Head Low-Rank Attention (MLRA)¶
Conference: ICLR 2026 arXiv: 2603.02188 Code: GitHub / HuggingFace Area: Autonomous Driving Keywords: KV Cache, Tensor Parallelism, Low-Rank Attention, Decoding Efficiency, Multi-Head Latent Attention
TL;DR¶
This paper proposes Multi-Head Low-Rank Attention (MLRA), which decomposes the single latent head in MLA into multiple independently shardable latent heads and sums the attention outputs across branches, enabling native 4-way tensor parallelism. The method achieves 2.8× decoding speedup while maintaining state-of-the-art performance.
Background & Motivation¶
Long-context inference bottleneck: During LLM decoding, each step requires repeatedly transferring the KV cache between HBM and SRAM, making data movement—rather than computation—the dominant source of latency in long-context inference.
Strengths and limitations of MLA: DeepSeek's Multi-Head Latent Attention (MLA) compresses the KV cache into a single latent head (only \(4.5 d_h\) per token), substantially reducing total KV cache size. However, the single latent head is not shardable—under tensor-parallel (TP) decoding, each device is forced to redundantly load the full KV cache.
Performance stagnation of MLA under TP: Regardless of the TP degree, the per-device KV cache load in MLA remains fixed at \(4.5 d_h\), completely negating the weight-sharding benefits of tensor parallelism.
Partial solution by GLA-2: GLA-2 bisects the latent head into two smaller latent heads, reducing the per-device load to \(2.5 d_h\) under 2-way TP, but further reduction is not achievable for TP > 2.
Importance of arithmetic intensity: The key metric for decoding efficiency is arithmetic intensity (FLOPs/byte), which requires maintaining high computational density while reducing memory loads, thereby shifting the workload from memory-bound to compute-bound regimes.
Variance mismatch: In MLA, RoPE keys and NoPE keys exhibit a significant variance mismatch (\(\text{Var}(K^{\text{RoPE}}) / \text{Var}(K^{\text{NoPE}}) \approx d/d_c\)), which is particularly pronounced when the latent dimension is much smaller than the hidden dimension.
Method¶
Overall Architecture¶
The core idea of MLRA is to explicitly decompose the non-shardable single latent head in MLA into multiple independent latent heads. Each latent head independently performs an up-projection to produce NoPE KVs, and the attention outputs across branches are summed. This design natively supports 4-way tensor parallelism.
MLRA provides two variants: - MLRA-2: 2 latent heads, each serving half of the attention heads, yielding a 2-branch summed output. - MLRA-4: 4 latent heads, each serving all attention heads via up-projection, yielding a 4-branch summed output.
Key Designs¶
1. Block Decomposition¶
Function: Partitions the KV latent matrix \(C^{KV} \in \mathbb{R}^{n \times d_c}\) into 4 channel-wise blocks \(C_{:,(b)}^{KV}\), and correspondingly partitions the up-projection matrices \(W^{UK}\) and \(W^{UV}\) into 4 row-wise sub-blocks.
Design Motivation: In MLA, the NoPE key/value for each head is mathematically equivalent to the sum of 4 sub-block products (Eq. 2). Moving this summation from the KV computation to the attention output level enables each sub-block to compute attention independently.
Mechanism (MLRA-4 as an example): $\(O_{:,i,:} = \sum_{b=0}^{3} \text{Softmax}\left(\tau Q_{:,i,:}^{\text{NoPE}} (C_{:,(b)}^{KV} W_{(b),(i)}^{UK})^\top + \tau Q_{:,i,:}^{\text{RoPE}} (K^{\text{RoPE}})^\top \right) (C_{:,(b)}^{KV} W_{(b),(i)}^{UV})\)$
Each branch \(b\) depends only on \(C_{:,(b)}^{KV}\) (of size \(d_h\)) and can be assigned to separate TP devices for independent computation.
2. Grouped Mapping in MLRA-2¶
Function: Adopts the grouping strategy from GLA-2, bisecting the latent head so that the first half of attention heads uses the first latent group and the second half uses the second.
Design Motivation: Provides a lightweight alternative that balances model capacity and efficiency within a 2-branch summation.
Mechanism: A group mapping function \(\gamma(i)\) determines which latent group the \(i\)-th attention head uses; the output of each head is the sum of attention outputs from 2 branches.
3. Variance Calibration¶
Function: Applies scaling factors to the query and KV latent states, and normalizes the multi-branch attention outputs.
Design Motivation: Theoretical analysis shows that the variance of NoPE keys is \(d_c \sigma_w^2\) while that of RoPE keys is \(d \sigma_w^2\); branch splitting further exacerbates this mismatch. Multi-branch summation also alters the variance of the output.
Mechanism: - Latent state scaling: \(C^Q \leftarrow \sqrt{d/d_c'} \cdot C^Q\), \(C^{KV} \leftarrow \sqrt{4d/d_c} \cdot C^{KV}\) - Output scaling: MLRA-2 outputs are divided by \(\sqrt{2}\); MLRA-4 outputs are divided by \(2\).
4. TP-Friendly Decoding¶
Function: Enables native 4-way TP decoding, with each device loading only \(1.5 d_h\) of KV cache.
Design Motivation: The 4 latent blocks can be naturally distributed across 4 TP devices; each device loads only one latent block (\(d_h\)) plus the shared RoPE key (\(0.5 d_h\)).
Mechanism: Adopts MLA's weight absorption technique to absorb up-projections into the query side, then independently executes MQA-style decoding on each device, followed by an all-reduce summation.
Loss & Training¶
- Initialization: Output projection parameters use zero initialization (outperforming standard \(\mathcal{N}(0, 0.02)\)); remaining parameters use standard initialization.
- Optimizer: AdamW with \((\beta_1, \beta_2)=(0.9, 0.95)\), weight decay 0.1, gradient clipping 1.0.
- Learning rate: Peak \(1.6 \times 10^{-4}\), linear warmup for the first 2000 steps, followed by cosine annealing to 10%.
- Training scale: 2.9B parameters, FineWeb-Edu-100B dataset, 98.3B tokens, context length 2048, 8 × H100 GPUs.
- Optional gating: Adding a gating mechanism before the attention output projection can further reduce perplexity.
Key Experimental Results¶
Main Results¶
Validation perplexity (average over 7 datasets, lower is better):
| Method | Wikipedia | C4 | Pile | RefinedWeb | Cosmopedia | FineWeb | FineWeb-Edu | Avg. |
|---|---|---|---|---|---|---|---|---|
| MHA | 14.624 | 16.575 | 12.929 | 18.698 | 9.102 | 15.656 | 9.434 | 13.860 |
| GQA | 15.057 | 16.628 | 13.758 | 18.885 | 9.504 | 15.713 | 9.427 | 14.139 |
| MLA | 14.567 | 16.345 | 12.965 | 18.523 | 8.966 | 15.440 | 9.284 | 13.727 |
| GLA-2 | 14.605 | 16.323 | 13.225 | 18.509 | 9.118 | 15.424 | 9.249 | 13.779 |
| MLRA-4 | 14.407 | 16.286 | 13.124 | 18.398 | 8.937 | 15.361 | 9.193 | 13.672 |
Zero-shot commonsense reasoning accuracy (%):
| Method | ARC-E | ARC-C | OBQA | BoolQ | HellaSwag | Winogrande | PIQA | Avg. |
|---|---|---|---|---|---|---|---|---|
| MHA | 69.11 | 39.16 | 40.80 | 62.26 | 60.82 | 57.62 | 74.86 | 57.81 |
| GQA | 67.13 | 39.42 | 42.00 | 63.39 | 61.29 | 56.91 | 75.08 | 57.89 |
| MLA | 68.22 | 39.16 | 42.60 | 64.10 | 61.39 | 60.06 | 75.68 | 58.75 |
| MLRA-4 | 67.63 | 41.38 | 43.00 | 61.74 | 62.16 | 61.48 | 74.48 | 58.84 |
Ablation Study¶
Effect of gating on perplexity (average over 7 datasets):
| Method | w/o Gating | w/ Gating | Gain |
|---|---|---|---|
| GQA | 14.139 | 13.806 | -0.333 |
| MLA | 13.727 | 13.642 | -0.085 |
| GLA-2 | 13.779 | 13.701 | -0.078 |
| MLRA-2 | 13.804 | 13.651 | -0.153 |
| MLRA-4 | 13.672 | 13.621 | -0.051 |
Additional ablation findings: - Zero initialization vs. \(\mathcal{N}(0, 0.02)\): Zero initialization consistently outperforms random initialization across all models. - Variance scaling: MLA and GLA-2 benefit substantially; MLRA-2 benefits marginally, as the branch decomposition naturally mitigates variance mismatch. - Doubling attention heads: Doubling the number of heads in GQA/MLA/GLA-2 while keeping total parameters fixed does not improve and in fact degrades performance.
Key Findings¶
- MLRA-4 is universally optimal: It achieves the best perplexity (13.672) and zero-shot reasoning accuracy (58.84%) among all baselines, including MLA.
- 2.8× decoding speedup: MLRA-4 achieves a stable 2.8× speedup over MLA in long-context decoding ranging from 128K to 2M tokens.
- 1.05–1.26× over GQA: MLRA-4 outperforms GQA in long-context decoding, with the gap widening as context length increases.
- TP=4 suffices to reach \(1.5 d_h\): GQA and GTA require 8-way TP to achieve a comparable per-device KV load, whereas MLRA requires only 4-way TP.
- Gating further improves results: With gating, MLRA-4 perplexity drops to 13.621, remaining best-in-class.
Highlights & Insights¶
- Elegant mathematical motivation: Starting from the block decomposition of MLA, the method moves the summation from the KV level to the attention output level—mathematically concise and intuitively clear.
- Rigorous variance analysis: The paper theoretically derives the variance of each component and provides explicit calibration strategies, avoiding empirical hyperparameter search.
- Strong practical utility: MLRA shares the same total KV cache size as MLA (\(4.5 d_h\) per token); the only difference is shardability during distributed decoding, making migration from existing MLA systems low-cost.
- Arithmetic intensity analysis: MLRA-4 achieves an arithmetic intensity of approximately \(2h\) (same as MLA), demonstrating that reduced memory loads do not sacrifice computational efficiency.
- Complete engineering implementation: An MLRA-4 kernel is implemented on top of FlashAttention-3 and validated on a real H100 cluster.
Limitations & Future Work¶
- Validated only at 2.9B scale: No experiments at 7B or larger scales; it remains to be verified whether MLRA's advantages hold for larger models.
- Single pretraining dataset: Only FineWeb-Edu-100B is used; generalization to multilingual or code-mixed data is not validated.
- Limitations of Assumption 1: The variance analysis assumes i.i.d. weight distributions independent of inputs, an assumption that does not hold strictly during training.
- Fixed 4-way decomposition: MLRA-4 is tied to 4-way TP; 2-way or 8-way TP scenarios require MLRA-2 or further extension, respectively.
- Approximation in multi-branch summation: Moving the summation outside the softmax fundamentally changes the attention distribution; while empirically effective, this is not theoretically equivalent to the original formulation.
- Instruction tuning and alignment not evaluated: The study focuses solely on pretraining and does not cover SFT or RLHF stages.
Related Work & Insights¶
- MLA (DeepSeek-V2/V3): The direct predecessor of MLRA, which compresses KV cache via latent compression but does not support TP sharding.
- GQA: Grouped Query Attention achieves efficiency gains by reducing the number of KV heads, but KV cache still grows linearly with head count.
- GLA-2 (Zadouri et al., 2025): The first work to attempt splitting MLA's latent head, but supports only 2-way TP.
- TPA (Zhang et al., 2025): Tensor Product Attention constructs KVs as linear combinations of shared heads, with limited TP support.
- FlashMLA / FlashAttention-3: Efficient attention kernels; MLRA's kernel is implemented on top of FlashAttention-3.
- LongCat (2025): First to observe the RoPE key variance mismatch issue; MLRA adopts and extends its scaling strategy.
- Inspiration: The paradigm of "lifting an internal summation to independent multi-branch computation" may generalize to other settings requiring latent compression combined with TP, such as KV cache quantization and sparse attention.
Rating¶
- Novelty: ⭐⭐⭐⭐ — Precise insight; elegantly resolves the non-shardability of MLA through branch decomposition.
- Experimental Thoroughness: ⭐⭐⭐⭐ — Multiple ablations, multi-dataset evaluation, and decoding speed/throughput benchmarks, though limited to the 2.9B scale.
- Writing Quality: ⭐⭐⭐⭐⭐ — Rigorous mathematical derivations, clear notation, and a complete logical chain from background to methodology.
- Value: ⭐⭐⭐⭐ — Directly addresses a practical pain point in MLA deployment, with significant utility for the DeepSeek model family and large-scale inference systems.