Skip to content

Joint Diffusion Models in Continual Learning

Info Content
Conference ICCV 2025
arXiv 2411.08224
Code GitHub
Area Continual Learning · Diffusion Models · Generative Replay
Keywords continual learning, generative replay, joint diffusion, knowledge distillation, catastrophic forgetting

TL;DR

This paper proposes JDCL, which unifies a classifier and a diffusion generative model into a single jointly parameterized network. Combined with knowledge distillation and a two-stage training strategy, JDCL substantially alleviates catastrophic forgetting in generative replay-based continual learning, surpassing existing generative replay methods.

Background & Motivation

Catastrophic Forgetting and Generative Replay

Neural networks tend to suffer a dramatic performance drop on previously learned tasks when trained on new task data — a phenomenon known as catastrophic forgetting. Generative Replay (GR) methods address this issue by synthesizing past data using generative models. However, conventional approaches suffer from fundamental limitations:

Decoupling of generative model and classifier: In the standard pipeline, the generative model and classifier are trained independently. The classifier's performance is highly dependent on the quality of generated samples, creating a knowledge transfer bottleneck between the two components.

Degradation of generated sample quality: Even state-of-the-art diffusion models cannot perfectly model the data distribution. Repeatedly training the classifier on generated data leads to continuous performance degradation.

Plasticity–stability trade-off: Existing methods either sacrifice plasticity (failing to learn new tasks) or stability (rapidly forgetting old tasks).

Core Observation

The authors conduct a key experiment (Fig. 1): a classifier and a joint diffusion model are first trained to convergence on CIFAR-10, then training is continued using only data generated by the diffusion model. The results show that: - The performance of the standalone classifier drops sharply - The joint diffusion model exhibits significantly smaller degradation - Adding knowledge distillation further reduces the degradation

This demonstrates that joint modeling combined with knowledge distillation is critical for knowledge retention in generative replay-based continual learning.

Method

Overall Architecture

JDCL comprises three core components: joint diffusion modeling, two-stage local-global training, and knowledge distillation.

1. Joint Diffusion Model

The UNet denoising network and the classifier are unified under the same parameterization. The UNet encoder \(e_\nu\) extracts a set of features \(\mathcal{Z}_t = \{z_t^1, z_t^2, \ldots, z_t^n\}\) from different layers, which are aggregated into a vector \(z_t = f(\mathcal{Z}_t)\) via average pooling and fed into the classifier \(g_\omega\) for class prediction.

The joint probability is modeled as:

\[p_{\nu,\psi,\omega}(x_{0:T}, y) = p_{\nu,\omega}(y|x_0) \cdot p_{\nu,\psi}(x_{0:T})\]

The joint training loss is:

\[L_{JD}(\nu,\psi,\omega) = \alpha \cdot L_{\text{class}}(\nu,\omega) - \sum_{t=2}^{T} L_{t,\text{diff}}(\nu,\psi) - L_0 - L_T\]

where the diffusion loss adopts the simplified DDPM objective \(L_{t,\text{diff}} = \mathbb{E}[\|\epsilon - \hat{\epsilon}\|^2]\) and the classification loss is standard cross-entropy.

2. Two-Stage Local-Global Training

To balance plasticity and stability, a two-stage scheme is employed:

  • Local stage: A copy of the current global model is trained exclusively on new task data \(D_\tau\), ensuring full plasticity.
  • Global stage: The global model generates data \(S_{1\ldots\tau-1}\) for old tasks, while the local model generates data \(S_\tau\) for the current task; the global model is then fine-tuned on the combined data.

3. Knowledge Distillation

Knowledge distillation is applied to both the diffusion and classification components:

Diffusion distillation loss:

\[L_{t,\text{diff}}^{KD}(\nu,\psi) = \mathbb{E}[\|\epsilon_f - \hat{\epsilon}\|^2]\]

Classification distillation loss:

\[L_{\text{class}}^{KD}(\nu,\omega) = -\mathbb{E}\left[\sum_{k} \log \frac{\exp(\varphi_k)}{\sum_c \exp(\varphi_c)} \varphi_k^f\right]\]

The final continual learning objective is:

\[L_{CL} = \mathbb{E}_{S_{1\ldots t-1}}[L_{JDKD}(\cdot; p^f)] + \mathbb{E}_{S_t}[L_{JDKD}(\cdot; p_n)] + \beta \cdot \mathbb{E}_{S_{1\ldots t}}[L_{JD}]\]

4. Semi-Supervised Extension

Leveraging the flexibility of joint modeling, unlabeled data is used exclusively for training the generative component, combined with pseudo-labeling and consistency regularization: weakly augmented samples generate pseudo-labels, while strongly augmented samples are used to compute the semi-supervised loss.

Key Experimental Results

Main Results: Fully Supervised Continual Learning

Method CIFAR-10 (T=5) CIFAR-100 (T=5) CIFAR-100 (T=10) ImageNet100 (T=5)
Continual Joint (upper bound) 86.41 73.07 64.15 50.59
GUIDE (Prev. SOTA) 64.47 41.66 26.13 39.07
DGR diffusion 59.00 28.25 15.90 23.92
GFR 26.70 34.80 21.90 32.95
JDCL (Ours) 83.69 47.95 29.04 54.53

Key findings: JDCL surpasses the previous SOTA by 19+ points on CIFAR-10 (~30% relative gain) and by 15+ points on ImageNet100 (~40% relative gain), approaching the upper bound achieved with an unlimited replay buffer.

Semi-Supervised Continual Learning

Method CIFAR-10 0.8% CIFAR-10 5% CIFAR-100 0.8% CIFAR-100 5%
NNCSL (5120) 73.7 79.3 27.5 46.0
JDCL (no buffer) 78.93 79.96 22.19 26.39

On CIFAR-10, JDCL without any memory buffer outperforms NNCSL, which relies on a replay buffer of 5,120 samples.

Ablation Study

Joint Modeling Knowledge Distillation Two-Stage Training Accuracy
83.7
68.4
48.2

Both joint modeling and knowledge distillation are indispensable; knowledge distillation in particular contributes substantially — removing it causes a 35.5-point drop.

Highlights & Insights

  1. Core innovation: Unifying the generative and discriminative models under a shared parameterization fundamentally eliminates the knowledge transfer bottleneck.
  2. Self-supervised replay: The joint model uses its own generative component for replay; because the generative and discriminative parts share the same encoder, distribution shift is naturally avoided.
  3. Flexibility: The local and global training stages are fully decoupled, making the semi-supervised extension straightforward.
  4. Computational efficiency: Compared to training a separate generative model and classifier, joint training reduces overall computational overhead.

Limitations & Future Work

  • Performance under the CIFAR-100 semi-supervised setting is suboptimal, likely because weak supervision leads to an imbalance between the classification and generation objectives during joint training.
  • Experiments are conducted only on small-scale datasets (CIFAR, ImageNet100), and evaluation on large-scale scenarios is lacking.
  • The model is built on a UNet backbone; integration with Transformer-based architectures remains unexplored.
  • Generative replay methods: DGR, RTF, DDGR, GUIDE, and others using GANs, VAEs, or diffusion models for replay.
  • Regularization methods: EWC, LwF, and others that constrain updates to important parameters.
  • H-space utilization: Leveraging intermediate UNet features for downstream tasks such as segmentation and classification.
  • Semi-supervised continual learning: ORDisCo, CCIC, NNCSL, and related approaches.

Rating

Dimension Score
Novelty ⭐⭐⭐⭐
Experimental Thoroughness ⭐⭐⭐⭐⭐
Writing Quality ⭐⭐⭐⭐
Value ⭐⭐⭐⭐
Overall Recommendation ⭐⭐⭐⭐