Joint Diffusion Models in Continual Learning¶
| Info | Content |
|---|---|
| Conference | ICCV 2025 |
| arXiv | 2411.08224 |
| Code | GitHub |
| Area | Continual Learning · Diffusion Models · Generative Replay |
| Keywords | continual learning, generative replay, joint diffusion, knowledge distillation, catastrophic forgetting |
TL;DR¶
This paper proposes JDCL, which unifies a classifier and a diffusion generative model into a single jointly parameterized network. Combined with knowledge distillation and a two-stage training strategy, JDCL substantially alleviates catastrophic forgetting in generative replay-based continual learning, surpassing existing generative replay methods.
Background & Motivation¶
Catastrophic Forgetting and Generative Replay¶
Neural networks tend to suffer a dramatic performance drop on previously learned tasks when trained on new task data — a phenomenon known as catastrophic forgetting. Generative Replay (GR) methods address this issue by synthesizing past data using generative models. However, conventional approaches suffer from fundamental limitations:
Decoupling of generative model and classifier: In the standard pipeline, the generative model and classifier are trained independently. The classifier's performance is highly dependent on the quality of generated samples, creating a knowledge transfer bottleneck between the two components.
Degradation of generated sample quality: Even state-of-the-art diffusion models cannot perfectly model the data distribution. Repeatedly training the classifier on generated data leads to continuous performance degradation.
Plasticity–stability trade-off: Existing methods either sacrifice plasticity (failing to learn new tasks) or stability (rapidly forgetting old tasks).
Core Observation¶
The authors conduct a key experiment (Fig. 1): a classifier and a joint diffusion model are first trained to convergence on CIFAR-10, then training is continued using only data generated by the diffusion model. The results show that: - The performance of the standalone classifier drops sharply - The joint diffusion model exhibits significantly smaller degradation - Adding knowledge distillation further reduces the degradation
This demonstrates that joint modeling combined with knowledge distillation is critical for knowledge retention in generative replay-based continual learning.
Method¶
Overall Architecture¶
JDCL comprises three core components: joint diffusion modeling, two-stage local-global training, and knowledge distillation.
1. Joint Diffusion Model¶
The UNet denoising network and the classifier are unified under the same parameterization. The UNet encoder \(e_\nu\) extracts a set of features \(\mathcal{Z}_t = \{z_t^1, z_t^2, \ldots, z_t^n\}\) from different layers, which are aggregated into a vector \(z_t = f(\mathcal{Z}_t)\) via average pooling and fed into the classifier \(g_\omega\) for class prediction.
The joint probability is modeled as:
The joint training loss is:
where the diffusion loss adopts the simplified DDPM objective \(L_{t,\text{diff}} = \mathbb{E}[\|\epsilon - \hat{\epsilon}\|^2]\) and the classification loss is standard cross-entropy.
2. Two-Stage Local-Global Training¶
To balance plasticity and stability, a two-stage scheme is employed:
- Local stage: A copy of the current global model is trained exclusively on new task data \(D_\tau\), ensuring full plasticity.
- Global stage: The global model generates data \(S_{1\ldots\tau-1}\) for old tasks, while the local model generates data \(S_\tau\) for the current task; the global model is then fine-tuned on the combined data.
3. Knowledge Distillation¶
Knowledge distillation is applied to both the diffusion and classification components:
Diffusion distillation loss:
Classification distillation loss:
The final continual learning objective is:
4. Semi-Supervised Extension¶
Leveraging the flexibility of joint modeling, unlabeled data is used exclusively for training the generative component, combined with pseudo-labeling and consistency regularization: weakly augmented samples generate pseudo-labels, while strongly augmented samples are used to compute the semi-supervised loss.
Key Experimental Results¶
Main Results: Fully Supervised Continual Learning¶
| Method | CIFAR-10 (T=5) | CIFAR-100 (T=5) | CIFAR-100 (T=10) | ImageNet100 (T=5) |
|---|---|---|---|---|
| Continual Joint (upper bound) | 86.41 | 73.07 | 64.15 | 50.59 |
| GUIDE (Prev. SOTA) | 64.47 | 41.66 | 26.13 | 39.07 |
| DGR diffusion | 59.00 | 28.25 | 15.90 | 23.92 |
| GFR | 26.70 | 34.80 | 21.90 | 32.95 |
| JDCL (Ours) | 83.69 | 47.95 | 29.04 | 54.53 |
Key findings: JDCL surpasses the previous SOTA by 19+ points on CIFAR-10 (~30% relative gain) and by 15+ points on ImageNet100 (~40% relative gain), approaching the upper bound achieved with an unlimited replay buffer.
Semi-Supervised Continual Learning¶
| Method | CIFAR-10 0.8% | CIFAR-10 5% | CIFAR-100 0.8% | CIFAR-100 5% |
|---|---|---|---|---|
| NNCSL (5120) | 73.7 | 79.3 | 27.5 | 46.0 |
| JDCL (no buffer) | 78.93 | 79.96 | 22.19 | 26.39 |
On CIFAR-10, JDCL without any memory buffer outperforms NNCSL, which relies on a replay buffer of 5,120 samples.
Ablation Study¶
| Joint Modeling | Knowledge Distillation | Two-Stage Training | Accuracy |
|---|---|---|---|
| ✓ | ✓ | ✓ | 83.7 |
| ✗ | ✓ | ✓ | 68.4 |
| ✓ | ✗ | ✓ | 48.2 |
Both joint modeling and knowledge distillation are indispensable; knowledge distillation in particular contributes substantially — removing it causes a 35.5-point drop.
Highlights & Insights¶
- Core innovation: Unifying the generative and discriminative models under a shared parameterization fundamentally eliminates the knowledge transfer bottleneck.
- Self-supervised replay: The joint model uses its own generative component for replay; because the generative and discriminative parts share the same encoder, distribution shift is naturally avoided.
- Flexibility: The local and global training stages are fully decoupled, making the semi-supervised extension straightforward.
- Computational efficiency: Compared to training a separate generative model and classifier, joint training reduces overall computational overhead.
Limitations & Future Work¶
- Performance under the CIFAR-100 semi-supervised setting is suboptimal, likely because weak supervision leads to an imbalance between the classification and generation objectives during joint training.
- Experiments are conducted only on small-scale datasets (CIFAR, ImageNet100), and evaluation on large-scale scenarios is lacking.
- The model is built on a UNet backbone; integration with Transformer-based architectures remains unexplored.
Related Work & Insights¶
- Generative replay methods: DGR, RTF, DDGR, GUIDE, and others using GANs, VAEs, or diffusion models for replay.
- Regularization methods: EWC, LwF, and others that constrain updates to important parameters.
- H-space utilization: Leveraging intermediate UNet features for downstream tasks such as segmentation and classification.
- Semi-supervised continual learning: ORDisCo, CCIC, NNCSL, and related approaches.
Rating¶
| Dimension | Score |
|---|---|
| Novelty | ⭐⭐⭐⭐ |
| Experimental Thoroughness | ⭐⭐⭐⭐⭐ |
| Writing Quality | ⭐⭐⭐⭐ |
| Value | ⭐⭐⭐⭐ |
| Overall Recommendation | ⭐⭐⭐⭐ |