SAEmnesia: Erasing Concepts in Diffusion Models with Supervised Sparse Autoencoders¶
Conference: ICML 2026
arXiv: 2509.21379
Code: https://github.com/EIDOSLAB/SAEmnesia
Area: AI Safety / Concept Unlearning / Diffusion Model Interpretability
Keywords: Concept Erasing, Sparse Autoencoders, Supervised Training, Feature Centralization, Diffusion Models
TL;DR¶
By adding a supervised "concept-latent" assignment loss during Sparse Autoencoder (SAE) training, each target concept is forced to concentrate on a single neuron (feature centralization). This reduces concept erasing in diffusion models from a 2D hyperparameter search ("find neurons + adjust strength") to a single multiplier adjustment. On UnlearnCanvas, it achieves a 9.22-point average improvement over the SOTA SAeUron, reduces hyperparameter search costs by 96.67%, and is more robust to adversarial attacks.
Background & Motivation¶
Background: The safe deployment of text-to-image diffusion models (SD series) requires "concept unlearning"—selectively erasing unwanted concepts like nudity, copyrighted characters, or specific objects while maintaining other generative capabilities. Current approaches fall into two categories: (i) fine-tuning model weights (ESD, UCE, SalUn, etc.); (ii) frozen weights with mechanistic intervention on cross-attention activations using sparse autoencoders (SAE) (Concept Steerers, SAeUron). The latter is reversible (removing the SAE restores the model fully) and interpretable.
Limitations of Prior Work: SAeUron, the strongest representative of the SAE route, trains SAEs in an unsupervised manner, leading to feature splitting—one concept (e.g., "Bears") is scattered across multiple latents. Erasing it requires: (1) searching thousands of latents for the combination representing "Bears" (evaluating 30 latent subsets × 7 strengths = 210 trials in experimental reports); (2) managing overlap between latents of related concepts, which risk damaging nearby concepts like "Cats."
Key Challenge: Monosemanticity (one neuron sensitive to one concept) and one-to-one mapping (one concept corresponding to one neuron) in unsupervised SAEs are not naturally guaranteed; the latter has been consistently missing. Without one-to-one mapping, mechanistic intervention requires combinatorial searching, and interpretability remains post-hoc.
Goal: Bind each target concept to a unique latent during training while maintaining SAE reconstruction quality, simplifying inference-time erasure to a single scalar negation.
Key Insight: Diffusion training already provide supervision signals—anchor prompts (e.g., An image of Bears) naturally carry concept labels. Reintroducing these signals into SAE training is more direct than post-hoc alignment using score functions.
Core Idea: Add two supervised losses to the standard TopK SAE loss—a Concept Assignment loss to push each concept's activation to a designated latent, and a Decorrelation loss to suppress activation correlation between distinct macro-classes (objects vs. styles).
Method¶
Overall Architecture¶
The SAEmnesia pipeline is attached after the frozen Stable Diffusion v1.5 cross-attention block up.1.1:
- Activation Collection: For each target concept \(c\), 80 anchor prompts (e.g.,
An image of Bears) are used to generate 50-step denoising trajectories. Cross-attention feature maps \(\mathbf{F}_t \in \mathbb{R}^{h\times w\times d}\) are extracted fromup.1.1. The \(d\)-dimensional vector at each spatial position serves as a training sample, labeled with the corresponding object/style. - Two-stage Training: (i) Standard unsupervised TopK SAE training (reconstruction loss + auxiliary loss for dead latents); (ii) Finetuning with the composite SAEmnesia loss to reinforce "concept-latent" binding.
- Concept-Latent Assignment \(\Phi\): Before training, a score function is used to assign the latent index \(i_c = \Phi(c)\) with the highest score for each concept \(c\).
- Inference-time Erasure: To delete concept \(c\), the activation of latent \(i_c\) is multiplied by a negative scalar \(\gamma_c < 0\), then decoded back to the activation space and fed back to the diffusion backbone.
The process leaves the diffusion backbone unchanged. The SAE is hot-swappable, allowing switching between "forget" and "recover" with one line of code.
Key Designs¶
-
Supervised Concept-Latent Assignment and Concept Assignment Loss:
- Function: Transitions the mapping of "which latent encodes which concept" from post-hoc attribution into a hard training constraint.
- Mechanism: Uses the score function \(\text{score}(i,t,c,D) = \frac{\mu(i,t,D_c)}{\sum_j \mu(j,t,D_c)+\delta} - \frac{\mu(i,t,D_{\neg c})}{\sum_j \mu(j,t,D_{\neg c})+\delta}\) to measure the exclusivity of latent \(i\) for concept \(c\). The latent with the maximum score is chosen as the designated slot \(i_c\). During training, \(\mathcal{L}_{\text{CA}} = -\frac{1}{B}\sum_b \frac{1}{|\mathcal{T}^{(b)}|}\sum_{c \in \mathcal{T}^{(b)}} \log \sigma(v^{(b)}_{i_c})\) is applied to push the pre-activation \(v_{i_c}\) to activate strongly when the concept is present.
- Design Motivation: CA loss only acts on specific concept-latent pairs, providing local sparse supervision that doesn't disrupt unsupervised feature learning.
-
Cross-Macro-Group Decorrelation Constraint:
- Function: Prevents designated latents from different macro-classes (e.g., objects vs. styles) from co-activating, which would cause interference during intervention.
- Mechanism: Groups concepts into disjoint macro-classes \(\mathcal{C} = \bigcup_m \mathcal{C}_m\). Penalizes Pearson correlation \(\rho\) between activation vectors \(\mathbf{a}_c = [v_{i_c}^{(1)}, \dots, v_{i_c}^{(B)}]^\top\) of latents from different groups using \(\mathcal{L}_{\text{DC}} = \frac{\sum_{m<m'} \sum_{i\in\mathcal{I}_m, j\in\mathcal{I}_{m'}} \rho(\mathbf{a}_i, \mathbf{a}_j)}{\sum_{m<m'}|\mathcal{I}_m||\mathcal{I}_{m'}|}\).
- Design Motivation: Full decorrelation might destroy natural semantic relationships (e.g., "Cats" vs "Dogs"), so it is applied only at the macro-group level to ensure erasing an object does not affect a style.
-
Single Latent Thresholding (Inference-time steering):
- Function: Condenses erasure to a scalar multiplication, active only when the latent is truly triggered.
- Mechanism: \(z_{i_c}\) is modified as \(z_{i_c} = \gamma_c \mu(i_c, t, D_c) z_{i_c}\) if and only if \(z_{i_c} > \mu(i_c, t, D)\) (exceeding the mean activation). \(\gamma_c < 0\) is the only tunable hyperparameter.
- Design Motivation: Thresholding acts as a guardrail, ensuring intervention only occurs when the model is actually generating the target concept, minimizing collateral damage.
Loss & Training¶
The complete objective is \(\mathcal{L}_{\text{SAEmnesia}} = \mathcal{L}_{\text{unsupSAE}} + \beta(\mathcal{L}_{\text{CA}} + \eta \mathcal{L}_{\text{DC}}) + \lambda \mathcal{L}_{L_1}\). Training proceeds in two stages: pure unsupervised pre-training, followed by supervised finetuning. Activations are sourced from SD v1.5 up.1.1 across all timesteps.
Key Experimental Results¶
Main Results¶
Object erasure on UnlearnCanvas (percentage; UA = Unlearning Accuracy, IRA/CRA = In-/Cross-domain Retain Accuracy):
| Method Group | Method | UA ↑ | IRA ↑ | CRA ↑ | Avg ↑ |
|---|---|---|---|---|---|
| Fine-tune | ESD | 92.15 | 55.78 | 44.23 | 64.05 |
| Fine-tune | SalUn | 86.91 | 96.35 | 99.59 | 94.28 |
| Adapter | SPM | 71.25 | 90.79 | 81.65 | 81.23 |
| Unsupervised SAE | SAeUron | 87.16 | 85.57 | 74.14 | 82.29 |
| Supervised SAE | Ours | 94.65 | 91.39 | 88.48 | 91.51 |
Compared to SAeUron: UA +7.49, IRA +5.82, CRA +14.34, Avg +9.22.
Ablation Study¶
| Configuration / Scenario | Key Metric | Description |
|---|---|---|
| Hyperparam Search Cost | 7 vs. 210 evals | SAEmnesia only tunes \(m\); cost reduced by 96.67% |
| Sequential Erasure (9 objs) | UA 92.4% vs. 64.0% | +28.4 points over baseline; RA 60.9% vs. 48.4% |
| White-box Attack (UnlearnDiffAtk) | UA 57.50% vs. 34.20% | Drop of 40.1 pts vs. SAeUron drop of 49.5 pts |
| Black-box Attack (Ring-A-Bell) | UA 97.0% vs. 79.5% | Robustness holds across threat models |
| NSFW Mitigation (I2P) | 9 detected vs. 18 | Using only 2 latents ("naked man"/"naked woman") |
| Feature score distribution | 0.0404 vs. 0.0166 | 2.43× increase in score peak after supervised training |
Key Findings¶
- CA loss is the primary contributor: The 2.43× score peak is the root of downstream gains in sequential erasure, adversarial robustness, and efficiency.
- Decorrelation works at the macro-group level: The authors state that interference within the same macro-group (e.g., Dogs vs. Cats) remains an open issue.
- Emergent adversarial robustness: One-to-one mapping narrows the attack surface; adversarial prompts must hit a specific latent precisely.
- Uniform multiplier stability: SAEmnesia with a global \(\gamma_c\) outperforms SAeUron even with optimal per-concept search.
Highlights & Insights¶
- Post-hoc attribution to training-time constraint: Monosemanticity is treated as a training objective rather than just an evaluation metric, transforming mechanistic interpretability from an observation tool to a control tool.
- Search space reduction: Shifting from \(m \times l\) to \(m\) search avoids combinatorial explosion, which is vital for sequential concept erasure.
- Activation thresholding guardrail: Decoupling "high static score" from "active in current forward pass" is the hidden key to precision.
- Macro-group decorrelation: Penalizing correlations based on user-centric semantic groups rather than raw data distributions is a transferable strategy for multi-task SAEs.
Limitations & Future Work¶
- Structural Limitations: Only validated on U-Net; transformer-based models like FLUX require adaptation.
- Closed-vocabulary: New concepts require recalculating scores (post-hoc binding is used if not retraining).
- Within-group interference: Interference between similar concepts (e.g., Dogs vs. Cats) is not fully resolved.
- Storage Cost: Maintaining multiple SAE checkpoints (unsupervised vs. supervised) involves higher engineering overhead.
Related Work & Insights¶
- vs. SAeUron: Both use SAEs on cross-attention activations; the difference is SAEmnesia introduces supervised assignment to solve feature splitting.
- vs. Concept Steerers: Concept Steerers intervene on text embeddings; SAEmnesia intervenes on the visual path, making it more robust to attacks bypassing the text encoder (e.g., Ring-A-Bell).
- vs. Fine-tuning (SalUn/ESD): Fine-tuning is irreversible and risks damaging downstream capabilities; SAEs are modular and auditable.
- vs. ScaPre: ScaPre uses spectral trace regularization for sequential unlearning; SAEmnesia provides an alternative by ensuring representation additivity through one-to-one mapping.
Rating¶
- Novelty: ⭐⭐⭐⭐ Introducing supervision to SAEs for feature centralization is a clear leap, though the innovation is primarily in the loss design.
- Experimental Thoroughness: ⭐⭐⭐⭐ Covers UnlearnCanvas, I2P, adversarial attacks, and sequential erasure, but lacks validation on latest-generation architectures.
- Writing Quality: ⭐⭐⭐⭐ Clear motivation-method-evidence chain with honest discussion of limitations.
- Value: ⭐⭐⭐⭐⭐ Advances mechanistic erasure toward engineering viability (single latent, scalar multiplication, modularity).