Prompting Language-Informed Distribution for Compositional Zero-Shot Learning¶
Conference: ECCV 2024
arXiv: 2305.14428
Code: https://github.com/Cogito2012/PLID
Area: LLM Pre-training
Keywords: Compositional Zero-Shot Learning, CLIP, LLM, Distribution Prompting, Primitive Decomposition
TL;DR¶
This paper proposes the PLID method, which leverages sentence-level category descriptions generated by LLMs to construct language-knowledge-driven Gaussian distributions. Combined with vision-language primitive decomposition and randomized logit fusion, it achieves state-of-the-art (SOTA) performance on the Compositional Zero-Shot Learning (CZSL) task.
Background & Motivation¶
Background: Compositional Zero-Shot Learning (CZSL) requires generalizing from seen compositions (e.g., sliced potatoes, red tomatoes) to unseen compositions (e.g., sliced tomatoes). Recent works based on CLIP prompt tuning have substantially outperformed traditional vision-based methods.
Limitations of Prior Work: - Lack of Diversity: Hard prompt methods such as CSP use only a single prompt for each class (e.g., "a photo of [state][object]"), failing to capture intra-class visual variations. - Lack of Informativeness: Although distribution-based prompting methods like ProDA introduce multiple prompts to increase diversity, these prompts lack linguistic semantic information, limiting their efficacy in distinguishing fine-grained compositional categories. - Primitive Entanglement: Visual primitives (states and objects) are naturally coupled (e.g., tomatoes are inherently associated with red). Existing methods either ignore decoupling or only decouple on the textual side.
Key Challenge: The prompt must simultaneously satisfy diversity and informativeness—ProDA offers diversity but lacks informativeness, while CSP provides some informativeness but lacks diversity.
Goal: To make the category textual representation of CLIP both diverse and highly informative, while supporting primitive decomposition on both the visual and textual sides.
Key Insight: Leverage LLMs to generate multiple descriptive sentences for each compositional class \(\rightarrow\) construct Gaussian distributions for the categories \(\rightarrow\) align distributions in both the compositional and primitive spaces simultaneously.
Core Idea: Use the descriptive sentences generated by LLMs as distribution support points (DSPs) to learn language-knowledge-driven class distributions via soft prompts, enabling diverse and informative zero-shot compositional recognition.
Method¶
Overall Architecture¶
Input image \(\rightarrow\) CLIP vision encoder + VFE enhancement \(\rightarrow\) obtain image feature \(\mathbf{v}\)
Class name \(\rightarrow\) CLIP text encoder + soft prompt \(\rightarrow\) obtain class mean \(\mathbf{q}_y\)
LLM-generated M descriptions \(\rightarrow\) CLIP encoding \(\rightarrow\) DSP \(\mathbf{D}^{(y)}\) \(\rightarrow\) TFE enhancement \(\rightarrow\) class mean \(\mathbf{t}_y\)
\(\rightarrow\) build a three-level Gaussian distribution of composition/state/object \(\rightarrow\) distribution alignment loss
\(\rightarrow\) VLPD primitive decomposition \(\rightarrow\) randomized logit fusion \(\rightarrow\) final prediction
Key Designs¶
-
Language-Knowledge-Driven Distribution (LID):
- Function: Construct a Gaussian distribution based on LLM descriptions for each compositional class \(y = (s, o)\).
- Mechanism: Generate M descriptions \(S^{(y)} = \{S_1^{(y)}, ..., S_M^{(y)}\}\) using an LLM, and obtain \(\mathbf{D}^{(y)} \in \mathbb{R}^{M \times d}\) through the CLIP text encoder. Use TFE (a single-layer cross-attention) to fuse \(\mathbf{D}^{(y)}\) into the class embedding \(\mathbf{q}_y\) to get the enhanced mean \(\mathbf{t}_y = \Psi_{\text{TFE}}(\mathbf{q}_y, \mathbf{D}^{(y)})\). Treat \(\mathbf{t}_y + \mathbf{D}^{(y)}\) as the distribution support points, assuming they follow \(\mathcal{N}(\mathbf{t}_y, \boldsymbol{\Sigma}_y)\). The training objective minimizes the NLL upper bound:
\(\mathcal{L}_y(\mathbf{x}, y) = -\log \frac{\exp(h_y / \tau)}{\sum_{k=1}^{C} \exp((h_k + h_{k,y}^{(m)}) / \tau)}\)
where the pairwise margin \(h_{k,y}^{(m)} = \mathbf{v}^\top \mathbf{A}_{k,y} \mathbf{v} / (2\tau)\) is determined by the covariance difference \(\mathbf{A}_{k,y} = \boldsymbol{\Sigma}_{kk} + \boldsymbol{\Sigma}_{yy} - \boldsymbol{\Sigma}_{ky} - \boldsymbol{\Sigma}_{yk}\). - Design Motivation: Minimizing this loss naturally minimizes intra-class variance (\(\boldsymbol{\Sigma}_{yy}\)) and maximizes inter-class separation (\(\boldsymbol{\Sigma}_{ky}\)), achieving automatic optimization of the class distribution.
-
Vision-Language Primitive Decomposition (VLPD):
- Function: Decompose the image feature \(\mathbf{v}\) into state and object features via two parallel networks \(f_s\) and \(f_o\).
- Mechanism:
\(h_s = \cos(f_s(\mathbf{v}), \frac{1}{|\mathcal{Y}_s|}\sum_{y \in \mathcal{Y}_s} \mathbf{t}_y), \quad h_o = \cos(f_o(\mathbf{v}), \frac{1}{|\mathcal{Y}_o|}\sum_{y \in \mathcal{Y}_o} \mathbf{t}_y)\)
Text-side primitive embeddings are obtained by averaging group compositional embeddings (e.g., averaging all compositions containing the same state to obtain the state embedding, and likewise for objects). - Design Motivation: Unlike DFSP which only performs text-side decomposition, VLPD decomposes on both the visual and textual sides, showing better performance in experiments.
-
Randomized Logit Mixed Fusion (SLM):
- Function: Perform a randomized weighted fusion between the direct composition prediction \(h_y\) and the re-composition prediction \(h_y^{(rc)} = h_s + h_o\).
- Mechanism:
\(\tilde{h}_y = (1 - \lambda) h_y + \lambda h_y^{(rc)}, \quad \lambda \sim \text{Beta}(a, b)\)
Compute randomized weights by sampling \(\lambda\) from a Beta distribution during training, and use the expectation \(\lambda = a/(a+b)\) during inference. - Design Motivation: The randomness of the Beta distribution introduces a regularization effect. Logit-level fusion avoids information loss regarding inter-class relationships in large category spaces, which typically occurs during softmax probability mixing.
Loss & Training¶
The total training loss is the sum of the composition-level, state-level, and object-level distribution alignment losses: $\(\mathcal{L} = \mathcal{L}_y + \mathcal{L}_s + \mathcal{L}_o\)$
All three employ the same form of the NLL upper bound loss, and each constructs a Gaussian distribution in its corresponding semantic space. During training, only the soft prompts \(\mathbf{p}_{1:L}\), primitive embeddings \([\mathbf{s}][\mathbf{o}]\), TFE/VFE parameters, and \(f_s, f_o\) are optimized, while the CLIP encoders remain frozen.
Key Experimental Results¶
Main Results¶
Closed-World Setting (H / AUC)
| Dataset | Metric | PLID | DFSP (Prev. SOTA) | CSP | Gain |
|---|---|---|---|---|---|
| MIT-States | H / AUC | 39.0 / 22.1 | 37.3 / 20.6 | 36.3 / 19.4 | +1.7 / +1.5 |
| UT-Zappos | H / AUC | 52.4 / 38.7 | 47.2 / 36.0 | 46.6 / 33.0 | +5.2 / +2.7 |
| C-GQA | H / AUC | 27.9 / 11.0 | 27.1 / 10.5 | 20.5 / 6.2 | +0.8 / +0.5 |
Open-World Setting (H / AUC)
| Dataset | Metric | PLID | DFSP (Prev. SOTA) | Gain |
|---|---|---|---|---|
| MIT-States | H / AUC | 20.4 / 7.3 | 19.3 / 6.8 | +1.1 / +0.5 |
| UT-Zappos | H / AUC | 46.6 / 30.8 | 44.0 / 30.3 | +2.6 / +0.5 |
| C-GQA | H / AUC | 10.6 / 2.5 | 10.4 / 2.4 | +0.2 / +0.1 |
Ablation Study¶
Ablation of Main Components (MIT-States, AUC_cw)
| Config | AUC_cw | AUC_ow | Description |
|---|---|---|---|
| (a) Baseline (Mean Pooling) | 18.56 | 5.56 | No distribution modeling |
| (b) + LID | 20.43 | 6.50 | Distribution modeling brings +1.87 |
| (c) + LID + FE | 21.09 | 6.95 | Feature enhancement is effective |
| (d) + LID + FE + OPT-1.3B | 21.67 | 7.01 | Better LLM brings slight improvement |
| (e) + LID + FE + OPT + PDF | 22.12 | 7.34 | Primitive decomposition & fusion further improves |
Ablation of VLPD Decomposition Strategy
| Decomposition Modality | Fusion Type | AUC_cw | Description |
|---|---|---|---|
| Text-only decomposition | No fusion | 20.98 | DFSP strategy |
| Text + Vision | Deterministic fusion | 21.90 | Dual-side decomposition is better |
| Text + Vision | Randomized fusion | 22.12 | Beta randomness acts as regularization |
Key Findings¶
- ProDA (uninformative distribution) is far inferior to PLID (informative distribution), indicating that diversity alone is insufficient and linguistic informativeness is key.
- A smaller LLM (OPT-1.3B) is sufficient. Larger models like GPT-3.5 and Mistral-7B do not yield further improvements, showing that the quality of LLM descriptions rather than scale is the deciding factor.
- \(M=64\) descriptions and \(N=8\) image-augmented views are the optimal hyperparameters; too few descriptions yield an insufficiently rich distribution, while too many show diminishing returns.
- Beta(1,9) performs best, indicating that directly learned composition predictions should dominate, while re-composition predictions serve an auxiliary calibration role.
- t-SNE visualization clearly shows that the within-class distribution becomes tighter and the between-class separation becomes larger after distribution learning.
Highlights & Insights¶
- LLM Description = Class Distribution Support Points: Translating diverse LLM-generated descriptions into support points of a Gaussian distribution introduces a novel "description-as-distribution" paradigm, which can be generalized to any zero-shot recognition task.
- NLL Upper Bound with Organic Variance Optimization: By deriving the NLL upper bound, the optimization signal naturally minimizes intra-class variance and maximizes inter-class distances, eliminating the need to design auxiliary contrastive losses.
- Parameter Efficiency: Requires only a single set of soft prompts, which is far fewer than the large collection of learned prompts in ProDA.
- Logit Fusion > Probability Fusion: In large category spaces (e.g., C-GQA with 278K classes), softmax probability normalization loses inter-class correlation information; thus, logit-level fusion is more appropriate.
Limitations & Future Work¶
- When the number of classes is extremely large (e.g., 278K classes in C-GQA), calculating the covariance matrix \(\mathbf{A} \in \mathbb{R}^{d \times C \times C}\) requires group-wise approximations to be feasible.
- The quality of LLM descriptions affects performance, but how to automatically evaluate and filter description quality has not yet been explored.
- Only the ViT-L/14 backbone was validated; whether smaller CLIP models or more recent VLMs are equally effective remains unknown.
- The image-side VFE utilizes multi-view augmentation, which tends to overfit when \(N\) is too large.
- The margin of performance improvement under the open-world setting is limited, suggesting room for improvement in handling extremely large search spaces.
Related Work & Insights¶
- vs. CSP: CSP uses hard prompts with learnable primitive embeddings. While simple and efficient, it lacks diversity and informativeness. PLID builds on this by introducing distribution modeling and LLM descriptions.
- vs. DFSP: DFSP performs textual primitive decomposition and probability fusion. PLID shifts to dual-side (vision and text) decomposition + logit fusion, consistently outperforming DFSP across all datasets.
- vs. ProDA: ProDA learns multiple sets of soft prompts to model distributions, but these prompts lack semantic information and introduce a huge parameter overhead. PLID replicates this distribution modeling using LLM descriptions instead, yielding better performance with much fewer parameters.
Rating¶
- Novelty: ⭐⭐⭐⭐ The philosophy of using LLM descriptions to construct class distributions is novel and practical, backed by a solid theoretical derivation of distribution modeling.
- Experimental Thoroughness: ⭐⭐⭐⭐⭐ Three datasets, two settings, and exceptionally thorough ablation studies (LID levels, LLM selection, FE design, fusion strategy, hyperparameters, and visualizations).
- Writing Quality: ⭐⭐⭐⭐⭐ Clear mathematical derivations, a coherent motivation-method-experiment logic chain, and highly informative figures and tables.
- Value: ⭐⭐⭐⭐ Represents a substantial advancement in the CZSL domain; the "description-as-distribution" paradigm holds strong generalizable potential.