Skip to content

Exploiting Weight-Space Symmetries for Approximating Curvature

Conference: ICML 2026
arXiv: 2606.00442
Code: https://github.com/mtkresearch/symm_opt
Area: Optimization / Second-order Optimizers / Geometry & Algebra
Keywords: Hessian Approximation, Weight-Space Symmetries, Orbit Averaging, Shampoo, Muon

TL;DR

This paper proves that by utilizing the invariance of neural network losses to "weight-space symmetry groups" (such as parameter permutation/rescaling) and performing orbit averaging on a single gradient, a highly structured Hessian approximation that is inexpensive to store and invert can be analytically derived. Furthermore, Shampoo/Muon are shown to be special cases corresponding to assigning identity groups to certain layers, thereby integrating these empirical optimizers into a unified symmetry-curvature framework.

Background & Motivation

Background: From second-order optimization (preconditioning gradients for faster convergence) and Bayesian deep learning (Laplace posterior) to continual learning (protecting important directions) and pruning/compression (curvature scoring), many subfields of machine learning treat "efficient estimation of the (inverse) curvature of the loss" as a core component. In practice, mainstream methods rely on "block-diagonal + Kronecker decomposition" approximations like KFAC, Shampoo, and Soap to keep storage and inversion costs feasible.

Limitations of Prior Work: The explanations for why these methods are effective have been pieced together post-hoc—some claim Shampoo approximates Gauss-Newton, others argue it is equivalent to spectral descent. However, no unified principle dictates what structure should appear in a Hessian approximation, how many parameters can be saved, or why such savings are justified.

Key Challenge: Neural network losses are explicitly invariant to many weight transformations (arbitrary permutations of hidden layer neurons, sign permutations in tanh networks, synchronized input-output permutations in autoencoders, etc.). This "obvious" invariance has rarely been exploited in curvature estimation. While Kunin (2020) and Ziyin (2023) proved that the Hessian at critical points inherits symmetry, no one has applied this structure to "Hessian approximations at arbitrary points during training."

Goal: Starting from weight-space symmetry groups, construct a Hessian approximant that can be calculated using a single gradient, stored and inverted cheaply, and allows for a continuous precision-cost trade-off based on the size of the symmetry group. Use this framework to explain Shampoo/Muon as special cases.

Key Insight: Loss invariance \(\mathcal{L}(\bm w)=\mathcal{L}(A\bm w)\) directly implies gradient equivariance \(\nabla\mathcal{L}(A\bm w)=A\nabla\mathcal{L}(\bm w)\). Thus, a single gradient along a group orbit "automatically informs" us of all gradients on that orbit. Curvature information is naturally embedded in the orbit and only needs to be "analytically extracted."

Core Idea: Combine the secant condition with a second-order Taylor expansion to perform averaging over group orbits, obtaining the structural equation \(S_{\bm g}\approx H^\star S_{\bm w}H^\star\). The solution is shown to be a linear combination of a low-dimensional basis in a commutant algebra, allowing "factors" to be stored instead of full matrices.

Method

The logical chain of the method is: "Loss symmetry group \(\Rightarrow\) Gradient equivariance \(\Rightarrow\) Orbit averaging for the structural equation \(\Rightarrow\) Commutant algebra provides sparse bases \(\Rightarrow\) Solve factors via least squares \(\Rightarrow\) Apply to secant condition for PSD Hessian approximation \(\Rightarrow\) Choosing different groups yields Symo / Shampoo / Muon optimizers." The theory uses Schur-Weyl duality to link "groups" and "algebras," with a JIT compiler automatically translating symbolic symmetry definitions into PyTorch computations.

Overall Architecture

Let the network parameters be \(\bm w=[\text{vec}(B);\text{vec}(C);\dots]\), where each tensor undergoes group actions along its axes: \(\bm v\to(\bigotimes_k A_{i(k)})\bm v\). The authors define first-order orbit averaging \(\mathcal{R}_1(\bm v, \mathcal{G})\equiv\mathbb{E}_{\mathcal{G}}[(\bigotimes_k A_{i(k)})\bm v]\) and second-order orbit averaging \(\mathcal{R}_2(\bm v, \bm v', \mathcal{G})\equiv\mathbb{E}_{\mathcal{G}}[(\bigotimes_k A_{i(k)})\bm v{\bm v'}^\top(\bigotimes_k A_{i'(k)})^\top]\). The former converges to the "center" of \(\bm w\) in the orbit, \(\bm w^\star\equiv\mathcal{R}_1(\bm w, \mathcal{G})\), while the latter yields an object weighted by a small set of sparse basis tensors in the commutant algebra. Expanding the second-order Taylor series around \(\bm w^\star\) and averaging along the entire orbit upgrades the secant condition \(\bm g-\bm g^\star\approx H^\star(\bm w-\bm w^\star)\) to the structural equation \(S_{\bm g}\approx H^\star S_{\bm w}H^\star\). This yields the unique PSD solution \(H^\star_{\text{PD}}=S_{\bm w}^{-1/2}(S_{\bm w}^{1/2}S_{\bm g}S_{\bm w}^{1/2})^{1/2}S_{\bm w}^{-1/2}\), or \(H^\star_{\bm g}=S_{\bm g}^{1/2}\) under the simplification \(S_{\bm w}\propto I\). Finally, "larger groups yield longer orbits but lower commutant dimensions" provides a knob for the precision-cost trade-off.

Key Designs

  1. Orbit-Averaged Hessian Approximation \(H^\star_{\bm g}=S_{\bm g}^{1/2}\):

    • Function: Derives a highly structured, storable, and invertible curvature proxy from a single gradient calculation, eliminating the need for engineering specialized \(A\) and \(G\) block-diagonal terms as in KFAC.
    • Mechanism: Expanding the Taylor series around \(\bm w^\star\equiv\mathcal{R}_1(\bm w, \mathcal{G})\) gives \(A(\bm g-\bm g^\star)\approx H^\star A(\bm w-\bm w^\star)\) for all \(A\in\mathcal{G}\). Outer products and orbit averaging yield \(S_{\bm g}\approx H^\star S_{\bm w}H^\star\). The authors empirically find that \(S_{\bm w} \approx cI\) (to avoid numerical ill-conditioning as \(S_{\bm w}\) is often rank-deficient) maintains accuracy while requiring only \(S_{\bm g}\) estimation, thus promoting \(H^\star_{\bm g}=S_{\bm g}^{1/2}\). Lemma C.1 bounds the secant error by \(\tfrac{M}{2}\|\bm w-\bm w^\star\|^2\), where \(\|\bm w-\bm w^\star\|\) increases monotonically with group size—providing an analytical basis for the "larger group, longer orbit, coarser approximation" trade-off.
    • Design Motivation: To rewrite the engineering problem of "curvature estimation precision vs. cost" into an algebraic problem of "how to select the symmetry group."
  2. Commutant Algebra + JIT Compilation of Sparse Decompositions for \(S_{\bm v\bm v'}\):

    • Function: Represents \(S_{\bm w}\) and \(S_{\bm g}\) as linear combinations of extremely few sparse binary tensors. Only the combination coefficients (factors \(f\)) need to be estimated, drastically compressing storage/inversion complexity.
    • Mechanism: Per Lemma 3.1, \(AS_{\bm v\bm v'}A^\top=S_{\bm v\bm v'}\) holds for all \(A\in\mathcal{G}\), forcing \(S_{\bm v\bm v'}\) into the commutant algebra of the group. This algebra has a natural sparse basis (Kronecker products of identity and all-ones tensors). For example, in a toy MLP, \(S_{\bm{cc}}\) under the permutation group is \(S_{mnop}=\delta_{mo}f^{(1)}_{np}+\bm{1}_{mo}f^{(2)}_{np}\) (32 factors). With sign-flipping symmetry in tanh networks, the basis collapses to a single term \(\delta_{mo}\), halving factors. The JIT compiler translates symbolic symmetry declarations into PyTorch graphs.
    • Design Motivation: To replace heuristic decisions about "which Kronecker structure is reasonable" with a theorem-backed approach where the symmetry group determines the commutant algebra dimension.
  3. Symo Optimizer and its Reduction to Shampoo / Muon:

    • Function: Uses \(H^\star_{\bm g}\) as a preconditioner for the Symo update \(\bm w_{t+1}=\bm w_t-\eta (H^\star_{\bm g})^{-1}\bm g_t\), and proves it reduces exactly to Shampoo / Muon under specific group selections.
    • Mechanism: Via Lemma 3.2, factors are solved via least squares \(\bm f^\star=\arg\min_{\bm f}|S(\bm f)-\bm v{\bm v'}^\top|_F^2\). For \(S_{mnop}=\delta_{mo}f_{np}\), the optimal solution is \(\mathcal{R}_2(\bm g,\bm g,\mathcal{G})=GG^\top\otimes I\). Lemma 3.3 proves that if curvature is of the form \(I\otimes F\) or \(F\otimes I\), the Symo update simplifies to whitened Shampoo or Muon (e.g., \(\sqrt{n}G(G^\top G)^{-1/2}\)). This reduction is triggered by "assigning identity groups to certain layers," such as even MLP layers or embedding-dimension axes in Transformers.
    • Design Motivation: To provide a unified derivation for widely used but poorly understood optimizers like Shampoo and Muon.

Key Experimental Results

Hessian Approximation Accuracy (Secant Cosine Similarity)

The experiment measures the cosine similarity between \(\hat H(\bm w'-\bm w)\) and the true difference \(\bm g'-\bm g\) at various points along the training trajectory.

Approximator Analytical Form Storage/Inversion Scale Random Direction Cosine Gradient Direction Cosine
Exact Hessian \(H\) \(\nabla^2\mathcal{L}(\bm w)\) \(O(P^2)\) 1.0 (Ref) 1.0 (Ref)
Centered Hessian \(H^\star\) \(\nabla^2\mathcal{L}(\bm w^\star)\) \(O(P^2)\) Near \(H\) Near \(H\)
\(H^\star_{\text{PD}}\) (Eq. 10) \(S_{\bm w}^{-1/2}(S_{\bm w}^{1/2}S_{\bm g}S_{\bm w}^{1/2})^{1/2}S_{\bm w}^{-1/2}\) \(O(\text{Factors})\) Same tier as \(H^\star_{\bm g}\) Better than BD
\(H^\star_{\bm g}\) (Eq. 11) \(S_{\bm g}^{1/2}\) \(O(\text{Factors})\) Strong Strong
\(H^\star_{\bm g}(\text{BD})\) Block-Diagonal version \(O(\text{Sum of blocks})\) Moderate Identical to Shampoo
Shampoo \(L^{1/4}_t G_t R^{1/4}_t\) \(O(n^2+m^2)\) per block Moderate Identical to \(H^\star_{\bm g}(\text{BD})\)

Group Size vs. Factor Count (Toy MLP \(C\in\mathbb{R}^{3\times 4}\))

The authors demonstrate the trade-off between group size and the number of factors needed for \(S_{\bm{cc}}\).

Structure / Symmetry Group \(S_{\bm{cc}}\) Analytical Form Basis Terms Total Factors
ReLU MLP, Hidden Permutation \(\mathcal{G}_1\) \(S_{mnop}=\delta_{mo}f^{(1)}_{np}+\bm{1}_{mo}f^{(2)}_{np}\) 2 32
tanh MLP, Hidden Sign Permutation \(S_{mnop}=\delta_{mo}f_{np}\) 1 16
Autoencoder MLP, Synced I/O Permutation \(\mathbb{E}_{A_1,A_2}[(A_1\otimes A_2)\bm c\bm c^\top(A_1\otimes A_2)^\top]\) Multiple Sparse 4

Key Findings

  • The trade-off between "group size \(\leftrightarrow\) factor count" is continuous. Moving from 32 to 16 to 4 factors corresponds to increased symmetry and decreased storage, while Lemma C.1 confirms \(\|\bm w-\bm w^\star\|\) increases, indicating a coarser approximation.
  • In secant experiments, the block-diagonal \(H^\star_{\bm g}(\text{BD})\) perfectly aligns with Shampoo in the gradient direction (consistent with Lemma 3.3), suggesting Shampoo/Muon succeed by targeting a "sweet spot" of assigning identity groups to certain layers without breaking overall symmetric structure.
  • \(H^\star_{\bm g}\) is cheaper than and nearly as accurate as \(H^\star_{\text{PD}}\). Simplifying \(S_{\bm w}\approx cI\) avoids numerical issues from rank deficiency, representing a crucial bridge from theory to engineering.
  • Normalized \(\mathcal{R}_2(\bm g, \bm g, \mathcal{G})\) for Transformers exhibits visual "colored block" patterns (Fig. 3), directly confirming that curvature matrices remain highly structured after orbit averaging and can be characterized by few factors.

Highlights & Insights

  • Geometrizing Engineering Problems: Rather than guessing Kronecker structures, the choice is transformed into selecting symmetry groups. This allows for principled enumeration of feasible structures via Schur-Weyl duality.
  • Symmetry-Based Explanation for Shampoo/Muon: These optimizers are not merely "approximating Gauss-Newton"; they effectively opt to ignore exploitable symmetries in certain layers (assigning identity groups) to achieve a sweet spot in the cost-accuracy trade-off.
  • Engineering Simplicity via JIT: Automatic generation of PyTorch computation graphs from symbolic symmetry declarations allows for near-zero engineering cost to implement Hessian approximations for new architectures.
  • Theory-Engineering Dual-Track: The derivation of both the "elegant theoretical solution" (\(H^\star_{\text{PD}}\)) and the "practically useful solution" (\(H^\star_{\bm g}\)) provides clear guidance on the trade-offs required for real-world implementation.

Limitations & Future Work

  • The framework assumes exact invariance of the loss \(\mathcal{L}\), whereas real-world networks (with biases, residuals, LayerNorm, etc.) often only satisfy these symmetries approximately. Robustness in deep/complex structures lacks extensive validation.
  • The simplification \(S_{\bm w}\propto I\) is primarily empirical. Quantitative bounds on the gap between \(H^\star_{\bm g}\) and \(H^\star_{\text{PD}}\) are not provided.
  • While small LMs were tested, comparisons on trillion-parameter models regarding wall-clock time, peak memory, and convergence curves are missing.
  • Orbit averaging from a single gradient assumes the gradient \(\bm g\) is accurate, but in mini-batch training, noise might propagate into the factor estimation.
  • The focus is on second-order optimization; performance in other curvature-dependent tasks (Laplace approximation, pruning) remains unexplored.
  • vs. Bernacchia (2025): While both use averaging to find structural curvature, that work averages over "ensembles of randomly initialized networks," whereas this work derives structure from orbits of a "single model at any training point."
  • vs. (E)KFAC / Shampoo / Soap: These fix Kronecker structures heuristically; this work derives all valid structures from the commutant algebra of symmetry groups.
  • vs. Gauss-Newton / Spectral Descent Explanations: Instead of treating Shampoo as an approximation of another method, this work views it as an exact solution under a specific (restricted) symmetry group.
  • vs. Quasi-Newton (L-BFGS): Quasi-Newton methods require multiple historical gradients to approximate Hessian rank; this method uses a single gradient combined with symmetry to analytically extract the structured curvature.