Progressive Test Time Energy Adaptation for Medical Image Segmentation¶
Conference: ICCV 2025 arXiv: 2503.16616 Code: None Area: Medical Imaging Keywords: test-time adaptation, energy-based model, medical image segmentation, domain shift, shape prior
TL;DR¶
This paper proposes a progressive test-time adaptation method based on energy-based models. A shape energy model is trained as an in-distribution/out-of-distribution discriminator; at test time, energy minimization guides the segmentation model to adapt to the target domain. The method consistently outperforms baselines across 8 public datasets covering cardiac, spinal cord, and lung segmentation tasks.
Background & Motivation¶
Distribution shift in medical image segmentation: - Inconsistent imaging protocols across hospitals (MRI sequences, scanner parameters) - Patient population heterogeneity (age, pathological status, demographics) - Models trained on the source domain exhibit significant performance degradation on target domains
Limitations of prior work:
Domain adaptation methods: Require multiple passes over target data, which is impractical in clinical settings where patient data cannot be anticipated.
Test-time training (TTT): Requires an auxiliary self-supervised task to be jointly trained with the main task.
Entropy-based TTA (TENT/EATA/SAR): General-purpose regularization methods that do not exploit shape priors specific to segmentation.
CoTTA/MEMO: Based on pseudo-labels or augmentation consistency, but lack sufficient granularity.
TEA: Applies energy-based models to classification TTA, but produces only a single global energy value, which is insufficiently fine-grained.
Core motivation: Segmentation tasks possess strong shape priors (e.g., cardiac anatomy), which can be leveraged to assess whether a predicted shape is anatomically plausible. An energy-based model can serve as a patch-level shape discriminator, identifying erroneous regions and guiding the segmentation model to correct them.
Method¶
Overall Architecture¶
The method consists of two stages: 1. Preparation stage (source domain): Train a shape energy model \(g_\phi(\cdot)\). 2. Adaptation stage (target domain): Freeze the energy model and progressively update the BatchNorm layers of the segmentation model \(f_\theta(\cdot)\).
Key Designs¶
-
Region-based Energy Model:
- A fully convolutional network maps the segmentation map \(\hat{S}\) to a \(K \times K\) energy map: \(g_\phi(\hat{S}): \mathbb{R}^{H\times W} \mapsto \mathbb{R}^{K\times K}\)
- Each patch (of size \(h \times w\), where \(h=H/K,\ w=W/K\)) corresponds to one energy value.
- Low energy = in-distribution (correct shape); high energy = out-of-distribution (erroneous prediction).
- Formulated as a binary classification task, trained with a patchwise BCE loss:
\(\mathcal{L}_\phi = \frac{1}{N_p}\sum_{i=1}^{N_p} \left(-y_s^i\log\sigma(-g_\phi(s_s^i)) - (1-y_s^i)\log(1-\sigma(-g_\phi(s_s^i)))\right)\)
- Design motivation: A single global energy value lacks granularity; patch-level energy enables localization of specific erroneous regions.
-
Adversarial Perturbation for Negative Sample Generation:
- The source domain contains only correct segmentation results, lacking out-of-distribution (erroneous) samples.
- FGSM is applied to introduce adversarial perturbations to input images: \(\epsilon = \delta \cdot \text{sign}(\nabla_{I_s}\mathcal{L}(f_\theta(I_s), S_s))\)
- The perturbed input is passed through the segmentation network to produce erroneous segmentations \(\tilde{S}_s = f_\theta(I_s + \epsilon)\).
- Spatial affine transformations and pixel-level noise are additionally applied to increase diversity.
- Classification labels are generated by comparing perturbed segmentations against ground truth: \(y_s = 1 - \mathbf{1}(d(\tilde{s}_s, s_s) < \tau)\)
- Design motivation: Adversarial perturbations push data toward low-density regions (natural OOD regions), and the segmentation network's constraints ensure the generated errors remain anatomically plausible.
-
Progressive Energy Adaptation:
- At test time, the energy model \(g_\phi\) is frozen; only the BatchNorm parameters of \(f_\theta\) are updated.
- Objective: align predicted energy values toward a reference low-energy target (an all-zero matrix \(\mathbf{0}_{K\times K}\)).
- Adaptation objective:
\(\theta^* = \arg\min_\theta -\sum_{i=1}^{B_t}\log(1-\sigma(-g_\phi(\hat{s}_t^i)))\)
- Adam optimizer is used with 10 iterations per sample; model weights are restored after each batch.
- Design motivation: Minimizing energy values encourages the segmentation model to produce predictions consistent with natural anatomical structures.
Loss & Training¶
- Energy model training: BCE loss; patch size \(h=w=16\); mean absolute difference as the distance metric; threshold \(\tau=50\).
- Test-time adaptation: Adam optimizer, 10 iterations per sample.
- BatchNorm-only updates: Following standard TTA practice; weights are restored after each batch.
- Adversarial perturbation: Dice Loss is used as the objective for FGSM.
Key Experimental Results¶
Main Results (Tables)¶
Cardiac segmentation (ACDC → other datasets, UNet backbone):
| Method | LVQuant LV DSC↑ | LVQuant Myo DSC↑ | MyoPS LV DSC↑ | M&M LV DSC↑ | M&M Myo DSC↑ | Avg Rank |
|---|---|---|---|---|---|---|
| Pretrained | 58.98 | 42.52 | 85.69 | 47.69 | 41.19 | 4.33 |
| TENT | 65.78 | 51.57 | 85.63 | 57.01 | 48.26 | 2.92 |
| CoTTA | 64.58 | 50.52 | 85.64 | 52.98 | 46.72 | 3.67 |
| TEA | 67.96 | 54.10 | 85.88 | 52.83 | 48.06 | 2.92 |
| Ours | 76.93 | 59.43 | 86.06 | 61.84 | 53.13 | 1.08 |
Spinal cord segmentation (GMSC Site 1 → others, single class):
| Method | 1→2 | 1→3 | 1→4 | 4→1 | 4→2 | 4→3 | Avg DSC |
|---|---|---|---|---|---|---|---|
| TENT | 70.5 | 16.8 | 57.4 | 87.0 | 67.9 | 72.9 | 62.1 |
| CoTTA | 66.1 | 63.3 | 92.1 | 95.0 | 54.7 | 86.7 | 76.4 |
| TEA | 68.4 | 66.5 | 92.4 | 94.9 | 54.7 | 86.7 | 77.3 |
| InTENT | 86.6 | 28.7 | 71.4 | 83.3 | 79.2 | 75.0 | 70.7 |
| Ours | 73.6 | 77.7 | 95.3 | 95.1 | 56.2 | 87.2 | 80.9 |
Lung segmentation (CHN X-ray → others):
| Method | CHN→MCU DSC | CHN→JSRT DSC | Avg DSC |
|---|---|---|---|
| TENT | 86.2 | 95.2 | 90.7 |
| CoTTA | 95.8 | 95.2 | 95.5 |
| TEA | 95.7 | 95.5 | 95.6 |
| InTENT | 95.5 | 96.3 | 95.9 |
| Ours | 96.1 | 96.3 | 96.2 |
Ablation Study (Tables)¶
Adaptation performance across different segmentation backbones (ACDC → LVQuant LV DSC):
| Backbone | Pretrained | TENT | CoTTA | TEA | Ours | Avg Rank |
|---|---|---|---|---|---|---|
| UNet | 58.98 | 65.78 | 64.58 | 67.96 | 76.93 | 1.08 |
| MedNeXt | 57.55 | 75.10 | 74.57 | 75.85 | 76.22 | 1.00 |
| SwinUNETR | 68.44 | 74.06 | 73.41 | 74.32 | 76.05 | 1.25 |
Adaptation performance across different source domains (M&M → others, UNet):
| Method | LVQuant LV DSC | MyoPS LV DSC | ACDC LV DSC | Avg Rank |
|---|---|---|---|---|
| Pretrained | 89.08 | 75.80 | 40.84 | 4.08 |
| TENT | 92.03 | 77.34 | 52.74 | 3.67 |
| TEA | 92.27 | 77.75 | 56.68 | 3.00 |
| Ours | 93.25 | 79.14 | 59.97 | 1.08 |
Key Findings¶
- Achieves the lowest average rank (1.0–1.33) across three segmentation backbones (UNet/MedNeXt/SwinUNETR), demonstrating backbone-agnostic effectiveness.
- On cardiac segmentation with the UNet backbone, LV DSC improves from 58.98% (pretrained) to 76.93% (adapted), a gain of nearly 18 percentage points.
- The energy model achieves OOD detection accuracy exceeding 92%, effectively identifying erroneous regions.
- The method remains effective for single-class segmentation tasks such as spinal cord and lung segmentation, achieving average DSC of 80.9% and 96.2%, respectively.
- Compared to entropy-based methods such as TENT, the shape-prior-guided energy approach shows a more pronounced advantage under large distribution shifts.
Highlights & Insights¶
- First energy-based model for TTA in medical segmentation: Innovatively employs an energy-based model as an implicit encoder of shape priors, replacing traditional explicit shape parameterization.
- Adversarial perturbation for training data generation: Cleverly leverages FGSM to explore the space of erroneous segmentations, eliminating the need for additional OOD data collection.
- Region-level vs. global energy: Patch-level energy discrimination is more fine-grained than TEA's single global value, enabling localization of specific erroneous regions.
- Backbone-agnostic design: The method can be applied as a plug-and-play module to any segmentation network without requiring specific architectural modifications.
- Progressive adaptation: Each image is independently adapted before weights are restored, preventing error accumulation.
Limitations & Future Work¶
- Ten optimization iterations per sample for BatchNorm updates reduce inference speed.
- Updating only BatchNorm layers may limit adaptation capacity, particularly for architectures with few BatchNorm parameters.
- The discriminative capability of the energy model depends on the diversity of the source domain and the quality of the adversarial perturbations.
- Hyperparameters such as the perturbation magnitude \(\delta\) and patch size require careful tuning.
- The method has not been validated on 3D volumetric segmentation tasks.
Related Work & Insights¶
- The concept of using energy-based models as shape priors is generalizable to other dense prediction tasks that require structural constraints.
- The strategy of generating negative samples via adversarial perturbations offers broader inspiration for general OOD detection.
- The idea of region-level energy discrimination can be combined with hierarchical energy models to enable more fine-grained adaptation.
Rating¶
- Novelty: ⭐⭐⭐⭐⭐ First application of energy-based models to TTA for medical segmentation; the adversarial perturbation strategy for negative sample generation is elegant.
- Experimental Thoroughness: ⭐⭐⭐⭐⭐ Evaluated on 8 datasets, 3 backbones, 3 organ types, and multiple imaging modalities — highly comprehensive.
- Writing Quality: ⭐⭐⭐⭐ Mathematical derivations are rigorous and method descriptions are clear, though notation density is occasionally high.
- Value: ⭐⭐⭐⭐⭐ High clinical utility; the backbone-agnostic design lowers the barrier to adoption, with significant gains under large distribution shifts.