Balancing Learning Rates Across Layers: Exact Two-Step Dynamics and Optimal Scaling in Linear Neural Networks¶
Conference: ICML 2026
arXiv: 2606.00340
Code: https://github.com/TDCSZ327/Layer-Balancing
Area: Optimization Theory
Keywords: Layer-wise learning rate, linear networks, gradient decomposition, training dynamics, learning rate balancing
TL;DR¶
This work derives exact closed-form expressions for the test loss after one and two steps of gradient descent in two-layer and three-layer linear neural networks. It reveals a phase transition: asymmetric learning rates are optimal for the first step update, while symmetric (balanced) learning rates become locally optimal after the second step, providing a theoretical foundation for layer-wise learning rate scheduling.
Background & Motivation¶
Background: In deep network training, layer-wise learning rate scheduling (e.g., LARS, LAMB, TempBalance, Adam-mini) has been widely adopted to accelerate convergence and improve generalization. These methods adapt to differences in gradient characteristics across layers by assigning different learning rates.
Limitations of Prior Work: Existing layer-wise learning rate strategies are primarily based on heuristic designs or asymptotic analysis, lacking exact formulas that directly link learning rate selection to test loss. Neither continuous-time gradient flow analysis nor NTK approximations capture the signal-residual coupling effect in discrete, finite-step settings. Inter-layer interactions make it difficult to quantify the impact of learning rate allocation on generalization.
Key Challenge: In multi-layer networks, the norms of the signal component and the self-interaction component of gradients vary across layers, and the learning rate plays a critical role in determining training dynamics. Existing theoretical frameworks either assume infinitesimal step sizes (mean-field / \(\mu\)P) or analyze each layer independently (ignoring cross-layer coupling), failing to precisely characterize the impact of layer-wise learning rates on generalization in finite-step training.
Goal: To derive exact closed-form expressions for the test loss with respect to layer-wise learning rates after one and two steps of gradient descent in linear networks, thereby precisely characterizing when asymmetric vs. symmetric learning rates should be used.
Key Insight: By utilizing the algebraic structure of orthogonal initialization, the gradient is decomposed into a signal-alignment term \(A_\ell^t\) (dominating the learning signal) and a self-interaction term \(B_\ell^t\) (coupling between weights). It is proven that the self-interaction term is negligible below a critical learning rate threshold, enabling the derivation of an analytical proxy loss.
Core Idea: The optimal allocation of layer-wise learning rates is dynamic—early training benefits from asymmetric allocation to exploit layer-specific signal propagation, while later training favors balanced allocation to facilitate cross-layer coordination.
Method¶
Overall Architecture¶
Linear networks with orthogonal initialization are considered: a two-layer model \(f(x) = \frac{1}{h}x^\top W_1 W_2\) (input/output in \(\mathbb{R}^h\)) and a three-layer model \(f^*(x) = \frac{1}{\sqrt{h}}x^\top W_1 W_2 a\) (scalar output, fixed \(a\)). Training data is generated by a linear teacher model, and the optimization objective is MSE loss. The analysis workflow consists of: (1) decomposing the gradient into signal and residual terms; (2) proving the residual terms are negligible in wide networks; (3) deriving closed-form solutions for test loss using signal-only trajectories; (4) analyzing the symmetry of the optimal learning rate allocation.
Key Designs¶
-
Gradient Decomposition and Signal Dominance Proof:
- Function: Decomposes the exact gradient into a signal-alignment term and a self-interaction term, proving the former dominates.
- Mechanism: For the \(\ell\)-th layer of a two-layer network, the gradient \(G_\ell^t = B_\ell^t - A_\ell^t\), where \(A_1^t = \frac{1}{h}M W_2^{t\top}\) captures the label signal and \(B_1^t = \frac{1}{h^2}W_1^t W_2^t W_2^{t\top}\) represents weight self-interaction. Proposition 5.1 proves that when \(\eta_1, \eta_2 \leq O(h\sqrt{h})\), \(\|G_\ell^t - A_\ell^t\| \leq \|G_\ell^t\| / (\sqrt{h}-1)\), meaning the residual term is suppressed by a \(1/\sqrt{h}\) factor.
- Design Motivation: This decomposition replaces the non-analytical exact gradient with a structurally simple signal term, making closed-form test loss derivation possible while identifying critical learning rate scales \(\eta = \Theta(h\sqrt{h})\) (two-layer) and \(\eta = \Theta(h)\) (three-layer).
-
Exact Two-Step Test Loss Formula:
- Function: Provides an exact polynomial expression for the test loss with respect to \(\eta_1, \eta_2\).
- Mechanism: The one-step test loss for a two-layer network is \(L^{(1)} = \frac{\eta_1^2}{h^4} + \frac{\eta_2^2}{h^4} + \frac{2\eta_1\eta_2}{h^4} + \frac{\eta_1^2\eta_2^2}{h^7} - \frac{2\eta_1}{h^2} - \frac{2\eta_2}{h^2} + \frac{1}{h} + \frac{2\eta_1\eta_2}{h^5} + 1\), which includes linear improvement terms, quadratic interaction terms, and residual variance terms. The two-step loss contains higher powers of \((1+\eta_1\eta_2/h^3)\), reflecting multiplicative representation learning across layers.
- Design Motivation: The exact formula allows for rigorous analysis of whether the learning rate symmetry point \(\eta_1 = \eta_2\) is a local minimum, thereby revealing the phase transition phenomenon.
-
Phase Transition Theorem from Asymmetry to Balance:
- Function: Proves that the optimal learning rate allocation undergoes a qualitative change as training steps progress.
- Mechanism: Under the constraint \(\eta_1 + \eta_2 = 2h^\alpha\), Corollary 5.4 proves: (a) for any \(0 < \alpha \leq 3/2\), the symmetry point \(\eta_1 = \eta_2\) is not a local minimum of the one-step loss; (b) for \(1 < \alpha \leq 3/2\) and sufficiently large \(h\), \(\eta_1 = \eta_2\) is a local minimum of the two-step loss. Similar conclusions hold for three-layer networks but with the critical scale reduced to \(O(h)\).
- Design Motivation: This provides the first theoretical explanation for why layer-wise learning rate scheduling should be asymmetric in early training (exploiting role differences between representation and readout layers) and move toward balance in later stages (promoting cross-layer coordination and alignment).
Key Experimental Results¶
Main Results: Theoretical Predictions vs. Actual Test Loss¶
| Setting | Network | Steps | Theory-Experiment Deviation | Symmetry Conclusion |
|---|---|---|---|---|
| \(h=1000\), Orthogonal Init | 2-layer | 1 step | Closely matched | Symmetric LR not optimal |
| \(h=1000\), Orthogonal Init | 2-layer | 2 steps | Closely matched | Symmetric LR locally optimal |
| \(h=1000\), Orthogonal Init | 3-layer | 1 step | Closely matched | Symmetric LR not optimal |
| \(h=1000\), Orthogonal Init | 3-layer | 2 steps | Closely matched | Symmetric LR locally optimal |
| \(h=1000\), Gaussian Init | 2/3-layer | Multi-step | Consistent trend | Same as above, transition persists |
Ablation Study: Generalization Verification¶
| Extension Condition | Key Findings |
|---|---|
| Label noise \(\xi \sim \mathcal{N}(0,\rho)\) | Asymmetric \(\to\) Balance transition still holds |
| 4-layer/8-layer Deep Linear | One-step asymmetric, two-step balanced transition maintained |
| 3-layer Nonlinear (ReLU) | Curve symmetry slightly weaker but transition trend is consistent |
| Multi-step Training (to 512 steps) | Balanced LR remains locally optimal in subsequent steps |
| Frobenius Norm-driven LR Scheduler | Achieves lower training/test loss than uniform baseline |
Key Findings¶
- Critical learning rate thresholds: \(O(h\sqrt{h})\) for two-layer networks and \(O(h)\) for three-layer networks; gradient approximations fail beyond these thresholds.
- Two-step loss in three-layer networks shows stronger dependence on higher-order \(\eta_1\eta_2\) terms (e.g., \(\eta_1^4\eta_2^4/h^6\)), reflecting deeper cross-layer coupling.
- The adaptive layer-wise LR scheduler designed based on \(\|W_1\|_F, \|W_2\|_F\) validates theoretical predictions: the norm difference \(\|W_1\|_F - \|W_2\|_F \to 0\) corresponds to learning rates tending toward balance while converging to a flatter minimum.
Highlights & Insights¶
- Signal-residual gradient decomposition is the most ingenious tool in this work: by proving the self-interaction term \(B_\ell^t\) is suppressed by \(1/\sqrt{h}\), the non-analytical exact dynamics are simplified into signal-only trajectories, enabling closed-form test loss derivations. This decomposition approach is transferable to other theoretical works requiring analysis of finite-step gradient descent.
- The "Asymmetric then Balanced" phase transition perspective provides a unified theoretical explanation for practical layer-wise learning rate schedulers (LARS, LAMB, TempBalance): early stages treat layers differently due to distinct roles (representation vs. readout), while later stages favor balance as cross-layer coordination dominates.
- The norm-driven LR scheduler design \(\eta_{W_i}^{(t)} = \frac{2\|W_j^t\|_F}{\|W_1^t\|_F + \|W_2^t\|_F} \cdot lr\) is a direct mapping from theory to practice, being both simple and effective.
Limitations & Future Work¶
- Theoretical analysis is limited to linear networks and orthogonal/Gaussian initializations; although nonlinear ReLU experiments show similar trends, theoretical guarantees are lacking.
- Only one-step and two-step dynamics are analyzed; conclusions for optimal LR allocation during long-term training rely on experimental observations.
- Assumes \(n = h = d\), without considering different regimes like over-parameterization or under-parameterization.
- Does not address Stochastic Gradient Descent (SGD) or mini-batch settings, leaving a gap with practical training.
- Future work could extend analysis to practical architectures like Transformers, specifically exploring optimal LR allocation between attention layers and FFN layers.
Related Work & Insights¶
- LARS/LAMB (You et al., 2017; 2018): Layer-wise LR based on "trust ratios"; this work provides theoretical support for their early-stage asymmetric allocation.
- TempBalance (Zhou et al., 2023; Liu et al., 2024): Layer-wise LR based on heavy-tailedness of weight spectra; this work's norm-balancing perspective is complementary.
- Adam-mini / Blockwise-LR (Zhang et al., 2024; Wang et al., 2025): Layer-wise LR based on Hessian block structures.
- Du et al. (2018): Proved automatic balancing of layer norms in deep homogeneous models; this work further reveals the role of learning rates in driving this balance.
- Kunin et al. (2024): Studied feature learning under unbalanced initialization, finding that layer-balanced learning rates facilitate rapid feature learning.