Skip to content

Improving Knowledge Distillation via Regularizing Feature Direction and Norm

Conference: ECCV 2024
arXiv: 2305.17007
Code: GitHub
Area: Model Compression
Keywords: Knowledge Distillation, Feature Direction, Feature Norm, Class Mean Alignment, ND Loss

TL;DR

A novel ND loss function is proposed. By simultaneously aligning the feature direction of the student to the class mean direction of the teacher and encouraging the student to generate large-norm features, it significantly improves the performance of existing knowledge distillation methods on ImageNet, CIFAR100, and COCO.

Background & Motivation

Background: Knowledge distillation (KD) is a core technique in model compression, leveraging large pre-trained teacher networks to guide the training of small student networks. Mainstream methods treat teacher features as "knowledge", transferring it by minimizing the KL divergence (at the logits level) or the L2 distance (at the intermediate feature level) between the teacher and the student.

Limitations of Prior Work: Directly forcing student features to align with teacher features does not necessarily translate to a direct improvement in the student's actual performance (e.g., classification accuracy). For instance, minimizing the L2 distance of the penultimate layer features does not guarantee training a better student classifier. Point-wise feature alignment imposes overly strict constraints and overlooks the structural properties that truly matter in the feature space: direction and norm.

Key Challenge: Existing KD methods pursue "feature value matching", whereas classification performance relies more heavily on the "relative positions of features in space" (direction) and the "discriminative strength of features" (norm). Point-wise alignment is neither efficient nor necessary, and may even introduce noise.

Goal: (1) How to more effectively utilize teacher features to guide the student in learning directional knowledge? (2) Student feature norms are significantly smaller than the teacher's; can this phenomenon be exploited to improve distillation performance?

Key Insight: The authors observe two key phenomena: first, the teacher's class mean features naturally construct a strong classifier (resembling a Nearest Class Mean classifier); second, the feature norms generated by the student network are significantly smaller than those of the teacher. Inspired by model pruning and domain adaptation where "features with larger norms are more discriminative," the authors suggest that simultaneously regularizing direction and norm can transfer classification knowledge in a more fundamental way.

Core Idea: Use the teacher's class mean features as directional anchors to align student feature directions, and encourage students to generate large-norm features, replacing traditional point-wise feature matching with a concise ND loss.

Method

Overall Architecture

The overall methodology is highly straightforward: based on the standard knowledge distillation framework, without modifying the network architecture, a new regularization loss, ND loss, is added solely to the penultimate layer features. The input image passes through the student network to generate the penultimate features \(f_s\), while utilizing pre-computed teacher class mean features \(\mu_c\) (each class \(c\) corresponds to a mean vector). The ND loss simultaneously constrains the direction of the student features (aligning with the corresponding class mean direction) and the norm (encouraging large norms). The final loss is a weighted combination of the original KD loss and the ND loss.

Key Designs

  1. Teacher Class Mean Feature Extraction:

    • Function: Construct anchors for directional alignment
    • Mechanism: Prior to training, a forward pass is run on the entire training set using the teacher network to collect features of all samples in each class at the penultimate layer, computing the mean feature for each class as \(\mu_c = \frac{1}{|S_c|}\sum_{x \in S_c} f_T(x)\). These class mean vectors serve as fixed directional anchors throughout the entire training process.
    • Design Motivation: Class mean features are more stable and representative than single-sample features, naturally filtering individual noise. Utilizing class means as alignment targets is more robust than sample-wise alignment, and the computational overhead is negligible (requiring only a one-time pre-computation).
  2. Feature Direction Alignment (Direction Regularization):

    • Function: Guide student features toward the correct class directions
    • Mechanism: For a sample belonging to class \(c\), calculate the cosine similarity between the student feature \(f_s\) and the teacher class mean \(\mu_c\), and maximize this similarity. The direction alignment loss is \(L_{dir} = 1 - \cos(f_s, \mu_c)\). This essentially pulls the student features toward the teacher prototype direction of the corresponding class on a unit hypersphere.
    • Design Motivation: Cosine distance focuses solely on direction rather than magnitude, avoiding the interference caused by norm differences in L2 distance. Directional alignment allows the student to learn "which direction the feature should point to" rather than "what the absolute value of this feature should be", which is more aligned with the essence of classification tasks.
  3. Feature Norm Regularization (Norm Regularization):

    • Function: Encourage students to generate large-norm features to enhance discriminativeness
    • Mechanism: The authors' experiments reveal that student feature norms are significantly smaller than the teacher's (e.g., teacher's average norm is around 50, whereas the student's is only around 15-20). By introducing a norm regularization term \(L_{norm} = -\|f_s\|\) (or equivalently encouraging the norm to increase), the student features become more "confident". The final ND loss unifies direction and norm into a concise form: \(L_{ND} = -f_s^T \cdot \hat{\mu}_c\), where \(\hat{\mu}_c\) represents the normalized class mean.
    • Design Motivation: Inspired by model pruning (where weights with small norms are less important) and domain adaptation (where features with larger norms are more reliable), a large feature norm indicates that the network is more confident in its prediction. ND loss cleverly unifies the optimization of direction and norm into an inner product form, since \(f_s^T \cdot \hat{\mu}_c = \|f_s\| \cdot \cos(f_s, \mu_c)\), simultaneously incentivizing direction alignment and norm enhancement.

Loss & Training

The final training loss is formulated as: \(L = L_{task} + \alpha L_{KD} + \beta L_{ND}\)

where \(L_{task}\) is the cross-entropy classification loss, \(L_{KD}\) represents any existing distillation loss (e.g., KL divergence, FitNet, etc.), and \(L_{ND}\) is the proposed direction-norm regularization loss. \(\beta\) is a hyperparameter that controls the weight of the ND loss. ND loss is plug-and-play and can be combined with any existing KD method.

Key Experimental Results

Main Results

Dataset Teacher→Student Baseline +ND loss Gain
CIFAR-100 ResNet32x4→ResNet8x4 KD: 73.33 75.48 +2.15
CIFAR-100 WRN-40-2→WRN-16-2 KD: 74.92 76.31 +1.39
CIFAR-100 VGG13→VGG8 KD: 72.98 74.15 +1.17
ImageNet ResNet34→ResNet18 KD: 71.03 72.15 +1.12
COCO R101→R50 (Faster RCNN) FGD: 40.7 41.3 +0.6 mAP

Ablation Study

Configuration CIFAR-100 Top-1 (%) Description
Baseline KD 73.33 Distillation with KL divergence only
+ Direction only 74.82 Direction alignment only, already yields significant improvement
+ Norm only 74.15 Norm regularization only
+ ND (Direction + Norm) 75.48 Combination of both yields the best results
Substituting class means with sample-wise teacher features 74.20 Class means perform better than single-sample features

Key Findings

  • Directional alignment is the most critical component. Using direction alignment alone brings a ~1.5% improvement, demonstrating that feature direction is a key piece of information neglected in knowledge distillation.
  • Norm regularization provides an additional ~0.7% gain, showing a synergistic effect when both are combined.
  • Utilizing class mean features yields better results than using sample-wise teacher features, confirming the denoising and smoothing effects of class means.
  • The ND loss is insensitive to the hyperparameter \(\beta\), bringing consistent improvements over a wide range of values.
  • It is equally effective on object detection tasks (COCO), indicating the good generalization capability of the method.

Highlights & Insights

  • Sleek Unification of the ND Loss: A single inner product \(-f_s^T \cdot \hat{\mu}_c\) simultaneously optimizes both the direction and norm objectives, achieving an elegant implementation with no need for separate hyperparameters to control each component. This approach of unifying two seemingly independent optimization goals into a single mathematical expression is highly instructional.
  • The Cleverness of Class Mean as Alignment Anchors: Class mean features are naturally the prototype centers of a Nearest Class Mean (NCM) classifier. Working as alignment targets instead of sample-wise teacher features, they reduce individual noise and lower storage requirements (only requiring storage for \(C\) vectors instead of \(N\)).
  • Plug-and-Play Design Philosophy: The ND loss can be directly superimposed on any existing KD method without altering network architectures or training pipelines. This "performance booster" is particularly suited for transferring to other tasks requiring feature alignment, such as domain adaptation, federated learning, and continual learning.

Limitations & Future Work

  • Pre-computing class mean features once before training keeps them fixed, but the optimal alignment target should perhaps adapt dynamically as the student network learns. Future work could consider momentum-based updates for class means or progressive adjustment strategies based on curriculum learning.
  • The proposed method is primarily evaluated on classification and detection tasks, lacking experiments on other tasks such as segmentation and generation.
  • In scenarios with an immense number of classes (e.g., long-tailed datasets), the class mean features of minority classes may be unreliable, necessitating additional processing.
  • While norm regularization encourages large norms, formal analysis is lacking regarding "how large is appropriate." Overly large norms might lead to numerical instability.
  • vs FitNet / AT: FitNet directly matches the L2 distance of intermediate features, and AT matches attention maps—both representing point-wise or spatial-wise alignment. The directional alignment in this paper is performed at the semantic level (class level), which is coarser-grained but more effective.
  • vs CRD (Contrastive Representation Distillation): CRD utilizes contrastive learning for distillation and also focuses on the structural topology of the feature space. Our ND loss is simpler and more direct, and can be combined with CRD to achieve extra gains.
  • vs DKD (Decoupled KD): DKD decouples logit distillation into target and non-target classes. ND loss provides a complementary perspective from the feature space, and the two can be employed jointly.

Rating

  • Novelty: ⭐⭐⭐⭐ The perspective of analyzing knowledge distillation via feature direction and norm is novel, though the technical implementation is relatively straightforward.
  • Experimental Thoroughness: ⭐⭐⭐⭐ Covers both classification and detection tasks with diverse teacher-student pairs and extensive ablation studies.
  • Writing Quality: ⭐⭐⭐⭐ Clear motivation, concise methodology, and well-structured arguments.
  • Value: ⭐⭐⭐⭐⭐ Significant improvements as a plug-and-play component, holding high practical value.