MUST: Modality-Specific Representation-Aware Transformer for Diffusion-Enhanced Survival Prediction with Missing Modality¶
Conference: CVPR 2026 arXiv: 2603.26071 Code: Project Page Area: Medical Imaging / Multimodal Fusion Keywords: Survival Prediction, Missing Modality, Algebraic Decomposition, Latent Diffusion Model, Multimodal Fusion
TL;DR¶
This paper proposes MUST, a framework that explicitly decomposes multimodal representations into modality-specific and cross-modal shared components via algebraic constraints, and employs a conditional latent diffusion model to generate modality-specific information under missing-modality scenarios. MUST achieves state-of-the-art performance with a C-index of 0.742 across five TCGA cancer datasets, with degradation of only ~0.4%–3.5% under missing-modality conditions.
Background & Motivation¶
- Background: Multimodal survival prediction (pathology WSI + genomics) significantly improves prognostic accuracy; methods such as SurvPath and CMTA achieve multimodal fusion via cross-attention mechanisms.
- Limitations of Prior Work: Modality absence is common in clinical settings—genomic profiling is costly and time-consuming, and historical records often contain only pathology without molecular data. Existing multimodal models assume complete data and suffer severe performance degradation under missing modalities.
- Key Challenge: Existing missing-modality methods fall into three categories—feature alignment (agnostic to what is missing), interpolation (noisy in high-dimensional spaces), and joint distribution learning (without disentangling modality-specific vs. shared information). The fundamental issue is the lack of explicit modeling of each modality's unique contribution.
- Goal: To precisely identify "what information is lost" under missing-modality conditions and to recover it in a targeted manner.
- Key Insight: Modality representations are algebraically decomposed within a learned low-rank shared subspace into modality-specific and shared components. The shared component can be deterministically recovered from any available modality, while the specific component is generated by a conditional diffusion model.
- Core Idea: An algebraically invertible decomposition enables a precise "recover what is missing" reconstruction strategy.
Method¶
Overall Architecture¶
Inputs: A set of patch features \(P\) from pathology WSIs and a set of genomic tokens \(G\). Each modality is encoded by its respective encoder to obtain global representations \(g_P, g_G\). Bidirectional cross-attention extracts the information each modality carries about the other: \(c_{P\leftarrow G}, c_{G\leftarrow P}\). Self-attention then extracts modality-specific components \(u_P, u_G\). All components are projected into a low-rank shared subspace, and algebraic decomposition is performed: \(g_P = \hat{u}_P + \hat{c}_{G\leftarrow P}\). Under complete data, the three components \([\hat{u}_P; \hat{c}; \hat{u}_G]\) are concatenated and fed into a prediction head to output discrete risk probabilities. Under missing modalities, the shared component is deterministically recovered via algebraic relations, and the missing modality-specific component is generated by the LDM.
Key Designs¶
-
Low-Rank Shared Subspace Algebraic Decomposition:
- Function: Decomposes global representations into modality-specific and shared components.
- Mechanism: A learnable low-rank projection matrix \(P_\cap = B_\cap B_\cap^T\) (\(B_\cap \in \mathbb{R}^{D\times r}\), \(r\ll D\)) satisfying idempotency is constructed. Shared components are projected into the subspace, while specific components are projected into the orthogonal complement. Three constraints are imposed: shared consistency (cross-attention outputs from both directions are consistent), inter-modality orthogonality (\(\hat{u}_P \perp \hat{u}_G\)), and intra-modality orthogonality (\(\hat{u}_m \perp \hat{c}_m\)).
- Design Motivation: Unlike ShaSpec's implicit distribution alignment, algebraic constraints guarantee that the shared component can be deterministically recovered from any available modality, providing a mathematical guarantee for missing-modality reconstruction.
-
Conditional Latent Diffusion Model (LDM) for Generating Missing Specific Components:
- Function: Provides high-quality generation for modality-specific information that is genuinely not recoverable from other modalities.
- Mechanism: After freezing the main network parameters, a 4-layer Transformer denoising network is trained. The recovered shared component \(\hat{c}\) and a learned modality-specific CLS token \([\text{CLS}_{u}]\) serve as conditions; the missing \(\hat{u}\) is generated via DDIM sampling over 50 steps. At inference, 5 samples are generated and averaged to reduce stochasticity.
- Design Motivation: Constraining stochastic generation to "truly modality-specific residuals" rather than the entire representation space substantially reduces the generative difficulty.
-
Progressive Two-Stage Training:
- Function: Ensures stable and convergent training.
- Mechanism: In the first stage, each modality encoder is trained with survival loss and Gaussian noise injection, allowing it to first learn meaningful task-relevant features. In the second stage, decomposition loss \(\mathcal{L}_{\text{decomp}}\), shared consistency loss \(\mathcal{L}_{\text{shared}}\), and orthogonality loss \(\mathcal{L}_{\text{orth}}\) are introduced.
- Design Motivation: Direct end-to-end training of the decomposition framework is prone to degenerate solutions. Staged training ensures encoders acquire semantic representations before structured decomposition is imposed.
Loss & Training¶
- Stage 1: \(\mathcal{L}_{\text{warm}} = \mathcal{L}_{\text{surv}}(\phi([g_P; \epsilon_P])) + \mathcal{L}_{\text{surv}}(\phi([g_G; \epsilon_G]))\)
- Stage 2: \(\mathcal{L}_{\text{main}} = \mathcal{L}_{\text{surv}} + \lambda_{\text{dec}}\mathcal{L}_{\text{decomp}} + \lambda_{\text{sh}}\mathcal{L}_{\text{shared}} + \lambda_{\text{orth}}\mathcal{L}_{\text{orth}}\)
- LDM stage: Standard diffusion denoising loss \(\mathcal{L}_{\text{LDM}} = \mathbb{E}[\|\epsilon - \epsilon_\theta(z_t, t, \text{cond})\|^2]\)
- Hyperparameters: \(\lambda_{\text{dec}}=1.0,\ \lambda_{\text{sh}}=1.0,\ \lambda_{\text{orth}}=0.5\); shared subspace rank \(r=64\); feature dimension \(D=256\).
Key Experimental Results¶
Main Results¶
C-index comparison across 5 TCGA cancer datasets (BLCA/BRCA/GBMLGG/LUAD/UCEC):
| Method | Setting | BLCA | BRCA | GBMLGG | LUAD | UCEC | Overall |
|---|---|---|---|---|---|---|---|
| CMTA | Both modalities | 0.691 | 0.648 | 0.857 | 0.667 | 0.755 | 0.724 |
| MUST | Both modalities | 0.703 | 0.690 | 0.864 | 0.686 | 0.768 | 0.742 |
| LD-CVAE | Missing genomics | 0.651 | 0.649 | 0.831 | 0.629 | 0.726 | 0.697 |
| MUST | Missing genomics | 0.673 | 0.651 | 0.864 | 0.637 | 0.755 | 0.716 |
| ShaSpec | Missing pathology | 0.636 | 0.629 | 0.823 | 0.610 | 0.682 | 0.676 |
| MUST | Missing pathology | 0.702 | 0.692 | 0.865 | 0.690 | 0.748 | 0.739 |
Ablation Study¶
| Configuration | C-index (Overall) | Note |
|---|---|---|
| Without warm-up | −0.6–3.5% | Varies by dataset; UCEC most affected |
| LDM conditioned on \(\hat{c}\) only | Missing G: 0.712, Missing P: 0.732 | Lacks structural prior |
| LDM conditioned on \([\hat{c}; \text{CLS}]\) | Missing G: 0.716, Missing P: 0.739 | CLS token provides modality structural prior |
Key Findings¶
- Performance drops only 0.4% when pathology is missing (0.742→0.739), versus 3.5% when genomics is missing (0.742→0.716)—indicating that the LDM exerts a "regularizing denoising" effect on high-dimensional noisy patch features.
- On BRCA/GBMLGG/LUAD, performance marginally improves when pathology is missing, as the diffusion generation process filters high-frequency noise from WSIs.
- Decomposition fidelity (cosine similarity) ranges from 0.75 to 0.94, validating the effectiveness of algebraic decomposition.
- On an A6000 GPU, full-data inference takes ≤70 ms; missing-modality inference takes 879 ms (50-step DDIM × 5 samples), which is clinically acceptable.
Highlights & Insights¶
- The algebraically invertible design is particularly elegant: Unlike ShaSpec's distribution alignment, MUST uses low-rank projection with orthogonality constraints to enable precise recovery of the shared component, strictly confining uncertainty to the modality-specific component. This reduces missing-modality handling to "deterministic recovery + bounded stochastic generation."
- The "missing improves performance" phenomenon warrants attention: LDM-generated pathology-specific components naturally filter high-dimensional WSI noise through the diffusion denoising process, suggesting a potential "augmentation-style inference" strategy.
- The combination of progressive training and noise injection is transferable to other multimodal decomposition settings.
Limitations & Future Work¶
- Only two modalities (pathology + genomics) are handled; extending to \(N\) modalities incurs quadratic growth in pairwise cross-attention complexity.
- LDM inference takes 879 ms (5-sample averaging), which is marginally acceptable in clinical settings but remains relatively slow.
- Decomposition fidelity of 0.75–0.94 indicates that algebraic decomposition is imperfect; recovered shared components may introduce errors in low-fidelity cases.
- Lighter-weight generative models (e.g., Flow Matching) could be explored to replace DDIM and reduce the number of sampling steps.
Related Work & Insights¶
- vs. ShaSpec: Both attempt to separate shared and specific information, but ShaSpec relies on distribution alignment (head distillation) without algebraic invertibility guarantees, resulting in larger degradation under missing modalities (4.7% vs. 3.5%).
- vs. LD-CVAE: Performs joint distribution learning without disentangling contributions, and cannot handle missing-pathology scenarios (unidirectional architecture); MUST is bidirectionally symmetric.
- vs. CMTA: Also uses cross-attention but lacks a missing-modality mechanism. MUST demonstrates that "cross-attention alone is insufficient—an algebraic framework is needed to prevent modality collapse."
Rating¶
- Novelty: ⭐⭐⭐⭐ The combination of algebraic decomposition and conditional diffusion is creative, though the overall paradigm of decompose-then-generate is not entirely novel.
- Experimental Thoroughness: ⭐⭐⭐⭐⭐ Five datasets, three evaluation settings, comprehensive ablation, Kaplan–Meier curve analysis, and inference latency analysis.
- Writing Quality: ⭐⭐⭐⭐ Mathematical formulations are clear, though the dense notation creates a high barrier for first-time readers.
- Value: ⭐⭐⭐⭐ Missing modalities in clinical settings represent a genuine and prevalent problem; the proposed method demonstrates strong practical utility.