Tree-Sliced Wasserstein Distance: A Geometric Perspective¶
Conference: ICML2025
arXiv: 2406.13725
Code: GitHub
Area: Image Generation
Keywords: Sliced Wasserstein, Tree Metric, Radon Transform, Optimal Transport, Generative Models
TL;DR¶
Proposes Tree-Sliced Wasserstein distance on Systems of Lines (TSW-SL), which replaces the one-dimensional lines in SW with tree-shaped line systems as projection domains. This preserves topological structure while maintaining the efficient computation of closed-form solutions, outperforming SW and its variants in gradient flows, style transfer, and generative models.
Background & Motivation¶
Background¶
Background: OT Computational Bottleneck: Optimal Transport (Wasserstein distance) has a super-cubic computational complexity with respect to the number of support points \(n\), making direct applications difficult.
Limitations of Prior Work¶
Limitations of Prior Work: Limitations of SW: Sliced Wasserstein (SW) accelerates computation by utilizing closed-form solutions through projection onto one-dimensional lines. However, one-dimensional projection leads to the loss of topological information of the input distribution, performing poorly especially in high dimensions.
Key Challenge¶
Key Challenge: Core Motivation: To find a geometric domain that is more complex than one-dimensional lines but still allows closed-form OT solutions. Tree metric spaces satisfy this condition, as Wasserstein distances on them have closed-form expressions.
Method¶
Overall Architecture¶
- Tree System: Connects \(k\) lines in \(\mathbb{R}^d\) according to a tree structure to form a tree system \((\mathcal{L}, \mathcal{T})\).
- Splitting Map: \(\alpha: \mathbb{R}^d \to \Delta_{k-1}\) distributes the mass of high-dimensional points to each line proportionally.
- Radon Transform on Tree Systems: Pushes forward the high-dimensional distribution onto the tree system, and then computes the OT under the tree metric.
Key Designs¶
- Metrization of Tree Systems (Theorem 3.2): The topological space \(\Omega_{\mathcal{L}}\) on the tree system can be metrized by a tree metric \(d_{\mathcal{L}}\), where the distance between two points equals the measure of the unique path: $\(d_{\mathcal{L}}(a,b) = \mu_{\mathcal{L}}(P_{a,b})\)$
- System Radon Transform (Definition 4.1): $\(\mathcal{R}^\alpha_{\mathcal{L}} f(x,l) = \int_{\mathbb{R}^d} f(y) \cdot \alpha(y)_l \cdot \delta(t_x - \langle y - x_l, \theta_l \rangle) \, dy\)$ The injectivity of this transform is proven (Theorem 4.2).
- TSW-SL Distance (Definition 5.1): $\(\text{TSW-SL}(\mu,\nu) = \int_{\mathbb{T}} W_{d_{\mathcal{L}},1}(\mathcal{R}^\alpha_{\mathcal{L}} f_\mu, \mathcal{R}^\alpha_{\mathcal{L}} f_\nu) \, d\sigma(\mathcal{L})\)$ Approximated via Monte Carlo sampling of \(L\) tree systems, each containing \(k\) lines.
Loss & Complexity¶
- Time complexity is \(O(Lkn\log n + Lkdn)\), which is equivalent to SW under the same number of projection directions.
- TSW-SL is a metric on \(\mathcal{P}(\mathbb{R}^d)\) (Theorem 5.2), and degenerates to standard SW when \(k=1\).
Key Experimental Results¶
Main Results: Gradient Flows¶
| Method | Swiss Roll (iter 2500) | 25 Gaussians (iter 2500) |
|---|---|---|
| SW | 1.05e-3 | 2.20e-2 |
| MaxSW | 3.45e-3 | - |
| TSW-SL | Significantly outperforms SW | Significantly outperforms SW |
On Swiss Roll and 25 Gaussians datasets, TSW-SL achieves lower Wasserstein distances across all iteration rounds.
Image Style Transfer & Generative Models¶
- Style Transfer: The quality of stylized images generated by TSW-SL outperforms SW and MaxSW.
- Generative Models (MNIST/CIFAR-10): TSW-SL consistently outperforms SW and its variants in FID scores, while maintaining comparable computational time.
Ablation Study¶
- Impact of the Number of Lines \(k\) in the Tree System: Increasing \(k\) generally yields better performance, but with diminishing marginal returns.
- Choice of Splitting Map: The uniform splitting map (\(\alpha = 1/k\)) already yields good results, and the distance-weighted splitting map can further improve performance.
Highlights & Insights¶
- Theoretical Completeness: From topological structure \(\to\) metric space \(\to\) Radon transform \(\to\) injectivity \(\to\) metric property, establishing a complete mathematical framework.
- Elegant Generalization: TSW-SL naturally encompasses SW as a special case when \(k=1\). The theoretically richer topological structure captures more high-dimensional information.
- Strong Practicality: With the same computational complexity as SW, it can directly replace the distance metric in existing SW pipelines.
- Clever Use of Tree Metrics: Exploiting the closed-form solution of OT on trees avoids the computational challenge of general high-dimensional OT.
Limitations & Future Work¶
- Currently, a chain-like tree structure is used. The performance of more general tree sampling strategies (e.g., star-like, random trees) remains to be explored.
- The splitting map \(\alpha\) is currently predefined; learning the optimal splitting map could be considered.
- Only the case of \(p=1\) is verified (due to the limitation of closed-form solutions on tree metrics). Generalization to \(p>1\) requires additional work.
- Sampling efficiency and variance control in high-dimensional and high-\(k\) scenarios still require in-depth analysis.
Related Work & Insights¶
- Sliced Wasserstein Family: SW, MaxSW, Subspace SW, etc., improve OT efficiency through different projection strategies.
- Tree Wasserstein (TW): Le et al. (2019) proposed closed-form solutions for OT on tree metrics, which this work generalizes to continuous tree systems.
- Nonlinear Projection: Kolouri et al. (2019) parameterized projection directions using neural networks. This work is complementary, approaching from a geometric domain perspective.
- Subspace Methods: Alvarez-Melis et al. (2018), Bonet et al. (2023) project onto low-dimensional subspaces instead of one-dimensional lines.
- Insights:
- Sliced OT variants on richer geometric domains (e.g., graphs, manifolds) can be further explored.
- Stochastic sampling strategies for tree systems can be combined with attention mechanisms to adaptively adjust the tree structure.
- Learning the splitting map can be viewed as a formulation of soft clustering, which has potential connections to mixture models.
Technical Details Supplement¶
- Sampling of Tree Systems (Algorithm 1): The starting point \(x_1\) is sampled from the uniform distribution \(\mathcal{U}([-1,1]^d)\), and the direction \(\theta_1\) is sampled from \(\mathcal{U}(\mathbb{S}^{d-1})\). The starting point of each subsequent line is randomly selected on the previous line, with directions sampled independently.
- Monte Carlo Estimation: \(\widehat{\text{TSW-SL}}(\mu,\nu) = \frac{1}{L}\sum_{l=1}^L W_{d_{\mathcal{L}_l},1}(\mathcal{R}^\alpha_{\mathcal{L}_l}f_\mu, \mathcal{R}^\alpha_{\mathcal{L}_l}f_\nu)\), where the OT of each tree system is computed directly using the closed-form solution (Equation 5).
- Handling Discrete Distributions: Support points are projected onto each line of the tree system, with weights allocated by the splitting map, and then sorted to compute TW.
Rating¶
- Novelty: ⭐⭐⭐⭐⭐
- Experimental Thoroughness: ⭐⭐⭐⭐
- Writing Quality: ⭐⭐⭐⭐⭐
- Value: ⭐⭐⭐⭐