What Happens During the Loss Plateau? Understanding Abrupt Learning in Transformers¶
Conference: NeurIPS 2025 arXiv: 2506.13688 Code: github.com/pulkitgopalani/tf-loss-plateau Area: Interpretability Keywords: abrupt learning, loss plateau, Transformer training dynamics, representation collapse, repetition bias, attention map
TL;DR¶
This paper systematically investigates the phenomenon of "abrupt learning" in Transformer training, revealing that during the loss plateau the model has already learned partial solutions while simultaneously exhibiting output repetition bias and representation collapse. It further demonstrates that the slow learning of attention maps constitutes the key bottleneck, with findings validated in the early pretraining stages of LLMs such as Pythia and OLMo.
Background & Motivation¶
Abrupt learning is a common phenomenon in Transformer training: model performance stagnates during a prolonged loss plateau before suddenly improving dramatically. This phenomenon is regarded as a manifestation of "emergence."
Although prior work has studied this phenomenon in specific settings (e.g., in-context learning, sparse parity, syntactic learning), it lacks:
A unified understanding: Common patterns and mechanisms across tasks remain unclear.
Generality: Most prior work relies on task-specific assumptions, limiting the generalizability of conclusions.
Connection to LLMs: Whether findings from small-scale experiments transfer to large-model pretraining has not been verified.
Three core questions addressed in this paper: - What common patterns exist in the input–output behavior and internal representations of models during the plateau? - What hidden progress accumulates prior to the abrupt improvement? - Are these findings universal?
Method¶
Overall Architecture¶
The primary research platform consists of shallow Transformers (1–2 layers) trained on algorithmic tasks, supplemented by validation on LLM pretraining. The core algorithmic task is the Moving Window Sum (MWS):
Given input sequence \(x_1, x_2, \ldots, x_n\), the output \(y_i\) is:
Experimental setup: \(n=16, p=17\), using a 1-layer 1-head linear-attention Transformer.
Key Finding 1: Partial Solutions During the Plateau¶
During the loss plateau, the model has already learned to correctly predict a subset of simpler tokens. For instance, in the MWS task, \(y_1 = x_1\) requires only copying; the model learns to predict the first output token correctly at an early stage while subsequent tokens remain incorrect. This pattern appears consistently across multiple algorithmic tasks.
Key Finding 2: Repetition Bias¶
During the plateau, the model tends to generate repeated tokens, e.g., \(x, x, x, \ldots\). The repetition frequency is defined as:
In early training (approximately the first 50 steps), \(\rho\) rapidly rises from near 0 to approximately 0.8, indicating that this is an implicit bias introduced by gradient-based training.
For tasks without consecutive repetitions (e.g., prefix sums), the repetition bias manifests differently — only a small number of distinct tokens appear in the output. This is quantified using sequence entropy:
Key Finding 3: Representation Collapse¶
Hidden representations at different output positions become nearly parallel. This is measured via pairwise cosine similarity:
In early training, the average cosine similarity across output positions (excluding the already-learned first position) rapidly rises to approximately 0.95, then drops significantly as performance improves.
An important distinction: this differs from the initialization-time rank collapse observed in deep softmax Transformers — the representation collapse studied here occurs several training steps after initialization, not at initialization, and uses shallow linear attention.
Key Finding 4: Attention Map Learning as the Key Bottleneck¶
Attention Progress Metric (APM):
where \(\Omega\) denotes the set of positions in the optimal attention map. APM monotonically increases during the plateau, from near 0 to approximately 0.8, and its growth is smoother than changes in loss or accuracy — indicating that the attention map makes steady progress even when the loss appears stagnant.
Intervention: Biasing the Attention Map¶
By applying a multiplicative mask \(M\) to the attention map during training — scaling entries at optimal positions \((i,j) \in \Omega\) by a factor \(c\):
- \(c > 1\) (biased toward optimal): Representation collapse is reduced (peak cosine drops from 0.95 to approximately 0.6), repetition frequency is lower, and convergence is faster.
- \(0 < c < 1\) (biased away from optimal): The plateau is prolonged, and both representation collapse and repetition bias worsen.
Optimal Initialization Experiments¶
- Fixing attention layer weights to their optimal values → almost no plateau, rapid convergence, no representation collapse.
- Fixing the MLP or Embedding to optimal values → training dynamics are nearly unchanged; plateau and collapse still occur.
Conclusion: The attention map is the primary bottleneck responsible for representation collapse and the loss plateau.
Loss & Training¶
- Adam optimizer, learning rate \(10^{-4}\), no weight decay.
- Online training (256 new samples drawn per step), no fixed training set.
- Consistency of findings verified across different optimizers (SGD, Muon) and hyperparameter settings.
Key Experimental Results¶
Main Results: Multi-Task Validation¶
| Task | Partial Solution | Repetition Bias | Representation Collapse |
|---|---|---|---|
| MWS (Moving Window Sum) | \(y_1 = x_1\) copy | \(\rho \approx 0.8\) | COS \(\approx 0.95\) |
| Prefix Sum | First token | Low sequence entropy | COS \(\approx 0.80\) |
| Multi-digit Addition | Final carry | Present | Present |
| Permutation | Partial positions | Present | Present |
| Histogram | Partial bins | Present | Present |
All three phenomena are consistently observed across all tasks, confirming the universality of the findings.
Ablation Study: Attention Intervention¶
| Intervention Condition | Peak COS | Convergence Time | Peak Repetition Frequency |
|---|---|---|---|
| \(c = 1\) (baseline) | \(\approx 0.95\) | Normal | \(\approx 0.8\) |
| \(c = 10\) (strongly biased toward optimal) | \(\approx 0.6\) | Significantly faster | Low |
| \(c = 0.2\) (biased away from optimal) | \(> 0.95\) | Significantly delayed | High and persistent |
| KQ fixed at optimal | \(\approx 0.45\) | Very fast | Low |
| KQOV fixed at optimal | \(\approx 0.15\) | Fastest | Very low |
| MLP fixed at optimal | \(\approx 0.95\) | Essentially unchanged | Essentially unchanged |
LLM Validation¶
| Model | Early Peak COS | Repetition Bias |
|---|---|---|
| Pythia-14M | \(> 0.9\) | Present |
| Pythia-1B | \(> 0.9\) | Present |
| Pythia-1.4B | \(> 0.9\) | Present |
| Pythia-2.8B | \(> 0.9\) | Present |
| OLMo-2 7B (step 150) | \(\approx 0.93\) | Present |
| OLMo-2 7B (step 600) | \(\approx 0.43\) | Dissipated |
Key Findings¶
- Repetitive sequences are easier to learn: The REPEAT1 task (\(y_i = x_1\)) exhibits almost no loss plateau; a single gradient step drives COS to approximately 0.5.
- The \(\alpha_1\) metric (\(\frac{1}{n}\sum \mathbf{1}[y_i = y_1]\)) rapidly approaches 1.0 in early training — the model tends to copy the first token to all positions.
- The attention layer, not the MLP, causes collapse: Comparison of residual streams before and after the attention layer confirms that collapse occurs after the attention layer.
- APM shows progress before the loss drops: Hidden progress accumulates during the plateau before any visible improvement in loss.
Highlights & Insights¶
- Unified perspective: This work is the first to systematically understand Transformer training dynamics by treating partial solutions, repetition bias, representation collapse, and attention learning as an integrated whole.
- Causal validation: Through attention map intervention experiments, the paper establishes not merely correlation but a causal chain: attention learning → collapse/repetition → loss plateau.
- "Search" hypothesis: The plateau is likened to a search process by the attention layer over the token space, analogous to the first layer of a two-layer network aligning with the support of the target function.
- Bridging small models and LLMs: Validation of the same phenomena in the early pretraining of Pythia and OLMo substantially strengthens the generalizability of the conclusions.
- Methodological clarity: Multiple quantitative metrics — APM, repetition frequency, sequence entropy — render abstract training dynamics measurable and comparable.
Limitations & Future Work¶
- Insufficient theoretical explanation: Why does gradient training introduce representation collapse bias? Only a "search" hypothesis is proposed, without rigorous theoretical verification.
- Dominated by small models: Primary conclusions are based on 1–2 layer Transformers; how layers interact in deeper multi-layer models remains unclear.
- Specificity of algorithmic tasks: The tasks used have well-defined optimal solutions; how findings generalize to settings without clear optima, such as natural language, requires further investigation.
- Linear attention limitations: Although some conclusions are partially verified on softmax attention, properties specific to linear attention may limit the generalizability of certain findings.
- Limited LLM validation: Only early checkpoints are examined for representation collapse; partial solutions and attention progress in LLMs are not analyzed in depth.
Related Work & Insights¶
- Distinction from grokking: Grokking refers to post-memorization generalization on a fixed dataset, whereas this paper studies the breakthrough from a plateau in online training — the mechanisms differ.
- Distinction from rank collapse: Rank collapse occurs at initialization in deep softmax Transformers, while the representation collapse studied here is training-induced and occurs in shallow linear-attention models.
- Implications for LLM training: Understanding early training dynamics can inform better learning rate schedules, attention initialization strategies, and curriculum learning designs.
- Attention transfer: The findings are consistent with the effectiveness of pretrained attention transfer in vision Transformers, suggesting that attention patterns constitute transferable knowledge across tasks.
Rating¶
- Novelty: ⭐⭐⭐⭐⭐ — First work to unify the multi-dimensional characteristics of the loss plateau and establish the central role of attention learning.
- Experimental Thoroughness: ⭐⭐⭐⭐⭐ — Multi-task validation, intervention experiments, diverse metrics, and LLM verification; extremely comprehensive.
- Writing Quality: ⭐⭐⭐⭐⭐ — Clear logic, polished figures, and cleverly designed quantitative metrics.
- Value: ⭐⭐⭐⭐⭐ — Provides a foundational contribution to understanding Transformer training dynamics with practical implications for LLM training.