Towards an Optimal Control Perspective of ResNet Training¶
Conference: ICML 2025
arXiv: 2506.21453
Code: -
Area: Theory / Neural Network Training
Keywords: ResNet, optimal control, stage cost, layer pruning, self-regularization
TL;DR¶
Formulates ResNet training as an optimal control problem, achieving self-regularization by adding stage cost losses to intermediate layers, and proves that redundant deep weights asymptotically vanish, laying the foundation for theory-driven layer pruning.
Background & Motivation¶
- Residual connections in ResNet can be viewed as the forward Euler discretization of continuous-time Neural ODEs.
- From an optimal control perspective: forward data propagation corresponds to the state trajectory of a dynamical system, while trainable parameters act as control signals.
- Existing optimal control training methods are only applicable to architectures with fixed hidden dimensions and make restrictive assumptions about the loss functions.
- Goal: Generalize the stage cost formulation to standard ResNets (with non-trivial skip connections) and general loss functions.
Method¶
1. ResNet as a Dynamical System¶
Forward propagation of standard ResNet-N:
where \(\mathcal{K}_\mathcal{S}\) is the set of layer indices with non-trivial skip connections (1×1 convolutions).
2. Designing Intermediate Output Heads¶
Key Innovation: Construct intermediate outputs by utilizing the skip connections of subsequent layers and the final output head:
This indicates that intermediate predictions reuse existing parameters of the backbone network (skip connection weights and the output head), requiring no additional parameters.
3. Stage Cost Training Objective¶
Degenerates to standard training when \(\gamma = 0\).
4. Theoretical Result: Asymptotic Loss Bound¶
Theorem 3.1: Let ResNet-N be trained with stage cost and weight decay, and let ResNet-M (\(M<N\)) be its SubResNet, then:
Implication: When a shallower SubResNet is already capable of completing the learning task (small \(\bar{\mathcal{L}}\)), the weights of deeper residual blocks will approach zero—meaning the network automatically discovers the optimal depth required.
Simplified result without weight decay:
Key Experimental Results¶
CIFAR-10 Loss Trajectory¶
- The ResNet trained with stage cost achieves a good fit after only 12 residual blocks (78.74% test accuracy).
- Standard training only achieves a good fit at the final output.
- Performance plateaus within the same stage (same number of filters) and jumps at the start of a new stage (increased number of filters).
Layer Pruning Comparison¶
| Model | MNIST | CIFAR-10 | CIFAR-100 |
|---|---|---|---|
| ResNet-54 Standard Training | 99.64 | 93.05 | 71.94 |
| SubResNet-12 Standard Training | 99.63 | 91.43 | 68.77 |
| ResNet-54 Stage Cost | 99.62 | 91.59 | 69.97 |
| SubResNet-12 Pruned | 99.54 | 91.02 | 66.55 |
- In homogeneous models, the gap between the pruned SubResNet-12 and the standard-trained SubResNet-12 is only ≤3.5%.
- ResNets trained standardly struggle with lossless layer pruning.
Tightness of the Theoretical Bound¶
Experiments verify that the bound provided by Theorem 3.1 is relatively tight.
Highlights & Insights¶
- Elegantly maps the concept of stage cost from optimal control to standard ResNet architectures.
- Requires no additional parameters—intermediate output heads reuse the skip connections and the output layer.
- Theoretically proves that deeper weights asymptotically approach zero, laying the foundation for theory-driven layer pruning.
- Deep convergence analysis of the training dynamics provides new insights into the optimal depth of ResNets.
- Applicable to general loss functions (including standard cross-entropy).
Limitations & Future Work¶
- Experiments are limited to small-scale datasets such as MNIST and CIFAR-10/100.
- Layer pruning for standard ResNets (with non-homogeneous dimensions) is not yet straightforward.
- The choice of the stage cost weight \(\gamma\) lacks theoretical guidance.
- The connection with early-exit methods is not discussed in depth.
- Not validated on large-scale models (such as ImageNet + ResNet-152).
Rating¶
⭐⭐⭐⭐ — Theoretically elegant, systematically applying the control theory perspective to standard ResNets, but the scale of the experiments limits its practical impact.