NeuroCircuitry-Inspired Hierarchical Graph Causal Attention Networks for Explainable Depression Identification¶
Conference: ICLR 2026
arXiv: 2511.17622
Code: GitHub
Area: Medical Imaging
Keywords: Brain network GNN, depression identification, causal attention, hierarchical circuit encoding, explainable AI
TL;DR¶
This paper proposes the NH-GCAT framework, which explicitly incorporates neuroscience priors on depression-related neural circuits into a GNN, modeling brain activity at three spatial scales—region, circuit, and network. The method achieves state-of-the-art classification on the REST-meta-MDD dataset (AUC 78.5%, ACC 73.8%) and provides interpretable analyses consistent with established neuroscientific findings.
Background & Motivation¶
Major depressive disorder (MDD) affects millions worldwide, and the brain network topology revealed by fMRI is naturally suited to graph-based modeling. Existing GNN approaches face two key bottlenecks:
Limited accuracy: Brain regions are treated uniformly, and methods rely solely on static functional connectivity (FC), ignoring temporal dynamics and MDD-specific alterations.
Poor interpretability: Explicit modeling of the hierarchical organization of brain networks and causal relationships among neural circuits is absent; post-hoc explanation methods cannot be aligned with neuroscientific knowledge.
Neuroscience has established that MDD pathology manifests differently across three spatial scales: - Local regions: Abnormal low-frequency BOLD oscillations (associated with rumination) - Circuit level: Specific dysregulation in DMN, SN, FPN, LN, and RN - Whole-brain network: Abnormal causal information flow between circuits
Prior methods either perform purely data-driven black-box classification or introduce priors at only a single scale, without systematically integrating neuroscientific knowledge across scales.
Method¶
Overall Architecture¶
NH-GCAT consists of three hierarchical modules corresponding to the three spatial scales:
- RG-Fusion (Residual Gated Fusion): Local brain region level — fuses temporal BOLD dynamics with static FC
- HC-Pooling (Hierarchical Circuit Pooling): Multi-region circuit level — aggregates nodes according to the five depression-relevant circuits (DMN/SN/FPN/LN/RN)
- VLCA (Variational Latent Causal Attention): Multi-circuit network level — infers directed information flow between circuits and supports counterfactual reasoning
Problem formulation: Given rs-fMRI from \(N\) subjects, static features \(\mathbf{X}^{(1)} \in \mathbb{R}^{n \times m}\) (FC + demographic variables) and BOLD time series \(\mathbf{X}^{(2)} \in \mathbb{R}^{n \times T}\) are extracted to construct a graph \(\mathcal{G} = (\mathcal{V}, \mathcal{E}, \mathbf{X}^{(1)}, \mathbf{X}^{(2)})\), and the model learns \(f: \mathcal{G} \to \{0, 1\}\).
Key Designs¶
RG-Fusion: Residual Gated Fusion Module¶
A dual-pathway parallel processing scheme is adopted:
- Temporal pathway: BOLD signal → Transformer Encoder producing \(\mathbf{H}_{\text{temp}}\) → concatenated with static features → dual-branch graph convolution (SAGEConv + GATConv) producing \(\mathbf{Z}_{\text{temp}}\)
- Static pathway: Static features → MLP → Gate module for adaptive fusion with temporal features → GATConv producing \(\mathbf{Z}_{\text{static}}\)
- Gate mechanism: \(\mathbf{G} = \sigma(\mathbf{W}_g[\mathbf{Z}_1 | \mathbf{Z}_2] + \mathbf{b}_g)\), \(\mathbf{Z}_{\text{fused}} = \mathbf{G} \odot \mathbf{Z}_1 + (1-\mathbf{G}) \odot \mathbf{Z}_2\)
- Residual connection: FeatureAttention → NodeAttention produces \(\mathbf{H}_{\text{attn}}\), which is then residual-gated with \(\mathbf{Z}_{\text{temp}}\); the result is concatenated with \(\mathbf{Z}_{\text{static}}\) and passed through a variational encoder to yield \(\mathbf{Z}_{\text{ve}}\)
Design Motivation: To capture low-frequency BOLD oscillations (0.01–0.08 Hz), which carry the most diagnostic information in depression fMRI.
HC-Pooling: Hierarchical Circuit Encoding¶
- Circuit assignment: 116 AAL regions are assigned to five circuits (DMN/SN/FPN/LN/RN) based on neuroanatomical knowledge
- Adjacency matrix reconstruction: Subject-level FC, MDD group-mean FC, and HC group-mean FC are fused via learnable gating: \(\mathbf{A}^{(c_j)} = \sum_{k=1}^{3} \text{softmax}(\text{MLP}(\mathbf{Z})) \cdot \mathbf{A}_k\)
- Top-down hierarchical assignment: Gumbel-Softmax assigns nodes within each circuit to three levels (high-level integration / intermediate processing / primary processing)
- Bottom-up hierarchical aggregation: ChildSumTreeLSTM aggregates representations bottom-up to obtain per-circuit embeddings \(\mathbf{H}_{\text{DMN}}, \ldots, \mathbf{H}_{\text{RN}}\)
VLCA: Variational Latent Causal Attention¶
- Standard attention: Q/K/V are computed over the five circuit embeddings to obtain real attention weights \(\mathbf{A}^{\text{real}}\)
- Variational encoding: Attention-weighted representations are encoded into a continuous latent space: \(\mathbf{z}^{\text{real}} = \boldsymbol{\mu}^{\text{real}} + \boldsymbol{\sigma}^{\text{real}} \odot \boldsymbol{\epsilon}\)
- Counterfactual reasoning: The attention matrix is replaced with an identity matrix (severing inter-circuit interactions) to obtain \(\mathbf{z}^{\text{cf}}\)
- Causal effect estimation: \(\mathbf{y}^{\text{effect}} = f_{\text{pred}}(\mathbf{z}^{\text{real}}) - f_{\text{pred}}(\mathbf{z}^{\text{cf}})\)
Loss & Training¶
The total loss is a weighted sum of four terms:
- \(\mathcal{L}_{\text{cls}}\): Binary cross-entropy for MDD/HC classification
- \(\mathcal{L}_{\text{kl}}\): KL regularization for variational encoding
- \(\mathcal{L}_{\text{VLCA}}\): Classification loss + KL divergence on causal effects
- \(\mathcal{L}_{\text{mse}}\): Constrains the learned adjacency matrix to be consistent with group-level FC priors
Training uses the Adam optimizer with gradient clipping and dynamic scheduling of regularization weights. Hidden dimension is 128; causal attention uses a single head of dimension 64. All experiments are conducted on an NVIDIA RTX 4090.
Key Experimental Results¶
Main Results¶
Dataset: REST-meta-MDD, 16 acquisition sites, 1,601 subjects (830 MDD + 771 HC), AAL-116 parcellation.
| Model | AUC | ACC | SEN | SPE | F1 |
|---|---|---|---|---|---|
| LCCAF | 75.6 (1.0) | 70.2 (8.3) | 69.7 (2.7) | 70.7 (2.1) | — |
| BPI-GNN | — | 73.0 (1.0) | — | — | 72.0 (1.0) |
| GAT-Baseline | 71.5 (3.2) | 67.7 (2.7) | 77.5 (9.1) | 57.2 (9.4) | 71.2 (3.3) |
| NH-GCAT | 78.5 (1.7) | 73.8 (1.4) | 76.4 (5.8) | 71.0 (6.6) | 75.0 (1.8) |
| Gain | +2.9 | +0.8 | −1.1 | +0.3 | +2.4 |
LOSO-CV cross-site generalization: Weighted average ACC across 16 sites = 73.3%, surpassing CI-GNN (69.2%) and BrainIB (68.8%).
| Method | Weighted Avg. ACC |
|---|---|
| CI-GNN | 69.2% |
| BrainIB | 68.8% |
| NH-GCAT | 73.3% |
Ablation Study¶
| Model Variant | AUC | ACC | SPE | F1 |
|---|---|---|---|---|
| GAT-Baseline | 71.5 | 67.7 | 57.2 | 71.2 |
| + RG-Fusion | 74.8 (+3.3) | 70.2 (+2.5) | 70.6 (+13.4) | 70.5 |
| + VLCA | 75.9 (+1.1) | 72.0 (+1.8) | 68.2 | 73.6 (+3.1) |
| + HC-Pooling (full) | 78.5 (+7.0) | 73.8 (+6.1) | 71.0 (+13.8) | 75.0 (+3.8) |
Incremental addition of the three modules yields a cumulative AUC gain of 7.0% and a specificity gain of 13.8% (statistically significant, p < 0.05).
Key Findings¶
- Low-frequency oscillation validation: RG-Fusion achieves AUC = 0.742 on low-frequency inputs (0.01–0.08 Hz), significantly higher than 0.679 on high-frequency inputs (0.1–0.25 Hz) (p = 0.0037).
- Hierarchical assignment: HC-Pooling finds that DMN regions (e.g., medial superior frontal gyrus) in MDD patients are over-assigned to the high level, consistent with pathological rumination; FPN regions show reduced high-level assignment, suggesting impaired cognitive control.
- Causal circuit analysis: VLCA reveals MDD-related patterns including weakened DMN→SN regulation and abnormally enhanced RN→DMN input, corresponding to core depressive symptoms.
Highlights & Insights¶
- Three-scale modeling paradigm: This is the first framework to simultaneously incorporate neuroscientific priors at the region, circuit, and network levels, unlike prior work that applies constraints at a single scale only.
- Counterfactual causal reasoning: VLCA estimates causal effects by "severing inter-circuit interactions" in a counterfactual manner, which is more principled than post-hoc attention visualization.
- Strong clinical interpretability: Intermediate representations from all modules can be mapped to established neuroscientific findings, avoiding opaque intermediate representations.
Limitations & Future Work¶
- Validation is limited to the REST-meta-MDD dataset; generalizability to other psychiatric conditions (e.g., autism via ABIDE) has not been tested.
- The five-circuit parcellation is based on prior knowledge; different disorders may require different circuit definitions.
- Temporal resolution is constrained by the fMRI TR (~2 s), precluding the capture of faster neural dynamics.
- Performance is unstable across some sites (e.g., Site 8 drops by 11.2%); cross-site domain shift remains an open challenge.
Related Work & Insights¶
- NH-GCAT belongs to the same brain-graph GNN family as BrainIB (information bottleneck) and CI-GNN (causal inference), but is the first framework to integrate priors across all three spatial scales.
- The use of ChildSumTreeLSTM for hierarchical aggregation is a transferable idea applicable to other graph problems with known hierarchical structure.
- The framework can be extended to other neuroimaging modalities (EEG, DTI) and other psychiatric or neurological conditions such as Parkinson's disease and PTSD.
Rating¶
- Novelty: ★★★★☆ — Three-scale integration of neuroscientific priors represents a significant methodological contribution
- Technical Depth: ★★★★☆ — All three modules are carefully designed with rigorous mathematical formulation
- Experimental Thoroughness: ★★★☆☆ — Only one dataset, but ablation and interpretability analyses are very thorough
- Writing Quality: ★★★★☆ — Well-structured with good coordination between text and figures