Skip to content

M3-JEPA: Multimodal Alignment via Multi-gate MoE based on JEPA

Conference: ICML 2025
arXiv: 2409.05929
Code: GitHub - M3-JEPA
Area: Multimodal VLM
Keywords: JEPA, Mixture-of-Experts, multimodal alignment, energy-based model, alternating gradient descent

TL;DR

Generalizes JEPA (Joint-Embedding Predictive Architecture) to multimodal alignment of arbitrary modality combinations. It utilizes a Multi-gate MoE as a cross-modal predictor to perform alignment in the latent space (rather than the token space), where the gating function decouples modality-specific and shared information. Alternating gradient descent is employed to avoid gradient conflict between multi-directional tasks. With only 140M trainable parameters, it outperforms state-of-the-art models like BLIP-2 (1.2B) on multiple retrieval and classification tasks.

Background & Motivation

Background: Mainstream modern multimodal learning adopts generative architectures, which fall into two categories: one involves training from scratch (such as OFA, BEiT-3), requiring massive data and computation; the other utilizes pretrained LLMs as backbones and fine-tunes lightweight connectors (such as BLIP-2, LLaVA), which is more computationally efficient. Both categories of methods optimize cross-modal alignment in the raw token space.

Limitations of Prior Work: Aligning in the token space is prone to modality collapse, where the signals of one modality dominate another. Key sources of this issue include: (1) gradient conflicts in multi-directional tasks; (2) distribution mismatch between continuous domains (images/videos) and discrete domains (text); and (3) informational uncertainty and redundancy (e.g., a single image can have multiple semantically equivalent but textually different descriptions). These factors make cross-modal alignment difficult to converge and can lead to missing key information.

Key Challenge: Alignment in the token space is too "superficial" as it requires predicting exact token sequences, whereas cross-modal information naturally exhibits uncertainty and one-to-many mappings. A scheme is needed to align in a more abstract latent space, preserving only the core semantic information shared across modalities.

Goal: (1) How to avoid modality collapse caused by token space alignment? (2) How to design a general any-to-any multimodal alignment framework? (3) How to avoid gradient conflicts between multi-directional tasks (such as image→text and text→image)?

Key Insight: JEPA (Joint-Embedding Predictive Architecture) approaches learning from the perspective of energy-based models. Instead of performing generative prediction in the token space, it uses a predictor to project input embeddings into the output embedding space, conducting alignment in the latent space. While I-JEPA and V-JEPA have validated the effectiveness of this paradigm in visual self-supervised learning, a general multimodal version has not yet been introduced.

Core Idea: Implement the cross-modal predictor of JEPA using a Multi-gate MoE. The gating function automatically decouples modality-specific and shared information. Combined with alternating gradient descent to prevent multi-task conflicts, this achieves the first any-to-any multimodal JEPA alignment framework.

Method

Overall Architecture

Given \(M\) modalities and \(T\) tasks, each modality produces embeddings using frozen pretrained unimodal encoders (Llama3-8B for text, DINOv2-Large for images, LanguageBind for audio). For task \(t\), the input embedding \(e_x^t\) and output embedding \(e_y^t\) are generated by their corresponding modality encoders. A Multi-gate MoE predictor \(\mathcal{P}\) projects \(e_x^t\) into the latent space of \(e_y^t\), forming \(e_{x \to y}^t = \mathcal{P}(e_x^t)\). Then, it minimizes the energy function \(\mathcal{F}^t(x,y)\) between \(e_{x \to y}^t\) and \(e_y^t\) in the latent space. The encoders are fine-tuned with 3-layer LoRA (rank=64) while keeping the rest of the parameters frozen. The MoE predictor is randomly initialized and trained with all parameters.

Key Designs

  1. Multi-gate MoE Cross-Modal Predictor:

    • Function: Acts as a lightweight cross-modal connector, projecting input modality embeddings into the latent space of the output modality.
    • Mechanism: Implements \(N=12\) expert networks for each modality, totaling \(M \times N\) experts, utilizing Top-\(K\) (\(K=4\)) sparse activation. The gating function takes the concatenation of the input embedding \(e_x\) and a learnable modality embedding \(e_m\) as input: \(\mathbb{G} = \text{softmax}(g \cdot [e_x \oplus e_m])\), where \(g\) is a shared projection matrix. Two parallel gates (\(L=2\)) are used to independently serve the contrastive loss and regularization loss. The total number of trainable parameters is only 140M.
    • Design Motivation: (1) Modality-specific paths (modality experts + modality embedding \(e_m\)) capture information unique to each modality; (2) the shared projection matrix \(g\) establishes a shared cross-modal subspace; (3) the lightweight MoE is significantly more efficient than fully fine-tuning the encoders.
  2. Energy Function with Dual Contrastive and Regularization Loss:

    • Function: Optimizes cross-modal alignment from two complementary perspectives.
    • Mechanism: The regularization loss \(\mathcal{L}_{\text{reg}} = |e_{x \to y} - e_y|_2^2\) directly pulls the embeddings of positive pairs closer (minimizing conditional entropy \(\mathcal{H}(y|x)\)). The contrastive loss \(\mathcal{L}_{\text{cl}}\) uses in-batch negative samples through InfoNCE to pull positive pairs together while pushing negative pairs apart (maximizing mutual information \(\mathcal{I}(x;y)\)). The total loss is defined as \(\mathcal{L} = \alpha \mathcal{L}_{\text{reg}} + (1-\alpha)\mathcal{L}_{\text{cl}}\), where theoretical analysis proves that the optimal \(\alpha = 0.5\) (corresponding to the critical temperature of free energy minimization), which is also validated by experiments.
    • Design Motivation: Utilizing only the contrastive loss can lead to representation collapse (where all embeddings converge), while using only the regularization loss fails to distinguish negative pairs. The dual loss forms a complete energy function, concurrently maximizing mutual information and minimizing conditional entropy from an information theory perspective.
  3. Alternating Gradient Descent (AGD):

    • Function: Resolves gradient conflicts among multi-directional multimodal tasks.
    • Mechanism: Alternates among \(T\) tasks across training steps. In each step, forward and backward propagation are performed only for the current task: \(\theta(i+1) \leftarrow \theta(i) - \eta \nabla_\theta \mathcal{L}^t\), where \(\text{mod}(i, T) = t\). Unlike joint optimization, AGD decouples parameter updates for each task, avoiding gradient conflicts caused by tasks like image→text and text→image competing for the same connector weights.
    • Design Motivation: Traditional joint optimization is prone to gradient conflicts (seesaw effect) in multi-directional tasks. AGD draws inspiration from the successful experiences of alternating training in multi-task learning.

Loss & Training

The total loss is \(\mathcal{L} = 0.5 \cdot \mathcal{L}_{\text{reg}} + 0.5 \cdot \mathcal{L}_{\text{cl}}\), where the regularization loss is the L2 distance and the contrastive loss is InfoNCE. Training utilizes the Adam optimizer with a batch size of 128, a cosine learning rate schedule, a warmup of 0.1, and a weight decay of 0.005.

Key Experimental Results

Main Results: Vision-Language Retrieval (Flickr30K / COCO)

Method Trainable Parameters Flickr30K I→T R@1 Flickr30K T→I R@1 COCO I→T R@1 COCO T→I R@1
CLIP 428M 88.0 68.7 - -
BLIP-2 (ViT-g) 1.2B 97.6 89.7 85.4 68.3
BEiT-3 1.9B 94.9 81.5 84.8 67.2
M3-JEPA 140M 97.8 97.8 87.7 89.7

M3-JEPA achieves 89.7% on COCO T→I R@1 with only 140M parameters, significantly outperforming BLIP-2's 68.3% (+21.4 pt), and reaches 97.8% in both directions on Flickr30K.

Ablation Study: Impact of Method Components on COCO Retrieval

MoE AGD I→T R@1 I→T R@5 I→T R@10 T→I R@1 T→I R@5 T→I R@10
74.4 86.0 92.2 82.3 89.5 92.6
68.2 68.7 81.1 74.2 88.7 92.4
87.7 99.6 99.9 89.7 99.7 99.9

Both MoE and AGD are indispensable: removing MoE (replacing with an MLP) drops I→T R@1 to 74.4%, and removing AGD drops it to 68.2%, while combining both achieves 87.7%.

Key Findings

  • M3-JEPA achieves 86.6% accuracy on ImageNet-1K classification, outperforming CLIP-ViT (82.1%) and DINOv2 (83.2%), demonstrating that the JEPA paradigm can handle classification tasks.
  • It also outperforms methods like LanguageBind in zero-shot audio-text retrieval (Clotho/Audiocaps), demonstrating the framework's modular extensibility.
  • It approaches the performance of BLIP-2 on VQA tasks (VQAv2 test-dev 82.3%), proving its adaptability to multimodal input scenarios.
  • The theoretically predicted optimal \(\alpha=0.5\) perfectly matches the experimental verification.

Highlights & Insights

  • The first any-to-any multimodal JEPA framework, extending I-JEPA/V-JEPA from unimodal self-supervised learning to cross-modal alignment.
  • Outperforms the 1.2B parameter BLIP-2 with only 140M parameters: the efficiency advantage of the lightweight MoE predictor + frozen encoders is clear.
  • The modality decoupling of the MoE gating is supported by information theory: the shared matrix \(g\) corresponds to mutual information, and the modality embedding \(e_m\) corresponds to conditional entropy.
  • The perfect alignment of \(\alpha=0.5\) between theoretical analysis and experiments enhances the credibility of the framework.

Limitations & Future Work

  • The quality of gating information decoupling depends on data quality and the representation spaces of modality encoders.
  • Frozen encoders limit fine-grained adaptation performance (performance on VQA is slightly below the fully fine-tuned BEiT-3).
  • Task switching in AGD increases training complexity and the difficulty of hyperparameter tuning.
  • Validation on additional modalities, such as 3D or tactile signals, is currently missing.
  • vs I-JEPA/V-JEPA: This work is the first generalization of JEPA to multimodality, transitioning from unimodal self-supervised learning to cross-modal alignment.
  • vs CLIP/ALIGN: CLIP performs contrastive learning in the token space, while M3-JEPA operates in the latent space; the latter filters out irrelevant information through the MoE predictor.
  • Insights: The JEPA + MoE paradigm may become a new foundation for self-supervised multimodal learning, particularly suitable for scenarios with high informational uncertainty.

Rating

⭐⭐⭐⭐⭐ Pioneering work generalizing JEPA to multimodality. It offers complete theoretical analysis (information theory + optimal hyperparameters + convergence guarantees), covers text/image/audio modalities across various tasks including retrieval, classification, and VQA, and presents a striking efficiency advantage with only 140M parameters.