FEDTAIL: Federated Long-Tailed Domain Generalization with Sharpness-Guided Gradient Matching¶
Conference: ICML 2025
arXiv: 2506.08518
Code: https://github.com/sunnyinAI/FedTail
Area: LLM Evaluation
Keywords: federated learning, Domain Generalization, Long-Tailed, Sharpness-Aware Minimization, Gradient Coherence
TL;DR¶
FedTAIL proposes a federated domain generalization framework that simultaneously addresses the dual challenges of domain shift and long-tailed class imbalance through three modules: gradient coherence regularization, class-wise sharpness-aware minimization, and curvature-aware dynamic weighting, achieving SOTA performance on multiple benchmarks.
Background & Motivation¶
Background: Domain Generalization (DG) aims to train models that can generalize to unseen target domains. Sharpness-Aware Minimization (SAM) improves generalization by seeking flat minima.
Limitations of Prior Work: Standard SAM operates globally, neglecting curvature differences across classes, which may cause tail classes to converge to saddle points in long-tailed scenarios. Furthermore, gradients of classification loss and adversarial domain alignment loss may conflict.
Key Challenge: In federated scenarios, data is naturally non-i.i.d. and exhibits long-tailed distributions, simultaneously suffering from both domain shift and class imbalance.
Key Insight: Unifying gradient coordination, class-aware regularization, and conditional distribution alignment into a scalable framework.
Core Idea: Compute class-wise SAM perturbations \(\epsilon_c\) and dynamically weight them by the inverse of the maximum eigenvalue of the class Hessian.
Method¶
Overall Architecture¶
Feature extractor \(F_\theta\) + classifier \(T_\phi\) + domain discriminator \(D_\psi\), trained federatedly across multiple clients. The total loss is: \(\mathcal{L}_{\text{FedTAIL}} = \mathcal{L}_{\text{cls}} + \mathcal{L}_{\text{adv}} + \mathcal{L}_{\text{sharp-er}} + \sum_c \gamma_c \mathcal{L}_c + \mathcal{L}_{\text{coh}}\)
Key Designs¶
-
Gradient Coherence Regularization:
- Function: Alleviate gradient conflicts between classification and adversarial domain alignment.
- Mechanism: \(\mathcal{L}_{\text{coh}} = -\alpha \langle \nabla_\theta \mathcal{L}_{\text{cls}}, \nabla_\theta \mathcal{L}_{\text{adv}} \rangle\), punishing the negative inner product of the two gradient directions.
- Design Motivation: Ensure that domain alignment does not compromise classification performance.
-
Class-wise Sharpness-Aware Minimization (Class-wise SAM):
- Function: Compute SAM perturbations separately for each class.
- Mechanism: \(\epsilon_c = \rho \cdot \nabla_\theta \mathcal{L}_c / \|\nabla_\theta \mathcal{L}_c\|_2\), then \(\mathcal{L}_{\text{sharp}} = \sum_c \mathbb{E}_{(x,y=c)}[\ell(h_{\theta+\epsilon_c}(x), y)]\).
- Introducing curvature-aware weights: \(\gamma_c = 1/(1 + \sigma_{\max}(\nabla^2 \mathcal{L}_c))\), where larger curvature (high-frequency/tail classes) \(\rightarrow\) larger weight.
- Design Motivation: Global SAM fails to capture differences between classes, and tail classes require more attention.
-
Sharpness-Aware Conditional Distribution Alignment (Sharpness-Aware ER):
- Function: Inject SAM perturbations into entropy regularization.
- Mechanism: \(\mathcal{L}_{\text{sharp-er}} = \sum_i \text{KL}(P_i(Y|F(X)) \| Q_T(Y|F(X+\epsilon)))\).
- Design Motivation: Traditional entropy regularization amplifies the gradients of easily transferrable samples while neglecting hard samples.
Loss & Training¶
Federated Averaging (FedAvg) aggregates client updates, and each client locally computes gradients, class-wise perturbations, and sharpness-aware updates.
Key Experimental Results¶
Main Results¶
| Dataset | Metric | FedTAIL | Prev. SOTA | Gain |
|---|---|---|---|---|
| PACS | Avg Acc | 88.9% | 87.6% (SAMALTDG) | +1.3% |
| OfficeHome | Avg Acc | 71.4% | 69.8% | +1.6% |
| Digits-DG | Avg Acc | 88.5% | 86.9% | +1.6% |
| mini-DomainNet | Avg Acc | 73.2% | 71.5% | +1.7% |
Ablation Study¶
| Configuration | PACS Avg | Description |
|---|---|---|
| Full FedTAIL | 88.9% | Complete model |
| w/o Gradient Coherence | 87.1% | Without gradient coherence, drops by 1.8% |
| w/o Class-wise SAM | 87.5% | Without class-wise SAM, drops by 1.4% |
| w/o Curvature Weighting | 88.0% | Without curvature weighting, drops by 0.9% |
Key Findings¶
- Gradient coherence regularization contributes the most, indicating that the classification-adversarial gradient conflict is the primary bottleneck.
- Class-wise SAM yields more pronounced effects under severe long-tailed imbalance.
- The approach is effective in both federated and centralized settings.
Highlights & Insights¶
- Connects the gradient flow analysis of entropy regularization with the long-tailed distribution problem, revealing the mechanism where high-confidence samples dominate gradients.
- The curvature-aware weight \(\gamma_c\) utilizes the maximum eigenvalue of the Hessian to automatically identify undertrained tail classes.
- The modules of the framework are decoupled and can be flexibly combined.
Limitations & Future Work¶
- The computation of the maximum eigenvalue of the Hessian is costly, which is not discussed in detail regarding efficiency.
- Experiments are mainly conducted on medium and small-scale datasets; large-scale federated scenarios remain to be validated.
- It assumes that all clients share the same model architecture, leaving heterogeneous scenarios unaddressed.
Rating¶
- Novelty: ⭐⭐⭐⭐ Innovation lies in the multi-module combination, though individual modules are incremental.
- Experimental Thoroughness: ⭐⭐⭐⭐ Comprehensive benchmarks and detailed ablation studies.
- Writing Quality: ⭐⭐⭐⭐ Clear description of the methodology.
- Value: ⭐⭐⭐⭐ High practical value in the crossed scenario of federated + long-tailed + DG.