Skip to content

Learnable Sparsity for Vision Generative Models

Conference: ICLR 2026
OpenReview: https://openreview.net/forum?id=9pNWZLVZ4r
Code: To be confirmed
Area: Model Compression / Diffusion Model Pruning
Keywords: Structured Pruning, Differentiable Mask, Diffusion Model, Flow Matching, End-to-End Pruning, Gradient Checkpointing

TL;DR

EcoDiff utilizes an end-to-end differentiable masking objective spanning the entire denoising trajectory to perform structured pruning on diffusion and flow matching models. Combined with "Timestep Gradient Checkpointing," it reduces memory consumption from \(O(T)\) to \(O(1)\), enabling the pruning of 20% of parameters in SDXL/FLUX with minimal quality loss using only 100 samples and 10 A100 GPU hours.

Background & Motivation

  • Background: Vision generative models (diffusion, flow matching) have seen explosive growth in parameter counts—the latest FLUX reaches 12B, approximately 13 times that of SD2 from two years ago. Larger models lead to slower inference, higher deployment barriers, and increased carbon emissions. Pruning offers a more computationally efficient compression route than knowledge distillation.
  • Limitations of Prior Work: Existing diffusion pruning methods almost exclusively rely on expensive retraining to recover quality. Fang et al. estimated that diffusion model compression requires 10%–20% of the original training cost; compressing SD2 can consume 40,000 GPU hours. Even BK-SDM, which reduces costs to approximately 300 A100 hours, requires 0.22M calibration data. For larger models like SDXL/FLUX, the retraining burden is even more prohibitive.
  • Key Challenge: The root cause of expensive retraining is coarse pruning criteria—existing methods mostly use simple heuristics or one-shot pruning, where incorrect neuron selection can only be mitigated by subsequent retraining. While differentiable masks are a more refined selection criterion in classical or LLM pruning, direct application to vision generative models fails: diffusion/flow models are Markovian, and a minor change in an intermediate step creates a "ripple effect" along the denoising chain, ultimately distorting the image completely.
  • Goal: To learn a sparse mask that minimizes quality loss in a universal and scalable manner for various architectures (U-Net diffusion + DiT flow matching) without relying on large-scale retraining.
  • Core Idea: [End-to-End Differentiable Mask] Instead of calculating loss at every step, matching is performed only on the final denoising latent \(z_0\), allowing mask learning to directly align with the output of the entire generation trajectory; [Timestep Gradient Checkpointing] An engineering technique is used to reduce the massive memory overhead of end-to-end backpropagation, allowing large models to be processed on a single GPU.

Method

Overall Architecture

EcoDiff attaches a set of 0/1 structured masks \(M\) to attention heads and FFN neurons within Transformer blocks. These masks are learned by minimizing the discrepancy between the "original model final latent \(z_0\)" and the "masked model final latent \(\hat z_0\)" (along with an \(L_0\) sparsity penalty). Since \(L_0\) is non-differentiable, hard-concrete continuous relaxation is used to transform the masks into gradient-optimizable continuous variables \(\lambda\). After learning, a threshold is applied to obtain hard masks for physical neuron deletion. As end-to-end backpropagation requires storing intermediates from all timesteps, causing memory explosion, timestep gradient checkpointing is employed to store only the latents after each denoising step and recompute states during the backward pass, reducing memory complexity from \(O(T)\) to \(O(1)\). Pruning is optionally followed by lightweight recovery using LoRA or full-parameter fine-tuning.

flowchart LR
    A["Text Prompt + Initial Noise z_T"] --> B["Original Model ε_θ<br/>Full Denoising → z_0"]
    A --> C["Masked Model ε_θ^mask(M)<br/>Full Denoising → ẑ_0"]
    B & C --> D["End-to-End Loss<br/>‖z_0 − ẑ_0‖² + β‖M‖₀"]
    D -->|Timestep Gradient Checkpointing<br/>O(T)→O(1) Memory| E["Update Continuous Mask λ"]
    E -->|Threshold τ| F["Hard Mask M → Physical Deletion"]
    F --> G["Optional: LoRA / Full-Param Tuning Recovery"]

Key Designs

1. End-to-End Pruning Objective: Optimizing for the Terminal Latent
A naive approach is step-wise reconstruction loss \(L=\sum_i\sum_t\|f(x_{i,t-1},M)-x_{i,t}\|_2^2\), but it has two major flaws: it treats all timesteps equally, underestimating critical steps, and it implicitly assumes the input for each step is correct. This causes pruning decisions to prioritize "short-term accuracy" while ignoring the long-term impact of neurons. EcoDiff reformulates the full denoising process as a nested function \(z_0=F(z_T,y,T)\), ensuring the masked model's terminal output matches the original model: $\(\arg\min_{M}\;\mathbb{E}_{z_T,y\sim C}\big[\|F_{\epsilon_\theta}(z_T,y)-F_{\epsilon_\theta^{\text{mask}}}(z_T,y,M)\|^2\big]+\beta\|M\|_0\)$ This ensures the mask is naturally responsible for the final result of the entire trajectory, automatically weighting which steps and neurons are truly important without requiring manually designed re-weighting factors like DiffPruning.

2. Structured Neuron Masking for Transformer Blocks: Time-Invariant and Plug-and-Play
Pruning granularity is focused on two structures within Transformer blocks. MHA uses masks \(M_i\in\{0,1\}\) per attention head: \(\text{MHA}^{\text{mask}}=(M_1\cdot\text{attn}_1\|\dots\|M_h\cdot\text{attn}_h)W^o\). FFN applies masks to neurons after the activation layer: \(\text{FFN}^{\text{mask}}(x)=(\sigma(xW_1+b_1)\odot M_{\text{ffn}})W_2+b_2\). This design preserves module input/output dimensions, requiring almost no structural changes or specialized hardware during deployment. A key observation is that since diffusion/flow models reuse the same denoising network at every timestep, the learned masks are naturally time-invariant—once a neuron is structurally removed, it disappears for all forward passes.

3. Continuous Relaxation of Discrete Masks: Hard-Concrete for Gradient Optimization
Since \(\|M\|_0\) is non-differentiable, EcoDiff utilizes hard-concrete sampling from Louizos et al.: \(s=\sigma\big((\log(u+\delta)-\log(1-u+\delta)+\lambda)/\alpha\big)\), \(\hat M=\min(1,\max(0,s(\zeta-\gamma)+\gamma))\), controlled by stretch parameters \(\gamma,\zeta\) and temperature \(\alpha\). This turns the learnable parameter into a continuous \(\lambda\in\mathbb{R}^{|M|}\). The total loss becomes \(L(\lambda)=L_E(\lambda)+\beta L_0(\lambda)\), where \(L_0(\lambda)\) has a closed-form expectation. After training, discretization occurs via \(M(\lambda)=\mathbb{I}(\lambda>\tau)\), where \(\tau\) determines the target sparsity, leading to physical deletion of neurons for actual acceleration.

4. Timestep Gradient Checkpointing: Reducing Backpropagation Memory from O(T) to O(1)
Learning masks end-to-end creates a gradient chain across all denoising steps. In SDXL, a naive implementation requires ~1400GB of VRAM. Traditional gradient checkpointing (within a single forward pass) is insufficient for diffusion processes requiring multiple passes. EcoDiff's timestep gradient checkpointing stores only the latent \(\hat z_t\) after each denoising step during the forward pass. During the backward pass, it recomputes intermediate states step-by-step and accumulates gradients: \(\frac{dL}{d\lambda}\mathrel{+}=\frac{dL}{d\hat z_t}\frac{d\hat z_t}{d\lambda}\). This reduces memory complexity from \(O(T)\) to \(O(1)\) at the cost of one additional forward pass (approximately 2× runtime), bringing SDXL memory usage from 1400GB down to under 30GB.

Key Experimental Results

Main Results

Evaluated on 5,000 MS COCO / Flickr30K images with a unified 10 A100 hour budget (except FLUX-Lite, which utilized 1120 H200 hours):

Model Method Sparsity Params Speedup COCO FID↓ COCO CLIP↑ Flickr FID↓
SDXL (U-Net) Original 0% 2.6B 27.43 0.33 33.95
BK-SDM 20% 2.1B 1.25× 42.87 0.30 56.17
DiffPruning 20% 2.1B 1.25× 83.81 0.25 96.53
Per-Step Loss 20% 2.1B 1.25× 97.36 0.22 110.53
EcoDiff 20% 2.1B 1.25× 32.19 0.33 40.91
FLUX-Dev (DiT) Dev 0% 11.9B 28.47 0.34 37.82
DiffPruning 20% 9.6B 1.25× 40.84 0.33 48.02
FLUX-Lite 33% 8B 1.49× 29.36 0.34 38.17
EcoDiff 20% 9.6B 1.25× 30.81 0.32 42.58
FLUX-Schnell Schnell 0% 11.9B 30.99 0.33 39.70
DiffPruning 20% 9.6B 8.75× 42.36 0.30 54.49
EcoDiff 20% 9.6B 8.75× 31.76 0.30 43.25

EcoDiff maintains a COCO FID of 32.19 for SDXL (close to the original 27.43), while DiffPruning and Per-Step Loss soar to 83.81 and 97.36 respectively. On FLUX-Dev, it achieves quality comparable to FLUX-Lite using only 10 A100 hours versus 1120 H200 hours.

Ablation Study

SDXL with different sparsity levels and post-pruning recovery strategies (Full Fine-tuning / LoRA / No recovery with 50-step mask learning):

Sparsity Recovery COCO FID↓ COCO CLIP↑ Iterations
0% 27.43 0.33
25% No 34.61 0.32 50
25% Full 31.64 0.34 10k
40% No 43.19 0.30 50
40% Full 33.25 0.33 10k
50% No 81.76 0.26 50
50% LoRA 53.89 0.28 10k
50% Full 34.87 0.33 10k

Learning masks for only 50 steps (without recovery) maintains quality at 25%–40% sparsity. At high sparsity (50%), LoRA lacks expressivity, but 10k steps of full-parameter fine-tuning recovers the FID to 34.87.

Key Findings

  • Step-wise loss is the worst baseline (Per-Step Loss FID 97.36), confirming the necessity of end-to-end objectives—the myopia of step-wise loss accumulates errors along the denoising chain.
  • EcoDiff generally yields lower SSIM (<0.65). The authors explain this as prioritizing semantic fidelity (high FID/CLIP) over pixel-level replication; minor shifts in texture/details are penalized by pixel metrics but do not harm subjective quality.
  • Capable of pruning step-distilled models: FLUX-Schnell loses only 0.77 COCO FID at 20% sparsity, achieving a cumulative 8.75× speedup relative to FLUX-Dev.
  • At 10%–20% sparsity, semantics for specific prompts ("a cat and a dog playing chess") actually improved, correlating with the FID decrease.

Highlights & Insights

  • Converting the "pruning criterion" problem into "end-to-end trajectory matching": By using an objective responsible only for the final point, the "ripple error" of step-wise pruning in diffusion models is bypassed. This is the core insight of the paper.
  • The observation of time-invariance of masks is critical—since weights are shared across steps, a single set of structural masks covers the entire trajectory, greatly simplifying the problem.
  • Timestep Gradient Checkpointing solves the engineering hurdle of "end-to-end memory explosion" (1400GB to 30GB), serving as the key enabling technology to make this approach viable for 12B models like FLUX.
  • High generality: The same framework covers U-Net diffusion, DiT flow matching, and step-distilled models, and is orthogonal to other acceleration methods like feature reuse.

Limitations & Future Work

  • Sparsity is limited: 20% is nearly lossless, but more aggressive levels (e.g., 50%) necessitate full-parameter fine-tuning. The paper notes seeking higher pruning rates as future work.
  • While low SSIM is explained by semantic prioritization, it may be a concern for downstream tasks requiring pixel-level consistency (e.g., controllable editing, video consistency).
  • Timestep Gradient Checkpointing introduces roughly 2× runtime overhead, a trade-off between memory and computation time.
  • The end-to-end objective only matches the final latent, lacking explicit constraints on intermediate steps, which may affect interactive multi-step generation.
  • Differentiable Masking/L0 Pruning: Inherits ideas from hard-concrete relaxation (Louizos et al.) and structural neuron masks in LLM-Pruner (Ma et al.), systematically adapting them to vision generative models for the first time.
  • Diffusion Pruning Comparison: Unlike DiffPruning (gradient proxy importance), BK-SDM (block removal + feature distillation), or FLUX-Lite (post-training community pruning), EcoDiff’s differentiation lies in its "end-to-end objective + extremely low compute (10 A100h / 100 samples)."
  • Inspiration: For any "Markovian multi-step generation" process (video generation, autoregressive diffusion, world models), an "end-to-end compression objective focused on the terminal point + cross-step gradient checkpointing" represents a transferable paradigm.

Rating

  • Novelty: ⭐⭐⭐⭐ The combination of end-to-end differentiable masking and timestep gradient checkpointing brings low-cost differentiable pruning to SDXL/FLUX-level models for the first time.
  • Experimental Thoroughness: ⭐⭐⭐⭐ Covers U-Net/DiT/distilled models, multiple baselines, sparsity/recovery ablations, and memory/time complexity validation across FID, CLIP, and SSIM.
  • Writing Quality: ⭐⭐⭐⭐ Progresses logically from motivation to challenges (flaws in step-wise loss) to method and engineering enablers.
  • Value: ⭐⭐⭐⭐ Compress massive generative models with minimal compute, addressing deployment costs and carbon footprint with high practicality.