Gated Tree Cross-Attention for Checkpoint-Compatible Syntax Injection in Decoder-Only LLMs¶
Conference: ACL 2026
arXiv: 2602.15846
Code: https://github.com/Pineandgrass/GatedTreeCrossAttention
Area: LLM Architecture / Syntax Injection / Checkpoint-Compatible
Keywords: GTCA, syntax injection, checkpoint-compatible, constituency chunk memory, token update mask
TL;DR¶
The authors attach a Gated Tree Cross-Attention (GTCA) side-branch to frozen decoder-only LLMs (Qwen-2.5-7B, Llama-3-8B). An offline Berkeley parser pre-computes constituency trees, which are indexed by height into chunk memory. Token hidden states read this memory via head-wise gated cross-attention to obtain residual updates. Combined with a token update mask and three-stage training to prevent interference, BLiMP accuracy improves from 78.58/79.95 to 83.12/84.61, while performance on MCQA, HellaSwag, and WinoGrande remains stable.
Background & Motivation¶
Background: While decoder-only LLMs achieve high scores on aggregate benchmarks, they often fail fine-grained syntactic stress tests (BLiMP, HANS, CoLA). The user experience manifests as "different phrasing of the same meaning yielding opposite answers," a fragility that cascades into downstream reasoning. Probing work has repeatedly demonstrated that internal hidden states of LLMs can recover dependency geometry (Hewitt & Manning 2019), indicating that syntax is "encoded."
Limitations of Prior Work: ① "Encoding ≠ usage"; recoverability does not imply effective utilization—GPT-2 remains far from human performance on BLiMP. ② Prevailing syntax injection methods (modifying attention bias, tree-RNNs, dependency-aware attention) usually require architectural rewrites or full retraining, which is unfriendly to pre-trained LLMs and may trigger catastrophic forgetting. ③ Parameter-efficient methods like LoRA/QLoRA do not modify the attention structure and cannot introduce inductive biases like "explicit tree structures."
Key Challenge: The goal is to introduce explicit syntactic hierarchical signals into a pre-trained checkpoint without modifying the backbone or interfering with pre-trained competence. Crucially, this must not affect likelihood-based MCQA scoring (modifying hidden states of option tokens would contaminate the relative likelihood between options).
Goal: Construct a "pluggable and bypassable" syntax injection path that allows the model to learn when and how much to trust syntactic info, ensuring stable improvements on syntactic benchmarks without harming general capabilities while keeping the backbone fixed.
Key Insight: The authors pre-compute constituency parse trees offline and cache them via hashes (eliminating parser overhead during training). Trees are sliced into chunk memory by tree height and fed to corresponding Transformer layers in a layer-aligned fashion—higher layers receive higher-level chunks, while lower layers receive leaf chunks—aligning hierarchical inductive bias with the natural stratification of Transformers.
Core Idea: Treat the tree structure as an external cache + gated attention source for the decoder-only LLM—similar to RAG, but retrieving syntactic chunks rather than documents—then use head-wise gates to let the model autonomously learn whether to utilize the information.
Method¶
Overall Architecture¶
GTCA is a forward wrapper side-branch attached to a decoder-only Transformer. For each input: ① An offline Berkeley Neural Parser computes constituency trees (cached via hash index). ② The parse-tree encoder slices trees by height, mean-pools span token embeddings, and applies height-specific projections \(W_{h(u)}\) + LayerNorm to obtain chunk memory \(C^\ell\). ③ At the \(\ell\)-th layer, the pre-update hidden state \(H_{\text{pre}}^\ell\) acts as a query to read \(C^\ell\) through head-wise gated cross-attention, outputting a residual \(\Delta H^\ell\). ④ A token update mask applies \(\Delta H^\ell\) only to question and answer fields (forcing mask=0 for option tokens) to produce \(H_{\text{post}}^\ell\) for the next layer. Backbone parameters remain frozen or use low-rank adaptation, and the GTCA branch can be bypassed at any time.
Key Designs¶
-
Height-aligned Chunk Memory:
- Function: Slices constituency trees into multi-layer chunks by height, establishing a one-to-one correspondence with Transformer layers to feed local structures to lower layers and global structures to higher layers.
- Mechanism: Define chunk height as \(h(u)=D-\text{depth}(u)\), where leaf tokens have height 0 and the root has the maximum height. For each chunk, span mean-pooling is applied: \(p_u=\text{MeanPool}(E^i, i\in S(u))\), followed by a height-specific projection \(W_{h(u)}\in\mathbb{R}^{d\times d}\) and LayerNorm to get \(c_u\). Chunk memory \(C^\ell\) at layer \(\ell\) only contains chunks where \(h(u)=\min(\ell, D)\), retaining up to \(K=64\) chunks in BFS order.
- Design Motivation: Transformers possess an inductive bias where "lower layers capture local syntax and higher layers capture global semantics." Aligning chunk memory with this hierarchy allows the model to integrate external tree information naturally and prevents high-level tokens from being disturbed by low-level noise.
-
Head-wise Gated Cross-Attention:
- Function: Uses cross-attention to allow token states to read chunk memory, with gates determining the trust level for each head.
- Mechanism: Standard cross-attention is used: \(Q=H_{\text{pre}}^\ell W_Q^\ell\), \(K=C^\ell W_K^\ell\), \(V=C^\ell W_V^\ell\), supplemented by a head-wise gate logit \(G^\ell = H_{\text{pre}}^\ell W_G^\ell\). The attention output is multiplied by a sigmoid gate: \(\text{Gated\_Attn}^\ell = \text{Attn}^\ell \odot \sigma(G^\ell)\). Using a scalar gate per head (rather than element-wise) requires fewer parameters and ensures stable training. A causal mask prevents tokens from attending to chunks whose right boundary exceeds the current token. Finally, \(\Delta H^\ell = \text{Merge}(\text{Gated\_Attn}^\ell)W_O^{ca,\ell}\) is injected as a residual.
- Design Motivation: Hard injection (without gating) forces the backbone to absorb chunk signals at every layer, potentially destroying pre-trained representations. Head-wise gates allow each head to learn "does my current query require syntax," effectively transforming the "explicit tree as a hard constraint" into an "explicit tree as an optional prior."
-
Token Update Mask + Three-stage Training:
- Function: Restricts the scope of interference (spatial) and manages the timing of interference (temporal) to prevent catastrophic forgetting.
- Mechanism: ① Spatial: Define a binary mask \(m_{\text{tok}}\in\{0,1\}^n\) such that \(H_{\text{post}}^\ell \leftarrow H_{\text{pre}}^\ell + \alpha_{\text{struct}}(m_{\text{tok}} \odot \Delta H^\ell)\). Option tokens in MCQA are forced to \(m_{\text{tok}}=0\) because likelihood-based scoring depends on their log-probabilities; modifying them would distort the answer distribution. ② Temporal: A three-stage training schedule first trains the GTCA branch alone, then jointly fine-tunes affected submodules, and finally opens more parameters to prevent large initial \(\Delta H\) from disrupting pre-trained states.
- Design Motivation: Addressing the "checkpoint-compatible" constraint—any modification to hidden states can degrade pre-trained capabilities. The mask ignores sensitive spatial regions, and the stage schedule handles temporal volatility, making it possible for MCQA performance to remain stable after adding GTCA.
Loss & Training¶
Continued training employs standard language modeling loss in an MCQA-friendly format. The three-stage schedule freezes the backbone in Stage 1 to train GTCA projections and gates, opens submodules interacting with GTCA in Stage 2, and allows full convergence at a low learning rate in Stage 3. The chunk limit is set to \(K=64\), and a scaling factor \(\alpha_{\text{struct}}\) controls residual magnitude. Offline parsing is performed via the Berkeley Neural Parser with hash-based caching.
Key Experimental Results¶
Main Results (BLiMP Syntactic Capability)¶
| Model | Baseline BLiMP | + GTCA | \(\Delta\) |
|---|---|---|---|
| Qwen-2.5-7B | 78.58 | 83.12 | +4.54 |
| Llama-3-8B | 79.95 | 84.61 | +4.66 |
| Category | Task | Baseline | + GTCA | Note |
|---|---|---|---|---|
| Syntax | BLiMP | 78.58-79.95 | 83.12-84.61 | +4-5 pp |
| Syntax | CoLA (GLUE) | — | Consistent Improvement | Grammaticality |
| MCQA | CLOTH | — | Stable/Slight Rise | Cloze test |
| MCQA | MMLU | — | Stable/Slight Rise | Knowledge QA |
| Common Sense | HellaSwag | — | Stable | Continuation |
| Common Sense | WinoGrande | — | Stable | Coreference |
Ablation Study¶
| Configuration | Key Metric | Interpretation |
|---|---|---|
| Full GTCA (height-aligned + gated + mask + stage) | BLiMP 83.12 | Complete model |
| w/o head-wise gate (hard injection) | Significant Drop | Contaminated backbone representations |
| w/o token update mask (modifying options) | MCQA Regression | Option likelihood drift |
| w/o staged training | Unstable / Performance Drop | Large \(\Delta H\) destroyed pre-trained states |
| Single projection (replaces \(W_{h(u)}\)) | BLiMP Slight Drop | Hierarchical coupling is necessary |
Key Findings¶
- Gating is the key to successful injection: Head-wise gates allow the model to decide whether to trust chunk memory. Compared to hard injection, this preserves pre-trained capabilities while selectively using structure—turning the "explicit syntax vs. implicit LLM capability" trade-off into a self-learning problem.
- Option tokens must be read-only: In MCQA tasks, applying syntactic updates to option tokens alters their log-probabilities, shifting answer selection. This serves as a vital engineering lesson for any continued training method on likelihood-based MCQA.
- Layer-aligned chunk memory provides interpretable utilization: UUAS probes show that GTCA enhances unlabeled undirected attachment consistency in hidden states, with higher layers relying on higher chunks and lower layers on leaf chunks—matching Transformer's own syntactic stratification.
- Syntactic gains do not come at the cost of general ability: Performance on all MCQA and common sense tasks remained stable or improved slightly, proving that the three safety mechanisms (gate + mask + stage) effectively isolated interference.
Highlights & Insights¶
- "Syntax as an External Retrieval Source" Paradigm: Treating the parse tree as a cacheable, bypassable, and gated external memory aligns syntax injection with the engineering philosophy of RAG. The LLM is never overwritten, and the external source remains hot-swappable. This paradigm could be extended to morphological, logical, or ontological plug-ins.
- Checkpoint-compatibility is a primary industrial concern: The cost of training modern LLMs is immense; any method requiring backbone rewrites or full retraining is rarely feasible for deployment. GTCA's forward wrapper approach provides strategic flexibility.
- Token update mask is a crucial, often overlooked detail: Many continued training methods fail to distinguish which tokens should be modified. This work demonstrates how hidden state changes leak into likelihood-based scoring, providing a valuable case study.
- Height-specific projection + layer alignment: Hard-binding tree height to Transformer layers is a natural yet rarely validated choice. This paper provides ablation evidence and UUAS probe analysis to support this intuition.
Limitations & Future Work¶
- Scalability: Tests were limited to ~7-8B decoder-only models; performance on 70B+, MoE, or hybrid architectures remains untested.
- Dependence on external constituency parsers: Parser errors are cached and injected; the robustness to parser noise was not extensively discussed.
- Engineering overhead: Storing chunk memory (hash indexing, span alignment) requires specific pipeline support; solutions for online/streaming scenarios were not provided.
- Evaluation: While BLiMP improvements are significant, they remain below human levels (~95+). The effect on specific long-range dependencies (e.g., islands, binding) was not disentangled.
- Comparison: A parameter-budget-equivalent comparison with LoRA/Adapters is missing—specifically, whether GTCA provides higher syntactic gains than LoRA given the same number of trainable parameters.
Related Work & Insights¶
- vs. Strubell et al. 2018 / Bugliarello & Okazaki 2020: These methods inject dependency info via self-attention bias, requiring training from scratch or full fine-tuning. GTCA leaves the backbone self-attention untouched.
- vs. Bai et al. 2021 (plug-in syntax): Similar "plug-in" philosophy but focused on encoder-only PLMs. GTCA adapts this to decoder-only LLMs and addresses the specific issue of MCQA-likelihood interference.
- vs. Iwamoto et al. 2023: They discuss catastrophic forgetting; GTCA's token update mask and staged training serve as concrete engineering solutions to this problem.
- vs. LoRA / PEFT: These are structure-agnostic, parameter-efficient fine-tuning methods. GTCA is a structure-aware, checkpoint-compatible injection path that is orthogonal to and potentially stackable with LoRA.
- vs. Hewitt & Manning 2019 probing: Probing proves "syntax is there"; this work uses UUAS probes before and after GTCA to provide causal evidence that explicit injection makes internal attachments more consistent.
Rating¶
- Novelty: ⭐⭐⭐⭐ Combining tree-plug-ins with head-wise gating and dual safety mechanisms is novel and robust.
- Experimental Thoroughness: ⭐⭐⭐⭐ Two backbones, 6 benchmarks, plus UUAS probes and ablations, though backbone scales were small.
- Writing Quality: ⭐⭐⭐⭐ Clear formulas and notation; safety mechanisms are well-explained.
- Value: ⭐⭐⭐⭐ Checkpoint-compatibility is highly practical for industry; the syntax plug-in paradigm has migration potential.