Skip to content

Align-then-Unlearn: Embedding Alignment for LLM Unlearning

Conference: ICML 2025
arXiv: 2506.13181
Code: https://github.com/ExplainableML/align-then-unlearn
Area: AI Safety
Keywords: LLM Unlearning, Embedding Space, Semantic Unlearning, Privacy Protection, Concept-level Unlearning

TL;DR

The Align-then-Unlearn framework is proposed to perform unlearning in the semantic embedding space (rather than at the token level). It first pre-trains an embedding prediction module to align future semantic representations, and then fine-tunes the LLM to push predicted embeddings away from the target concept embedding, achieving concept-level knowledge unlearning that is robust to prompt rephrasings.

Background & Motivation

Background: LLMs inevitably retain sensitive information (personal privacy, copyrighted content, etc.) after training on massive datasets. Machine unlearning aims to selectively remove the influence of specific data from trained models.

Limitations of Prior Work: Existing SOTA methods (e.g., Gradient Ascent, DPO, NPO) operate at the token level, defining the unlearning target via specific text sequences in the forget set. This leads to two limitations: (a) the unlearning scope is difficult to control precisely because the forget set can be extremely large; (b) they suffer from a lack of robustness to prompt rephrasings, as a simple change in phrasing can easily bypass the unlearning.

Key Challenge: Token-level unlearning only "masks" information at the output level without truly removing the target knowledge from the model's semantic representations. Consequently, related concepts can still be extracted through alternate activation paths.

Goal: How to achieve concept-level, paraphrasing-robust knowledge unlearning?

Key Insight: Since token granularity is too fine, it is more effective to operate in the semantic embedding space. By using a single embedding vector to represent the "concept to be forgotten" as a whole, the model's internal representations can be pushed away from this target concept.

Core Idea: First pre-train an embedding prediction head to align with the semantic space, and then utilize this prediction head as a "probe" to guide the LLM's hidden states away from the target concept embedding.

Method

Overall Architecture

Align-then-Unlearn operates in two phases: - Phase 1 - Alignment Pre-training: A small embedding prediction module \(E\) is attached to the LLM and trained to map the LLM's hidden states to the future token semantic embedding space generated by a pre-trained text encoder (MPNet). - Phase 2 - Unlearning: The prediction head \(E\) is frozen. Using the target concept (e.g., "Stephen King") embedding \(e_{\text{unlearn}}\) as an anchor, the LLM is fine-tuned to minimize the similarity between the predicted embedding \(\hat{e}_t\) and \(e_{\text{unlearn}}\).

Given an input token sequence \((x_1, \dots, x_T)\), the model generates hidden states \((h_1, \dots, h_T)\). The embedding prediction head maps \(h_t\) to \(\hat{e}_t = E(h_1, \dots, h_t)\), which represents the joint semantics of the next \(k\) tokens.

Key Designs

  1. Embedding Prediction Head:

    • Function: Maps LLM hidden states to the semantic embedding space, predicting the overall semantics of future \(k\) tokens.
    • Mechanism: Utilizes a 6-layer network with a hidden dimension of 768. The alignment loss is defined by the cosine distance \(\mathcal{L}_{\text{align}} = 1 - \text{sim}(\hat{e}_t, e_t)\), where \(e_t\) is obtained by encoding the future window \((x_{t+1}, \dots, x_{t+k})\) using a frozen MPNet.
    • Design Motivation: Compared to token-by-token prediction, embedding-space representations capture global semantics, allowing the model to capture concept-level rather than literal information.
  2. Unlearning in Embedding Space:

    • Function: Fine-tunes LLM parameters to push predicted embeddings away from the target concept.
    • Mechanism: The unlearning loss is defined as \(\mathcal{L}_{\text{unlearn}} = \max(0, \text{sim}(\hat{e}_t, e_{\text{unlearn}}) - \tau)\), which penalizes the model only when the similarity exceeds a threshold \(\tau\).
    • Design Motivation: The threshold \(\tau\) provides fine-grained control, preventing over-forgetting and protecting the model's remaining capabilities. Additionally, the forget target can be defined by a single text description (e.g., "Stephen King"), eliminating the need for large forget sets.
  3. Iterative Alignment-Unlearning Alternating Training:

    • Function: Alternately executes embedding head realignment and LLM unlearning updates.
    • Mechanism: After unlearning, the distribution of LLM hidden states shifts, rendering the embedding head ineffective. Re-aligning the head restores its diagnostic capability to continue driving deeper unlearning.
    • Design Motivation: Establishes adversarial dynamics—the LLM attempts to hide target concepts from \(E\), while \(E\) continually recovers its detection capability, forcing the LLM to perform genuine knowledge elimination in deeper representations.

Loss & Training

  • Alignment phase: \(\theta_E^* = \arg\min_{\theta_E} \mathbb{E}[\mathcal{L}_{\text{align}}]\), training only the embedding head.
  • Unlearning phase: \(\theta_M^* = \arg\min_{\theta_M} \mathbb{E}[\mathcal{L}_{\text{unlearn}}]\), freezing the embedding head and fine-tuning the LLM.
  • Dynamic threshold decaying: Gradually decreasing \(\tau\) to achieve progressive unlearning.

Key Experimental Results

Main Results

On the RWKU benchmark, evaluated on Phi-3-mini-4k-instruct and compared with SOTA methods (averaging over 15 unlearning targets):

Method Forget FB ↓ Forget QA ↓ Forget AA ↓ Neighbor QA ↑ MMLU ↑
Before Unlearning 47.1 47.4 55.8 61.4 64.4
GA (Full) 17.8 14.3 26.3 51.7 64.3
DPO (Full) 25.0 19.1 29.9 39.6 63.0
NPO (Full) 22.5 16.9 27.3 53.6 64.2
ATU (20%) 13.5 15.3 25.9 52.3 64.5

Ablation Study

Config Forget QA ↓ Neighbor QA ↑ MMLU ↑ Description
ATU (50% threshold) 40.5 64.4 64.2 Mild unlearning, best retention of neighbor knowledge
ATU (35% threshold) 24.8 56.4 64.8 Moderate unlearning
ATU (20% threshold) 15.3 52.3 64.5 Deep unlearning
Layer 10 54.32* - - Unlearning effect varies across layers
Layer 20 12.40* - - Some targets have the best effect in middle layers

*Results for a single target (Warren Buffett).

Key Findings

  • ATU achieves the lowest Forget FB (13.5%) at a 20% threshold, while maintaining MMLU at 64.5% (slightly higher than the original model's 64.4%).
  • Performance across different unlearning targets varies widely from layer to layer, implying that concept knowledge is unevenly distributed within the network.
  • A persistent trade-off exists between unlearning performance and neighbor knowledge preservation.

Highlights & Insights

  • The perspective shift of concept-level vs. token-level unlearning is highly elegant—defining the forget target via a single embedding vector is exceptionally data-efficient, eliminating the need to construct large-scale forget sets.
  • The adversarial alternating training is cleverly designed: the embedding head continuously "catches up" with changes in the LLM, forcing unlearning to occur in deeper representations rather than acting as a shallow output mask, sharing conceptual similarity with GANs but applied to unlearning.
  • The threshold \(\tau\) provides a tunable switch for the unlearning-performance trade-off, offering higher controllability than most alternatives.

Limitations & Future Work

  • The threshold \(\tau\) lacks an adaptive adjustment mechanism, leading to unstable transferability across different targets.
  • The loss of neighbor knowledge remains significant, suggesting that entanglement among concepts in the embedding space is difficult to avoid entirely.
  • Evaluated only on Phi-3-mini, leaving the performance on larger models (e.g., 70B+) untested.
  • Currently focuses solely on entity-level unlearning (names); the efficacy of unlearning for more complex concepts (e.g., technical knowledge, reasoning patterns) remains unknown.
  • Robustness against membership inference attacks is not discussed.
  • vs. GA/NPO: Token-level gradient ascent methods can rapidly drop forget scores but are vulnerable to rephrasings; ATU operates in the embedding space and is theoretically more robust.
  • vs. DPO: DPO requires positive and negative sample pairs, whereas ATU only requires a single concept description.
  • vs. ICU: ICU yields poorer results (Forget QA drops only to 34.6%), which ATU significantly outperforms.

Rating

  • Novelty: ⭐⭐⭐⭐ The perspective of embedding-space unlearning is novel, and the adversarial alternating training design is clever.
  • Experimental Thoroughness: ⭐⭐⭐ Only evaluated on a single benchmark and model; lacks quantitative comparisons regarding rephrasing robustness.
  • Writing Quality: ⭐⭐⭐⭐ Clear, concise, and highly intuitive diagrams.
  • Value: ⭐⭐⭐⭐ Proposes a promising paradigm for concept-level unlearning.