M3-JEPA: Multimodal Alignment via Multi-gate MoE based on JEPA¶
Conference: ICML 2025
arXiv: 2409.05929
Code: GitHub - M3-JEPA
Area: Multimodal VLM
Keywords: JEPA, Mixture-of-Experts, multimodal alignment, energy-based model, alternating gradient descent
TL;DR¶
Generalizes JEPA (Joint-Embedding Predictive Architecture) to multimodal alignment of arbitrary modality combinations. It utilizes a Multi-gate MoE as a cross-modal predictor to perform alignment in the latent space (rather than the token space), where the gating function decouples modality-specific and shared information. Alternating gradient descent is employed to avoid gradient conflict between multi-directional tasks. With only 140M trainable parameters, it outperforms state-of-the-art models like BLIP-2 (1.2B) on multiple retrieval and classification tasks.
Background & Motivation¶
Background: Mainstream modern multimodal learning adopts generative architectures, which fall into two categories: one involves training from scratch (such as OFA, BEiT-3), requiring massive data and computation; the other utilizes pretrained LLMs as backbones and fine-tunes lightweight connectors (such as BLIP-2, LLaVA), which is more computationally efficient. Both categories of methods optimize cross-modal alignment in the raw token space.
Limitations of Prior Work: Aligning in the token space is prone to modality collapse, where the signals of one modality dominate another. Key sources of this issue include: (1) gradient conflicts in multi-directional tasks; (2) distribution mismatch between continuous domains (images/videos) and discrete domains (text); and (3) informational uncertainty and redundancy (e.g., a single image can have multiple semantically equivalent but textually different descriptions). These factors make cross-modal alignment difficult to converge and can lead to missing key information.
Key Challenge: Alignment in the token space is too "superficial" as it requires predicting exact token sequences, whereas cross-modal information naturally exhibits uncertainty and one-to-many mappings. A scheme is needed to align in a more abstract latent space, preserving only the core semantic information shared across modalities.
Goal: (1) How to avoid modality collapse caused by token space alignment? (2) How to design a general any-to-any multimodal alignment framework? (3) How to avoid gradient conflicts between multi-directional tasks (such as image→text and text→image)?
Key Insight: JEPA (Joint-Embedding Predictive Architecture) approaches learning from the perspective of energy-based models. Instead of performing generative prediction in the token space, it uses a predictor to project input embeddings into the output embedding space, conducting alignment in the latent space. While I-JEPA and V-JEPA have validated the effectiveness of this paradigm in visual self-supervised learning, a general multimodal version has not yet been introduced.
Core Idea: Implement the cross-modal predictor of JEPA using a Multi-gate MoE. The gating function automatically decouples modality-specific and shared information. Combined with alternating gradient descent to prevent multi-task conflicts, this achieves the first any-to-any multimodal JEPA alignment framework.
Method¶
Overall Architecture¶
Given \(M\) modalities and \(T\) tasks, each modality produces embeddings using frozen pretrained unimodal encoders (Llama3-8B for text, DINOv2-Large for images, LanguageBind for audio). For task \(t\), the input embedding \(e_x^t\) and output embedding \(e_y^t\) are generated by their corresponding modality encoders. A Multi-gate MoE predictor \(\mathcal{P}\) projects \(e_x^t\) into the latent space of \(e_y^t\), forming \(e_{x \to y}^t = \mathcal{P}(e_x^t)\). Then, it minimizes the energy function \(\mathcal{F}^t(x,y)\) between \(e_{x \to y}^t\) and \(e_y^t\) in the latent space. The encoders are fine-tuned with 3-layer LoRA (rank=64) while keeping the rest of the parameters frozen. The MoE predictor is randomly initialized and trained with all parameters.
Key Designs¶
-
Multi-gate MoE Cross-Modal Predictor:
- Function: Acts as a lightweight cross-modal connector, projecting input modality embeddings into the latent space of the output modality.
- Mechanism: Implements \(N=12\) expert networks for each modality, totaling \(M \times N\) experts, utilizing Top-\(K\) (\(K=4\)) sparse activation. The gating function takes the concatenation of the input embedding \(e_x\) and a learnable modality embedding \(e_m\) as input: \(\mathbb{G} = \text{softmax}(g \cdot [e_x \oplus e_m])\), where \(g\) is a shared projection matrix. Two parallel gates (\(L=2\)) are used to independently serve the contrastive loss and regularization loss. The total number of trainable parameters is only 140M.
- Design Motivation: (1) Modality-specific paths (modality experts + modality embedding \(e_m\)) capture information unique to each modality; (2) the shared projection matrix \(g\) establishes a shared cross-modal subspace; (3) the lightweight MoE is significantly more efficient than fully fine-tuning the encoders.
-
Energy Function with Dual Contrastive and Regularization Loss:
- Function: Optimizes cross-modal alignment from two complementary perspectives.
- Mechanism: The regularization loss \(\mathcal{L}_{\text{reg}} = |e_{x \to y} - e_y|_2^2\) directly pulls the embeddings of positive pairs closer (minimizing conditional entropy \(\mathcal{H}(y|x)\)). The contrastive loss \(\mathcal{L}_{\text{cl}}\) uses in-batch negative samples through InfoNCE to pull positive pairs together while pushing negative pairs apart (maximizing mutual information \(\mathcal{I}(x;y)\)). The total loss is defined as \(\mathcal{L} = \alpha \mathcal{L}_{\text{reg}} + (1-\alpha)\mathcal{L}_{\text{cl}}\), where theoretical analysis proves that the optimal \(\alpha = 0.5\) (corresponding to the critical temperature of free energy minimization), which is also validated by experiments.
- Design Motivation: Utilizing only the contrastive loss can lead to representation collapse (where all embeddings converge), while using only the regularization loss fails to distinguish negative pairs. The dual loss forms a complete energy function, concurrently maximizing mutual information and minimizing conditional entropy from an information theory perspective.
-
Alternating Gradient Descent (AGD):
- Function: Resolves gradient conflicts among multi-directional multimodal tasks.
- Mechanism: Alternates among \(T\) tasks across training steps. In each step, forward and backward propagation are performed only for the current task: \(\theta(i+1) \leftarrow \theta(i) - \eta \nabla_\theta \mathcal{L}^t\), where \(\text{mod}(i, T) = t\). Unlike joint optimization, AGD decouples parameter updates for each task, avoiding gradient conflicts caused by tasks like image→text and text→image competing for the same connector weights.
- Design Motivation: Traditional joint optimization is prone to gradient conflicts (seesaw effect) in multi-directional tasks. AGD draws inspiration from the successful experiences of alternating training in multi-task learning.
Loss & Training¶
The total loss is \(\mathcal{L} = 0.5 \cdot \mathcal{L}_{\text{reg}} + 0.5 \cdot \mathcal{L}_{\text{cl}}\), where the regularization loss is the L2 distance and the contrastive loss is InfoNCE. Training utilizes the Adam optimizer with a batch size of 128, a cosine learning rate schedule, a warmup of 0.1, and a weight decay of 0.005.
Key Experimental Results¶
Main Results: Vision-Language Retrieval (Flickr30K / COCO)¶
| Method | Trainable Parameters | Flickr30K I→T R@1 | Flickr30K T→I R@1 | COCO I→T R@1 | COCO T→I R@1 |
|---|---|---|---|---|---|
| CLIP | 428M | 88.0 | 68.7 | - | - |
| BLIP-2 (ViT-g) | 1.2B | 97.6 | 89.7 | 85.4 | 68.3 |
| BEiT-3 | 1.9B | 94.9 | 81.5 | 84.8 | 67.2 |
| M3-JEPA | 140M | 97.8 | 97.8 | 87.7 | 89.7 |
M3-JEPA achieves 89.7% on COCO T→I R@1 with only 140M parameters, significantly outperforming BLIP-2's 68.3% (+21.4 pt), and reaches 97.8% in both directions on Flickr30K.
Ablation Study: Impact of Method Components on COCO Retrieval¶
| MoE | AGD | I→T R@1 | I→T R@5 | I→T R@10 | T→I R@1 | T→I R@5 | T→I R@10 |
|---|---|---|---|---|---|---|---|
| ✗ | ✓ | 74.4 | 86.0 | 92.2 | 82.3 | 89.5 | 92.6 |
| ✓ | ✗ | 68.2 | 68.7 | 81.1 | 74.2 | 88.7 | 92.4 |
| ✓ | ✓ | 87.7 | 99.6 | 99.9 | 89.7 | 99.7 | 99.9 |
Both MoE and AGD are indispensable: removing MoE (replacing with an MLP) drops I→T R@1 to 74.4%, and removing AGD drops it to 68.2%, while combining both achieves 87.7%.
Key Findings¶
- M3-JEPA achieves 86.6% accuracy on ImageNet-1K classification, outperforming CLIP-ViT (82.1%) and DINOv2 (83.2%), demonstrating that the JEPA paradigm can handle classification tasks.
- It also outperforms methods like LanguageBind in zero-shot audio-text retrieval (Clotho/Audiocaps), demonstrating the framework's modular extensibility.
- It approaches the performance of BLIP-2 on VQA tasks (VQAv2 test-dev 82.3%), proving its adaptability to multimodal input scenarios.
- The theoretically predicted optimal \(\alpha=0.5\) perfectly matches the experimental verification.
Highlights & Insights¶
- The first any-to-any multimodal JEPA framework, extending I-JEPA/V-JEPA from unimodal self-supervised learning to cross-modal alignment.
- Outperforms the 1.2B parameter BLIP-2 with only 140M parameters: the efficiency advantage of the lightweight MoE predictor + frozen encoders is clear.
- The modality decoupling of the MoE gating is supported by information theory: the shared matrix \(g\) corresponds to mutual information, and the modality embedding \(e_m\) corresponds to conditional entropy.
- The perfect alignment of \(\alpha=0.5\) between theoretical analysis and experiments enhances the credibility of the framework.
Limitations & Future Work¶
- The quality of gating information decoupling depends on data quality and the representation spaces of modality encoders.
- Frozen encoders limit fine-grained adaptation performance (performance on VQA is slightly below the fully fine-tuned BEiT-3).
- Task switching in AGD increases training complexity and the difficulty of hyperparameter tuning.
- Validation on additional modalities, such as 3D or tactile signals, is currently missing.
Related Work & Insights¶
- vs I-JEPA/V-JEPA: This work is the first generalization of JEPA to multimodality, transitioning from unimodal self-supervised learning to cross-modal alignment.
- vs CLIP/ALIGN: CLIP performs contrastive learning in the token space, while M3-JEPA operates in the latent space; the latter filters out irrelevant information through the MoE predictor.
- Insights: The JEPA + MoE paradigm may become a new foundation for self-supervised multimodal learning, particularly suitable for scenarios with high informational uncertainty.
Rating¶
⭐⭐⭐⭐⭐ Pioneering work generalizing JEPA to multimodality. It offers complete theoretical analysis (information theory + optimal hyperparameters + convergence guarantees), covers text/image/audio modalities across various tasks including retrieval, classification, and VQA, and presents a striking efficiency advantage with only 140M parameters.