Skip to content

Diverse Prototypical Ensembles Improve Robustness to Subpopulation Shift

Conference: ICML2025
arXiv: 2505.23027
Code: minhto2802/dpe4subpop
Area: others (Robustness / Distribution Shift)
Keywords: subpopulation shift, prototype classifier, ensemble diversity, worst-group accuracy, distribution robustness

TL;DR

The paper proposes Diversified Prototypical Ensemble (DPE), which replaces the standard linear classification head with multiple diverse prototype classifiers. By utilizing both explicit (inter-prototype similarity loss) and implicit (bootstrap sampling) diversification strategies, DPE adaptively discovers subpopulation decision boundaries without requiring subpopulation annotations, significantly improving worst-group accuracy.

Background & Motivation

Problem Definition

Subpopulation shift refers to the difference in subpopulation distributions between training and testing sets, which is a common form of distribution shift. Yang et al. (2023) categorize it into four types:

Spurious Correlations: Non-causal features mislead predictions (e.g., background water \(\to\) waterbirds).

Attribute Imbalance: Certain attribute values appear far more frequently than others.

Class Imbalance: Some labels are severely underrepresented.

Attribute Generalization: Attribute values unseen during training appear at test time.

Limitations of Prior Work

  • Classifiers trained via ERM tend to learn features of majority subpopulations, performing poorly on minority subpopulations.
  • Methods like gDRO and JTT rely on subpopulation annotations, which are often unavailable in real-world scenarios.
  • Methods that explicitly identify minority groups increase complexity and struggle to generalize to unseen subpopulations.
  • Simple resampling/reweighting (e.g., CRT, DFR) is effective but still limited by known subpopulation structures.

Design Motivation

A single classifier can only learn a single decision boundary, which is easily dominated by majority subpopulations. If an ensemble is employed and diversity among members is explicitly encouraged, different members can capture decision boundaries of different subpopulations—even without subpopulation labels. Prototypical classifiers naturally preserve the geometric structure of the feature space, making them suitable for discovering subpopulations with limited data.

Method

Overall Architecture: Two-Stage Training

Stage 1: Train a feature extractor \(f: \mathbb{R}^n \to \mathbb{R}^d\) using standard ERM on the entire training set, then freeze \(f\). This follows the findings of Kirichenko et al. (2022) and Izmailov et al. (2022) showing that even in the presence of spurious correlations, the feature representations learned by ERM still contain core discriminative information.

Stage 2: Train the DPE classification head (replacing the original linear layer) on a class-balanced subset of the validation set, without requiring subpopulation annotations.

Prototypical Classifier

Given \(K\) classes, each class defines a learnable prototype \(p^{(i)} \in \mathbb{R}^d\). The classification probability is based on the distance from the feature to each class prototype:

\[P(y|X) = \frac{\exp(-D(f(X), p^{(y)}))}{\sum_{i=1}^{K} \exp(-D(f(X), p^{(i)}))}\]

where the distance function \(D\) is the scaled Euclidean distance (on normalized vectors):

\[D(x, y) = |d_s| \cdot \left\| \frac{x}{\|x\|} - \frac{y}{\|y\|} \right\|_2\]

\(d_s\) is a learnable scaling factor. The loss function introduces a temperature \(\tau\):

\[\mathcal{L}(X, y) = -\log \frac{\exp(-D(f_\theta(X), p^{(y)}) / \tau)}{\sum_{i=1}^{K} \exp(-D(f_\theta(X), p^{(i)}) / \tau)}\]

Prototypical Ensemble

Each class uses \(N\) prototypes, resulting in the set \(\{p_j^{(i)}\}_{i=1,...,K,\; j=1,...,N}\). The final prediction is the average probability across the \(N\) members:

\[\hat{y} = \arg\max_{k \in \{1,...,K\}} \frac{1}{N} \sum_{j=1}^{N} P_j^{(k)}(y|X)\]

Key Designs

1. Explicit Diversification: Inter-Prototype Similarity (IPS) Loss

For the \(n\)-th ensemble member, the IPS loss penalizes the similarity between prototypes of the same class:

\[\mathcal{L}_{\text{IPS}} = \sum_{k=1}^{K} \sum_{i=1}^{n} \sum_{j=1}^{n} \mathbb{1}_{\{i \neq j\}} \frac{|\langle p_i^{(k)}, p_j^{(k)} \rangle|}{n \cdot d}\]
  • Scaled by \(n\) (the current number of members) and \(d\) (embedding dimension).
  • When training the \(n\)-th member, the prototypes of the first \(n-1\) members are frozen, and only the current member's \(\{p_n^{(k)}\}_{k=1,...,K}\) are optimized.
  • Total Loss = \(\mathcal{L}(X, y) + \mathcal{L}_{\text{IPS}}\)

2. Implicit Diversification: Bootstrap Aggregation

Each ensemble member is trained on a different class-balanced subset of the validation set; the differences between random subsets implicitly encourage different members to learn distinct decision boundaries.

Training Protocol

Ensemble members are trained sequentially (not in parallel), with each new member maintaining diversity relative to the frozen preceding members via the IPS loss.

Key Experimental Results

Datasets

9 real-world datasets across vision and NLP covering four types of subpopulation shift:

Dataset Area Shift Type
Waterbirds Vision Spurious Correlation
CelebA Vision Spurious Correlation
MetaShift Vision Spurious Correlation
ImageNetBG Vision Spurious Correlation
NICO++ Vision Attribute Generalization
Living17 Vision Attribute Generalization
CheXpert Medical Imaging Attribute Imbalance
CivilComments NLP Attribute Imbalance
MultiNLI NLP Spurious Correlation

Main Results: Worst-Group Accuracy (WGA)

Under conditions without subpopulation annotations (ERM backbone):

Method Waterbirds CelebA CivilComments MultiNLI MetaShift ImageNetBG NICO++ Living17
ERM 69.1 57.6 63.2 66.4 82.1 76.8 35.0 48.0
CRT 76.3 69.6 67.8 65.4 83.1 78.2 33.3 -
DFR 89.0 73.7 64.4 63.8 81.4 74.4 38.0 -
ERM+DPE 91.0 81.9 69.9 69.3 84.1 87.9 50.0 54.0

Key Observations: - DPE achieves the best WGA on 8 out of 9 datasets. - Outperforms DFR by 2.0% on Waterbirds and by 8.2% on CelebA. - Significant improvement on ImageNetBG: 87.9% vs. 78.2% of CRT (+9.7%). - Most notable improvement on NICO++ (Attribute Generalization): 50.0% vs. 38.0% of DFR (+12.0%). - Reaches 54.0% on Living17, whereas CRT/DFR cannot even report results.

Comparison with other methods using a stronger ERM* backbone:

Method Waterbirds CelebA CivilComments MultiNLI
ERM* 77.9 66.5 69.4 66.5
RWY 86.1 82.9 67.5 68.0
AFR 90.4 82.0 68.7 -

DPE using a standard ERM backbone already outperforms or matches these methods that use stronger backbones.

Highlights & Insights

  1. No Subpopulation Annotations Required: Unlike most competing methods that require explicit subpopulation labels or a known number of subpopulations, DPE automatedly discovers the subpopulation structures.
  2. Plug-and-Play: Simply replacing the final linear classification layer and freezing the feature extractor incurs minimal training cost.
  3. Visual Verification: On the Waterbirds dataset, different prototypes indeed capture different semantic subpopulations (e.g., "birds on land" vs. "birds in water"), verifying the effectiveness of the diversification strategy.
  4. Simple IPS Loss Design: Normalizing and summing the absolute values of inner products offers an elegant yet highly effective way to scatter prototypes.
  5. Broad Applicability: Covers vision + NLP, spurious correlation + attribute imbalance + attribute generalization, providing comprehensive coverage.

Limitations & Future Work

  1. Choice of Ensemble Size \(N\): The paper does not thoroughly discuss how to select the optimal \(N\), only showing it in parts of the ablation studies; tuning \(N\) is challenging when validation subpopulation labels are unavailable.
  2. Missing CheXpert Results: In Table 1, the CheXpert column for the ERM+DPE row is "-", with no explanation provided. This is unfortunate given that medical imaging is an important application scenario.
  3. Sequential Training: Ensemble members must be trained sequentially, preventing parallelization, so training time grows linearly as \(N\) increases.
  4. Only Replacing the Linear Layer: Freezing the feature extractor implies that if the ERM features themselves are of poor quality, the upper bound of DPE is capped.
  5. Choice of Distance Function: Only scaled Euclidean distance is used, without exploring more flexible metrics such as Mahalanobis distance.
  6. Insufficient Theoretical Analysis: Lacks theoretical guarantees explaining why a diversified prototypical ensemble is guaranteed to cover all subpopulations.
  • Kirichenko et al. (2022) & Izmailov et al. (2022): Demonstrated that ERM features are sufficient; DPE builds on this by freezing features and retraining the classification head.
  • Snell et al. (2017): Groundbreaking work on Prototypical Networks; DPE extends them from few-shot learning to the subpopulation robustness setting.
  • DivDis (Lee et al., 2022): Promotes ensemble diversity through disambiguation; DPE borrows ideas of explicit diversification.
  • D-BAT (Pagliardini et al., 2023): Ensemble learning with source distribution consensus and OOD disagreement; inspired DPE's diversification direction.
  • SubpopBench (Yang et al., 2023): A unified evaluation benchmark upon which DPE was comprehensively validated.

Rating

  • Novelty: ⭐⭐⭐⭐ — Combining prototype classifiers with explicit/implicit diversification for subpopulation robustness is novel, though individual components are not entirely new.
  • Experimental Thoroughness: ⭐⭐⭐⭐ — 9 datasets, 4 shift types, and multiple baselines provide extensive coverage and robust results, but points are docked for the missing CheXpert results.
  • Writing Quality: ⭐⭐⭐⭐ — Well-written and clear, with intuitive motivational figures (Fig 1-2) and a rigorous method description.
  • Value: ⭐⭐⭐⭐ — Simple, practical, plug-and-play, and requiring no subpopulation annotations, offering high value for real-world deployment.