Skip to content

On the Surprising Effectiveness of Large Learning Rates under Standard Width Scaling

Conference: NeurIPS 2025 arXiv: 2505.22491 Authors: Moritz Haas, Sebastian Bordt, Ulrike von Luxburg, Leena Chennuru Vankadara (Tübingen, UCL Gatsby) Area: Other Keywords: infinite-width limit, standard parameterization, learning rate scaling, cross-entropy loss, feature learning, controlled divergence

TL;DR

This paper reveals that under standard parameterization (SP), the cross-entropy loss causes the previously monolithic "unstable" regime to split into two distinct sub-regimes: catastrophic instability and controlled divergence. In the controlled divergence regime (\(\eta_n = \Theta(n^{-1/2})\)), logits diverge while gradients and activations remain stable, thereby establishing the first practically useful infinite-width limit for SP that admits feature learning.

Background & Motivation

Root Cause

Infinite-width theory (Tensor Programs, etc.) predicts that under standard parameterization (He initialization + global learning rate), learning rates larger than \(\mathcal{O}(1/n)\) lead to training instability, while \(\mathcal{O}(1/n)\) rates cause feature learning to vanish (degenerating to the kernel regime). In practice, however: - The optimal learning rate decays far more slowly than \(\mathcal{O}(1/n)\), typically at \(\Omega(1/\sqrt{n})\); - Networks remain stably trainable at large widths and learn meaningful features; - This gap is already pronounced in a single gradient step of a 2-layer MLP, ruling out depth accumulation and multi-step finite-width effects as primary explanations.

Limitations of Prior Work

Finite-width effects: The gap is more pronounced in a single gradient step of a 2-layer MLP (Figure F.15), ruling out depth/training-time accumulation.

Catapult mechanism: Analysis of linear models under SP shows that catapult dynamics require \(\eta_n = \mathcal{O}(n^{-1})\) to prevent sharpness divergence, and thus cannot explain stability at larger learning rates.

Misalignment hypothesis: Everett et al. (2024) hypothesize that the absence of alignment between weights and activations accounts for the discrepancy; however, the Refined Coordinate Check (RCC) proposed in this work confirms that the infinite-width alignment predictions hold already at moderate widths.

Core Problem

Why does SP remain stable and effective under large learning rates? Does there exist an infinite-width limit that more faithfully captures the behavior of finite-width networks in practice?

Method

Theoretical Framework: Controlled Divergence Regime

Consider an \((L+1)\)-layer MLP of width \(n\) under SP initialization, trained with SGD at global learning rate \(\eta_n = \eta \cdot n^{-\alpha}\).

Key observation: The choice of loss function determines the consequences of logit divergence. - MSE loss: The loss-logit gradient \(\chi_t = f_t(\xi_t) - y_t\) diverges directly when logits diverge, causing catastrophic instability. - CE loss: The loss-logit gradient \(\chi_t = \sigma(f_t(\xi_t)) - y_t\) is bounded by the softmax, which maps logits into \([0,1]\); logit divergence merely drives \(\sigma(f)\) toward a one-hot prediction, leaving gradients bounded.

Proposition 2: Three Asymptotic Regimes under SP (CE Loss)

Regime Exponent \(\alpha\) Logits \(\|f_t\|_{RMS}\) Activations \(\|x_t^l\|_{RMS}\) Gradients \(\|\chi_t\|_{RMS}\)
Stable \(\alpha \geq 1\) \(\mathcal{O}(1)\) \(\Theta(1)\) \(\mathcal{O}(1)\)
Controlled divergence \(\frac{1}{2} \leq \alpha < 1\) \(\Theta(n^{1-\alpha}) \to \infty\) \(\Theta(1)\) \(\mathcal{O}(1)\)
Catastrophic instability \(\alpha < \frac{1}{2}\) \(\to \infty\) \(\to \infty\) \(\to \infty\)

Under MSE loss, any \(\alpha < 1\) leads directly to catastrophic instability; no controlled divergence regime exists.

Proposition 4: Feature Learning in the Controlled Divergence Regime

Under CE loss, SP, and \(\eta_n = \eta \cdot n^{-\alpha}\) with \(\frac{1}{2} \leq \alpha < 1\): - Feature learning vanishes in the input layer: \(\|\Delta x_t^1\|_{RMS} = \Theta(n^{-1/2 - \alpha})\); - Feature learning is non-vanishing in hidden layers: \(\|\Delta x_t^l\|_{RMS} = \Theta(n^{1/2 - \alpha})\) for \(l \in [2, L]\); - In particular, at \(\alpha = 1/2\), feature learning in all hidden layers is width-independent: \(\|\Delta x_t^l\|_{RMS} = \Theta(1)\).

This constitutes the first infinite-width limit for SP in a practically useful feature learning regime.

Maximum Stable Learning Rate Exponents Across Optimizers and Architectures

The maximum stable learning rate exponent differs by layer type: - Output layer (logits): \(\eta_n = \mathcal{O}(n^{-1})\) (though CE loss permits exceeding this); - Hidden layers: \(\eta_n = \mathcal{O}(n^{-1/2})\); - Input layer / LayerNorm / Embedding: \(\eta_n = \mathcal{O}(1)\).

For Adam, gradient normalization further stabilizes training, and \(\eta_n = \Theta(n^{-1})\) suffices for width-independent updates (analogous to \(\mu\)P). In Transformers, trainable LayerNorm parameters play a key stabilizing role, raising the maximum stable learning rate from \(n^{-1}\) to \(n^{-1/2}\).

Refined Coordinate Check (RCC)

The paper proposes a diagnostic tool that separately measures the effective update \((\Delta W_t^l) x_t^{l-1}\) and the propagated update \(W_0^l (\Delta x_t^{l-1})\), enabling accurate estimation of width-scaling exponents at moderate widths (\(n \leq 512\)).

Key Experimental Results

Experimental Setup

  • Architectures: MLPs (2–8 layers), Pythia-GPT (up to 1.4B parameters)
  • Datasets: CIFAR-10, MNIST, Fashion-MNIST, DCLM-Baseline language data
  • Optimizers: SGD, Adam, AdamW
  • Width range: up to 16384 (MLP), 4096 (GPT)
  • Parameterizations: SP, \(\mu\)P, SP-full-align

Table: Predicted vs. Observed Maximum Stable Learning Rate Exponents under CE vs. MSE Loss

Parameterization Loss Predicted max-stable \(\alpha\) Observed max-stable \(\alpha\) Observed optimal \(\alpha\)
SP MSE \(-1\) \(\approx -1\) \(\approx -1\)
SP CE \(-0.5\) \(\approx -0.5\) \(\approx -0.5\)
SP-full-align CE \(0\) \(\approx 0\) \(\approx 0\) (language) / decreasing (vision)

Theoretical predictions are in strong agreement with experiments. In deep nonlinear networks, the maximum stable learning rate typically coincides with the optimal learning rate.

Table: CE vs. MSE Performance under \(\mu\)P (8-layer MLP, SGD, Best Training Accuracy)

Dataset SP + CE SP + MSE \(\mu\)P + CE \(\mu\)P + MSE
MNIST ~98% ~85% ~98% ~97%
CIFAR-10 ~50% ~28% ~52% ~50%
Fashion-MNIST ~87% ~65% ~88% ~86%

Key findings: - Under SP, CE substantially outperforms MSE because MSE loses feature learning at large widths; - Under \(\mu\)P, both losses perform comparably since \(\mu\)P ensures balanced feature learning across layers; - This provides a width-scaling theoretical explanation for the practical dominance of CE loss in deep learning.

New Findings on SP-full-align

  • SP-full-align (SP + \(\mu\)P learning rates) recommended by Everett et al. (2024) fails to transfer learning rates on vision datasets;
  • The failure stems from width-independent alignment between \(W_0^{L+1}\) and \(\Delta x_t^L\), causing logits to diverge with width;
  • Successful transfer on language data is attributed to the output dimension satisfying \(d_{\text{out}} \gg n\), which makes the initial operator norm approximately width-independent.

Highlights & Insights

  • First practically useful infinite-width limit for SP with feature learning: Under CE loss and \(\eta_n = \Theta(n^{-1/2})\), all hidden layers exhibit width-independent feature learning, resolving a long-standing theory–practice gap.
  • Fine-grained characterization of the instability regime: The previously undifferentiated "unstable" regime is decomposed into two qualitatively distinct sub-regimes; the controlled divergence induced by CE loss is identified as the key mechanism.
  • Scaling-theoretic explanation for the dominance of CE loss: This work is the first to explain from a width-scaling perspective why CE loss substantially outperforms MSE in deep learning—under SP, CE permits larger learning rates while preserving feature learning, whereas the gap vanishes under \(\mu\)P.
  • Practical diagnostic tool RCC: The Refined Coordinate Check separates effective and propagated updates, enabling accurate estimation of scaling exponents at moderate widths; code is publicly available.
  • Importance of trainable LayerNorm: The work provides a scaling-theoretic explanation for why modern architectures almost universally employ trainable LayerNorm parameters—they raise the maximum stable learning rate in Transformers from \(n^{-1}\) to \(n^{-1/2}\).

Limitations & Future Work

  • Analysis restricted to single-epoch training: Dynamics in multi-epoch settings (e.g., interactions with overfitting) are left for future work.
  • Precise prediction of optimal learning rate exponents remains difficult: The theory predicts only the maximum stable exponent; the optimal exponent further depends on strong assumptions about architecture and data distribution.
  • Numerical precision constraints: Under standard floating-point precision, logit divergence may exceed the numerical range at moderate widths, requiring specialized implementations to mitigate the accumulation of width-dependent factors.
  • Input-layer feature learning remains absent: SP at \(\alpha = 1/2\) still cannot recover feature learning in the input layer; fully width-independent training continues to require \(\mu\)P.
  • Logit divergence under CE + SP may cause overconfidence: Rapid growth of logits can degrade model calibration; models trained under \(\mu\)P may be better calibrated.
  • Infinite-width limits: NTK (Jacot et al., 2018), Mean-field (Mei et al., 2018; Chizat & Bach, 2018), Tensor Programs (Yang & Hu, 2021) → This work extends Tensor Program theory to the controlled divergence regime of SP.
  • \(\mu\)P and hyperparameter transfer: Yang et al. (2022) propose \(\mu\)P for width-independent hyperparameters → This work explains why SP also approximately works in practice (attributing this to CE loss).
  • Edge of stability: Cohen et al. (2021) observe that training converges toward sharpness \(2/\eta\) → This work analyzes the limitations of the catapult mechanism under SP.
  • Benefits of large learning rates: Andriushchenko et al. (2023b) show that large steps learn sparse features; Cai et al. (2024) demonstrate margin improvements → This work provides complementary theory from a width-scaling perspective.
  • SP-full-align: Everett et al. (2024) recommend SP + \(\mu\)P learning rates as best practice → This work identifies its failure on vision data and proposes a corrected initialization scheme.

Rating

  • Novelty: ⭐⭐⭐⭐⭐ — First to identify the critical role of CE loss in width scaling; the discovery of the controlled divergence regime is highly insightful.
  • Experimental Thoroughness: ⭐⭐⭐⭐⭐ — Comprehensive cross-validation across MLP+GPT, SGD+Adam, CE+MSE, and SP+\(\mu\)P; theoretical predictions are in strong agreement with experiments.
  • Writing Quality: ⭐⭐⭐⭐ — Logically rigorous with clearly stated main conclusions; dense mathematical notation poses a barrier for readers without a theoretical background.
  • Value: ⭐⭐⭐⭐⭐ — Resolves a long-standing contradiction between infinite-width theory and practice, with direct implications for hyperparameter selection in large-scale model training.