Skip to content

PFT: Phonon Fine-tuning for Machine Learned Interatomic Potentials

Conference: ICML 2026
arXiv: 2601.07742
Code: None
Area: Scientific Computing / Materials Simulation / Machine Learned Interatomic Potentials
Keywords: MLIP, Phonon, Hessian, Force Constants, Fine-tuning

TL;DR

This paper proposes PFT (Phonon Fine-tuning), which stochastically samples Hessian columns via Hessian-vector products and directly supervises the energy Hessian to align with DFT force constants during MLIP fine-tuning. Combined with co-training to alleviate catastrophic forgetting, it reduces thermodynamic phonon errors of Nequix MP on the MDR Phonon benchmark by an average of 55% and lowers thermal conductivity \(\kappa_{\text{SRME}}\) from 0.446 to 0.307, achieving SOTA among models trained on MPtrj.

Background & Motivation

Background: Machine Learned Interatomic Potentials (MLIPs) have become cost-effective surrogates for DFT in large-scale materials screening. Leading universal MLIPs (e.g., MACE-MP-0, SevenNet, Nequix) learn the Born-Oppenheimer Potential Energy Surface (PES) by regressing Energy \(E\), Force \(\mathbf{F}=-\nabla E\), and Stress \(\sigma\) (the EFS loss) on relaxation trajectory datasets like MPtrj and OMat24.

Limitations of Prior Work: Many critical physical properties—such as phonon dispersion, vibrational entropy \(S\), Helmholtz free energy \(F\), constant-volume heat capacity \(C_V\), and thermal conductivity \(\kappa\)—depend not on the zero-order or first-order quantities of the PES, but on second-order force constants \(\Phi_{aibj}=\partial^{2}E/\partial r_{a,i}\partial r_{b,j}\) or even higher-order derivatives. EFS loss only indirectly constrains second-order derivatives, leading many MPtrj-trained MLIPs to manifest "over-softened" PES curvatures near equilibrium configurations. This results in systematically low phonon frequencies, imaginary frequencies, and significant distortion in predicted phase stability and thermal conductivity.

Key Challenge: Direct supervision of force constants requires calculating the Hessian. However, crystal phonon calculations must use sufficiently large supercells to avoid self-interaction, often involving thousands of atoms. The \(3N\times 3N\) full Hessian calculation and memory scale as \(O(N^2)\), making full training infeasible. Furthermore, phonon data exclusively comprises equilibrium configurations, making direct fine-tuning prone to destroying the model's pre-trained capabilities on non-equilibrium configurations.

Goal: (1) Introduce the second-order PES curvature as a differentiable training signal into MLIPs; (2) Ensure this signal remains trainable on supercells with thousands of atoms; (3) Complete fine-tuning without compromising existing performance on MPtrj.

Key Insight: The authors first generated a scatter plot of "Hessian error vs. phonon thermodynamic error" across multiple MPtrj-trained base models (Fig. 2), observing a very strong positive correlation. This implies that reducing Hessian errors will inherently improve downstream phonon properties, effectively reducing the task of "improving phonon properties" to "aligning the Hessian."

Core Idea: A Hessian alignment loss term \(\mathcal{L}_\Phi\) is added to the EFS loss. For each structure, a single Hessian column is randomly sampled and gradients are calculated using a Hessian-vector product (HVP), reducing single-step training complexity from \(O(N^2)\) to \(O(N)\). Upstream EFS data co-training is utilized to prevent forgetting.

Method

Overall Architecture

The objective of PFT is straightforward: take an MLIP \(\hat{E}_\theta(\mathbf{r})\) pre-trained on large-scale trajectory data (primarily Nequix MP, replicated on MACE-MP-0 and Nequix OAM), and correct the "over-softened" PES curvature near equilibrium without losing original capabilities. It utilizes an additional phonon dataset (MDR Phonon, ~8.5k training materials, 300k finite-displacement DFT calculations) for second-order force constant labels and retains a portion of upstream MPtrj data to prevent forgetting.

During training, for each phonon supercell structure, an atom \(b\) and Cartesian direction \(j\) are randomly sampled to construct a unit vector \(\mathbf{v}\) that is 1 only at \((b,j)\). A single Hessian column \(\nabla^2_\mathbf{r}\hat{E}\,\mathbf{v}\) is computed via a Hessian-vector product and compared against the corresponding DFT force constant column using MAE. This, combined with the three EFS terms, forms \(\mathcal{L}_\text{PFT}\). For every step of PFT, \(K=4\) steps of standard EFS fine-tuning on upstream MPtrj data are performed (Algorithm 1). The resulting \(\hat{E}_\theta\) maintains PES curvature alignment with DFT; downstream inference using either finite displacement or analytical AD yields nearly identical results.

%%{init: {'flowchart': {'rankSpacing': 24, 'nodeSpacing': 28, 'padding': 6, 'wrappingWidth': 400, 'subGraphTitleMargin': {'top': 8, 'bottom': 16}}}}%%
flowchart TD
    A["Pre-trained MLIP Ê_θ<br/>(Nequix MP, MPtrj trained)"]
    P["Phonon Supercell Structure<br/>(MDR Phonon, DFT Force Constants Φ)"]
    subgraph STEP["Single-step PFT"]
        direction TB
        S["Random Column Sampling<br/>Atom b, Direction j → Unit Vector v"]
        H["Hessian-Vector Product<br/>∇²Ê·v to get Hessian column, O(N)"]
        L["Hessian Column Alignment Loss L_Φ<br/>MAE with DFT Force Constant column"]
        C["L_PFT = EFS terms + λ_Φ·L_Φ"]
        S --> H --> L --> C
    end
    U["Upstream EFS Co-training<br/>Per 1 PFT step, K=4 MPtrj EFS steps"]
    O["Curvature-aligned Ê_θ<br/>SOTA in Phonon/Thermal Conductivity"]
    A --> S
    P --> L
    C --> U
    U -->|"Loop 1:K to prevent forgetting"| S
    U --> O

Key Designs

1. Hessian Column Alignment Loss \(\mathcal{L}_\Phi\) + Random Column Sampling: Turning curvature into a supervised target

Properties like phonon dispersion, vibrational entropy, and thermal conductivity depend on second-order force constants \(\Phi\). EFS loss learns "forces" but imposes almost no direct constraint on "how forces change with position." The authors found that direct EFS fine-tuning on phonon displacement data performed worse than base models (Table 1, \(\omega_\text{max}\) error surged from 24 to 182), indicating that EFS signals from displacement points cannot replace curvature supervision. PFT treats the second derivative of energy with respect to coordinates as a first-class supervision target: \(\mathcal{L}_\Phi = \frac{1}{3N_a}\sum_{a,i}\mathbb{E}_{b,j}\,|\partial^2\hat{E}/\partial r_{a,i}\partial r_{b,j} - \Phi_{aibj}|\).

A full Hessian is \(3N\times 3N\), which is computationally prohibitive for large supercells. Consequently, one \((b,j)\) is uniformly sampled per structure per step, equating to supervising a single Hessian column. Since the expectation of this sampling equals the MAE of the full Hessian, the gradient is unbiased. Furthermore, E(3)-equivariant architectures provide significant symmetry redundancy in force constants, allowing single-column sampling to cover most degrees of freedom statistically while saving computation.

2. Hessian-Vector Product Reducing Complexity from \(O(N^2)\) to \(O(N)\): Enabling curvature supervision on large supercells

Explicitly constructing the full Hessian on large supercells causes memory exhaustion. PFT circumvents this using the HVP technique (Pearlmutter 1994): \(\nabla^2_\mathbf{r}\hat{E}\,\mathbf{v} = \nabla_\mathbf{r}((\nabla_\mathbf{r}\hat{E})^\top \mathbf{v})\). In JAX, this is jax.jvp(jax.grad(energy), (pos,), (v,))[1]—first computing forces via reverse-mode, then applying forward-mode JVP to obtain the derivative of force along \(\mathbf{v}\). This allows one Hessian column to be obtained in a single back-propagation pass without materializing \(N^2\) matrices. In implementation, multiple structures in a batch are combined into a single graph, and a combined HVP is performed. Optimizer updates require another gradient of the HVP, resulting in "triple-backward."

This reduction enables Hessian supervision for supercells with hundreds of atoms on a single A100 GPU. The entire PFT process takes only 35 A100 hours (including co-training) or 15 A100 hours (without), less than a third of the 100 A100 hours required for pre-training.

3. Upstream EFS Co-training: Minimizing catastrophic forgetting

Phonon data consists entirely of equilibrium configurations. Training exclusively on this data causes the PES to drift in non-equilibrium regions, evidenced by a significant increase in energy/force/stress errors on the MPtrj validation set (Fig. 3). PFT addresses this by performing \(K=4\) steps of standard EFS updates on the upstream dataset \(\mathcal{D}_\text{up}\) (MPtrj) for every 1 PFT step (Algorithm 1, lines 4-7). \(K\) is selected by monitoring validation sets on both sides.

Empirically, PFT without co-training degrades EFS performance on MPtrj. With 1:4 co-training, this degradation is almost entirely eliminated, while the Hessian MAE increases only slightly. In Matbench Discovery stability classification, co-training keeps the performance drop within 1%. Compared to methods like LoRA or EWC, this fixed-ratio data mixing is simpler to implement while effectively neutralizing forgetting.

Loss & Training

Total loss: \(\mathcal{L}_\text{PFT} = \lambda_E\mathcal{L}_E + \lambda_F\mathcal{L}_F + \lambda_\sigma\mathcal{L}_\sigma + \lambda_\Phi\mathcal{L}_\Phi\). The first three follow standard EFS (MAE for energy/stress, \(\ell_2\) for force). The phonon structure forces and stresses are approximated as 0 (equilibrium), and supercell energy is the unit cell energy multiplied by the number of repetitions. Co-training ratio \(K=4\). PFT runs for 200 epochs. The same hyperparameter configuration was reused across MACE-MP-0 and Nequix OAM without further tuning.

Key Experimental Results

Main Results

Dataset / Metric Nequix MP base Nequix MP PFT Nequix MP PFT (No co-train) Prev. SOTA (MPtrj)
MDR Phonon \(\omega_\text{max}\) (K) MAE 24 12 10 eSEN-MP 24
MDR Phonon \(S\) (J/K/mol) MAE 32 14 11 eSEN-MP 14
MDR Phonon \(F\) (kJ/mol) MAE 12 5 4 eSEN-MP 4
MDR Phonon \(C_V\) (J/K/mol) MAE 6 3 2 eSEN-MP 5
3rd-order Force Constant \(\Phi^{(3)}\) MAE (meV/ų) 10.52 8.35 7.46
Matbench Disc. Thermal \(\kappa_\text{SRME}\) 0.446 0.307 0.281 eSEN-30M 0.340

Errors in four phonon thermodynamic quantities were reduced by an average of 55%. The same recipe on MACE-MP-0 reduced \(\omega_\text{max}\) from 61 to 19 and \(S\) from 60 to 14. On the stronger Nequix OAM base model, PFT still achieved a 50% reduction. Notably, Nequix MP PFT (708K parameters) outperformed the Nequix OAM base model, suggesting Hessian supervision is more efficient than simply increasing upstream data volume.

Ablation Study

Configuration Hessian MAE MPtrj EFS Degradation Description
Nequix MP base High 0 EFS training only
EFS Fine-tuning on phonon displacement Higher than base \(\omega_\text{max}\) error spiked from 24 to 182; EFS signal cannot replace Hessian supervision
PFT (No co-training) Lowest Significant degradation (Fig. 3) Best phonons but catastrophic forgetting
PFT (co-training, \(K=4\)) Slightly higher than no co-train Minimal degradation Best overall balance
Force Constant Inference: FD vs. AD Nearly identical (Table 1) HVP analytical method eliminates the displacement distance hyperparameter

Key Findings

  • "Hessian error vs. phonon property error" shows strong positive correlation across models (Fig. 2 / Fig. 5). Correcting the curvature automatically improves downstream properties, redefining phonon prediction as a regression problem.
  • Even though only second-order derivatives are supervised, models show 20–30% improvement in third-order force constants and thermal conductivity, indicating that Hessian supervision implicitly constrains higher-order PES smoothness.
  • EFS training on rattled/perturbed structures does not substitute for Hessian supervision—this serves as an important warning for training paradigms like OMat24 that rely on noisy equilibrium configurations.

Highlights & Insights

  • HVP as a first-class training primitive: The authors adapted the HVP technique (common in PINNs/implicit differentiation) for MLIP training, enabling \(O(N)\) second-order derivative supervision on GPUs. The implementation is just a few lines of JAX, compatible with any energy model that supports grad.
  • "Inverse derivation" of training targets via correlation analysis: Fig. 2 proves "Hessian error = proxy for phonon error." Supervising the Hessian is a robust "find the right loss" paradigm, transferable to other tasks sensitive to high-order derivatives (electron-phonon coupling, thermoelectricity, elastic moduli).
  • Minimalist Co-training: Unlike complex methods like LoRA or EWC, the \(1:K\) mixing of upstream data is engineering-simple yet successfully mitigates catastrophic forgetting.
  • Small Model + Good Loss > Large Model + More Data: 708K-parameter Nequix MP PFT outperformed 30M-parameter eSEN-MP and Nequix OAM, suggesting that inductive biases in the loss function can be more valuable than data scale in physical ML.

Limitations & Future Work

  • Only validated on energy-conserving, \(E(3)\)-equivariant MLIPs. Non-equivariant models might introduce bias with single-column sampling, requiring additional data augmentation as noted by the authors.
  • Phonon data remains biased toward dynamically stable systems; effectiveness on systems with strong anharmonicity, polar polarization, or electron-phonon coupling (e.g., ferroelectrics, superconductors) is not fully validated.
  • Force constant labels (MDR Phonon) are derived from PBE functional finite displacement. Performance on other functionals (SCAN, meta-GGA, hybrid) or vdW corrections requires evaluation.
  • No public code (as of v4); reproduction requires rebuilding the training pipeline from Nequix/JAX.
  • Potential improvements: Upgrading single-column sampling to block-Hessian sampling, incorporating acoustic sum rules as hard constraints, or making \(K\) adaptive based on gradient magnitudes.
  • vs. eSEN-MP / SevenNet / MACE-MP-0 (pure EFS models): These rely on data and model capacity to implicitly approximate curvature. PFT explicitly supervises curvature. Table 1 shows 708K PFT outperforming 30M eSEN-MP, proving "better loss > bigger model."
  • vs. EFS augmentation with rattled/perturbed configurations (OMat24 style): Experiments show this does not improve Hessian accuracy and can be detrimental. "Noisy samples" cannot replace true second-order supervision.
  • vs. Hessian-aware Optimization / Physics-Informed Neural Networks (PINNs): Shares the philosophy of incorporating higher-order derivatives into the loss. This work successfully scales the mechanism to real material settings with large supercells.
  • Insight: Any task where the output is a higher-order derivative of a scalar function (vibrational spectroscopy, elastic constants, fluid Navier-Stokes residuals, stability analysis in Neural ODEs) can benefit from the "HVP column sampling + co-training" recipe.

Rating

  • Novelty: ⭐⭐⭐⭐ — Hessian supervision and HVP are existing tools, but this is the first scalable MLIP training paradigm to systematically prove downstream benefits.
  • Experimental Thoroughness: ⭐⭐⭐⭐⭐ — 3 base models × 2 upstream datasets × MDR Phonon / 3rd-order FC / Thermal Conductivity / Matbench Discovery. Ablations clearly explain why EFS alone fails and why co-training is necessary.
  • Writing Quality: ⭐⭐⭐⭐ — Clear formulas and algorithms. Correlation analysis in Fig. 2 is very persuasive.
  • Value: ⭐⭐⭐⭐⭐ — Achieves SOTA in phonons and thermal conductivity for MPtrj-trained MLIPs at less than 1/3 the pre-training cost; an "out-of-the-box" upgrade for the materials ML community.