SAMO: A Lightweight Sharpness-Aware Approach for Multi-Task Optimization with Joint Global-Local Perturbation¶
Conference: ICCV 2025 arXiv: 2507.07883 Code: GitHub Area: Model Compression Keywords: Multi-task learning, sharpness-aware minimization, gradient conflict, zeroth-order gradient estimation, layer-wise normalization
TL;DR¶
This paper proposes SAMO, a lightweight sharpness-aware multi-task optimization method that mitigates task gradient conflicts via joint global-local perturbation, while substantially reducing computational overhead through zeroth-order gradient approximation and layer-wise normalization.
Background & Motivation¶
Multi-task learning (MTL) aims to train a single model to simultaneously learn multiple tasks, leveraging shared knowledge to improve data efficiency and generalization. However, the central challenge in MTL is task conflict: gradients from different tasks may contradict each other in direction or magnitude, causing naive gradient averaging to be dominated by a single task and thereby degrading overall performance.
Existing approaches fall into two broad categories: (1) gradient manipulation methods (e.g., MGDA, CAGrad, FairGrad), which adjust gradient directions or weights to find a compromise update direction; and (2) architectural design methods (e.g., MoE, soft parameter sharing), which reduce conflict through model structure. Most of these methods, however, overlook the role of loss landscape geometry.
Sharpness-Aware Minimization (SAM) is widely applied in single-task settings, improving generalization by simultaneously minimizing the loss value and the "sharpness" of the loss landscape. Through empirical analysis, the authors identify a key insight: SAM effectively alleviates task conflicts in MTL. Specifically:
- SAM drives the model toward flatter regions, where changes in one task objective do not significantly affect others.
- SAM substantially increases the cosine similarity between task gradients (from negative to positive values).
- Sharpness metrics \(\lambda_{\max}\) and \(\lambda_{\max}/\lambda_5\) of the loss landscape both decrease markedly.
Nevertheless, integrating SAM into MTL faces two major challenges: (1) both global information (averaged gradients) and local information (per-task gradients) are beneficial for SAM, yet how to effectively combine them remains unclear—G-SAM and L-SAM each exhibit advantages depending on the method and dataset; (2) computing per-task gradients for local perturbation requires \(K\) additional backward passes, incurring substantial computational cost. The sole prior work, F-MTL, applies SAM to each individual task but introduces a \(K\)-fold increase in backward passes and doubles memory requirements.
Method¶
Overall Architecture¶
SAMO augments standard MTL gradient manipulation methods with a lightweight sharpness-aware module. The core idea is to first compute a joint global-local perturbation, evaluate gradients at the perturbed parameters, and then feed the perturbed gradients into any existing gradient manipulation method (e.g., FairGrad). The overall pipeline proceeds as: compute averaged loss gradient → approximate per-task local gradients → construct joint global-local perturbation → compute new gradients at perturbed parameters → pass into MTL method to obtain the update direction.
Key Designs¶
- Joint Global-Local Perturbation: SAMO defines the perturbation for each task \(i\) as a weighted average of the global gradient and the local gradient:
where \(\alpha \in [0,1]\) balances global and local information. The global perturbation captures positive transfer across tasks, while the local perturbation retains task-specific characteristics. Each task's perturbation direction thus accounts for both shared patterns and its own loss landscape. The design motivation stems from the empirical observation that G-SAM and L-SAM each have advantages, and their joint use achieves a better balance.
- Zeroth-Order Gradient Approximation (SPSA): To avoid computing \(K\) backward passes for each task, SAMO employs Simultaneous Perturbation Stochastic Approximation (SPSA), which estimates local gradients using only forward passes:
where \(z_i\) is a random vector sampled from a standard Gaussian distribution and \(\mu\) is a small perturbation factor. This approximation requires only \(2K\) forward passes (in place of \(K\) backward passes), and the forward pass cost \(C_f\) is substantially lower than the backward pass cost \(C_b\). The design is inspired by parameter-efficient fine-tuning (PEFT): the global perturbation serves as the "backbone" and the local perturbation acts as the "adapter."
- Layer-wise Normalization Strategy: The variance of zeroth-order gradient estimates can be large, leading to training instability. SAMO proposes layer-wise normalization, which rescales the estimated local gradient at each layer to match the magnitude of the global gradient:
where \(\theta^d\) denotes the parameters of layer \(d\). This preserves the directional information of the zeroth-order estimate while aligning its magnitude with that of the exact gradient, preventing optimization instability caused by variance fluctuations.
Loss & Training¶
SAMO introduces no additional loss functions; instead, it operates as a plug-and-play module that can be integrated into any existing MTL method, reusing the original training configuration entirely. Compared to F-MTL, SAMO's additional computational overhead is only \(C_b + 2KC_f\) (one backward pass plus \(2K\) forward passes), whereas F-MTL requires \(KC_b + C_{gm}\) (\(K\) backward passes plus the cost of gradient manipulation). Since \(C_f \ll C_b\), SAMO's runtime overhead is comparable to using global SAM alone (G-SAM).
For hyperparameters: the perturbation radius \(\rho\) is searched over \(\{0.01, 0.05, 0.1, 0.5\}\); the global-local weight \(\alpha\) is selected from \(\{0.3, 0.5, 0.7\}\); and the zeroth-order perturbation factor \(\mu\) is set to \(0.01\).
Key Experimental Results¶
Main Results¶
Cityscapes (2 tasks: semantic segmentation + depth estimation)
| Method | mIoU ↑ | Pix Acc ↑ | Abs Err ↓ | Rel Err ↓ | Δm% ↓ |
|---|---|---|---|---|---|
| STL | 74.01 | 93.16 | 0.0125 | 27.77 | — |
| FairGrad | 74.10 | 93.03 | 0.0135 | 29.92 | 3.90 |
| F-MTL (best) | 73.77 | 93.12 | 0.0129 | 27.44 | 0.67 |
| SAMO-FairGrad | 74.37 | 93.14 | 0.0129 | 26.30 | -0.62 |
NYU-v2 (3 tasks: segmentation + depth + surface normals)
| Method | mIoU ↑ | Abs Err ↓ | Angle Dist ↓ | Δm% ↓ |
|---|---|---|---|---|
| FairGrad | 38.80 | 0.5572 | 24.55 | -4.96 |
| F-MTL (best) | 40.42 | 0.5389 | 25.03 | -4.77 |
| SAMO-FairGrad | 39.05 | 0.5359 | 24.43 | -6.55 |
Ablation Study¶
| Configuration | Cityscapes Δm% | NYU-v2 Δm% | CelebA Δm% | Note |
|---|---|---|---|---|
| G-SAM-MGDA | 7.51 | -0.23 | 11.78 | Global information only |
| L-SAM-MGDA | 11.94 | 0.01 | 8.47 | Local information only |
| SAMO-MGDA | 4.30 | -2.19 | 9.59 | Joint global-local (Ours) |
| G-SAM-FairGrad | 0.93 | -5.70 | 0.41 | Global information only |
| L-SAM-FairGrad | 1.01 | -5.42 | -0.42 | Local information only |
| SAMO-FairGrad | -0.62 | -6.55 | -0.74 | Joint global-local (Ours) |
Key Findings¶
- SAMO consistently improves all three baseline methods (LS, MGDA, FairGrad) across all datasets.
- The joint global-local perturbation consistently outperforms using either global or local perturbation alone.
- Consistent improvements are maintained on CelebA (40 tasks) and QM9 (11 tasks), demonstrating scalability.
- SAMO remains effective in the multi-input setting of Office-Home, confirming that the method is not restricted to shared-input configurations.
- Runtime comparison: SAMO incurs only 2–6% additional time over baselines on Cityscapes/NYU-v2, far below the 80%+ overhead of F-MTL.
Highlights & Insights¶
- The mechanistic discovery that SAM mitigates task conflicts is a significant contribution: through Hessian spectral analysis and cosine similarity visualization, the paper quantitatively verifies that SAM's guidance toward flat regions is equivalent to alleviating task conflict.
- The combination of zeroth-order gradient approximation and layer-wise normalization is elegant—directional information is preserved while variance is controlled.
- The plug-and-play nature of the method enables integration with any existing gradient manipulation approach, offering strong practical value.
- The visualization analysis on a two-dimensional synthetic problem intuitively illustrates how SAM alters optimization trajectories.
Limitations & Future Work¶
- Zeroth-order gradient estimation is inherently approximate, with limited directional accuracy in high-dimensional spaces; single-sample estimation may introduce substantial noise.
- \(\alpha\) is treated as a fixed hyperparameter requiring manual tuning; adaptive adjustment strategies merit exploration.
- Experiments do not cover large-scale pretrained models (e.g., ViT-Large); the effectiveness of zeroth-order estimation at larger parameter scales requires further validation.
- Combinations with advanced SAM variants (e.g., ASAM, LookSAM) remain unexplored.
Related Work & Insights¶
- F-MTL (Phan et al.) is the most closely related prior work, applying SAM directly to each individual task, but at prohibitive computational cost.
- SPSA gradient estimation originates from the stochastic optimization literature and is applied here for the first time to approximate local perturbations in an MTL setting.
- Insight: Zeroth-order optimization has found applications beyond MTL (e.g., prompt tuning), and SAMO's methodology may generalize to other scenarios requiring multi-objective optimization.
Rating¶
- Novelty: ⭐⭐⭐⭐ The joint global-local perturbation paradigm for integrating SAM into MTL is novel, though the zeroth-order approximation itself is an existing technique.
- Experimental Thoroughness: ⭐⭐⭐⭐⭐ Five datasets covering segmentation, classification, regression, and multi-input scenarios, with comprehensive ablation studies.
- Writing Quality: ⭐⭐⭐⭐ Clear logical flow, thorough analysis, and rich figures and tables.
- Value: ⭐⭐⭐⭐ Strong practical utility as a plug-and-play module; the lightweight design offers considerable engineering value.