Skip to content

GALAX: Graph-Augmented Language Model for Explainable Reinforcement-Guided Subgraph Reasoning in Precision Medicine

Conference: ICLR 2026
OpenReview: https://openreview.net/forum?id=ADFXCeYXvR
Code: TBD
Area: Precision Medicine / Graph-Augmented LLM / Reinforcement Learning
Keywords: Subgraph Reasoning, Process Reward Model, Multi-omics, CRISPR Target Discovery, GNN-LLM Fusion

TL;DR

GALAX treats a pre-trained GNN as a "process judge," using reinforcement learning to guide an LLM in incrementally constructing disease-related subgraphs. This enables explainable, patient-specific cancer target prediction without the need for step-by-step annotations.

Background & Motivation

  • Background: Precision medicine aims to identify key signaling pathways and therapeutic targets driving diseases. This requires the simultaneous utilization of three types of information: quantitative multi-omics features (genomics/transcriptomics/proteomics), the topological structure of biological networks, and literature-scale textual knowledge. CRISPR large-scale knockout screens provide a reliable experimental "gold standard" for targets.
  • Limitations of Prior Work: Traditional differential expression or essentiality scoring fails to model the hierarchy and cross-modal dependencies of molecular networks. Graph models excel at outcome prediction but lack the structured supervision required for target prioritization. Knowledge Graph-based RAG (RoG, SubgraphRAG, GNN-RAG, G-Retriever) typically focuses only on the accuracy of the final answer; the retrieved subgraphs are often noisy, oversized, and lack ground-truth mechanistic structures. Crucially, they generally lose numerical omics signals, leading to a lack of cell-line-specific information.
  • Key Challenge: Process Reward Models (PRM) could theoretically provide fine-grained supervision for intermediate reasoning steps, but they face three major bottlenecks: coarse step definitions, difficulty in verifying intermediate correctness, and vulnerability to reward hacking. In biomedicine, these issues are amplified: Text-Omic Numerical Graphs (TNG) lack step-by-step ground-truth labels, and the combinatorial explosion of biological pathways makes exhaustive planning or retrieval impractical.
  • Goal: To unify numerical omics, textual knowledge, and biological topology into a reinforcement learning framework without explicit step-wise labels, using subgraph reasoning as a bridge to connect "numerical evidence, topological knowledge, and linguistic context" to output accurate and explainable patient-specific targets.
  • Core Idea: Utilize a pre-trained GNN as a source of process rewards—allowing the GNN to score the "biological plausibility and cancer relevance" of partial signal subgraphs generated step-by-step by the LLM, supplemented by schema-based rule validation. This achieves process-level supervision without annotations, translating linguistic reasoning into explainable graph construction.

Method

Overall Architecture

GALAX consists of two stages: first, an initial LLM \(f_{init}\) performs rough reasoning from multi-omics profiles to extract candidate targets. Then, a reinforcement learning graph generator \(\pi\), guided by step-wise rewards from a pre-trained graph foundation model \(g\), incrementally builds an explainable subgraph \(G^\dagger\). Finally, a second-stage LLM \(f_{final}\) reads both the "initial output and the generated subgraph" to make refined predictions. The subgraph \(G^\dagger\) itself serves as a transparent basis for reasoning. The data foundation is the TOSG (Text-Omic Signaling Graph), supported by the Target-QA benchmark providing CRISPR ground-truth supervision.

flowchart LR
    A[Multi-omics Profile<br/>Top-K Gene/Trans/Prot] --> B[Retrieve Disease-related Proteins<br/>+ h-hop neighbors]
    A --> C[f_init Initial LLM<br/>Rough reasoning extracted candidates]
    B --> D[π Graph Generator<br/>Edge-by-edge subgraph construction]
    C --> D
    E[g Pre-trained GNN<br/>Process Reward GPRM] -.Scoring.-> D
    F[Schema Rule Validation] -.Penalty.-> D
    D --> G[Optimal Subgraph G†]
    G --> H[f_final Stage-2 LLM<br/>Expert mode description + Prediction]
    H --> I[Top-γ CRISPR Targets]

Key Designs

1. TOSG welds three modalities into one graph: placing omics values, text, and topology in the same coordinate system. The authors construct the Text-Omic Signaling Graph \(G = \{\mathcal{X}^{(0)}, \mathcal{T}, \mathcal{V}, \mathcal{E}\}\), where nodes are categorized into four types: promoters, genes, transcripts, and proteins (\(|\mathcal{V}| = m^{(pm)} + m^{(g)} + m^{(t)} + m^{(p)} = M\)). Omics features for each sample are concatenated across categories \(X_n^{(0)} = [x_n^{(pm)} \oplus x_n^{(g)} \oplus x_n^{(t)} \oplus x_n^{(p)}]\). The graph is split into an internal signaling subgraph \(G^{(in)}\) (propagating along the central dogma: promoter→gene→transcript→protein) and a protein interaction subgraph \(G^{(PPI)}\). This preserves numerical evidence from DepMap while attaching textual descriptions and topological relationships from BioMedGraphica, providing an aligned input for both GNN and LLM.

2. Dual foundation model pre-training: Training the GNN as a "judge of cancer relevance" and the LLM as a "proponent of biological terminology." The graph side follows a two-stage process: The first stage applies Bernoulli masking to PPI edges \(\mathcal{E}_{mask} \sim \text{Bernoulli}(p)\), followed by cross-modal encoding and internal/global masked propagation to learn gene regulatory patterns. The second stage transfers pre-trained parameters \(\theta_G^{pre}\) to a downstream classifier, predicting disease types using \(\hat{Y}^{(0)} = \arg\max_o \text{Softmax}[\text{MLP}_G(f_G(\cdot))]\) to minimize cross-entropy. This pre-trained GNN becomes the "authoritative judge" for process rewards (achieving 99.46%/96.15% train/test accuracy in disease classification). The language side involves pre-training LLaMA3-8B on biomedical corpora to master terminology and structures related to protein interactions and disease-target relationships.

3. Modeling subgraph construction as RL with edge-by-edge addition: State-Action-Policy revolving around "adding one edge." The state is the current graph \(G_n^{(i)} = (V_n^{(i)}, E_n^{(i)})\). The action \(\Delta_n^{(i)} = (v_{src}^i, v_{tgt}^i)\) adds an edge under a feasibility mask. The policy first uses a message propagation module to obtain node embeddings \(X_n^{(i)} = \pi_{MSG}(G_n^{(i)}, X_n^{(cand)})\), then samples source and target nodes using two masked probability functions: \(v_{src}^i \sim \pi_{SRC}(X_n^{(i)}, M_{SRC})\) and \(v_{tgt}^i \sim \pi_{TGT}(X_n^{(i)}, M_{TGT}; v_{src}^i)\). The mask \(M_{SRC}\) restricts source nodes to those already in the graph, while \(M_{TGT}\) excludes the source node itself. The starting node set \(V_n^{(start)}\) is selected based on priority: Top-\(\eta\) disease-related proteins, then Top-\(\eta\) initial candidates, or random sampling of \(\eta\) omics nodes. The candidate set is \(V_n^{(cand)} = V_n^{(init)} \cup V_n^{(sub)} \cup V_n^{(omic)}\).

4. GPRM Process Reward = GNN immediate feedback + rollout future simulation + rule penalty, replacing step-wise labels. This is the core mechanism of the paper. Each step's reward first checks the pre-trained classifier's probability for the target class \(o^\star\) on the current subgraph, using \(L\) rollout simulations for look-ahead: $\(R_n^{(i)} = g_{o^\star}(G_n^{(i+1)}) - \frac{1}{|O|} + \lambda \cdot \frac{1}{L}\sum_{\ell=1}^{L}\left(g_{o^\star}(\text{Rollout}_\ell(G_n^{(i+1)})) - \frac{1}{|O|}\right)\)$ A rule-based term \(R_{rule}\) based on BioMedGraphica relations penalizes illegal edges violating the schema, yielding \(R_{total}^{(i)} = R_n^{(i)} + \lambda_{rule} \cdot R_{rule}(G_n^{(i+1)})\). A greedy acceptance strategy is used: an edge is accepted only if \(R_{total}^{(i)} > 0\). The generator is optimized using reward-weighted cross-entropy \(L_{step} = -R_{total}^{(i)}[\text{CE}(v_{src}^i, \pi_{SRC}) + \text{CE}(v_{tgt}^i, \pi_{TGT})]\). After multiple samplings, the optimal subgraph \(G_n^\dagger\) is retained. This design cleverly externalizes the "correctness of intermediate steps" to an independently trained GNN, bypassing the need for manual labels in PRM.

5. Subgraph reinjection for final answer: Translating the graph into expert-level text prompts to fine-tune the LLM for target output. The optimal subgraph \(G_n^\dagger\) is verbalized into structured text using an "expert mode," which is appended to the original query to form the final prompt \(P_n^{(final)}\). The second-stage LLM generates Top-\(\gamma\) targets autoregressively via \(\xi_{\theta_{final}}(\hat{A}_n | Q_n, G_n^\dagger)\), aligned with ground truth using token-level cross-entropy \(L_{final}\). Finally, NER is used to extract the predicted proteins for comparison with reference targets. The subgraph serves as both a supervision signal and a readable mechanistic explanation.

Key Experimental Results

Main Results (Precision / Recall, Excerpt)

Model Overall Prec ↑ Overall Rec ↑ LUAD Rec BRCA Rec
M2T (Traditional Multi-omics) 0.0016 0.0011 0.0014 0.0000
L3-FT(QA)+Omics 0.5250 0.4959 0.4905 0.4856
G-Retriever+pre-GAT 0.4763 0.3929 0.3881 0.3772
RoG 0.5248 0.4726 0.4562 0.4311
SubgraphRAG 0.5280 0.4617 0.4448 0.3917
GNN-RAG 0.5258 0.4735 0.5052 0.4389
Ours (GALAX) 0.5472 0.5332 0.5157 0.5533
GALAX (Qwen2.5-7B) 0.5445 0.5405 0.5462 0.5206
  • Dataset Target-QA: 363 cancer cell line QA pairs, 80/20 split, 4 random seeds. Each answer consists of Top-100 CRISPR prioritized targets.
  • GALAX shows the most significant improvement in Recall (approx. +6 points over the strongest RAG baseline), particularly notable in BRCA.

Hit@10 / Hit@5 (Excerpt)

Model Hit@10 ↑ Hit@5 ↑
L3-FT(QA)+Omics 0.8693 0.8889
RoG 0.8450 0.8593
SubgraphRAG 0.8476 0.8624
GNN-RAG 0.8323 0.8656
GALAX 0.8815 0.9249
GALAX (Qwen2.5-7B) 0.8841 0.9079

Ablation Study (Language Axis × Graph Axis)

Configuration Function Result
L3+Omics LLaMA3 without task fine-tuning Very poor
L3-FT(Med)+Omics Biomedical text domain adaptation Slight improvement
L3-FT(QA)+Omics Target-QA task fine-tuning Qualitative jump (Performance leap)
+KG (Static KG) Concatenating KG directly Negligible gain or decrease
G-Retriever+pre-GAT Subgraph retrieval via pre-trained GAT Unstable (difficult to extract from millions of nodes)
+RL (Full GALAX) Reinforcement-guided construction Further ~2%–5% gain across metrics

Key Findings

  • Task-adaptive fine-tuning is the performance watershed: L3-FT(QA) jumped from ~1% to ~52% compared to domain fine-tuning, indicating that QA-style supervision is more critical than raw biomedical text ingestion.
  • Static KG concatenation can be harmful: Feeding the Knowledge Graph directly to the LLM yielded little gain or even a loss, confirming that the "noisy subgraph + no mechanism ground-truth" retrieval paradigm is unreliable.
  • Reinforcement-guided process-level subgraph construction is the source of incremental gain: Adding RL on top of QA fine-tuning provided a stable +2%–5% across datasets, validating that the "GNN as process judge" route outperforms all graph-augmented RAG baselines.
  • Controllable Complexity: The training/inference complexity of GALAX is \(O(\kappa + M\varepsilon + M^2\varepsilon)\), comparable to RoG/GNN-RAG. Since candidates \(\eta \ll M\), the graph embedding term dominates the RL reward costs.

Highlights & Insights

  • Resolving the PRM bottleneck: The three pain points of PRM (step definition, verification, and reward hacking) stem from a lack of step-wise labels. GALAX uses an independently pre-trained, objective-oriented (disease classification) GNN as a judge. This provides verifiable intermediate signals and, because the judge is fixed, makes the strategy difficult to hack. It is an elegant engineering solution.
  • Explainability as a "By-product" rather than "Post-processing": Subgraph \(G^\dagger\) serves simultaneously as the reasoning process, the supervision object, and the human-readable mechanistic explanation. It is inherently tied to the decision, unlike post-hoc attribution methods.
  • Practical Multi-modal Fusion: Numerical omics are not discarded (a point of criticism for RAG systems), but instead enter the TOSG alongside text and topology, preserving cell-line-specific information.
  • Rollout Foresight + Greedy Acceptance: Each step considers both immediate classification gain and simulated future trajectories to avoid nearsighted edge addition, while greedy acceptance ensures monotonic improvement.

Limitations & Future Work

  • Small Data Scale: Target-QA contains only 363 QA pairs and 336 pre-training samples. Furthermore, the task is simplified to binary classification (\(|O|=2\)) using only 1-hop neighbors. Generalization to more cancer types and deeper topologies requires further validation.
  • Judge Performance Ceiling: Process rewards rely entirely on the judgment of the pre-trained GNN (edge prediction AUC is only 64.4%). Biases in the judge may be amplified by reinforcement learning.
  • High Dependency on External Components: The pipeline involves GPT-4o-mini for NER, BioMedGraphica for integration, and BioBERT for text embedding. The long chain results in high reproducibility costs, and drift in any segment could affect results.
  • Sensitivity of Reward Hyperparameters: \(\lambda\), \(\lambda_{rule}\), rollout depth \(L\), and \(\eta\) are set empirically (often equal weight or fixed), lacking a systematic sensitivity analysis.
  • Clinical Gap: There is still a gap between CRISPR cell-line targets and real patient treatment response. The biological accuracy of the "explainable subgraphs" requires validation via wet-lab experiments.
  • Graph-Augmented RAG (RoG, SubgraphRAG, GNN-RAG, G-Retriever): These are direct competitors. GALAX differs by actively generating subgraphs and scoring the process, rather than just optimizing final answers via subgraph retrieval.
  • Process Reward Models / Large Reasoning Models (PRM, StepGRPO, RLHF/PPO/GRPO): GALAX inherits the idea of "process-level supervision" but replaces manual or model-based annotations with a pre-trained GNN + rules to avoid reward hacking.
  • Multi-omics GNN (MOGONET, MoGCN): These provide precedents for graph-structure reasoning in cancer subtyping but typically focus on outcome prediction rather than explainable target prioritization.
  • Insight: For any reasoning task "lacking step-wise labels but having outcome labels," one can consider training an independent outcome classifier as a process judge. This transforms expensive step-wise labeling into a reward design problem—a paradigm applicable beyond biomedicine.

Rating

  • Novelty: ⭐⭐⭐⭐⭐ The idea of using a pre-trained GNN as a source for process rewards to bypass the PRM labeling dilemma is novel and self-consistent. It is an organic fusion of GNN, LLM, and RL.
  • Experimental Thoroughness: ⭐⭐⭐ Ablations across language and graph axes are clear, and complexity analysis is thorough. However, the small dataset size, task simplification, and lack of reward sensitivity analysis are drawbacks.
  • Writing Quality: ⭐⭐⭐⭐ The motivation is progressively developed, and notation is rigorous. Figure 3 provides a clear workflow; however, high notation density and heavy reliance on the appendix make the initial reading threshold high.
  • Value: ⭐⭐⭐⭐ This provides a new "generative + process supervision" path for explainable target discovery in precision medicine. The accompanying Target-QA benchmark also has reuse value.