Skip to content

Exploring and Exploiting Stability in Latent Flow Matching

Conference: ICML 2026
arXiv: 2605.08398
Code: https://github.com/briqr/explo-r-it-ing_lfm_stability
Area: Diffusion Models / Flow Matching / Data Pruning
Keywords: Latent Flow Matching, Trajectory Stability, Data Pruning, Coarse-to-Fine, Inference Acceleration

TL;DR

This work systematically characterizes the "trajectory stability" of Latent Flow Matching (LFM)—under the same noise seed, pruning 75% of the data, changing model size, or altering training seeds still produces nearly identical images. This property is then translated into two practical algorithms: (1) Using balanced-clustering pruning, 50% of CelebA-HQ data can be pruned with a slight FID improvement, and 75% can be pruned on ImageNet; (2) A Coarse-to-Fine two-stage generation, combining DiT-XL/2 (675M) and DiT-S/2 (33M), achieves 2.15× faster inference.

Background & Motivation

Background: Diffusion models have become the mainstream paradigm for image/video/medical image generation. Flow Matching (FM), as an ODE alternative to DDPM, is increasingly popular due to fewer sampling steps. Latent FM (LFM) further moves FM into the VAE latent space and serves as the backbone for large models like SD3 and Flux.

Limitations of Prior Work: Training LFM is extremely expensive—requiring massive datasets, long durations, and huge computational resources. Conditional models also need extensive manual labeling. However, the community has never systematically asked: How large must the dataset be? How large must the model be? Some scattered observations suggest the existence of stability (Kadkhodaie observed model convergence across splits in score-based diffusion), but no practical pruning/acceleration solutions have been proposed, and prior work is limited to low-resolution pixel space.

Key Challenge: Theoretically, FM learns "transport between distributions" and should be sensitive to small perturbations in the sample distribution. Yet, empirical evidence shows that FM models, even under large perturbations (removing half the data, changing to a model 20× larger), still map the same \(x_0\) to nearly the same \(x_1\). If this "stability" is real, it implies that much of the training data is redundant and can be pruned.

Goal: (1) Strictly measure this stability in LFM (using ArcFace similarity for faces and DINO similarity for ImageNet under the same seed); (2) Theoretically explain it (based on the extremely peaked softmax in the closed-form FM solution from Bertrand 2025); (3) Translate stability into practical algorithms—data pruning and model pruning.

Key Insight: Bertrand 2025 proves that in the optimal velocity field of rectified FM, \(\hat{u}^*(x,t)=\sum_i \lambda_i(x,t)\frac{x^i-x}{1-t}\), the softmax weights \(\lambda_i\) become extremely peaked early on—a single training sample dominates the entire trajectory. Thus, as long as this "dominant sample" remains in the data, pruning other samples has minimal effect on the trajectory.

Core Idea: Leverage LFM's intrinsic stability to simultaneously improve "training efficiency (less data/labels)" and "inference efficiency (model size composition)", empirically validated with three pruning criteria and balanced clustering.

Method

Overall Architecture

The method proceeds along two axes:

  • Data Side (Pruning): Defines three sample-scoring criteria (gradient \(\mathcal{G}\), loss \(\mathcal{L}\), clustering \(\mathcal{C}\)) to prune the training set \(S\) into \(S'\subset S\). For each pruning ratio \(pr\), LFM is trained and FID is compared.
  • Inference Side (C2F): Trains two DiT models—a small DiT-S/2 (33M) for the early segment \(t\in[0,t_0)\), and a large DiT-XL/2 (675M) for the later segment \(t\in[t_0,1]\). ODE reverse integration and seam loss are used to stitch the two segments.

Key Designs

  1. Three FM-Compatible Pruning Criteria:

    • Function: Assigns an "importance score" to each training sample; the top \(1-pr\) are retained.
    • Mechanism:
      • \(\mathcal{G}\) (Gradient): Trains a small proxy model for 7% of steps, fixes \(M=2\) noises and \(T=8\) timesteps, computes squared gradient norm per sample, normalizes by per-\(t\) mean to remove timestep bias, yielding \(s_i^{\mathcal{G}}\).
      • \(\mathcal{L}\) (Loss): Replaces the gradient in the above with loss value, much cheaper, commonly used as the main criterion.
      • \(\mathcal{C}\) (Clustering): Performs k-means clustering in CLIP image embedding space, with proportional (sampling by cluster size to preserve distribution) and balanced (equal samples per cluster for dataset balancing) variants; within clusters, samples can be selected by proximity to center, distance, or kernel-mean matching.
    • Design Motivation: For discriminative models, \(\mathcal{L}\) selects hard examples effectively; but in FM, most loss comes from shared noise variance, so stable signals require following "shared noise paths" and EMA—this is the core adaptation from classification pruning to FM.
  2. Coarse-to-Fine Two-Stage Generation (C2F):

    • Function: Reduces inference cost by ≈2.15× while maintaining/improving FID.
    • Mechanism: First, a lightweight Coarse model \(v_C\) is trained on the pruned \(S'\) for \(t\in[0,t_0)\). The pretrained Fine model \(v_F\) (DiT-XL/2) covers \(t\in[t_0,1]\). To ensure smooth transition at \(t_0\), Fine performs ODE reverse integration \(x_{k+1}=x_k+h\,v_F(x_k,t_k),\,h<0\) from \(x_1\) back to \(x_{t_0}\), using this \(x_{t_0}\) as the Coarse training target, and adds seam loss \(\mathcal{L}_{\text{seam}}^v=\|v_F(x_{t_0},t_0)-v_C(x_{t_0},t_0)\|^2\).
    • Design Motivation: Stability indicates that large and small models produce similar trajectories—thus, a small model suffices for the noise-dominated early stage, avoiding unnecessary computation by the 675M-parameter model. Seam loss simply aligns the "seam" between the two segments, and a few epochs suffice to fine-tune a usable C2F.
  3. Balanced Clustering for Fairness (\(\mathcal{C}_b\)):

    • Function: Performs k-means on CLIP embeddings, then samples equally from each cluster to automatically balance dataset bias.
    • Mechanism: On CelebA-HQ, the unpruned model's generated images are gender-skewed (more females than males); after \(\mathcal{C}_b\) pruning, PaliGemma computes gender KL divergence dropping from 0.044 to 0.016 (almost as low as explicit label-aware \((\mathcal{C}_b)_{\text{gender}}=0.005\)), with similar reductions for age/skin-tone/hair-color attributes.
    • Design Motivation: Stability ensures clusters are "mutually non-interfering"—removing samples within one cluster does not affect trajectories in others. Thus, cluster-level balancing can safely correct data bias without harming FID.

Loss & Training

The total loss for the Coarse model: \(\mathcal{L}_{\text{coarse}}=\mathbb{E}\,\mathcal{L}_{\text{FM}}^{t\in[0,t_0)}+\lambda_v\,\mathcal{L}_{\text{seam}}^v\). The seam coefficient \(\lambda_v\) is a hyperparameter; in the paper, \(t_0=0.7\) achieves the best FID/speed trade-off. Coarse uses DiT-S/2, Fine uses DiT-XL/2; on H100 with batch 128 and \(256^2\) resolution, C2F runs at 43.53 ms/img, Fine-only at 93.95 ms/img.

Key Experimental Results

Main Results

CelebA-HQ (\(pr=0.5\)) FID under different pruning criteria (lower is better):

Method FID Notes
Unpruned 24.24 Full data baseline
Random 25.25±0.38 Random pruning
\(\mathcal{G}\) (High Gradient) 24.62 Nearly unchanged
\(\mathcal{G}^{-1}\) (Low Gradient) 29.75 Significantly worse
\(\mathcal{L}\) (High Loss) 33.92 Worst (counterintuitive, opposite to classification)
\(\mathcal{L}^{-1}\) (Low Loss) 23.49 Slight improvement
\(\mathcal{C}_p\) 25.19 Proportional
\(\mathcal{C}_b\) 22.80 Balanced clustering best
\(\mathcal{C}_b^\kappa\) 23.42 Kernel variant

ImageNet (DiT-XL/2 conditional, 200k iterations):

Pruning Rate \(pr\) FID Trend Notes
0 (unpruned) Baseline
0.75 Slight increase until 600k, then converges Most stable long-term gain
0.9 Fastest before 200k, drops after 590k Strongest mid-term
0.95 Fastest before 170k, then collapses Short-term sprint

Ablation Study

C2F on CelebA-HQ: effect of seam position \(t_0\):

Configuration FID@\(t_0=0.7\) Inference Speed (ms/img) Notes
Fine-only 24.24 93.95 All DiT-XL/2
C2F (unpruned Coarse) Slightly better 43.53 2.15× speedup
C2F + \(\mathcal{C}_b\) pruned Coarse Best 43.53 Win-win for speed and FID
C2F_male (violates stability) 44.92 43.53 Seam loss cannot rescue

Key Findings

  • \(\mathcal{L}\) behaves oppositely in FM compared to classification models: In classification, "high loss samples" are hard examples and useful to retain; in FM, \(\mathcal{L}\) performs worst (FID 33.92), while \(\mathcal{L}^{-1}\) is best. The reason is that high loss in FM mostly comes from "low-density outliers", while FM relies on "dominant samples" to build trajectories—outliers actually hinder training. This is a counterintuitive but practically valuable finding.
  • Different perturbations have vastly different effects on stability: Changing DiT-S/2→DiT-XL/2 (s=0.81, almost unchanged), switching to U-Net (s=0.55, slight drop), removing a gender mode (s=0.58), but changing VAE seed (s=0.32) or flipping all latent feature map signs (s=0.32) completely breaks stability. This indicates that the root of stability lies in the coupling of latent space geometry and FM objective, not the architecture itself.
  • Score-based diffusion does not exhibit the same stability: Replacing FM with score-based diffusion eliminates stability, indicating this is a property specific to rectified FM, not all diffusion models.
  • Balanced clustering reduces bias without harming FID: \(\mathcal{C}_b\) reduces gender KL from 0.044 to 0.016, with FID improving. This provides a simple, label-free solution for "dataset balancing".

Highlights & Insights

  • Elevating "stability" from phenomenon to theoretical explanation and practical algorithm is the paper's main contribution: directly leveraging Bertrand 2025's closed-form solution as the theoretical foundation, then translating it into data pruning and C2F, forming a complete theory-empirics-engineering loop.
  • C2F has significant engineering value: Without touching Fine model weights, training only a small Coarse model with seam loss achieves 2.15× acceleration in production; this "partial model distillation" is very deployment-friendly for DiT-XL/Flux-scale models.
  • Boundaries of stability (broken by VAE changes or latent sign flips) serve as a warning to the LFM community—any modification to the VAE (replacement, scaling, normalization) will invalidate existing LFM and require retraining.

Limitations & Future Work

  • Validation is mainly on medium-scale datasets (CelebA-HQ 28k, FFHQ 63k, ImageNet 1.2M) and DiT series; whether the findings hold for web-scale data (LAION-5B scale) and large Flux/SD3 models remains unanswered.
  • The \(\mathcal{G}\) gradient criterion is too expensive to compute and is only used for analysis, not deployed on large datasets; making it practical may require random projection or sketching.
  • C2F's seam loss only aligns at a single time point, without considering curvature matching between the two ODE segments; if the second derivatives of the velocity fields differ greatly, minor artifacts may still occur.
  • The relationship between stability and generalization is left as future work—intuitively, stronger stability approaches "replicating the training set", so balancing stability and diversity remains an open question.
  • vs Kadkhodaie 2024: They observed post-split convergence in score-based diffusion in pixel space; this work extends the phenomenon to latent FM, provides theoretical grounding, and translates it into practical tools.
  • vs Bertrand 2025: Bertrand provided the FM closed-form, but only for studying "when models generalize"; this work cleverly repurposes the softmax-peaked property to justify pruning.
  • vs Dataset Distillation / Coreset: This work demonstrates that simple cluster-balanced pruning on LFM can outperform more complex coreset methods, offering a concise baseline for data efficiency in generative models.

Rating

  • Novelty: ⭐⭐⭐⭐ Stability phenomenon and C2F are not entirely new, but this is the first systematic treatment on LFM with theoretical explanation.
  • Experimental Thoroughness: ⭐⭐⭐⭐⭐ Covers CelebA-HQ / FFHQ / ImageNet, six pruning criteria, five perturbation types—very comprehensive.
  • Writing Quality: ⭐⭐⭐⭐ Clear formula exposition, visually impactful perturbation classification in Figure 4.
  • Value: ⭐⭐⭐⭐⭐ High engineering value (direct 2.15× acceleration + 50% data pruning), and provides actionable guidance on LFM stability boundaries.