Soft Prompt Generation for Domain Generalization¶
Conference: ECCV 2024
arXiv: 2404.19286
Code: https://github.com/renytek13/Soft-Prompt-Generation-with-CGAN
Area: Image Generation
Keywords: domain generalization, prompt learning, CLIP, CGAN, generative model
TL;DR¶
This paper proposes SPG (Soft Prompt Generation), which introduces generative models to VLM prompt learning for the first time. By dynamically generating instance-specific soft prompts from images via CGAN, it stores domain knowledge in the generative model rather than prompt vectors, achieving superior domain generalization performance.
Background & Motivation¶
- VLMs such as CLIP achieve remarkable adaptation performance on downstream tasks through soft prompts.
- However, their generalization performance drops significantly under domain shifts.
- Limitations of prior prompt learning methods:
- CoOp: Learns fixed prompts, which overfit the training distribution.
- CoCoOp/DPL: Use MLPs to generate residual vectors to adjust fixed prompts, but simple MLPs struggle to capture complex image-prompt relationships.
- CAE: Introduces a domain bank, but the prompts lack diversity.
- Shift in Core Idea: Instead of storing domain knowledge in the prompts, it is stored in the generative model, enabling the model to dynamically generate domain-adaptive prompts for each image.
Method¶
Overall Architecture¶
SPG consists of a two-stage training scheme and an inference stage: 1. Stage I: Learn domain prompt labels for each source domain. 2. Stage II: Use CGAN to learn to generate corresponding domain prompts from images. 3. Inference: The generator of CGAN directly generates instance-specific soft prompts for target domain images.
Key Designs¶
1. Domain Prompt Label Learning (Training Stage I)
- Solitarily optimize an optimal soft prompt \(v^{d_i}\) for each source domain \(d_i\).
- Formulated under the CoOp framework with a context length of 4.
- Optimized via cross-entropy loss: \(v^{d_i*} = \operatorname{argmin} \mathbb{E}[-\log p(y|x, v^{d_i})]\).
- The domain prompt labels encapsulate rich domain information and serve as the training targets for Stage II.
2. CGAN Pre-training (Training Stage II)
- Generator \(G\): Receives noise \(z\) and image embedding \(f(x) \rightarrow\) generates soft prompts.
- Input: \([z, f(x)]\) (concatenation of noise and CLIP image features).
- Output: A prompt vector with the same shape as the domain prompt labels.
- Discriminator \(D\): Receives the domain prompt label or generated prompt + image embedding \(\rightarrow\) determines real or fake.
- Adversarial training objective: \(\min_G \max_D V(G, D)\).
- To enhance stability, a gradient clipping strategy is incorporated.
3. Inference
- Only the generator of CGAN is utilized.
- Given a target domain image \(x \rightarrow f(x) + \text{noise } z \rightarrow G(z|f(x)) \rightarrow\) instance-specific soft prompt.
- \(p(y=i|x) = \operatorname{softmax}(\langle w_i, f(x) \rangle / \tau)\)
- where \(w_i = g([G(z|f(x)), c_i])\), with \(g\) denoting the text encoder.
Loss & Training¶
- Stage I: SGD optimizer, batch size 32, context length of 4.
- Stage II: AdamW optimizer, weight decay \(1\mathrm{e}{-4}\).
- Learning rate: \(2\mathrm{e}{-3}\) (PACS/VLCS/TerraInc) / \(2\mathrm{e}{-4}\) (OfficeHome/DomainNet).
- Gradient clipping is applied to stabilize CGAN training.
- Backbone: ResNet50 and ViT-B/16.
- Model selection: Highest accuracy achieved on the training domain validation set.
Key Experimental Results¶
Main Results (Multi-source DG, ViT-B/16)¶
| Method | PACS | VLCS | OfficeHome | TerraInc | DomainNet | Average |
|---|---|---|---|---|---|---|
| ZS-CLIP | 95.7 | 82.6 | 80.4 | 28.0 | 57.6 | 68.9 |
| CoOp | 95.4 | 82.5 | 82.0 | 33.0 | 56.2 | 69.8 |
| CoCoOp | 96.0 | 81.7 | 81.1 | 33.8 | 56.9 | 69.9 |
| MaPLe | 96.3 | 82.7 | 82.6 | 34.5 | 57.7 | 70.8 |
| SPG | 96.8 | 83.1 | 83.0 | 37.8 | 58.7 | 71.9 |
Ablation Study¶
| Variant | PACS | VLCS | Average |
|---|---|---|---|
| w/o Domain prompt labels (using unified prompt) | 95.2 | 82.0 | Decreased |
| w/o CGAN (directly using domain prompt) | 95.8 | 82.4 | Decreased |
| Replacing CGAN with MLP | 96.0 | 82.5 | Decreased |
| Full SPG | 96.8 | 83.1 | Best |
Key Findings¶
- SPG achieves the best performance across all five DG benchmarks, with an average gain of 1.1% (relative to MaPLe).
- Domain prompt labels are crucial, providing high-quality training targets for the CGAN.
- CGAN outperforms MLP, capturing more complex image-prompt mapping relations.
- The most significant gain (+3.3% relative to MaPLe) is observed on TerraIncognita, which represents the largest domain gap.
- SPG is also effective in Single-source DG and Multi-target DG settings.
Highlights & Insights¶
- Paradigm Innovation: Integrates generative models into VLM prompt learning for the first time, pioneering a new "prompt generation" paradigm.
- Shift of Domain Knowledge Storage: Transitions from prompt vectors to generative model parameters, allowing for greater flexibility.
- Diversity of Instance-Specific Prompts: CGAN naturally supports diverse prompt generation, with noise input introducing beneficial stochasticity.
- Clever Two-Stage Training Strategy: Domain prompt labels act as a "teacher" to guide CGAN in acquiring domain knowledge.
- Simple and Effective: SOTA results are achieved using the classic CGAN architecture, which is simple and easy to implement.
Limitations & Future Work¶
- CGAN training instability necessitates techniques such as gradient clipping.
- Only CGAN is utilized as the generative model; stronger generative models (e.g., diffusion models) might yield better performance.
- The quality of domain prompt labels directly impacts CGAN training, making optimization in Stage I crucial.
- Visual-side prompt generation has not been explored (currently restricted to the textual side).
Related Work & Insights¶
- CoOp: Pioneering work in fixed soft prompts.
- CoCoOp: Pioneer of image-conditional residual prompting.
- CGAN: Backbone of the generative approach.
- DAPL: A reference for prompt learning in domain adaptation.
- Insight: Shifting domain knowledge from "static parameter storage" to "dynamic generation" is an effective mechanism to improve generalization capabilities.
Rating¶
| Dimension | Score (1-10) |
|---|---|
| Novelty | 8 |
| Technical Depth | 7 |
| Experimental Thoroughness | 8 |
| Value | 8 |
| Writing Quality | 7 |
| Overall Rating | 7.6 |