Zero-Shot Generalization of Vision-Based RL Without Data Augmentation¶
Conference: ICML2025
arXiv: 2410.07441
Code: To be confirmed
Area: Reinforcement Learning
Keywords: Disentangled Representation Learning, Associative Memory, Hopfield Networks, Zero-Shot Generalization, Visual Reinforcement Learning, Data Augmentation
TL;DR¶
Proposes ALDA (Associative Latent DisentAnglement), which achieves zero-shot generalization of visual RL in unseen environments through disentangled representation learning and an associative memory mechanism, performing comparably to methods using tens of millions of external data samples without requiring data augmentation.
Background & Motivation¶
Generalization of visual RL agents to new environments is a long-standing unsolved challenge. Current mainstream methods rely on data augmentation (e.g., random crop, random convolution, image overlay) to prevent overfitting by expanding the coverage of training data. However, this strategy faces fundamental issues:
- Computational overhead grows exponentially with variations: It is necessary to cover all possible combinations of environmental changes.
- Unstable training: Data augmentation can disrupt the stability of RL training.
- Essentially "weak disentanglement": The paper mathematically proves that data augmentation methods are actually performing implicit weak disentanglement (separating task-relevant/irrelevant variables), but fail to achieve complete factorization.
Biological inspiration: Neurons in the hippocampal-entorhinal system, such as grid cells and object-vector cells, each encode a single factor of variation (e.g., distance, direction). This decoupled representation + memory association mechanism helps organisms achieve rapid generalization. The paper argues that pure disentangled representation is insufficient for OOD generalization (as disproved by Schott et al., 2022); the crucial missing component is associative memory—which allows OOD inputs to be mapped back to known values dimension-by-dimension in the disentangled space.
Method¶
Overall Architecture: SAC + ALDA¶
ALDA adds two modules to the standard SAC (Soft Actor-Critic) framework:
- Disentangled Representation Learning (improved based on QLAE)
- Associative Memory (implicit Hopfield networks)
Disentangled Representation Learning¶
Adopts the discretized latent space design of QLAE (Quantized Latent Autoencoder):
- The encoder \(f_\theta\) maps observations to a continuous latent space.
- Each latent variable dimension has an independent scalar codebook \(Z = V_1 \times \cdots \times V_{n_z}\).
- Discretizes continuous outputs via nearest-neighbor quantization:
Frame Stacking Processing¶
Visual RL commonly uses frame stacking to encode temporal information, but disentanglement models perform poorly on stacked images. ALDA's solution:
- Collapses \(k\) frames into the batch dimension and encodes them separately into independent disentangled vectors \(z_d \in \mathbb{R}^{Bk \times n_{s_i}}\).
- Reshapes them back to \(\mathbb{R}^{B \times kn_{s_i}}\) and fuses temporal information through a 1D-CNN to obtain the final representation \(z \in \mathbb{R}^{B \times e}\).
Associative Memory Mechanism¶
Core insight of the paper: The quantization operation of QLAE is essentially already a Hopfield network. According to the generalized Hopfield framework:
In QLAE: similarity = L1 distance, separation = argmin, projection = identity function.
Improvement: Replacing argmin with a Softmax separation function yields the retrieval dynamics of modern Hopfield networks:
where \(\beta\) controls the degree of memory separation. When \(\beta \to \infty\), it degrades to the original argmin.
Loss & Training¶
The final objective function consists of commitment loss + reconstruction loss + weight decay:
Key Design Choices:
- Retain the commitment loss (encoder \(\to\) codebook direction) and remove the quantize loss (codebook \(\to\) encoder direction).
- Explanation: The codebook stays stable as "task-optimized memories," allowing the encoder to learn to map to these memories.
- Very strong weight decay \(\lambda_\theta = \lambda_\phi = 0.1\).
Data Augmentation = Weak Disentanglement Theorem¶
Theorem 1: If \(Q^*(z, a)\) is invariant to distractor variables, then the dimensions in the latent space encoding task-relevant variables \(D\) and task-irrelevant variables \(E\) must satisfy:
This implies that data augmentation methods are intrinsically performing partial factorization (weak disentanglement) but cannot guarantee complete disentanglement. In contrast, complete disentanglement allows OOD inputs to be mapped back to known values independently on a dimension-by-dimension basis.
Key Experimental Results¶
Experimental Setup¶
- Training environments: 4 tasks from DeepMind Control Suite (walker walk, cartpole balance, finger spin, etc.)
- Evaluation environments: Color Hard (extreme RGB color randomization), DistractingCS (camera shake + random background videos)
- Latent variable dimension \(|z_d| = 12\) (unified across all tasks)
Main Results¶
| Method | Extra Data/Augmentation | Color Hard | DistractingCS | Training Performance |
|---|---|---|---|---|
| ALDA | None | ✅ Best (except SVEA) | ✅ Best (except SVEA) | ✅ Stable |
| SVEA | 1.8M real scene image overlay | Best | Best | Stable |
| DARLA | None (Two-stage) | Poor | Poor | Unstable |
| SAC+AE | None | Average | Average | Stable |
| RePo | None (Model-based) | Average | Average | Stable |
Key Findings¶
- Without using any external data, ALDA approaches or even matches SVEA (which uses the tens of millions-scale Places dataset) on multiple tasks.
- When SVEA uses other augmentations instead of image overlay, ALDA can outperform SVEA.
- BioAE (another disentanglement method) performs well initially but degrades later, showing that the associative memory mechanism is crucial for maintaining generalization.
- Latent traversal visualization indicates that each dimension indeed encodes a single factor (e.g., torso orientation, hip joint angle, scene color, etc.).
Highlights & Insights¶
- Solid Theoretical Contribution: Mathematically proves that data augmentation \(\equiv\) weak disentanglement, establishing a connection between two seemingly unrelated fields.
- Neuroscience-Inspired System Design: Disentanglement (single-factor neurons in the entorhinal cortex) + Association (memory association in the hippocampus) constitute a complete generalization pipeline.
- QLAE as a Hopfield Network: Reveals the equivalence between quantization operations and associative memory retrieval, achieving superior gradient properties via the Softmax separation function.
- No Discarding of Irrelevant Information: Unlike task-centric representation methods, ALDA retains all variables but encodes them in a disentangled manner. These "irrelevant" details may become useful when the task changes.
- Minimal Changes, Significant Impact: Simply replacing the separation function (argmin \(\to\) Softmax) and removing the quantization loss significantly improves generalization performance.
Limitations & Future Work¶
- Undecoupled Temporal Information: \(z_d\) only models factors of the image distribution, while temporal information is processed by the downstream 1D-CNN. How to learn disentangled representations that encompass both image and temporal factors remains an open question.
- Limited Performance on DistractingCS: Camera shake affects the implicitly learned dynamics, causing significant degradation across all methods.
- Manual Setting of Latent Dimensions: \(|z_d|=12\) is selected empirically, lacking a method to automatically determine the actual number of factors.
- Simple Associative Memory Model: Does not utilize learnable attention-based Hopfield networks; a stronger memory model could potentially yield better performance.
- Limited Evaluation Environments: Only validated on DMControl, lacking high-dimensional manipulation tasks or real-world robot experiments.
- Unquantifiable Disentanglement Evaluation: Actual factors are unknown in real tasks, limiting evaluation to qualitative latent traversal visualization.
Related Work & Insights¶
- DARLA (Higgins et al., 2017b): The first RL disentanglement generalization method, which, however, uses two-stage training and lacks sufficient random action coverage.
- SAC+AE (Yarats et al., 2021b): Deterministic autoencoder + reconstruction loss, showing some generalization capability but with no focus on disentanglement.
- SVEA (Hansen et al., 2021): SOTA among data augmentation methods, which uses the Places dataset and incurs high computational overhead.
- QLAE (Hsu et al., 2023): Current SOTA disentanglement method, upon which this paper introduces associative memory.
- Modern Hopfield Networks (Ramsauer et al., 2021): Associative memory on continuous representations, demonstrating equivalence with attention mechanisms.
Rating¶
- Novelty: ⭐⭐⭐⭐ — The combination of disentanglement + associative memory is highly novel, and the theoretical proof is valuable.
- Experimental Thoroughness: ⭐⭐⭐ — Thoroughly validated on DMControl, but lacks more complex/real-world scenarios.
- Writing Quality: ⭐⭐⭐⭐ — Clear motivation with a smooth logical transition between theory and methodology.
- Value: ⭐⭐⭐⭐ — Provides a new paradigm for visual RL generalization beyond data augmentation.