Skip to content

Balancing Learning Rates Across Layers: Exact Two-Step Dynamics and Optimal Scaling in Linear Neural Networks

Conference: ICML 2026
arXiv: 2606.00340
Code: https://github.com/TDCSZ327/Layer-Balancing
Area: Optimization Theory
Keywords: Layer-wise learning rate, linear networks, gradient decomposition, training dynamics, learning rate balancing

TL;DR

This work derives exact closed-form expressions for the test loss after one and two steps of gradient descent in two-layer and three-layer linear neural networks. It reveals a phase transition: asymmetric learning rates are optimal for the first step update, while symmetric (balanced) learning rates become locally optimal after the second step, providing a theoretical foundation for layer-wise learning rate scheduling.

Background & Motivation

Background: In deep network training, layer-wise learning rate scheduling (e.g., LARS, LAMB, TempBalance, Adam-mini) has been widely adopted to accelerate convergence and improve generalization. These methods adapt to differences in gradient characteristics across layers by assigning different learning rates.

Limitations of Prior Work: Existing layer-wise learning rate strategies are primarily based on heuristic designs or asymptotic analysis, lacking exact formulas that directly link learning rate selection to test loss. Neither continuous-time gradient flow analysis nor NTK approximations capture the signal-residual coupling effect in discrete, finite-step settings. Inter-layer interactions make it difficult to quantify the impact of learning rate allocation on generalization.

Key Challenge: In multi-layer networks, the norms of the signal component and the self-interaction component of gradients vary across layers, and the learning rate plays a critical role in determining training dynamics. Existing theoretical frameworks either assume infinitesimal step sizes (mean-field / \(\mu\)P) or analyze each layer independently (ignoring cross-layer coupling), failing to precisely characterize the impact of layer-wise learning rates on generalization in finite-step training.

Goal: To derive exact closed-form expressions for the test loss with respect to layer-wise learning rates after one and two steps of gradient descent in linear networks, thereby precisely characterizing when asymmetric vs. symmetric learning rates should be used.

Key Insight: By utilizing the algebraic structure of orthogonal initialization, the gradient is decomposed into a signal-alignment term \(A_\ell^t\) (dominating the learning signal) and a self-interaction term \(B_\ell^t\) (coupling between weights). It is proven that the self-interaction term is negligible below a critical learning rate threshold, enabling the derivation of an analytical proxy loss.

Core Idea: The optimal allocation of layer-wise learning rates is dynamic—early training benefits from asymmetric allocation to exploit layer-specific signal propagation, while later training favors balanced allocation to facilitate cross-layer coordination.

Method

Overall Architecture

Linear networks with orthogonal initialization are considered: a two-layer model \(f(x) = \frac{1}{h}x^\top W_1 W_2\) (input/output in \(\mathbb{R}^h\)) and a three-layer model \(f^*(x) = \frac{1}{\sqrt{h}}x^\top W_1 W_2 a\) (scalar output, fixed \(a\)). Training data is generated by a linear teacher model, and the optimization objective is MSE loss. The analysis workflow consists of: (1) decomposing the gradient into signal and residual terms; (2) proving the residual terms are negligible in wide networks; (3) deriving closed-form solutions for test loss using signal-only trajectories; (4) analyzing the symmetry of the optimal learning rate allocation.

Key Designs

  1. Gradient Decomposition and Signal Dominance Proof:

    • Function: Decomposes the exact gradient into a signal-alignment term and a self-interaction term, proving the former dominates.
    • Mechanism: For the \(\ell\)-th layer of a two-layer network, the gradient \(G_\ell^t = B_\ell^t - A_\ell^t\), where \(A_1^t = \frac{1}{h}M W_2^{t\top}\) captures the label signal and \(B_1^t = \frac{1}{h^2}W_1^t W_2^t W_2^{t\top}\) represents weight self-interaction. Proposition 5.1 proves that when \(\eta_1, \eta_2 \leq O(h\sqrt{h})\), \(\|G_\ell^t - A_\ell^t\| \leq \|G_\ell^t\| / (\sqrt{h}-1)\), meaning the residual term is suppressed by a \(1/\sqrt{h}\) factor.
    • Design Motivation: This decomposition replaces the non-analytical exact gradient with a structurally simple signal term, making closed-form test loss derivation possible while identifying critical learning rate scales \(\eta = \Theta(h\sqrt{h})\) (two-layer) and \(\eta = \Theta(h)\) (three-layer).
  2. Exact Two-Step Test Loss Formula:

    • Function: Provides an exact polynomial expression for the test loss with respect to \(\eta_1, \eta_2\).
    • Mechanism: The one-step test loss for a two-layer network is \(L^{(1)} = \frac{\eta_1^2}{h^4} + \frac{\eta_2^2}{h^4} + \frac{2\eta_1\eta_2}{h^4} + \frac{\eta_1^2\eta_2^2}{h^7} - \frac{2\eta_1}{h^2} - \frac{2\eta_2}{h^2} + \frac{1}{h} + \frac{2\eta_1\eta_2}{h^5} + 1\), which includes linear improvement terms, quadratic interaction terms, and residual variance terms. The two-step loss contains higher powers of \((1+\eta_1\eta_2/h^3)\), reflecting multiplicative representation learning across layers.
    • Design Motivation: The exact formula allows for rigorous analysis of whether the learning rate symmetry point \(\eta_1 = \eta_2\) is a local minimum, thereby revealing the phase transition phenomenon.
  3. Phase Transition Theorem from Asymmetry to Balance:

    • Function: Proves that the optimal learning rate allocation undergoes a qualitative change as training steps progress.
    • Mechanism: Under the constraint \(\eta_1 + \eta_2 = 2h^\alpha\), Corollary 5.4 proves: (a) for any \(0 < \alpha \leq 3/2\), the symmetry point \(\eta_1 = \eta_2\) is not a local minimum of the one-step loss; (b) for \(1 < \alpha \leq 3/2\) and sufficiently large \(h\), \(\eta_1 = \eta_2\) is a local minimum of the two-step loss. Similar conclusions hold for three-layer networks but with the critical scale reduced to \(O(h)\).
    • Design Motivation: This provides the first theoretical explanation for why layer-wise learning rate scheduling should be asymmetric in early training (exploiting role differences between representation and readout layers) and move toward balance in later stages (promoting cross-layer coordination and alignment).

Key Experimental Results

Main Results: Theoretical Predictions vs. Actual Test Loss

Setting Network Steps Theory-Experiment Deviation Symmetry Conclusion
\(h=1000\), Orthogonal Init 2-layer 1 step Closely matched Symmetric LR not optimal
\(h=1000\), Orthogonal Init 2-layer 2 steps Closely matched Symmetric LR locally optimal
\(h=1000\), Orthogonal Init 3-layer 1 step Closely matched Symmetric LR not optimal
\(h=1000\), Orthogonal Init 3-layer 2 steps Closely matched Symmetric LR locally optimal
\(h=1000\), Gaussian Init 2/3-layer Multi-step Consistent trend Same as above, transition persists

Ablation Study: Generalization Verification

Extension Condition Key Findings
Label noise \(\xi \sim \mathcal{N}(0,\rho)\) Asymmetric \(\to\) Balance transition still holds
4-layer/8-layer Deep Linear One-step asymmetric, two-step balanced transition maintained
3-layer Nonlinear (ReLU) Curve symmetry slightly weaker but transition trend is consistent
Multi-step Training (to 512 steps) Balanced LR remains locally optimal in subsequent steps
Frobenius Norm-driven LR Scheduler Achieves lower training/test loss than uniform baseline

Key Findings

  • Critical learning rate thresholds: \(O(h\sqrt{h})\) for two-layer networks and \(O(h)\) for three-layer networks; gradient approximations fail beyond these thresholds.
  • Two-step loss in three-layer networks shows stronger dependence on higher-order \(\eta_1\eta_2\) terms (e.g., \(\eta_1^4\eta_2^4/h^6\)), reflecting deeper cross-layer coupling.
  • The adaptive layer-wise LR scheduler designed based on \(\|W_1\|_F, \|W_2\|_F\) validates theoretical predictions: the norm difference \(\|W_1\|_F - \|W_2\|_F \to 0\) corresponds to learning rates tending toward balance while converging to a flatter minimum.

Highlights & Insights

  • Signal-residual gradient decomposition is the most ingenious tool in this work: by proving the self-interaction term \(B_\ell^t\) is suppressed by \(1/\sqrt{h}\), the non-analytical exact dynamics are simplified into signal-only trajectories, enabling closed-form test loss derivations. This decomposition approach is transferable to other theoretical works requiring analysis of finite-step gradient descent.
  • The "Asymmetric then Balanced" phase transition perspective provides a unified theoretical explanation for practical layer-wise learning rate schedulers (LARS, LAMB, TempBalance): early stages treat layers differently due to distinct roles (representation vs. readout), while later stages favor balance as cross-layer coordination dominates.
  • The norm-driven LR scheduler design \(\eta_{W_i}^{(t)} = \frac{2\|W_j^t\|_F}{\|W_1^t\|_F + \|W_2^t\|_F} \cdot lr\) is a direct mapping from theory to practice, being both simple and effective.

Limitations & Future Work

  • Theoretical analysis is limited to linear networks and orthogonal/Gaussian initializations; although nonlinear ReLU experiments show similar trends, theoretical guarantees are lacking.
  • Only one-step and two-step dynamics are analyzed; conclusions for optimal LR allocation during long-term training rely on experimental observations.
  • Assumes \(n = h = d\), without considering different regimes like over-parameterization or under-parameterization.
  • Does not address Stochastic Gradient Descent (SGD) or mini-batch settings, leaving a gap with practical training.
  • Future work could extend analysis to practical architectures like Transformers, specifically exploring optimal LR allocation between attention layers and FFN layers.
  • LARS/LAMB (You et al., 2017; 2018): Layer-wise LR based on "trust ratios"; this work provides theoretical support for their early-stage asymmetric allocation.
  • TempBalance (Zhou et al., 2023; Liu et al., 2024): Layer-wise LR based on heavy-tailedness of weight spectra; this work's norm-balancing perspective is complementary.
  • Adam-mini / Blockwise-LR (Zhang et al., 2024; Wang et al., 2025): Layer-wise LR based on Hessian block structures.
  • Du et al. (2018): Proved automatic balancing of layer norms in deep homogeneous models; this work further reveals the role of learning rates in driving this balance.
  • Kunin et al. (2024): Studied feature learning under unbalanced initialization, finding that layer-balanced learning rates facilitate rapid feature learning.