WILTing Trees: Interpreting the Distance Between MPNN Embeddings¶
Conference: ICML 2025
arXiv: 2505.24642
Code: None
Area: Graph Learning/GNN Interpretability
Keywords: MPNN, Graph Distance, Interpretability, Optimal Transport, Weisfeiler-Leman, Graph Kernels
TL;DR¶
This paper discovers that the embedding distances learned by MPNNs align with task-related functional distances rather than structural distances. Consequently, it proposes an optimal transport distance based on Labeled Weisfeiler-Leman Trees (WILT) to distill and interpret MPNN distances, where edge weights reveal that a small number of key subgraphs dominate the metric structure of the embedding space.
Background & Motivation¶
Although MPNNs exhibit outstanding performance on graph prediction tasks, the metric structure of their embedding space (i.e., the distances between graphs) remains poorly understood. Prior studies (Chuang & Jegelka 2022; Böker et al. 2024) attempted to analyze the generalization performance of MPNNs using structural distances. However, these structural distances ignore task information and require strong assumptions (such as Lipschitz constant constraints or margin assumptions between classes).
The authors pose two core problems: 1. Does the distance learned by the MPNN \(d_{\text{MPNN}}\) align with the task-related functional distance \(d_{\text{func}}\)? Is this alignment key to high performance? 2. How do MPNNs learn such metric structures, and can they be distilled in an interpretable manner?
The authors find that the alignment between MPNN distances and structural distances does not consistently improve with training, nor does it correlate with performance; instead, alignment with functional distances is key. This finding motivates the authors to design WILT to distill and interpret the distance functions of MPNNs.
Method¶
Overall Architecture¶
The overall workflow consists of two phases: 1. Analysis Phase: Define a functional pseudometric \(d_{\text{func}}\) and alignment metric \(\text{ALI}_k\) to verify that the trained MPNN distance indeed aligns with \(d_{\text{func}}\), and that the degree of alignment highly correlates with predictive performance. 2. Distillation Phase: Construct a Weisfeiler-Leman labeled tree (WILT), define a trainable optimal transport distance \(d_{\text{WILT}}\) on it, and learn edge weights by minimizing the MSE with \(d_{\text{MPNN}}\). The learned edge weights directly reveal which WL colors (subgraph patterns) have the greatest impact on MPNN embedding distances.
Key Designs¶
Key Design 1: Functional Pseudometric and Alignment Metric¶
The authors define a functional pseudometric to quantify the distance between two graphs on the target task:
To measure the degree of alignment between the MPNN distance and the functional distance, the \(\text{ALI}_k\) metric is defined: for each graph \(G\), the average functional distance to its \(k\)-nearest neighbors in the MPNN space, \(A_k(G)\), is compared with the average functional distance to its non-neighbors, \(B_k(G)\). If \(B_k(G) > A_k(G)\) (meaning closer graphs have closer functional properties), it indicates that the MPNN distance aligns with the functional distance:
Experiments demonstrate that this metric improves significantly after training and exhibits a strong positive correlation with predictive performance (the Spearman correlation coefficient reaches 0.66–0.70 on Mutagenicity).
Key Design 2: WILT Distance and Linear-Time Computation¶
WILT is a weighted rooted tree constructed from the color hierarchy generated by the WL algorithm. The tree nodes represent all colors appearing in the WL test, and edges connect the colors of the same node across adjacent iterations. The WILT distance is defined as the optimal transport distance between the WL colors of the graph nodes, where the ground metric is the path length on the tree:
The key breakthrough lies in an equivalent representation: because the ground metric is the path length on a tree, this optimal transport problem can be equivalently transformed into a weighted Manhattan distance:
where \(\nu_c^G\) is the occurrence frequency of color \(c\) in graph \(G\). This reduces the computational complexity from \(O(n^3)\) to \(O(|V_G|+|V_H|)\), which is linear time.
Key Design 3: Two Normalizations and Generalization Relationship with Graph Kernels¶
To handle graphs of different sizes, two types of normalization are proposed: - Size Normalization \(\dot{d}_{\text{WILT}}\): Replaces counts with \(\nu_c^G / |V_G|\); when the edge weights are constantly \(\frac{1}{2(L+1)}\), it degenerates to the Wasserstein WL graph kernel distance. - Dummy Node Normalization \(\bar{d}_{\text{WILT}}\): Adds virtual isolated nodes to equalize the node count across all graphs; when the edge weights are constantly \(\frac{1}{2}\), it degenerates to the WL Optimal Assignment kernel distance.
The expressive power relationship is theoretically proven as \(\dot{d}_{\text{WILT}} < \bar{d}_{\text{WILT}} \cong d_{\text{WL}}\), with \(\dot{d}_{\text{WILT}}\) being more suitable for approximating mean-pooling MPNNs, and \(\bar{d}_{\text{WILT}}\) more suitable for sum-pooling MPNNs.
Key Experimental Results¶
Table 1: RMSE of WILT Distance Approximating MPNN Distance (×10⁻²)¶
| Dataset/Pooling | \(d_{\text{WWL}}\) | \(d_{\text{WLOA}}\) | \(\dot{d}_{\text{WILT}}\) | \(\bar{d}_{\text{WILT}}\) |
|---|---|---|---|---|
| Mutagenicity/mean | 9.25±0.87 | 18.74±3.36 | 1.74±0.52 | 3.34±1.01 |
| Mutagenicity/sum | 12.25±0.54 | 5.98±1.60 | 1.22±0.31 | 0.82±0.17 |
| ENZYMES/mean | 12.18±0.23 | 16.79±2.33 | 2.71±0.38 | 4.64±0.67 |
| ENZYMES/sum | 11.28±0.65 | 6.83±0.41 | 9.15±0.47 | 1.43±0.10 |
| Lipophilicity/mean | 10.92±0.42 | 13.97±0.97 | 3.11±0.54 | 6.35±1.22 |
| Lipophilicity/sum | 10.83±0.73 | 10.00±1.34 | 2.50±0.67 | 2.64±0.74 |
The WILT distance is an order of magnitude lower than the fixed-weight baselines (WWL and WLOA). Size normalization is more suitable for mean pooling, while dummy normalization is more suitable for sum pooling.
Table 2: Spearman Correlation Coefficient Between ALI_k and Predictive Performance¶
| k | Mutagenicity (train/test) | ENZYMES (train/test) | Lipophilicity (train/test) |
|---|---|---|---|
| 1 | 0.66/0.70 | 0.89/0.49 | -0.65/-0.57 |
| 5 | 0.65/0.69 | 0.87/0.50 | -0.63/-0.56 |
| 10 | 0.63/0.68 | 0.87/0.47 | -0.62/-0.56 |
| 20 | 0.61/0.67 | 0.85/0.46 | -0.60/-0.55 |
For classification tasks (Mutagenicity, ENZYMES), a higher ALI correlates with higher accuracy; for the regression task (Lipophilicity), a higher ALI correlates with a lower RMSE, demonstrating a consistent relationship across tasks.
Key Findings¶
- MPNN Distances Align with Functional Distances, Not Structural Distances: After training MPNNs, their embedding distances naturally align with the task-relevant functional distances, and the degree of alignment is highly positively correlated with performance. The structural distance alignment, which previous work focused on, neither improves with training nor correlates with performance.
- Only a Few Key Subgraphs Determine MPNN Distances: The distribution of WILT edge weights is heavily skewed towards zero. Even when 95% of the edge weights are zeroed out using L1 regularization, \(d_{\text{WILT}}\) still outperforms fixed-weight baselines. This indicates that MPNNs rely on a very small number of WL colors to define embedding distances.
- Identified Key Subgraphs Align with Domain Knowledge: On the Mutagenicity dataset, the WL colors corresponding to the largest weights represent known mutagenic functional groups such as epoxide and aliphatic halide (Kazius et al., 2005), validating that the MPNN has learned meaningful chemical knowledge.
Highlights & Insights¶
- Elegant Theoretical Framework: Extends MPNN distance analysis from binary expressivity to metric space analysis, establishing an interpretable bridge via optimal transport on WILT. The equivalent representation reduces the Wasserstein distance computation, which typically requires cubic complexity, to linear time.
- Unifying Two Classical Graph Kernels: As a parameterized framework, the WILT distance unifies the Wasserstein WL kernel and the WL optimal assignment kernel, both of which are special cases under specific weight configurations, revealing the intrinsic connection between different graph kernels.
- Global Interpretability: Unlike most instance-level GNN explanation methods, WILT provides a global metric space explanation, identifying subgraph patterns that play a decisive role in the distances between all graph pairs. It is applicable to both classification and regression tasks.
Limitations & Future Work¶
- Only GCN and GIN architectures with fixed hyperparameter configurations were evaluated; generalization to other GNN architectures or different hyperparameters remains unverified.
- Only the embedding distances of the final layer of the MPNN were analyzed; the evolution of distances in intermediate layers was not explored.
- WILT is constrained by the 1-WL color hierarchy; its applicability to higher-order GNNs warrants further research.
- The experimental dataset sizes are relatively small (Mutagenicity ~4337, ENZYMES 600); scalability to larger graphs needs to be validated.
- The current method is only applicable to graph-level tasks; its applicability to node-classification tasks is yet to be developed.
Related Work & Insights¶
- GNN Metric Analysis: Tree Mover's Distance by Chuang & Jegelka (2022) and the fine-grained expressivity work by Böker et al. (2024) focus on structural distances. This work challenges this paradigm and pivots toward functional distances.
- GNN Interpretability: Unlike instance-level methods such as GNNExplainer, this paper takes a global distillation approach (similar to GraphChef) but focuses on the metric structure rather than the decision boundaries.
- Graph Kernel Methods: WILT unifies the WWL kernel and the WLOA kernel, suggesting that trainable graph kernel designs can potentially achieve both interpretability and high performance.
- Insights: The paradigm of "first training a black-box model, then distilling it into an interpretable structure" is worth extending to other domains. The comparative concept of functional distance vs. structural distance is highly inspiring for understanding the essence of representation learning.
Rating¶
| Dimension | Score | Description |
|---|---|---|
| Novelty | 8/10 | Understands MPNNs from a metric space perspective; the WILT concept is novel and unifies classical graph kernels. |
| Technical Depth | 9/10 | Solid theoretical analysis, rigorous expressivity theorems, and an elegant linear-time algorithm. |
| Experimental Thoroughness | 7/10 | Limited datasets and architectures, but the quantitative and qualitative results support each other well. |
| Practical Value | 7/10 | Provides a global perspective for GNN explanation, with convincing validation in the chemistry domain. |
| Writing Quality | 8/10 | Clear structure, problem-driven, and intuitive illustrations. |
| Overall Rating | 8/10 | Excellently reveals the metric structure of the MPNN embedding space from both theoretical and empirical angles. |