Ambiguity-aware Truncated Flow Matching for Ambiguous Medical Image Segmentation¶
Conference: AAAI 2026 arXiv: 2511.06857v2 Code: https://github.com/PerceptionComputingLab/ATFM Area: Medical Imaging Keywords: Ambiguous medical image segmentation, truncated diffusion model, Flow Matching, Gaussian truncation representation, semantic consistency
TL;DR¶
This paper proposes the ATFM framework, which decouples prediction accuracy and diversity into distribution-level and sample-level optimization through a data-hierarchical inference paradigm. By integrating two modules — Gaussian Truncation Representation (GTR) and Segmentation Flow Matching (SFM) — ATFM simultaneously improves prediction accuracy, fidelity, and diversity in ambiguous medical image segmentation.
Background & Motivation¶
Ambiguous Medical Image Segmentation (AMIS) aims to generate multiple plausible segmentation predictions for a single medical image, reflecting the inherent inter-annotator ambiguity. Clinically, high diversity captures the intrinsic ambiguity of the image, while high accuracy supports reliable diagnostic decisions — both are indispensable.
Existing methods face an inherent trade-off between accuracy and diversity: - Stochastic methods (Prob. U-Net, PHiSeg): sacrifice accuracy for diversity, yielding low-confidence diagnoses - Multi-annotator-aware methods: improve both by modeling annotator styles, but suppress low-frequency patterns, degrading segmentation quality - cVAE/diffusion methods: inject randomness to enhance diversity, but single-stage inference couples accuracy and diversity optimization
Truncated Diffusion Probabilistic Models (TDPMs) show promise through an inference paradigm shift — truncating the diffusion process at \(T_{\text{trunc}}\) and estimating the truncated distribution with an auxiliary network. However, directly applying TDPMs introduces three issues: (1) a unified inference objective cannot decouple accuracy from diversity; (2) sampling-based approximation of the truncated distribution leads to insufficient fidelity; (3) lack of semantic guidance after truncation impairs prediction plausibility.
Core Problem¶
How to simultaneously improve prediction accuracy and diversity in ambiguous medical image segmentation, breaking the inherent trade-off observed in existing methods?
Method¶
Overall Architecture¶
ATFM consists of three core components, organized around a redefined inference paradigm:
- Data-Hierarchical Inference: Redefines the AMIS-specific inference paradigm, improving accuracy at the data distribution level and enhancing diversity at the data sample level
- GTR (Gaussian Truncation Representation): Explicitly models a Gaussian distribution at the truncation point to improve prediction fidelity and the reliability of the truncated distribution
- SFM (Segmentation Flow Matching): Introduces semantics-aware flow transformation to enhance diversity while ensuring prediction plausibility
Key Designs¶
Data-Hierarchical Inference¶
The core idea is to decouple accuracy and diversity by marginalizing randomness in the diffusion process:
- Distribution level (before truncation): Marginalizes randomness across multiple annotator samples \(\{s_1, s_2, \dots, s_n\}\); supervises the explicit distribution \(P \sim \mathcal{N}(\mu, \Sigma)\) at the truncation point to approximate the true annotation distribution \(Q \sim \mathcal{N}'(\mu', \Sigma')\), optimizing overall accuracy
- Sample level (after truncation): Samples from the high-fidelity truncated distribution and generates diverse predictions \(\{pred_i\}_{i=1}^n\) through the diffusion process, supervised per-sample against ground truth \(\{gt_i\}_{i=1}^n\), enhancing diversity
This hierarchical design ensures that accuracy optimization does not compromise diversity, and diversity enhancement is built upon a globally aligned distribution.
Gaussian Truncation Representation (GTR)¶
GTR replaces the sampling-based approximation of conventional TDPMs with explicit Gaussian modeling:
- Extracts image semantic features \(Z\) via a segmentation backbone \(f_\theta\)
- Estimates mean \(\mu\) and covariance \(\Sigma\) using independent convolutional layers \(g_\phi\) and \(h_\psi\), respectively (low-rank parameterization with rank \(r=10\))
- Truncation-point distribution: \(X_{T_{\text{trunc}}} \sim \mathcal{N}(\mu, \Sigma)\)
Theoretical justification (Theorem 1 & 2): - Theorem 1: The marginal distribution at any diffusion timestep can be parameterized as \(\mathcal{N}(\mu, DD^\top + L)\) - Theorem 2: For any Gaussian distribution, there exists a timestep at which the diffusion process produces the same distribution
Therefore, the Gaussian distribution in GTR is a valid and optimal choice for the truncated distribution.
Network architecture: Standard encoder–decoder, 4 resolution levels, encoder filter sizes 32→64→128→192, decoder uses transposed convolution upsampling and skip connections.
Segmentation Flow Matching (SFM)¶
SFM replaces DDPM with Flow Matching after truncation, and introduces semantic consistency modeling:
- Optimal Transport schedule: Linear interpolation along the shortest path from source \(X_{T_{\text{trunc}}}\) to target \(X_1\)
- \(X_t = t \times X_1 + (1-t) \times X_{T_{\text{trunc}}}\)
- ST-Net (Semantic-aware Transformation Network): Time-conditioned U-Net predicting the flow transformation direction \(g_\theta(X_t)\)
- 4-level encoder–decoder, 15 residual blocks
- Sinusoidal timestep embedding fused into each residual block via MLP
- Linear attention in all layers; full self-attention in the bottleneck
- Intermediate prediction: \(x_1^t = x_t + g_\theta(X_t) \times (1-t)\), projected in latent space via analytic geometry
- Semantic consistency: At each timestep \(t\), computes Dice loss between the intermediate prediction \(x_1^t\) and all ground-truth annotations, explicitly modeling semantic consistency among state, prediction, and ground truth
Advantage of FM over DDPM: avoids the interference of Gaussian constraints on fine-grained predictions.
Loss & Training¶
Two-stage training:
-
GTR stage: Train GTR to convergence, then freeze parameters
- \(\mathcal{L}_{\text{Prior}} = -\log \int p(Y|X_{T_{\text{trunc}}}) p(X_{T_{\text{trunc}}}|X) dX_{T_{\text{trunc}}} \approx \frac{1}{M}\sum_{i=1}^M -\log p(Y|X_{T_{\text{trunc}}}^i)\)
- Monte Carlo sampling with \(M=20\)
-
SFM stage: Train SFM on top of the frozen GTR
- \(\mathcal{L}_{\text{FM}}\): standard Flow Matching loss
- \(\mathcal{L}_{\text{SF}}\): semantic consistency loss (Dice loss)
- \(\mathcal{L}_{\text{SFM}} = \mathcal{L}_{\text{FM}} + \alpha \cdot \mathcal{L}_{\text{SF}}\)
- \(\alpha\): \(10^{-3}\) for LIDC, \(10^{-4}\) for ISIC3
Training configuration: - Single RTX 3090 (24 GB) - GTR: 1000 epochs (LIDC) / 400 epochs (ISIC3) - SFM: 200 epochs (LIDC) / 120 epochs (ISIC3) - Adam optimizer, learning rate \(10^{-4}\) - \(\lambda = 10^{-3}\) (\(T=1000\)), linear schedule
Key Experimental Results¶
| Dataset | Metric | Ours (ATFM) | Prev. SOTA | Gain |
|---|---|---|---|---|
| LIDC | GED₁₆↓ | Best | CCDM / AB | Consistently superior |
| LIDC | GED₁₀₀↓ | Best | Runner-up | 11.5% |
| LIDC | HM-IoU₃₂↑ | Best | Runner-up | ≥7.3% |
| LIDC | MDM₃₂↑ | Best | — | Leading |
| ISIC3 | GED↓ | Best | Runner-up | 12% |
| ISIC3 | HM-IoU↑ | Best | — | Best |
| ISIC3 | MDM↑ | Best | — | Best |
Inference efficiency (generating 100 samples):
| Method | Steps | Time |
|---|---|---|
| CIMD | Multi-step | 420s |
| AB | Multi-step | 1050s |
| CCDM | Multi-step | 1100s |
| ATFM | GTR + 25 steps | 113s |
ATFM is substantially faster than other diffusion-based methods, requiring only 25 diffusion steps plus a single GTR estimation.
Ablation Study¶
Five variants compared on the LIDC dataset:
| Variant | Description | Conclusion |
|---|---|---|
| Act. GTR | GTR + activation layer only | Baseline |
| SFM w/o \(\mathcal{L}_{\text{SF}}\) | SFM without semantic loss | Insufficient diversity |
| SFM | Full SFM | Strong standalone performance |
| ATFM w/o \(\mathcal{L}_{\text{SF}}\) | Full framework without semantic loss | Reduced plausibility |
| ATFM | Full framework | Best |
- ATFM outperforms Act. GTR and SFM alone by ≥10% and ≥6%, respectively, validating data-hierarchical inference and the synergy of both modules
- The average gap between models with and without \(\mathcal{L}_{\text{SF}}\) reaches 11%, highlighting the critical role of semantic consistency modeling
- 25 inference steps achieves the best balance between performance and efficiency
- An excessively small \(\alpha\) limits the effect of \(\mathcal{L}_{\text{SF}}\); an excessively large \(\alpha\) weakens \(\mathcal{L}_{\text{FM}}\)
Highlights & Insights¶
- Inference paradigm innovation: The first work to redefine the TDPMs inference paradigm as an AMIS-specific data-hierarchical inference, fundamentally decoupling accuracy and diversity
- Theoretical completeness: Theorems 1 & 2 provide rigorous theoretical guarantees for GTR's Gaussian modeling
- Coherent three-module design: Data-Hierarchical Inference + GTR + SFM each address a specific problem with clear design logic
- Significant efficiency advantage: 113s vs. 420–1100s, representing a 3.7–9.7× speedup with superior performance
- FM as a replacement for DDPM: Avoids the interference of Gaussian constraints on fine-grained segmentation, representing a well-motivated technical choice
Limitations & Future Work¶
- Limited dataset scale: Validation is restricted to LIDC (lung CT) and a subset of ISIC3 (skin lesions, only 300 images), without covering additional modalities and organs
- Fixed annotation count: LIDC uses 4 annotations and ISIC3 uses 3; the effect of varying annotation counts on the method remains unexplored
- Two-stage training: GTR must be trained to convergence before freezing, making the pipeline complex; end-to-end joint training may be preferable
- Missing table values: Quantitative table values in the HTML paper were not fully displayed, limiting direct comparison
- 3D extension: Experiments are conducted only on 2D slices; applicability to 3D volumetric segmentation remains to be verified
Related Work & Insights¶
| Method Category | Representative Work | Limitations | ATFM Advantage |
|---|---|---|---|
| Model ensembles / multi-head | SSN, MoSE | Do not alter the inference process; constrained by model selection | Redefines the inference paradigm |
| cVAE methods | Prob. U-Net, PHiSeg | Single-stage inference; accuracy and diversity are coupled | Hierarchically decouples the two objectives |
| Diffusion methods | CIMD, CCDM | Gaussian constraints interfere with fine-grained predictions; slow inference | FM avoids Gaussian constraints; efficient 25-step inference |
| Multi-annotator-aware | c-Prob. U-Net, c-SSN | Suppress low-frequency patterns | GTR explicitly models and preserves low-frequency patterns |
| Conventional TDPMs | TDPM | Inaccurate sampling approximation; lack of semantic guidance | GTR explicit modeling + SFM semantic supervision |
Broader implications:
- Redefining inference paradigms: The idea of redefining the inference process without modifying the model architecture is transferable to other generative tasks requiring multi-objective balance
- Generality of truncated diffusion: The "truncation + replacement of intermediate inference path" strategy in TDPMs reveals structural flexibility in diffusion model inference, applicable to other conditional generation tasks
- Flow Matching for segmentation: Introducing FM into segmentation tasks is an emerging trend that avoids the mismatch between DDPM's Gaussian assumption and the one-hot label space
- Explicit vs. implicit distribution modeling: GTR's explicit Gaussian modeling approach can be generalized to other scenarios requiring distribution modeling at intermediate representations
Rating¶
- Novelty: ⭐⭐⭐⭐ — The data-hierarchical inference paradigm is a meaningful contribution; the three-module design logic is clear
- Technical Depth: ⭐⭐⭐⭐ — Backed by theoretical guarantees (two theorems) with complete methodological derivation
- Experimental Thoroughness: ⭐⭐⭐ — Only two datasets, but ablation and hyperparameter analyses are comprehensive
- Writing Quality: ⭐⭐⭐⭐ — Well-structured; the "Summarized Advantage" for each module aids understanding
- Value: ⭐⭐⭐⭐ — Provides a new paradigm for AMIS; the application of FM in medical segmentation offers a valuable reference