Capture the Key in Reasoning to Enhance CoT Distillation Generalization¶
Conference: ACL 2025 (Long Paper)
arXiv: 2405.19737
Code: None (The paper states it will be open-sourced in the future)
Area: LLM / NLP — Knowledge Distillation, Chain-of-Thought Reasoning
Keywords: CoT Distillation, Key Reasoning Steps, Minimum Edit Distance, Dual CoTs, Smaller Language Model Reasoning Enhancement
TL;DR¶
Proposed EDIT (mistakE-Driven key reasonIng step distillaTion), which constructs paired positive/negative dual CoTs data, employs the minimum edit distance algorithm to locate key reasoning steps, and guides smaller models to focus on learning these key steps through a token-level fine-grained loss function, rather than simply mimicking the teacher's reasoning format.
Background & Motivation¶
LLMs possess powerful CoT reasoning capabilities, but their deployment cost is high, necessitating distillation into smaller language models (SLMs). Existing CoT distillation methods primarily perform SFT (Supervised Fine-Tuning) on the correct CoT data generated by a teacher LLM.
The authors uncover a crucial insight: the vast majority of CoT content consists of simple reasoning formats (templated language), while the key reasoning steps that genuinely determine the conclusion account for only about 4.7%. Consequently, simple SFT causes the student model to mimic the teacher's reasoning format (e.g., fixed phrases, transitional sentences) while remaining prone to errors or omissions in key steps—essentially "copying the style without mastering the substance."
Analogous to human learning: analyzing errors against correct answers often exposes the critical steps that lead to success or failure more clearly.
Core Problem¶
How can smaller language models truly learn key reasoning steps in CoT distillation, instead of merely mimicking the teacher's reasoning format?
Method¶
Overall Architecture¶
EDIT consists of three stages: 1. Data Preparation: Keep all CoTs generated by the teacher LLM (including both correct and incorrect ones). 2. Dual CoTs Construction: Generate paired correct-incorrect CoT data through carefully designed prompts. 3. Two-Stage Training: First perform SFT to establish fundamental reasoning capabilities, then apply KRSL (Key Reasoning Steps Learning) to focus on critical steps.
Key Designs¶
-
Dual CoTs Data Generation:
- Rectify Incorrect CoTs (Rectify): Design an Answer Hint Prompt (AHP) that provides the correct answer beforehand in few-shot exemplars, guiding the teacher LLM to generate a new version with a similar reasoning path but the correct conclusion for originally incorrect CoTs \(\rightarrow\) yields \(D^{-+}\).
- Corrupt Correct CoTs (Corrupt): Design a Contrastive CoTs Prompt (CCP) utilizing in-context learning with correct-incorrect pairs to induce the LLM to generate a variant with a similar reasoning path but an incorrect conclusion for originally correct CoTs \(\rightarrow\) yields \(D^{+-}\).
- The resulting dual CoTs feature "highly similar paths but completely different conclusions."
-
Locating Key Steps via Minimum Edit Distance:
- Apply the minimum edit distance algorithm to the correct and incorrect versions in the dual CoTs.
- Tokens labeled as
insert/replaceare treated as key steps in correct reasoning. - Tokens labeled as
delete/replaceare treated as key steps in incorrect reasoning. - Passages of exact match between the two are ignored.
-
Key Reasoning Steps Learning (KRSL):
- Assign token-level weights to key steps: correct key step weight \(\alpha=1.0\), incorrect key step weight \(\beta=0.025\).
- Optimization objective: Maximize the log-likelihood of key steps in correct CoTs, while minimizing the log-likelihood of key steps in incorrect CoTs.
- Loss function: $\(\max_{\pi_{sft}} \mathbb{E}[\mathcal{L}(\pi, q, CoT^+, \omega^+) - \mathcal{L}(\pi, q, CoT^-, \omega^-)]\)$
- Core difference from DPO: DPO sums over all tokens, which fails on highly similar dual CoTs because identical tokens dominate the loss; KRSL avoids this by optimizing only the key step tokens.
Key Experimental Results¶
| Dataset | Metric | EDIT | Std-CoT (Baseline) | Gain |
|---|---|---|---|---|
| BBH-test (IND) | Accuracy | 60.9% | 54.2% | +6.7% |
| BB-sub (OOD) | Accuracy | 31.1% | 28.7% | +2.4% |
| AGIEval (OOD) | Accuracy | 25.9% | 21.6% | +4.3% |
| ARC-E (OOD) | Accuracy | 64.1% | 59.6% | +4.5% |
| ARC-C (OOD) | Accuracy | 50.5% | 45.1% | +5.4% |
| Average | Accuracy | 46.5% | 41.8% | +4.7% |
While the teacher model (ChatGPT Zero-shot-CoT) averages 61.9%, LLaMA2-7B distilled via EDIT performs close to the teacher's level on ARC-C.
Compared to fair baselines: - Std-CoT w/ Repeat Sampling (aligned data quantity): 42.8% \(\rightarrow\) EDIT still leads by 3.7%. - Std-CoT w/ Dual CoTs (SFT directly on dual data): 43.8% \(\rightarrow\) EDIT still leads by 2.7%. - DPO almost collapses in this scenario: yielding only 8.1% (due to generation mode collapse).
Ablation Study¶
- w/o RWC (removing rectified incorrect CoTs): Average drops to 42.7% (-3.8%), showing that leveraging the teacher's incorrect CoTs provides diverse thinking patterns.
- w/o KRSL (removing key reasoning steps learning): Average drops to 44.3% (-2.2%), confirming that focused learning of key steps is the core mechanism.
- Model Scale: Consistently effective across TinyLLaMA-1.1B / LLaMA2-7B / LLaMA2-13B; the gains are more pronounced on harder tasks (BB-sub, AGIEval).
- Model Architecture: Consistently effective for CodeLLaMA-7B / LLaMA3-8B / Mistral-7B; stronger base models (Mistral) benefit even more.
- Correct vs. Incorrect Key Steps: The impact of correct key steps > incorrect key steps, but learning both jointly yields the best performance.
- Data Quality vs. Quantity: \(D_{dual}^-\) (teacher's native error pairs) possesses higher quality, yielding better results with less data than the larger \(D_{dual}^+\).
- Error Types: Logical errors (LEs) > Knowledge errors (KEs) \(\approx\) Calculation errors (MCEs); logical errors provide more generalizable reasoning patterns.
Highlights & Insights¶
- Precise Insight: Only 4.7% of key steps dictate reasoning success or failure, and SFT cannot distinguish key from non-key steps—this finding itself is highly valuable.
- Ingenious Method: Bridging classical NLP algorithms (minimum edit distance) with modern LLM distillation to find key steps is intuitive and effective.
- Thorough Analysis against DPO: Explains why DPO fails on highly similar dual data (identical tokens dominate the loss), and how KRSL bypasses this issue through precise token selection.
- Exploration of Error Types: Discovers that logical errors benefit reasoning distillation more than knowledge/calculation errors, offering guidance for future "error data engineering."
- Significant OOD Generalization: Demonstrates solid improvements across 4 OOD datasets, indicating that the model learns general reasoning rather than task-specific patterns.
Limitations & Future Work¶
- Teacher Model Limitations: Evaluated only with ChatGPT (gpt-3.5-turbo) as the teacher, without verifying stronger models like GPT-4.
- Student Model Limitations: Experiments were mostly conducted on the LLaMA2 series, without covering modern models like Qwen or Phi.
- Task Type Bias toward Multiple-Choice: BBH is heavily multiple-choice oriented; the method has not been verified on open-ended generation tasks (such as GSM8K or MATH).
- Dual CoTs Generation Cost: Relies on additional teacher API calls to generate dual data, increasing data preparation overhead.
- Sensitivity of KRSL Learning Rate: The second stage uses a learning rate of 5e-6, which is far smaller than the 2e-4 used in the SFT stage; hyperparameter sensitivity is not fully discussed.
- Lack of Objective CoT Evaluation: The paper acknowledges that current CoT quality is mainly evaluated via GPT-4 scoring, which lacks objective metrics.
Related Work & Insights¶
| Method | Core Idea | Difference from EDIT |
|---|---|---|
| Std-CoT (Magister 2023) | Directly performs SFT on teacher's correct CoTs | Does not distinguish between key/non-key steps |
| MT-CoT (Li 2022) | Jointly optimizes answer prediction and CoT via multi-task learning | Additional objectives but still relies on global SFT |
| SCOTT (Wang 2023) | Enhances reasoning consistency using counterfactual data | Focuses on consistency rather than locating key steps |
| Distilling Step-by-Step (Hsieh 2023) | Extracts rationales as additional supervision | Multi-task framework; no fine-grained step optimization |
| LEMA (An 2023) | Fine-tunes on corrected error data | Corrects the entire CoT without isolating key steps |
| DPO (Rafailov 2023) | Preference alignment learning | Whole-token loss fails on highly similar pairs |
Inspirations & Connections¶
- Generality of Key Step Localization: The minimum edit distance method for locating key tokens can be extended to other scenarios requiring fine-grained feedback, such as code generation and mathematical proofs.
- "Error Data Engineering" Direction: Different types of errors offer varying value to learning (logic > knowledge > calculation), inspiring more efficient ways to construct contrastive data.
- Relationship with RLHF/DPO: KRSL can be viewed as an improvement of DPO for extremely high-similarity preference pair scenarios, offering potential to combine with modern alignment methods.
- The 4.7% Proportion of Key Steps: Implies a large amount of redundant information in CoT, potentially inspiring research on CoT compression or implicit CoT.
Rating¶
- Novelty: ⭐⭐⭐⭐ The core insight (key steps account for only 4.7%) and the min-edit-distance localization method are novel, though the overall framework of dual data + contrastive learning is not entirely unprecedented.
- Experimental Thoroughness: ⭐⭐⭐⭐ Solid ablations across model scales/architectures, DPO comparison, and error type analysis, but lacks evaluations on mainstream mathematical reasoning benchmarks such as GSM8K or MATH.
- Writing Quality: ⭐⭐⭐⭐ Clear structure, intuitive diagrams, well-articulated motivation, and fully disclosed prompt templates.
- Value to Me: ⭐⭐⭐⭐ The key step localization concept and the DPO failure analysis are highly valuable for understanding CoT distillation and preference learning.